sleap.nn.callbacks#

Training-related tf.keras callbacks.

class sleap.nn.callbacks.MatplotlibSaver(save_folder: str, plot_fn: Callable[[], Figure], prefix: str | None = None)[source]#

Callback for saving images rendered with matplotlib during training.

This is useful for saving visualizations of the training to disk. It will be called at the end of each epoch.

plot_fn#

Function with no arguments that returns a matplotlib figure handle. See sleap.nn.training.Trainer.visualize_predictions for example.

save_folder#

Path to a directory to save images to. This folder will be created if it does not exist.

prefix#

String that will be prepended to the filenames. This is useful for indicating which dataset the visualization was sampled from, for example.

Notes

This will save images with the naming pattern:

“{save_folder}/{prefix}.{epoch}.png”

or:

“{save_folder}/{epoch}.png”

if a prefix is not specified.

on_epoch_end(epoch, logs=None)[source]#

Save figure at the end of each epoch.

class sleap.nn.callbacks.ModelCheckpointOnEvent(filepath: str, event: str = 'train_end')[source]#

Callback for model checkpointing on a fixed event.

filepath#

Path to save model to.

event#

Event to trigger model saving (“train_start” or “train_end”).

on_epoch_end(epoch, logs=None)[source]#

Called at the end of each epoch.

on_train_begin(logs=None)[source]#

Called at the start of training.

on_train_end(logs=None)[source]#

Called at the end of training.

class sleap.nn.callbacks.ProgressReporterZMQ(address='tcp://127.0.0.1:9001', what='not_set')[source]#
on_batch_begin(batch, logs=None)[source]#

A backwards compatibility alias for on_train_batch_begin.

on_batch_end(batch, logs=None)[source]#

A backwards compatibility alias for on_train_batch_end.

on_epoch_begin(epoch, logs=None)[source]#

Called at the start of an epoch. Subclasses should override for any actions to run. This function should only be called during train mode. # Arguments

epoch: integer, index of epoch. logs: dict, currently no data is passed to this argument for this method

but that may change in the future.

on_epoch_end(epoch, logs=None)[source]#

Called at the end of an epoch. Subclasses should override for any actions to run. This function should only be called during train mode. # Arguments

epoch: integer, index of epoch. logs: dict, metric results for this training epoch, and for the

validation epoch if validation is performed. Validation result keys are prefixed with val_.

on_train_begin(logs=None)[source]#

Called at the beginning of training. Subclasses should override for any actions to run. # Arguments

logs: dict, currently no data is passed to this argument for this method

but that may change in the future.

on_train_end(logs=None)[source]#

Called at the end of training. Subclasses should override for any actions to run. # Arguments

logs: dict, currently no data is passed to this argument for this method

but that may change in the future.

class sleap.nn.callbacks.TensorBoardMatplotlibWriter(log_dir: str, plot_fn: Callable[[], Figure], tag: str = 'viz')[source]#

Callback for writing image summaries with visualizations during training.

logdir#

Path to log directory.

plot_fn#

Function with no arguments that returns a matplotlib figure handle.

tag#

Text to append to the summary label in TensorBoard.

on_epoch_end(epoch, logs=None)[source]#

Called at the end of each epoch.

class sleap.nn.callbacks.TrainingControllerZMQ(address='tcp://127.0.0.1:9000', topic='', poll_timeout=10)[source]#
on_batch_end(batch, logs=None)[source]#

Called at the end of a training batch.

set_lr(lr)[source]#

Adjust the model learning rate.

This is the based off of the implementation used in the native learning rate scheduling callbacks.