Custom loss functions and metrics.
- class sleap.nn.losses.OHKMLoss(hard_to_easy_ratio: float = 2.0, min_hard_keypoints: int = 2, max_hard_keypoints: int = - 1, loss_scale: float = 5.0, name='ohkm', **kwargs)#
Online hard keypoint mining loss.
This loss serves to dynamically reweight the MSE of the top-K worst channels in each batch. This is useful when fine tuning a model to improve performance on a hard part to optimize for (e.g., small, hard to see, often not visible).
Note: This works with any type of channel, so it can work for PAFs as well.
The minimum ratio of the individual keypoint loss with respect to the lowest keypoint loss in order to be considered as “hard”. This helps to switch focus on across groups of keypoints during training.
The minimum number of keypoints that will be considered as “hard”, even if they are not below the
The maximum number of hard keypoints to apply scaling to. This can help when there are few very easy keypoints which may skew the ratio and result in loss scaling being applied to most keypoints, which can reduce the impact of hard mining altogether.
Factor to scale the hard keypoint losses by.
- call(y_gt, y_pr, sample_weight=None)#
y_true – Ground truth values. shape =
[batch_size, d0, .. dN], except sparse loss functions such as sparse categorical crossentropy where shape =
[batch_size, d0, .. dN-1]
y_pred – The predicted values. shape =
[batch_size, d0, .. dN]
Loss values with the shape
[batch_size, d0, .. dN-1].
- classmethod from_config(config: sleap.nn.config.optimization.HardKeypointMiningConfig) sleap.nn.losses.OHKMLoss #
Lossfrom its config (output of
config – Output of
- class sleap.nn.losses.PartLoss(*args, **kwargs)#
Compute channelwise loss.
Useful for monitoring the MSE for specific body parts (channels).
Index of channel to compute MSE for.
Name of the loss tensor.
Computes and returns the metric value tensor.
Result computation is an idempotent operation that simply calculates the metric value using the state variables.
- update_state(y_gt, y_pr, sample_weight=None)#
Accumulates statistics for the metric.
Note: This function is executed as a graph function in graph mode. This means:
Operations on the same resource are executed in textual order. This should make it easier to do things like add the updated value of a variable to another, for example.
You don’t need to worry about collecting the update ops to execute. All update ops added to the graph by this function will be executed.
As a result, code should generally work the same way with graph or eager execution.
**kwargs – A mini-batch of inputs to the Metric.
- sleap.nn.losses.compute_ohkm_loss(y_gt: tensorflow.python.framework.ops.Tensor, y_pr: tensorflow.python.framework.ops.Tensor, hard_to_easy_ratio: float = 2.0, min_hard_keypoints: int = 2, max_hard_keypoints: int = - 1, loss_scale: float = 5.0) tensorflow.python.framework.ops.Tensor #
Compute the online hard keypoint mining loss.