Source code for sleap.gui.learning.receptivefield

"""
Widget for previewing receptive field on sample image using model hyperparams.
"""
from sleap import Video
from sleap.nn.config import ModelConfig
from sleap.gui.widgets.video import GraphicsView

import numpy as np

from sleap import Skeleton
from sleap.nn.model import Model

from typing import Optional, Text

from PySide2 import QtWidgets, QtGui, QtCore


[docs]def compute_rf(down_blocks: int, convs_per_block: int = 2, kernel_size: int = 3) -> int: """ Computes receptive field for specified model architecture. Ref: https://distill.pub/2019/computing-receptive-fields/ (Eq. 2) """ # Define the strides and kernel sizes for a single down block. # convs have stride 1, pooling has stride 2: block_strides = [1] * convs_per_block + [2] # convs have `kernel_size` x `kernel_size` kernels, pooling has 2 x 2 kernels: block_kernels = [kernel_size] * convs_per_block + [2] # Repeat block parameters by the total number of down blocks. strides = np.array(block_strides * down_blocks) kernels = np.array(block_kernels * down_blocks) # L = Total number of layers L = len(strides) # Compute the product term of the RF equation. rf = 1 for l in range(L): rf += (kernels[l] - 1) * np.prod(strides[:l]) return int(rf)
[docs]def receptive_field_info_from_model_cfg(model_cfg: ModelConfig) -> dict: """Gets receptive field information given specific model configuration.""" rf_info = dict( size=None, max_stride=None, down_blocks=None, convs_per_block=None, kernel_size=None, ) try: model = Model.from_config(model_cfg, Skeleton()) except ZeroDivisionError: # Unable to create model from these config parameters return rf_info if hasattr(model_cfg.backbone.which_oneof(), "max_stride"): rf_info["max_stride"] = model_cfg.backbone.which_oneof().max_stride if hasattr(model.backbone, "down_convs_per_block"): rf_info["convs_per_block"] = model.backbone.down_convs_per_block elif hasattr(model.backbone, "convs_per_block"): rf_info["convs_per_block"] = model.backbone.convs_per_block if hasattr(model.backbone, "kernel_size"): rf_info["kernel_size"] = model.backbone.kernel_size rf_info["down_blocks"] = model.backbone.down_blocks if rf_info["down_blocks"] and rf_info["convs_per_block"] and rf_info["kernel_size"]: rf_info["size"] = compute_rf( down_blocks=rf_info["down_blocks"], convs_per_block=rf_info["convs_per_block"], kernel_size=rf_info["kernel_size"], ) return rf_info
[docs]class ReceptiveFieldWidget(QtWidgets.QWidget): """ Widget for previewing receptive field on sample image, with caption. Args: head_name: If given, then used in caption to show which model the preview is for. Usage: Create, then call `setImage` and `setModelConfig` methods. """ def __init__(self, head_name: Text = "", *args, **kwargs): super(ReceptiveFieldWidget, self).__init__(*args, **kwargs) self.layout = QtWidgets.QVBoxLayout() self._field_image_widget = ReceptiveFieldImageWidget() self._info_text_header = ( f"<p>Receptive Field for {head_name}:</p>" if head_name else "<p>Receptive Field:</p>" ) self._info_widget = QtWidgets.QLabel("") self.layout.addWidget(self._field_image_widget) self.layout.addWidget(self._info_widget) self.layout.addStretch() self.setLayout(self.layout) def _get_info_text( self, size, scale, max_stride, down_blocks, convs_per_block, kernel_size ) -> Text: """Returns text explaining how receptive field size is determined.""" result = self._info_text_header if size: result += f"<p><i>{size} pixels</i></p>" else: result += f"<p><i>Unable to determine size</i></p>" result += f""" <p>Receptive field size is a function<br /> of the number of down blocks ({down_blocks}), the<br /> number of convolutions per block ({convs_per_block}),<br /> and the convolution kernel size ({kernel_size}).</p> <p>You can control the number of down<br /> blocks by setting the <b>Max Stride</b> ({max_stride}).</p> <p>The number of convolutions per block<br /> and the kernel size are currently fixed<br /> by your choice of backbone.</p> <p>You can also control the receptive<br /> field size relative to the original<br /> image by adjusting the <b>Input Scaling</b> ({scale}).</p> """ return result
[docs] def setModelConfig(self, model_cfg: ModelConfig, scale: float): """Updates receptive field preview from model config.""" rf_info = receptive_field_info_from_model_cfg(model_cfg) self._info_widget.setText( self._get_info_text( size=rf_info["size"], scale=scale, max_stride=rf_info["max_stride"], down_blocks=rf_info["down_blocks"], convs_per_block=rf_info["convs_per_block"], kernel_size=rf_info["kernel_size"], ) ) self._field_image_widget._set_field_size(rf_info["size"] or 0, scale)
[docs] def setImage(self, *args, **kwargs): """Sets image on which receptive field box will be drawn.""" self._field_image_widget.setImage(*args, **kwargs)
[docs]class ReceptiveFieldImageWidget(GraphicsView): """Widget for showing image with receptive field.""" def __init__(self, *args, **kwargs): self._widget_size = 200 self._pen_width = 4 self._box_size = None self._scale = None box_pen = QtGui.QPen(QtGui.QColor("blue"), self._pen_width) box_pen.setCosmetic(True) self.box = QtWidgets.QGraphicsRectItem() self.box.setPen(box_pen) super(ReceptiveFieldImageWidget, self).__init__(*args, **kwargs) self.setFixedSize(self._widget_size, self._widget_size) self.scene.addItem(self.box) # TODO: zoom around bounding box for labeled instance # self.zoomToRect(QtCore.QRectF(0, 0, 1, 1))
[docs] def viewportEvent(self, event): """ Re-draw receptive field when needed by overriding QGraphicsView method. """ # Update the position and visible size of field if isinstance(event, QtGui.QPaintEvent): self._set_field_size() # Now draw the viewport return super(ReceptiveFieldImageWidget, self).viewportEvent(event)
def _set_field_size(self, size: Optional[int] = None, scale: float = 1.0): """Draws receptive field preview rect, updating size if needed.""" if size is not None: self._box_size = size self._scale = scale if self._box_size: self.box.show() else: self.box.hide() return # Adjust box relative to scaling on image that will happen in training scaled_box_size = self._box_size // self._scale # Calculate offset so that box stays centered in the view vis_box_rect = self.mapFromScene( 0, 0, scaled_box_size, scaled_box_size ).boundingRect() offset = self._widget_size / 2 scene_center = self.mapToScene( offset - (vis_box_rect.width() / 2), offset - (vis_box_rect.height() / 2) ) self.box.setRect( scene_center.x(), scene_center.y(), scaled_box_size, scaled_box_size )
def demo_receptive_field(): app = QtWidgets.QApplication([]) video = Video.from_filename("tests/data/videos/centered_pair_small.mp4") win = ReceptiveFieldImageWidget() win.setImage(video.get_frame(0)) win._set_field_size(50) win.show() app.exec_() if __name__ == "__main__": demo_receptive_field()