Source code for sleap.gui.suggestions

"""
Module for generating lists of suggested frames (for labeling or reviewing).
"""

import attr
import numpy as np
import random

from typing import List, Optional, Union

from sleap.io.video import Video
from sleap.info.feature_suggestions import (
    FeatureSuggestionPipeline,
    ParallelFeaturePipeline,
)

GroupType = int


[docs]@attr.s(auto_attribs=True, slots=True) class SuggestionFrame: """Object for storing a single suggested frame item.""" video: Video frame_idx: int group: Optional[GroupType] = None
[docs]class VideoFrameSuggestions(object): """ Class for generating lists of suggested frames. Implements various algorithms as methods: * sample (either random or evenly spaced sample frames from each video) * image features (raw images/brisk -> pca -> k-means) * prediction_score (frames with number of instances below specified score) Each of algorithm method should accept `labels`; other parameters will be passed from the `params` dict given to :meth:`suggest`. """
[docs] @classmethod def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame]: """ This is the main entry point for generating lists of suggested frames. Args: params: A dictionary with all params to control how we generate suggestions, minimally this will have a "method" key with the name of one of the class methods. labels: A `Labels` object for which we are generating suggestions. Returns: List of `SuggestionFrame` objects. """ # map from method param value to corresponding class method method_functions = dict( sample=cls.basic_sample_suggestion_method, image_features=cls.image_feature_based_method, prediction_score=cls.prediction_score, velocity=cls.velocity, ) method = str.replace(params["method"], " ", "_") if method_functions.get(method, None) is not None: return method_functions[method](labels=labels, **params) else: print(f"No {method} method found for generating suggestions.") return []
# Functions corresponding to "method" param
[docs] @classmethod def basic_sample_suggestion_method( cls, labels, per_video: int = 20, sampling_method: str = "random", **kwargs ): """Method to generate suggestions by taking strides through video.""" suggestions = [] for video in labels.videos: if sampling_method == "stride": vid_suggestions = list( range(0, video.frames, video.frames // per_video) )[:per_video] else: # random sampling vid_suggestions = random.sample(range(video.frames), per_video) group = labels.videos.index(video) suggestions.extend( cls.idx_list_to_frame_list(vid_suggestions, video, group) ) return suggestions
[docs] @classmethod def image_feature_based_method( cls, labels, per_video, sample_method, scale, merge_video_features, feature_type, pca_components, n_clusters, per_cluster, **kwargs, ): """ Method to generate suggestions based on image features. This is a wrapper for `feature_suggestion_pipeline` implemented in `sleap.info.feature_suggestions`. """ brisk_threshold = kwargs.get("brisk_threshold", 80) vocab_size = kwargs.get("vocab_size", 20) pipeline = FeatureSuggestionPipeline( per_video=per_video, scale=scale, sample_method=sample_method, feature_type=feature_type, brisk_threshold=brisk_threshold, vocab_size=vocab_size, n_components=pca_components, n_clusters=n_clusters, per_cluster=per_cluster, ) if merge_video_features == "across all videos": # Run single pipeline with all videos return pipeline.get_suggestion_frames(videos=labels.videos) else: # Run pipeline separately (in parallel) for each video suggestions = ParallelFeaturePipeline.run(pipeline, labels.videos) return suggestions
[docs] @classmethod def prediction_score(cls, labels: "Labels", score_limit, instance_limit, **kwargs): """ Method to generate suggestions for proofreading frames with low score. """ score_limit = float(score_limit) instance_limit = int(instance_limit) suggestions = [] for video in labels.videos: suggestions.extend( cls._prediction_score_video(video, labels, score_limit, instance_limit) ) return suggestions
@classmethod def _prediction_score_video( cls, video: "Video", labels: "Labels", score_limit: float, instance_limit: int ): lfs = labels.find(video) frames = len(lfs) idxs = np.ndarray((frames), dtype="int") scores = np.full((frames, instance_limit), 100.0, dtype="float") # Build matrix with scores for instances in frames for i, lf in enumerate(lfs): # Scores from instances in frame frame_scores = [inst.score for inst in lf if hasattr(inst, "score")] # Just get the lowest scores if len(frame_scores) > instance_limit: frame_scores = sorted(frame_scores)[:instance_limit] # Add to matrix scores[i, : len(frame_scores)] = frame_scores idxs[i] = lf.frame_idx # Find instances below score of <score_limit> low_instances = np.nansum(scores < score_limit, axis=1) # Find all the frames with at least <instance_limit> low scoring instances result = idxs[low_instances >= instance_limit].tolist() return cls.idx_list_to_frame_list(result, video)
[docs] @classmethod def velocity( cls, labels: "Labels", node: Union[int, str], threshold: float, **kwargs ): """ Finds frames for proofreading with high node velocity. """ if isinstance(node, str): node_name = node else: try: node_name = labels.skeletons[0].nodes[node] except IndexError: node_name = "" suggestions = [] for video in labels.videos: suggestions.extend(cls._velocity_video(video, labels, node_name, threshold)) return suggestions
@classmethod def _velocity_video( cls, video: "Video", labels: "Labels", node_name: str, threshold: float ): from sleap.info.summary import StatisticSeries displacements = StatisticSeries(labels).get_primary_point_displacement_series( video=video, reduction="sum", primary_node=node_name ) data_range = np.ptp(displacements) data_min = np.min(displacements) frame_idxs = list( map( int, np.squeeze( np.argwhere(displacements - data_min > data_range * threshold) ), ) ) return cls.idx_list_to_frame_list(frame_idxs, video) # Utility functions @staticmethod def idx_list_to_frame_list( idx_list, video: "Video", group: Optional[GroupType] = None ) -> List[SuggestionFrame]: return [SuggestionFrame(video, frame_idx, group) for frame_idx in idx_list]
def demo_gui(): from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap import Labels from PySide2.QtWidgets import QApplication labels = Labels.load_file( "tests/data/json_format_v2/centered_pair_predictions.json" ) options_lists = dict(node=labels.skeletons[0].node_names) app = QApplication() win = YamlFormWidget.from_name( "suggestions", title="Generate Suggestions", field_options_lists=options_lists ) def demo_suggestions(params): print(params) x = VideoFrameSuggestions.suggest(params=params, labels=labels) for suggested_frame in x: print( suggested_frame.video.backend.filename, suggested_frame.frame_idx, suggested_frame.group, ) win.mainAction.connect(demo_suggestions) win.show() app.exec_() if __name__ == "__main__": demo_gui()