Source code for sleap.nn.data.dataset_ops
"""Transformers for dataset (multi-example) operations, e.g., shuffling and batching.
These are mostly wrappers for standard tf.data.Dataset ops.
"""
import numpy as np
import tensorflow as tf
from sleap.nn.data.utils import expand_to_rank
import attr
from typing import List, Text, Optional, Any, Callable, Dict
[docs]@attr.s(auto_attribs=True)
class Shuffler:
"""Shuffling transformer for use in pipelines.
The input to this transformer should not be repeated or batched (though the latter
would technically work). Repeating prevents the shuffling from going through "epoch"
or "iteration" loops in the underlying dataset.
Though batching before shuffling works and respects epoch boundaries, it is not
recommended as it implies that the same examples will always be optimized for
together within a mini-batch. This is not as effective for promoting generalization
as element-wise shuffling which produces new combinations of elements within mini-
batches.
The ideal pipeline follows the order:
shuffle -> batch -> repeat
Attributes:
shuffle: If False, returns the input dataset unmodified.
buffer_size: Number of examples to keep in a buffer to sample uniformly from. If
set too high, it may take a long time to fill the initial buffer, especially
if it resets every epoch.
reshuffle_each_iteration: If True, resets the sampling buffer every iteration
through the underlying dataset.
"""
shuffle: bool = True
buffer_size: int = 64
reshuffle_each_iteration: bool = True
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return []
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return []
[docs] def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset with shuffled element order.
Args:
ds_input: Any dataset.
Returns:
A `tf.data.Dataset` with elements containing the same keys, but in a
shuffled order, if enabled.
If the input dataset is repeated, this doesn't really respect epoch
boundaries since it never reaches the end of the iterator.
"""
if self.shuffle:
return ds_input.shuffle(
buffer_size=self.buffer_size,
reshuffle_each_iteration=self.reshuffle_each_iteration,
)
else:
return ds_input
[docs]@attr.s(auto_attribs=True)
class Batcher:
"""Batching transformer for use in pipelines.
This class enables variable-length example keys to be batched by converting them to
ragged tensors prior to concatenation, then converting them back to dense tensors.
See the notes in the `Shuffling` and `Repeater` transformers if training. If using
in inference, this transformer will be used on its own without dropping remainders.
The ideal (training) pipeline follows the order:
shuffle -> batch -> repeat
Attributes:
batch_size: Number of elements within a batch. Every key will be stacked within
their first axis (with expansion) such that it has `batch_size` length.
drop_remainder: If True, final elements with fewer than `batch_size` examples
will be dropped once the end of the input dataset iteration is reached. This
should be True for training and False for inference.
"""
batch_size: int = 8
drop_remainder: bool = False
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return []
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return []
[docs] def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset with batched elements.
Args:
ds_input: Any dataset that produces dictionaries keyed by strings and values
with any rank tensors.
Returns:
A `tf.data.Dataset` with elements containing the same keys, but with each
tensor promoted to 1 rank higher (except for scalars with rank 0 will be
promoted to rank 2).
The keys of each element will contain `batch_size` individual elements
stacked along the axis 0, such that length (`.shape[0]`) is equal to
`batch_size`.
Any keys that had variable length elements within the batch will be padded
with NaNs to the size of the largest element's length for that key.
"""
def expand(example):
"""Expand all keys to a minimum rank of 1."""
for key in example:
example[key] = expand_to_rank(example[key], target_rank=1, prepend=True)
return example
def unrag(example):
"""Convert all keys back to dense tensors NaN padded."""
for key in example:
if isinstance(example[key], tf.RaggedTensor):
example[key] = example[key].to_tensor(
default_value=tf.cast(np.nan, example[key].dtype)
)
return example
# Ensure that all keys have a rank of at least 1 (i.e., scalars).
ds_output = ds_input.map(
expand, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
# Batch elements as ragged tensors.
ds_output = ds_output.apply(
tf.data.experimental.dense_to_ragged_batch(
batch_size=self.batch_size, drop_remainder=self.drop_remainder
)
)
# Convert elements back into dense tensors with padding.
ds_output = ds_output.map(
unrag, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return ds_output
[docs]@attr.s(auto_attribs=True)
class Unbatcher:
"""Unbatching transformer for use in pipelines."""
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return []
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return []
[docs] def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset with unbatched elements."""
return ds_input.unbatch()
[docs]@attr.s(auto_attribs=True)
class Repeater:
"""Repeating transformer for use in pipelines.
Repeats the underlying elements indefinitely or for a number of "iterations" or
"epochs".
If placed before batching, this can create mini-batches with examples from across
epoch boundaries.
If placed after batching, this may never reach examples that are dropped as
remainders if not shuffling.
The ideal pipeline follows the order:
shuffle -> batch -> repeat
Attributes:
repeat: If False, returns the input dataset unmodified.
epochs: If -1, repeats the input dataset elements infinitely. Otherwise, loops
through the elements of the input dataset this number of times.
"""
repeat: bool = True
epochs: int = -1
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return []
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return []
[docs] def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset with repeated loops over the input elements.
Args:
ds_input: Any dataset.
Returns:
A `tf.data.Dataset` with elements containing the same keys, but repeated for
`epochs` iterations.
"""
if self.repeat:
return ds_input.repeat(count=self.epochs)
else:
return ds_input
[docs]@attr.s(auto_attribs=True)
class Prefetcher:
"""Prefetching transformer for use in pipelines.
Prefetches elements from the input dataset to minimize the processing bottleneck
as elements are requested since prefetching can occur in parallel.
Attributes:
prefetch: If False, returns the input dataset unmodified.
buffer_size: Keep `buffer_size` elements loaded in the buffer. If set to -1
(`tf.data.experimental.AUTOTUNE`), this value will be optimized
automatically to decrease latency.
"""
prefetch: bool = True
buffer_size: int = tf.data.experimental.AUTOTUNE
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return []
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return []
[docs] def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset with prefetching to maintain a buffer during iteration.
Args:
ds_input: Any dataset.
Returns:
A `tf.data.Dataset` with identical elements. Processing that occurs with the
elements that are produced can be done in parallel (e.g., training on the
GPU) while new elements are generated from the pipeline.
"""
if self.prefetch:
return ds_input.prefetch(buffer_size=self.buffer_size)
else:
return ds_input
[docs]@attr.s(auto_attribs=True)
class Preloader:
"""Preload elements of the underlying dataset to generate in-memory examples.
This transformer can lead to considerable performance improvements at the cost of
memory consumption.
This is functionally equivalent to `tf.data.Dataset.cache`, except the cached
examples are accessible directly via the `examples` attribute.
Attributes:
examples: Stored list of preloaded elements.
"""
examples: List[Any] = attr.ib(init=False, factory=list)
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return []
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return []
[docs] def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset that generates preloaded elements.
Args:
ds_input: Any `tf.data.Dataset` that generates examples as a dictionary of
tensors. Should not be repeating infinitely.
Return:
A dataset that generates the same examples.
This is similar to prefetching, except that examples are yielded through a
generator and loaded when this method is called rather than during pipeline
iteration.
"""
# Preload examples from the input dataset.
self.examples = list(iter(ds_input))
# Store example metadata.
keys = list(self.examples[0].keys())
dtypes = [self.examples[0][key].dtype for key in keys]
def gen():
for example in self.examples:
yield tuple(example[key] for key in keys)
ds_output = tf.data.Dataset.from_generator(gen, output_types=tuple(dtypes))
ds_output = ds_output.map(
lambda *example: {key: val for key, val in zip(keys, example)}
)
return ds_output
[docs]@attr.s(auto_attribs=True)
class LambdaFilter:
"""Transformer for filtering examples out of a dataset.
This class is useful for eliminating examples that fail to meet some criteria, e.g.,
when no peaks are found.
Attributes:
filter_fn: Callable that takes an example dictionary as input and returns True
if the the element should be kept.
"""
filter_fn: Callable[[Dict[Text, tf.Tensor]], bool]
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return []
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return []
[docs] def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset with filtering applied.
Args:
ds_input: Any dataset that produces dictionaries keyed by strings and values
with any rank tensors.
Returns:
A `tf.data.Dataset` with elements containing the same keys, but with
potentially fewer elements.
"""
# Filter and return.
ds_output = ds_input.filter(self.filter_fn)
return ds_output