sleap.nn.data.instance_centroids

Transformers for finding instance centroids.

class sleap.nn.data.instance_centroids.InstanceCentroidFinder(center_on_anchor_part: bool = False, anchor_part_names=None, skeletons=None, instances_key: str = 'instances')[source]

Data transformer to add centroid information to instances.

This is useful as a transformation to data streams that will be used in centroid networks or for instance cropping.

center_on_anchor_part

If True, specifies that centering should be done relative to a body part rather than the midpoint of the instance bounding box. If False, the midpoint of the bounding box of all points will be used.

Type

bool

anchor_part_names

List of strings specifying the body part name in each skeleton to use as anchors for centering. If center_on_anchor_part is False, this has no effect and does not need to be specified.

Type

Optional[List[str]]

skeletons

List of sleap.Skeleton`s to use for looking up the index of the anchor body parts. If `center_on_anchor_part is False, this has no effect and does not need to be specified.

Type

Optional[List[sleap.skeleton.Skeleton]]

instances_key

Name of the example key where the instance points are stored. Defaults to “instances”.

Type

str

classmethod from_config(config: sleap.nn.config.data.InstanceCroppingConfig, skeletons: Optional[Union[sleap.skeleton.Skeleton, List[sleap.skeleton.Skeleton]]] = None) sleap.nn.data.instance_centroids.InstanceCentroidFinder[source]

Build an instance of this class from its configuration options.

Parameters
  • config – An InstanceCroppingConfig instance with the desired parameters.

  • skeletons – List of skeletons to use. This must be provided if doing instance cropping centered on an anchor part.

Returns

An instance of this class.

This will assume that center_on_anchor_part is False when the config.center_on_part attribute is not a string.

Raises

ValueError – If the skeletons are not provided in the arguments and the config specifies an anchor part name.

property input_keys: List[str]

Return the keys that incoming elements are expected to have.

property output_keys: List[str]

Return the keys that outgoing elements will have.

transform_dataset(ds_input: tensorflow.python.data.ops.dataset_ops.DatasetV2) tensorflow.python.data.ops.dataset_ops.DatasetV2[source]

Create a dataset that contains centroids computed from the inputs.

Parameters

ds_input – A dataset with “instances” key containing instance points in a tf.float32 tensor of shape (n_instances, n_nodes, 2). If centering on anchor parts, a “skeleton_inds” key of dtype tf.int32 and shape (n_instances,) must also be present to indicate which skeleton is associated with each instance. These must match the order in the skeletons attribute of this class.

Returns

A tf.data.Dataset with elements containing a “centroids” key containing a tf.float32 tensor of shape (n_instances, 2) with the computed centroids.

sleap.nn.data.instance_centroids.find_points_bbox_midpoint(points: tensorflow.python.framework.ops.Tensor) tensorflow.python.framework.ops.Tensor[source]

Find the midpoint of the bounding box of a set of points.

Parameters

instances – A tf.Tensor of dtype tf.float32 and of shape (…, n_points, 2), i.e., rank >= 2.

Returns

The midpoints between the bounds of each set of points. The output will be of shape (…, 2), reducing the rank of the input by 1. NaNs will be ignored in the calculation.

Notes

The midpoint is calculated as:
xy_mid = xy_min + ((xy_max - xy_min) / 2)

= ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2) = (2 * xy_min + xy_max - xy_min) / 2 = (xy_min + xy_max) / 2

sleap.nn.data.instance_centroids.get_instance_anchors(instances: tensorflow.python.framework.ops.Tensor, anchor_inds: tensorflow.python.framework.ops.Tensor) tensorflow.python.framework.ops.Tensor[source]

Gather the anchor points of a set of instances.

Parameters
  • instances – A tensor of shape (n_instances, n_nodes, 2) containing instance points. This must be rank-3 even if a single instance is present.

  • anchor_inds – A tensor of shape (n_instances,) and dtype tf.int32. These specify the index of the anchor node for each instance.

Returns

A tensor of shape (n_instances, 2) containing the anchor points for each each instance. This is basically a slice along the nodes axis, where each instance may potentially have a different node to use as an anchor.