Source code for sleap.info.metrics

"""
Module for producing prediction metrics for SLEAP datasets.
"""
from inspect import signature
import numpy as np
from scipy.optimize import linear_sum_assignment
from typing import Callable, List, Optional, Union, Tuple

from sleap.instance import Instance, PredictedInstance
from sleap.io.dataset import Labels


[docs]def matched_instance_distances( labels_gt: Labels, labels_pr: Labels, match_lists_function: Callable, frame_range: Optional[range] = None, ) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: """ Distances between ground truth and predicted nodes over a set of frames. Args: labels_gt: the `Labels` object with ground truth data labels_pr: the `Labels` object with predicted data match_lists_function: function for determining corresponding instances Takes two lists of instances and returns "sorted" lists. frame_range (optional): range of frames for which to compare data If None, we compare every frame in labels_gt with corresponding frame in labels_pr. Returns: Tuple: * frame indices map: instance idx (for other matrices) -> frame idx * distance matrix: (instances * nodes) * ground truth points matrix: (instances * nodes * 2) * predicted points matrix: (instances * nodes * 2) """ frame_idxs = [] points_gt = [] points_pr = [] for lf_gt in labels_gt.find(labels_gt.videos[0]): frame_idx = lf_gt.frame_idx # Get instances from ground truth/predicted labels instances_gt = lf_gt.instances lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx) if len(lfs_pr): instances_pr = lfs_pr[0].instances else: instances_pr = [] # Sort ground truth and predicted instances. # We'll then compare points between corresponding items in lists. # We can use different "match" functions depending on what we want. sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr) # Convert lists of instances to (instances, nodes, 2) matrices. # This allows match_lists_function to return data as either # a list of Instances or a (instances, nodes, 2) matrix. if type(sorted_gt[0]) != np.ndarray: sorted_gt = list_points_array(sorted_gt) if type(sorted_pr[0]) != np.ndarray: sorted_pr = list_points_array(sorted_pr) points_gt.append(sorted_gt) points_pr.append(sorted_pr) frame_idxs.extend([frame_idx] * len(sorted_gt)) # Convert arrays to numpy matrixes # instances * nodes * (x,y) points_gt = np.concatenate(points_gt) points_pr = np.concatenate(points_pr) # Calculate distances between corresponding nodes for all corresponding # ground truth and predicted instances. D = np.linalg.norm(points_gt - points_pr, axis=2) return frame_idxs, D, points_gt, points_pr
[docs]def match_instance_lists( instances_a: List[Union[Instance, PredictedInstance]], instances_b: List[Union[Instance, PredictedInstance]], cost_function: Callable, ) -> Tuple[ List[Union[Instance, PredictedInstance]], List[Union[Instance, PredictedInstance]] ]: """Sorts two lists of Instances to find best overall correspondence for a given cost function (e.g., total distance between points).""" pairwise_distance_matrix = calculate_pairwise_cost( instances_a, instances_b, cost_function ) match_a, match_b = linear_sum_assignment(pairwise_distance_matrix) sorted_a = list(map(lambda idx: instances_a[idx], match_a)) sorted_b = list(map(lambda idx: instances_b[idx], match_b)) return sorted_a, sorted_b
[docs]def calculate_pairwise_cost( instances_a: List[Union[Instance, PredictedInstance]], instances_b: List[Union[Instance, PredictedInstance]], cost_function: Callable, ) -> np.ndarray: """Calculate (a * b) matrix of pairwise costs using cost function.""" matrix_size = (len(instances_a), len(instances_b)) pairwise_cost_matrix = np.full(matrix_size, np.inf) for idx_a, inst_a in enumerate(instances_a): for idx_b, inst_b in enumerate(instances_b): # cost_function can either take a single input or two inputs # single input: ndarray of distances between corresponding nodes # two inputs: the pair of instances if len(signature(cost_function).parameters) == 1: point_dist_array = point_dist(inst_a, inst_b) cost = cost_function(point_dist_array) else: cost = cost_function(inst_a, inst_b) pairwise_cost_matrix[idx_a, idx_b] = cost return pairwise_cost_matrix
[docs]def match_instance_lists_nodewise( instances_a: List[Union[Instance, PredictedInstance]], instances_b: List[Union[Instance, PredictedInstance]], thresh: float = 5, ) -> Tuple[ List[Union[Instance, PredictedInstance]], List[Union[Instance, PredictedInstance]] ]: """For each node for each instance in the first list, pairs it with the closest corresponding node from *any* instance in the second list.""" node_count = len(instances_a[0].skeleton.nodes) b_points_arrays = list_points_array(instances_b) best_points_array = [] for inst_a in instances_a: # Calculate distance from nodes in A to nodes for each B dist_array = [] for inst_b in instances_b: dist_array.append(point_dist(inst_a, inst_b)) dist_array = np.stack(dist_array) # Construct matrix with closest node from any B closest_point_array = np.full((node_count, 2), np.nan) for node_idx in range(node_count): # Make sure there's some prediction for this node if any(~np.isnan(dist_array[:, node_idx])): best_idx = np.nanargmin(dist_array[:, node_idx]) # Ignore closest point if distance is beyond threshold if dist_array[best_idx, node_idx] <= thresh: closest_point_array[node_idx] = b_points_arrays[best_idx, node_idx] # Add matrix of points to compare against ground truth instance best_points_array.append(closest_point_array) return instances_a, best_points_array
[docs]def point_dist( inst_a: Union[Instance, PredictedInstance], inst_b: Union[Instance, PredictedInstance], ) -> np.ndarray: """Given two instances, returns array of distances for corresponding nodes.""" points_a = inst_a.points_array points_b = inst_b.points_array point_dist = np.linalg.norm(points_a - points_b, axis=1) return point_dist
[docs]def nodeless_point_dist( inst_a: Union[Instance, PredictedInstance], inst_b: Union[Instance, PredictedInstance], ) -> np.ndarray: """Given two instances, returns array of distances for closest points ignoring node identities.""" matrix_size = (len(inst_a.skeleton.nodes), len(inst_b.skeleton.nodes)) pairwise_distance_matrix = np.full(matrix_size, 0) points_a = inst_a.points_array points_b = inst_b.points_array # Calculate the distance between any pair of inst A and inst B points for idx_a in range(points_a.shape[0]): for idx_b in range(points_b.shape[0]): pair_distance = np.linalg.norm(points_a[idx_a] - points_b[idx_b]) if not np.isnan(pair_distance): pairwise_distance_matrix[idx_a, idx_b] = pair_distance # Match A and B points to sum of distances match_a, match_b = linear_sum_assignment(pairwise_distance_matrix) # Sort points by this match and calculate overall distance sorted_points_a = points_a[match_a, :] sorted_points_b = points_b[match_b, :] point_dist = np.linalg.norm(points_a - points_b, axis=1) return point_dist
[docs]def compare_instance_lists( instances_a: List[Union[Instance, PredictedInstance]], instances_b: List[Union[Instance, PredictedInstance]], ) -> np.ndarray: """Given two lists of corresponding Instances, returns (instances * nodes) matrix of distances between corresponding nodes.""" paired_points_array_distances = [] for inst_a, inst_b in zip(instances_a, instances_b): paired_points_array_distances.append(point_dist(inst_a, inst_b)) return np.stack(paired_points_array_distances)
[docs]def list_points_array( instances: List[Union[Instance, PredictedInstance]] ) -> np.ndarray: """Given list of Instances, returns (instances * nodes * 2) matrix.""" points_arrays = list(map(lambda inst: inst.points_array, instances)) return np.stack(points_arrays)
[docs]def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are <= threshold.""" return np.sum(dist_array[~np.isnan(dist_array)] <= thresh)
[docs]def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are not <= threshold.""" return dist_array.shape[0] - point_match_count(dist_array, thresh)
if __name__ == "__main__": labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json") labels_pr = Labels.load_json( "tests/data/json_format_v2/centered_pair_predictions.json" ) # OPTION 1 # Match each ground truth instance node to the closest corresponding node # from any predicted instance in the same frame. nodewise_matching_func = match_instance_lists_nodewise # OPTION 2 # Match each ground truth instance to a distinct predicted instance: # We want to maximize the number of "matching" points between instances, # where "match" means the points are within some threshold distance. # Note that each sorted list will be as long as the shorted input list. instwise_matching_func = lambda gt_list, pr_list: match_instance_lists( gt_list, pr_list, point_nonmatch_count ) # PICK THE FUNCTION inst_matching_func = nodewise_matching_func # inst_matching_func = instwise_matching_func # Calculate distances frame_idxs, D, points_gt, points_pr = matched_instance_distances( labels_gt, labels_pr, inst_matching_func ) # Show mean difference for each node node_names = labels_gt.skeletons[0].node_names for node_idx, node_name in enumerate(node_names): mean_d = np.nanmean(D[..., node_idx]) print(f"{node_name}\t\t{mean_d}")