sleap.nn.architectures.upsampling
sleap.nn.architectures.upsampling#
This module defines common upsampling layer stack configurations.
- The generic upsampling stack consists of:
transposed convolution or bilinear upsampling with stride > 1
skip connections
0 or more 3x3 convolutions for refinement
Configuring these components suffices to define the “decoder” portion of most canonical “encoder-decoder”-like architectures (e.g., LEAP CNN, UNet, Hourglass, etc.), as well as simpler patterns like shallow or direct upsampling (e.g., DLC).
- class sleap.nn.architectures.upsampling.UpsamplingStack(output_stride: int, upsampling_stride: int = 2, transposed_conv: bool = True, transposed_conv_filters: int = 64, transposed_conv_filters_rate: float = 1, transposed_conv_kernel_size: int = 4, transposed_conv_batchnorm: bool = True, make_skip_connection: bool = True, skip_add: bool = False, refine_convs: int = 2, refine_convs_filters: int = 64, refine_convs_filters_rate: float = 1, refine_convs_batchnorm: bool = True)[source]#
Standard stack of upsampling layers with refinement and skip connections.
- output_stride#
The desired final stride of the output tensor of the stack.
- Type
int
- upsampling_stride#
The striding of the upsampling layer (not tensor). This is typically set to 2, such that the tensor doubles in size with each upsampling step, but can be set higher to upsample to the desired
output_stride
directly in fewer steps. See the notes in themake_stack
method for examples.- Type
int
- transposed_conv#
If True, use a strided transposed convolution to perform learnable upsampling. If False, bilinear upsampling will be used instead.
- Type
bool
- transposed_conv_filters#
Integer that specifies the base number of filters in each transposed convolution layer. This will be scaled by the
transposed_conv_filters_rate
at every upsampling step. No effect if bilinear upsampling is used.- Type
int
- transposed_conv_filters_rate#
Factor to scale the number of filters in the transposed convolution layer after each upsampling step. If set to 1, the number of filters won’t change. No effect if bilinear upsampling is used.
- Type
float
- transposed_conv_kernel_size#
Size of the kernel for the transposed convolution. No effect if bilinear upsampling is used.
- Type
int
- transposed_conv_batchnorm#
Specifies whether batch norm should be applied after the transposed convolution (and before the ReLU activation). No effect if bilinear upsampling is used.
- Type
bool
- make_skip_connection#
If True, incoming feature tensors form skip connection with upsampled features. If False, no skip connection will be formed.
- Type
bool
- skip_add#
If True, incoming feature tensors form skip connection with upsampled features via element-wise addition. Height/width are matched via stride and a 1x1 linear conv is applied if the channel counts do no match up. If False, the skip connection is formed via channel-wise concatenation.
- Type
bool
- refine_convs#
If greater than 0, specifies the number of 3x3 convolutions that will be applied after the upsampling step for refinement. These layers can serve the purpose of “mixing” the skip connection fused features, or to refine the current feature map after upsampling, which can help to prevent aliasing and checkerboard effects. If 0, no additional convolutions will be applied.
- Type
int
- refine_convs_filters#
Similar to
transposed_conv_filters
, specifies the number of filters to use for the refinement convolutions in each upsampling step. No effect ifrefine_convs
is 0.- Type
int
- refine_convs_filters_rate#
Factor to scale the number of filters in the refine conv layers after each upsampling step. The same number of filters are used for all convs within the same upsampling step. If set to 1, the number of filters won’t change. No effect if
refine_convs
is 0.- Type
float
- refine_convs_batchnorm#
Specifies whether batch norm should be applied after each 3x3 convolution and before the ReLU activation. No effect if
refine_convs
is 0.- Type
bool
- classmethod from_config(config: sleap.nn.config.model.UpsamplingConfig, output_stride: int) sleap.nn.architectures.upsampling.UpsamplingStack [source]#
Create a model from a set of configuration parameters.
- Parameters
config – An
UpsamplingConfig
instance with the desired parameters.output_stride – The desired final stride of the output tensor of the stack.
- Returns
An instance of this class with the specified configuration.
- make_stack(x: tensorflow.python.framework.ops.Tensor, current_stride: int, skip_sources: Optional[Sequence[sleap.nn.architectures.common.IntermediateFeature]] = None) Tuple[tensorflow.python.framework.ops.Tensor, List[sleap.nn.architectures.common.IntermediateFeature]] [source]#
Create the stack of upsampling layers.
- Parameters
x – Feature tensor at low resolution, typically from encoder/backbone.
current_stride – Stride of the feature tensor relative to the network input.
skip_sources – If a list of
IntermediateFeature`s are provided, they will be searched to find source tensors with matching stride after each upsampling step. The first element of this list with a matching stride will be selected as the source at each level. Skip connection will be a concatenation or addition, depending on the `skip_add
class attribute.
- Returns
A tuple of the resulting upsampled tensor at the stride specified by the
output_stride
class attribute, and a list of intermediate tensors after each upsampling step.The intermediate features are useful when creating multi-head architectures with different output strides for the heads.
Note
The computed number of upsampling steps will be determined by the
current_stride
argument, andoutput_stride
andupsampling_stride
class attributes.- Specifically, the number of upsampling steps is equal to:
log(current_stride) - log(output_stride)
where the log base is equal to the
upsampling_stride
.These can be used to control the number of upsampling steps indirectly, for example:
Start with
current_stride = 16
and want to get tooutput_stride = 4
; withupsampling_stride = 2
this will take 2 upsampling steps, and withupsampling_stride = 4
this will take 1 upsampling step.