Source code for sleap.nn.training

"""Training functionality and high level APIs."""

import os
import re
from datetime import datetime
from time import time
import logging

import tensorflow as tf
import numpy as np

import attr
from typing import Optional, Callable, List, Union, Text, TypeVar
from abc import ABC, abstractmethod

import cattr
import json
import copy

import sleap
from sleap.util import get_package_file

# Config
from sleap.nn.config import (
    TrainingJobConfig,
    SingleInstanceConfmapsHeadConfig,
    CentroidsHeadConfig,
    CenteredInstanceConfmapsHeadConfig,
    MultiInstanceConfig,
)

# Model
from sleap.nn.model import Model

# Data
from sleap.nn.config import LabelsConfig
from sleap.nn.data.pipelines import LabelsReader
from sleap.nn.data.pipelines import (
    Pipeline,
    SingleInstanceConfmapsPipeline,
    CentroidConfmapsPipeline,
    TopdownConfmapsPipeline,
    BottomUpPipeline,
    KeyMapper,
)
from sleap.nn.data.training import split_labels

# Optimization
from sleap.nn.config import OptimizationConfig
from sleap.nn.losses import OHKMLoss, PartLoss
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

# Outputs
from sleap.nn.config import (
    OutputsConfig,
    ZMQConfig,
    TensorBoardConfig,
    CheckpointingConfig,
)
from sleap.nn.callbacks import (
    TrainingControllerZMQ,
    ProgressReporterZMQ,
    ModelCheckpointOnEvent,
)
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, CSVLogger

# Visualization
import matplotlib
import matplotlib.pyplot as plt
from sleap.nn.callbacks import TensorBoardMatplotlibWriter, MatplotlibSaver
from sleap.nn.viz import plot_img, plot_confmaps, plot_peaks, plot_pafs


logger = logging.getLogger(__name__)


[docs]@attr.s(auto_attribs=True) class DataReaders: """Container class for SLEAP labels that serve as training data sources. Attributes: training_labels_reader: LabelsReader pipeline provider for a training data from a sleap.Labels instance. validation_labels_reader: LabelsReader pipeline provider for a validation data from a sleap.Labels instance. test_labels_reader: LabelsReader pipeline provider for a test set data from a sleap.Labels instance. This is not necessary for training. """ training_labels_reader: LabelsReader validation_labels_reader: LabelsReader test_labels_reader: Optional[LabelsReader] = None
[docs] @classmethod def from_config( cls, labels_config: LabelsConfig, training: Union[Text, sleap.Labels], validation: Union[Text, sleap.Labels, float], test: Optional[Union[Text, sleap.Labels]] = None, video_search_paths: Optional[List[Text]] = None, update_config: bool = False, ) -> "DataReaders": """Create data readers from a (possibly incomplete) configuration.""" # Use config values if not provided in the arguments. if training is None: training = labels_config.training_labels if validation is None: if labels_config.validation_labels is not None: validation = labels_config.validation_labels else: validation = labels_config.validation_fraction if test is None: test = labels_config.test_labels # Update the config fields with arguments (if not a full sleap.Labels instance). if update_config: if isinstance(training, Text): labels_config.training_labels = training if isinstance(validation, Text): labels_config.validation_labels = validation elif isinstance(validation, float): labels_config.validation_fraction = validation if isinstance(test, Text): labels_config.test_labels = test # Build class. # TODO: use labels_config.search_path_hints for loading return cls.from_labels( training=training, validation=validation, test=test, video_search_paths=video_search_paths, )
[docs] @classmethod def from_labels( cls, training: Union[Text, sleap.Labels], validation: Union[Text, sleap.Labels, float], test: Optional[Union[Text, sleap.Labels]] = None, video_search_paths: Optional[List[Text]] = None, ) -> "DataReaders": """Create data readers from sleap.Labels datasets as data providers.""" if isinstance(training, str): print("video search paths: ", video_search_paths) training = sleap.Labels.load_file(training, video_search=video_search_paths) print(training.videos) if isinstance(validation, str): validation = sleap.Labels.load_file( validation, video_search=video_search_paths ) elif isinstance(validation, float): training, validation = split_labels(training, [-1, validation]) if isinstance(test, str): test = sleap.Labels.load_file(test, video_search=video_search_paths) test_reader = None if test is not None: test_reader = LabelsReader.from_user_instances(test) return cls( training_labels_reader=LabelsReader.from_user_instances(training), validation_labels_reader=LabelsReader.from_user_instances(validation), test_labels_reader=test_reader, )
@property def training_labels(self) -> sleap.Labels: """Return the sleap.Labels underlying the training data reader.""" return self.training_labels_reader.labels @property def validation_labels(self) -> sleap.Labels: """Return the sleap.Labels underlying the validation data reader.""" return self.validation_labels_reader.labels @property def test_labels(self) -> sleap.Labels: """Return the sleap.Labels underlying the test data reader.""" if self.test_labels_reader is None: raise ValueError("No test labels provided to data reader.") return self.test_labels_reader.labels
[docs]def setup_optimizer(config: OptimizationConfig) -> tf.keras.optimizers.Optimizer: """Set up model optimizer from config.""" if config.optimizer.lower() == "adam": optimizer = tf.keras.optimizers.Adam( learning_rate=config.initial_learning_rate, amsgrad=True ) elif config.optimizer.lower() == "rmsprop": optimizer = tf.keras.optimizers.RMSprop( learning_rate=config.initial_learning_rate ) elif config.optimizer.lower() == "sgd": optimizer = tf.keras.optimizers.SGD(learning_rate=config.initial_learning_rate) else: # TODO: explicit lookup optimizer = config.optimizer return optimizer
[docs]def setup_losses(config: OptimizationConfig) -> Callable[[tf.Tensor], tf.Tensor]: """Set up model loss function from config.""" losses = [tf.keras.losses.MeanSquaredError()] if config.hard_keypoint_mining.online_mining: losses.append(OHKMLoss.from_config(config.hard_keypoint_mining)) logger.info(f" OHKM enabled: {config.hard_keypoint_mining}") def loss_fn(y_gt, y_pr): loss = 0 for loss_fn in losses: loss += loss_fn(y_gt, y_pr) return loss return loss_fn
[docs]def setup_metrics( config: OptimizationConfig, part_names: Optional[List[Text]] = None ) -> List[Union[tf.keras.losses.Loss, tf.keras.metrics.Metric]]: """Set up training metrics from config.""" metrics = [] if config.hard_keypoint_mining.online_mining: metrics.append(OHKMLoss.from_config(config.hard_keypoint_mining)) if part_names is not None: for channel_ind, part_name in enumerate(part_names): metrics.append(PartLoss(channel_ind=channel_ind, name=part_name)) return metrics
[docs]def setup_optimization_callbacks( config: OptimizationConfig, ) -> List[tf.keras.callbacks.Callback]: """Set up optimization callbacks from config.""" callbacks = [] if config.learning_rate_schedule.reduce_on_plateau: callbacks.append( ReduceLROnPlateau( monitor="val_loss", mode="min", factor=config.learning_rate_schedule.reduction_factor, patience=config.learning_rate_schedule.plateau_patience, min_delta=config.learning_rate_schedule.plateau_min_delta, cooldown=config.learning_rate_schedule.plateau_cooldown, min_lr=config.learning_rate_schedule.min_learning_rate, verbose=1, ) ) logger.info(f" Learning rate schedule: {config.learning_rate_schedule}") if config.early_stopping.stop_training_on_plateau: callbacks.append( EarlyStopping( monitor="val_loss", mode="min", patience=config.early_stopping.plateau_patience, min_delta=config.early_stopping.plateau_min_delta, verbose=1, ) ) logger.info(f" Early stopping: {config.early_stopping}") return callbacks
[docs]def get_timestamp() -> Text: """Return the date and time as a string.""" return datetime.now().strftime("%y%m%d_%H%M%S")
[docs]def setup_new_run_folder( config: OutputsConfig, base_run_name: Optional[Text] = None ) -> Text: """Create a new run folder from config.""" run_path = None if config.save_outputs: # Auto-generate run name. if config.run_name is None: config.run_name = get_timestamp() if isinstance(base_run_name, str): config.run_name = config.run_name + "." + base_run_name # Find new run name suffix if needed. if config.run_name_suffix is None: config.run_name_suffix = "" run_path = os.path.join( config.runs_folder, f"{config.run_name_prefix}{config.run_name}" ) i = 0 while os.path.exists(run_path): i += 1 config.run_name_suffix = f"_{i}" run_path = os.path.join( config.runs_folder, f"{config.run_name_prefix}{config.run_name}{config.run_name_suffix}", ) # Build run path. run_path = config.run_path return run_path
[docs]def setup_zmq_callbacks(zmq_config: ZMQConfig) -> List[tf.keras.callbacks.Callback]: """Set up ZeroMQ callbacks from config.""" callbacks = [] if zmq_config.subscribe_to_controller: callbacks.append( TrainingControllerZMQ( address=zmq_config.controller_address, poll_timeout=zmq_config.controller_polling_timeout, ) ) logger.info(f" ZMQ controller subcribed to: {zmq_config.controller_address}") if zmq_config.publish_updates: callbacks.append(ProgressReporterZMQ(address=zmq_config.publish_address)) logger.info(f" ZMQ progress reporter publish on: {zmq_config.publish_address}") return callbacks
[docs]def setup_checkpointing( config: CheckpointingConfig, run_path: Text ) -> List[tf.keras.callbacks.Callback]: """Set up model checkpointing callbacks from config.""" callbacks = [] if config.initial_model: callbacks.append( ModelCheckpointOnEvent( filepath=os.path.join(run_path, "initial_model.h5"), event="train_begin" ) ) if config.best_model: callbacks.append( ModelCheckpoint( filepath=os.path.join(run_path, "best_model.h5"), monitor="val_loss", save_best_only=True, save_weights_only=False, save_freq="epoch", verbose=0, ) ) if config.every_epoch: callbacks.append( ModelCheckpointOnEvent( filepath=os.path.join(run_path, "model.epoch%04d.h5"), event="epoch_end" ) ) if config.latest_model: callbacks.append( ModelCheckpointOnEvent( filepath=os.path.join(run_path, "latest_model.h5"), event="epoch_end" ) ) if config.final_model: callbacks.append( ModelCheckpointOnEvent( filepath=os.path.join(run_path, "final_model.h5"), event="train_end" ) ) return callbacks
[docs]def setup_tensorboard( config: TensorBoardConfig, run_path: Text ) -> List[tf.keras.callbacks.Callback]: """Set up TensorBoard callbacks from config.""" callbacks = [] if config.write_logs: callbacks.append( TensorBoard( log_dir=run_path, histogram_freq=0, write_graph=config.architecture_graph, update_freq=config.loss_frequency, profile_batch=2 if config.profile_graph else 0, embeddings_freq=0, embeddings_metadata=None, ) ) return callbacks
[docs]def setup_output_callbacks( config: OutputsConfig, run_path: Optional[Text] = None ) -> List[tf.keras.callbacks.Callback]: """Set up training outputs callbacks from config.""" callbacks = [] if config.save_outputs and run_path is not None: callbacks.extend(setup_checkpointing(config.checkpointing, run_path)) callbacks.extend(setup_tensorboard(config.tensorboard, run_path)) if config.log_to_csv: callbacks.append( CSVLogger(filename=os.path.join(run_path, "training_log.csv")) ) callbacks.extend(setup_zmq_callbacks(config.zmq)) return callbacks
[docs]def setup_visualization( config: OutputsConfig, run_path: Text, viz_fn: Callable[[], matplotlib.figure.Figure], name: Text, ) -> List[tf.keras.callbacks.Callback]: """Set up visualization callbacks from config.""" callbacks = [] try: matplotlib.use("Qt5Agg") except ImportError: print( "Unable to use Qt backend for matplotlib. " "This probably means Qt is running headless." ) if config.save_visualizations and config.save_outputs: callbacks.append( MatplotlibSaver( save_folder=os.path.join(run_path, "viz"), plot_fn=viz_fn, prefix=name ) ) if ( config.tensorboard.write_logs and config.tensorboard.visualizations and config.save_outputs ): callbacks.append( TensorBoardMatplotlibWriter( log_dir=os.path.join(run_path, name), plot_fn=viz_fn, tag=name ) ) return callbacks
[docs]def sanitize_scope_name(name: Text) -> Text: """Sanitizes string which will be used as TensorFlow scope name.""" # Add "." to beginning if first character isn't acceptable name = re.sub("^([^A-Za-z0-9.])", ".\\1", name) # Replace invalid characters with "_" name = re.sub("([^A-Za-z0-9._])", "_", name) return name
PipelineBuilder = TypeVar( "PipelineBuilder", CentroidConfmapsPipeline, TopdownConfmapsPipeline, BottomUpPipeline, SingleInstanceConfmapsPipeline, )
[docs]@attr.s(auto_attribs=True) class Trainer(ABC): """Base trainer class that provides general model training functionality. This class is intended to be instantiated using the `from_config()` class method, which will return the appropriate subclass based on the input configuration. This class should not be used directly. It is intended to be subclassed by a model output type-specific trainer that provides more specific functionality. Attributes: data_readers: A `DataReaders` instance that contains training data providers. model: A `Model` instance describing the SLEAP model to train. config: A `TrainingJobConfig` that describes the training parameters. initial_config: This attribute will contain a copy of the input configuration before any attributes are updated in `config`. pipeline_builder: A model output type-specific data pipeline builder to create pipelines that generate data used for training. This must be specified in subclasses. training_pipeline: The data pipeline that generates examples from the training set for optimization. validation_pipeline: The data pipeline that generates examples from the validation set for optimization. training_viz_pipeline: The data pipeline that generates examples from the training set for visualization. validation_viz_pipeline: The data pipeline that generates examples from the validation set for visualization. optimization_callbacks: Keras callbacks related to optimization. output_callbacks: Keras callbacks related to outputs. visualization_callbacks: Keras callbacks related to visualization. run_path: The path to the run folder that will contain training results, if any. """ data_readers: DataReaders model: Model config: TrainingJobConfig initial_config: Optional[TrainingJobConfig] = None pipeline_builder: PipelineBuilder = attr.ib(init=False) training_pipeline: Pipeline = attr.ib(init=False) validation_pipeline: Pipeline = attr.ib(init=False) training_viz_pipeline: Pipeline = attr.ib(init=False) validation_viz_pipeline: Pipeline = attr.ib(init=False) optimization_callbacks: List[tf.keras.callbacks.Callback] = attr.ib( factory=list, init=False ) output_callbacks: List[tf.keras.callbacks.Callback] = attr.ib( factory=list, init=False ) visualization_callbacks: List[tf.keras.callbacks.Callback] = attr.ib( factory=list, init=False ) run_path: Optional[Text] = attr.ib(default=None, init=False)
[docs] @classmethod def from_config( cls, config: TrainingJobConfig, training_labels: Optional[Union[Text, sleap.Labels]] = None, validation_labels: Optional[Union[Text, sleap.Labels, float]] = None, test_labels: Optional[Union[Text, sleap.Labels]] = None, video_search_paths: Optional[List[Text]] = None, ) -> "Trainer": """Initialize the trainer from a training job configuration. Args: config: A `TrainingJobConfig` instance. training_labels: Training labels to use instead of the ones in the config, if any. If a path is specified, it will overwrite the one in the config. validation_labels: Validation labels to use instead of the ones in the config, if any. If a path is specified, it will overwrite the one in the config. test_labels: Teset labels to use instead of the ones in the config, if any. If a path is specified, it will overwrite the one in the config. """ # Copy input config before we make any changes. initial_config = copy.deepcopy(config) # Create data readers and store loaded skeleton. data_readers = DataReaders.from_config( config.data.labels, training=training_labels, validation=validation_labels, test=test_labels, video_search_paths=video_search_paths, update_config=True, ) config.data.labels.skeletons = data_readers.training_labels.skeletons # Create model. model = Model.from_config( config.model, skeleton=config.data.labels.skeletons[0], update_config=True ) # Determine output type to create type-specific model trainer. head_config = config.model.heads.which_oneof() if isinstance(head_config, CentroidsHeadConfig): trainer_cls = CentroidConfmapsModelTrainer elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig): trainer_cls = TopdownConfmapsModelTrainer elif isinstance(head_config, MultiInstanceConfig): trainer_cls = BottomUpModelTrainer elif isinstance(head_config, SingleInstanceConfmapsHeadConfig): trainer_cls = SingleInstanceModelTrainer else: raise ValueError( "Model head not specified or configured. Check the config.model.heads" " setting." ) return trainer_cls( config=config, initial_config=initial_config, data_readers=data_readers, model=model, )
@abstractmethod def _update_config(self): """Implement in subclasses.""" pass @abstractmethod def _setup_pipeline_builder(self): """Implement in subclasses.""" pass @property @abstractmethod def input_keys(self): """Implement in subclasses.""" pass @property @abstractmethod def output_keys(self): """Implement in subclasses.""" pass @abstractmethod def _setup_visualization(self): """Implement in subclasses.""" pass def _setup_model(self): """Set up the keras model.""" # Infer the input shape by evaluating the data pipeline. logger.info("Building test pipeline...") t0 = time() base_pipeline = self.pipeline_builder.make_base_pipeline( self.data_readers.training_labels_reader ) base_example = next(iter(base_pipeline.make_dataset())) input_shape = base_example[self.input_keys[0]].shape # TODO: extend input shape determination for multi-input logger.info(f"Loaded test example. [{time() - t0:.3f}s]") logger.info(f" Input shape: {input_shape}") # Create the tf.keras.Model instance. self.model.make_model(input_shape) logger.info("Created Keras model.") logger.info(f" Backbone: {self.model.backbone}") logger.info(f" Max stride: {self.model.maximum_stride}") logger.info(f" Parameters: {self.model.keras_model.count_params():3,d}") logger.info(" Heads: ") for i, head in enumerate(self.model.heads): logger.info(f" heads[{i}] = {head}") @property def keras_model(self) -> tf.keras.Model: """Alias for `self.model.keras_model`.""" return self.model.keras_model def _setup_pipelines(self): """Set up training data pipelines for consumption by the keras model.""" # Create the training and validation pipelines with appropriate tensor names. key_mapper = KeyMapper( [ { input_key: input_name for input_key, input_name in zip( self.input_keys, self.keras_model.input_names ) }, { output_key: output_name for output_key, output_name in zip( self.output_keys, self.keras_model.output_names ) }, ] ) self.training_pipeline = ( self.pipeline_builder.make_training_pipeline( self.data_readers.training_labels_reader ) + key_mapper ) logger.info(f"Training set: n = {len(self.data_readers.training_labels)}") self.validation_pipeline = ( self.pipeline_builder.make_training_pipeline( self.data_readers.validation_labels_reader ) + key_mapper ) logger.info(f"Validation set: n = {len(self.data_readers.validation_labels)}") def _setup_optimization(self): """Set up optimizer, loss functions and compile the model.""" optimizer = setup_optimizer(self.config.optimization) loss_fn = setup_losses(self.config.optimization) # TODO: Implement general part loss reporting. part_names = None if isinstance(self.pipeline_builder, TopdownConfmapsPipeline): part_names = [ sanitize_scope_name(name) for name in self.model.heads[0].part_names ] metrics = setup_metrics(self.config.optimization, part_names=part_names) self.optimization_callbacks = setup_optimization_callbacks( self.config.optimization ) self.keras_model.compile( optimizer=optimizer, loss=loss_fn, metrics=metrics, loss_weights={ output_name: head.loss_weight for output_name, head in zip( self.keras_model.output_names, self.model.heads ) }, ) def _setup_outputs(self): """Set up output-related functionality.""" if self.config.outputs.save_outputs: # Build path to run folder. self.run_path = setup_new_run_folder( self.config.outputs, base_run_name=type(self.model.backbone).__name__ ) # Setup output callbacks. self.output_callbacks = setup_output_callbacks( self.config.outputs, run_path=self.run_path ) if self.run_path is not None and self.config.outputs.save_outputs: # Create run directory. os.makedirs(self.run_path, exist_ok=True) logger.info(f"Created run path: {self.run_path}") # Save configs. if self.initial_config is not None: self.initial_config.save_json( os.path.join(self.run_path, "initial_config.json") ) self.config.save_json(os.path.join(self.run_path, "training_config.json")) # Save input (ground truth) labels. sleap.Labels.save_file( self.data_readers.training_labels_reader.labels, os.path.join(self.run_path, "labels_gt.train.slp"), ) sleap.Labels.save_file( self.data_readers.validation_labels_reader.labels, os.path.join(self.run_path, "labels_gt.val.slp"), ) if self.data_readers.test_labels_reader is not None: sleap.Labels.save_file( self.data_readers.test_labels_reader.labels, os.path.join(self.run_path, "labels_gt.test.slp"), ) @property def callbacks(self) -> List[tf.keras.callbacks.Callback]: """Return all callbacks currently configured.""" callbacks = ( self.optimization_callbacks + self.visualization_callbacks + self.output_callbacks ) # Some callbacks should be called after all previous ones since they depend on # the state of some shared objects (e.g., tf.keras.Model). final_callbacks = [] for callback in callbacks[::-1]: if isinstance(callback, tf.keras.callbacks.EarlyStopping): final_callbacks.append(callback) callbacks.remove(callback) return callbacks + final_callbacks
[docs] def setup(self): """Set up data pipeline and model for training.""" logger.info(f"Setting up for training...") t0 = time() self._update_config() logger.info(f"Setting up pipeline builders...") self._setup_pipeline_builder() logger.info(f"Setting up model...") self._setup_model() logger.info(f"Setting up data pipelines...") self._setup_pipelines() logger.info(f"Setting up optimization...") self._setup_optimization() logger.info(f"Setting up outputs...") self._setup_outputs() logger.info(f"Setting up visualization...") self._setup_visualization() logger.info(f"Finished trainer set up. [{time() - t0:.1f}s]")
[docs] def train(self): """Execute the optimization loop to train the model.""" if self.keras_model is None: self.setup() logger.info(f"Creating tf.data.Datasets for training data generation...") t0 = time() training_ds = self.training_pipeline.make_dataset() validation_ds = self.validation_pipeline.make_dataset() logger.info(f"Finished creating training datasets. [{time() - t0:.1f}s]") logger.info(f"Starting training loop...") t0 = time() self.keras_model.fit( training_ds, epochs=self.config.optimization.epochs, validation_data=validation_ds, steps_per_epoch=self.config.optimization.batches_per_epoch, validation_steps=self.config.optimization.val_batches_per_epoch, callbacks=self.callbacks, verbose=2, ) logger.info(f"Finished training loop. [{(time() - t0) / 60:.1f} min]") # Save predictions and evaluations. if self.config.outputs.save_outputs: sleap.nn.evals.evaluate_model( cfg=self.config, labels_reader=self.data_readers.training_labels_reader, model=self.model, save=True, split_name="train", ) sleap.nn.evals.evaluate_model( cfg=self.config, labels_reader=self.data_readers.validation_labels_reader, model=self.model, save=True, split_name="val", ) if self.data_readers.test_labels_reader is not None: sleap.nn.evals.evaluate_model( cfg=self.config, labels_reader=self.data_readers.test_labels_reader, model=self.model, save=True, split_name="test", )
[docs]@attr.s(auto_attribs=True) class CentroidConfmapsModelTrainer(Trainer): """Trainer for models that output centroid confidence maps.""" pipeline_builder: CentroidConfmapsPipeline = attr.ib(init=False) def _update_config(self): """Update the configuration with inferred values.""" if self.config.data.preprocessing.pad_to_stride is None: self.config.data.preprocessing.pad_to_stride = self.model.maximum_stride if self.config.optimization.batches_per_epoch is None: n_training_examples = len(self.data_readers.training_labels) n_training_batches = ( n_training_examples // self.config.optimization.batch_size ) self.config.optimization.batches_per_epoch = max( self.config.optimization.min_batches_per_epoch, n_training_batches ) if self.config.optimization.val_batches_per_epoch is None: n_validation_examples = len(self.data_readers.validation_labels) n_validation_batches = ( n_validation_examples // self.config.optimization.batch_size ) self.config.optimization.val_batches_per_epoch = max( self.config.optimization.min_val_batches_per_epoch, n_validation_batches ) def _setup_pipeline_builder(self): """Initialize pipeline builder.""" self.pipeline_builder = CentroidConfmapsPipeline( data_config=self.config.data, optimization_config=self.config.optimization, centroid_confmap_head=self.model.heads[0], ) @property def input_keys(self) -> List[Text]: """Return example keys to be mapped to model inputs.""" return ["image"] @property def output_keys(self) -> List[Text]: """Return example keys to be mapped to model outputs.""" return ["centroid_confidence_maps"] def _setup_visualization(self): """Set up visualization pipelines and callbacks.""" # Create visualization/inference pipelines. self.training_viz_pipeline = self.pipeline_builder.make_viz_pipeline( self.data_readers.training_labels_reader, self.keras_model ) self.validation_viz_pipeline = self.pipeline_builder.make_viz_pipeline( self.data_readers.validation_labels_reader, self.keras_model ) # Create static iterators. training_viz_ds_iter = iter(self.training_viz_pipeline.make_dataset()) validation_viz_ds_iter = iter(self.validation_viz_pipeline.make_dataset()) def visualize_example(example): img = example["image"].numpy() cms = example["predicted_centroid_confidence_maps"].numpy() pts_gt = example["centroids"].numpy() pts_pr = example["predicted_centroids"].numpy() scale = 1.0 if img.shape[0] < 512: scale = 2.0 if img.shape[0] < 256: scale = 4.0 fig = plot_img(img, dpi=72 * scale, scale=scale) plot_confmaps(cms, output_scale=cms.shape[0] / img.shape[0]) plot_peaks(pts_gt, pts_pr, paired=False) return fig self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_example(next(training_viz_ds_iter)), name=f"train", ) ) self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_example(next(validation_viz_ds_iter)), name=f"validation", ) )
[docs]@attr.s(auto_attribs=True) class SingleInstanceModelTrainer(Trainer): """Trainer for models that output single-instance confidence maps.""" pipeline_builder: SingleInstanceConfmapsPipeline = attr.ib(init=False) def _update_config(self): """Update the configuration with inferred values.""" if self.config.data.preprocessing.pad_to_stride is None: self.config.data.preprocessing.pad_to_stride = self.model.maximum_stride if self.config.optimization.batches_per_epoch is None: n_training_examples = len( self.data_readers.training_labels_reader.labels.user_instances ) n_training_batches = ( n_training_examples // self.config.optimization.batch_size ) self.config.optimization.batches_per_epoch = max( self.config.optimization.min_batches_per_epoch, n_training_batches ) if self.config.optimization.val_batches_per_epoch is None: n_validation_examples = len( self.data_readers.validation_labels_reader.labels.user_instances ) n_validation_batches = ( n_validation_examples // self.config.optimization.batch_size ) self.config.optimization.val_batches_per_epoch = max( self.config.optimization.min_val_batches_per_epoch, n_validation_batches ) def _setup_pipeline_builder(self): # Initialize pipeline builder. self.pipeline_builder = SingleInstanceConfmapsPipeline( data_config=self.config.data, optimization_config=self.config.optimization, single_instance_confmap_head=self.model.heads[0], ) @property def input_keys(self) -> List[Text]: """Return example keys to be mapped to model inputs.""" return ["image"] @property def output_keys(self) -> List[Text]: """Return example keys to be mapped to model outputs.""" return ["confidence_maps"] def _setup_visualization(self): """Set up visualization pipelines and callbacks.""" # Create visualization/inference pipelines. self.training_viz_pipeline = self.pipeline_builder.make_viz_pipeline( self.data_readers.training_labels_reader, self.keras_model ) self.validation_viz_pipeline = self.pipeline_builder.make_viz_pipeline( self.data_readers.validation_labels_reader, self.keras_model ) # Create static iterators. training_viz_ds_iter = iter(self.training_viz_pipeline.make_dataset()) validation_viz_ds_iter = iter(self.validation_viz_pipeline.make_dataset()) def visualize_example(example): img = example["image"].numpy() cms = example["predicted_confidence_maps"].numpy() pts_gt = example["instances"].numpy()[0] pts_pr = example["predicted_points"].numpy() scale = 1.0 if img.shape[0] < 512: scale = 2.0 if img.shape[0] < 256: scale = 4.0 fig = plot_img(img, dpi=72 * scale, scale=scale) plot_confmaps(cms, output_scale=cms.shape[0] / img.shape[0]) plot_peaks(pts_gt, pts_pr, paired=True) return fig self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_example(next(training_viz_ds_iter)), name=f"train", ) ) self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_example(next(validation_viz_ds_iter)), name=f"validation", ) )
[docs]@attr.s(auto_attribs=True) class TopdownConfmapsModelTrainer(Trainer): """Trainer for models that output instance centered confidence maps.""" pipeline_builder: TopdownConfmapsPipeline = attr.ib(init=False) def _update_config(self): """Update the configuration with inferred values.""" if self.config.data.preprocessing.pad_to_stride is None: self.config.data.preprocessing.pad_to_stride = 1 if self.config.data.instance_cropping.crop_size is None: self.config.data.instance_cropping.crop_size = sleap.nn.data.instance_cropping.find_instance_crop_size( self.data_readers.training_labels, padding=self.config.data.instance_cropping.crop_size_detection_padding, maximum_stride=self.model.maximum_stride, input_scaling=self.config.data.preprocessing.input_scaling, ) if self.config.optimization.batches_per_epoch is None: n_training_examples = len( self.data_readers.training_labels_reader.labels.user_instances ) n_training_batches = ( n_training_examples // self.config.optimization.batch_size ) self.config.optimization.batches_per_epoch = max( self.config.optimization.min_batches_per_epoch, n_training_batches ) if self.config.optimization.val_batches_per_epoch is None: n_validation_examples = len( self.data_readers.validation_labels_reader.labels.user_instances ) n_validation_batches = ( n_validation_examples // self.config.optimization.batch_size ) self.config.optimization.val_batches_per_epoch = max( self.config.optimization.min_val_batches_per_epoch, n_validation_batches ) def _setup_pipeline_builder(self): # Initialize pipeline builder. self.pipeline_builder = TopdownConfmapsPipeline( data_config=self.config.data, optimization_config=self.config.optimization, instance_confmap_head=self.model.heads[0], ) @property def input_keys(self) -> List[Text]: """Return example keys to be mapped to model inputs.""" return ["instance_image"] @property def output_keys(self) -> List[Text]: """Return example keys to be mapped to model outputs.""" return ["instance_confidence_maps"] def _setup_visualization(self): """Set up visualization pipelines and callbacks.""" # Create visualization/inference pipelines. self.training_viz_pipeline = self.pipeline_builder.make_viz_pipeline( self.data_readers.training_labels_reader, self.keras_model ) self.validation_viz_pipeline = self.pipeline_builder.make_viz_pipeline( self.data_readers.validation_labels_reader, self.keras_model ) # Create static iterators. training_viz_ds_iter = iter(self.training_viz_pipeline.make_dataset()) validation_viz_ds_iter = iter(self.validation_viz_pipeline.make_dataset()) def visualize_example(example): img = example["instance_image"].numpy() cms = example["predicted_instance_confidence_maps"].numpy() pts_gt = example["center_instance"].numpy() pts_pr = example["predicted_center_instance_points"].numpy() scale = 1.0 if img.shape[0] < 512: scale = 2.0 if img.shape[0] < 256: scale = 4.0 fig = plot_img(img, dpi=72 * scale, scale=scale) plot_confmaps(cms, output_scale=cms.shape[0] / img.shape[0]) plot_peaks(pts_gt, pts_pr, paired=True) return fig self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_example(next(training_viz_ds_iter)), name=f"train", ) ) self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_example(next(validation_viz_ds_iter)), name=f"validation", ) )
[docs]@attr.s(auto_attribs=True) class BottomUpModelTrainer(Trainer): """Trainer for models that output multi-instance confidence maps and PAFs.""" pipeline_builder: BottomUpPipeline = attr.ib(init=False) def _update_config(self): """Update the configuration with inferred values.""" if self.config.data.preprocessing.pad_to_stride is None: self.config.data.preprocessing.pad_to_stride = self.model.maximum_stride if self.config.optimization.batches_per_epoch is None: n_training_examples = len(self.data_readers.training_labels) n_training_batches = ( n_training_examples // self.config.optimization.batch_size ) self.config.optimization.batches_per_epoch = max( self.config.optimization.min_batches_per_epoch, n_training_batches ) if self.config.optimization.val_batches_per_epoch is None: n_validation_examples = len(self.data_readers.validation_labels) n_validation_batches = ( n_validation_examples // self.config.optimization.batch_size ) self.config.optimization.val_batches_per_epoch = max( self.config.optimization.min_val_batches_per_epoch, n_validation_batches ) def _setup_pipeline_builder(self): # Initialize pipeline builder. self.pipeline_builder = BottomUpPipeline( data_config=self.config.data, optimization_config=self.config.optimization, confmaps_head=self.model.heads[0], pafs_head=self.model.heads[1], ) @property def input_keys(self) -> List[Text]: """Return example keys to be mapped to model inputs.""" return ["image"] @property def output_keys(self) -> List[Text]: """Return example keys to be mapped to model outputs.""" return ["confidence_maps", "part_affinity_fields"] def _setup_visualization(self): """Set up visualization pipelines and callbacks.""" # Create visualization/inference pipelines. self.training_viz_pipeline = self.pipeline_builder.make_viz_pipeline( self.data_readers.training_labels_reader, self.keras_model ) self.validation_viz_pipeline = self.pipeline_builder.make_viz_pipeline( self.data_readers.validation_labels_reader, self.keras_model ) # Create static iterators. training_viz_ds_iter = iter(self.training_viz_pipeline.make_dataset()) validation_viz_ds_iter = iter(self.validation_viz_pipeline.make_dataset()) def visualize_confmaps_example(example): img = example["image"].numpy() cms = example["predicted_confidence_maps"].numpy() pts_gt = example["instances"].numpy() pts_pr = example["predicted_peaks"].numpy() scale = 1.0 if img.shape[0] < 512: scale = 2.0 if img.shape[0] < 256: scale = 4.0 fig = plot_img(img, dpi=72 * scale, scale=scale) plot_confmaps(cms, output_scale=cms.shape[0] / img.shape[0]) plot_peaks(pts_gt, pts_pr, paired=False) return fig def visualize_pafs_example(example): img = example["image"].numpy() pafs = example["predicted_part_affinity_fields"].numpy() scale = 1.0 if img.shape[0] < 512: scale = 2.0 if img.shape[0] < 256: scale = 4.0 fig = plot_img(img, dpi=72 * scale, scale=scale) pafs = pafs.reshape((pafs.shape[0], pafs.shape[1], -1, 2)) pafs_mag = np.sqrt(pafs[..., 0] ** 2 + pafs[..., 1] ** 2) plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] / img.shape[0]) return fig self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_confmaps_example(next(training_viz_ds_iter)), name=f"train", ) ) self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_confmaps_example(next(validation_viz_ds_iter)), name=f"validation", ) ) # Memory leak: self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_pafs_example(next(training_viz_ds_iter)), name=f"train_pafs_magnitude", ) ) self.visualization_callbacks.extend( setup_visualization( self.config.outputs, run_path=self.run_path, viz_fn=lambda: visualize_pafs_example(next(validation_viz_ds_iter)), name=f"validation_pafs_magnitude", ) )
[docs]def main(): """Create CLI for training and run.""" import argparse parser = argparse.ArgumentParser() parser.add_argument( "training_job_path", help="Path to training job profile JSON file." ) parser.add_argument("labels_path", help="Path to labels file to use for training.") parser.add_argument( "--video-paths", type=str, default="", help="List of paths for finding videos in case paths inside labels file need fixing.", ) parser.add_argument( "--val_labels", "--val", help="Path to labels file to use for validation (overrides training job path if set).", ) parser.add_argument( "--test_labels", "--test", help="Path to labels file to use for test (overrides training job path if set).", ) parser.add_argument( "--tensorboard", action="store_true", help="Enables TensorBoard logging to the run path.", ) parser.add_argument( "--save_viz", action="store_true", help="Enables saving of prediction visualizations to the run folder.", ) parser.add_argument( "--zmq", action="store_true", help="Enables ZMQ logging (for GUI)." ) parser.add_argument( "--run_name", default="", help="Run name to use when saving file, overrides other run name settings.", ) parser.add_argument("--prefix", default="", help="Prefix to prepend to run name.") parser.add_argument("--suffix", default="", help="Suffix to append to run name.") args, _ = parser.parse_known_args() # Find job configuration file. job_filename = args.training_job_path if not os.path.exists(job_filename): profile_dir = get_package_file("sleap/training_profiles") if os.path.exists(os.path.join(profile_dir, job_filename)): job_filename = os.path.join(profile_dir, job_filename) else: raise FileNotFoundError(f"Could not find training profile: {job_filename}") # Load job configuration. job_config = TrainingJobConfig.load_json(job_filename) # Override config settings for CLI-based training. job_config.outputs.save_outputs = True job_config.outputs.tensorboard.write_logs = args.tensorboard job_config.outputs.zmq.publish_updates = args.zmq job_config.outputs.zmq.subscribe_to_controller = args.zmq if args.run_name != "": job_config.outputs.run_name = args.run_name if args.prefix != "": job_config.outputs.run_name_prefix = args.prefix if args.suffix != "": job_config.outputs.run_name_suffix = args.suffix job_config.outputs.save_visualizations = args.save_viz logger.info(f"Training labels file: {args.labels_path}") logger.info(f"Training profile: {job_filename}") logger.info("") # Log configuration to console. logger.info("Arguments:") logger.info(json.dumps(vars(args), indent=4)) logger.info("") logger.info("Training job:") logger.info(job_config.to_json()) logger.info("") logger.info("System:") if sleap.nn.system.is_gpu_system(): # Disable preallocation to handle Linux/low GPU memory issue. sleap.nn.system.disable_preallocation() sleap.nn.system.summary() logger.info("") logger.info("Initializing trainer...") # Create a trainer and run! trainer = Trainer.from_config( job_config, training_labels=args.labels_path, validation_labels=args.val_labels, test_labels=args.test_labels, video_search_paths=args.video_paths.split(","), ) trainer.train()
if __name__ == "__main__": main()