Custom loss functions and metrics.

class sleap.nn.losses.OHKMLoss(hard_to_easy_ratio: float = 2.0, min_hard_keypoints: int = 2, max_hard_keypoints: int = - 1, loss_scale: float = 5.0, name='ohkm', **kwargs)[source]

Online hard keypoint mining loss.

This loss serves to dynamically reweight the MSE of the top-K worst channels in each batch. This is useful when fine tuning a model to improve performance on a hard part to optimize for (e.g., small, hard to see, often not visible).

Note: This works with any type of channel, so it can work for PAFs as well.


The minimum ratio of the individual keypoint loss with respect to the lowest keypoint loss in order to be considered as “hard”. This helps to switch focus on across groups of keypoints during training.


The minimum number of keypoints that will be considered as “hard”, even if they are not below the hard_to_easy_ratio.


The maximum number of hard keypoints to apply scaling to. This can help when there are few very easy keypoints which may skew the ratio and result in loss scaling being applied to most keypoints, which can reduce the impact of hard mining altogether.


Factor to scale the hard keypoint losses by.

call(y_gt, y_pr, sample_weight=None)[source]

Invokes the Loss instance.

  • y_true – Ground truth values, with the same shape as ‘y_pred’.

  • y_pred – The predicted values.

classmethod from_config(config: sleap.nn.config.optimization.HardKeypointMiningConfig)sleap.nn.losses.OHKMLoss[source]

Instantiates a Loss from its config (output of get_config()).


config – Output of get_config().


A Loss instance.

class sleap.nn.losses.PartLoss(*args, **kwargs)[source]

Compute channelwise loss.

Useful for monitoring the MSE for specific body parts (channels).


Index of channel to compute MSE for.


Name of the loss tensor.


Computes and returns the metric value tensor.

Result computation is an idempotent operation that simply calculates the metric value using the state variables.

update_state(y_gt, y_pr, sample_weight=None)[source]

Accumulates statistics for the metric.

Note: This function is executed as a graph function in graph mode. This means:

  1. Operations on the same resource are executed in textual order. This should make it easier to do things like add the updated value of a variable to another, for example.

  2. You don’t need to worry about collecting the update ops to execute. All update ops added to the graph by this function will be executed.

As a result, code should generally work the same way with graph or eager execution.

Please use tf.config.experimental_run_functions_eagerly(True) to execute this function eagerly for debugging or profiling.

  • *args

  • **kwargs – A mini-batch of inputs to the Metric.

sleap.nn.losses.compute_ohkm_loss(y_gt: tensorflow.python.framework.ops.Tensor, y_pr: tensorflow.python.framework.ops.Tensor, hard_to_easy_ratio: float = 2.0, min_hard_keypoints: int = 2, max_hard_keypoints: int = - 1, loss_scale: float = 5.0) → tensorflow.python.framework.ops.Tensor[source]

Compute the online hard keypoint mining loss.