Source code for sleap.io.format.sleap_analysis
"""
Adaptor to read and write analysis HDF5 files.
These contain location and track data, but lack other metadata included in a
full SLEAP dataset file.
Note that this adaptor will use default track names and skeleton node names
if these cannot be read from the HDF5 (some files have these, some don't).
To determine whether this adaptor can read a file, we check it's an HDF5 file
with a `track_occupancy` dataset.
"""
import numpy as np
from typing import Union
from sleap import Labels, Video, Skeleton
from sleap.instance import PredictedInstance, LabeledFrame, Track
from .adaptor import Adaptor, SleapObjectType
from .filehandle import FileHandle
[docs]class SleapAnalysisAdaptor(Adaptor):
@property
def handles(self):
return SleapObjectType.labels
@property
def default_ext(self):
return "h5"
@property
def all_exts(self):
return ["h5", "hdf5"]
@property
def name(self):
return "SLEAP Analysis HDF5"
[docs] def can_read_file(self, file: FileHandle):
if not self.does_match_ext(file.filename):
return False
if not file.is_hdf5:
return False
if "track_occupancy" not in file.file:
return False
return True
[docs] def can_write_filename(self, filename: str):
return self.does_match_ext(filename)
[docs] def does_read(self) -> bool:
return True
[docs] def does_write(self) -> bool:
return True
[docs] @classmethod
def read(
cls, file: FileHandle, video: Union[Video, str], *args, **kwargs,
) -> Labels:
connect_adj_nodes = False
if video is None:
raise ValueError("Cannot read analysis hdf5 if no video specified.")
if not isinstance(video, Video):
video = Video.from_filename(video)
f = file.file
tracks_matrix = f["tracks"][:].T
# shape: frames * nodes * 2 * tracks
frame_count, node_count, _, track_count = tracks_matrix.shape
if "track_names" in f:
track_names_list = f["track_names"][:].T
tracks = [Track(0, track_name.decode()) for track_name in track_names_list]
else:
tracks = [Track(0, f"track_{i}") for i in range(track_count)]
if "node_names" in f:
node_names_dset = f["node_names"][:].T
node_names = [name.decode() for name in node_names_dset]
else:
node_names = [f"node {i}" for i in range(node_count)]
skeleton = Skeleton()
last_node_name = None
for node_name in node_names:
skeleton.add_node(node_name)
if connect_adj_nodes and last_node_name:
skeleton.add_edge(last_node_name, node_name)
last_node_name = node_name
frames = []
for frame_idx in range(frame_count):
instances = []
for track_idx in range(track_count):
points = tracks_matrix[frame_idx, ..., track_idx]
if not np.all(np.isnan(points)):
point_scores = np.ones(len(points))
# make everything a PredictedInstance since the usual use
# case is to export predictions for analysis
instances.append(
PredictedInstance.from_arrays(
points=points,
point_confidences=point_scores,
skeleton=skeleton,
track=tracks[track_idx],
instance_score=1,
)
)
if instances:
frames.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=instances)
)
return Labels(labeled_frames=frames)
[docs] @classmethod
def write(cls, filename: str, source_object: Labels):
from sleap.info.write_tracking_h5 import main as write_analysis
write_analysis(source_object, output_path=filename, all_frames=True)