"""Transformers for finding instance centroids."""
import tensorflow as tf
import attr
from typing import Optional, List, Text, Union
import sleap
from sleap.nn.data.utils import ensure_list
from sleap.nn.config import InstanceCroppingConfig
[docs]def find_points_bbox_midpoint(points: tf.Tensor) -> tf.Tensor:
"""Find the midpoint of the bounding box of a set of points.
Args:
instances: A tf.Tensor of dtype tf.float32 and of shape (..., n_points, 2),
i.e., rank >= 2.
Returns:
The midpoints between the bounds of each set of points. The output will be of
shape (..., 2), reducing the rank of the input by 1. NaNs will be ignored in the
calculation.
Notes:
The midpoint is calculated as:
xy_mid = xy_min + ((xy_max - xy_min) / 2)
= ((2 * xy_min) / 2) + ((xy_max - xy_min) / 2)
= (2 * xy_min + xy_max - xy_min) / 2
= (xy_min + xy_max) / 2
"""
pts_min = tf.reduce_min(points, axis=-2)
pts_max = tf.reduce_max(points, axis=-2)
return (pts_max + pts_min) * 0.5
[docs]def get_instance_anchors(instances: tf.Tensor, anchor_inds: tf.Tensor) -> tf.Tensor:
"""Gather the anchor points of a set of instances.
Args:
instances: A tensor of shape (n_instances, n_nodes, 2) containing instance
points. This must be rank-3 even if a single instance is present.
anchor_inds: A tensor of shape (n_instances,) and dtype tf.int32. These specify
the index of the anchor node for each instance.
Returns:
A tensor of shape (n_instances, 2) containing the anchor points for each
each instance. This is basically a slice along the nodes axis, where each
instance may potentially have a different node to use as an anchor.
"""
inds = tf.stack([tf.range(tf.shape(anchor_inds)[0]), anchor_inds], axis=-1)
return tf.gather_nd(instances, inds)
[docs]@attr.s(auto_attribs=True)
class InstanceCentroidFinder:
"""Data transformer to add centroid information to instances.
This is useful as a transformation to data streams that will be used in centroid
networks or for instance cropping.
Attributes:
center_on_anchor_part: If True, specifies that centering should be done relative
to a body part rather than the midpoint of the instance bounding box. If
False, the midpoint of the bounding box of all points will be used.
anchor_part_names: List of strings specifying the body part name in each
skeleton to use as anchors for centering. If `center_on_anchor_part` is
False, this has no effect and does not need to be specified.
skeletons: List of `sleap.Skeleton`s to use for looking up the index of the
anchor body parts. If `center_on_anchor_part` is False, this has no effect
and does not need to be specified.
"""
center_on_anchor_part: bool = False
anchor_part_names: Optional[List[Text]] = attr.ib(
default=None, converter=attr.converters.optional(ensure_list)
)
skeletons: Optional[List[sleap.Skeleton]] = attr.ib(
default=None, converter=attr.converters.optional(ensure_list)
)
[docs] @classmethod
def from_config(
cls,
config: InstanceCroppingConfig,
skeletons: Optional[Union[sleap.Skeleton, List[sleap.Skeleton]]] = None,
) -> "InstanceCentroidFinder":
"""Build an instance of this class from its configuration options.
Args:
config: An `InstanceCroppingConfig` instance with the desired parameters.
skeletons: List of skeletons to use. This must be provided if doing instance
cropping centered on an anchor part.
Returns:
An instance of this class.
This will assume that `center_on_anchor_part` is False when the
`config.center_on_part` attribute is not a string.
Raises:
ValueError: If the skeletons are not provided in the arguments and the
config specifies an anchor part name.
"""
if isinstance(config.center_on_part, str):
if skeletons is None:
raise ValueError(
"Skeletons must be provided when the config specifies an anchor "
"part (config.center_on_anchor_part = "
f"{config.center_on_part})."
)
return cls(
center_on_anchor_part=True,
anchor_part_names=config.center_on_part,
skeletons=skeletons,
)
else:
return cls(
center_on_anchor_part=False, anchor_part_names=None, skeletons=None
)
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
if self.center_on_anchor_part:
return ["instances", "skeleton_inds"]
else:
return ["instances"]
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return self.input_keys + ["centroids"]