"""
Functions/classes used by multiple trackers.
Main types of functions:
1. Calculate pair-wise instance similarity; used for populating similarity/cost
matrix.
2. Pick matches based on cost matrix.
3. Other clean-up (e.g., cull instances, connect track breaks).
"""
import operator
from collections import defaultdict
from typing import List, Tuple, Optional, TypeVar, Callable
import attr
import numpy as np
from scipy.optimize import linear_sum_assignment
from sleap import PredictedInstance, Instance, Track
from sleap.nn import utils
InstanceType = TypeVar("InstanceType", Instance, PredictedInstance)
[docs]def instance_similarity(
ref_instance: InstanceType, query_instance: InstanceType
) -> float:
"""Computes similarity between instances."""
ref_visible = ~(np.isnan(ref_instance.points_array).any(axis=1))
dists = np.sum(
(query_instance.points_array - ref_instance.points_array) ** 2, axis=1
)
similarity = np.nansum(np.exp(-dists)) / np.sum(ref_visible)
return similarity
[docs]def centroid_distance(
ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict()
) -> float:
"""Returns the negative distance between the centroids of two instances.
Uses `cache` dictionary (created with function so it persists between calls)
since without cache this method is significantly slower than others.
"""
if ref_instance not in cache:
cache[ref_instance] = ref_instance.centroid
if query_instance not in cache:
cache[query_instance] = query_instance.centroid
a = cache[ref_instance]
b = cache[query_instance]
return -np.linalg.norm(a - b)
[docs]def instance_iou(
ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict()
) -> float:
"""Computes IOU between bounding boxes of instances."""
if ref_instance not in cache:
cache[ref_instance] = ref_instance.bounding_box
if query_instance not in cache:
cache[query_instance] = query_instance.bounding_box
a = cache[ref_instance]
b = cache[query_instance]
return utils.compute_iou(a, b)
[docs]def hungarian_matching(cost_matrix: np.ndarray) -> List[Tuple[int, int]]:
"""Wrapper for Hungarian matching algorithm in scipy."""
row_ind, col_ind = linear_sum_assignment(cost_matrix)
return list(zip(row_ind, col_ind))
[docs]def greedy_matching(cost_matrix: np.ndarray) -> List[Tuple[int, int]]:
"""
Performs greedy bipartite matching.
"""
# Sort edges by ascending cost.
rows, cols = np.unravel_index(np.argsort(cost_matrix, axis=None), cost_matrix.shape)
unassigned_edges = list(zip(rows, cols))
# Greedily assign edges.
assignments = []
while len(unassigned_edges) > 0:
# Assign the lowest cost edge.
row_ind, col_ind = unassigned_edges.pop(0)
assignments.append((row_ind, col_ind))
# Remove all other edges that contain either node (in reverse order).
for i in range(len(unassigned_edges) - 1, -1, -1):
if unassigned_edges[i][0] == row_ind or unassigned_edges[i][1] == col_ind:
del unassigned_edges[i]
return assignments
def nms_instances(
instances, iou_threshold, target_count=None
) -> Tuple[List[PredictedInstance], List[PredictedInstance]]:
boxes = np.array([inst.bounding_box for inst in instances])
scores = np.array([inst.score for inst in instances])
picks = nms_fast(boxes, scores, iou_threshold, target_count)
to_keep = [inst for i, inst in enumerate(instances) if i in picks]
to_remove = [inst for i, inst in enumerate(instances) if i not in picks]
return to_keep, to_remove
[docs]def nms_fast(boxes, scores, iou_threshold, target_count=None) -> List[int]:
"""https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/"""
# if there are no boxes, return an empty list
if len(boxes) == 0:
return []
# if we already have fewer boxes than the target count, return all boxes
if target_count and len(boxes) < target_count:
return list(range(len(boxes)))
# if the bounding boxes coordinates are integers, convert them to floats --
# this is important since we'll be doing a bunch of divisions
if boxes.dtype.kind == "i":
boxes = boxes.astype("float")
# initialize the list of picked indexes
picked_idxs = []
# init list of boxes removed by nms
nms_idxs = []
# grab the coordinates of the bounding boxes
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
# compute the area of the bounding boxes and sort the bounding
# boxes by their scores
area = (x2 - x1 + 1) * (y2 - y1 + 1)
idxs = np.argsort(scores)
# keep looping while some indexes still remain in the indexes list
while len(idxs) > 0:
# we want to add the best box which is the last box in sorted list
picked_box_idx = idxs[-1]
# last = len(idxs) - 1
# i = idxs[last]
picked_idxs.append(picked_box_idx)
# find the largest (x, y) coordinates for the start of
# the bounding box and the smallest (x, y) coordinates
# for the end of the bounding box
xx1 = np.maximum(x1[picked_box_idx], x1[idxs[:-1]])
yy1 = np.maximum(y1[picked_box_idx], y1[idxs[:-1]])
xx2 = np.minimum(x2[picked_box_idx], x2[idxs[:-1]])
yy2 = np.minimum(y2[picked_box_idx], y2[idxs[:-1]])
# compute the width and height of the bounding box
w = np.maximum(0, xx2 - xx1 + 1)
h = np.maximum(0, yy2 - yy1 + 1)
# compute the ratio of overlap
overlap = (w * h) / area[idxs[:-1]]
# find boxes with iou over threshold
nms_for_new_box = np.where(overlap > iou_threshold)[0]
nms_idxs.extend(list(idxs[nms_for_new_box]))
# delete new box (last in list) plus nms boxes
idxs = np.delete(idxs, nms_for_new_box)[:-1]
# if we're below the target number of boxes, add some back
if target_count and nms_idxs and len(picked_idxs) < target_count:
# sort by descending score
nms_idxs.sort(key=lambda idx: -scores[idx])
add_back_count = min(len(nms_idxs), len(picked_idxs) - target_count)
picked_idxs.extend(nms_idxs[:add_back_count])
# return the list of picked boxes
return picked_idxs
[docs]def cull_instances(
frames: List["LabeledFrame"],
instance_count: int,
iou_threshold: Optional[float] = None,
):
"""
Removes instances from frames over instance per frame threshold.
Args:
frames: The list of `LabeledFrame` objects with predictions.
instance_count: The maximum number of instances we want per frame.
iou_threshold: Intersection over Union (IOU) threshold to use when
removing overlapping instances over target count; if None, then
only use score to determine which instances to remove.
Returns:
None; modifies frames in place.
"""
if not frames:
return
frames.sort(key=lambda lf: lf.frame_idx)
lf_inst_list = []
# Find all frames with more instances than the desired threshold
for lf in frames:
if len(lf.predicted_instances) > instance_count:
# List of instances which we'll pare down
keep_instances = lf.predicted_instances
# Use NMS to remove overlapping instances over target count
if iou_threshold:
keep_instances, extra_instances = nms_instances(
keep_instances,
iou_threshold=iou_threshold,
target_count=instance_count,
)
# Mark for removal
lf_inst_list.extend([(lf, inst) for inst in extra_instances])
# Use lower score to remove instances over target count
if len(keep_instances) > instance_count:
# Sort by ascending score, get target number of instances
# from the end of list (i.e., with highest score)
extra_instances = sorted(
keep_instances, key=operator.attrgetter("score")
)[:-instance_count]
# Mark for removal
lf_inst_list.extend([(lf, inst) for inst in extra_instances])
# Remove instances over per frame threshold
for lf, inst in lf_inst_list:
lf.instances.remove(inst)
[docs]def cull_frame_instances(
instances_list: List[InstanceType],
instance_count: int,
iou_threshold: Optional[float] = None,
) -> List["LabeledFrame"]:
"""
Removes instances (for single frame) over instance per frame threshold.
Args:
instances_list: The list of instances for a single frame.
instance_count: The maximum number of instances we want per frame.
iou_threshold: Intersection over Union (IOU) threshold to use when
removing overlapping instances over target count; if None, then
only use score to determine which instances to remove.
Returns:
Updated list of frames, also modifies frames in place.
"""
if not instances_list:
return
if len(instances_list) > instance_count:
# List of instances which we'll pare down
keep_instances = instances_list
# Use NMS to remove overlapping instances over target count
if iou_threshold:
keep_instances, extra_instances = nms_instances(
keep_instances,
iou_threshold=iou_threshold,
target_count=instance_count,
)
# Remove the extra instances
for inst in extra_instances:
instances_list.remove(inst)
# Use lower score to remove instances over target count
if len(keep_instances) > instance_count:
# Sort by ascending score, get target number of instances
# from the end of list (i.e., with highest score)
extra_instances = sorted(keep_instances, key=operator.attrgetter("score"))[
:-instance_count
]
# Remove the extra instances
for inst in extra_instances:
instances_list.remove(inst)
return instances_list
[docs]def connect_single_track_breaks(
frames: List["LabeledFrame"], instance_count: int
) -> List["LabeledFrame"]:
"""
Merges breaks in tracks by connecting single lost with single new track.
Args:
frames: The list of `LabeledFrame` objects with predictions.
instance_count: The maximum number of instances we want per frame.
Returns:
Updated list of frames, also modifies frames in place.
"""
if not frames:
return frames
# Move instances in new tracks into tracks that disappeared on previous frame
fix_track_map = dict()
last_good_frame_tracks = {inst.track for inst in frames[0].instances}
for lf in frames:
frame_tracks = {inst.track for inst in lf.instances}
tracks_fixed_before = frame_tracks.intersection(set(fix_track_map.keys()))
if tracks_fixed_before:
for inst in lf.instances:
if (
inst.track in fix_track_map
and fix_track_map[inst.track] not in frame_tracks
):
inst.track = fix_track_map[inst.track]
frame_tracks = {inst.track for inst in lf.instances}
extra_tracks = frame_tracks - last_good_frame_tracks
missing_tracks = last_good_frame_tracks - frame_tracks
if len(extra_tracks) == 1 and len(missing_tracks) == 1:
for inst in lf.instances:
if inst.track in extra_tracks:
old_track = inst.track
new_track = missing_tracks.pop()
fix_track_map[old_track] = new_track
inst.track = new_track
break
else:
if len(frame_tracks) == instance_count:
last_good_frame_tracks = frame_tracks
return frames
[docs]@attr.s(auto_attribs=True, slots=True)
class Match:
"""Stores a match between a specific instance and specific track."""
track: Track
instance: Instance
score: Optional[float] = None
is_first_choice: bool = False
[docs]@attr.s(auto_attribs=True)
class FrameMatches:
"""
Calculates (and stores) matches for a frame.
This class encapsulates the logic to generate matches (using a custom
matching function) from a cost matrix. One key feature is that it retains
additional information, such as whether all the matches were first-choice
(i.e., if each instance got the instance it would have matched to if there
weren't other instances).
Typically this will be created using the `from_candidate_instances` method
which creates the cost matrix and then uses the matching function to find
matches.
Attributes:
matches: the list of `Match` objects.
cost_matrix: the cost matrix, shape is
(number of untracked instances, number of candidate tracks).
unmatched_instances: the instances for which we are finding matches.
"""
matches: List[Match]
cost_matrix: np.ndarray
unmatched_instances: List[InstanceType] = attr.ib(factory=list)
@property
def has_only_first_choice_matches(self) -> bool:
"""Whether all the matches were first-choice.
A match is a 'first-choice' for an instance if that instance would have
matched to the same track even if there were no other instances.
"""
return all(match.is_first_choice for match in self.matches)
@classmethod
def from_candidate_instances(
cls,
untracked_instances: List[InstanceType],
candidate_instances: List[InstanceType],
similarity_function: Callable,
matching_function: Callable,
):
cost = np.ndarray((0,))
candidate_tracks = []
if candidate_instances:
# Group candidate instances by track.
candidate_instances_by_track = defaultdict(list)
for instance in candidate_instances:
candidate_instances_by_track[instance.track].append(instance)
# Compute similarity matrix between untracked instances and best
# candidate for each track.
candidate_tracks = list(candidate_instances_by_track.keys())
matching_similarities = np.full(
(len(untracked_instances), len(candidate_tracks)), np.nan
)
for i, untracked_instance in enumerate(untracked_instances):
for j, candidate_track in enumerate(candidate_tracks):
# Compute similarity between untracked instance and all track
# candidates.
track_instances = candidate_instances_by_track[candidate_track]
track_matching_similarities = [
similarity_function(untracked_instance, candidate_instance,)
for candidate_instance in track_instances
]
# Keep the best scoring instance for this track.
best_ind = np.argmax(track_matching_similarities)
# Use the best similarity score for matching.
best_similarity = track_matching_similarities[best_ind]
matching_similarities[i, j] = best_similarity
# Perform matching between untracked instances and candidates.
cost = -matching_similarities
cost[np.isnan(cost)] = np.inf
return cls.from_cost_matrix(
cost, untracked_instances, candidate_tracks, matching_function
)
@classmethod
def from_cost_matrix(
cls,
cost_matrix: np.ndarray,
instances: List[InstanceType],
tracks: List[Track],
matching_function: Callable,
):
matches = []
match_instance_inds = []
if instances and tracks:
match_inds = matching_function(cost_matrix)
# Determine the first-choice match for each instance since we want
# to know whether or not all the matches in the frame were
# uncontested.
best_matches_vector = cost_matrix.argmin(axis=1)
# Assign each matched instance.
for i, j in match_inds:
match_instance_inds.append(i)
match_instance = instances[i]
match_track = tracks[j]
match_similarity = -cost_matrix[i, j]
is_first_choice = best_matches_vector[i] == j
# return matches as tuples
matches.append(
Match(
instance=match_instance,
track=match_track,
score=match_similarity,
is_first_choice=is_first_choice,
)
)
# Make list of untracked instances which we didn't match to anything
unmatched_instances = [
untracked_instance
for i, untracked_instance in enumerate(instances)
if i not in match_instance_inds
]
return cls(
cost_matrix=cost_matrix,
matches=matches,
unmatched_instances=unmatched_instances,
)
[docs]def first_choice_matching(cost_matrix: np.ndarray) -> List[Tuple[int, int]]:
"""
Returns match indices where each row gets matched to best column.
The means that multiple rows might be matched to the same column.
"""
row_count = len(cost_matrix)
best_matches_vector = cost_matrix.argmin(axis=1)
match_indices = list(zip(range(row_count), best_matches_vector))
return match_indices