Source code for sleap.nn.config.data

import attr
from typing import Optional, Text, List
import sleap


[docs]@attr.s(auto_attribs=True) class LabelsConfig: """Labels configuration. Attributes: training_labels: A filepath to a saved labels file containing user labeled frames to use for generating the training set. validation_labels: A filepath to a saved labels file containing user labeled frames to use for generating validation data. These will not be trained on directly, but will be used to tune hyperparameters such as learning rate or early stopping. If not specified, the validation set will be sampled from the training labels. validation_fraction: Float between 0 and 1 specifying the fraction of the training set to sample for generating the validation set. The remaining labeled frames will be left in the training set. If the `validation_labels` are already specified, this has no effect. test_labels: A filepath to a saved labels file containing user labeled frames to use for generating the test set. This is typically a held out set of examples that are never used for training or hyperparameter tuning (like the validation set). This is optional, but useful for benchmarking as metrics can be computed from these data during model optimization. This is also useful to explicitly keep track of the test set that should be used when multiple splits are created for training. search_path_hints: List of paths to use for searching for missing data. This is useful when labels and data are moved across computers, network storage, or operating systems that may have different absolute paths than those stored in the labels. This has no effect if the labels were exported as a package with the user labeled data. skeletons: List of `sleap.Skeleton` instances that can be used by the model. If not specified, these will be pulled out of the labels during training, but must be specified for inference in order to generate predicted instances. """ training_labels: Optional[Text] = None validation_labels: Optional[Text] = None validation_fraction: float = 0.1 test_labels: Optional[Text] = None search_path_hints: List[Text] = attr.ib(factory=list) skeletons: List[sleap.Skeleton] = attr.ib(factory=list)
[docs]@attr.s(auto_attribs=True) class PreprocessingConfig: """Preprocessing configuration. Attributes: ensure_rgb: If True, converts the image to RGB if not already. ensure_grayscale: If True, converts the image to grayscale if not already. imagenet_mode: Specifies an ImageNet-based normalization mode commonly used in `tf.keras.applications`-based pretrained models. This has no effect if None or not specified. Valid values are: "tf": Values will be scaled to [-1, 1], expanded to RGB if grayscale. This preprocessing mode is required when using pretrained ResNetV2, MobileNetV1, MobileNetV2 and NASNet models. "caffe": Values will be scaled to [0, 255], expanded to RGB if grayscale, RGB channels flipped to BGR, and subtracted by a fixed mean. This preprocessing mode is required when using pretrained ResNetV1 models. "torch": Values will be scaled to [0, 1], expanded to RGB if grayscale, subtracted by a fixed mean, and scaled by fixed standard deviation. This preprocessing mode is required when using pretrained DenseNet models. input_scale: Scalar float specifying scaling factor to resize raw images by. This can considerably increase performance and memory requirements at the cost of accuracy. Generally, it should only be used when the raw images are at a much higher resolution than the smallest features in the data. pad_to_stride: Number of pixels that the image size must be divisible by. If > 1, this will pad the bottom and right of the images to ensure they meet this divisibility criteria. Padding is applied after the scaling specified in the `input_scale` attribute. If set to None, this will be automatically detected from the model architecture. This must be divisible by the model's max stride (typically 32). This padding will be ignored when instance cropping inputs since the crop size should already be divisible by the model's max stride. """ ensure_rgb: bool = False ensure_grayscale: bool = False imagenet_mode: Optional[Text] = attr.ib( default=None, validator=attr.validators.optional( attr.validators.in_(["tf", "caffe", "torch"]) ), ) input_scaling: float = 1.0 pad_to_stride: Optional[int] = None
[docs]@attr.s(auto_attribs=True) class InstanceCroppingConfig: """Instance cropping configuration. These are only used in topdown or centroid models. Attributes: center_on_part: String name of the part to center the instance to. If None or not specified, instances will be centered to the centroid of their bounding box. This value will be used for both topdown and centroid models. It must match the name of a node on the skeleton. crop_size: Integer size of bounding box height and width to crop out of the full image. This should be greater than the largest size of the instances in pixels. The crop is applied after any input scaling, so be sure to adjust this to changes in the input image scale. If set to None, this will be automatically detected from the data during training or from the model input layer during inference. This must be divisible by the model's max stride (typically 32). crop_size_detection_padding: Integer specifying how much extra padding should be applied around the instance bounding boxes when automatically detecting the appropriate crop size from the data. No effect if the `crop_size` is already specified. """ center_on_part: Optional[Text] = None crop_size: Optional[int] = None crop_size_detection_padding: int = 16
[docs]@attr.s(auto_attribs=True) class DataConfig: """Data configuration. labels: Configuration options related to user labels for training or testing. preprocessing: Configuration options related to data preprocessing. instance_cropping: Configuration options related to instance cropping for centroid and topdown models. """ labels: LabelsConfig = attr.ib(factory=LabelsConfig) preprocessing: PreprocessingConfig = attr.ib(factory=PreprocessingConfig) instance_cropping: InstanceCroppingConfig = attr.ib(factory=InstanceCroppingConfig)