sleap.nn.architectures.unet#
This module provides a generalized implementation of UNet.
See the UNet
class docstring for more information.
- class sleap.nn.architectures.unet.PoolingBlock(pool: bool = True, pooling_stride: int = 2)[source]#
Pooling-only encoder block.
Used to compensate for UNet having a skip source before the pooling, so the blocks need to end with a conv, not the pooling layer. This is added to the end of the encoder stack to ensure that the number of down blocks is equal to the number of pooling steps.
- pool#
If True, applies max pooling at the end of the block.
- Type
bool
- pooling_stride#
Stride of the max pooling operation. If 1, the output of this block will be at the same stride (== 1/scale) as the input.
- Type
int
- class sleap.nn.architectures.unet.UNet(stacks: int = 1, filters: int = 64, filters_rate: float = 2, kernel_size: int = 3, stem_kernel_size: int = 3, convs_per_block: int = 2, stem_blocks: int = 0, down_blocks: int = 4, middle_block: bool = True, up_blocks: int = 4, up_interpolate: bool = False, block_contraction: bool = False)[source]#
UNet encoder-decoder architecture for fully convolutional networks.
This is the canonical architecture described in Ronneberger et al., 2015.
The default configuration with 4 down/up blocks and 64 base filters has ~34.5M parameters.
- filters#
Base number of filters in the first encoder block. More filters will increase the representational capacity of the network at the cost of memory and runtime.
- Type
int
- filters_rate#
Factor to increase the number of filters by in each block.
- Type
float
- kernel_size#
Size of convolutional kernels (== height == width).
- Type
int
- stem_kernel_size#
Size of convolutional kernels in stem blocks.
- Type
int
- stem_blocks#
If >0, will create additional “down” blocks for initial downsampling. These will be configured identically to the down blocks below.
- Type
int
- down_blocks#
Number of blocks with pooling in the encoder. More down blocks will
- Type
int
- convs_per_block#
Number of convolutions in each block. More convolutions per block will increase the representational capacity of the network at the cost of memory and runtime. increase the effective maximum receptive field.
- Type
int
- up_blocks#
Number of blocks with upsampling in the decoder. If this is equal to
down_blocks
, the output of this network will be at the same stride (scale) as the input.- Type
int
- middle_block#
If True, add an additional block at the end of the encoder.
- Type
bool
- up_interpolate#
If True, use bilinear interpolation instead of transposed convolutions for upsampling. Interpolation is faster but transposed convolutions may be able to learn richer or more complex upsampling to recover details from higher scales. If using transposed convolutions, the number of filters are determined by
filters
andfilters_rate
to progressively decrease the number of filters at each step.- Type
bool
- block_contraction#
If True, reduces the number of filters at the end of middle and decoder blocks. This has the effect of introducing an additional bottleneck before each upsampling step. The original implementation does not do this, but the CARE implementation does.
- Type
bool
Note
This bears some differences with other implementations, particularly with respect to the skip connection source tensors in the encoder. In the original, the skip connection is formed from the output of the convolutions in each encoder block, not the pooling step. This results in skip connections starting at the first stride level as well as subsequent ones.
- property decoder_stack: List[sleap.nn.architectures.encoder_decoder.SimpleUpsamplingBlock]#
Define the decoder stack.
- property encoder_stack: List[sleap.nn.architectures.encoder_decoder.SimpleConvBlock]#
Define the encoder stack.
- classmethod from_config(config: sleap.nn.config.model.UNetConfig) sleap.nn.architectures.unet.UNet [source]#
Create a model from a set of configuration parameters.
- Parameters
config – An
UNetConfig
instance with the desired parameters.- Returns
An instance of this class with the specified configuration.
- property stem_stack: Optional[List[sleap.nn.architectures.encoder_decoder.SimpleConvBlock]]#
Define the downsampling stem.