Source code for sleap.gui.learning.utils

from typing import Any, Dict, Optional, Text

import attr
import cattr
import numpy as np

from sleap import Skeleton
from sleap.nn.config import TrainingJobConfig, ModelConfig
from sleap.nn.model import Model


@attr.s(auto_attribs=True)
class ScopedKeyDict:

    key_val_dict: Dict[Text, Any]

    @classmethod
    def set_hierarchical_key_val(cls, current_dict, key, val):
        # Ignore "private" keys starting with "_"
        if key[0] == "_":
            return

        if "." not in key:
            current_dict[key] = val
        else:
            top_key, *subkey_list = key.split(".")
            if top_key not in current_dict:
                current_dict[top_key] = dict()
            subkey = ".".join(subkey_list)
            cls.set_hierarchical_key_val(current_dict[top_key], subkey, val)

    def to_hierarchical_dict(self):
        hierarch_dict = dict()
        for key, val in self.key_val_dict.items():
            self.set_hierarchical_key_val(hierarch_dict, key, val)
        return hierarch_dict

    @classmethod
    def from_hierarchical_dict(cls, hierarch_dict):
        return cls(key_val_dict=cls._make_flattened_dict(hierarch_dict))

    @classmethod
    def _make_flattened_dict(cls, hierarch_dicts, scope_string=""):
        flattened_dict = dict()
        for key, val in hierarch_dicts.items():
            if isinstance(val, Dict):
                # Dict so recurse adding node to scope string
                flattened_dict.update(
                    cls._make_flattened_dict(
                        hierarch_dicts=val,
                        scope_string=cls._subscope_key(scope_string, key),
                    )
                )
            else:
                # Leafs (non-dict)
                flattened_dict[cls._subscope_key(scope_string, key)] = val
        return flattened_dict

    @staticmethod
    def _subscope_key(scope_string, key):
        return key if not scope_string else f"{scope_string}.{key}"


def apply_cfg_transforms_to_key_val_dict(key_val_dict):
    if "outputs.tags" in key_val_dict and isinstance(key_val_dict["outputs.tags"], str):
        key_val_dict["outputs.tags"] = [
            tag.strip() for tag in key_val_dict["outputs.tags"].split(",")
        ]

    if "_ensure_channels" in key_val_dict:
        ensure_channels = key_val_dict["_ensure_channels"].lower()
        ensure_rgb = False
        ensure_grayscale = False
        if ensure_channels == "rgb":
            ensure_rgb = True
        elif ensure_channels == "grayscale":
            ensure_grayscale = True

        key_val_dict["data.preprocessing.ensure_rgb"] = ensure_rgb
        key_val_dict["data.preprocessing.ensure_grayscale"] = ensure_grayscale

    if "model.backbone.resnet.upsampling.skip_connections" in key_val_dict:
        if key_val_dict["model.backbone.resnet.upsampling.skip_connections"] == "":
            key_val_dict["model.backbone.resnet.upsampling.skip_connections"] = None


def make_training_config_from_key_val_dict(key_val_dict):
    apply_cfg_transforms_to_key_val_dict(key_val_dict)
    cfg_dict = ScopedKeyDict(key_val_dict).to_hierarchical_dict()

    cfg = cattr.structure(cfg_dict, TrainingJobConfig)

    return cfg


def make_model_config_from_key_val_dict(key_val_dict):
    apply_cfg_transforms_to_key_val_dict(key_val_dict)
    cfg_dict = ScopedKeyDict(key_val_dict).to_hierarchical_dict()

    if "model" in cfg_dict:
        cfg_dict = cfg_dict["model"]

    return cattr.structure(cfg_dict, ModelConfig)


[docs]def compute_rf(down_blocks: int, convs_per_block: int = 2, kernel_size: int = 3) -> int: """ Ref: https://distill.pub/2019/computing-receptive-fields/ (Eq. 2) """ # Define the strides and kernel sizes for a single down block. # convs have stride 1, pooling has stride 2: block_strides = [1] * convs_per_block + [2] # convs have `kernel_size` x `kernel_size` kernels, pooling has 2 x 2 kernels: block_kernels = [kernel_size] * convs_per_block + [2] # Repeat block parameters by the total number of down blocks. strides = np.array(block_strides * down_blocks) kernels = np.array(block_kernels * down_blocks) # L = Total number of layers L = len(strides) # Compute the product term of the RF equation. rf = 1 for l in range(L): rf += (kernels[l] - 1) * np.prod(strides[:l]) return int(rf)
def receptive_field_info_from_model_cfg(model_cfg: ModelConfig) -> dict: rf_info = dict( size=None, max_stride=None, down_blocks=None, convs_per_block=None, kernel_size=None, ) try: model = Model.from_config(model_cfg, Skeleton()) except ZeroDivisionError: # Unable to create model from these config parameters return rf_info if hasattr(model_cfg.backbone.which_oneof(), "max_stride"): rf_info["max_stride"] = model_cfg.backbone.which_oneof().max_stride if hasattr(model.backbone, "down_convs_per_block"): rf_info["convs_per_block"] = model.backbone.down_convs_per_block elif hasattr(model.backbone, "convs_per_block"): rf_info["convs_per_block"] = model.backbone.convs_per_block if hasattr(model.backbone, "kernel_size"): rf_info["kernel_size"] = model.backbone.kernel_size rf_info["down_blocks"] = model.backbone.down_blocks if rf_info["down_blocks"] and rf_info["convs_per_block"] and rf_info["kernel_size"]: rf_info["size"] = compute_rf( down_blocks=rf_info["down_blocks"], convs_per_block=rf_info["convs_per_block"], kernel_size=rf_info["kernel_size"], ) return rf_info