Source code for sleap.nn.data.general
"""General purpose transformers for common pipeline processing tasks."""
import tensorflow as tf
import attr
from typing import List, Text
[docs]@attr.s(auto_attribs=True)
class KeyRenamer:
"""Transformer for renaming example keys."""
old_key_names: List[Text] = attr.ib(factory=list)
new_key_names: List[Text] = attr.ib(factory=list)
drop_old: bool = True
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return self.old_key_names
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
if self.drop_old:
return self.new_key_names
else:
return self.old_key_names + self.new_key_names
[docs] def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset that contains filtered data."""
def rename_keys(example):
"""Local processing function for dataset mapping."""
for old_key, new_key in zip(self.old_key_names, self.new_key_names):
example[new_key] = example[old_key]
if self.drop_old:
for old_key in self.old_key_names:
example.pop(old_key)
return example
# Map the main processing function to each example.
output_ds = input_ds.map(
rename_keys, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return output_ds
[docs]@attr.s(auto_attribs=True)
class KeyFilter:
"""Transformer for filtering example keys."""
keep_keys: List[Text] = attr.ib(factory=list)
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return self.keep_keys
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return self.keep_keys
[docs] def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset that contains filtered data."""
def filter_keys(example):
"""Local processing function for dataset mapping."""
return {key: example[key] for key in self.keep_keys}
# Map the main processing function to each example.
output_ds = input_ds.map(
filter_keys, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return output_ds
[docs]@attr.s(auto_attribs=True)
class KeyDeviceMover:
"""Transformer for moving example keys to cpu."""
keys: List[Text] = attr.ib(factory=list)
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return self.keys
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return self.keys
[docs] def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset that contains data but moved to cpu."""
def move_keys(example):
"""Local processing function for dataset mapping."""
with tf.device("/cpu:0"):
for key in self.keys:
if key in example:
example[key] = tf.identity(example[key])
for key in self.keys:
print(f"{key} on {example[key].device}")
return example
# Map the main processing function to each example.
output_ds = input_ds.map(
move_keys, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return output_ds