Source code for sleap.nn.losses

"""Custom loss functions and metrics."""

import tensorflow as tf
from sleap.nn.config import HardKeypointMiningConfig


[docs]def compute_ohkm_loss( y_gt: tf.Tensor, y_pr: tf.Tensor, hard_to_easy_ratio: float = 2.0, min_hard_keypoints: int = 2, max_hard_keypoints: int = -1, loss_scale: float = 5.0, ) -> tf.Tensor: """Compute the online hard keypoint mining loss.""" # Compute elementwise squared difference. loss = tf.math.squared_difference(y_gt, y_pr) # rank 4 # Store initial shape for normalization. batch_shape = tf.shape(loss) # Reduce over everything but channels axis. loss = tf.reduce_sum(loss, axis=[0, 1, 2]) # Compute the loss for the "easy" keypoint. best_loss = tf.math.reduce_min(loss) # Find the number of hard keypoints. is_hard_keypoint = (loss / best_loss) >= hard_to_easy_ratio n_hard_keypoints = tf.reduce_sum(tf.cast(is_hard_keypoint, tf.int32)) # Work out the actual final number of keypoints to consider as hard. if max_hard_keypoints < 0: max_hard_keypoints = tf.shape(loss)[0] else: max_hard_keypoints = tf.minimum(max_hard_keypoints, tf.shape(loss)[0]) k = tf.minimum(tf.maximum(n_hard_keypoints, min_hard_keypoints), max_hard_keypoints) # Pull out the top hard values. k_vals, k_inds = tf.math.top_k(loss, k=k, sorted=False) # Apply weights. k_loss = k_vals * loss_scale # Reduce over all channels. n_elements = tf.cast( batch_shape[0] * batch_shape[1] * batch_shape[2] * k, tf.float32 ) k_loss = tf.reduce_sum(k_loss) / n_elements return k_loss
[docs]class OHKMLoss(tf.keras.losses.Loss): """Online hard keypoint mining loss. This loss serves to dynamically reweight the MSE of the top-K worst channels in each batch. This is useful when fine tuning a model to improve performance on a hard part to optimize for (e.g., small, hard to see, often not visible). Note: This works with any type of channel, so it can work for PAFs as well. Attributes: hard_to_easy_ratio: The minimum ratio of the individual keypoint loss with respect to the lowest keypoint loss in order to be considered as "hard". This helps to switch focus on across groups of keypoints during training. min_hard_keypoints: The minimum number of keypoints that will be considered as "hard", even if they are not below the `hard_to_easy_ratio`. max_hard_keypoints: The maximum number of hard keypoints to apply scaling to. This can help when there are few very easy keypoints which may skew the ratio and result in loss scaling being applied to most keypoints, which can reduce the impact of hard mining altogether. loss_scale: Factor to scale the hard keypoint losses by. """ def __init__( self, hard_to_easy_ratio: float = 2.0, min_hard_keypoints: int = 2, max_hard_keypoints: int = -1, loss_scale: float = 5.0, name="ohkm", **kwargs ): super(OHKMLoss, self).__init__(name=name, **kwargs) self.hard_to_easy_ratio = hard_to_easy_ratio self.min_hard_keypoints = min_hard_keypoints self.max_hard_keypoints = max_hard_keypoints self.loss_scale = loss_scale
[docs] @classmethod def from_config(cls, config: HardKeypointMiningConfig) -> "OHKMLoss": return cls( hard_to_easy_ratio=config.hard_to_easy_ratio, min_hard_keypoints=config.min_hard_keypoints, max_hard_keypoints=config.max_hard_keypoints if config.max_hard_keypoints is not None else -1, loss_scale=config.loss_scale, )
[docs] def call(self, y_gt, y_pr, sample_weight=None): return compute_ohkm_loss( y_gt, y_pr, hard_to_easy_ratio=self.hard_to_easy_ratio, min_hard_keypoints=self.min_hard_keypoints, max_hard_keypoints=self.max_hard_keypoints, loss_scale=self.loss_scale, )
[docs]class PartLoss(tf.keras.metrics.Metric): """Compute channelwise loss. Useful for monitoring the MSE for specific body parts (channels). Attributes: channel_ind: Index of channel to compute MSE for. name: Name of the loss tensor. """ def __init__(self, channel_ind, name="part_loss", **kwargs): super(PartLoss, self).__init__(name=name, **kwargs) self.channel_ind = channel_ind self.channel_mse = self.add_weight( name=name + ".mse", initializer="zeros", dtype=tf.float32 ) self.n_samples = self.add_weight( name=name + ".n_samples", initializer="zeros", dtype=tf.int32 ) self.height = self.add_weight( name=name + ".height", initializer="zeros", dtype=tf.int32 ) self.width = self.add_weight( name=name + ".width", initializer="zeros", dtype=tf.int32 )
[docs] def update_state(self, y_gt, y_pr, sample_weight=None): shape = tf.shape(y_gt) n_samples = shape[0] channel_mse = tf.reduce_sum( tf.math.squared_difference( tf.gather(y_gt, self.channel_ind, axis=3), tf.gather(y_pr, self.channel_ind, axis=3), ) ) # rank 4 self.height.assign(shape[1]) self.width.assign(shape[2]) self.n_samples.assign_add(n_samples) self.channel_mse.assign_add(channel_mse)
[docs] def result(self): return self.channel_mse / tf.cast( self.n_samples * self.height * self.width, tf.float32 )