"""This module defines the main SLEAP model class for defining a trainable model.
This is a higher level wrapper around `tf.keras.Model` that holds all the configuration
parameters required to construct the actual model. This allows for easy querying of the
model configuration without actually instantiating the model itself.
"""
import tensorflow as tf
import attr
from typing import List, TypeVar, Optional, Text, Tuple
import sleap
from sleap.nn.architectures import (
LeapCNN,
UNet,
Hourglass,
ResNetv1,
ResNet50,
ResNet101,
ResNet152,
IntermediateFeature,
)
from sleap.nn.heads import (
CentroidConfmapsHead,
SingleInstanceConfmapsHead,
CenteredInstanceConfmapsHead,
MultiInstanceConfmapsHead,
PartAffinityFieldsHead,
)
from sleap.nn.config import (
LEAPConfig,
UNetConfig,
HourglassConfig,
ResNetConfig,
SingleInstanceConfmapsHeadConfig,
CentroidsHeadConfig,
CenteredInstanceConfmapsHeadConfig,
MultiInstanceConfig,
BackboneConfig,
HeadsConfig,
ModelConfig,
)
from sleap.nn.data.utils import ensure_list
ARCHITECTURES = [LeapCNN, UNet, Hourglass, ResNetv1, ResNet50, ResNet101, ResNet152]
ARCHITECTURE_NAMES = [cls.__name__ for cls in ARCHITECTURES]
Architecture = TypeVar("Architecture", *ARCHITECTURES)
BACKBONE_CONFIG_TO_CLS = {
LEAPConfig: LeapCNN,
UNetConfig: UNet,
HourglassConfig: Hourglass,
ResNetConfig: ResNetv1,
}
HEADS = [
CentroidConfmapsHead,
SingleInstanceConfmapsHead,
CenteredInstanceConfmapsHead,
MultiInstanceConfmapsHead,
PartAffinityFieldsHead,
]
Head = TypeVar("Head", *HEADS)
[docs]@attr.s(auto_attribs=True)
class Model:
"""SLEAP model that describes an architecture and output types.
Attributes:
backbone: An `Architecture` class that provides methods for building a
tf.keras.Model given an input.
heads: List of `Head`s that define the outputs of the network.
keras_model: The current `tf.keras.Model` instance if one has been created.
"""
backbone: Architecture
heads: List[Head] = attr.ib(converter=ensure_list)
keras_model: Optional[tf.keras.Model] = None
[docs] @classmethod
def from_config(
cls,
config: ModelConfig,
skeleton: Optional[sleap.Skeleton] = None,
update_config: bool = False,
) -> "Model":
"""Create a SLEAP model from configurations.
Arguments:
config: The configurations as a `ModelConfig` instance.
skeleton: A `sleap.Skeleton` to use if not provided in the config.
Returns:
An instance of `Model` built with the specified configurations.
"""
# Figure out which backbone class to use.
backbone_config = config.backbone.which_oneof()
backbone_cls = BACKBONE_CONFIG_TO_CLS[type(backbone_config)]
# Figure out which head class to use.
head_config = config.heads.which_oneof()
if isinstance(head_config, SingleInstanceConfmapsHeadConfig):
part_names = head_config.part_names
if part_names is None:
if skeleton is None:
raise ValueError(
"Skeleton must be provided when the head configuration is "
"incomplete."
)
part_names = skeleton.node_names
if update_config:
head_config.part_names = part_names
heads = SingleInstanceConfmapsHead.from_config(
head_config, part_names=part_names
)
output_stride = heads.output_stride
elif isinstance(head_config, CentroidsHeadConfig):
heads = CentroidConfmapsHead.from_config(head_config)
output_stride = heads.output_stride
elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig):
part_names = head_config.part_names
if part_names is None:
if skeleton is None:
raise ValueError(
"Skeleton must be provided when the head configuration is "
"incomplete."
)
part_names = skeleton.node_names
if update_config:
head_config.part_names = part_names
heads = CenteredInstanceConfmapsHead.from_config(
head_config, part_names=part_names
)
output_stride = heads.output_stride
elif isinstance(head_config, MultiInstanceConfig):
part_names = head_config.confmaps.part_names
if part_names is None:
if skeleton is None:
raise ValueError(
"Skeleton must be provided when the head configuration is "
"incomplete."
)
part_names = skeleton.node_names
if update_config:
head_config.confmaps.part_names = part_names
edges = head_config.pafs.edges
if edges is None:
if skeleton is None:
raise ValueError(
"Skeleton must be provided when the head configuration is "
"incomplete."
)
edges = skeleton.edge_names
if update_config:
head_config.pafs.edges = edges
heads = [
MultiInstanceConfmapsHead.from_config(
head_config.confmaps, part_names=part_names
),
PartAffinityFieldsHead.from_config(head_config.pafs, edges=edges),
]
output_stride = min(heads[0].output_stride, heads[1].output_stride)
backbone_config.output_stride = output_stride
return cls(backbone=backbone_cls.from_config(backbone_config), heads=heads)
@property
def maximum_stride(self) -> int:
"""Return the maximum stride of the model backbone."""
return self.backbone.maximum_stride
[docs] def make_model(self, input_shape: Tuple[int, int, int]) -> tf.keras.Model:
"""Create a trainable model from the configuration.
Args:
input_shape: Tuple of (height, width, channels) specifying the shape of the
inputs before preprocessing.
Returns:
An instantiated `tf.keras.Model`.
"""
# Create input layer.
x_in = tf.keras.layers.Input(input_shape, name="input")
# Create backbone.
x_main, x_mid = self.backbone.make_backbone(x_in=x_in)
# Make sure main and intermediate feature outputs are lists.
if isinstance(x_main, tf.Tensor):
x_main = [x_main]
if len(x_mid) > 0 and isinstance(x_mid[0], IntermediateFeature):
x_mid = [x_mid]
# Build output layers for each head.
x_outs = []
for output in self.heads:
x_head = []
if output.output_stride == self.backbone.output_stride:
# The main output has the same stride as the head, so build output layer
# from that tensor.
for i, x in enumerate(x_main):
x_head.append(
tf.keras.layers.Conv2D(
filters=output.channels,
kernel_size=1,
strides=1,
padding="same",
name=f"{type(output).__name__}_{i}",
)(x)
)
else:
# Look for an intermediate activation that has the correct stride.
for feats in zip(*x_mid):
# TODO: Test for this assumption?
assert all([feat.stride == feats[0].stride for feat in feats])
if feats[0].stride == output.output_stride:
for i, feat in enumerate(feats):
x_head.append(
tf.keras.layers.Conv2D(
filters=output.channels,
kernel_size=1,
strides=1,
padding="same",
name=f"{type(output).__name__}_{i}",
)(feat.tensor)
)
break
if len(x_head) == 0:
raise ValueError(
f"Could not find a feature activation for output at stride "
f"{output.output_stride}."
)
x_outs.append(x_head)
# TODO: Warn/error if x_main was not connected to any heads?
# Create model.
self.keras_model = tf.keras.Model(inputs=x_in, outputs=x_outs)
return self.keras_model