Source code for sleap.nn.viz

"""Visualization and plotting utilities."""

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Union, Tuple, Optional, Text


[docs]def imgfig( size: Union[float, Tuple] = 6, dpi: int = 72, scale: float = 1.0 ) -> matplotlib.figure.Figure: """Create a tight figure for image plotting. Args: size: Scalar or 2-tuple specifying the (width, height) of the figure in inches. If scalar, will assume equal width and height. dpi: Dots per inch, controlling the resolution of the image. scale: Factor to scale the size of the figure by. This is a convenience for increasing the size of the plot at the same DPI. Returns: A matplotlib.figure.Figure to use for plotting. """ if not isinstance(size, (tuple, list)): size = (size, size) fig = plt.figure(figsize=(scale * size[0], scale * size[1]), dpi=dpi) ax = fig.add_axes([0, 0, 1, 1], frameon=False) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.autoscale(tight=True) ax.set_xticks([]) ax.set_yticks([]) ax.grid(False) return fig
[docs]def plot_img( img: np.ndarray, dpi: int = 72, scale: float = 1.0 ) -> matplotlib.figure.Figure: """Plot an image in a tight figure.""" if img.shape[0] == 1: # Squeeze out batch singleton dimension. img = img.squeeze(axis=0) # Check if image is grayscale (single channel). grayscale = img.shape[-1] == 1 if grayscale: # Squeeze out singleton channel. img = img.squeeze(axis=-1) # Normalize the range of pixel values. img_min = img.min() img_max = img.max() if img_min < 0.0 or img_max > 1.0: img = (img - img_min) / (img_max - img_min) fig = imgfig( size=(float(img.shape[1]) / dpi, float(img.shape[0]) / dpi), dpi=dpi, scale=scale, ) ax = fig.gca() ax.imshow( img, cmap="gray" if grayscale else None, origin="upper", extent=[-0.5, img.shape[1] - 0.5, img.shape[0] - 0.5, -0.5], ) return fig
[docs]def plot_confmaps(confmaps: np.ndarray, output_scale: float = 1.0): """Plot confidence maps reduced over channels.""" ax = plt.gca() return ax.imshow( np.squeeze(confmaps.max(axis=-1)), alpha=0.5, origin="upper", extent=[ -0.5, confmaps.shape[1] / output_scale - 0.5, confmaps.shape[0] / output_scale - 0.5, -0.5, ], )
[docs]def plot_peaks( pts_gt: np.ndarray, pts_pr: Optional[np.ndarray] = None, paired: bool = False ): """Plot ground truth and detected peaks.""" handles = [] ax = plt.gca() if paired and pts_pr is not None: for p_gt, p_pr in zip(pts_gt, pts_pr): handles.append( ax.plot([p_gt[0], p_pr[0]], [p_gt[1], p_pr[1]], "r-", alpha=0.5, lw=2) ) if pts_pr is not None: handles.append( ax.plot( pts_gt[..., 0].ravel(), pts_gt[..., 1].ravel(), "g.", alpha=0.7, ms=10, mew=1, mec="w", ) ) handles.append( ax.plot(pts_pr[:, 0], pts_pr[:, 1], "r.", alpha=0.7, ms=10, mew=1, mec="w") ) else: cmap = sns.color_palette("tab20") for i, pt in enumerate(pts_gt): handles.append( ax.plot( pt[0], pt[1], ".", alpha=0.7, ms=15, mew=1, mfc=cmap[i % len(cmap)], mec="w", ) ) return handles
[docs]def plot_pafs( pafs: np.ndarray, output_scale: float = 1.0, stride: int = 1, scale: float = 4.0, width: float = 1.0, cmap: Optional[Text] = None, ): """Quiver plot for a single frame of pafs.""" if cmap is None: cmap = sns.color_palette("tab20") pafs = pafs.reshape((pafs.shape[0], pafs.shape[1], -1, 2)) h_quivers = [] for k in range(pafs.shape[-2]): pafs_k = pafs[..., k, :] # rank 3 pafs_k = pafs_k[::stride, ::stride, :] h_quivers_k = plt.quiver( np.linspace( 0, (stride * pafs_k.shape[1] / output_scale), pafs_k.shape[1], endpoint=False, ), np.linspace( 0, (stride * pafs_k.shape[0] / output_scale), pafs_k.shape[0], endpoint=False, ), pafs_k[..., 0], pafs_k[..., 1], angles="xy", pivot="mid", units="xy", scale_units="xy", scale=1.0 / scale, color=cmap[k % len(cmap)], minlength=0.1, width=width, alpha=0.8, ) h_quivers.append(h_quivers_k) return h_quivers
[docs]def plot_instance( instance, skeleton=None, cmap=None, color_by_node=False, lw=2, ms=10, bbox=None, scale=1.0, **kwargs ): """Plot a single instance with edge coloring.""" if cmap is None: cmap = sns.color_palette("tab20") if skeleton is None: skeleton = instance.skeleton if len(skeleton.edges) == 0: color_by_node = True h_lines = [] if color_by_node: for k, (x, y) in enumerate(instance.points_array): if bbox is not None: x -= bbox[1] y -= bbox[0] x *= scale y *= scale h_lines_k = plt.plot(x, y, ".", ms=ms, c=cmap[k % len(cmap)], **kwargs) h_lines.append(h_lines_k) else: for k, (src_node, dst_node) in enumerate(skeleton.edges): src_pt = instance.points_array[instance.skeleton.node_to_index(src_node)] dst_pt = instance.points_array[instance.skeleton.node_to_index(dst_node)] x = np.array([src_pt[0], dst_pt[0]]) y = np.array([src_pt[1], dst_pt[1]]) if bbox is not None: x -= bbox[1] y -= bbox[0] x *= scale y *= scale h_lines_k = plt.plot( x, y, ".-", ms=ms, lw=lw, c=cmap[k % len(cmap)], **kwargs ) h_lines.append(h_lines_k) return h_lines
[docs]def plot_instances( instances, skeleton=None, cmap=None, color_by_track=False, tracks=None, **kwargs ): """Plot a list of instances with identity coloring.""" if cmap is None: cmap = sns.color_palette("tab10") if color_by_track and tracks is None: # Infer tracks for ordering if not provided. tracks = set() for instance in instances: tracks.add(instance.track) # Sort by spawned frame. tracks = sorted(list(tracks), key=lambda track: track.name) h_lines = [] for i, instance in enumerate(instances): if color_by_track: if instance.track is None: raise ValueError( "Instances must have a set track when coloring by track." ) if instance.track not in tracks: raise ValueError("Instance has a track not found in specified tracks.") color = cmap[tracks.index(instance.track) % len(cmap)] else: # Color by identity (order in list). color = cmap[i % len(cmap)] h_lines_i = plot_instance(instance, skeleton=skeleton, cmap=[color], **kwargs) h_lines.append(h_lines_i) return h_lines
def plot_bbox(bbox, **kwargs): if hasattr(bbox, "bounding_box"): bbox = bbox.bounding_box y1, x1, y2, x2 = bbox plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], "-", **kwargs)