sleap.nn.callbacks#
Training-related tf.keras callbacks.
- class sleap.nn.callbacks.MatplotlibSaver(save_folder: str, plot_fn: Callable[[], Figure], prefix: str | None = None)[source]#
Callback for saving images rendered with matplotlib during training.
This is useful for saving visualizations of the training to disk. It will be called at the end of each epoch.
- plot_fn#
Function with no arguments that returns a matplotlib figure handle. See
sleap.nn.training.Trainer.visualize_predictions
for example.
- save_folder#
Path to a directory to save images to. This folder will be created if it does not exist.
- prefix#
String that will be prepended to the filenames. This is useful for indicating which dataset the visualization was sampled from, for example.
Notes
- This will save images with the naming pattern:
“{save_folder}/{prefix}.{epoch}.png”
- or:
“{save_folder}/{epoch}.png”
if a prefix is not specified.
- class sleap.nn.callbacks.ModelCheckpointOnEvent(filepath: str, event: str = 'train_end')[source]#
Callback for model checkpointing on a fixed event.
- filepath#
Path to save model to.
- event#
Event to trigger model saving (“train_start” or “train_end”).
- class sleap.nn.callbacks.ProgressReporterZMQ(address='tcp://127.0.0.1:9001', what='not_set')[source]#
-
- on_epoch_begin(epoch, logs=None)[source]#
Called at the start of an epoch. Subclasses should override for any actions to run. This function should only be called during train mode. # Arguments
epoch: integer, index of epoch. logs: dict, currently no data is passed to this argument for this method
but that may change in the future.
- on_epoch_end(epoch, logs=None)[source]#
Called at the end of an epoch. Subclasses should override for any actions to run. This function should only be called during train mode. # Arguments
epoch: integer, index of epoch. logs: dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys are prefixed with
val_
.
- class sleap.nn.callbacks.TensorBoardMatplotlibWriter(log_dir: str, plot_fn: Callable[[], Figure], tag: str = 'viz')[source]#
Callback for writing image summaries with visualizations during training.
- logdir#
Path to log directory.
- plot_fn#
Function with no arguments that returns a matplotlib figure handle.
- tag#
Text to append to the summary label in TensorBoard.