sleap.nn.identity#
Utilities for models that learn identity.
These functions implement the inference logic for classifying peaks using class maps or classification vectors.
- sleap.nn.identity.classify_peaks_from_maps(class_maps: Tensor, peak_points: Tensor, peak_vals: Tensor, peak_sample_inds: Tensor, peak_channel_inds: Tensor, n_channels: int) Tuple[Tensor, Tensor, Tensor] [source]#
Classify and group local peaks by their class map probability.
- Parameters:
class_maps – Class maps for a batch as a
tf.Tensor
of dtypetf.float32
and shape(n_samples, height, width, n_classes)
.peak_points – Local peak coordinates as a
tf.Tensor
of dtypetf.float32
and shape(n_peaks,)
. These should be in the same scale as the class maps.peak_vals – Confidence map value each peak as a
tf.Tensor
of dtypetf.float32
and shape(n_peaks,)
.peak_sample_inds – Sample index for each peak as a
tf.Tensor
of dtypetf.int32
and shape(n_peaks,)
.peak_channel_inds – Channel index for each peak as a
tf.Tensor
of dtypetf.int32
and shape(n_peaks,)
.n_channels – Integer number of channels (nodes) the instances should have.
- Returns:
A tuple of
(points, point_vals, class_probs)
containing the grouped peaks.points
: Class-grouped peaks as atf.Tensor
of dtypetf.float32
and shape(n_samples, n_classes, n_channels, 2)
. Missing points will be denoted by NaNs.point_vals
: The confidence map values for each point as atf.Tensor
of dtypetf.float32
and shape(n_samples, n_classes, n_channels)
.class_probs
: Classification probabilities for matched points as atf.Tensor
of dtype
tf.float32
and shape(n_samples, n_classes, n_channels)
.
See also: group_class_peaks
- sleap.nn.identity.classify_peaks_from_vectors(peak_points: Tensor, peak_vals: Tensor, peak_class_probs: Tensor, crop_sample_inds: Tensor, n_samples: int) Tuple[Tensor, Tensor, Tensor] [source]#
Group peaks by classification probabilities.
This is used in top-down classification models.
- Parameters:
peak_points –
peak_vals –
peak_class_probs –
crop_sample_inds –
n_samples – Number of samples in the batch.
- Returns:
A tuple of
(points, point_vals, class_probs)
.points
: Class-grouped peaks as atf.Tensor
of dtypetf.float32
and shape(n_samples, n_classes, n_channels, 2)
. Missing points will be denoted by NaNs.point_vals
: The confidence map values for each point as atf.Tensor
of dtypetf.float32
and shape(n_samples, n_classes, n_channels)
.class_probs
: Classification probabilities for matched points as atf.Tensor
of dtype
tf.float32
and shape(n_samples, n_classes, n_channels)
.
- sleap.nn.identity.group_class_peaks(peak_class_probs: Tensor, peak_sample_inds: Tensor, peak_channel_inds: Tensor, n_samples: int, n_channels: int) Tuple[Tensor, Tensor] [source]#
Group local peaks using class probabilities.
This is useful for matching peaks that span multiple samples and channels into classes (e.g., instance identities) by their class probability.
- Parameters:
peak_class_probs – Class probabilities for each peak as
tf.Tensor
of dtypetf.float32
and shape(n_peaks, n_classes)
.peak_sample_inds – Sample index for each peak as
tf.Tensor
of dtypetf.int32
and shape(n_peaks,)
.peak_channel_inds – Channel index for each peak as
tf.Tensor
of dtypetf.int32
and shape(n_peaks,)
.n_samples – Integer number of samples in the batch.
n_channels – Integer number of channels (nodes) the instances should have.
- Returns:
A tuple of
(peak_inds, class_inds)
.peak_inds
: Indices of class-grouped peaks within[0, n_peaks)
. Will be atmost
n_classes
long.class_inds
: Indices of the corresponding class for each peak within[0, n_peaks)
. Will be at mostn_classes
long.
Notes
Peaks will be assigned to classes by their probability using the Hungarian algorithm. Peaks that are assigned to classes that are not the highest probability for each class are removed from the matches.
See also: classify_peaks_from_maps, classify_peaks_from_vectors