Source code for sleap.nn.architectures.common
"""Common utilities for architecture and model building."""
import attr
import tensorflow as tf
[docs]@attr.s(auto_attribs=True)
class IntermediateFeature:
"""Intermediate feature tensor for use in skip connections.
This class is effectively a named tuple to store the stride (resolution) metadata.
Attributes:
tensor: The tensor output from an intermediate layer.
stride: Stride of the tensor relative to the input.
"""
tensor: tf.Tensor
stride: int
@property
def scale(self) -> float:
"""Return the absolute scale of the tensor relative to the input.
This is equivalent to the reciprocal of the stride, e.g., stride 2 => scale 0.5.
"""
return 1.0 / float(self.stride)