"""This module provides a set of utilities for grouping peaks based on PAFs.
Part affinity fields (PAFs) are a representation used to resolve the peak grouping
problem for multi-instance pose estimation [1].
They are a convenient way to represent directed graphs with support in image space. For
each edge, a PAF can be represented by an image with two channels, corresponding to the
x and y components of a unit vector pointing along the direction of the underlying
directed graph formed by the connections of the landmarks belonging to an instance.
Given a pair of putatively connected landmarks, the agreement between the line segment
that connects them and the PAF vectors found at the coordinates along the same line can
be used as a measure of "connectedness". These scores can then be used to guide the
instance-wise grouping of landmarks.
This image space representation is particularly useful as it is amenable to neural
network-based prediction from unlabeled images.
References:
.. [1] Zhe Cao, Tomas Simon, Shih-En Wei, Yaser Sheikh. Realtime Multi-Person 2D
Pose Estimation using Part Affinity Fields. In _CVPR_, 2017.
"""
import attr
from typing import Dict, List, Union, Tuple, Text
import tensorflow as tf
import numpy as np
from scipy.optimize import linear_sum_assignment
from sleap.nn.config import MultiInstanceConfig
@attr.s(auto_attribs=True, slots=True, frozen=True)
class PeakID:
node_ind: int
peak_ind: int
@attr.s(auto_attribs=True, slots=True, frozen=True)
class EdgeType:
src_node_ind: int
dst_node_ind: int
@attr.s(auto_attribs=True, slots=True)
class EdgeConnection:
src_peak_ind: int
dst_peak_ind: int
score: float
[docs]def assign_connections_to_instances(
connections: Dict[EdgeType, List[EdgeConnection]],
min_instance_peaks: Union[int, float] = 0,
n_nodes: int = None,
) -> Dict[PeakID, int]:
"""Assigns connected edges to instances via greedy graph partitioning.
Args:
connections: A dict that maps EdgeType to a list of EdgeConnections found
through connection scoring. This can be generated by the
filter_connection_candidates function.
min_instance_peaks: If this is greater than 0, grouped instances with fewer
assigned peaks than this threshold will be excluded. If a float in the
range (0., 1.] is provided, this is interpreted as a fraction of the total
number of nodes in the skeleton. If an integer is provided, this is the
absolute minimum number of peaks.
n_nodes: Total node type count. Used to convert min_instance_peaks to an
absolute number when a fraction is specified. If not provided, the node
count is inferred from the unique node inds in connections.
Returns:
instance_assignments: A dict mapping PeakID to a unique instance ID specified
as an integer.
A PeakID is a tuple of (node_type_ind, peak_ind), where the peak_ind is the
index or identifier specified in a EdgeConnection as a src_peak_ind or
dst_peak_ind.
Note:
Instance IDs are not necessarily consecutive since some instances may be
filtered out during the partitioning or filtering.
This function expects connections from a single sample/frame!
"""
# Grouping table that maps PeakID(node_ind, peak_ind) to an instance_id.
instance_assignments = dict()
# Loop through edge types.
for edge_type, edge_connections in connections.items():
# Loop through connections for the current edge.
for connection in edge_connections:
# Notation: specific peaks are identified by (node_ind, peak_ind).
src_id = PeakID(edge_type.src_node_ind, connection.src_peak_ind)
dst_id = PeakID(edge_type.dst_node_ind, connection.dst_peak_ind)
# Get instance assignments for the connection peaks.
src_instance = instance_assignments.get(src_id, None)
dst_instance = instance_assignments.get(dst_id, None)
if src_instance is None and dst_instance is None:
# Case 1: Neither peak is assigned to an instance yet. We'll create a
# new instance to hold both.
new_instance = max(instance_assignments.values(), default=-1) + 1
instance_assignments[src_id] = new_instance
instance_assignments[dst_id] = new_instance
elif src_instance is not None and dst_instance is None:
# Case 2: The source peak is assigned already, but not the destination
# peak. We'll assign the destination peak to the same instance as the
# source.
instance_assignments[dst_id] = src_instance
elif src_instance is not None and dst_instance is not None:
# Case 3: Both peaks have been assigned. We'll update the destination
# peak to be a part of the source peak instance.
instance_assignments[dst_id] = src_instance
# We'll also check if they form disconnected subgraphs, in which case
# we'll merge them by assigning all peaks belonging to the destination
# peak's instance to the source peak's instance.
src_instance_nodes = set(
peak_id.node_ind
for peak_id, instance in instance_assignments.items()
if instance == src_instance
)
dst_instance_nodes = set(
peak_id.node_ind
for peak_id, instance in instance_assignments.items()
if instance == dst_instance
)
if len(src_instance_nodes.intersection(dst_instance_nodes)) == 0:
for peak_id in instance_assignments:
if instance_assignments[peak_id] == dst_instance:
instance_assignments[peak_id] = src_instance
if min_instance_peaks > 0:
if isinstance(min_instance_peaks, float):
if n_nodes is None:
# Infer number of nodes if not specified.
all_node_types = set()
for edge_type in connections:
all_node_types.add(edge_type.src_node_ind)
all_node_types.add(edge_type.dst_node_ind)
n_nodes = len(all_node_types)
# Calculate minimum threshold.
min_instance_peaks = int(min_instance_peaks * n_nodes)
# Compute instance peak counts.
instance_ids, instance_peak_counts = np.unique(
list(instance_assignments.values()), return_counts=True
)
instance_peak_counts = {
instance: peaks_count
for instance, peaks_count in zip(instance_ids, instance_peak_counts)
}
# Filter out small instances.
instance_assignments = {
peak_id: instance
for peak_id, instance in instance_assignments.items()
if instance_peak_counts[instance] >= min_instance_peaks
}
return instance_assignments
def make_predicted_instances(peaks, peak_scores, connections, instance_assignments):
# Ensure instance IDs are contiguous.
instance_ids, instance_inds = np.unique(
list(instance_assignments.values()), return_inverse=True
)
for peak_id, instance_ind in zip(instance_assignments.keys(), instance_inds):
instance_assignments[peak_id] = instance_ind
n_instances = len(instance_ids)
# Compute instance scores as the sum of all edge scores.
predicted_instance_scores = np.full((n_instances,), 0.0, dtype="float32")
for edge_type, edge_connections in connections.items():
# Loop over all connections for this edge type.
for edge_connection in edge_connections:
# Look up the source peak.
src_peak_id = PeakID(
node_ind=edge_type.src_node_ind, peak_ind=edge_connection.src_peak_ind
)
if src_peak_id in instance_assignments:
# Add to the total instance score.
instance_ind = instance_assignments[src_peak_id]
predicted_instance_scores[instance_ind] += edge_connection.score
# Sanity check: both peaks in the edge should have been assigned to the
# same instance.
dst_peak_id = PeakID(
node_ind=edge_type.dst_node_ind,
peak_ind=edge_connection.dst_peak_ind,
)
assert instance_ind == instance_assignments[dst_peak_id]
# Fill in assigned peak data.
# n_nodes = peaks.shape[0]
n_nodes = len(peaks)
predicted_instances = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32")
predicted_peak_scores = np.full((n_instances, n_nodes), np.nan, dtype="float32")
for peak_id, instance_ind in instance_assignments.items():
predicted_instances[instance_ind, peak_id.node_ind, :] = peaks[
peak_id.node_ind
][peak_id.peak_ind]
predicted_peak_scores[instance_ind, peak_id.node_ind] = peak_scores[
peak_id.node_ind
][peak_id.peak_ind]
return predicted_instances, predicted_peak_scores, predicted_instance_scores
@attr.s(auto_attribs=True)
class PAFScorer:
part_names: List[Text]
edges: List[Tuple[Text, Text]]
pafs_stride: int
max_edge_length: float = 128
min_edge_score: float = 0.05
n_points: int = 10
min_instance_peaks: Union[int, float] = 0
edge_inds: List[Tuple[int, int]] = attr.ib(init=False)
edge_types: List[EdgeType] = attr.ib(init=False)
n_nodes: int = attr.ib(init=False)
n_edges: int = attr.ib(init=False)
def __attrs_post_init__(self):
self.edge_inds = [
(self.part_names.index(src), self.part_names.index(dst))
for (src, dst) in self.edges
]
self.edge_types = [
EdgeType(src_node, dst_node) for src_node, dst_node in self.edge_inds
]
self.n_nodes = len(self.part_names)
self.n_edges = len(self.edges)
@classmethod
def from_config(
cls,
config: MultiInstanceConfig,
max_edge_length: float = 128,
min_edge_score: float = 0.05,
n_points: int = 10,
min_instance_peaks: Union[int, float] = 0,
) -> "PAFScorer":
return cls(
part_names=config.confmaps.part_names,
edges=config.pafs.edges,
pafs_stride=config.pafs.output_stride,
max_edge_length=max_edge_length,
min_edge_score=min_edge_score,
n_points=n_points,
min_instance_peaks=min_instance_peaks,
)
def sample_edge_line(self, paf, src_peak, dst_peak):
paf_x = tf.gather(paf, 0, axis=-1)
paf_y = tf.gather(paf, 1, axis=-1)
max_x = tf.cast(tf.shape(paf_x)[1] - 1, tf.float32)
max_y = tf.cast(tf.shape(paf_x)[0] - 1, tf.float32)
line_x = tf.linspace(src_peak[0], dst_peak[0], self.n_points)
line_y = tf.linspace(src_peak[1], dst_peak[1], self.n_points)
line_x /= tf.cast(self.pafs_stride, tf.float32)
line_y /= tf.cast(self.pafs_stride, tf.float32)
line_x = tf.clip_by_value(tf.round(line_x), 0, max_x)
line_y = tf.clip_by_value(tf.round(line_y), 0, max_y)
line_x = tf.cast(line_x, tf.int32)
line_y = tf.cast(line_y, tf.int32)
line_subs = tf.stack([line_y, line_x], axis=1)
line_paf_x = tf.gather_nd(paf_x, line_subs)
line_paf_y = tf.gather_nd(paf_y, line_subs)
line_paf = tf.stack([line_paf_x, line_paf_y], axis=-1) # (n_points, 2)
return line_paf
def score_pair(self, line_paf, src_peak, dst_peak):
# Normalized spatial vector
spatial_vec = dst_peak - src_peak
spatial_vec_length = tf.norm(spatial_vec)
spatial_vec /= spatial_vec_length
# Compute dot product scores
line_scores = tf.squeeze(
line_paf @ tf.expand_dims(spatial_vec, axis=-1), axis=-1
) # (n_points,)
# Compute average line scores with distance penalty.
dist_penalty = (
tf.cast(self.max_edge_length, tf.float32) / spatial_vec_length
) - 1
# Compute overall line score
line_score = tf.reduce_mean(line_scores)
line_score_with_dist_penalty = line_score + tf.minimum(dist_penalty, 0)
# Compute fraction of connections above threshold.
fraction_correct = tf.reduce_mean(
tf.cast(line_scores > self.min_edge_score, tf.float32)
)
return line_score_with_dist_penalty, fraction_correct
def score_edge(self, paf, src_peaks, dst_peaks):
line_scores = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
fraction_correct = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
# Iterate over source peaks.
for i in range(len(src_peaks)):
line_scores_i = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
fraction_correct_i = tf.TensorArray(
dtype=tf.float32, size=0, dynamic_size=True
)
# Iterate over destination peaks.
for j in range(len(dst_peaks)):
# Pull out peaks.
src_peak = src_peaks[i]
dst_peak = dst_peaks[j]
# Get line integral from PAF tensor.
line_paf = self.sample_edge_line(paf, src_peak, dst_peak)
# Compute scores from line integral.
line_score_ij, fraction_correct_ij = self.score_pair(
line_paf, src_peak, dst_peak
)
line_scores_i = line_scores_i.write(j, line_score_ij)
fraction_correct_i = fraction_correct_i.write(j, fraction_correct_ij)
line_scores_i = line_scores_i.stack()
fraction_correct_i = fraction_correct_i.stack()
line_scores = line_scores.write(i, line_scores_i)
fraction_correct = line_scores.write(i, fraction_correct_i)
line_scores = line_scores.stack()
fraction_correct = fraction_correct.stack()
return line_scores, fraction_correct
def score_and_match_edge(self, paf, src_peaks, dst_peaks):
# Compute scores from PAF line integrals.
line_scores, fraction_correct = self.score_edge(paf, src_peaks, dst_peaks)
# Replace NaNs with inf since linear_sum_assignment doesn't accept NaNs
line_costs = tf.where(
condition=tf.math.is_nan(line_scores),
x=tf.constant([np.inf]),
y=-line_scores,
)
# Match edge candidates.
src_inds, dst_inds = tf.py_function(
linear_sum_assignment, inp=[line_costs], Tout=[tf.int32, tf.int32]
)
# Pull out matched scores.
match_subs = tf.stack([src_inds, dst_inds], axis=1)
line_scores = tf.gather_nd(line_scores, match_subs)
fraction_correct = tf.gather_nd(fraction_correct, match_subs)
return src_inds, dst_inds, line_scores, fraction_correct
def match_all_peaks(self, pafs, flat_peaks, flat_channel_inds):
# Make sure PAFs are unflattened into (..., n_edges, 2).
pafs = tf.reshape(pafs, [tf.shape(pafs)[0], tf.shape(pafs)[1], -1, 2])
# Sort peaks by channel
sort_idx = tf.argsort(flat_channel_inds)
peaks = tf.gather(flat_peaks, sort_idx)
channel_inds = tf.gather(flat_channel_inds, sort_idx)
# Group peaks by channel
peaks = tf.RaggedTensor.from_value_rowids(
values=peaks, value_rowids=channel_inds, nrows=self.n_nodes
)
# Initialize dynamically sized containers.
all_edge_inds = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
all_src_inds = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
all_dst_inds = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
all_line_scores = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
all_fraction_correct = tf.TensorArray(
dtype=tf.float32, size=0, dynamic_size=True
)
# Iterate over edges.
for edge_ind in range(self.n_edges):
# Pull out edge data.
paf = tf.gather(pafs, edge_ind, axis=-2)
src_peaks = peaks[self.edge_inds[edge_ind][0]]
dst_peaks = peaks[self.edge_inds[edge_ind][1]]
# Score the edge.
(
src_inds,
dst_inds,
line_scores,
fraction_correct,
) = self.score_and_match_edge(paf, src_peaks, dst_peaks)
# Store edge results.
all_edge_inds = all_edge_inds.write(
edge_ind, tf.broadcast_to(edge_ind, tf.shape(src_inds))
)
all_src_inds = all_src_inds.write(edge_ind, src_inds)
all_dst_inds = all_dst_inds.write(edge_ind, dst_inds)
all_line_scores = all_line_scores.write(edge_ind, line_scores)
all_fraction_correct = all_fraction_correct.write(
edge_ind, fraction_correct
)
# Concatenate dynamic tensors into flat ones. These can be split again by using
# flat_edge_inds as a grouping vector.
flat_edge_inds = all_edge_inds.concat()
flat_src_inds = all_src_inds.concat()
flat_dst_inds = all_dst_inds.concat()
flat_line_scores = all_line_scores.concat()
flat_fraction_correct = all_fraction_correct.concat()
return (
flat_edge_inds,
flat_src_inds,
flat_dst_inds,
flat_line_scores,
flat_fraction_correct,
)
def match_instances(
self,
flat_peaks,
flat_peak_scores,
flat_channel_inds,
flat_edge_inds,
flat_src_peak_inds,
flat_dst_peak_inds,
flat_line_scores,
flat_fraction_correct,
):
# Convert all the data to numpy arrays.
flat_peaks = flat_peaks.numpy()
flat_peak_scores = flat_peak_scores.numpy()
flat_channel_inds = flat_channel_inds.numpy()
flat_edge_inds = flat_edge_inds.numpy()
flat_src_peak_inds = flat_src_peak_inds.numpy()
flat_dst_peak_inds = flat_dst_peak_inds.numpy()
flat_line_scores = flat_line_scores.numpy()
flat_fraction_correct = flat_fraction_correct.numpy()
# Group peaks by channel.
peaks = []
peak_scores = []
for i in range(self.n_nodes):
in_channel = flat_channel_inds == i
peaks.append(flat_peaks[in_channel])
peak_scores.append(flat_peak_scores[in_channel])
# Group connection data by edge.
src_peak_inds = []
dst_peak_inds = []
line_scores = []
fraction_correct = []
for i in range(self.n_edges):
in_edge = flat_edge_inds == i
src_peak_inds.append(flat_src_peak_inds[in_edge])
dst_peak_inds.append(flat_dst_peak_inds[in_edge])
line_scores.append(flat_line_scores[in_edge])
fraction_correct.append(flat_fraction_correct[in_edge])
# Form connections structure.
connections = dict()
for edge_ind, (src_peak_ind, dst_peak_ind, line_score) in enumerate(
zip(src_peak_inds, dst_peak_inds, line_scores)
):
connections[self.edge_types[edge_ind]] = [
EdgeConnection(src, dst, score)
for src, dst, score in zip(src_peak_ind, dst_peak_ind, line_score)
]
# Bipartite graph partitioning to group connections into instances.
instance_assignments = assign_connections_to_instances(
connections,
min_instance_peaks=self.min_instance_peaks,
n_nodes=self.n_nodes,
)
# Gather the data by instance.
(
predicted_instances,
predicted_peak_scores,
predicted_instance_scores,
) = make_predicted_instances(
peaks, peak_scores, connections, instance_assignments
)
return predicted_instances, predicted_peak_scores, predicted_instance_scores
def match_with_pafs(self, pafs, flat_peaks, flat_peak_scores, flat_channel_inds):
# Match peaks within each edge using PAF scores.
(
flat_edge_inds,
flat_src_peak_inds,
flat_dst_peak_inds,
flat_line_scores,
flat_fraction_correct,
) = self.match_all_peaks(pafs, flat_peaks, flat_channel_inds)
# Given matched peaks, group them into instances.
(
predicted_instances,
predicted_peak_scores,
predicted_instance_scores,
) = tf.py_function(
self.match_instances,
inp=[
flat_peaks,
flat_peak_scores,
flat_channel_inds,
flat_edge_inds,
flat_src_peak_inds,
flat_dst_peak_inds,
flat_line_scores,
flat_fraction_correct,
],
Tout=[tf.float32, tf.float32, tf.float32],
)
return predicted_instances, predicted_peak_scores, predicted_instance_scores
@attr.s(auto_attribs=True)
class PartAffinityFieldInstanceGrouper:
paf_scorer: PAFScorer
peaks_key: Text = "predicted_peaks"
peak_scores_key: Text = "predicted_peak_confidences"
channel_inds_key: Text = "predicted_peak_channel_inds"
pafs_key: Text = "predicted_part_affinity_fields"
predicted_instances_key: Text = "predicted_instances"
predicted_peak_scores_key: Text = "predicted_peak_scores"
predicted_instance_scores_key: Text = "predicted_instance_scores"
keep_pafs: bool = False
@classmethod
def from_config(
cls,
config: MultiInstanceConfig,
max_edge_length: float = 128,
min_edge_score: float = 0.05,
n_points: int = 10,
min_instance_peaks: Union[int, float] = 0,
peaks_key: Text = "predicted_peaks",
peak_scores_key: Text = "predicted_peak_confidences",
channel_inds_key: Text = "predicted_peak_channel_inds",
pafs_key: Text = "predicted_part_affinity_fields",
predicted_instances_key: Text = "predicted_instances",
predicted_peak_scores_key: Text = "predicted_peak_scores",
predicted_instance_scores_key: Text = "predicted_instance_scores",
keep_pafs: bool = False,
) -> "PartAffinityFieldInstanceGrouper":
return cls(
paf_scorer=PAFScorer.from_config(
config,
max_edge_length=max_edge_length,
min_edge_score=min_edge_score,
n_points=n_points,
min_instance_peaks=min_instance_peaks,
),
peaks_key=peaks_key,
peak_scores_key=peak_scores_key,
channel_inds_key=channel_inds_key,
pafs_key=pafs_key,
predicted_instances_key=predicted_instances_key,
predicted_peak_scores_key=predicted_peak_scores_key,
predicted_instance_scores_key=predicted_instance_scores_key,
keep_pafs=keep_pafs,
)
@property
def input_keys(self) -> List[Text]:
return [
self.peaks_key,
self.peak_scores_key,
self.channel_inds_key,
self.pafs_key,
]
@property
def output_keys(self) -> List[Text]:
return self.input_keys + [
self.predicted_instances_key,
self.predicted_peak_scores_key,
self.predicted_instance_scores_key,
]
def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
def group_instances(example):
# Pull out example data.
pafs = example[self.pafs_key]
flat_peaks = example[self.peaks_key]
flat_peak_scores = example[self.peak_scores_key]
flat_channel_inds = example[self.channel_inds_key]
# Run matching.
(
predicted_instances,
predicted_peak_scores,
predicted_instance_scores,
) = self.paf_scorer.match_with_pafs(
pafs, flat_peaks, flat_peak_scores, flat_channel_inds
)
# Update example.
example[self.predicted_instances_key] = predicted_instances
example[self.predicted_peak_scores_key] = predicted_peak_scores
example[self.predicted_instance_scores_key] = predicted_instance_scores
if not self.keep_pafs:
# Drop PAFs.
example.pop(self.pafs_key)
return example
output_ds = input_ds.map(
group_instances, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return output_ds