Open In Colab

Interactive and realtime inference#

For most workflows, using the sleap-track CLI is probably the most convenient option, but if you’re developing a custom application you can take advantage of SLEAP’s inference API to use your trained models in your own custom scripts.

In this notebook we will explore how to predict poses from raw images in pure Python, and do some basic benchmarking on a simulated realtime predictor that could be used to enable closed-loop experiments.

1. Setup SLEAP#

Run this cell first to install SLEAP. If you get a dependency error in subsequent cells, just click RuntimeRestart runtime to reload the packages.

Don’t forget to set RuntimeChange runtime typeGPU as the accelerator.

# This should take care of all the dependencies on colab:
!pip uninstall -qqq -y opencv-python opencv-contrib-python
!pip install -qqq "sleap[pypi]>=1.3.3"


# But to do it locally, we'd recommend the conda package (available on Windows + Linux):
# conda create -n sleap -c sleap -c conda-forge -c nvidia sleap
ERROR: Cannot uninstall opencv-python 4.6.0, RECORD file not found. Hint: The package was installed by conda.
ERROR: Cannot uninstall shiboken2 5.15.6, RECORD file not found. You might be able to recover from this via: 'pip install --force-reinstall --no-deps shiboken2==5.15.6'.

Import SLEAP to make sure it installed correctly and print out some information about the system:

import sleap
sleap.disable_preallocation()  # This initializes the GPU and prevents TensorFlow from filling the entire GPU memory
sleap.versions()
sleap.system_summary()
SLEAP: 1.3.2
TensorFlow: 2.7.0
Numpy: 1.21.5
Python: 3.7.12
OS: Linux-5.15.0-78-generic-x86_64-with-debian-bookworm-sid
GPUs: 1/1 available
  Device: /physical_device:GPU:0
         Available: True
        Initalized: False
     Memory growth: True
2023-09-01 13:56:37.731425: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 13:56:37.735933: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 13:56:37.736867: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero

2. Setup data#

Before we start, let’s download a raw video and a set of trained top-down ID models that we’ll use to build our application around.

!curl -L --output video.mp4 https://storage.googleapis.com/sleap-data/reference/flies13/190719_090330_wt_18159206_rig1.2%4015000-17560.mp4
!curl -L --output centroid_model.zip https://storage.googleapis.com/sleap-data/reference/flies13/centroid.fast.210504_182918.centroid.n%3D1800.zip
!curl -L --output centered_instance_id_model.zip https://storage.googleapis.com/sleap-data/reference/flies13/td_id.fast.v2.210519_111253.multi_class_topdown.n%3D1800.zip
!ls -lah
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 81.3M  100 81.3M    0     0  23.7M      0  0:00:03  0:00:03 --:--:-- 23.7M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 6223k  100 6223k    0     0  30.2M      0 --:--:-- --:--:-- --:--:-- 30.3M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 32.2M  100 32.2M    0     0  14.5M      0  0:00:02  0:00:02 --:--:-- 14.5M
total 1.1G
drwxrwxr-x  5 talmolab talmolab 4.0K Sep  1 13:56  .
drwxrwxr-x 10 talmolab talmolab 4.0K Aug 31 15:43  ..
-rw-rw-r--  1 talmolab talmolab  82M May 20  2021  190719_090330_wt_18159206_rig1.2@15000-17560.mp4.1
-rw-rw-r--  1 talmolab talmolab 1.6M May 20  2021  190719_090330_wt_18159206_rig1.2@15000-17560.slp
-rw-rw-r--  1 talmolab talmolab 1.6M May 20  2021  190719_090330_wt_18159206_rig1.2@15000-17560.slp.1
drwxrwxr-x  2 talmolab talmolab 4.0K Jun 20 10:00  analysis_example
-rw-rw-r--  1 talmolab talmolab 713K Jun 20 10:00  Analysis_examples.ipynb
-rw-rw-r--  1 talmolab talmolab  33M Sep  1 13:56  centered_instance_id_model.zip
-rw-rw-r--  1 talmolab talmolab 6.1M May 20  2021 'centroid.fast.210504_182918.centroid.n=1800.zip'
-rw-rw-r--  1 talmolab talmolab 6.1M May 20  2021 'centroid.fast.210504_182918.centroid.n=1800.zip.1'
-rw-rw-r--  1 talmolab talmolab 6.1M Sep  1 13:56  centroid_model.zip
drwxrwxr-x  4 talmolab talmolab 4.0K Sep  1 13:30  dataset
-rw-rw-r--  1 talmolab talmolab 481K Sep  1 13:49  Data_structures.ipynb
-rw-rw-r--  1 talmolab talmolab 661K Aug 31 12:52  fly_clip.mp4
-rw-rw-r--  1 talmolab talmolab 4.1K Jun 20 10:00  index.rst
-rw-rw-r--  1 talmolab talmolab 197K Sep  1 13:53  Interactive_and_realtime_inference.ipynb
-rw-rw-r--  1 talmolab talmolab 120K Aug 31 12:25  Interactive_and_resumable_training.ipynb
-rw-rw-r--  1 talmolab talmolab 620M Aug 31 12:14  labels.pkg.slp
-rw-rw-r--  1 talmolab talmolab 1.6M Aug 31 12:05  labels_with_images.pkg.slp
-rw-rw-r--  1 talmolab talmolab 158K Aug 31 12:35  Model_evaluation.ipynb
drwxrwxr-x  4 talmolab talmolab 4.0K Sep  1 13:39  models
-rw-rw-r--  1 talmolab talmolab 157K Aug 31 12:52  Post_inference_tracking.ipynb
-rw-rw-r--  1 talmolab talmolab 412K Aug 31 12:52  predictions.slp
-rw-rw-r--  1 talmolab talmolab 422K Aug 31 12:52  retracked.slp
-rw-rw-r--  1 talmolab talmolab  30M May 20  2021 'td_fast.210505_012601.centered_instance.n=1800.zip'
-rw-rw-r--  1 talmolab talmolab  30M May 20  2021 'td_fast.210505_012601.centered_instance.n=1800.zip.1'
-rw-rw-r--  1 talmolab talmolab  30M May 20  2021 'td_fast.210505_012601.centered_instance.n=1800.zip.2'
-rw-rw-r--  1 talmolab talmolab  78M May  6  2021  test.pkg.slp
-rw-rw-r--  1 talmolab talmolab  89M Sep  1 13:42  trained_models.zip
-rw-rw-r--  1 talmolab talmolab  94K Sep  1 13:44  Training_and_inference_on_an_example_dataset.ipynb
-rw-rw-r--  1 talmolab talmolab  12K Aug 31 11:39  Training_and_inference_using_Google_Drive.ipynb
-rw-rw-r--  1 talmolab talmolab  82M Sep  1 13:56  video.mp4

Note: These zip files just have the contents of standard SLEAP model folders that are generated during training.

3. Interactive inference#

SLEAP provides a high-level API for performing inference in the form of Predictor classes specific to each approach/model type.

To create one from a set of trained models, we can use the high-level sleap.load_model() function:

predictor = sleap.load_model(["centroid_model.zip", "centered_instance_id_model.zip"], batch_size=16)
2023-09-01 13:57:04.806004: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-01 13:57:04.807011: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 13:57:04.807970: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 13:57:04.808962: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 13:57:05.103658: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 13:57:05.104377: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 13:57:05.105059: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-01 13:57:05.106019: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21129 MB memory:  -> device: 0, name: NVIDIA RTX A5000, pci bus id: 0000:01:00.0, compute capability: 8.6

This function handles all the logic of loading trained models, reading the configurations used to train them, and constructs inference models that also include non-trainable operations like peak finding and instance grouping.

Next, we’ll load a video that we want to use for inference. SLEAP Video objects don’t actually load the whole video into memory, they just provide a common numpy-like interface for reading from different file formats:

video = sleap.load_video("video.mp4")
video.shape, video.dtype
((2560, 1024, 1024, 1), dtype('uint8'))

Our predictor is pretty flexible. It can handle a variety of different input formats, all of which will return a Labels object that contains all of our predictions:

# Load frames to a numpy array.
imgs = video[:100]
print(f"imgs.shape: {imgs.shape}")

# Predict on numpy array.
predictions = predictor.predict(imgs)
predictions
imgs.shape: (100, 1024, 1024, 1)
2023-09-01 13:57:13.455046: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8201
2023-09-01 13:57:15.358483: I tensorflow/stream_executor/cuda/cuda_blas.cc:1774] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Labels(labeled_frames=100, videos=1, skeletons=1, tracks=2)

We can then inspect the results of our predictor:

# Visualize a frame.
predictions[100].plot(scale=0.25)
../_images/Interactive_and_realtime_inference_16_0.png
# Inspect the contents of a single frame.
labeled_frame = predictions[100]
labeled_frame.instances
[PredictedInstance(video=Video(filename=video.mp4, shape=(2560, 1024, 1024, 1), backend=MediaVideo), frame_idx=100, points=[head: (212.5, 427.0, 0.94), thorax: (252.0, 433.1, 0.95), abdomen: (288.6, 439.3, 0.68), wingL: (304.5, 443.3, 0.88), wingR: (306.2, 435.8, 0.68), forelegL4: (216.2, 445.5, 0.88), forelegR4: (216.1, 410.0, 0.90), midlegL4: (244.4, 471.3, 0.90), midlegR4: (256.6, 408.9, 0.86), hindlegL4: (275.0, 459.2, 0.89), hindlegR4: (292.3, 412.0, 0.81), eyeL: (220.0, 438.0, 0.84), eyeR: (223.8, 417.5, 0.91)], score=0.99, track=Track(spawned_on=0, name='female'), tracking_score=0.00),
 PredictedInstance(video=Video(filename=video.mp4, shape=(2560, 1024, 1024, 1), backend=MediaVideo), frame_idx=100, points=[head: (313.7, 432.6, 0.87), thorax: (348.9, 427.9, 1.00), abdomen: (378.9, 425.8, 0.83), wingL: (397.0, 428.7, 0.89), wingR: (394.9, 420.7, 0.74), forelegL4: (307.4, 446.4, 0.88), forelegR4: (306.5, 422.5, 0.89), midlegL4: (341.6, 474.2, 0.97), midlegR4: (332.6, 386.3, 0.97), hindlegL4: (378.9, 458.8, 0.92), hindlegR4: (387.7, 394.8, 0.88), eyeL: (323.7, 442.1, 0.96), eyeR: (320.7, 420.8, 0.88)], score=0.99, track=Track(spawned_on=0, name='male'), tracking_score=0.00)]
# Convert an instance to a numpy array:
labeled_frame[0].numpy()
rec.array([[212.51400757, 426.97024536],
           [251.97747803, 433.08648682],
           [288.64355469, 439.3086853 ],
           [304.53396606, 443.33477783],
           [306.20336914, 435.77227783],
           [216.24688721, 445.47549438],
           [216.14550781, 409.98342896],
           [244.39497375, 471.31561279],
           [256.61740112, 408.89056396],
           [274.97470093, 459.1831665 ],
           [292.2600708 , 411.95904541],
           [219.98565674, 437.97906494],
           [223.75566101, 417.5496521 ]],
          dtype=float64)

What if we don’t want or need the inference results wrapped in the SLEAP structures?

By using the low-level inference model, we can actually go directly from image to numpy arrays of our results:

imgs = video[:16]  # batch of 16 images

predictions = predictor.inference_model.predict(imgs, numpy=True)
predictions
4/4 [==============================] - 2s 176ms/step
{'instance_peaks': array([[[[234.2224 , 430.62598],
          [271.5043 , 436.13202],
          [309.87125, 436.64966],
          [324.12512, 438.3908 ],
          [320.3458 , 435.9504 ],
          [246.42352, 450.67786],
          [242.37636, 413.81458],
          [285.5624 , 460.22766],
          [273.45117, 406.51895],
          [      nan,       nan],
          [      nan,       nan],
          [241.9716 , 442.32303],
          [245.46788, 421.90228]],
 
         [[319.80017, 435.48407],
          [351.93695, 434.0301 ],
          [369.43228, 431.78564],
          [393.89014, 481.0584 ],
          [398.4241 , 429.79565],
          [      nan,       nan],
          [305.42896, 419.3896 ],
          [325.67926, 475.0098 ],
          [331.97974, 384.30814],
          [363.66406, 473.9354 ],
          [377.3852 , 398.13065],
          [328.40244, 445.51434],
          [328.1667 , 423.94733]]],
 
 
        [[[234.36913, 430.38037],
          [271.65576, 436.0479 ],
          [311.6751 , 437.00995],
          [324.48315, 438.1421 ],
          [322.20544, 435.06784],
          [246.43257, 450.61487],
          [242.3986 , 413.8269 ],
          [285.565  , 460.00977],
          [273.78204, 406.46442],
          [      nan,       nan],
          [      nan,       nan],
          [242.11816, 442.0634 ],
          [245.55441, 421.7281 ]],
 
         [[320.03793, 435.2389 ],
          [353.87274, 434.77695],
          [370.67218, 432.9711 ],
          [393.91922, 481.09735],
          [399.77133, 431.25983],
          [      nan,       nan],
          [308.409  , 421.48993],
          [325.82016, 474.90115],
          [331.94632, 385.0652 ],
          [363.65408, 473.70728],
          [384.68225, 399.30194],
          [328.72806, 445.15356],
          [328.48532, 423.624  ]]],
 
 
        [[[234.5559 , 430.06238],
          [271.8775 , 435.9898 ],
          [312.13086, 438.16318],
          [324.77222, 437.65994],
          [322.40115, 434.7244 ],
          [246.44681, 450.51874],
          [242.45566, 413.7617 ],
          [285.8958 , 460.56442],
          [273.66855, 406.2377 ],
          [      nan,       nan],
          [      nan,       nan],
          [242.26588, 441.80545],
          [245.77664, 420.7662 ]],
 
         [[320.46994, 435.2546 ],
          [354.89484, 434.93176],
          [372.25574, 433.46127],
          [394.40717, 479.5797 ],
          [400.30173, 431.96054],
          [306.9821 , 449.3157 ],
          [308.8817 , 421.52148],
          [325.98843, 474.9167 ],
          [332.17917, 385.04684],
          [363.0318 , 473.50616],
          [391.05493, 396.85666],
          [329.16904, 445.04953],
          [328.89996, 423.52533]]],
 
 
        [[[234.65547, 429.6946 ],
          [272.38303, 435.68842],
          [311.04352, 437.86963],
          [324.80847, 437.3792 ],
          [322.84747, 433.93973],
          [246.71852, 451.2873 ],
          [242.57388, 413.58414],
          [286.164  , 461.83655],
          [272.8726 , 406.21753],
          [      nan,       nan],
          [      nan,       nan],
          [242.43861, 441.46246],
          [245.25829, 420.48416]],
 
         [[320.7713 , 433.55927],
          [356.25912, 432.81424],
          [372.98462, 432.9266 ],
          [402.0365 , 465.378  ],
          [400.8439 , 431.7685 ],
          [      nan,       nan],
          [310.4258 , 422.7895 ],
          [325.16397, 474.86514],
          [332.16724, 384.9967 ],
          [362.87766, 473.12836],
          [390.43555, 393.69998],
          [330.20596, 443.4066 ],
          [329.0497 , 421.68896]]],
 
 
        [[[234.51591, 429.5735 ],
          [272.3791 , 435.4755 ],
          [310.74457, 436.20264],
          [325.24997, 437.69904],
          [323.1339 , 433.8241 ],
          [246.75269, 451.22192],
          [242.58466, 413.53275],
          [286.0668 , 461.6229 ],
          [272.87787, 406.2068 ],
          [      nan,       nan],
          [      nan,       nan],
          [242.3858 , 441.31342],
          [245.15892, 420.27942]],
 
         [[320.91632, 432.5178 ],
          [356.588  , 432.3604 ],
          [374.51236, 432.42508],
          [405.0515 , 450.2759 ],
          [401.2467 , 432.2713 ],
          [314.74677, 442.78735],
          [312.76758, 422.29553],
          [325.20752, 474.6215 ],
          [332.2873 , 384.86606],
          [362.8446 , 472.95822],
          [388.92188, 394.203  ],
          [329.54233, 442.43842],
          [329.1192 , 420.79416]]],
 
 
        [[[234.54964, 429.56854],
          [272.30457, 435.13345],
          [309.08594, 434.02444],
          [325.13245, 437.11148],
          [324.71674, 431.81714],
          [246.79828, 450.9629 ],
          [242.6766 , 413.53745],
          [286.09372, 461.14362],
          [272.87155, 406.23718],
          [      nan,       nan],
          [      nan,       nan],
          [242.4111 , 441.2425 ],
          [245.13495, 420.83694]],
 
         [[320.7404 , 430.43884],
          [356.4725 , 431.68488],
          [375.05853, 431.87177],
          [404.3775 , 451.92688],
          [401.39508, 431.9776 ],
          [      nan,       nan],
          [312.77365, 421.6409 ],
          [325.17343, 474.26575],
          [331.44904, 384.56747],
          [363.05463, 472.54587],
          [388.72284, 394.13287],
          [330.25458, 440.28958],
          [328.9332 , 419.74493]]],
 
 
        [[[234.15704, 429.3947 ],
          [272.1558 , 435.1859 ],
          [310.46423, 435.5753 ],
          [324.42407, 437.18854],
          [322.80786, 433.41486],
          [246.72241, 450.9671 ],
          [242.64005, 413.65726],
          [285.9537 , 461.01648],
          [272.73447, 406.31635],
          [305.89285, 449.9849 ],
          [      nan,       nan],
          [241.21112, 441.0713 ],
          [244.77327, 419.9886 ]],
 
         [[321.03162, 429.8643 ],
          [356.5856 , 430.9501 ],
          [377.2166 , 431.29108],
          [405.09204, 451.2633 ],
          [402.97113, 431.12497],
          [      nan,       nan],
          [312.74753, 421.16742],
          [325.3774 , 474.73508],
          [331.5342 , 384.97403],
          [378.56894, 469.3632 ],
          [388.81372, 393.89886],
          [330.641  , 439.67194],
          [329.04425, 418.99023]]],
 
 
        [[[232.79128, 428.2476 ],
          [271.7884 , 435.45706],
          [310.26096, 437.58252],
          [322.67697, 439.28253],
          [322.35138, 435.4916 ],
          [246.49533, 451.1817 ],
          [242.4297 , 413.56104],
          [286.01126, 461.4526 ],
          [272.72516, 406.3869 ],
          [      nan,       nan],
          [284.4912 , 408.79095],
          [240.58961, 440.1936 ],
          [244.4464 , 420.00543]],
 
         [[322.69318, 430.96207],
          [358.88284, 430.98035],
          [379.26816, 431.0259 ],
          [405.7312 , 449.5473 ],
          [405.13306, 431.02057],
          [      nan,       nan],
          [309.64542, 421.59024],
          [325.46237, 474.79062],
          [331.63318, 384.9981 ],
          [390.9735 , 466.93915],
          [388.87518, 393.89645],
          [331.4858 , 440.98822],
          [330.72357, 419.30713]]],
 
 
        [[[232.9138 , 428.26993],
          [271.89908, 435.6341 ],
          [310.36536, 437.9696 ],
          [322.63763, 439.87323],
          [322.4065 , 435.7932 ],
          [246.48575, 451.27322],
          [242.48721, 413.6446 ],
          [285.74454, 460.08987],
          [272.75647, 406.338  ],
          [      nan,       nan],
          [320.82465, 422.17297],
          [240.64159, 440.22705],
          [244.54178, 420.04788]],
 
         [[322.2764 , 429.7331 ],
          [359.43756, 429.0462 ],
          [379.8793 , 429.56253],
          [407.32346, 448.95087],
          [405.74594, 429.27792],
          [315.46356, 441.38046],
          [309.48642, 421.8147 ],
          [325.63016, 474.81934],
          [331.73767, 385.03244],
          [399.19778, 461.1395 ],
          [388.32227, 394.00305],
          [331.94138, 439.76627],
          [330.20728, 418.03998]]],
 
 
        [[[232.59984, 427.94275],
          [271.68756, 435.925  ],
          [309.74356, 438.45367],
          [322.3493 , 441.94934],
          [322.39355, 436.09885],
          [246.09349, 450.45755],
          [242.331  , 413.8041 ],
          [284.40057, 460.55066],
          [273.6091 , 406.43307],
          [286.35394, 459.9949 ],
          [      nan,       nan],
          [240.04814, 440.10544],
          [244.36105, 419.95673]],
 
         [[322.50397, 428.86414],
          [359.65952, 428.01282],
          [381.80063, 428.2879 ],
          [407.9239 , 446.02728],
          [406.27682, 428.24774],
          [317.42343, 444.4193 ],
          [308.38232, 422.35754],
          [325.6553 , 474.45853],
          [331.8156 , 384.7812 ],
          [399.62988, 456.58368],
          [388.52002, 394.27118],
          [332.3299 , 438.78006],
          [330.43085, 417.03174]]],
 
 
        [[[232.25208, 427.7414 ],
          [271.57523, 436.99503],
          [308.347  , 440.97897],
          [321.64392, 445.52814],
          [322.16394, 439.4637 ],
          [229.9819 , 444.81857],
          [242.35481, 413.535  ],
          [284.59384, 461.70065],
          [273.50806, 406.95544],
          [286.72635, 460.96436],
          [314.3465 , 428.5469 ],
          [239.56883, 440.8733 ],
          [244.04318, 420.60315]],
 
         [[324.36966, 429.4342 ],
          [360.08127, 427.41803],
          [384.283  , 427.4751 ],
          [408.8785 , 443.59448],
          [408.36377, 425.55792],
          [316.73703, 445.6411 ],
          [308.78436, 421.899  ],
          [325.92154, 474.19464],
          [331.91168, 385.32022],
          [399.73245, 457.32578],
          [388.57062, 394.18298],
          [334.3139 , 438.40005],
          [331.89133, 417.64728]]],
 
 
        [[[232.70679, 428.36255],
          [272.08994, 436.64023],
          [310.14267, 440.50543],
          [322.68262, 444.5147 ],
          [322.82147, 438.87054],
          [224.32256, 448.4768 ],
          [242.57848, 413.34476],
          [284.7278 , 461.26282],
          [273.8772 , 406.77335],
          [286.55972, 460.77054],
          [      nan,       nan],
          [239.95602, 440.7761 ],
          [244.31602, 420.40244]],
 
         [[325.2043 , 429.92737],
          [360.62262, 427.1631 ],
          [386.82898, 425.76257],
          [410.35846, 440.0152 ],
          [408.79132, 423.68118],
          [318.88504, 445.35867],
          [308.8374 , 421.72562],
          [326.25244, 474.88055],
          [332.2403 , 385.27567],
          [399.44467, 457.21188],
          [388.84778, 394.12372],
          [335.362  , 439.4058 ],
          [332.62274, 417.9344 ]]],
 
 
        [[[232.79385, 428.12885],
          [272.08496, 435.77728],
          [310.3099 , 437.81348],
          [324.31982, 440.3584 ],
          [324.60254, 434.39813],
          [222.91586, 451.43195],
          [242.6026 , 413.74078],
          [284.7489 , 460.09384],
          [273.8778 , 406.4865 ],
          [287.56982, 459.68353],
          [322.8655 , 421.80096],
          [240.19046, 440.23196],
          [244.46782, 419.99805]],
 
         [[325.60196, 431.36603],
          [360.8261 , 427.19696],
          [387.17218, 425.47867],
          [410.81366, 438.09143],
          [408.99658, 422.15668],
          [318.84363, 445.00012],
          [311.57254, 423.2615 ],
          [333.60617, 467.9318 ],
          [331.6039 , 385.32465],
          [399.44635, 457.16357],
          [388.92133, 394.11078],
          [336.72043, 439.8229 ],
          [332.6642 , 419.31372]]],
 
 
        [[[232.83435, 428.26373],
          [272.11572, 435.61078],
          [312.17926, 439.66278],
          [322.83746, 442.15924],
          [324.40552, 435.6441 ],
          [225.87045, 451.41144],
          [242.64131, 413.59937],
          [285.06647, 460.35507],
          [273.84183, 406.3719 ],
          [      nan,       nan],
          [322.41534, 422.61237],
          [240.42723, 440.2208 ],
          [244.4097 , 419.95218]],
 
         [[327.3499 , 431.52005],
          [361.313  , 425.36267],
          [389.47607, 423.60114],
          [411.6601 , 435.50894],
          [409.51843, 419.6943 ],
          [319.90283, 445.82428],
          [313.70898, 423.5036 ],
          [345.66882, 473.1757 ],
          [331.79486, 385.46274],
          [399.46533, 457.10553],
          [388.24854, 394.009  ],
          [337.8076 , 440.06436],
          [333.29004, 419.49707]]],
 
 
        [[[232.41422, 429.4673 ],
          [271.8141 , 435.7682 ],
          [310.01324, 439.98956],
          [322.19714, 443.71683],
          [324.71207, 434.39133],
          [224.85786, 451.4593 ],
          [242.5914 , 413.65207],
          [285.67142, 461.77646],
          [273.7307 , 406.5118 ],
          [      nan,       nan],
          [322.71594, 420.21155],
          [239.99216, 440.57278],
          [243.82819, 420.339  ]],
 
         [[328.47983, 431.74188],
          [363.93173, 425.2397 ],
          [390.49423, 423.05255],
          [413.68115, 433.6671 ],
          [410.5454 , 419.09042],
          [320.30078, 446.80396],
          [313.82977, 421.7456 ],
          [356.64886, 473.89554],
          [331.84995, 385.1559 ],
          [399.78146, 457.11206],
          [388.5744 , 393.94125],
          [339.9305 , 440.99496],
          [334.5468 , 419.42017]]],
 
 
        [[[232.05379, 430.01157],
          [271.71146, 436.17175],
          [310.08688, 438.66077],
          [322.65015, 442.097  ],
          [324.3269 , 434.45065],
          [224.67744, 450.92798],
          [242.56874, 413.94662],
          [285.72803, 462.40347],
          [273.67886, 406.66385],
          [313.6862 , 456.8137 ],
          [318.559  , 416.42374],
          [239.62582, 441.11035],
          [242.73026, 420.8417 ]],
 
         [[329.30188, 431.77295],
          [364.57666, 425.20844],
          [391.32507, 421.96838],
          [414.35016, 433.2262 ],
          [411.04324, 418.17578],
          [320.63538, 445.82654],
          [315.36795, 420.21204],
          [360.9988 , 471.7216 ],
          [332.20065, 385.24988],
          [399.20847, 456.9794 ],
          [388.68896, 394.04962],
          [340.75934, 441.0198 ],
          [335.4428 , 419.33124]]]], dtype=float32),
 'instance_peak_vals': array([[[0.9914025 , 0.9798533 , 0.7552497 , 0.45417705, 0.49756864,
          0.8265212 , 0.89824754, 0.7941327 , 0.81785023, 0.05611448,
          0.06403984, 0.88647026, 0.96359974],
         [0.9033977 , 0.25969282, 0.6343123 , 0.8396003 , 0.7613073 ,
          0.04938014, 0.84057474, 0.8820076 , 0.8816869 , 0.8243384 ,
          0.33521563, 0.8434063 , 0.8127704 ]],
 
        [[0.9598888 , 0.97341204, 0.6766811 , 0.35414153, 0.49778372,
          0.883279  , 0.9271338 , 0.7989652 , 0.7574282 , 0.04437362,
          0.06203796, 0.8609162 , 0.89723104],
         [0.8814398 , 0.43337214, 0.6627722 , 0.8388201 , 0.71751094,
          0.08318384, 0.7553143 , 0.8750135 , 0.8972577 , 0.85390973,
          0.87049603, 0.84071857, 0.8853136 ]],
 
        [[0.9277581 , 0.9876475 , 0.71884066, 0.36052382, 0.53324103,
          0.89681005, 0.92098916, 0.8180281 , 0.6177351 , 0.0311976 ,
          0.07055778, 0.83666444, 0.8608399 ],
         [0.8386477 , 0.58817774, 0.72051835, 0.7902795 , 0.7041355 ,
          0.2181147 , 0.76299024, 0.8507803 , 0.8824023 , 0.8892915 ,
          0.8559173 , 0.83882904, 0.9163557 ]],
 
        [[0.9318335 , 1.0054291 , 0.7037247 , 0.44776785, 0.55141157,
          0.8751741 , 0.8788193 , 0.7378067 , 0.6061791 , 0.06516132,
          0.145283  , 0.81688696, 0.88854957],
         [0.85625255, 0.86021763, 0.82891417, 0.5004723 , 0.8896506 ,
          0.15082283, 0.57127994, 0.86683005, 0.94244254, 0.8910252 ,
          0.9375356 , 0.92730576, 0.8518939 ]],
 
        [[0.9335175 , 0.98755246, 0.66180676, 0.5590857 , 0.5017098 ,
          0.89124495, 0.8839093 , 0.77439654, 0.5733776 , 0.0646795 ,
          0.12731166, 0.816599  , 0.90029544],
         [0.9238624 , 0.8279644 , 0.7274184 , 0.8509916 , 0.9116395 ,
          0.21640316, 0.4109717 , 0.92344654, 0.8912647 , 0.8676515 ,
          0.91081876, 0.9236755 , 0.9313457 ]],
 
        [[0.9660537 , 0.97779256, 0.6795893 , 0.5347014 , 0.49429995,
          0.89868015, 0.88998085, 0.82294524, 0.49898362, 0.14230077,
          0.13475017, 0.8461558 , 0.89860517],
         [0.8971772 , 0.85703963, 0.743163  , 0.87278444, 0.90552235,
          0.19766915, 0.33566353, 0.89383173, 0.87157995, 0.83140534,
          0.92693084, 0.9499294 , 0.85782766]],
 
        [[0.9214447 , 0.9804845 , 0.6575725 , 0.46105212, 0.5740245 ,
          0.88368326, 0.89460224, 0.81119704, 0.50101817, 0.24979575,
          0.16411652, 0.83694774, 0.9241573 ],
         [0.8916    , 0.87129986, 0.7239725 , 0.8828186 , 0.7020806 ,
          0.16116264, 0.36204475, 0.8973187 , 0.8997571 , 0.51675177,
          0.89034307, 0.98887885, 0.88438815]],
 
        [[0.8979453 , 0.97743154, 0.5481076 , 0.523632  , 0.570176  ,
          0.8288708 , 0.9113763 , 0.9194614 , 0.575856  , 0.07603623,
          0.21255928, 0.9018014 , 0.9266098 ],
         [0.91993105, 0.8616991 , 0.781426  , 0.7750215 , 0.85324234,
          0.14189687, 0.5463986 , 0.8761287 , 0.93542594, 0.50916994,
          0.87139845, 0.8620718 , 0.9169966 ]],
 
        [[0.90489644, 0.9633726 , 0.6176859 , 0.6120859 , 0.53412354,
          0.8082982 , 0.9141492 , 0.8100913 , 0.7064677 , 0.07797408,
          0.28660768, 0.9255538 , 0.9081669 ],
         [0.9197768 , 0.89081717, 0.7697851 , 0.850639  , 0.8240589 ,
          0.2276387 , 0.7375747 , 0.9573141 , 0.95667875, 0.7197965 ,
          0.8762751 , 0.8575352 , 0.8765895 ]],
 
        [[0.9522048 , 0.96551245, 0.72864616, 0.5890152 , 0.561211  ,
          0.7051566 , 0.9421855 , 0.39786857, 0.7715297 , 0.6171893 ,
          0.06328589, 1.0118455 , 0.886791  ],
         [0.9031525 , 0.9011465 , 0.7290425 , 0.84665924, 0.85558087,
          0.35440978, 0.8101312 , 0.931835  , 0.91998947, 0.9771716 ,
          0.88361436, 0.8611444 , 0.88294595]],
 
        [[0.93872   , 0.97103214, 0.63806784, 0.89063996, 0.68062663,
          0.9067393 , 0.89928836, 0.40190646, 0.75169766, 0.5388288 ,
          0.30325472, 0.86616135, 0.864786  ],
         [0.9355017 , 0.93469065, 0.73501164, 0.89369905, 0.794787  ,
          0.29464462, 0.91743165, 0.88107586, 0.89442694, 0.97276276,
          0.9208387 , 0.8436978 , 0.9492276 ]],
 
        [[0.91440874, 0.97273135, 0.64372706, 0.85304886, 0.6125536 ,
          0.89858156, 0.89086473, 0.33406225, 0.7624657 , 0.64882857,
          0.18051867, 0.93381244, 0.90368915],
         [0.9286875 , 0.93761605, 0.7948513 , 0.81816167, 0.7628807 ,
          0.30384466, 0.83553046, 0.83106405, 0.9189269 , 0.93762034,
          0.94770956, 0.8512343 , 0.9446315 ]],
 
        [[0.9450149 , 0.9582136 , 0.78555703, 0.7544447 , 0.58366936,
          0.85938   , 0.94498163, 0.6194322 , 0.7035529 , 0.22808443,
          0.24900974, 0.981288  , 0.92618316],
         [0.93841267, 0.9422818 , 0.80968696, 0.8445456 , 0.7991047 ,
          0.4916717 , 0.77814513, 0.6231525 , 0.93198806, 0.9570074 ,
          0.95540506, 0.9207018 , 0.8778759 ]],
 
        [[0.9381855 , 0.94920886, 0.77673894, 0.87591183, 0.3847992 ,
          0.88775337, 0.92982674, 0.8082221 , 0.6930795 , 0.16653292,
          0.26732486, 0.9830136 , 0.93462956],
         [0.9093149 , 0.96090955, 0.8409559 , 0.83797425, 0.8743328 ,
          0.82546026, 0.32881752, 0.5494046 , 0.9653242 , 0.9882784 ,
          0.85375595, 0.95603913, 0.9316707 ]],
 
        [[0.9048104 , 0.92460406, 0.75584614, 0.8082359 , 0.47512543,
          0.8684657 , 0.9260271 , 0.8822638 , 0.71269846, 0.1508674 ,
          0.22018598, 0.9016738 , 0.90536344],
         [0.918121  , 0.96696764, 0.78534484, 0.883681  , 0.798996  ,
          0.69723856, 0.5170047 , 0.8321578 , 0.9426196 , 0.9527973 ,
          0.91900206, 0.9706679 , 0.90770215]],
 
        [[0.9391487 , 0.9352003 , 0.85189575, 0.72796327, 0.6884535 ,
          0.8768972 , 0.9508924 , 0.6879568 , 0.71122557, 0.7012927 ,
          0.6031595 , 0.87616193, 0.91429555],
         [0.8932258 , 0.97501004, 0.78940654, 0.8651793 , 0.72244436,
          0.82689875, 0.4597148 , 0.93260366, 0.9202296 , 0.94214964,
          0.8834407 , 0.98030627, 0.8976605 ]]], dtype=float32),
 'instance_scores': array([[0.9953135 , 0.99476504],
        [0.99593395, 0.99526805],
        [0.9959078 , 0.9945123 ],
        [0.99573624, 0.993386  ],
        [0.99603134, 0.99172956],
        [0.99564207, 0.9916197 ],
        [0.9947187 , 0.9915406 ],
        [0.9940315 , 0.98916876],
        [0.99394447, 0.98962784],
        [0.9944642 , 0.9910501 ],
        [0.99155337, 0.9933716 ],
        [0.9916019 , 0.9933976 ],
        [0.9932473 , 0.9932013 ],
        [0.9920751 , 0.9946308 ],
        [0.991653  , 0.99465877],
        [0.99162734, 0.99486005]], dtype=float32),
 'centroids': array([[[271.8735 , 436.4811 ],
         [355.93707, 435.63477]],
 
        [[272.0215 , 436.42197],
         [356.2099 , 435.4682 ]],
 
        [[272.23578, 436.31976],
         [356.61108, 435.4756 ]],
 
        [[356.57007, 433.15857],
         [272.7147 , 435.9847 ]],
 
        [[356.93347, 432.73026],
         [272.7111 , 435.8055 ]],
 
        [[356.86227, 432.03918],
         [272.64484, 435.49347]],
 
        [[357.0275 , 431.29968],
         [272.49817, 435.54977]],
 
        [[359.29578, 431.42874],
         [272.1338 , 435.81354]],
 
        [[359.7555 , 429.4507 ],
         [272.2437 , 435.95605]],
 
        [[359.9807 , 428.4453 ],
         [272.04776, 436.2247 ]],
 
        [[360.3565 , 427.81192],
         [271.94632, 437.30673]],
 
        [[360.8997 , 427.5365 ],
         [272.4532 , 436.9694 ]],
 
        [[361.10843, 427.52646],
         [272.42938, 436.09125]],
 
        [[361.59042, 425.5916 ],
         [272.44873, 435.94284]],
 
        [[364.18994, 425.5058 ],
         [272.18735, 436.0978 ]],
 
        [[364.8356 , 425.49683],
         [272.1019 , 436.49136]]], dtype=float32),
 'centroid_vals': array([[0.94554764, 0.83948356],
        [0.9591119 , 0.8525362 ],
        [0.95961505, 0.86304706],
        [0.9252076 , 0.97578657],
        [0.974096  , 0.9668305 ],
        [0.9845507 , 0.9572475 ],
        [0.9105379 , 0.97522974],
        [0.880064  , 0.9943127 ],
        [0.911333  , 1.0001038 ],
        [0.9698766 , 0.9948527 ],
        [0.96454924, 0.9799493 ],
        [0.96142364, 1.0046191 ],
        [0.95354944, 0.9987816 ],
        [0.94746464, 0.98374254],
        [0.97818244, 0.98671097],
        [0.9833999 , 0.98425347]], dtype=float32),
 'n_valid': array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}
for key, value in predictions.items():
    print(f"'{key}': {value.shape} ({value.dtype})")
'instance_peaks': (16, 2, 13, 2) (float32)
'instance_peak_vals': (16, 2, 13) (float32)
'instance_scores': (16, 2) (float32)
'centroids': (16, 2, 2) (float32)
'centroid_vals': (16, 2) (float32)
'n_valid': (16,) (int32)

4. Realtime performance#

Now that we know how to do inference with different types of outputs, let’s try to use that to build a simulated “realtime” application with timing.

First, we’ll create a class that simulates a camera grabber API that provides a sequence of pre-loaded frames.

from time import perf_counter
import numpy as np


class SimulatedCamera:
    """Simulated camera class that serves frames from memory continuously.

    Attributes:
        frames: Numpy array with pre-loaded frames.
        frame_counter: Count of frames that have been grabbed.
    """

    frames: np.ndarray
    frame_counter: int

    def __init__(self, frames):
        self.frames = frames
        self.frame_counter = 0
    
    def grab_frame(self):
        idx = self.frame_counter % len(self.frames)
        self.frame_counter += 1
        return self.frames[idx]

Then, we’ll define a simply acquisition loop, in which we repeatedly grab a frame and perform inference to time how long it takes.

recording_duration = 100  # session length in frames

# Pre-load images onto "camera"
camera = SimulatedCamera(video[:512])

# Camera capture loop
inference_times = []
frames_recorded = 0
while frames_recorded < recording_duration:
    # Get the next frame.
    frame = camera.grab_frame()
    frames_recorded += 1

    # Get inference results for the frame and time how long it took.
    t0 = perf_counter()
    frame_predictions = predictor.inference_model.predict_on_batch(np.expand_dims(frame, axis=0))
    dt = perf_counter() - t0
    inference_times.append(dt)

# Convert to milliseconds.
inference_times = np.array(inference_times) * 1000

# Separate out first timing from the rest. The first inference call is much slower as it builds the compute graph.
first_inference_time, inference_times = inference_times[0], inference_times[1:]
print(f"First inference time: {first_inference_time:.1f} ms")
print(f"Inference times: {inference_times.mean():.1f} +- {inference_times.std():.1f} ms")
First inference time: 886.2 ms
Inference times: 63.1 +- 1.2 ms

After the first batch, our inference latencies go way down and we can see how they vary over time:

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4), dpi=120, facecolor="w")
plt.plot(inference_times, ".")
plt.xlabel("Time (frames)")
plt.ylabel("Inference latency (ms)")
plt.grid(True);
../_images/Interactive_and_realtime_inference_27_0.png
plt.figure(figsize=(6, 4), dpi=120, facecolor="w")
plt.hist(inference_times, bins=30)
plt.xlabel("Inference latency (ms)")
plt.ylabel("PDF");
../_images/Interactive_and_realtime_inference_28_0.png