Source code for sleap.nn.heads

"""Model head definitions for defining model output types."""

import attr
from typing import Optional, Text, List, Sequence, Tuple

from sleap.nn.config import (

[docs]@attr.s(auto_attribs=True) class CentroidConfmapsHead: """Head for specifying instance centroid confidence maps.""" anchor_part: Optional[Text] = None sigma: float = 5.0 output_stride: int = 1 loss_weight: float = 1.0 @property def channels(self) -> int: """Return the number of channels in the tensor output by this head.""" return 1
[docs] @classmethod def from_config(cls, config: CentroidsHeadConfig) -> "CentroidConfmapsHead": """Create this head from a set of configurations. Attributes: config: A `CentroidsHeadConfig` instance specifying the head parameters. Returns: The instantiated head with the specified configuration options. """ return cls( anchor_part=config.anchor_part, sigma=config.sigma, output_stride=config.output_stride, )
[docs]@attr.s(auto_attribs=True) class SingleInstanceConfmapsHead: """Head for specifying single instance confidence maps.""" part_names: List[Text] sigma: float = 5.0 output_stride: int = 1 loss_weight: float = 1.0 @property def channels(self) -> int: """Return the number of channels in the tensor output by this head.""" return len(self.part_names)
[docs] @classmethod def from_config( cls, config: SingleInstanceConfmapsHeadConfig, part_names: Optional[List[Text]] = None, ) -> "SingleInstanceConfmapsHead": """Create this head from a set of configurations. Attributes: config: A `SingleInstanceConfmapsHeadConfig` instance specifying the head parameters. part_names: Text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. This must be provided if the `part_names` attribute of the configuration is not set. Returns: The instantiated head with the specified configuration options. """ if config.part_names is not None: part_names = config.part_names return cls( part_names=part_names, sigma=config.sigma, output_stride=config.output_stride, )
[docs]@attr.s(auto_attribs=True) class CenteredInstanceConfmapsHead: """Head for specifying centered instance confidence maps.""" part_names: List[Text] anchor_part: Optional[Text] = None sigma: float = 5.0 output_stride: int = 1 loss_weight: float = 1.0 @property def channels(self) -> int: """Return the number of channels in the tensor output by this head.""" return len(self.part_names)
[docs] @classmethod def from_config( cls, config: CenteredInstanceConfmapsHeadConfig, part_names: Optional[List[Text]] = None, ) -> "CenteredInstanceConfmapsHead": """Create this head from a set of configurations. Attributes: config: A `CenteredInstanceConfmapsHead` instance specifying the head parameters. part_names: Text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. This must be provided if the `part_names` attribute of the configuration is not set. Returns: The instantiated head with the specified configuration options. """ if config.part_names is not None: part_names = config.part_names return cls( part_names=part_names, anchor_part=config.anchor_part, sigma=config.sigma, output_stride=config.output_stride, )
[docs]@attr.s(auto_attribs=True) class MultiInstanceConfmapsHead: """Head for specifying multi-instance confidence maps.""" part_names: List[Text] sigma: float = 5.0 output_stride: int = 1 loss_weight: float = 1.0 @property def channels(self) -> int: """Return the number of channels in the tensor output by this head.""" return len(self.part_names)
[docs] @classmethod def from_config( cls, config: MultiInstanceConfmapsHeadConfig, part_names: Optional[List[Text]] = None, ) -> "MultiInstanceConfmapsHead": """Create this head from a set of configurations. Attributes: config: A `MultiInstanceConfmapsHeadConfig` instance specifying the head parameters. part_names: Text name of the body parts (nodes) that the head will be configured to produce. The number of parts determines the number of channels in the output. This must be provided if the `part_names` attribute of the configuration is not set. Returns: The instantiated head with the specified configuration options. """ if config.part_names is not None: part_names = config.part_names return cls( part_names=part_names, sigma=config.sigma, output_stride=config.output_stride, loss_weight=config.loss_weight, )
[docs]@attr.s(auto_attribs=True) class PartAffinityFieldsHead: """Head for specifying multi-instance part affinity fields.""" edges: Sequence[Tuple[Text, Text]] sigma: float = 15.0 output_stride: int = 1 loss_weight: float = 1.0 @property def channels(self) -> int: """Return the number of channels in the tensor output by this head.""" return int(len(self.edges) * 2)
[docs] @classmethod def from_config( cls, config: PartAffinityFieldsHeadConfig, edges: Optional[Sequence[Tuple[Text, Text]]] = None, ) -> "PartAffinityFieldsHead": """Create this head from a set of configurations. Attributes: config: A `PartAffinityFieldsHeadConfig` instance specifying the head parameters. edges: List of 2-tuples of the form `(source_node, destination_node)` that define pairs of text names of the directed edges of the graph. This must be set if the `edges` attribute of the configuration is not set. Returns: The instantiated head with the specified configuration options. """ if config.edges is not None: edges = config.edges return cls( edges=edges, sigma=config.sigma, output_stride=config.output_stride, loss_weight=config.loss_weight, )