"""
Logic for determining what color/width to draw instance nodes/edges.
The color can be determined by the current color palette as well as settings
on the `ColorManager` object:
* distinctly_color: "instances", "nodes", or "edges"
* color_predicted: whether to use colors for predicted instances, or just plot
them in yellow/grey
Initial color palette (and other settings, like default line width) is read
from user preferences but can be changed after object is created.
"""
from typing import Any, Iterable, Optional, Union, Text, Tuple
import yaml
from sleap.util import get_config_file
from sleap.instance import Instance, Track, Node
from sleap.io.dataset import Labels
from sleap.prefs import prefs
ColorTupleStringType = Text
ColorTupleType = Tuple[int, int, int]
[docs]class ColorManager:
"""Class to determine color to use for track.
The color depends on the order of the tracks in `Labels` object,
so we need to initialize with `Labels`.
Args:
labels: The :class:`Labels` dataset which contains the tracks for
which we want colors.
palette: String with the color palette name to use.
"""
def __init__(self, labels: Labels = None, palette: str = "standard"):
self.labels = labels
with open(get_config_file("colors.yaml"), "r") as f:
self._palettes = yaml.load(f, Loader=yaml.SafeLoader)
self._color_map = []
self.distinctly_color = "instances"
self.color_predicted = True
self.index_mode = "cycle"
self._index_mode_functions = dict(
cycle=lambda i, c: i % c, clip=lambda i, c: min(i, c - 1)
)
self.set_palette(palette)
self.uncolored_prediction_color = (250, 250, 10)
if prefs["bold lines"]:
self.thick_pen_width = 6
else:
self.thick_pen_width = 3
self.medium_pen_width = self.thick_pen_width // 2
self.default_pen_width = max(1, self.thick_pen_width // 4)
@property
def labels(self):
"""Gets or sets labels dataset for which we are coloring tracks."""
return self._labels
@labels.setter
def labels(self, val):
self._labels = val
@property
def palette(self):
"""Gets or sets palette (by name)."""
return self._palette
@palette.setter
def palette(self, palette: Union[Text, Iterable[ColorTupleStringType]]):
self._palette = palette
if isinstance(palette, Text):
self.index_mode = "clip" if palette.endswith("+") else "cycle"
if palette in self._palettes:
self._color_map = self._palettes[palette]
else:
# Can't find palette by name so just use standard palette.
self._color_map = self._palettes["standard"]
else:
# If palette is not given by name, it should be list of
# "r,g,b" strings.
self._color_map = palette
@property
def palette_names(self) -> Iterable[Text]:
"""Gets list of palette names."""
return self._palettes.keys()
@property
def tracks(self) -> Iterable[Track]:
"""Gets tracks for project."""
if self.labels:
return self.labels.tracks
return []
[docs] def set_palette(self, palette: Union[Text, Iterable[ColorTupleStringType]]):
"""Functional alias for palette property setter."""
self.palette = palette
[docs] def fix_index(self, idx: int) -> int:
"""Returns an index within range of color palette."""
return self._index_mode_functions[self.index_mode](idx, len(self._color_map))
[docs] def get_color_by_idx(self, idx: int) -> ColorTupleType:
"""Returns color tuple corresponding to item index."""
color_idx = self.fix_index(idx)
return self.color_to_tuple(self._color_map[color_idx])
[docs] @staticmethod
def color_to_tuple(color: Union[Text, Iterable[int]]) -> ColorTupleType:
"""Convert and ensure color is (r, g, b)-tuple."""
if isinstance(color, Text):
split_string = color.split(",")
if len(split_string) != 3:
raise ValueError(f"Color '{color}' is not 'r,g,b' string.")
try:
result = tuple(map(int, split_string))
return result
except:
raise ValueError(f"Color '{color}' is not 'r,g,b' string.")
if len(color) != 3:
raise ValueError(f"Color '{color}' is not (r,g,b) tuple.")
try:
result = tuple(map(int, color))
return result
except:
raise ValueError(f"Color '{color}' is not (r,g,b) tuple.")
[docs] def get_pseudo_track_index(self, instance: "Instance") -> Union[Track, int]:
"""
Returns an index for giving track colors to instances without track.
"""
if instance.track:
return instance.track
if not instance.frame:
return 0
non_track_instances = [
inst for inst in instance.frame.instances_to_show if inst.track is None
]
return len(self.tracks) + non_track_instances.index(instance)
[docs] def get_track_color(self, track: Union[Track, int]) -> ColorTupleType:
"""Returns the color to use for a given track.
Args:
track: `Track` object or an int
Returns:
(r, g, b)-tuple
"""
track_idx = self.tracks.index(track) if isinstance(track, Track) else track
if track_idx is None:
return (0, 0, 0)
return self.get_color_by_idx(track_idx)
[docs] @classmethod
def is_sequence(cls, item) -> bool:
"""Returns whether item is a tuple or list."""
return isinstance(item, tuple) or isinstance(item, list)
[docs] @classmethod
def is_edge(cls, item) -> bool:
"""Returns whether item is an edge, i.e., pair of nodes."""
return cls.is_sequence(item) and len(item) == 2 and cls.is_node(item[0])
[docs] @staticmethod
def is_node(item) -> bool:
"""Returns whether item is a node, i.e., Node or node name."""
return isinstance(item, Node) or isinstance(item, str)
[docs] @staticmethod
def is_predicted(instance) -> bool:
"""Returns whether instance is predicted."""
return hasattr(instance, "score")
[docs] def get_item_pen_width(
self, item: Any, parent_instance: Optional[Instance] = None
) -> float:
"""Gets width of pen to use for drawing item."""
if self.is_node(item):
if self.distinctly_color == "nodes":
return self.thick_pen_width
if self.is_predicted(parent_instance):
is_first_node = item == parent_instance.skeleton.nodes[0]
return self.thick_pen_width if is_first_node else self.medium_pen_width
else:
return self.medium_pen_width
if self.is_edge(item):
if self.distinctly_color == "edges":
return self.thick_pen_width
return self.default_pen_width
[docs] def get_item_type_pen_width(self, item_type: str) -> float:
"""Gets pen width to use for given item type (as string)."""
if item_type == "node":
if self.distinctly_color == "nodes":
return self.thick_pen_width
return self.medium_pen_width
if item_type == "edge":
if self.distinctly_color == "edges":
return self.thick_pen_width
return self.default_pen_width
[docs] def get_item_color(
self,
item: Any,
parent_instance: Optional[Instance] = None,
parent_skeleton: Optional["Skeleton"] = None,
) -> ColorTupleType:
"""Gets (r, g, b) tuple of color to use for drawing item."""
if not parent_instance and isinstance(item, Instance):
parent_instance = item
if not parent_skeleton and hasattr(parent_instance, "skeleton"):
parent_skeleton = parent_instance.skeleton
is_predicted = False
if parent_instance and self.is_predicted(parent_instance):
is_predicted = True
if is_predicted and not self.color_predicted:
if isinstance(item, Node):
return self.uncolored_prediction_color
return (128, 128, 128)
if self.distinctly_color == "instances" or hasattr(item, "track"):
track = None
if hasattr(item, "track"):
track = item.track
elif parent_instance:
track = parent_instance.track
if track is None and parent_instance:
# Get an index for items without track
track = self.get_pseudo_track_index(parent_instance)
return self.get_track_color(track=track)
if self.distinctly_color == "nodes" and parent_skeleton:
node = None
if isinstance(item, Node):
node = item
elif self.is_edge(item):
# use dst node for coloring edge
node = item[1]
if node:
node_idx = parent_skeleton.node_to_index(node)
return self.get_color_by_idx(node_idx)
# return (255, 0, 0)
if self.distinctly_color == "edges" and parent_skeleton:
edge_idx = 0
if self.is_edge(item):
edge_idx = parent_skeleton.edge_to_index(*item)
elif self.is_node(item):
for i, (src, dst) in enumerate(parent_skeleton.edges):
if dst == item:
edge_idx = i
break
return self.get_color_by_idx(edge_idx)
return (0, 0, 0)