sleap.nn.architectures.upsampling

This module defines common upsampling layer stack configurations.

The generic upsampling stack consists of:
  • transposed convolution or bilinear upsampling with stride > 1

  • skip connections

  • 0 or more 3x3 convolutions for refinement

Configuring these components suffices to define the “decoder” portion of most canonical “encoder-decoder”-like architectures (e.g., LEAP CNN, UNet, Hourglass, etc.), as well as simpler patterns like shallow or direct upsampling (e.g., DLC).

class sleap.nn.architectures.upsampling.UpsamplingStack(output_stride: int, upsampling_stride: int = 2, transposed_conv: bool = True, transposed_conv_filters: int = 64, transposed_conv_filters_rate: float = 1, transposed_conv_kernel_size: int = 4, transposed_conv_batchnorm: bool = True, make_skip_connection: bool = True, skip_add: bool = False, refine_convs: int = 2, refine_convs_filters: int = 64, refine_convs_filters_rate: float = 1, refine_convs_batchnorm: bool = True)[source]

Standard stack of upsampling layers with refinement and skip connections.

output_stride

The desired final stride of the output tensor of the stack.

Type

int

upsampling_stride

The striding of the upsampling layer (not tensor). This is typically set to 2, such that the tensor doubles in size with each upsampling step, but can be set higher to upsample to the desired output_stride directly in fewer steps. See the notes in the make_stack method for examples.

Type

int

transposed_conv

If True, use a strided transposed convolution to perform learnable upsampling. If False, bilinear upsampling will be used instead.

Type

bool

transposed_conv_filters

Integer that specifies the base number of filters in each transposed convolution layer. This will be scaled by the transposed_conv_filters_rate at every upsampling step. No effect if bilinear upsampling is used.

Type

int

transposed_conv_filters_rate

Factor to scale the number of filters in the transposed convolution layer after each upsampling step. If set to 1, the number of filters won’t change. No effect if bilinear upsampling is used.

Type

float

transposed_conv_kernel_size

Size of the kernel for the transposed convolution. No effect if bilinear upsampling is used.

Type

int

transposed_conv_batchnorm

Specifies whether batch norm should be applied after the transposed convolution (and before the ReLU activation). No effect if bilinear upsampling is used.

Type

bool

make_skip_connection

If True, incoming feature tensors form skip connection with upsampled features. If False, no skip connection will be formed.

Type

bool

skip_add

If True, incoming feature tensors form skip connection with upsampled features via element-wise addition. Height/width are matched via stride and a 1x1 linear conv is applied if the channel counts do no match up. If False, the skip connection is formed via channel-wise concatenation.

Type

bool

refine_convs

If greater than 0, specifies the number of 3x3 convolutions that will be applied after the upsampling step for refinement. These layers can serve the purpose of “mixing” the skip connection fused features, or to refine the current feature map after upsampling, which can help to prevent aliasing and checkerboard effects. If 0, no additional convolutions will be applied.

Type

int

refine_convs_filters

Similar to transposed_conv_filters, specifies the number of filters to use for the refinement convolutions in each upsampling step. No effect if refine_convs is 0.

Type

int

refine_convs_filters_rate

Factor to scale the number of filters in the refine conv layers after each upsampling step. The same number of filters are used for all convs within the same upsampling step. If set to 1, the number of filters won’t change. No effect if refine_convs is 0.

Type

float

refine_convs_batchnorm

Specifies whether batch norm should be applied after each 3x3 convolution and before the ReLU activation. No effect if refine_convs is 0.

Type

bool

classmethod from_config(config: sleap.nn.config.model.UpsamplingConfig, output_stride: int)sleap.nn.architectures.upsampling.UpsamplingStack[source]

Create a model from a set of configuration parameters.

Parameters
  • config – An UpsamplingConfig instance with the desired parameters.

  • output_stride – The desired final stride of the output tensor of the stack.

Returns

An instance of this class with the specified configuration.

make_stack(x: tensorflow.python.framework.ops.Tensor, current_stride: int, skip_sources: Optional[Sequence[sleap.nn.architectures.common.IntermediateFeature]] = None)Tuple[tensorflow.python.framework.ops.Tensor, List[sleap.nn.architectures.common.IntermediateFeature]][source]

Create the stack of upsampling layers.

Parameters
  • x – Feature tensor at low resolution, typically from encoder/backbone.

  • current_stride – Stride of the feature tensor relative to the network input.

  • skip_sources – If a list of IntermediateFeature`s are provided, they will be searched to find source tensors with matching stride after each upsampling step. The first element of this list with a matching stride will be selected as the source at each level. Skip connection will be a concatenation or addition, depending on the `skip_add class attribute.

Returns

A tuple of the resulting upsampled tensor at the stride specified by the output_stride class attribute, and a list of intermediate tensors after each upsampling step.

The intermediate features are useful when creating multi-head architectures with different output strides for the heads.

Note

The computed number of upsampling steps will be determined by the current_stride argument, and output_stride and upsampling_stride class attributes.

Specifically, the number of upsampling steps is equal to:

log(current_stride) - log(output_stride)

where the log base is equal to the upsampling_stride.

These can be used to control the number of upsampling steps indirectly, for example:

Start with current_stride = 16 and want to get to output_stride = 4; with upsampling_stride = 2 this will take 2 upsampling steps, and with upsampling_stride = 4 this will take 1 upsampling step.