"""Inference pipelines and utilities."""
import attr
import logging
import os
import time
from abc import ABC, abstractmethod
from typing import Text, Optional, List, Dict
import tensorflow as tf
import numpy as np
import sleap
from sleap import util
from sleap.nn.config import TrainingJobConfig
from sleap.nn.model import Model
from sleap.nn.tracking import Tracker, run_tracker
from sleap.nn.data.grouping import group_examples_iter
from sleap.nn.data.pipelines import (
Provider,
Pipeline,
LabelsReader,
VideoReader,
Normalizer,
Resizer,
Prefetcher,
LambdaFilter,
KerasModelPredictor,
LocalPeakFinder,
PredictedInstanceCropper,
InstanceCentroidFinder,
InstanceCropper,
GlobalPeakFinder,
MockGlobalPeakFinder,
KeyFilter,
KeyRenamer,
KeyDeviceMover,
PredictedCenterInstanceNormalizer,
PartAffinityFieldInstanceGrouper,
PointsRescaler,
)
logger = logging.getLogger(__name__)
[docs]def safely_generate(ds: tf.data.Dataset, progress: bool = True):
"""Yields examples from dataset, catching and logging exceptions."""
# Unsafe generating:
# for example in ds:
# yield example
ds_iter = iter(ds)
i = 0
wall_t0 = time.time()
done = False
while not done:
try:
next_val = next(ds_iter)
yield next_val
except StopIteration:
done = True
except Exception as e:
logger.info(f"ERROR in sample index {i}")
logger.info(e)
logger.info("")
finally:
if not done:
i += 1
# Show the current progress (frames, time, fps)
if progress:
if (i and i % 1000 == 0) or done:
elapsed_time = time.time() - wall_t0
logger.info(
f"Finished {i} examples in {elapsed_time:.2f} seconds (inference + postprocessing)"
)
if elapsed_time:
logger.info(f"examples/s = {i/elapsed_time}")
def make_grouped_labeled_frame(
video_ind: int,
frame_ind: int,
frame_examples: List[Dict[Text, tf.Tensor]],
videos: List[sleap.Video],
skeleton: "Skeleton",
points_key: Text,
point_confidences_key: Text,
image_key: Optional[Text] = None,
instance_score_key: Optional[Text] = None,
tracker: Optional[Tracker] = None,
) -> List[sleap.LabeledFrame]:
predicted_frames = []
# Create predicted instances from examples in the current frame.
predicted_instances = []
img = None
for example in frame_examples:
if instance_score_key is None:
instance_scores = np.nansum(example[point_confidences_key].numpy(), axis=-1)
else:
instance_scores = example[instance_score_key]
if example[points_key].ndim == 3:
for points, confidences, instance_score in zip(
example[points_key], example[point_confidences_key], instance_scores
):
if not np.isnan(points).all():
predicted_instances.append(
sleap.PredictedInstance.from_arrays(
points=points,
point_confidences=confidences,
instance_score=instance_score,
skeleton=skeleton,
)
)
else:
points = example[points_key]
confidences = example[point_confidences_key]
instance_score = instance_scores
if not np.isnan(points).all():
predicted_instances.append(
sleap.PredictedInstance.from_arrays(
points=points,
point_confidences=confidences,
instance_score=instance_score,
skeleton=skeleton,
)
)
if image_key is not None and image_key in example:
img = example[image_key]
else:
img = None
if len(predicted_instances) > 0:
if tracker:
# Set tracks for predicted instances in this frame.
predicted_instances = tracker.track(
untracked_instances=predicted_instances, img=img, t=frame_ind
)
# Create labeled frame from predicted instances.
labeled_frame = sleap.LabeledFrame(
video=videos[video_ind], frame_idx=frame_ind, instances=predicted_instances
)
predicted_frames.append(labeled_frame)
return predicted_frames
def get_keras_model_path(path: Text) -> Text:
if path.endswith(".json"):
path = os.path.dirname(path)
return os.path.join(path, "best_model.h5")
[docs]@attr.s(auto_attribs=True)
class Predictor(ABC):
"""Base interface class for predictors."""
@classmethod
@abstractmethod
def from_trained_models(cls, *args, **kwargs):
pass
@abstractmethod
def make_pipeline(self):
pass
@abstractmethod
def predict(self, data_provider: Provider):
pass
[docs]@attr.s(auto_attribs=True)
class MockPredictor(Predictor):
labels: sleap.Labels
@classmethod
def from_trained_models(cls, labels_path: Text):
labels = sleap.Labels.load_file(labels_path)
return cls(labels=labels)
def make_pipeline(self):
pass
def predict(self, data_provider: Provider):
prediction_video = None
# Try to match specified video by its full path
prediction_video_path = os.path.abspath(data_provider.video.filename)
for video in self.labels.videos:
if os.path.abspath(video.filename) == prediction_video_path:
prediction_video = video
break
if prediction_video is None:
# Try to match on filename (without path)
prediction_video_path = os.path.basename(data_provider.video.filename)
for video in self.labels.videos:
if os.path.basename(video.filename) == prediction_video_path:
prediction_video = video
break
if prediction_video is None:
# Default to first video in labels file
prediction_video = self.labels.videos[0]
# Get specified frames from labels file (or use None for all frames)
frame_idx_list = (
list(data_provider.example_indices)
if data_provider.example_indices
else None
)
frames = self.labels.find(video=prediction_video, frame_idx=frame_idx_list)
# Run tracker as specified
if self.tracker:
frames = run_tracker(tracker=self.tracker, frames=frames)
self.tracker.final_pass(frames)
# Return frames (there are no "raw" predictions we could return)
return frames
[docs]@attr.s(auto_attribs=True)
class VisualPredictor(Predictor):
"""Predictor class for generating the visual output of model."""
config: TrainingJobConfig
model: Model
pipeline: Optional[Pipeline] = attr.ib(default=None, init=False)
@classmethod
def from_trained_models(cls, model_path: Text) -> "VisualPredictor":
cfg = TrainingJobConfig.load_json(model_path)
keras_model_path = get_keras_model_path(model_path)
model = Model.from_config(cfg.model)
model.keras_model = tf.keras.models.load_model(keras_model_path, compile=False)
return cls(config=cfg, model=model)
def head_specific_output_keys(self) -> List[Text]:
keys = []
key = self.confidence_maps_key_name
if key:
keys.append(key)
key = self.part_affinity_fields_key_name
if key:
keys.append(key)
return keys
@property
def confidence_maps_key_name(self) -> Optional[Text]:
head_key = self.config.model.heads.which_oneof_attrib_name()
if head_key in ("multi_instance", "single_instance"):
return "predicted_confidence_maps"
if head_key == "centroid":
return "predicted_centroid_confidence_maps"
# todo: centered_instance
return None
@property
def part_affinity_fields_key_name(self) -> Optional[Text]:
head_key = self.config.model.heads.which_oneof_attrib_name()
if head_key == "multi_instance":
return "predicted_part_affinity_fields"
return None
def make_pipeline(self):
pipeline = Pipeline()
pipeline += Normalizer.from_config(self.config.data.preprocessing)
pipeline += Resizer.from_config(
self.config.data.preprocessing, keep_full_image=False, points_key=None
)
pipeline += KerasModelPredictor(
keras_model=self.model.keras_model,
model_input_keys="image",
model_output_keys=self.head_specific_output_keys(),
)
self.pipeline = pipeline
def predict_generator(self, data_provider: Provider):
if self.pipeline is None:
# Pass in data provider when mocking one of the models.
self.make_pipeline()
self.pipeline.providers = [data_provider]
# Yield each example from dataset, catching and logging exceptions
return safely_generate(self.pipeline.make_dataset())
def predict(self, data_provider: Provider):
generator = self.predict_generator(data_provider)
examples = list(generator)
return examples
[docs]@attr.s(auto_attribs=True)
class TopdownPredictor(Predictor):
centroid_config: Optional[TrainingJobConfig] = attr.ib(default=None)
centroid_model: Optional[Model] = attr.ib(default=None)
confmap_config: Optional[TrainingJobConfig] = attr.ib(default=None)
confmap_model: Optional[Model] = attr.ib(default=None)
pipeline: Optional[Pipeline] = attr.ib(default=None, init=False)
tracker: Optional[Tracker] = attr.ib(default=None, init=False)
batch_size: int = 1
peak_threshold: float = 0.2
integral_refinement: bool = True
integral_patch_size: int = 5
[docs] @classmethod
def from_trained_models(
cls,
centroid_model_path: Optional[Text] = None,
confmap_model_path: Optional[Text] = None,
batch_size: int = 1,
peak_threshold: float = 0.2,
integral_refinement: bool = True,
integral_patch_size: int = 5,
) -> "TopdownPredictor":
"""Create predictor from saved models.
Args:
centroid_model_path: Path to centroid model folder.
confmap_model_path: Path to topdown confidence map model folder.
Returns:
An instance of TopdownPredictor with the loaded models.
One of the two models can be left as None to perform inference with ground
truth data. This will only work with LabelsReader as the provider.
"""
if centroid_model_path is None and confmap_model_path is None:
raise ValueError(
"Either the centroid or topdown confidence map model must be provided."
)
if centroid_model_path is not None:
# Load centroid model.
centroid_config = TrainingJobConfig.load_json(centroid_model_path)
centroid_keras_model_path = get_keras_model_path(centroid_model_path)
centroid_model = Model.from_config(centroid_config.model)
centroid_model.keras_model = tf.keras.models.load_model(
centroid_keras_model_path, compile=False
)
else:
centroid_config = None
centroid_model = None
if confmap_model_path is not None:
# Load confmap model.
confmap_config = TrainingJobConfig.load_json(confmap_model_path)
confmap_keras_model_path = get_keras_model_path(confmap_model_path)
confmap_model = Model.from_config(confmap_config.model)
confmap_model.keras_model = tf.keras.models.load_model(
confmap_keras_model_path, compile=False
)
else:
confmap_config = None
confmap_model = None
return cls(
centroid_config=centroid_config,
centroid_model=centroid_model,
confmap_config=confmap_config,
confmap_model=confmap_model,
batch_size=batch_size,
peak_threshold=peak_threshold,
integral_refinement=integral_refinement,
integral_patch_size=integral_patch_size,
)
def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
keep_original_image = self.tracker and self.tracker.uses_image
pipeline = Pipeline()
if data_provider is not None:
pipeline.providers = [data_provider]
pipeline += Prefetcher()
pipeline += KeyRenamer(
old_key_names=["image", "scale"],
new_key_names=["full_image", "full_image_scale"],
drop_old=False,
)
if keep_original_image:
pipeline += KeyRenamer(
old_key_names=["image", "scale"],
new_key_names=["original_image", "original_image_scale"],
drop_old=False,
)
pipeline += KeyDeviceMover(["original_image"])
if self.confmap_config is not None:
# Infer colorspace preprocessing if not explicit.
if not (
self.confmap_config.data.preprocessing.ensure_rgb
or self.confmap_config.data.preprocessing.ensure_grayscale
):
if self.confmap_model.keras_model.inputs[0].shape[-1] == 1:
self.confmap_config.data.preprocessing.ensure_grayscale = True
else:
self.confmap_config.data.preprocessing.ensure_rgb = True
pipeline += Normalizer.from_config(
self.confmap_config.data.preprocessing, image_key="full_image"
)
points_key = "instances" if self.centroid_model is None else None
pipeline += Resizer.from_config(
self.confmap_config.data.preprocessing,
points_key=points_key,
image_key="full_image",
scale_key="full_image_scale",
)
if self.centroid_model is not None:
# Infer colorspace preprocessing if not explicit.
if not (
self.centroid_config.data.preprocessing.ensure_rgb
or self.centroid_config.data.preprocessing.ensure_grayscale
):
if self.centroid_model.keras_model.inputs[0].shape[-1] == 1:
self.centroid_config.data.preprocessing.ensure_grayscale = True
else:
self.centroid_config.data.preprocessing.ensure_rgb = True
pipeline += Normalizer.from_config(
self.centroid_config.data.preprocessing, image_key="image"
)
pipeline += Resizer.from_config(
self.centroid_config.data.preprocessing, points_key=None
)
# Predict centroids using model.
pipeline += KerasModelPredictor(
keras_model=self.centroid_model.keras_model,
model_input_keys="image",
model_output_keys="predicted_centroid_confidence_maps",
)
pipeline += LocalPeakFinder(
confmaps_stride=self.centroid_model.heads[0].output_stride,
peak_threshold=self.peak_threshold,
confmaps_key="predicted_centroid_confidence_maps",
peaks_key="predicted_centroids",
peak_vals_key="predicted_centroid_confidences",
peak_sample_inds_key="predicted_centroid_sample_inds",
peak_channel_inds_key="predicted_centroid_channel_inds",
keep_confmaps=False,
)
pipeline += LambdaFilter(
filter_fn=lambda ex: len(ex["predicted_centroids"]) > 0
)
if self.confmap_config is not None:
crop_size = self.confmap_config.data.instance_cropping.crop_size
else:
crop_size = sleap.nn.data.instance_cropping.find_instance_crop_size(
data_provider.labels
)
pipeline += PredictedInstanceCropper(
crop_width=crop_size,
crop_height=crop_size,
centroids_key="predicted_centroids",
centroid_confidences_key="predicted_centroid_confidences",
full_image_key="full_image",
full_image_scale_key="full_image_scale",
keep_instances_gt=self.confmap_model is None,
other_keys_to_keep=["original_image"] if keep_original_image else None,
)
if keep_original_image:
pipeline += KeyDeviceMover(["original_image"])
else:
# Generate ground truth centroids and crops.
anchor_part = self.confmap_config.data.instance_cropping.center_on_part
pipeline += InstanceCentroidFinder(
center_on_anchor_part=anchor_part is not None,
anchor_part_names=anchor_part,
skeletons=data_provider.labels.skeletons,
)
pipeline += KeyRenamer(
old_key_names=["full_image", "full_image_scale"],
new_key_names=["image", "scale"],
drop_old=True,
)
pipeline += InstanceCropper(
crop_width=self.confmap_config.data.instance_cropping.crop_size,
crop_height=self.confmap_config.data.instance_cropping.crop_size,
mock_centroid_confidence=True,
)
if self.confmap_model is not None:
# Predict confidence maps using model.
if self.batch_size > 1:
pipeline += sleap.nn.data.pipelines.Batcher(
batch_size=self.batch_size, drop_remainder=False
)
pipeline += KerasModelPredictor(
keras_model=self.confmap_model.keras_model,
model_input_keys="instance_image",
model_output_keys="predicted_instance_confidence_maps",
)
if self.batch_size > 1:
pipeline += sleap.nn.data.pipelines.Unbatcher()
pipeline += GlobalPeakFinder(
confmaps_key="predicted_instance_confidence_maps",
peaks_key="predicted_center_instance_points",
confmaps_stride=self.confmap_model.heads[0].output_stride,
peak_threshold=self.peak_threshold,
integral=self.integral_refinement,
integral_patch_size=self.integral_patch_size,
keep_confmaps=False,
)
else:
# Generate ground truth instance points.
pipeline += MockGlobalPeakFinder(
all_peaks_in_key="instances",
peaks_out_key="predicted_center_instance_points",
peak_vals_key="predicted_center_instance_confidences",
keep_confmaps=False,
)
keep_keys = [
"bbox",
"center_instance_ind",
"centroid",
"centroid_confidence",
"scale",
"video_ind",
"frame_ind",
"center_instance_ind",
"predicted_center_instance_points",
"predicted_center_instance_confidences",
]
if keep_original_image:
keep_keys.append("original_image")
pipeline += KeyFilter(keep_keys=keep_keys)
pipeline += PredictedCenterInstanceNormalizer(
centroid_key="centroid",
centroid_confidence_key="centroid_confidence",
peaks_key="predicted_center_instance_points",
peak_confidences_key="predicted_center_instance_confidences",
new_centroid_key="predicted_centroid",
new_centroid_confidence_key="predicted_centroid_confidence",
new_peaks_key="predicted_instance",
new_peak_confidences_key="predicted_instance_confidences",
)
self.pipeline = pipeline
return pipeline
def predict_generator(self, data_provider: Provider):
if self.pipeline is None:
if self.centroid_config is not None and self.confmap_config is not None:
self.make_pipeline()
else:
# Pass in data provider when mocking one of the models.
self.make_pipeline(data_provider=data_provider)
self.pipeline.providers = [data_provider]
# Yield each example from dataset, catching and logging exceptions
return safely_generate(self.pipeline.make_dataset())
def make_labeled_frames_from_generator(self, generator, data_provider):
grouped_generator = group_examples_iter(generator)
if self.confmap_config is not None:
skeleton = self.confmap_config.data.labels.skeletons[0]
else:
skeleton = self.centroid_config.data.labels.skeletons[0]
def make_lfs(video_ind, frame_ind, frame_examples):
return make_grouped_labeled_frame(
video_ind=video_ind,
frame_ind=frame_ind,
frame_examples=frame_examples,
videos=data_provider.videos,
skeleton=skeleton,
image_key="original_image",
points_key="predicted_instance",
point_confidences_key="predicted_instance_confidences",
instance_score_key="predicted_centroid_confidence",
tracker=self.tracker,
)
predicted_frames = []
for (video_ind, frame_ind), grouped_examples in grouped_generator:
predicted_frames.extend(make_lfs(video_ind, frame_ind, grouped_examples))
if self.tracker:
self.tracker.final_pass(predicted_frames)
return predicted_frames
def predict(
self,
data_provider: Provider,
make_instances: bool = True,
make_labels: bool = False,
):
t0_gen = time.time()
if isinstance(data_provider, sleap.Labels):
data_provider = LabelsReader(data_provider)
elif isinstance(data_provider, sleap.Video):
data_provider = VideoReader(data_provider)
generator = self.predict_generator(data_provider)
if make_instances or make_labels:
lfs = self.make_labeled_frames_from_generator(generator, data_provider)
elapsed = time.time() - t0_gen
logger.info(
f"Predicted {len(lfs)} labeled frames in {elapsed:.3f} secs [{len(lfs)/elapsed:.1f} FPS]"
)
if make_labels:
return sleap.Labels(lfs)
else:
return lfs
else:
examples = list(generator)
elapsed = time.time() - t0_gen
logger.info(
f"Predicted {len(examples)} examples in {elapsed:.3f} secs [{len(examples)/elapsed:.1f} FPS]"
)
return examples
[docs]@attr.s(auto_attribs=True)
class BottomupPredictor(Predictor):
bottomup_config: TrainingJobConfig
bottomup_model: Model
pipeline: Optional[Pipeline] = attr.ib(default=None, init=False)
tracker: Optional[Tracker] = attr.ib(default=None, init=False)
peak_threshold: float = 0.2
[docs] @classmethod
def from_trained_models(cls, bottomup_model_path: Text) -> "BottomupPredictor":
"""Create predictor from saved models."""
# Load bottomup model.
bottomup_config = TrainingJobConfig.load_json(bottomup_model_path)
bottomup_keras_model_path = get_keras_model_path(bottomup_model_path)
bottomup_model = Model.from_config(bottomup_config.model)
bottomup_model.keras_model = tf.keras.models.load_model(
bottomup_keras_model_path, compile=False
)
return cls(bottomup_config=bottomup_config, bottomup_model=bottomup_model)
def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
pipeline = Pipeline()
if data_provider is not None:
pipeline.providers = [data_provider]
# Infer colorspace preprocessing if not explicit.
if not (
self.bottomup_config.data.preprocessing.ensure_rgb
or self.bottomup_config.data.preprocessing.ensure_grayscale
):
if self.bottomup_model.keras_model.inputs[0].shape[-1] == 1:
self.bottomup_config.data.preprocessing.ensure_grayscale = True
else:
self.bottomup_config.data.preprocessing.ensure_rgb = True
pipeline += Normalizer.from_config(self.bottomup_config.data.preprocessing)
pipeline += Resizer.from_config(
self.bottomup_config.data.preprocessing,
keep_full_image=False,
points_key=None,
)
pipeline += Prefetcher()
pipeline += KerasModelPredictor(
keras_model=self.bottomup_model.keras_model,
model_input_keys="image",
model_output_keys=[
"predicted_confidence_maps",
"predicted_part_affinity_fields",
],
)
pipeline += LocalPeakFinder(
confmaps_stride=self.bottomup_model.heads[0].output_stride,
peak_threshold=self.peak_threshold,
confmaps_key="predicted_confidence_maps",
peaks_key="predicted_peaks",
peak_vals_key="predicted_peak_confidences",
peak_sample_inds_key="predicted_peak_sample_inds",
peak_channel_inds_key="predicted_peak_channel_inds",
keep_confmaps=False,
)
pipeline += LambdaFilter(filter_fn=lambda ex: len(ex["predicted_peaks"]) > 0)
pipeline += PartAffinityFieldInstanceGrouper.from_config(
self.bottomup_config.model.heads.multi_instance,
max_edge_length=128,
min_edge_score=0.05,
n_points=10,
min_instance_peaks=0,
peaks_key="predicted_peaks",
peak_scores_key="predicted_peak_confidences",
channel_inds_key="predicted_peak_channel_inds",
pafs_key="predicted_part_affinity_fields",
predicted_instances_key="predicted_instances",
predicted_peak_scores_key="predicted_peak_scores",
predicted_instance_scores_key="predicted_instance_scores",
keep_pafs=False,
)
keep_keys = [
"scale",
"video_ind",
"frame_ind",
"predicted_instances",
"predicted_peak_scores",
"predicted_instance_scores",
]
if self.tracker and self.tracker.uses_image:
keep_keys.append("image")
pipeline += KeyFilter(keep_keys=keep_keys)
pipeline += PointsRescaler(
points_key="predicted_instances", scale_key="scale", invert=True
)
self.pipeline = pipeline
return pipeline
def make_labeled_frames_from_generator(self, generator, data_provider):
grouped_generator = group_examples_iter(generator)
skeleton = self.bottomup_config.data.labels.skeletons[0]
def make_lfs(video_ind, frame_ind, frame_examples):
return make_grouped_labeled_frame(
video_ind=video_ind,
frame_ind=frame_ind,
frame_examples=frame_examples,
videos=data_provider.videos,
skeleton=skeleton,
image_key="image",
points_key="predicted_instances",
point_confidences_key="predicted_peak_scores",
instance_score_key="predicted_instance_scores",
tracker=self.tracker,
)
predicted_frames = []
for (video_ind, frame_ind), grouped_examples in grouped_generator:
predicted_frames.extend(make_lfs(video_ind, frame_ind, grouped_examples))
if self.tracker:
self.tracker.final_pass(predicted_frames)
return predicted_frames
def predict_generator(self, data_provider: Provider):
if self.pipeline is None:
self.make_pipeline()
self.pipeline.providers = [data_provider]
# Yield each example from dataset, catching and logging exceptions
return safely_generate(self.pipeline.make_dataset())
def predict(
self,
data_provider: Provider,
make_instances: bool = True,
make_labels: bool = False,
):
if isinstance(data_provider, sleap.Labels):
data_provider = LabelsReader(data_provider)
elif isinstance(data_provider, sleap.Video):
data_provider = VideoReader(data_provider)
generator = self.predict_generator(data_provider)
if make_instances or make_labels:
lfs = self.make_labeled_frames_from_generator(generator, data_provider)
if make_labels:
return sleap.Labels(lfs)
else:
return lfs
return list(generator)
[docs]@attr.s(auto_attribs=True)
class SingleInstancePredictor(Predictor):
confmap_config: TrainingJobConfig
confmap_model: Model
pipeline: Optional[Pipeline] = attr.ib(default=None, init=False)
peak_threshold: float = 0.2
integral_refinement: bool = True
integral_patch_size: int = 5
[docs] @classmethod
def from_trained_models(
cls,
confmap_model_path: Text,
peak_threshold: float = 0.2,
integral_refinement: bool = True,
integral_patch_size: int = 5,
) -> "SingleInstancePredictor":
"""Create predictor from saved models."""
# Load confmap model.
confmap_config = TrainingJobConfig.load_json(confmap_model_path)
confmap_keras_model_path = get_keras_model_path(confmap_model_path)
confmap_model = Model.from_config(confmap_config.model)
confmap_model.keras_model = tf.keras.models.load_model(
confmap_keras_model_path, compile=False
)
return cls(
confmap_config=confmap_config,
confmap_model=confmap_model,
peak_threshold=peak_threshold,
integral_refinement=integral_refinement,
integral_patch_size=integral_patch_size,
)
def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
pipeline = Pipeline()
if data_provider is not None:
pipeline.providers = [data_provider]
# Infer colorspace preprocessing if not explicit.
if not (
self.confmap_config.data.preprocessing.ensure_rgb
or self.confmap_config.data.preprocessing.ensure_grayscale
):
if self.confmap_model.keras_model.inputs[0].shape[-1] == 1:
self.confmap_config.data.preprocessing.ensure_grayscale = True
else:
self.confmap_config.data.preprocessing.ensure_rgb = True
pipeline += Normalizer.from_config(self.confmap_config.data.preprocessing)
pipeline += Resizer.from_config(
self.confmap_config.data.preprocessing, points_key=None
)
pipeline += Prefetcher()
pipeline += KerasModelPredictor(
keras_model=self.confmap_model.keras_model,
model_input_keys="image",
model_output_keys="predicted_instance_confidence_maps",
)
pipeline += GlobalPeakFinder(
confmaps_key="predicted_instance_confidence_maps",
peaks_key="predicted_instance",
peak_vals_key="predicted_instance_confidences",
confmaps_stride=self.confmap_model.heads[0].output_stride,
peak_threshold=self.peak_threshold,
integral=self.integral_refinement,
integral_patch_size=self.integral_patch_size,
)
pipeline += KeyFilter(
keep_keys=[
"scale",
"video_ind",
"frame_ind",
"predicted_instance",
"predicted_instance_confidences",
]
)
pipeline += PointsRescaler(
points_key="predicted_instance", scale_key="scale", invert=True
)
self.pipeline = pipeline
return pipeline
def make_labeled_frames_from_generator(self, generator, data_provider):
grouped_generator = group_examples_iter(generator)
skeleton = self.confmap_config.data.labels.skeletons[0]
def make_lfs(video_ind, frame_ind, frame_examples):
return make_grouped_labeled_frame(
video_ind=video_ind,
frame_ind=frame_ind,
frame_examples=frame_examples,
videos=data_provider.videos,
skeleton=skeleton,
points_key="predicted_instance",
point_confidences_key="predicted_instance_confidences",
)
predicted_frames = []
for (video_ind, frame_ind), grouped_examples in grouped_generator:
predicted_frames.extend(make_lfs(video_ind, frame_ind, grouped_examples))
return predicted_frames
def predict_generator(self, data_provider: Provider):
if self.pipeline is None:
self.make_pipeline()
self.pipeline.providers = [data_provider]
# Yield each example from dataset, catching and logging exceptions
return safely_generate(self.pipeline.make_dataset())
def predict(
self,
data_provider: Provider,
make_instances: bool = True,
make_labels: bool = False,
):
if isinstance(data_provider, sleap.Labels):
data_provider = LabelsReader(data_provider)
elif isinstance(data_provider, sleap.Video):
data_provider = VideoReader(data_provider)
generator = self.predict_generator(data_provider)
if make_instances or make_labels:
lfs = self.make_labeled_frames_from_generator(generator, data_provider)
if make_labels:
return sleap.Labels(lfs)
else:
return lfs
return list(generator)
CLI_PREDICTORS = {
"topdown": TopdownPredictor,
"bottomup": BottomupPredictor,
"single": SingleInstancePredictor,
}
def make_cli_parser():
import argparse
from sleap.util import frame_list
parser = argparse.ArgumentParser()
# Add args for entire pipeline
parser.add_argument(
"video_path", type=str, nargs="?", default="", help="Path to video file"
)
parser.add_argument(
"-m",
"--model",
dest="models",
action="append",
help="Path to trained model directory (with training_config.json). "
"Multiple models can be specified, each preceded by --model.",
)
parser.add_argument(
"--frames",
type=frame_list,
default="",
help="List of frames to predict. Either comma separated list (e.g. 1,2,3) or "
"a range separated by hyphen (e.g. 1-3, for 1,2,3). (default is entire video)",
)
parser.add_argument(
"--only-labeled-frames",
action="store_true",
default=False,
help="Only run inference on labeled frames (when running on labels dataset file).",
)
parser.add_argument(
"--only-suggested-frames",
action="store_true",
default=False,
help="Only run inference on suggested frames (when running on labels dataset file).",
)
parser.add_argument(
"-o",
"--output",
type=str,
default=None,
help="The output filename to use for the predicted data.",
)
parser.add_argument(
"--labels",
type=str,
default=None,
help="Path to labels dataset file (for inference on multiple videos or for re-tracking pre-existing predictions).",
)
# TODO: better video parameters
parser.add_argument(
"--video.dataset", type=str, default="", help="The dataset for HDF5 videos."
)
parser.add_argument(
"--video.input_format",
type=str,
default="",
help="The input_format for HDF5 videos.",
)
device_group = parser.add_mutually_exclusive_group(required=False)
device_group.add_argument(
"--cpu",
action="store_true",
help="Run inference only on CPU. If not specified, will use available GPU.",
)
device_group.add_argument(
"--first-gpu",
action="store_true",
help="Run inference on the first GPU, if available.",
)
device_group.add_argument(
"--last-gpu",
action="store_true",
help="Run inference on the last GPU, if available.",
)
device_group.add_argument(
"--gpu", type=int, default=0, help="Run inference on the i-th GPU specified."
)
# Add args for each predictor class
for predictor_name, predictor_class in CLI_PREDICTORS.items():
if "peak_threshold" in attr.fields_dict(predictor_class):
# get the default value to show in help string, although we'll
# use None as default so that unspecified vals won't be passed to
# builder.
default_val = attr.fields_dict(predictor_class)["peak_threshold"].default
parser.add_argument(
f"--{predictor_name}.peak_threshold",
type=float,
default=None,
help=f"Threshold to use when finding peaks in {predictor_class.__name__} (default: {default_val}).",
)
if "batch_size" in attr.fields_dict(predictor_class):
default_val = attr.fields_dict(predictor_class)["batch_size"].default
parser.add_argument(
f"--{predictor_name}.batch_size",
type=int,
default=None,
help=f"Batch size to use for model inference in {predictor_class.__name__} (default: {default_val}).",
)
# Add args for tracking
Tracker.add_cli_parser_args(parser, arg_scope="tracking")
parser.add_argument(
"--test-pipeline",
default=False,
action="store_true",
help="Test pipeline construction without running anything.",
)
return parser
def make_video_readers_from_cli(args) -> List[VideoReader]:
if args.video_path:
# TODO: better support for video params
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
video_reader = VideoReader.from_filepath(
filename=args.video_path, example_indices=args.frames, **video_kwargs
)
return [video_reader]
if args.labels:
labels = sleap.Labels.load_file(args.labels)
readers = []
if args.only_labeled_frames:
user_labeled_frames = labels.user_labeled_frames
else:
user_labeled_frames = []
for video in labels.videos:
if args.only_labeled_frames:
frame_indices = [
lf.frame_idx for lf in user_labeled_frames if lf.video == video
]
readers.append(VideoReader(video=video, example_indices=frame_indices))
elif args.only_suggested_frames:
readers.append(
VideoReader(
video=video, example_indices=labels.get_video_suggestions(video)
)
)
else:
readers.append(VideoReader(video=video))
return readers
raise ValueError("You must specify either video_path or labels dataset path.")
[docs]def make_predictor_from_paths(paths) -> Predictor:
"""Builds predictor object from a list of model paths."""
return make_predictor_from_models(find_heads_for_model_paths(paths))
[docs]def find_heads_for_model_paths(paths) -> Dict[str, str]:
"""Given list of models paths, returns dict with path keyed by head name."""
trained_model_paths = dict()
if paths is None:
return trained_model_paths
for model_path in paths:
# Load the model config
cfg = TrainingJobConfig.load_json(model_path)
# Get the head from the model (i.e., what the model will predict)
key = cfg.model.heads.which_oneof_attrib_name()
# If path is to config file json, then get the path to parent dir
if model_path.endswith(".json"):
model_path = os.path.dirname(model_path)
trained_model_paths[key] = model_path
return trained_model_paths
[docs]def make_predictor_from_models(
trained_model_paths: Dict[str, str],
labels_path: Optional[str] = None,
policy_args: Optional[dict] = None,
) -> Predictor:
"""Given dict of paths keyed by head name, returns appropriate predictor."""
def get_relevant_args(key):
if policy_args is not None and key in policy_args:
return policy_args[key]
return dict()
if "multi_instance" in trained_model_paths:
predictor = BottomupPredictor.from_trained_models(
trained_model_paths["multi_instance"], **get_relevant_args("bottomup")
)
elif "single_instance" in trained_model_paths:
predictor = SingleInstancePredictor.from_trained_models(
confmap_model_path=trained_model_paths["single_instance"],
**get_relevant_args("single"),
)
elif (
"centroid" in trained_model_paths and "centered_instance" in trained_model_paths
):
predictor = TopdownPredictor.from_trained_models(
centroid_model_path=trained_model_paths["centroid"],
confmap_model_path=trained_model_paths["centered_instance"],
**get_relevant_args("topdown"),
)
elif len(trained_model_paths) == 0 and labels_path:
predictor = MockPredictor.from_trained_models(labels_path=labels_path)
else:
raise ValueError(
f"Unable to run inference with {list(trained_model_paths.keys())} heads."
)
return predictor
def make_tracker_from_cli(policy_args):
if "tracking" in policy_args:
tracker = Tracker.make_tracker_by_name(**policy_args["tracking"])
return tracker
return None
def save_predictions_from_cli(args, predicted_frames, prediction_metadata=None):
from sleap import Labels
if args.output:
output_path = args.output
elif args.video_path:
out_dir = os.path.dirname(args.video_path)
out_name = os.path.basename(args.video_path) + ".predictions.slp"
output_path = os.path.join(out_dir, out_name)
elif args.labels:
out_dir = os.path.dirname(args.labels)
out_name = os.path.basename(args.labels) + ".predictions.slp"
output_path = os.path.join(out_dir, out_name)
else:
# We shouldn't ever get here but if we do, just save in working dir.
output_path = "predictions.slp"
labels = Labels(labeled_frames=predicted_frames, provenance=prediction_metadata)
print(f"Saving: {output_path}")
Labels.save_file(labels, output_path)
[docs]def main():
"""CLI for running inference."""
parser = make_cli_parser()
args, _ = parser.parse_known_args()
print(args)
if args.cpu or not sleap.nn.system.is_gpu_system():
sleap.nn.system.use_cpu_only()
else:
if args.first_gpu:
sleap.nn.system.use_first_gpu()
elif args.last_gpu:
sleap.nn.system.use_last_gpu()
else:
sleap.nn.system.use_gpu(args.gpu)
sleap.nn.system.disable_preallocation()
print("System:")
sleap.nn.system.summary()
video_readers = make_video_readers_from_cli(args)
# Find the specified models
model_paths_by_head = find_heads_for_model_paths(args.models)
# Make a scoped dictionary with args specified from cli
policy_args = util.make_scoped_dictionary(vars(args), exclude_nones=True)
# Create appropriate predictor given these models
predictor = make_predictor_from_models(
model_paths_by_head, labels_path=args.labels, policy_args=policy_args
)
# Make the tracker
tracker = make_tracker_from_cli(policy_args)
predictor.tracker = tracker
if args.test_pipeline:
print()
print(policy_args)
print()
print(predictor)
print()
predictor.make_pipeline()
print("===pipeline transformers===")
print()
for transformer in predictor.pipeline.transformers:
print(transformer.__class__.__name__)
print(f"\t-> {transformer.input_keys}")
print(f"\t {transformer.output_keys} ->")
print()
print("--test-pipeline arg set so stopping here.")
return
# Run inference!
t0 = time.time()
predicted_frames = []
for video_reader in video_readers:
video_predicted_frames = predictor.predict(video_reader)
predicted_frames.extend(video_predicted_frames)
# Create dictionary of metadata we want to save with predictions
prediction_metadata = dict()
for head, path in model_paths_by_head.items():
prediction_metadata[f"model.{head}.path"] = os.path.abspath(path)
for scope in policy_args.keys():
for key, val in policy_args[scope].items():
prediction_metadata[f"{scope}.{key}"] = val
prediction_metadata["video.path"] = args.video_path
prediction_metadata["sleap.version"] = sleap.__version__
save_predictions_from_cli(args, predicted_frames, prediction_metadata)
print(f"Total Time: {time.time() - t0}")
if __name__ == "__main__":
main()