"""Transformers for image resizing and padding."""
import tensorflow as tf
import attr
from typing import List, Text, Optional, Tuple
from sleap.nn.config import PreprocessingConfig
from sleap.nn.data.utils import expand_to_rank
[docs]def find_padding_for_stride(
image_height: int, image_width: int, max_stride: int
) -> Tuple[int, int]:
"""Compute padding required to ensure image is divisible by a stride.
This function is useful for determining how to pad images such that they will not
have issues with divisibility after repeated pooling steps.
Args:
image_height: Scalar integer specifying the image height (rows).
image_width: Scalar integer specifying the image height (columns).
max_stride: Scalar integer specifying the maximum stride that the image must be
divisible by.
Returns:
A tuple of (pad_bottom, pad_right), integers with the number of pixels that the
image would need to be padded by to meet the divisibility requirement.
"""
pad_bottom = (max_stride - (image_height % max_stride)) % max_stride
pad_right = (max_stride - (image_width % max_stride)) % max_stride
return pad_bottom, pad_right
[docs]def pad_to_stride(image: tf.Tensor, max_stride: int) -> tf.Tensor:
"""Pad an image to meet a max stride constraint.
This is useful for ensuring there is no size mismatch between an image and the
output tensors after multiple downsampling and upsampling steps.
Args:
image: Single image tensor of shape (height, width, channels).
max_stride: Scalar integer specifying the maximum stride that the image must be
divisible by. This is the ratio between the length of the image and the
length of the smallest tensor it is converted to. This is typically
`2 ** n_down_blocks`, where `n_down_blocks` is the number of 2-strided
reduction layers in the model.
Returns:
The input image with 0-padding applied to the bottom and/or right such that the
new shape's height and width are both divisible by `max_stride`.
"""
pad_bottom, pad_right = find_padding_for_stride(
image_height=tf.shape(image)[0],
image_width=tf.shape(image)[1],
max_stride=max_stride,
)
if pad_bottom > 0 or pad_right > 0:
paddings = tf.cast([[0, pad_bottom], [0, pad_right], [0, 0]], tf.int32)
image = tf.pad(image, paddings, mode="CONSTANT", constant_values=0)
return image
[docs]def resize_image(image: tf.Tensor, scale: tf.Tensor) -> tf.Tensor:
"""Rescale an image by a scale factor.
This function is primarily a convenience wrapper for `tf.image.resize` that
calculates the new shape from the scale factor.
Args:
image: Single image tensor of shape (height, width, channels).
scale: Factor to resize the image dimensions by, specified as either a float
scalar or as a 2-tuple of [scale_x, scale_y]. If a scalar is provided, both
dimensions are resized by the same factor.
Returns:
The resized image tensor of the same dtype but scaled height and width.
See also: tf.image.resize
"""
height = tf.shape(image)[0]
width = tf.shape(image)[1]
new_size = tf.reverse(
tf.cast(
tf.cast([width, height], tf.float32) * tf.cast(scale, tf.float32), tf.int32
),
[0],
)
return tf.cast(
tf.image.resize(
image,
size=new_size,
method="bilinear",
preserve_aspect_ratio=False,
antialias=False,
),
image.dtype,
)
[docs]@attr.s(auto_attribs=True)
class Resizer:
"""Data transformer to resize or pad images.
This is useful as a transformation to data streams that require resizing or padding
in order to be downsampled or meet divisibility criteria.
Attributes:
image_key: String name of the key containing the images to resize.
scale_key: String name of the key containing the scale of the images.
points_key: String name of the key containing points to adjust for the resizing
operation.
scale: Scalar float specifying scaling factor to resize images by.
pad_to_stride: Maximum stride in a model that the images must be divisible by.
If > 1, this will pad the bottom and right of the images to ensure they meet
this divisibility criteria. Padding is applied after the scaling specified
in the `scale` attribute.
keep_full_image: If True, keeps the (original size) full image in the examples.
This is useful for multi-scale inference.
full_image_key: String name of the key containing the full images.
"""
image_key: Text = "image"
scale_key: Text = "scale"
points_key: Optional[Text] = "instances"
scale: float = 1.0
pad_to_stride: int = 1
keep_full_image: bool = False
full_image_key: Text = "full_image"
[docs] @classmethod
def from_config(
cls,
config: PreprocessingConfig,
image_key: Text = "image",
scale_key: Text = "scale",
pad_to_stride: Optional[int] = None,
keep_full_image: bool = False,
full_image_key: Text = "full_image",
points_key: Optional[Text] = "instances"
) -> "Resizer":
"""Build an instance of this class from its configuration options.
Args:
config: An `PreprocessingConfig` instance with the desired parameters. If
`config.pad_to_stride` is not an explicit integer, the `pad_to_stride`
parameter must be provided.
image_key: String name of the key containing the images to resize.
scale_key: String name of the key containing the scale of the images.
pad_to_stride: An integer specifying the `pad_to_stride` if
`config.pad_to_stride` is not an explicit integer (e.g., set to None).
keep_full_image: If True, keeps the (original size) full image in the
examples. This is useful for multi-scale inference.
full_image_key: String name of the key containing the full images.
points_key: String name of the key containing points to adjust for the
resizing operation.
Returns:
An instance of this class.
Raises:
ValueError: If `config.pad_to_stride` is not set to an integer and the
`pad_to_stride` argument is not provided.
"""
if isinstance(config.pad_to_stride, int):
pad_to_stride = config.pad_to_stride
if not isinstance(pad_to_stride, int):
raise ValueError(
"Pad to stride must be specified in the arguments if not explicitly "
f"set to an integer (config.pad_to_stride = {config.pad_to_stride})."
)
return cls(
image_key=image_key,
points_key=points_key,
scale_key=scale_key,
scale=config.input_scaling,
pad_to_stride=pad_to_stride,
keep_full_image=keep_full_image,
full_image_key=full_image_key,
)
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
input_keys = [self.image_key, self.scale_key]
if self.points_key is not None:
input_keys.append(self.points_key)
return input_keys
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
output_keys = self.input_keys
if self.keep_full_image:
output_keys.append(self.full_image_key)
return output_keys
[docs]@attr.s(auto_attribs=True)
class PointsRescaler:
"""Transformer to apply or invert scaling operations on points."""
points_key: Text = "predicted_instances"
scale_key: Text = "scale"
invert: bool = True
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
return [self.points_key, self.scale_key]
@property
def output_keys(self) -> List[Text]:
"""Return the keys that outgoing elements will have."""
return self.input_keys