Open In Colab

Analysis examples#

In this notebook we’ll show examples of how you might use the predictions exported from SLEAP. We’ll work with an Analysis HDF5 file (rather than the .slp predictions file). This HDF5 file can be exported from SLEAP.

We advise building your post-SLEAP analysis pipeline around these HDF5 files rather than trying to work directly with the .slp files used by SLEAP.

Note: You can work with these HDF5 directly in Python (as we’ll do here) or MATLAB without having SLEAP itself installed.

Example analysis data#

Let’s start by download a sample HDF5. These predictions were created with models trained on our sample Drosophila melanogaster courtship dataset. Using these models we can inference on a video clip with 3000 frames. The video clip, resulting predictions, and exported HDF5 are all available here.

We’ll just download the predictions.analysis.h5 file:

!wget -O predictions.analysis.h5 https://github.com/talmolab/sleap/raw/main/docs/notebooks/analysis_example/predictions.analysis.h5
--2021-01-04 03:24:05--  https://github.com/talmolab/sleap/raw/main/docs/notebooks/analysis_example/predictions.analysis.h5
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/talmolab/sleap/main/docs/notebooks/analysis_example/predictions.analysis.h5 [following]
--2021-01-04 03:24:05--  https://raw.githubusercontent.com/talmolab/sleap/tree/main/docs/notebooks/analysis_example/predictions.analysis.h5
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 551501 (539K) [application/octet-stream]
Saving to: ‘predictions.analysis.h5’

predictions.analysi 100%[===================>] 538.58K  --.-KB/s    in 0.05s   

2021-01-04 03:24:05 (10.5 MB/s) - ‘predictions.analysis.h5’ saved [551501/551501]

We can set the path and filename to the analysis HDF5. In our case, this is just predictions.analysis.h5. If you’re running analysis code on your local computer this will be the full path and filename of your HDF5. If you’re running analysis code on Colab, then you’ll probably copy files to Colab via Google Drive and should use the path to your file there.

filename = "predictions.analysis.h5"

Loading the data#

We use the h5py package to load data from the HDF5. This is already installed in Colab. If your running analysis code on your local machine and have SLEAP installed, then h5py and other packages we use are already installed in your SLEAP conda environment. Otherwise, you may need to use conda or pip to install h5py as well as numpy, scipy, matplotlib, seaborn, and any other packages you want use in your analysis code.

Let’s load the file and take a peek.

import h5py
import numpy as np

with h5py.File(filename, "r") as f:
    dset_names = list(f.keys())
    locations = f["tracks"][:].T
    node_names = [n.decode() for n in f["node_names"][:]]

print("===filename===")
print(filename)
print()

print("===HDF5 datasets===")
print(dset_names)
print()

print("===locations data shape===")
print(locations.shape)
print()

print("===nodes===")
for i, name in enumerate(node_names):
    print(f"{i}: {name}")
print()
===filename===
predictions.analysis.h5

===HDF5 datasets===
['node_names', 'track_names', 'track_occupancy', 'tracks']

===locations data shape===
(3000, 13, 2, 2)

===nodes===
0: head
1: thorax
2: abdomen
3: wingL
4: wingR
5: forelegL4
6: forelegR4
7: midlegL4
8: midlegR4
9: hindlegL4
10: hindlegR4
11: eyeL
12: eyeR

In our example file, the shape of the locations matrix (the tracks dataset) is (3000, 13, 2, 2) after it is transposed (with the .T). We transpose the data when loading it in Python; no transpose is needed when using MATLAB. This is because Python and MATLAB expect matrices to be stored differently.

Here’s what each dimension of the matrix means:

  • 3000: the number of frames;

  • 13: the number of nodes in the skeleton (we’ve also loaded and displayed the node_names dataset with the names of these 13 nodes);

  • 2: for the x and y coordinates;

  • 2: the number of distinct animal identities which were found (we have 2 flies in the video clip and they were tracked perfectly, so we ended up with exactly 2 track, but there may be more tracks than animals if tracking didn’t work as well).

We can get these counts from the shape of the matrix, like so:

frame_count, node_count, _, instance_count = locations.shape

print("frame count:", frame_count)
print("node count:", node_count)
print("instance count:", instance_count)
frame count: 3000
node count: 13
instance count: 2

Now that we’ve loaded the data, let’s see some different things we can do with it…

Fill missing values#

from scipy.interpolate import interp1d

def fill_missing(Y, kind="linear"):
    """Fills missing values independently along each dimension after the first."""

    # Store initial shape.
    initial_shape = Y.shape

    # Flatten after first dim.
    Y = Y.reshape((initial_shape[0], -1))

    # Interpolate along each slice.
    for i in range(Y.shape[-1]):
        y = Y[:, i]

        # Build interpolant.
        x = np.flatnonzero(~np.isnan(y))
        f = interp1d(x, y[x], kind=kind, fill_value=np.nan, bounds_error=False)

        # Fill missing
        xq = np.flatnonzero(np.isnan(y))
        y[xq] = f(xq)
        
        # Fill leading or trailing NaNs with the nearest non-NaN values
        mask = np.isnan(y)
        y[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), y[~mask])

        # Save slice
        Y[:, i] = y

    # Restore to initial shape.
    Y = Y.reshape(initial_shape)

    return Y
locations = fill_missing(locations)

Visualize thorax movement across video#

HEAD_INDEX = 0
THORAX_INDEX = 1
ABDO_INDEX = 2

head_loc = locations[:, HEAD_INDEX, :, :]
thorax_loc = locations[:, THORAX_INDEX, :, :]
abdo_loc = locations[:, ABDO_INDEX, :, :]
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
sns.set('notebook', 'ticks', font_scale=1.2)
mpl.rcParams['figure.figsize'] = [15,6]
plt.figure()
plt.plot(thorax_loc[:,0,0], 'y',label='fly-0')
plt.plot(thorax_loc[:,0,1], 'g',label='fly-1')

plt.plot(-1*thorax_loc[:,1,0], 'y')
plt.plot(-1*thorax_loc[:,1,1], 'g')

plt.legend(loc="center right")
plt.title('Thorax locations')


plt.figure(figsize=(7,7))
plt.plot(thorax_loc[:,0,0],thorax_loc[:,1,0], 'y',label='fly-0')
plt.plot(thorax_loc[:,0,1],thorax_loc[:,1,1], 'g',label='fly-1')
plt.legend()

plt.xlim(0,1024)
plt.xticks([])

plt.ylim(0,1024)
plt.yticks([])
plt.title('Thorax tracks')
Text(0.5, 1.0, 'Thorax tracks')
../_images/dc4bfaee117747c1eeba0b62c375704e720b95ba9694ecd512a2cc28d501fbab.png ../_images/35c10296694a5eb3635e87b56ab9545d9a293bed86bd144682549fe750d7a34d.png

More advanced visualizations#

For some additional analysis, we’ll first smooth and differentiate the data with a Savitzky-Golay filter to extract velocities of each joint.

from scipy.signal import savgol_filter

def smooth_diff(node_loc, win=25, poly=3):
    """
    node_loc is a [frames, 2] array
    
    win defines the window to smooth over
    
    poly defines the order of the polynomial
    to fit with
    
    """
    node_loc_vel = np.zeros_like(node_loc)
    
    for c in range(node_loc.shape[-1]):
        node_loc_vel[:, c] = savgol_filter(node_loc[:, c], win, poly, deriv=1)
    
    node_vel = np.linalg.norm(node_loc_vel,axis=1)

    return node_vel

There are two flies. Let’s get results for each separately.

thx_vel_fly0 = smooth_diff(thorax_loc[:, :, 0])
thx_vel_fly1 = smooth_diff(thorax_loc[:, :, 1])

Visualizing thorax x-y dynamics and velocity for fly 0#

fig = plt.figure(figsize=(15,7))
ax1 = fig.add_subplot(211)
ax1.plot(thorax_loc[:, 0, 0], 'k', label='x')
ax1.plot(-1*thorax_loc[:, 1, 0], 'k', label='y')
ax1.legend()
ax1.set_xticks([])
ax1.set_title('Thorax')

ax2 = fig.add_subplot(212, sharex=ax1)
ax2.imshow(thx_vel_fly0[:,np.newaxis].T, aspect='auto', vmin=0, vmax=10)
ax2.set_yticks([])
ax2.set_title('Velocity')
Text(0.5, 1.0, 'Velocity')
../_images/6a87f28526a92715315c4766f197b24f30f433e554ae3059a3d226b457135a00.png

Visualize thorax colored by magnitude of fly speed#

fig = plt.figure(figsize=(15,6))
ax1 = fig.add_subplot(121)
ax1.plot(thorax_loc[:, 0, 0], thorax_loc[:, 1, 0], 'k')
ax1.set_xlim(0,1024)
ax1.set_xticks([])
ax1.set_ylim(0,1024)
ax1.set_yticks([])
ax1.set_title('Thorax tracks')

kp = thx_vel_fly0  # use thx_vel_fly1 for other fly
vmin = 0
vmax = 10

ax2 = fig.add_subplot(122)
ax2.scatter(thorax_loc[:,0,0], thorax_loc[:,1,0], c=kp, s=4, vmin=vmin, vmax=vmax)
ax2.set_xlim(0,1024)
ax2.set_xticks([])
ax2.set_ylim(0,1024)
ax2.set_yticks([])
ax2.set_title('Thorax tracks colored by magnitude of fly speed')
Text(0.5, 1.0, 'Thorax tracks colored by magnitude of fly speed')
../_images/4c35c39e8dcf293ab2362ab3534303852fcba727e88b0345cabc9999244acf84.png

Find covariance in thorax velocities between fly-0 and fly-1#

import pandas as pd

def corr_roll(datax, datay, win):
    """
    datax, datay are the two timeseries to find correlations between
    
    win sets the number of frames over which the covariance is computed
    
    """
    
    s1 = pd.Series(datax)
    s2 = pd.Series(datay)
    
    return np.array(s2.rolling(win).corr(s1))
win = 50

cov_vel = corr_roll(thx_vel_fly0, thx_vel_fly1,win)

fig, ax = plt.subplots(2, 1, sharex=True, figsize=(15,6))
ax[0].plot(thx_vel_fly0, 'y', label='fly-0')
ax[0].plot(thx_vel_fly1, 'g', label='fly-1')
ax[0].legend()
ax[0].set_title('Forward Velocity')

ax[1].plot(cov_vel, 'c', markersize=1)
ax[1].set_ylim(-1.2, 1.2)
ax[1].set_title('Covariance')

fig.tight_layout()
../_images/b21d966bde1fb74635e96843f91a4ec1072e73e79320eae51395ff6e3a9fe551.png

Clustering#

For an example of clustering the data, we’ll

  1. extract joint velocities for each joint,

  2. run simple k-means on the velocities from each frame.

def instance_node_velocities(instance_idx):
    fly_node_locations = locations[:, :, :, instance_idx]
    fly_node_velocities = np.zeros((frame_count, node_count))

    for n in range(0, node_count):
        fly_node_velocities[:, n] = smooth_diff(fly_node_locations[:, n, :])
    
    return fly_node_velocities
def plot_instance_node_velocities(instance_idx, node_velocities):
    plt.figure(figsize=(20,8))
    plt.imshow(node_velocities.T, aspect='auto', vmin=0, vmax=20, interpolation="nearest")
    plt.xlabel('frames')
    plt.ylabel('nodes')
    plt.yticks(np.arange(node_count), node_names, rotation=20);
    plt.title(f'Fly {instance_idx} node velocities')
fly_ID = 0
fly_node_velocities = instance_node_velocities(fly_ID)
plot_instance_node_velocities(fly_ID, fly_node_velocities)
../_images/42e0a303da85179624b2247d2d06868f157496f5fefdd2f4e8ad741179cb6e69.png
fly_ID = 1
fly_node_velocities = instance_node_velocities(fly_ID)
plot_instance_node_velocities(fly_ID, fly_node_velocities)
../_images/6dcf42bdd95e0ded30ed1fb43ad73bdc855130d58eaa29f55050031f80d0a148.png
from sklearn.cluster import KMeans
nstates = 10

km = KMeans(n_clusters=nstates)

labels = km.fit_predict(fly_node_velocities)
fig = plt.figure(figsize=(20, 12))

ax1 = fig.add_subplot(211)
ax1.imshow(fly_node_velocities.T, aspect="auto", vmin=0, vmax=20, interpolation="nearest")
ax1.set_xlabel("Frames")
ax1.set_ylabel("Nodes")
ax1.set_yticks(np.arange(node_count))
ax1.set_yticklabels(node_names);
ax1.set_title(f"Fly {fly_ID} node velocities")
ax1.set_xlim(0,frame_count)

ax2 = fig.add_subplot(212,sharex=ax1)
ax2.imshow(labels[None, :], aspect="auto", cmap="tab10", interpolation="nearest")
ax2.set_xlabel("Frames")
ax2.set_yticks([])
ax2.set_title("Ethogram (colors = clusters)");
../_images/b44ac03a6b1270ea102001c424a678d09aa630125b1a04b4b4429de3f8b4f1fc.png