Source code for sleap.nn.architectures.upsampling

"""This module defines common upsampling layer stack configurations.

The generic upsampling stack consists of:
    - transposed convolution or bilinear upsampling with stride > 1
    - skip connections
    - 0 or more 3x3 convolutions for refinement

Configuring these components suffices to define the "decoder" portion of most canonical
"encoder-decoder"-like architectures (e.g., LEAP CNN, UNet, Hourglass, etc.), as well as
simpler patterns like shallow or direct upsampling (e.g., DLC).
"""

import numpy as np
import tensorflow as tf
import attr
from typing import Union, Sequence, List, Tuple, Optional

from sleap.nn.architectures.common import IntermediateFeature
from sleap.nn.config import UpsamplingConfig


[docs]@attr.s(auto_attribs=True) class UpsamplingStack: """Standard stack of upsampling layers with refinement and skip connections. Attributes: output_stride: The desired final stride of the output tensor of the stack. upsampling_stride: The striding of the upsampling *layer* (not tensor). This is typically set to 2, such that the tensor doubles in size with each upsampling step, but can be set higher to upsample to the desired `output_stride` directly in fewer steps. See the notes in the `make_stack` method for examples. transposed_conv: If True, use a strided transposed convolution to perform learnable upsampling. If False, bilinear upsampling will be used instead. transposed_conv_filters: Integer that specifies the base number of filters in each transposed convolution layer. This will be scaled by the `transposed_conv_filters_rate` at every upsampling step. No effect if bilinear upsampling is used. transposed_conv_filters_rate: Factor to scale the number of filters in the transposed convolution layer after each upsampling step. If set to 1, the number of filters won't change. No effect if bilinear upsampling is used. transposed_conv_kernel_size: Size of the kernel for the transposed convolution. No effect if bilinear upsampling is used. transposed_conv_batchnorm: Specifies whether batch norm should be applied after the transposed convolution (and before the ReLU activation). No effect if bilinear upsampling is used. make_skip_connection: If True, incoming feature tensors form skip connection with upsampled features. If False, no skip connection will be formed. skip_add: If True, incoming feature tensors form skip connection with upsampled features via element-wise addition. Height/width are matched via stride and a 1x1 linear conv is applied if the channel counts do no match up. If False, the skip connection is formed via channel-wise concatenation. refine_convs: If greater than 0, specifies the number of 3x3 convolutions that will be applied after the upsampling step for refinement. These layers can serve the purpose of "mixing" the skip connection fused features, or to refine the current feature map after upsampling, which can help to prevent aliasing and checkerboard effects. If 0, no additional convolutions will be applied. refine_convs_filters: Similar to `transposed_conv_filters`, specifies the number of filters to use for the refinement convolutions in each upsampling step. No effect if `refine_convs` is 0. refine_convs_filters_rate: Factor to scale the number of filters in the refine conv layers after each upsampling step. The same number of filters are used for all convs within the same upsampling step. If set to 1, the number of filters won't change. No effect if `refine_convs` is 0. refine_convs_batchnorm: Specifies whether batch norm should be applied after each 3x3 convolution and before the ReLU activation. No effect if `refine_convs` is 0. """ output_stride: int upsampling_stride: int = 2 transposed_conv: bool = True transposed_conv_filters: int = 64 transposed_conv_filters_rate: float = 1 transposed_conv_kernel_size: int = 4 transposed_conv_batchnorm: bool = True make_skip_connection: bool = True skip_add: bool = False refine_convs: int = 2 refine_convs_filters: int = 64 refine_convs_filters_rate: float = 1 refine_convs_batchnorm: bool = True
[docs] @classmethod def from_config( cls, config: UpsamplingConfig, output_stride: int ) -> "UpsamplingStack": """Create a model from a set of configuration parameters. Args: config: An `UpsamplingConfig` instance with the desired parameters. output_stride: The desired final stride of the output tensor of the stack. Returns: An instance of this class with the specified configuration. """ return cls( output_stride=output_stride, upsampling_stride=config.block_stride, transposed_conv=config.method == "transposed_conv", transposed_conv_filters=config.filters, transposed_conv_filters_rate=config.filters_rate, transposed_conv_kernel_size=config.transposed_conv_kernel_size, transposed_conv_batchnorm=config.batch_norm, make_skip_connection=config.skip_connections is not None, skip_add=config.skip_connections is not None and config.skip_connections == "add", refine_convs=config.refine_convs, refine_convs_filters=config.filters, refine_convs_filters_rate=config.filters_rate, refine_convs_batchnorm=config.batch_norm, )
[docs] def make_stack( self, x: tf.Tensor, current_stride: int, skip_sources: Optional[Sequence[IntermediateFeature]] = None, ) -> Tuple[tf.Tensor, List[IntermediateFeature]]: """Create the stack of upsampling layers. Args: x: Feature tensor at low resolution, typically from encoder/backbone. current_stride: Stride of the feature tensor relative to the network input. skip_sources: If a list of `IntermediateFeature`s are provided, they will be searched to find source tensors with matching stride after each upsampling step. The first element of this list with a matching stride will be selected as the source at each level. Skip connection will be a concatenation or addition, depending on the `skip_add` class attribute. Returns: A tuple of the resulting upsampled tensor at the stride specified by the `output_stride` class attribute, and a list of intermediate tensors after each upsampling step. The intermediate features are useful when creating multi-head architectures with different output strides for the heads. Note: The computed number of upsampling steps will be determined by the `current_stride` argument, and `output_stride` and `upsampling_stride` class attributes. Specifically, the number of upsampling steps is equal to: `log(current_stride) - log(output_stride)` where the log base is equal to the `upsampling_stride`. These can be used to control the number of upsampling steps indirectly, for example: Start with `current_stride = 16` and want to get to `output_stride = 4`; with `upsampling_stride = 2` this will take 2 upsampling steps, and with `upsampling_stride = 4` this will take 1 upsampling step. """ # Calculate the number of upsampling steps. num_blocks = int( (np.log(current_stride) - np.log(self.output_stride)) / np.log(self.upsampling_stride) ) # Initialize intermediate features with the output of the encoder stack. intermediate_feats = [IntermediateFeature(tensor=x, stride=current_stride)] # Create each upsampling block. for block in range(num_blocks): # Update stride level. new_stride = current_stride // self.upsampling_stride block_prefix = f"upsample_s{current_stride}_to_s{new_stride}" if self.transposed_conv: # Upsample via strided transposed convolution. block_trans_conv_filters = int( self.transposed_conv_filters * self.transposed_conv_filters_rate ** block ) x = tf.keras.layers.Conv2DTranspose( filters=block_trans_conv_filters, kernel_size=self.transposed_conv_kernel_size, strides=self.upsampling_stride, padding="same", name=block_prefix + "_trans_conv", )(x) if self.transposed_conv_batchnorm: x = tf.keras.layers.BatchNormalization(name=block_prefix + "_bn")(x) x = tf.keras.layers.Activation("relu", name=block_prefix + "_relu")(x) else: # Upsample via bilinear interpolation. x = tf.keras.layers.UpSampling2D( size=self.upsampling_stride, interpolation="bilinear", name=block_prefix + "_interp", )(x) # Tensor is now upsampled to the updated stride. current_stride = new_stride # Form skip connection if there are any available at this stride level. if skip_sources is not None and self.make_skip_connection: added_skip = False for skip_source in skip_sources: if not added_skip and skip_source.stride == current_stride: if self.skip_add: source_x = skip_source.tensor if source_x.shape[-1] != x.shape[-1]: # Adjust channel count via 1x1 linear conv if not # matching. source_x = tf.keras.layers.Conv2D( filters=x.shape[-1], kernel_size=1, strides=1, padding="same", name=block_prefix + "_skip_conv1x1", )(source_x) # Concatenate via addition. x = tf.keras.layers.Add(name=block_prefix + "_skip_add")( [source_x, x] ) else: # Simple concatenation. x = tf.keras.layers.Concatenate( name=block_prefix + "_skip_concat" )([skip_source.tensor, x]) added_skip = True # Add further convolutions to refine after upsampling and/or skip. block_refine_filters = int( self.refine_convs_filters * self.refine_convs_filters_rate ** block ) for i in range(self.refine_convs): x = tf.keras.layers.Conv2D( filters=block_refine_filters, kernel_size=3, strides=1, padding="same", name=block_prefix + f"_refine{i}_conv", )(x) if self.refine_convs_batchnorm: x = tf.keras.layers.BatchNormalization( name=block_prefix + f"_refine{i}_bn" )(x) x = tf.keras.layers.Activation( "relu", name=block_prefix + f"_refine{i}_relu" )(x) intermediate_feats.append( IntermediateFeature(tensor=x, stride=current_stride) ) return x, intermediate_feats