Open In Colab

Interactive and resumable training#

Most of the time, you will be training models through the GUI or using the sleap-train CLI.

If you’d like to customize the training process, however, you can use SLEAP’s low-level training functionality interactively. This allows you to define scripts that train models according to your own workflow, for example, to resume training on an already trained model. Another possible application would be to train a model using transfer learning, where a pretrained model can be used to initialize the weights of the new model.

In this notebook we will explore how to set up a training job and train a model for multiple rounds without the GUI or CLI.

1. Setup SLEAP#

Run this cell first to install SLEAP. If you get a dependency error in subsequent cells, just click RuntimeRestart runtime to reload the packages.

Don’t forget to set RuntimeChange runtime typeGPU as the accelerator.

# This should take care of all the dependencies on colab:
!pip uninstall -qqq -y opencv-python opencv-contrib-python
!pip install -qqq "sleap[pypi]>=1.3.3"


# But to do it locally, we'd recommend the conda package (available on Windows + Linux):
# conda create -n sleap -c sleap -c conda-forge -c nvidia sleap
ERROR: Cannot uninstall opencv-python 4.6.0, RECORD file not found. Hint: The package was installed by conda.
ERROR: Cannot uninstall shiboken2 5.15.6, RECORD file not found. You might be able to recover from this via: 'pip install --force-reinstall --no-deps shiboken2==5.15.6'.

Import SLEAP to make sure it installed correctly and print out some information about the system:

import sleap
sleap.versions()
sleap.system_summary()
SLEAP: 1.3.2
TensorFlow: 2.7.0
Numpy: 1.21.5
Python: 3.7.12
OS: Linux-5.15.0-78-generic-x86_64-with-debian-bookworm-sid
GPUs: 1/1 available
  Device: /physical_device:GPU:0
         Available: True
        Initalized: False
     Memory growth: None

2. Setup training data#

Here we will download an existing training dataset package. This is an .slp file that contains both the labeled poses, as well as the image data for labeled frames.

If running on Google Colab, you’ll want to replace this with mounting your Google Drive folder containing your own data, or if running locally, simply change the path to your labels below in TRAINING_SLP_FILE.

# !curl -L --output labels.pkg.slp https://www.dropbox.com/s/b990gxjt3d3j3jh/210205.sleap_wt_gold.13pt.pkg.slp?dl=1
!curl -L --output labels.pkg.slp https://storage.googleapis.com/sleap-data/datasets/wt_gold.13pt/tracking_split2/train.pkg.slp
!ls -lah
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  619M  100  619M    0     0  32.9M      0  0:00:18  0:00:18 --:--:-- 34.4M
total 622M
drwxrwxr-x  3 talmolab talmolab 4.0K Sep  1 14:23 .
drwxrwxr-x 10 talmolab talmolab 4.0K Aug 31 15:43 ..
drwxrwxr-x  2 talmolab talmolab 4.0K Jun 20 10:00 analysis_example
-rw-rw-r--  1 talmolab talmolab 713K Jun 20 10:00 Analysis_examples.ipynb
-rw-rw-r--  1 talmolab talmolab 481K Sep  1 14:02 Data_structures.ipynb
-rw-rw-r--  1 talmolab talmolab 4.1K Jun 20 10:00 index.rst
-rw-rw-r--  1 talmolab talmolab 179K Sep  1 13:58 Interactive_and_realtime_inference.ipynb
-rw-rw-r--  1 talmolab talmolab 120K Sep  1 14:21 Interactive_and_resumable_training.ipynb
-rw-rw-r--  1 talmolab talmolab 620M Sep  1 14:24 labels.pkg.slp
-rw-rw-r--  1 talmolab talmolab 157K Sep  1 14:15 Model_evaluation.ipynb
-rw-rw-r--  1 talmolab talmolab 132K Sep  1 14:18 Post_inference_tracking.ipynb
-rw-rw-r--  1 talmolab talmolab  94K Sep  1 13:44 Training_and_inference_on_an_example_dataset.ipynb
-rw-rw-r--  1 talmolab talmolab  12K Aug 31 11:39 Training_and_inference_using_Google_Drive.ipynb
TRAINING_SLP_FILE = "labels.pkg.slp"

3. Setup training job#

A SLEAP TrainingJobConfig is a structure that contains all of the hyperparameters needed to train a SLEAP model. This is typically saved out to initial_config.json and training_config.json in the model folder so that training runs can be reproduced if needed, as well as to store metadata necessary for inference.

Normally, these are generated interactively by the GUI, or manually by editing an existing JSON file in a text editor. Here, we will define a configuration interactively entirely in Python.

from sleap.nn.config import *

# Initialize the default training job configuration.
cfg = TrainingJobConfig()

# Update path to training data we just downloaded.
cfg.data.labels.training_labels = TRAINING_SLP_FILE
cfg.data.labels.validation_fraction = 0.1

# Preprocesssing and training parameters.
cfg.data.instance_cropping.center_on_part = "thorax"
cfg.optimization.augmentation_config.rotate = True
cfg.optimization.epochs = 10  # This is the maximum number of training rounds.

# These configures the actual neural network and the model type:
cfg.model.backbone.unet = UNetConfig(
    filters=16,
    output_stride=4
)
cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig(
    anchor_part="thorax",
    sigma=1.5,
    output_stride=4
)

# Setup how we want to save the trained model.
cfg.outputs.run_name = "baseline_model.topdown"

Existing configs can also be loaded from a .json file with:

cfg = sleap.load_config("training_config.json")

4. Training#

Next we will create a SLEAP Trainer from the configuration we just specified. This handles all the nitty gritty mechanics necessary to setup training in the backend.

trainer = sleap.nn.training.Trainer.from_config(cfg)
INFO:sleap.nn.training:Loading training labels from: labels.pkg.slp
INFO:sleap.nn.training:Creating training and validation splits from validation fraction: 0.1
INFO:sleap.nn.training:  Splits: Training = 1440 / Validation = 160.

Great, now we’re ready to do the first round of training. This is when the model will actually start to improve over time:

trainer.train()
INFO:sleap.nn.training:Setting up for training...
INFO:sleap.nn.training:Setting up pipeline builders...
INFO:sleap.nn.training:Setting up model...
INFO:sleap.nn.training:Building test pipeline...
2023-09-01 14:24:11.775633: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-01 14:24:11.776555: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 14:24:11.777493: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 14:24:11.778196: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 14:24:12.055738: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 14:24:12.056597: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 14:24:12.057389: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 14:24:12.058046: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21261 MB memory:  -> device: 0, name: NVIDIA RTX A5000, pci bus id: 0000:01:00.0, compute capability: 8.6
INFO:sleap.nn.training:Loaded test example. [1.799s]
INFO:sleap.nn.training:  Input shape: (160, 160, 1)
INFO:sleap.nn.training:Created Keras model.
INFO:sleap.nn.training:  Backbone: UNet(stacks=1, filters=16, filters_rate=2, kernel_size=3, stem_kernel_size=7, convs_per_block=2, stem_blocks=0, down_blocks=4, middle_block=True, up_blocks=2, up_interpolate=False, block_contraction=False)
INFO:sleap.nn.training:  Max stride: 16
INFO:sleap.nn.training:  Parameters: 2,101,501
INFO:sleap.nn.training:  Heads: 
INFO:sleap.nn.training:    [0] = CenteredInstanceConfmapsHead(part_names=['head', 'thorax', 'abdomen', 'wingL', 'wingR', 'forelegL4', 'forelegR4', 'midlegL4', 'midlegR4', 'hindlegL4', 'hindlegR4', 'eyeL', 'eyeR'], anchor_part='thorax', sigma=1.5, output_stride=4, loss_weight=1.0)
INFO:sleap.nn.training:  Outputs: 
INFO:sleap.nn.training:    [0] = KerasTensor(type_spec=TensorSpec(shape=(None, 40, 40, 13), dtype=tf.float32, name=None), name='CenteredInstanceConfmapsHead/BiasAdd:0', description="created by layer 'CenteredInstanceConfmapsHead'")
INFO:sleap.nn.training:Training from scratch
INFO:sleap.nn.training:Setting up data pipelines...
INFO:sleap.nn.training:Training set: n = 1440
INFO:sleap.nn.training:Validation set: n = 160
INFO:sleap.nn.training:Setting up optimization...
INFO:sleap.nn.training:  Learning rate schedule: LearningRateScheduleConfig(reduce_on_plateau=True, reduction_factor=0.5, plateau_min_delta=1e-06, plateau_patience=5, plateau_cooldown=3, min_learning_rate=1e-08)
INFO:sleap.nn.training:  Early stopping: EarlyStoppingConfig(stop_training_on_plateau=True, plateau_min_delta=1e-06, plateau_patience=10)
INFO:sleap.nn.training:Setting up outputs...
INFO:sleap.nn.training:Created run path: models/baseline_model.topdown
INFO:sleap.nn.training:Setting up visualization...
INFO:sleap.nn.training:Finished trainer set up. [3.3s]
INFO:sleap.nn.training:Creating tf.data.Datasets for training data generation...
INFO:sleap.nn.training:Finished creating training datasets. [16.2s]
INFO:sleap.nn.training:Starting training loop...
Epoch 1/10
2023-09-01 14:24:32.586040: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8201
2023-09-01 14:24:42.104556: I tensorflow/stream_executor/cuda/cuda_blas.cc:1774] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
360/360 - 12s - loss: 0.0037 - head: 0.0030 - thorax: 0.0030 - abdomen: 0.0036 - wingL: 0.0040 - wingR: 0.0040 - forelegL4: 0.0037 - forelegR4: 0.0038 - midlegL4: 0.0041 - midlegR4: 0.0041 - hindlegL4: 0.0039 - hindlegR4: 0.0040 - eyeL: 0.0035 - eyeR: 0.0035 - val_loss: 0.0033 - val_head: 0.0020 - val_thorax: 0.0029 - val_abdomen: 0.0030 - val_wingL: 0.0033 - val_wingR: 0.0034 - val_forelegL4: 0.0037 - val_forelegR4: 0.0036 - val_midlegL4: 0.0039 - val_midlegR4: 0.0039 - val_hindlegL4: 0.0037 - val_hindlegR4: 0.0038 - val_eyeL: 0.0029 - val_eyeR: 0.0027 - lr: 1.0000e-04 - 12s/epoch - 32ms/step
Epoch 2/10
360/360 - 7s - loss: 0.0028 - head: 0.0013 - thorax: 0.0018 - abdomen: 0.0026 - wingL: 0.0027 - wingR: 0.0028 - forelegL4: 0.0032 - forelegR4: 0.0033 - midlegL4: 0.0038 - midlegR4: 0.0038 - hindlegL4: 0.0037 - hindlegR4: 0.0038 - eyeL: 0.0015 - eyeR: 0.0015 - val_loss: 0.0025 - val_head: 9.7323e-04 - val_thorax: 0.0011 - val_abdomen: 0.0026 - val_wingL: 0.0024 - val_wingR: 0.0026 - val_forelegL4: 0.0030 - val_forelegR4: 0.0030 - val_midlegL4: 0.0036 - val_midlegR4: 0.0037 - val_hindlegL4: 0.0038 - val_hindlegR4: 0.0037 - val_eyeL: 0.0012 - val_eyeR: 0.0012 - lr: 1.0000e-04 - 7s/epoch - 21ms/step
Epoch 3/10
360/360 - 7s - loss: 0.0022 - head: 8.0630e-04 - thorax: 6.7199e-04 - abdomen: 0.0022 - wingL: 0.0020 - wingR: 0.0021 - forelegL4: 0.0027 - forelegR4: 0.0027 - midlegL4: 0.0033 - midlegR4: 0.0035 - hindlegL4: 0.0034 - hindlegR4: 0.0035 - eyeL: 8.7345e-04 - eyeR: 8.4145e-04 - val_loss: 0.0020 - val_head: 8.6439e-04 - val_thorax: 5.9914e-04 - val_abdomen: 0.0020 - val_wingL: 0.0019 - val_wingR: 0.0020 - val_forelegL4: 0.0025 - val_forelegR4: 0.0024 - val_midlegL4: 0.0030 - val_midlegR4: 0.0031 - val_hindlegL4: 0.0030 - val_hindlegR4: 0.0031 - val_eyeL: 8.9466e-04 - val_eyeR: 9.5174e-04 - lr: 1.0000e-04 - 7s/epoch - 20ms/step
Epoch 4/10
360/360 - 7s - loss: 0.0018 - head: 6.7854e-04 - thorax: 4.6945e-04 - abdomen: 0.0020 - wingL: 0.0017 - wingR: 0.0018 - forelegL4: 0.0023 - forelegR4: 0.0023 - midlegL4: 0.0026 - midlegR4: 0.0027 - hindlegL4: 0.0028 - hindlegR4: 0.0029 - eyeL: 7.4546e-04 - eyeR: 6.9585e-04 - val_loss: 0.0018 - val_head: 7.7640e-04 - val_thorax: 5.3180e-04 - val_abdomen: 0.0020 - val_wingL: 0.0018 - val_wingR: 0.0018 - val_forelegL4: 0.0022 - val_forelegR4: 0.0022 - val_midlegL4: 0.0024 - val_midlegR4: 0.0025 - val_hindlegL4: 0.0026 - val_hindlegR4: 0.0026 - val_eyeL: 9.2650e-04 - val_eyeR: 9.0064e-04 - lr: 1.0000e-04 - 7s/epoch - 20ms/step
Epoch 5/10
360/360 - 7s - loss: 0.0015 - head: 5.8714e-04 - thorax: 4.0531e-04 - abdomen: 0.0017 - wingL: 0.0015 - wingR: 0.0015 - forelegL4: 0.0020 - forelegR4: 0.0019 - midlegL4: 0.0020 - midlegR4: 0.0021 - hindlegL4: 0.0023 - hindlegR4: 0.0024 - eyeL: 6.7827e-04 - eyeR: 6.2254e-04 - val_loss: 0.0015 - val_head: 6.5523e-04 - val_thorax: 4.4019e-04 - val_abdomen: 0.0016 - val_wingL: 0.0016 - val_wingR: 0.0015 - val_forelegL4: 0.0019 - val_forelegR4: 0.0020 - val_midlegL4: 0.0021 - val_midlegR4: 0.0020 - val_hindlegL4: 0.0021 - val_hindlegR4: 0.0021 - val_eyeL: 7.9871e-04 - val_eyeR: 7.8608e-04 - lr: 1.0000e-04 - 7s/epoch - 20ms/step
Epoch 6/10
360/360 - 7s - loss: 0.0013 - head: 5.3215e-04 - thorax: 3.5232e-04 - abdomen: 0.0016 - wingL: 0.0014 - wingR: 0.0014 - forelegL4: 0.0017 - forelegR4: 0.0018 - midlegL4: 0.0017 - midlegR4: 0.0018 - hindlegL4: 0.0020 - hindlegR4: 0.0021 - eyeL: 5.9826e-04 - eyeR: 5.6906e-04 - val_loss: 0.0013 - val_head: 5.3776e-04 - val_thorax: 3.7946e-04 - val_abdomen: 0.0014 - val_wingL: 0.0014 - val_wingR: 0.0013 - val_forelegL4: 0.0017 - val_forelegR4: 0.0018 - val_midlegL4: 0.0016 - val_midlegR4: 0.0017 - val_hindlegL4: 0.0017 - val_hindlegR4: 0.0018 - val_eyeL: 6.6378e-04 - val_eyeR: 6.5611e-04 - lr: 1.0000e-04 - 7s/epoch - 19ms/step
Epoch 7/10
360/360 - 7s - loss: 0.0012 - head: 4.8557e-04 - thorax: 3.1089e-04 - abdomen: 0.0014 - wingL: 0.0012 - wingR: 0.0012 - forelegL4: 0.0016 - forelegR4: 0.0016 - midlegL4: 0.0015 - midlegR4: 0.0016 - hindlegL4: 0.0018 - hindlegR4: 0.0019 - eyeL: 5.6096e-04 - eyeR: 5.3123e-04 - val_loss: 0.0012 - val_head: 5.2092e-04 - val_thorax: 3.4376e-04 - val_abdomen: 0.0014 - val_wingL: 0.0012 - val_wingR: 0.0012 - val_forelegL4: 0.0015 - val_forelegR4: 0.0017 - val_midlegL4: 0.0015 - val_midlegR4: 0.0015 - val_hindlegL4: 0.0017 - val_hindlegR4: 0.0017 - val_eyeL: 6.4288e-04 - val_eyeR: 6.0581e-04 - lr: 1.0000e-04 - 7s/epoch - 19ms/step
Epoch 8/10
360/360 - 7s - loss: 0.0011 - head: 4.3752e-04 - thorax: 2.7513e-04 - abdomen: 0.0013 - wingL: 0.0011 - wingR: 0.0011 - forelegL4: 0.0015 - forelegR4: 0.0015 - midlegL4: 0.0014 - midlegR4: 0.0014 - hindlegL4: 0.0017 - hindlegR4: 0.0017 - eyeL: 5.1807e-04 - eyeR: 4.9554e-04 - val_loss: 0.0011 - val_head: 5.6743e-04 - val_thorax: 3.5883e-04 - val_abdomen: 0.0014 - val_wingL: 0.0012 - val_wingR: 0.0011 - val_forelegL4: 0.0015 - val_forelegR4: 0.0016 - val_midlegL4: 0.0014 - val_midlegR4: 0.0014 - val_hindlegL4: 0.0015 - val_hindlegR4: 0.0015 - val_eyeL: 6.2925e-04 - val_eyeR: 6.5965e-04 - lr: 1.0000e-04 - 7s/epoch - 19ms/step
Epoch 9/10
360/360 - 7s - loss: 0.0011 - head: 4.2635e-04 - thorax: 2.4829e-04 - abdomen: 0.0012 - wingL: 0.0010 - wingR: 0.0010 - forelegL4: 0.0015 - forelegR4: 0.0014 - midlegL4: 0.0013 - midlegR4: 0.0013 - hindlegL4: 0.0016 - hindlegR4: 0.0017 - eyeL: 5.0197e-04 - eyeR: 4.8384e-04 - val_loss: 0.0011 - val_head: 4.8699e-04 - val_thorax: 3.5631e-04 - val_abdomen: 0.0013 - val_wingL: 0.0011 - val_wingR: 0.0011 - val_forelegL4: 0.0014 - val_forelegR4: 0.0016 - val_midlegL4: 0.0013 - val_midlegR4: 0.0015 - val_hindlegL4: 0.0014 - val_hindlegR4: 0.0015 - val_eyeL: 6.1692e-04 - val_eyeR: 5.8370e-04 - lr: 1.0000e-04 - 7s/epoch - 19ms/step
Epoch 10/10
360/360 - 7s - loss: 9.8454e-04 - head: 3.9611e-04 - thorax: 2.2278e-04 - abdomen: 0.0012 - wingL: 9.4893e-04 - wingR: 9.5555e-04 - forelegL4: 0.0014 - forelegR4: 0.0014 - midlegL4: 0.0012 - midlegR4: 0.0012 - hindlegL4: 0.0015 - hindlegR4: 0.0016 - eyeL: 4.7396e-04 - eyeR: 4.4770e-04 - val_loss: 0.0010 - val_head: 4.9330e-04 - val_thorax: 2.9460e-04 - val_abdomen: 0.0013 - val_wingL: 9.5190e-04 - val_wingR: 9.9289e-04 - val_forelegL4: 0.0014 - val_forelegR4: 0.0015 - val_midlegL4: 0.0012 - val_midlegR4: 0.0012 - val_hindlegL4: 0.0014 - val_hindlegR4: 0.0014 - val_eyeL: 5.5512e-04 - val_eyeR: 5.3737e-04 - lr: 1.0000e-04 - 7s/epoch - 19ms/step
INFO:sleap.nn.training:Finished training loop. [1.3 min]
INFO:sleap.nn.training:Deleting visualization directory: models/baseline_model.topdown/viz
INFO:sleap.nn.training:Saving evaluation metrics to model folder...


INFO:sleap.nn.evals:Saved predictions: models/baseline_model.topdown/labels_pr.train.slp
INFO:sleap.nn.evals:Saved metrics: models/baseline_model.topdown/metrics.train.npz
INFO:sleap.nn.evals:OKS mAP: 0.508754


INFO:sleap.nn.evals:Saved predictions: models/baseline_model.topdown/labels_pr.val.slp
INFO:sleap.nn.evals:Saved metrics: models/baseline_model.topdown/metrics.val.npz
INFO:sleap.nn.evals:OKS mAP: 0.477220

5. Continuing training#

If we still have the trainer in memory, we can continue training by simply calling trainer.train() again with a potentially different number of epochs:

trainer.config.optimization.epochs = 3
trainer.train()
INFO:sleap.nn.training:Creating tf.data.Datasets for training data generation...
INFO:sleap.nn.training:Finished creating training datasets. [17.1s]
INFO:sleap.nn.training:Starting training loop...
Epoch 1/3
360/360 - 7s - loss: 9.3201e-04 - head: 3.7118e-04 - thorax: 2.0303e-04 - abdomen: 0.0011 - wingL: 8.9319e-04 - wingR: 9.0134e-04 - forelegL4: 0.0013 - forelegR4: 0.0013 - midlegL4: 0.0011 - midlegR4: 0.0011 - hindlegL4: 0.0014 - hindlegR4: 0.0015 - eyeL: 4.4919e-04 - eyeR: 4.2012e-04 - val_loss: 9.4680e-04 - val_head: 3.9131e-04 - val_thorax: 2.4191e-04 - val_abdomen: 0.0010 - val_wingL: 8.9155e-04 - val_wingR: 8.9295e-04 - val_forelegL4: 0.0013 - val_forelegR4: 0.0014 - val_midlegL4: 0.0012 - val_midlegR4: 0.0012 - val_hindlegL4: 0.0013 - val_hindlegR4: 0.0013 - val_eyeL: 5.3658e-04 - val_eyeR: 5.0085e-04 - lr: 1.0000e-04 - 7s/epoch - 20ms/step
Epoch 2/3
360/360 - 7s - loss: 8.8906e-04 - head: 3.6015e-04 - thorax: 1.9128e-04 - abdomen: 0.0010 - wingL: 8.5054e-04 - wingR: 8.5352e-04 - forelegL4: 0.0013 - forelegR4: 0.0013 - midlegL4: 0.0010 - midlegR4: 0.0011 - hindlegL4: 0.0014 - hindlegR4: 0.0014 - eyeL: 4.3093e-04 - eyeR: 4.0690e-04 - val_loss: 8.9501e-04 - val_head: 4.1907e-04 - val_thorax: 2.3487e-04 - val_abdomen: 0.0010 - val_wingL: 8.6145e-04 - val_wingR: 8.4151e-04 - val_forelegL4: 0.0013 - val_forelegR4: 0.0014 - val_midlegL4: 0.0010 - val_midlegR4: 0.0011 - val_hindlegL4: 0.0013 - val_hindlegR4: 0.0012 - val_eyeL: 5.2130e-04 - val_eyeR: 4.9293e-04 - lr: 1.0000e-04 - 7s/epoch - 19ms/step
Epoch 3/3
360/360 - 7s - loss: 8.5396e-04 - head: 3.4440e-04 - thorax: 1.7180e-04 - abdomen: 9.9867e-04 - wingL: 8.1743e-04 - wingR: 8.2288e-04 - forelegL4: 0.0012 - forelegR4: 0.0012 - midlegL4: 9.7110e-04 - midlegR4: 0.0010 - hindlegL4: 0.0013 - hindlegR4: 0.0014 - eyeL: 4.1497e-04 - eyeR: 3.9294e-04 - val_loss: 8.8076e-04 - val_head: 3.7130e-04 - val_thorax: 2.4712e-04 - val_abdomen: 0.0010 - val_wingL: 8.2889e-04 - val_wingR: 8.5931e-04 - val_forelegL4: 0.0012 - val_forelegR4: 0.0014 - val_midlegL4: 9.9400e-04 - val_midlegR4: 0.0011 - val_hindlegL4: 0.0012 - val_hindlegR4: 0.0012 - val_eyeL: 4.9486e-04 - val_eyeR: 4.6961e-04 - lr: 1.0000e-04 - 7s/epoch - 19ms/step
INFO:sleap.nn.training:Finished training loop. [0.4 min]
INFO:sleap.nn.training:Deleting visualization directory: models/baseline_model.topdown/viz
INFO:sleap.nn.training:Saving evaluation metrics to model folder...


INFO:sleap.nn.evals:Saved predictions: models/baseline_model.topdown/labels_pr.train.slp
INFO:sleap.nn.evals:Saved metrics: models/baseline_model.topdown/metrics.train.npz
INFO:sleap.nn.evals:OKS mAP: 0.559100


INFO:sleap.nn.evals:Saved predictions: models/baseline_model.topdown/labels_pr.val.slp
INFO:sleap.nn.evals:Saved metrics: models/baseline_model.topdown/metrics.val.npz
INFO:sleap.nn.evals:OKS mAP: 0.529680

As you can see, the loss and accuracy pick up from where it left off in the previous training.

Usually, however, if you’re continuing training it’s likely because you’re starting off from an already trained model.

In this case, all you need to do to continue training is to create a new Trainer from the existing model configuration and load up the weights before continuing training:

# Load config.
cfg = sleap.load_config("models/baseline_model.topdown")
# cfg.outputs.run_name = "new_folder"  # Set the run_name to a new value if you want the model to be saved to a different folder.

# Create and initialize the trainer.
trainer = sleap.nn.training.Trainer.from_config(cfg)
trainer.setup()

# Replace the randomly initialized weights with the saved weights.
trainer.keras_model.load_weights("models/baseline_model.topdown/best_model.h5")
INFO:sleap.nn.training:Loading training labels from: labels.pkg.slp
INFO:sleap.nn.training:Creating training and validation splits from validation fraction: 0.1
INFO:sleap.nn.training:  Splits: Training = 1440 / Validation = 160.
INFO:sleap.nn.training:Setting up for training...
INFO:sleap.nn.training:Setting up pipeline builders...
INFO:sleap.nn.training:Setting up model...
INFO:sleap.nn.training:Building test pipeline...
INFO:sleap.nn.training:Loaded test example. [0.925s]
INFO:sleap.nn.training:  Input shape: (160, 160, 1)
INFO:sleap.nn.training:Created Keras model.
INFO:sleap.nn.training:  Backbone: UNet(stacks=1, filters=16, filters_rate=2.0, kernel_size=3, stem_kernel_size=7, convs_per_block=2, stem_blocks=0, down_blocks=4, middle_block=True, up_blocks=2, up_interpolate=False, block_contraction=False)
INFO:sleap.nn.training:  Max stride: 16
INFO:sleap.nn.training:  Parameters: 2,101,501
INFO:sleap.nn.training:  Heads: 
INFO:sleap.nn.training:    [0] = CenteredInstanceConfmapsHead(part_names=['head', 'thorax', 'abdomen', 'wingL', 'wingR', 'forelegL4', 'forelegR4', 'midlegL4', 'midlegR4', 'hindlegL4', 'hindlegR4', 'eyeL', 'eyeR'], anchor_part='thorax', sigma=1.5, output_stride=4, loss_weight=1.0)
INFO:sleap.nn.training:  Outputs: 
INFO:sleap.nn.training:    [0] = KerasTensor(type_spec=TensorSpec(shape=(None, 40, 40, 13), dtype=tf.float32, name=None), name='CenteredInstanceConfmapsHead/BiasAdd:0', description="created by layer 'CenteredInstanceConfmapsHead'")
INFO:sleap.nn.training:Training from scratch
INFO:sleap.nn.training:Setting up data pipelines...
INFO:sleap.nn.training:Training set: n = 1440
INFO:sleap.nn.training:Validation set: n = 160
INFO:sleap.nn.training:Setting up optimization...
INFO:sleap.nn.training:  Learning rate schedule: LearningRateScheduleConfig(reduce_on_plateau=True, reduction_factor=0.5, plateau_min_delta=1e-06, plateau_patience=5, plateau_cooldown=3, min_learning_rate=1e-08)
INFO:sleap.nn.training:  Early stopping: EarlyStoppingConfig(stop_training_on_plateau=True, plateau_min_delta=1e-06, plateau_patience=10)
INFO:sleap.nn.training:Setting up outputs...
INFO:sleap.nn.training:Created run path: models/baseline_model.topdown
INFO:sleap.nn.training:Setting up visualization...
INFO:sleap.nn.training:Finished trainer set up. [2.2s]
trainer.config.optimization.epochs = 3
trainer.train()
INFO:sleap.nn.training:Creating tf.data.Datasets for training data generation...
INFO:sleap.nn.training:Finished creating training datasets. [17.7s]
INFO:sleap.nn.training:Starting training loop...
Epoch 1/3
360/360 - 9s - loss: 8.3664e-04 - head: 3.5190e-04 - thorax: 1.7037e-04 - abdomen: 9.8467e-04 - wingL: 7.9929e-04 - wingR: 8.0385e-04 - forelegL4: 0.0012 - forelegR4: 0.0012 - midlegL4: 9.5228e-04 - midlegR4: 9.8510e-04 - hindlegL4: 0.0013 - hindlegR4: 0.0013 - eyeL: 4.0772e-04 - eyeR: 3.9413e-04 - val_loss: 8.7351e-04 - val_head: 4.0943e-04 - val_thorax: 1.7453e-04 - val_abdomen: 9.4413e-04 - val_wingL: 8.3617e-04 - val_wingR: 8.4860e-04 - val_forelegL4: 0.0012 - val_forelegR4: 0.0012 - val_midlegL4: 9.4441e-04 - val_midlegR4: 0.0011 - val_hindlegL4: 0.0014 - val_hindlegR4: 0.0014 - val_eyeL: 4.4847e-04 - val_eyeR: 4.4179e-04 - lr: 1.0000e-04 - 9s/epoch - 24ms/step
Epoch 2/3
360/360 - 7s - loss: 8.0541e-04 - head: 3.4627e-04 - thorax: 1.6070e-04 - abdomen: 9.4325e-04 - wingL: 7.7257e-04 - wingR: 7.7434e-04 - forelegL4: 0.0012 - forelegR4: 0.0012 - midlegL4: 8.9573e-04 - midlegR4: 9.3483e-04 - hindlegL4: 0.0013 - hindlegR4: 0.0013 - eyeL: 4.0939e-04 - eyeR: 3.8417e-04 - val_loss: 8.2339e-04 - val_head: 3.9561e-04 - val_thorax: 1.2637e-04 - val_abdomen: 8.6513e-04 - val_wingL: 7.1751e-04 - val_wingR: 7.5540e-04 - val_forelegL4: 0.0012 - val_forelegR4: 0.0012 - val_midlegL4: 8.5588e-04 - val_midlegR4: 0.0010 - val_hindlegL4: 0.0013 - val_hindlegR4: 0.0014 - val_eyeL: 4.8189e-04 - val_eyeR: 4.2402e-04 - lr: 1.0000e-04 - 7s/epoch - 20ms/step
Epoch 3/3
360/360 - 7s - loss: 7.7741e-04 - head: 3.2087e-04 - thorax: 1.4398e-04 - abdomen: 9.1826e-04 - wingL: 7.4005e-04 - wingR: 7.5282e-04 - forelegL4: 0.0011 - forelegR4: 0.0011 - midlegL4: 8.6551e-04 - midlegR4: 8.9726e-04 - hindlegL4: 0.0012 - hindlegR4: 0.0013 - eyeL: 3.8423e-04 - eyeR: 3.7468e-04 - val_loss: 8.4657e-04 - val_head: 3.5649e-04 - val_thorax: 1.2162e-04 - val_abdomen: 8.9171e-04 - val_wingL: 7.9007e-04 - val_wingR: 8.2471e-04 - val_forelegL4: 0.0013 - val_forelegR4: 0.0013 - val_midlegL4: 8.1375e-04 - val_midlegR4: 9.8217e-04 - val_hindlegL4: 0.0014 - val_hindlegR4: 0.0013 - val_eyeL: 4.7370e-04 - val_eyeR: 4.2098e-04 - lr: 1.0000e-04 - 7s/epoch - 19ms/step
INFO:sleap.nn.training:Finished training loop. [0.4 min]
INFO:sleap.nn.training:Deleting visualization directory: models/baseline_model.topdown/viz
INFO:sleap.nn.training:Saving evaluation metrics to model folder...


INFO:sleap.nn.evals:Saved predictions: models/baseline_model.topdown/labels_pr.train.slp
INFO:sleap.nn.evals:Saved metrics: models/baseline_model.topdown/metrics.train.npz
INFO:sleap.nn.evals:OKS mAP: 0.585451


INFO:sleap.nn.evals:Saved predictions: models/baseline_model.topdown/labels_pr.val.slp
INFO:sleap.nn.evals:Saved metrics: models/baseline_model.topdown/metrics.val.npz
INFO:sleap.nn.evals:OKS mAP: 0.574921

Again, the loss and accuracy pick up from where they left off prior to this round of training.

The resulting model can be used as usual for inference on new data.