Open In Colab

Interactive and realtime inference#

For most workflows, using the sleap-track CLI is probably the most convenient option, but if you’re developing a custom application you can take advantage of SLEAP’s inference API to use your trained models in your own custom scripts.

In this notebook we will explore how to predict poses from raw images in pure Python, and do some basic benchmarking on a simulated realtime predictor that could be used to enable closed-loop experiments.

1. Setup SLEAP#

Run this cell first to install SLEAP. If you get a dependency error in subsequent cells, just click RuntimeRestart runtime to reload the packages.

Don’t forget to set RuntimeChange runtime typeGPU as the accelerator.

# This should take care of all the dependencies on colab:
!pip uninstall -y opencv-python opencv-contrib-python && pip install sleap


# But to do it locally, we'd recommend the conda package (available on Windows + Linux):
# conda create -n sleap -c sleap -c conda-forge -c nvidia sleap
Found existing installation: opencv-python 4.1.2.30
Uninstalling opencv-python-4.1.2.30:
  Successfully uninstalled opencv-python-4.1.2.30
Found existing installation: opencv-contrib-python 4.1.2.30
Uninstalling opencv-contrib-python-4.1.2.30:
  Successfully uninstalled opencv-contrib-python-4.1.2.30
Collecting sleap
  Downloading sleap-1.2.2-py3-none-any.whl (62.0 MB)
     |████████████████████████████████| 62.0 MB 17 kB/s 
?25hRequirement already satisfied: networkx in /usr/local/lib/python3.7/dist-packages (from sleap) (2.6.3)
Collecting rich==10.16.1
  Downloading rich-10.16.1-py3-none-any.whl (214 kB)
     |████████████████████████████████| 214 kB 51.1 MB/s 
?25hRequirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from sleap) (5.4.8)
Collecting segmentation-models==1.0.1
  Downloading segmentation_models-1.0.1-py3-none-any.whl (33 kB)
Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from sleap) (0.11.2)
Collecting jsmin
  Downloading jsmin-3.0.1.tar.gz (13 kB)
Collecting attrs==21.2.0
  Downloading attrs-21.2.0-py2.py3-none-any.whl (53 kB)
     |████████████████████████████████| 53 kB 1.9 MB/s 
?25hCollecting opencv-python-headless<=4.5.5.62,>=4.2.0.34
  Downloading opencv_python_headless-4.5.5.62-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (47.7 MB)
     |████████████████████████████████| 47.7 MB 92 kB/s 
?25hCollecting pykalman==0.9.5
  Downloading pykalman-0.9.5.tar.gz (228 kB)
     |████████████████████████████████| 228 kB 67.2 MB/s 
?25hCollecting cattrs==1.1.1
  Downloading cattrs-1.1.1-py3-none-any.whl (16 kB)
Requirement already satisfied: scikit-image in /usr/local/lib/python3.7/dist-packages (from sleap) (0.18.3)
Requirement already satisfied: numpy<=1.21.5,>=1.19.5 in /usr/local/lib/python3.7/dist-packages (from sleap) (1.21.5)
Requirement already satisfied: scipy<=1.7.3,>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from sleap) (1.4.1)
Collecting jsonpickle==1.2
  Downloading jsonpickle-1.2-py2.py3-none-any.whl (32 kB)
Requirement already satisfied: pyzmq in /usr/local/lib/python3.7/dist-packages (from sleap) (22.3.0)
Collecting scikit-video
  Downloading scikit_video-1.1.11-py2.py3-none-any.whl (2.3 MB)
     |████████████████████████████████| 2.3 MB 54.6 MB/s 
?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from sleap) (3.13)
Requirement already satisfied: tensorflow<2.9.0,>=2.6.3 in /usr/local/lib/python3.7/dist-packages (from sleap) (2.8.0)
Requirement already satisfied: certifi<=2021.10.8,>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from sleap) (2021.10.8)
Requirement already satisfied: h5py<=3.6.0,>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from sleap) (3.1.0)
Collecting PySide2<=5.14.1,>=5.13.2
  Downloading PySide2-5.14.1-5.14.1-cp35.cp36.cp37.cp38-abi3-manylinux1_x86_64.whl (165.5 MB)
     |████████████████████████████████| 165.5 MB 64 kB/s 
?25hRequirement already satisfied: imageio<=2.15.0 in /usr/local/lib/python3.7/dist-packages (from sleap) (2.4.1)
Collecting python-rapidjson
  Downloading python_rapidjson-1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
     |████████████████████████████████| 1.6 MB 42.0 MB/s 
?25hCollecting qimage2ndarray<=1.8.3,>=1.8.2
  Downloading qimage2ndarray-1.8.3-py3-none-any.whl (11 kB)
Requirement already satisfied: scikit-learn==1.0.* in /usr/local/lib/python3.7/dist-packages (from sleap) (1.0.2)
Collecting imgstore==0.2.9
  Downloading imgstore-0.2.9-py2.py3-none-any.whl (904 kB)
     |████████████████████████████████| 904 kB 70.2 MB/s 
?25hRequirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from sleap) (1.3.5)
Collecting imgaug==0.4.0
  Downloading imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
     |████████████████████████████████| 948 kB 72.4 MB/s 
?25hRequirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from imgaug==0.4.0->sleap) (3.2.2)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from imgaug==0.4.0->sleap) (1.15.0)
Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from imgaug==0.4.0->sleap) (7.1.2)
Collecting opencv-python
  Downloading opencv_python-4.5.5.64-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.5 MB)
     |████████████████████████████████| 60.5 MB 1.3 MB/s 
?25hRequirement already satisfied: Shapely in /usr/local/lib/python3.7/dist-packages (from imgaug==0.4.0->sleap) (1.8.1.post1)
Requirement already satisfied: pytz in /usr/local/lib/python3.7/dist-packages (from imgstore==0.2.9->sleap) (2018.9)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.7/dist-packages (from imgstore==0.2.9->sleap) (2.8.2)
Requirement already satisfied: tzlocal in /usr/local/lib/python3.7/dist-packages (from imgstore==0.2.9->sleap) (1.5.1)
Requirement already satisfied: typing-extensions<5.0,>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from rich==10.16.1->sleap) (3.10.0.2)
WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ProtocolError('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))': /simple/colorama/
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from rich==10.16.1->sleap) (2.6.1)
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
     |████████████████████████████████| 51 kB 8.9 MB/s 
?25hRequirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn==1.0.*->sleap) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn==1.0.*->sleap) (3.1.0)
Collecting keras-applications<=1.0.8,>=1.0.7
  Downloading Keras_Applications-1.0.8-py3-none-any.whl (50 kB)
     |████████████████████████████████| 50 kB 8.7 MB/s 
?25hCollecting image-classifiers==1.0.0
  Downloading image_classifiers-1.0.0-py3-none-any.whl (19 kB)
Collecting efficientnet==1.0.0
  Downloading efficientnet-1.0.0-py3-none-any.whl (17 kB)
Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py<=3.6.0,>=3.1.0->sleap) (1.5.2)
Collecting shiboken2==5.14.1
  Downloading shiboken2-5.14.1-5.14.1-cp35.cp36.cp37.cp38-abi3-manylinux1_x86_64.whl (847 kB)
     |████████████████████████████████| 847 kB 56.7 MB/s 
?25hRequirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from scikit-image->sleap) (1.3.0)
Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.7/dist-packages (from scikit-image->sleap) (2021.11.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->imgaug==0.4.0->sleap) (3.0.7)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->imgaug==0.4.0->sleap) (1.4.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->imgaug==0.4.0->sleap) (0.11.0)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (1.1.0)
Collecting tf-estimator-nightly==2.8.0.dev2021122109
  Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)
     |████████████████████████████████| 462 kB 69.9 MB/s 
?25hRequirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (1.6.3)
Requirement already satisfied: gast>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (0.5.3)
Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (0.2.0)
Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (3.17.3)
Requirement already satisfied: libclang>=9.0.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (13.0.0)
Requirement already satisfied: keras-preprocessing>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (1.1.2)
Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (1.14.0)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (0.24.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (57.4.0)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (3.3.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (1.44.0)
Requirement already satisfied: absl-py>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (1.0.0)
Requirement already satisfied: tensorboard<2.9,>=2.8 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (2.8.0)
Requirement already satisfied: keras<2.9,>=2.8.0rc0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (2.8.0)
Requirement already satisfied: flatbuffers>=1.12 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.9.0,>=2.6.3->sleap) (2.0)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.7/dist-packages (from astunparse>=1.6.0->tensorflow<2.9.0,>=2.6.3->sleap) (0.37.1)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (1.35.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (3.3.6)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (0.4.6)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (1.8.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (0.6.1)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (1.0.1)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (2.23.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (0.2.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (4.2.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (4.11.3)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (0.4.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (1.24.3)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.9,>=2.8->tensorflow<2.9.0,>=2.6.3->sleap) (3.2.0)
Building wheels for collected packages: pykalman, jsmin
  Building wheel for pykalman (setup.py) ... ?25l?25hdone
  Created wheel for pykalman: filename=pykalman-0.9.5-py3-none-any.whl size=48462 sha256=b43fd016511642d3238f564a820ccced9855d44660a169c46474533d3cf57390
  Stored in directory: /root/.cache/pip/wheels/6a/04/02/2dda6ea59c66d9e685affc8af3a31ad3a5d87b7311689efce6
  Building wheel for jsmin (setup.py) ... ?25l?25hdone
  Created wheel for jsmin: filename=jsmin-3.0.1-py3-none-any.whl size=13782 sha256=fd47efc594f3416388e6e074d4602a5b5559ce66e69e621778a182409f5a004c
  Stored in directory: /root/.cache/pip/wheels/a4/0b/64/fb4f87526ecbdf7921769a39d91dcfe4860e621cf15b8250d6
Successfully built pykalman jsmin
Installing collected packages: keras-applications, tf-estimator-nightly, shiboken2, opencv-python, image-classifiers, efficientnet, commonmark, colorama, attrs, segmentation-models, scikit-video, rich, qimage2ndarray, python-rapidjson, PySide2, pykalman, opencv-python-headless, jsonpickle, jsmin, imgstore, imgaug, cattrs, sleap
  Attempting uninstall: attrs
    Found existing installation: attrs 21.4.0
    Uninstalling attrs-21.4.0:
      Successfully uninstalled attrs-21.4.0
  Attempting uninstall: imgaug
    Found existing installation: imgaug 0.2.9
    Uninstalling imgaug-0.2.9:
      Successfully uninstalled imgaug-0.2.9
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.4.0 which is incompatible.
Successfully installed PySide2-5.14.1 attrs-21.2.0 cattrs-1.1.1 colorama-0.4.4 commonmark-0.9.1 efficientnet-1.0.0 image-classifiers-1.0.0 imgaug-0.4.0 imgstore-0.2.9 jsmin-3.0.1 jsonpickle-1.2 keras-applications-1.0.8 opencv-python-4.5.5.64 opencv-python-headless-4.5.5.62 pykalman-0.9.5 python-rapidjson-1.6 qimage2ndarray-1.8.3 rich-10.16.1 scikit-video-1.1.11 segmentation-models-1.0.1 shiboken2-5.14.1 sleap-1.2.2 tf-estimator-nightly-2.8.0.dev2021122109

Import SLEAP to make sure it installed correctly and print out some information about the system:

import sleap
sleap.disable_preallocation()  # This initializes the GPU and prevents TensorFlow from filling the entire GPU memory
sleap.versions()
sleap.system_summary()
INFO:numexpr.utils:NumExpr defaulting to 2 threads.
SLEAP: 1.2.2
TensorFlow: 2.8.0
Numpy: 1.21.5
Python: 3.7.13
OS: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
GPUs: 1/1 available
  Device: /physical_device:GPU:0
         Available: True
        Initalized: False
     Memory growth: True

2. Setup data#

Before we start, let’s download a raw video and a set of trained top-down ID models that we’ll use to build our application around.

!curl -L --output video.mp4 https://storage.googleapis.com/sleap-data/reference/flies13/190719_090330_wt_18159206_rig1.2%4015000-17560.mp4
!curl -L --output centroid_model.zip https://storage.googleapis.com/sleap-data/reference/flies13/centroid.fast.210504_182918.centroid.n%3D1800.zip
!curl -L --output centered_instance_id_model.zip https://storage.googleapis.com/sleap-data/reference/flies13/td_id.fast.v2.210519_111253.multi_class_topdown.n%3D1800.zip
!ls -lah
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 81.3M  100 81.3M    0     0   119M      0 --:--:-- --:--:-- --:--:--  119M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 6223k  100 6223k    0     0  23.2M      0 --:--:-- --:--:-- --:--:-- 23.2M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 32.2M  100 32.2M    0     0  62.4M      0 --:--:-- --:--:-- --:--:-- 62.4M
total 120M
drwxr-xr-x 1 root root 4.0K Apr  3 23:33 .
drwxr-xr-x 1 root root 4.0K Apr  3 23:31 ..
-rw-r--r-- 1 root root  33M Apr  3 23:33 centered_instance_id_model.zip
-rw-r--r-- 1 root root 6.1M Apr  3 23:33 centroid_model.zip
drwxr-xr-x 4 root root 4.0K Mar 23 14:21 .config
drwxr-xr-x 1 root root 4.0K Mar 23 14:22 sample_data
-rw-r--r-- 1 root root  82M Apr  3 23:33 video.mp4

Note: These zip files just have the contents of standard SLEAP model folders that are generated during training.

3. Interactive inference#

SLEAP provides a high-level API for performing inference in the form of Predictor classes specific to each approach/model type.

To create one from a set of trained models, we can use the high-level sleap.load_model() function:

predictor = sleap.load_model(["centroid_model.zip", "centered_instance_id_model.zip"], batch_size=16)

This function handles all the logic of loading trained models, reading the configurations used to train them, and constructs inference models that also include non-trainable operations like peak finding and instance grouping.

Next, we’ll load a video that we want to use for inference. SLEAP Video objects don’t actually load the whole video into memory, they just provide a common numpy-like interface for reading from different file formats:

video = sleap.load_video("video.mp4")
video.shape, video.dtype
((2560, 1024, 1024, 1), dtype('uint8'))

Our predictor is pretty flexible. It can handle a variety of different input formats, all of which will return a Labels object that contains all of our predictions:

# Load frames to a numpy array.
imgs = video[:100]
print(f"imgs.shape: {imgs.shape}")

# Predict on numpy array.
predictions = predictor.predict(imgs)
predictions
imgs.shape: (100, 1024, 1024, 1)


Labels(labeled_frames=100, videos=1, skeletons=1, tracks=2)
# Predict on the entire video with parallelizable loading/preprocessing:
predictions = predictor.predict(video)
predictions


Labels(labeled_frames=2560, videos=1, skeletons=1, tracks=2)

We can then inspect the results of our predictor:

# Visualize a frame.
predictions[100].plot(scale=0.25)
../_images/Interactive_and_realtime_inference_17_0.png
# Inspect the contents of a single frame.
labeled_frame = predictions[100]
labeled_frame.instances
[PredictedInstance(video=Video(filename=video.mp4, shape=(2560, 1024, 1024, 1), backend=MediaVideo), frame_idx=100, points=[head: (212.5, 427.0, 0.94), thorax: (252.0, 433.1, 0.95), abdomen: (288.6, 439.3, 0.68), wingL: (304.5, 443.3, 0.88), wingR: (306.2, 435.8, 0.68), forelegL4: (216.2, 445.5, 0.88), forelegR4: (216.1, 410.0, 0.90), midlegL4: (244.4, 471.3, 0.90), midlegR4: (256.6, 408.9, 0.86), hindlegL4: (275.0, 459.2, 0.89), hindlegR4: (292.3, 412.0, 0.81), eyeL: (220.0, 438.0, 0.84), eyeR: (223.8, 417.5, 0.91)], score=0.99, track=Track(spawned_on=0, name='female'), tracking_score=0.00),
 PredictedInstance(video=Video(filename=video.mp4, shape=(2560, 1024, 1024, 1), backend=MediaVideo), frame_idx=100, points=[head: (313.7, 432.6, 0.87), thorax: (348.9, 427.9, 1.00), abdomen: (378.9, 425.8, 0.83), wingL: (397.0, 428.7, 0.89), wingR: (394.9, 420.7, 0.74), forelegL4: (307.4, 446.4, 0.88), forelegR4: (306.5, 422.5, 0.89), midlegL4: (341.6, 474.2, 0.97), midlegR4: (332.6, 386.3, 0.97), hindlegL4: (378.9, 458.8, 0.92), hindlegR4: (387.7, 394.8, 0.88), eyeL: (323.7, 442.1, 0.96), eyeR: (320.7, 420.8, 0.88)], score=0.99, track=Track(spawned_on=0, name='male'), tracking_score=0.00)]
# Convert an instance to a numpy array:
labeled_frame[0].numpy()
rec.array([[212.51400757, 426.97024536],
           [251.97747803, 433.08648682],
           [288.64355469, 439.3086853 ],
           [304.53396606, 443.33477783],
           [306.20336914, 435.77227783],
           [216.24688721, 445.4755249 ],
           [216.14550781, 409.98342896],
           [244.39497375, 471.31561279],
           [256.61740112, 408.89056396],
           [274.97470093, 459.1831665 ],
           [292.2600708 , 411.95904541],
           [219.98565674, 437.97906494],
           [223.75566101, 417.5496521 ]],
          dtype=float64)

What if we don’t want or need the inference results wrapped in the SLEAP structures?

By using the low-level inference model, we can actually go directly from image to numpy arrays of our results:

imgs = video[:16]  # batch of 16 images

predictions = predictor.inference_model.predict(imgs, numpy=True)
predictions
{'centroid_vals': array([[0.9455479 , 0.8394836 ],
        [0.95911187, 0.85253626],
        [0.9596152 , 0.8630471 ],
        [0.9252076 , 0.9757867 ],
        [0.9740962 , 0.9668303 ],
        [0.98455054, 0.95724756],
        [0.91053814, 0.9752301 ],
        [0.88006395, 0.99431276],
        [0.9113332 , 1.0001038 ],
        [0.9698767 , 0.9948529 ],
        [0.96454954, 0.9799493 ],
        [0.9614236 , 1.0046192 ],
        [0.9535493 , 0.99878174],
        [0.9474647 , 0.98374265],
        [0.9781825 , 0.9867112 ],
        [0.98339975, 0.9842536 ]], dtype=float32),
 'centroids': array([[[271.8735 , 436.4811 ],
         [355.93707, 435.63477]],
 
        [[272.0215 , 436.42197],
         [356.2099 , 435.4682 ]],
 
        [[272.23578, 436.31976],
         [356.61108, 435.4756 ]],
 
        [[356.57007, 433.15857],
         [272.7147 , 435.9847 ]],
 
        [[356.93347, 432.73026],
         [272.7111 , 435.8055 ]],
 
        [[356.86227, 432.03918],
         [272.64484, 435.49347]],
 
        [[357.0275 , 431.29968],
         [272.49817, 435.54977]],
 
        [[359.29578, 431.42874],
         [272.1338 , 435.81354]],
 
        [[359.7555 , 429.4507 ],
         [272.2437 , 435.95605]],
 
        [[359.9807 , 428.4453 ],
         [272.04776, 436.2247 ]],
 
        [[360.3565 , 427.81192],
         [271.94632, 437.30673]],
 
        [[360.8997 , 427.5365 ],
         [272.4532 , 436.9694 ]],
 
        [[361.10843, 427.52646],
         [272.42938, 436.09125]],
 
        [[361.59042, 425.5916 ],
         [272.44873, 435.94284]],
 
        [[364.18994, 425.5058 ],
         [272.18735, 436.0978 ]],
 
        [[364.8356 , 425.49683],
         [272.1019 , 436.49136]]], dtype=float32),
 'instance_peak_vals': array([[[0.9913698 , 0.9798432 , 0.755395  , 0.45440078, 0.49718782,
          0.82649314, 0.8982548 , 0.7941463 , 0.8178157 , 0.05604962,
          0.06407703, 0.8860661 , 0.9635323 ],
         [0.9033977 , 0.25969282, 0.63431203, 0.83960074, 0.76130724,
          0.04938019, 0.8405748 , 0.8820077 , 0.8816873 , 0.8243383 ,
          0.33521542, 0.843406  , 0.8127705 ]],
 
        [[0.9598928 , 0.9734157 , 0.67664635, 0.35409918, 0.49767363,
          0.8832786 , 0.9271228 , 0.79897636, 0.7574272 , 0.04437801,
          0.06204455, 0.86091673, 0.89724076],
         [0.88144   , 0.43337217, 0.6627725 , 0.83882016, 0.7175109 ,
          0.08318386, 0.7553143 , 0.8750135 , 0.89725804, 0.8539097 ,
          0.87049586, 0.84071857, 0.8853135 ]],
 
        [[0.9277582 , 0.9876474 , 0.71884066, 0.36052445, 0.5332413 ,
          0.8968105 , 0.9209892 , 0.8180278 , 0.6177353 , 0.03119754,
          0.07055765, 0.83666456, 0.86083984],
         [0.8386838 , 0.5882865 , 0.7205018 , 0.79034203, 0.70366687,
          0.21814364, 0.7629925 , 0.85078365, 0.88240033, 0.889361  ,
          0.855937  , 0.83885545, 0.9163793 ]],
 
        [[0.9318245 , 1.005442  , 0.70377296, 0.44777974, 0.5514284 ,
          0.8751964 , 0.8788199 , 0.7378154 , 0.60576206, 0.06517099,
          0.145257  , 0.81688404, 0.88855964],
         [0.8562528 , 0.86021775, 0.82891434, 0.5004723 , 0.8896506 ,
          0.1508227 , 0.57128006, 0.8668301 , 0.94244254, 0.8910252 ,
          0.9375358 , 0.92730594, 0.8518941 ]],
 
        [[0.93351734, 0.98755234, 0.6618066 , 0.55908614, 0.5017102 ,
          0.89124554, 0.8839096 , 0.77439624, 0.5733776 , 0.06467963,
          0.12731154, 0.81659895, 0.9002954 ],
         [0.9238624 , 0.8279646 , 0.7274185 , 0.8509916 , 0.91163963,
          0.21640284, 0.41097188, 0.9234465 , 0.8912649 , 0.8676514 ,
          0.91081864, 0.9236754 , 0.9313458 ]],
 
        [[0.96605366, 0.9777925 , 0.67958933, 0.5347009 , 0.49430045,
          0.89868015, 0.88998073, 0.82294536, 0.49898368, 0.1423007 ,
          0.1347502 , 0.846156  , 0.8986051 ],
         [0.8971774 , 0.85703975, 0.74316317, 0.87278455, 0.9055221 ,
          0.19766904, 0.3356636 , 0.89383155, 0.8715803 , 0.8314053 ,
          0.92693067, 0.94992954, 0.8578277 ]],
 
        [[0.92144465, 0.98048437, 0.65757245, 0.4610521 , 0.57402426,
          0.88368344, 0.89460254, 0.8111973 , 0.50101817, 0.24979569,
          0.16411611, 0.83694774, 0.9241577 ],
         [0.89160013, 0.8712998 , 0.72397256, 0.88281846, 0.7020805 ,
          0.16116247, 0.36204454, 0.8973186 , 0.8997571 , 0.5167517 ,
          0.89034295, 0.98887867, 0.8843883 ]],
 
        [[0.89794546, 0.97743154, 0.5481075 , 0.52363163, 0.570176  ,
          0.8288712 , 0.9113766 , 0.9194614 , 0.57585603, 0.07603604,
          0.21255916, 0.90180147, 0.9266095 ],
         [0.9199309 , 0.8616993 , 0.78142613, 0.77502143, 0.8532426 ,
          0.14189675, 0.5463987 , 0.8761284 , 0.9354262 , 0.5091697 ,
          0.8713986 , 0.862072  , 0.91699666]],
 
        [[0.9048965 , 0.96337247, 0.6176863 , 0.6120858 , 0.53412384,
          0.8082984 , 0.914149  , 0.8100912 , 0.7064674 , 0.07797385,
          0.28660813, 0.9255539 , 0.9081667 ],
         [0.9197771 , 0.89081717, 0.769785  , 0.85063875, 0.82405925,
          0.22763878, 0.7375746 , 0.95731395, 0.95667887, 0.7197969 ,
          0.87627506, 0.8575353 , 0.8765893 ]],
 
        [[0.9522317 , 0.96551776, 0.728644  , 0.58902043, 0.56121   ,
          0.7050669 , 0.94214785, 0.39777142, 0.7715537 , 0.617287  ,
          0.06328648, 1.0118883 , 0.8866795 ],
         [0.9031525 , 0.90114677, 0.7290425 , 0.84665924, 0.855581  ,
          0.35440993, 0.8101314 , 0.93183535, 0.91998935, 0.9771715 ,
          0.8836143 , 0.86114466, 0.88294595]],
 
        [[0.9387202 , 0.97103214, 0.6380678 , 0.89064   , 0.6806271 ,
          0.9067394 , 0.89928854, 0.40190598, 0.7516978 , 0.5388293 ,
          0.30325472, 0.8661613 , 0.8647857 ],
         [0.9355016 , 0.9346907 , 0.7350116 , 0.8936991 , 0.7947871 ,
          0.29464447, 0.9174315 , 0.8810758 , 0.89442706, 0.97276264,
          0.92083865, 0.84369785, 0.94922733]],
 
        [[0.914409  , 0.9727311 , 0.64372706, 0.85304916, 0.6125537 ,
          0.89858156, 0.89086455, 0.33406293, 0.76246554, 0.64882785,
          0.18051788, 0.9338125 , 0.903689  ],
         [0.9286875 , 0.93761635, 0.79485124, 0.8181616 , 0.76288086,
          0.3038448 , 0.8355305 , 0.83106405, 0.91892713, 0.9376198 ,
          0.94770956, 0.85123426, 0.9446316 ]],
 
        [[0.94501513, 0.95821375, 0.7855571 , 0.7544449 , 0.58367   ,
          0.8593804 , 0.9449818 , 0.6194321 , 0.7035531 , 0.22808488,
          0.24900919, 0.981288  , 0.92618316],
         [0.93841255, 0.9422814 , 0.80968684, 0.8445455 , 0.7991051 ,
          0.49167132, 0.77814525, 0.6231524 , 0.9319882 , 0.9570072 ,
          0.95540494, 0.9207019 , 0.8778761 ]],
 
        [[0.93817955, 0.9492211 , 0.7767393 , 0.8758958 , 0.38491583,
          0.88775396, 0.9298349 , 0.8082794 , 0.69305503, 0.1668036 ,
          0.26728866, 0.9830228 , 0.9346242 ],
         [0.909315  , 0.9609095 , 0.840956  , 0.83797425, 0.8743328 ,
          0.82546026, 0.32881746, 0.54940474, 0.96532434, 0.98827827,
          0.85375595, 0.95603913, 0.93167067]],
 
        [[0.9048101 , 0.9246041 , 0.7558464 , 0.80823594, 0.47512585,
          0.86846614, 0.9260269 , 0.8822637 , 0.7126984 , 0.15086724,
          0.22018576, 0.9016736 , 0.90536344],
         [0.91812086, 0.9669677 , 0.78534484, 0.88368094, 0.7989964 ,
          0.6972392 , 0.51700455, 0.8321577 , 0.9426196 , 0.9527976 ,
          0.9190021 , 0.9706677 , 0.9077022 ]],
 
        [[0.9391487 , 0.93520033, 0.85189587, 0.72796357, 0.6884538 ,
          0.8768974 , 0.9508925 , 0.6879569 , 0.7112255 , 0.70129263,
          0.6031595 , 0.8761619 , 0.9142955 ],
         [0.8932256 , 0.9750102 , 0.7894063 , 0.8651795 , 0.7224442 ,
          0.8268989 , 0.45971498, 0.93260354, 0.9202294 , 0.94214976,
          0.88344055, 0.9803063 , 0.8976606 ]]], dtype=float32),
 'instance_peaks': array([[[[234.2223 , 430.62558],
          [271.50427, 436.13205],
          [309.87225, 436.65012],
          [324.12576, 438.39148],
          [320.34717, 435.95013],
          [246.42339, 450.67798],
          [242.37634, 413.81458],
          [285.56247, 460.2276 ],
          [273.45126, 406.51892],
          [      nan,       nan],
          [      nan,       nan],
          [241.9709 , 442.32263],
          [245.46785, 421.90225]],
 
         [[319.80017, 435.48407],
          [351.93695, 434.0301 ],
          [369.43228, 431.78564],
          [393.89014, 481.0584 ],
          [398.4241 , 429.79565],
          [      nan,       nan],
          [305.42896, 419.3896 ],
          [325.67926, 475.0098 ],
          [331.97974, 384.30814],
          [363.66406, 473.9354 ],
          [377.3852 , 398.13065],
          [328.40244, 445.51434],
          [328.1667 , 423.94733]]],
 
 
        [[[234.36911, 430.38037],
          [271.65576, 436.0479 ],
          [311.67505, 437.0108 ],
          [324.4831 , 438.1426 ],
          [322.2054 , 435.06854],
          [246.43256, 450.61487],
          [242.39862, 413.8269 ],
          [285.56503, 460.0099 ],
          [273.78204, 406.4644 ],
          [      nan,       nan],
          [      nan,       nan],
          [242.11815, 442.0634 ],
          [245.55441, 421.72803]],
 
         [[320.03793, 435.2389 ],
          [353.87274, 434.77695],
          [370.67218, 432.9711 ],
          [393.91922, 481.09735],
          [399.77133, 431.25983],
          [      nan,       nan],
          [308.409  , 421.48993],
          [325.82016, 474.90115],
          [331.94632, 385.0652 ],
          [363.65408, 473.70728],
          [384.68225, 399.30194],
          [328.72806, 445.15356],
          [328.48532, 423.624  ]]],
 
 
        [[[234.5559 , 430.06238],
          [271.8775 , 435.9898 ],
          [312.13086, 438.16318],
          [324.77222, 437.65994],
          [322.40115, 434.7244 ],
          [246.44681, 450.51874],
          [242.45566, 413.7617 ],
          [285.8958 , 460.56442],
          [273.66855, 406.2377 ],
          [      nan,       nan],
          [      nan,       nan],
          [242.26588, 441.80545],
          [245.77664, 420.7662 ]],
 
         [[320.46982, 435.25452],
          [354.89542, 434.93198],
          [372.2558 , 433.46106],
          [394.40723, 479.57962],
          [400.3011 , 431.9626 ],
          [306.98218, 449.3156 ],
          [308.8817 , 421.52148],
          [325.98843, 474.91672],
          [332.17917, 385.04684],
          [363.03186, 473.50638],
          [391.05493, 396.85666],
          [329.1689 , 445.0495 ],
          [328.89993, 423.52527]]],
 
 
        [[[234.65546, 429.69464],
          [272.38306, 435.6884 ],
          [311.04346, 437.86926],
          [324.80878, 437.3788 ],
          [322.84747, 433.93933],
          [246.71854, 451.2873 ],
          [242.57391, 413.58414],
          [286.16397, 461.83658],
          [272.8733 , 406.21573],
          [      nan,       nan],
          [      nan,       nan],
          [242.4386 , 441.46246],
          [245.25829, 420.48416]],
 
         [[320.7713 , 433.55927],
          [356.25912, 432.81424],
          [372.98462, 432.9266 ],
          [402.0365 , 465.378  ],
          [400.8439 , 431.7685 ],
          [      nan,       nan],
          [310.4258 , 422.7895 ],
          [325.16397, 474.86514],
          [332.16724, 384.9967 ],
          [362.87766, 473.12836],
          [390.43555, 393.69998],
          [330.20596, 443.4066 ],
          [329.0497 , 421.68896]]],
 
 
        [[[234.51591, 429.5735 ],
          [272.3791 , 435.4755 ],
          [310.74457, 436.20264],
          [325.24997, 437.69904],
          [323.1339 , 433.8241 ],
          [246.75269, 451.22192],
          [242.58466, 413.53275],
          [286.0668 , 461.6229 ],
          [272.87787, 406.2068 ],
          [      nan,       nan],
          [      nan,       nan],
          [242.3858 , 441.31342],
          [245.15892, 420.27942]],
 
         [[320.91632, 432.5178 ],
          [356.588  , 432.3604 ],
          [374.51236, 432.42508],
          [405.0515 , 450.2759 ],
          [401.2467 , 432.2713 ],
          [314.74677, 442.78735],
          [312.76758, 422.29553],
          [325.20752, 474.6215 ],
          [332.2873 , 384.86606],
          [362.8446 , 472.95822],
          [388.92188, 394.203  ],
          [329.54233, 442.43842],
          [329.1192 , 420.79416]]],
 
 
        [[[234.54964, 429.56854],
          [272.30457, 435.13345],
          [309.08594, 434.02444],
          [325.13245, 437.11148],
          [324.71674, 431.81714],
          [246.79828, 450.9629 ],
          [242.6766 , 413.53745],
          [286.09372, 461.14362],
          [272.87155, 406.23718],
          [      nan,       nan],
          [      nan,       nan],
          [242.4111 , 441.2425 ],
          [245.13495, 420.83694]],
 
         [[320.7404 , 430.43884],
          [356.4725 , 431.68488],
          [375.05853, 431.87177],
          [404.3775 , 451.92688],
          [401.39508, 431.9776 ],
          [      nan,       nan],
          [312.77365, 421.6409 ],
          [325.17343, 474.26575],
          [331.44904, 384.56747],
          [363.05463, 472.54587],
          [388.72284, 394.13287],
          [330.25458, 440.28958],
          [328.9332 , 419.74493]]],
 
 
        [[[234.15704, 429.3947 ],
          [272.1558 , 435.1859 ],
          [310.46423, 435.5753 ],
          [324.42407, 437.18857],
          [322.80786, 433.41486],
          [246.72241, 450.9671 ],
          [242.64005, 413.65726],
          [285.9537 , 461.01648],
          [272.73447, 406.31635],
          [305.89285, 449.9849 ],
          [      nan,       nan],
          [241.21112, 441.0713 ],
          [244.77327, 419.9886 ]],
 
         [[321.03162, 429.8643 ],
          [356.5856 , 430.9501 ],
          [377.2166 , 431.29108],
          [405.09204, 451.2633 ],
          [402.97113, 431.12497],
          [      nan,       nan],
          [312.74753, 421.16742],
          [325.3774 , 474.7351 ],
          [331.5342 , 384.97403],
          [378.56894, 469.3632 ],
          [388.81372, 393.89886],
          [330.641  , 439.67197],
          [329.04425, 418.99023]]],
 
 
        [[[232.79128, 428.2476 ],
          [271.7884 , 435.45706],
          [310.26096, 437.58252],
          [322.67697, 439.28253],
          [322.35138, 435.4916 ],
          [246.49533, 451.1817 ],
          [242.4297 , 413.56104],
          [286.01126, 461.4526 ],
          [272.72516, 406.3869 ],
          [      nan,       nan],
          [284.4912 , 408.79095],
          [240.58961, 440.1936 ],
          [244.4464 , 420.00543]],
 
         [[322.69318, 430.96204],
          [358.8828 , 430.98035],
          [379.26816, 431.0259 ],
          [405.7312 , 449.5473 ],
          [405.13306, 431.02057],
          [      nan,       nan],
          [309.64542, 421.59024],
          [325.46237, 474.79062],
          [331.63318, 384.9981 ],
          [390.9735 , 466.93915],
          [388.87518, 393.89645],
          [331.4858 , 440.98822],
          [330.72357, 419.30713]]],
 
 
        [[[232.9138 , 428.26993],
          [271.89908, 435.6341 ],
          [310.36536, 437.9696 ],
          [322.63763, 439.87323],
          [322.4065 , 435.7932 ],
          [246.48575, 451.27322],
          [242.48721, 413.6446 ],
          [285.74454, 460.08987],
          [272.75647, 406.338  ],
          [      nan,       nan],
          [320.82465, 422.17297],
          [240.64159, 440.22705],
          [244.54178, 420.04788]],
 
         [[322.2764 , 429.7331 ],
          [359.43756, 429.0462 ],
          [379.8793 , 429.56253],
          [407.32346, 448.95087],
          [405.74594, 429.27792],
          [315.46356, 441.38046],
          [309.48642, 421.8147 ],
          [325.63013, 474.81934],
          [331.73767, 385.03244],
          [399.19778, 461.1395 ],
          [388.32227, 394.00305],
          [331.94138, 439.76627],
          [330.20728, 418.03998]]],
 
 
        [[[232.59995, 427.9426 ],
          [271.68756, 435.92496],
          [309.74353, 438.45377],
          [322.3493 , 441.9495 ],
          [322.39355, 436.099  ],
          [246.09337, 450.45764],
          [242.33101, 413.80396],
          [284.40045, 460.55066],
          [273.6091 , 406.4331 ],
          [286.35364, 459.99496],
          [      nan,       nan],
          [240.04811, 440.10532],
          [244.36139, 419.95685]],
 
         [[322.50397, 428.86414],
          [359.65952, 428.01282],
          [381.80063, 428.2879 ],
          [407.9239 , 446.02728],
          [406.27682, 428.24774],
          [317.4234 , 444.4193 ],
          [308.38232, 422.35754],
          [325.6553 , 474.45853],
          [331.8156 , 384.7812 ],
          [399.62988, 456.58368],
          [388.52002, 394.27118],
          [332.3299 , 438.7801 ],
          [330.43085, 417.03174]]],
 
 
        [[[232.25208, 427.7414 ],
          [271.57523, 436.99503],
          [308.347  , 440.97897],
          [321.64392, 445.52814],
          [322.16394, 439.4637 ],
          [229.9819 , 444.81857],
          [242.35481, 413.535  ],
          [284.59384, 461.70065],
          [273.50806, 406.95544],
          [286.72635, 460.96436],
          [314.3465 , 428.5469 ],
          [239.56883, 440.8733 ],
          [244.04318, 420.60315]],
 
         [[324.36966, 429.4342 ],
          [360.08127, 427.41803],
          [384.283  , 427.4751 ],
          [408.8785 , 443.59448],
          [408.36377, 425.55792],
          [316.73703, 445.6411 ],
          [308.78436, 421.899  ],
          [325.92154, 474.19464],
          [331.91168, 385.32022],
          [399.73245, 457.32578],
          [388.57062, 394.18298],
          [334.3139 , 438.40005],
          [331.89133, 417.64728]]],
 
 
        [[[232.70679, 428.36255],
          [272.08994, 436.64023],
          [310.14267, 440.50543],
          [322.68262, 444.5147 ],
          [322.82147, 438.87054],
          [224.32256, 448.4768 ],
          [242.57848, 413.34476],
          [284.7278 , 461.26282],
          [273.8772 , 406.77335],
          [286.55972, 460.77054],
          [      nan,       nan],
          [239.95602, 440.7761 ],
          [244.31602, 420.40244]],
 
         [[325.2043 , 429.92737],
          [360.62262, 427.1631 ],
          [386.82898, 425.76257],
          [410.35846, 440.0152 ],
          [408.79132, 423.68118],
          [318.88504, 445.35867],
          [308.8374 , 421.72562],
          [326.25244, 474.88055],
          [332.2403 , 385.27567],
          [399.44467, 457.21188],
          [388.84778, 394.12372],
          [335.362  , 439.4058 ],
          [332.62274, 417.9344 ]]],
 
 
        [[[232.79385, 428.12885],
          [272.08496, 435.77728],
          [310.3099 , 437.81348],
          [324.31982, 440.3584 ],
          [324.60254, 434.39813],
          [222.91586, 451.43195],
          [242.6026 , 413.74078],
          [284.7489 , 460.09384],
          [273.8778 , 406.4865 ],
          [287.56982, 459.68353],
          [322.8655 , 421.80096],
          [240.19046, 440.23196],
          [244.46782, 419.99805]],
 
         [[325.60196, 431.36603],
          [360.8261 , 427.19696],
          [387.17218, 425.47867],
          [410.81366, 438.09143],
          [408.99658, 422.15668],
          [318.84363, 445.00012],
          [311.57254, 423.2615 ],
          [333.60617, 467.9318 ],
          [331.6039 , 385.32465],
          [399.44635, 457.16357],
          [388.92133, 394.11078],
          [336.72043, 439.8229 ],
          [332.6642 , 419.31372]]],
 
 
        [[[232.83435, 428.2637 ],
          [272.11572, 435.61078],
          [312.17938, 439.66312],
          [322.83755, 442.15845],
          [324.40564, 435.64343],
          [225.87045, 451.41144],
          [242.64131, 413.59937],
          [285.06653, 460.35504],
          [273.84183, 406.37183],
          [      nan,       nan],
          [322.4148 , 422.6127 ],
          [240.42722, 440.2208 ],
          [244.4097 , 419.95215]],
 
         [[327.3499 , 431.52005],
          [361.313  , 425.36264],
          [389.47607, 423.60114],
          [411.6601 , 435.50894],
          [409.51843, 419.6943 ],
          [319.90283, 445.82428],
          [313.70898, 423.5036 ],
          [345.66882, 473.1757 ],
          [331.79486, 385.46274],
          [399.46533, 457.10553],
          [388.24854, 394.009  ],
          [337.8076 , 440.06436],
          [333.29004, 419.49707]]],
 
 
        [[[232.41422, 429.4673 ],
          [271.8141 , 435.7682 ],
          [310.01324, 439.98956],
          [322.19714, 443.71683],
          [324.71207, 434.39133],
          [224.85786, 451.4593 ],
          [242.5914 , 413.65204],
          [285.67142, 461.77646],
          [273.7307 , 406.5118 ],
          [      nan,       nan],
          [322.71594, 420.21155],
          [239.99216, 440.57278],
          [243.82819, 420.339  ]],
 
         [[328.47983, 431.74188],
          [363.9317 , 425.2397 ],
          [390.49423, 423.05255],
          [413.68115, 433.6671 ],
          [410.5454 , 419.09042],
          [320.30078, 446.80396],
          [313.82977, 421.7456 ],
          [356.64886, 473.89554],
          [331.84995, 385.1559 ],
          [399.78146, 457.11206],
          [388.5744 , 393.94125],
          [339.9305 , 440.99496],
          [334.5468 , 419.42017]]],
 
 
        [[[232.05379, 430.01157],
          [271.71146, 436.17175],
          [310.08688, 438.66077],
          [322.65015, 442.097  ],
          [324.3269 , 434.45065],
          [224.67744, 450.92798],
          [242.56874, 413.94662],
          [285.72803, 462.40347],
          [273.67886, 406.66385],
          [313.6862 , 456.8137 ],
          [318.559  , 416.42374],
          [239.62582, 441.11035],
          [242.73026, 420.8417 ]],
 
         [[329.30188, 431.77295],
          [364.57666, 425.20844],
          [391.32507, 421.96838],
          [414.35016, 433.2262 ],
          [411.04324, 418.17578],
          [320.63538, 445.82654],
          [315.36795, 420.21204],
          [360.9988 , 471.7216 ],
          [332.20065, 385.24988],
          [399.20847, 456.9794 ],
          [388.68896, 394.04962],
          [340.75934, 441.0198 ],
          [335.4428 , 419.33124]]]], dtype=float32),
 'instance_scores': array([[0.9953146 , 0.99476504],
        [0.9959341 , 0.99526805],
        [0.9959078 , 0.99451363],
        [0.99573493, 0.993386  ],
        [0.99603134, 0.99172956],
        [0.99564207, 0.9916197 ],
        [0.9947187 , 0.9915406 ],
        [0.9940315 , 0.98916876],
        [0.99394447, 0.98962784],
        [0.99446183, 0.9910501 ],
        [0.99155337, 0.9933716 ],
        [0.9916019 , 0.9933977 ],
        [0.9932473 , 0.9932013 ],
        [0.99207497, 0.9946308 ],
        [0.991653  , 0.99465877],
        [0.99162734, 0.99486005]], dtype=float32),
 'n_valid': array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32)}
for key, value in predictions.items():
    print(f"'{key}': {value.shape} ({value.dtype})")
'instance_peaks': (16, 2, 13, 2) (float32)
'instance_peak_vals': (16, 2, 13) (float32)
'instance_scores': (16, 2) (float32)
'centroids': (16, 2, 2) (float32)
'centroid_vals': (16, 2) (float32)
'n_valid': (16,) (int32)

4. Realtime performance#

Now that we know how to do inference with different types of outputs, let’s try to use that to build a simulated “realtime” application with timing.

First, we’ll create a class that simulates a camera grabber API that provides a sequence of pre-loaded frames.

from time import perf_counter
import numpy as np


class SimulatedCamera:
    """Simulated camera class that serves frames from memory continuously.

    Attributes:
        frames: Numpy array with pre-loaded frames.
        frame_counter: Count of frames that have been grabbed.
    """

    frames: np.ndarray
    frame_counter: int

    def __init__(self, frames):
        self.frames = frames
        self.frame_counter = 0
    
    def grab_frame(self):
        idx = self.frame_counter % len(self.frames)
        self.frame_counter += 1
        return self.frames[idx]

Then, we’ll define a simply acquisition loop, in which we repeatedly grab a frame and perform inference to time how long it takes.

recording_duration = 100  # session length in frames

# Pre-load images onto "camera"
camera = SimulatedCamera(video[:512])

# Camera capture loop
inference_times = []
frames_recorded = 0
while frames_recorded < recording_duration:
    # Get the next frame.
    frame = camera.grab_frame()
    frames_recorded += 1

    # Get inference results for the frame and time how long it took.
    t0 = perf_counter()
    frame_predictions = predictor.inference_model.predict_on_batch(np.expand_dims(frame, axis=0))
    dt = perf_counter() - t0
    inference_times.append(dt)

# Convert to milliseconds.
inference_times = np.array(inference_times) * 1000

# Separate out first timing from the rest. The first inference call is much slower as it builds the compute graph.
first_inference_time, inference_times = inference_times[0], inference_times[1:]
print(f"First inference time: {first_inference_time:.1f} ms")
print(f"Inference times: {inference_times.mean():.1f} +- {inference_times.std():.1f} ms")
First inference time: 2181.9 ms
Inference times: 28.8 +- 2.6 ms

After the first batch, our inference latencies go way down and we can see how they vary over time:

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4), dpi=120, facecolor="w")
plt.plot(inference_times, ".")
plt.xlabel("Time (frames)")
plt.ylabel("Inference latency (ms)")
plt.grid(True);
../_images/Interactive_and_realtime_inference_28_0.png
plt.figure(figsize=(6, 4), dpi=120, facecolor="w")
plt.hist(inference_times, bins=30)
plt.xlabel("Inference latency (ms)")
plt.ylabel("PDF");
../_images/Interactive_and_realtime_inference_29_0.png