Source code for sleap.nn.data.augmentation
"""Transformers for applying data augmentation."""
# Monkey patch for: https://github.com/aleju/imgaug/issues/537
# TODO: Fix when PyPI/conda packages are available for version fencing.
import numpy
if hasattr(numpy.random, "_bit_generator"):
numpy.random.bit_generator = numpy.random._bit_generator
import numpy as np
import tensorflow as tf
import attr
from typing import List, Text
import imgaug as ia
import imgaug.augmenters as iaa
from sleap.nn.config import AugmentationConfig
[docs]@attr.s(auto_attribs=True)
class ImgaugAugmenter:
"""Data transformer based on the `imgaug` library.
This class can generate a `tf.data.Dataset` from an existing one that generates
image and instance data. Element of the output dataset will have a set of
augmentation transformations applied.
Attributes:
augmenter: An instance of `imgaug.augmenters.Sequential` that will be applied to
each element of the input dataset.
"""
augmenter: iaa.Sequential
[docs] @classmethod
def from_config(cls, config: AugmentationConfig) -> "ImgaugAugmenter":
"""Create an augmenter from a set of configuration parameters.
Args:
config: An `AugmentationConfig` instance with the desired parameters.
Returns:
An instance of this class with the specified augmentation configuration.
"""
aug_stack = []
if config.rotate:
aug_stack.append(
iaa.Affine(
rotate=(config.rotation_min_angle, config.rotation_max_angle)
)
)
if config.translate:
aug_stack.append(
iaa.Affine(
translate_px={
"x": (config.translate_min, config.translate_max),
"y": (config.translate_min, config.translate_max),
}
)
)
if config.scale:
aug_stack.append(iaa.Affine(scale=(config.scale_min, config.scale_max)))
if config.uniform_noise:
aug_stack.append(
iaa.AddElementwise(
value=(config.uniform_noise_min_val, config.uniform_noise_max_val)
)
)
if config.gaussian_noise:
aug_stack.append(
iaa.AdditiveGaussianNoise(
loc=config.gaussian_noise_mean, scale=config.gaussian_noise_stddev
)
)
if config.contrast:
aug_stack.append(
iaa.GammaContrast(
gamma=(config.contrast_min_gamma, config.contrast_max_gamma)
)
)
if config.brightness:
aug_stack.append(
iaa.Add(value=(config.brightness_min_val, config.brightness_max_val))
)
return cls(augmenter=iaa.Sequential(aug_stack))
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return ["image", "instances"]
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return self.input_keys
[docs] def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
"""Create a `tf.data.Dataset` with elements containing augmented data.
Args:
input_ds: A dataset with elements that contain the keys "image" and
"instances". This is typically raw data from a data provider.
Returns:
A `tf.data.Dataset` with the same keys as the input, but with images and
instance points updated with the applied augmentations.
Notes:
The "scale" key in examples are not modified when scaling augmentation is
applied.
"""
# Define augmentation function to map over each sample.
def py_augment(image, instances):
"""Local processing function that will not be autographed."""
# Ensure that the transformations applied to all data within this
# example are kept consistent.
aug_det = self.augmenter.to_deterministic()
# Augment the image.
aug_img = aug_det.augment_image(image.numpy())
# Augment each set of points for each instance.
aug_instances = []
for instance in instances:
kps = ia.KeypointsOnImage.from_xy_array(
instance.numpy(), tuple(image.shape)
)
aug_instance = aug_det.augment_keypoints(kps).to_xy_array()
aug_instances.append(aug_instance)
# Convert the results to tensors.
# aug_img = tf.convert_to_tensor(aug_img, dtype=image.dtype)
# This will get converted to a rank 3 tensor (n_instances, n_nodes, 2).
aug_instances = np.stack(aug_instances, axis=0)
# aug_instances = [
# tf.convert_to_tensor(x, dtype=instances.dtype) for x in aug_instances
# ]
return aug_img, aug_instances
def augment(frame_data):
"""Wrap local processing function for dataset mapping."""
image, instances = tf.py_function(
py_augment,
[frame_data["image"], frame_data["instances"]],
[frame_data["image"].dtype, frame_data["instances"].dtype],
)
image.set_shape(frame_data["image"].get_shape())
instances.set_shape(frame_data["instances"].get_shape())
frame_data.update({"image": image, "instances": instances})
return frame_data
# Apply the augmentation to each element.
# Note: We map sequentially since imgaug gets slower with tf.data parallelism.
output_ds = input_ds.map(augment)
return output_ds