"""
Overlay for part affinity fields.
Currently a `DataOverlay` gets data from a model (i.e., it runs inference on the
current frame) and then uses a `MultiQuiverPlot` object to show the resulting
part affinity fields.
"""
from PySide2 import QtWidgets, QtGui, QtCore
import numpy as np
import itertools
import math
from sleap.io.video import HDF5Video
from sleap.gui.overlays.base import DataOverlay, h5_colors
[docs]class PafOverlay(DataOverlay):
"""Class show pafs saved in HDF5 (not currently used)."""
@classmethod
def from_h5(cls, filename, input_format="channels_last", **kwargs):
return DataOverlay.from_h5(
filename, "/pafs", input_format, overlay_class=MultiQuiverPlot, **kwargs
)
[docs]class MultiQuiverPlot(QtWidgets.QGraphicsObject):
"""QtWidgets.QGraphicsObject to display multiple quiver plots in a QtWidgets.QGraphicsView.
Args:
frame (numpy.array): Data for one frame of quiver plot data.
Shape of array should be (channels, height, width).
show (list, optional): List of channels to show. If None, show all channels.
decimation (int, optional): Decimation factor. If 1, show every arrow.
Returns:
None.
Note:
Each channel corresponds to two (h, w) arrays: x and y for the vector.
When initialized, creates one child QuiverPlot item for each channel.
"""
def __init__(
self,
frame: np.array = None,
show: list = None,
decimation: int = 1,
scale: float = 1.0,
*args,
**kwargs,
):
super(MultiQuiverPlot, self).__init__(*args, **kwargs)
self.frame = frame
self.affinity_field = []
self.decimation = decimation
self.scale = scale
# if data range is outside [-1, 1], assume it's [-255, 255] and scale
if np.ptp(self.frame) > 4:
self.frame = self.frame.astype(np.float64) / 255
if show is None:
self.show_list = range(self.frame.shape[2] // 2)
else:
self.show_list = show
for channel in self.show_list:
if channel < self.frame.shape[-1] // 2:
color_map = h5_colors[channel % len(h5_colors)]
aff_field_item = QuiverPlot(
field_x=self.frame[..., channel * 2],
field_y=self.frame[..., channel * 2 + 1],
color=color_map,
decimation=self.decimation,
scale=self.scale,
parent=self,
)
self.affinity_field.append(aff_field_item)
[docs] def boundingRect(self) -> QtCore.QRectF:
"""Method required by Qt.
"""
return QtCore.QRectF()
[docs] def paint(self, painter, option, widget=None):
"""Method required by Qt.
"""
pass
[docs]class QuiverPlot(QtWidgets.QGraphicsObject):
"""QtWidgets.QGraphicsObject for drawing single quiver plot.
Args:
field_x (numpy.array): (h, w) array of x component of vectors.
field_y (numpy.array): (h, w) array of y component of vectors.
color (list, optional): Arrow color. Format as (r, g, b) array.
decimation (int, optional): Decimation factor. If 1, show every arrow.
Returns:
None.
"""
def __init__(
self,
field_x: np.array = None,
field_y: np.array = None,
color=[255, 255, 255],
decimation=1,
scale=1,
*args,
**kwargs,
):
super(QuiverPlot, self).__init__(*args, **kwargs)
self.field_x, self.field_y = None, None
self.color = color
self.decimation = decimation
self.scale = scale
pen_width = min(4, max(0.1, math.log(self.decimation, 20)))
self.pen = QtGui.QPen(QtGui.QColor(*self.color), pen_width)
self.points = []
self.rect = QtCore.QRectF()
if field_x is not None and field_y is not None:
self.field_x, self.field_y = field_x, field_y
h, w = self.field_x.shape
h, w = int(h * self.scale), int(w * self.scale)
self.rect = QtCore.QRectF(0, 0, w, h)
self._add_arrows()
def _add_arrows(self, min_length=0.01):
points = []
if self.field_x is not None and self.field_y is not None:
raw_delta_yx = np.stack((self.field_y, self.field_x), axis=-1)
dim_0 = self.field_x.shape[0] // self.decimation * self.decimation
dim_1 = self.field_x.shape[1] // self.decimation * self.decimation
grid = np.mgrid[0 : dim_0 : self.decimation, 0 : dim_1 : self.decimation]
loc_yx = np.moveaxis(grid, 0, -1)
# Adjust by scaling factor
loc_yx = loc_yx * self.scale
if self.decimation > 1:
delta_yx = self._decimate(raw_delta_yx, self.decimation)
# Shift locations to midpoint of decimation square
loc_yx += self.decimation // 2
else:
delta_yx = raw_delta_yx
delta_yx = delta_yx * self.scale
# Split into x,y matrices
loc_y, loc_x = loc_yx[..., 0], loc_yx[..., 1]
delta_y, delta_x = delta_yx[..., 0], delta_yx[..., 1]
# Determine vector endpoint
x2 = delta_x * self.decimation + loc_x
y2 = delta_y * self.decimation + loc_y
line_length = (delta_x ** 2 + delta_y ** 2) ** 0.5
# Determine points for arrow
arrow_head_size = line_length / 4
u_dx = np.divide(
delta_x, line_length, out=np.zeros_like(delta_x), where=line_length != 0
)
u_dy = np.divide(
delta_y, line_length, out=np.zeros_like(delta_y), where=line_length != 0
)
p1_x = x2 - u_dx * arrow_head_size - u_dy * arrow_head_size
p1_y = y2 - u_dy * arrow_head_size + u_dx * arrow_head_size
p2_x = x2 - u_dx * arrow_head_size + u_dy * arrow_head_size
p2_y = y2 - u_dy * arrow_head_size - u_dx * arrow_head_size
# Build list of QPointF objects for faster drawing
y_x_pairs = itertools.product(
range(delta_yx.shape[0]), range(delta_yx.shape[1])
)
for y, x in y_x_pairs:
x1, y1 = loc_x[y, x], loc_y[y, x]
if line_length[y, x] > min_length:
points.append((x1, y1))
points.append((x2[y, x], y2[y, x]))
points.append((p1_x[y, x], p1_y[y, x]))
points.append((x2[y, x], y2[y, x]))
points.append((p2_x[y, x], p2_y[y, x]))
points.append((x2[y, x], y2[y, x]))
self.points = list(itertools.starmap(QtCore.QPointF, points))
def _decimate(self, image: np.array, box: int):
height = width = box
# Source: https://stackoverflow.com/questions/48482317/slice-an-image-into-tiles-using-numpy
_nrows, _ncols, depth = image.shape
_size = image.size
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
# if we can't tile whole image, forget about bottom/right edges
image = image[: (nrows + 1) * box, : (ncols + 1) * box]
tiles = np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False,
)
# Since strides accesses the ndarray by memory, we need to swap axes if
# the array is stored column-major (Fortran), which it is from h5py.
if _strides[0] < _strides[1]:
tiles = np.swapaxes(tiles, 0, 1)
return np.mean(tiles, axis=(2, 3))
[docs] def boundingRect(self) -> QtCore.QRectF:
"""Method called by Qt in order to determine whether object is in visible frame."""
return QtCore.QRectF(self.rect)
[docs] def paint(self, painter, option, widget=None):
"""Method called by Qt to draw object."""
if self.pen is not None:
painter.setPen(self.pen)
painter.drawLines(self.points)
pass
def show_pafs_from_h5(filename, input_format="channels_last", standalone=False):
video = HDF5Video(filename, "/box", input_format=input_format)
paf_data = HDF5Video(
filename, "/pafs", input_format=input_format, convert_range=False
)
pafs_ = [paf_data.get_frame(i) for i in range(paf_data.frames)]
pafs = np.stack(pafs_)
return demo_pafs(pafs, video, standalone=standalone)
def demo_pafs(pafs, video, decimation=4, scale=None, standalone=False):
from sleap.gui.widgets.video import QtVideoPlayer
if standalone:
app = QtWidgets.QApplication([])
win = QtVideoPlayer(video=video)
win.setWindowTitle("pafs")
decimation_size_bar = QtWidgets.QSlider(QtCore.Qt.Horizontal)
decimation_size_bar.valueChanged.connect(lambda e: win.plot())
decimation_size_bar.setValue(decimation)
decimation_size_bar.setMinimum(1)
decimation_size_bar.setMaximum(10)
decimation_size_bar.setEnabled(True)
win.layout.addWidget(decimation_size_bar)
win.show()
def plot_fields(parent, frame_idx):
if frame_idx < pafs.shape[0]:
frame_pafs = pafs[frame_idx, ...]
decimation = decimation_size_bar.value()
aff_fields_item = MultiQuiverPlot(
frame_pafs, show=None, decimation=decimation
)
if scale:
aff_fields_item.setScale(scale)
win.view.scene.addItem(aff_fields_item)
win.changedPlot.connect(plot_fields)
win.plot()
if standalone:
app.exec_()
return win
if __name__ == "__main__":
data_path = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5"
input_format = "channels_first"
show_pafs_from_h5(data_path, input_format=input_format, standalone=True)