sleap.nn.data.instance_centroids#
Transformers for finding instance centroids.
- class sleap.nn.data.instance_centroids.InstanceCentroidFinder(center_on_anchor_part: bool = False, anchor_part_names: Optional[Any] = None, skeletons: Optional[Any] = None, instances_key: str = 'instances')[source]#
Data transformer to add centroid information to instances.
This is useful as a transformation to data streams that will be used in centroid networks or for instance cropping.
- center_on_anchor_part#
If True, specifies that centering should be done relative to a body part rather than the midpoint of the instance bounding box. If False, the midpoint of the bounding box of all points will be used.
- Type
bool
- anchor_part_names#
List of strings specifying the body part name in each skeleton to use as anchors for centering. If
center_on_anchor_part
is False, this has no effect and does not need to be specified.- Type
Optional[List[str]]
- skeletons#
List of
sleap.Skeleton`s to use for looking up the index of the anchor body parts. If `center_on_anchor_part
is False, this has no effect and does not need to be specified.- Type
Optional[List[sleap.skeleton.Skeleton]]
- instances_key#
Name of the example key where the instance points are stored. Defaults to “instances”.
- Type
str
- classmethod from_config(config: sleap.nn.config.data.InstanceCroppingConfig, skeletons: Optional[Union[sleap.skeleton.Skeleton, List[sleap.skeleton.Skeleton]]] = None) sleap.nn.data.instance_centroids.InstanceCentroidFinder [source]#
Build an instance of this class from its configuration options.
- Parameters
config – An
InstanceCroppingConfig
instance with the desired parameters.skeletons – List of skeletons to use. This must be provided if doing instance cropping centered on an anchor part.
- Returns
An instance of this class.
This will assume that
center_on_anchor_part
is False when theconfig.center_on_part
attribute is not a string.- Raises
ValueError – If the skeletons are not provided in the arguments and the config specifies an anchor part name.
- property input_keys: List[str]#
Return the keys that incoming elements are expected to have.
- property output_keys: List[str]#
Return the keys that outgoing elements will have.
- transform_dataset(ds_input: tensorflow.python.data.ops.dataset_ops.DatasetV2) tensorflow.python.data.ops.dataset_ops.DatasetV2 [source]#
Create a dataset that contains centroids computed from the inputs.
- Parameters
ds_input – A dataset with “instances” key containing instance points in a tf.float32 tensor of shape (n_instances, n_nodes, 2). If centering on anchor parts, a “skeleton_inds” key of dtype tf.int32 and shape (n_instances,) must also be present to indicate which skeleton is associated with each instance. These must match the order in the
skeletons
attribute of this class.- Returns
A
tf.data.Dataset
with elements containing a “centroids” key containing a tf.float32 tensor of shape (n_instances, 2) with the computed centroids.
- sleap.nn.data.instance_centroids.find_points_bbox_midpoint(points: tensorflow.python.framework.ops.Tensor) tensorflow.python.framework.ops.Tensor [source]#
Find the midpoint of the bounding box of a set of points.
- Parameters
instances – A tf.Tensor of dtype tf.float32 and of shape (…, n_points, 2), i.e., rank >= 2.
- Returns
The midpoints between the bounds of each set of points. The output will be of shape (…, 2), reducing the rank of the input by 1. NaNs will be ignored in the calculation.
Notes
- The midpoint is calculated as:
- xy_mid = xy_min + ((xy_max - xy_min) / 2)
= ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2) = (2 * xy_min + xy_max - xy_min) / 2 = (xy_min + xy_max) / 2
- sleap.nn.data.instance_centroids.get_instance_anchors(instances: tensorflow.python.framework.ops.Tensor, anchor_inds: tensorflow.python.framework.ops.Tensor) tensorflow.python.framework.ops.Tensor [source]#
Gather the anchor points of a set of instances.
- Parameters
instances – A tensor of shape (n_instances, n_nodes, 2) containing instance points. This must be rank-3 even if a single instance is present.
anchor_inds – A tensor of shape (n_instances,) and dtype tf.int32. These specify the index of the anchor node for each instance.
- Returns
A tensor of shape (n_instances, 2) containing the anchor points for each each instance. This is basically a slice along the nodes axis, where each instance may potentially have a different node to use as an anchor.