"""
Conversion between flat (form data) and hierarchical (config object) dicts.
"""
from typing import Any, Dict, Optional, Text, Tuple
import attr
import cattr
from sleap.nn.config import TrainingJobConfig, ModelConfig
[docs]@attr.s(auto_attribs=True)
class ScopedKeyDict:
"""
Class to support conversion between flat and hierarchical dictionaries.
Flat dictionaries have scoped keys, e.g., "foo.bar". These typically come
from user-editable forms.
Hierarchical dictionaries have dictionaries as values of other dictionaries,
e.g., `{"foo": {"bar": ... } }`. These are typically used when serializing
and deserializing objects.
Attributes:
key_val_dict: Data stores in a *flat* dictionary with scoped keys.
"""
key_val_dict: Dict[Text, Any]
[docs] @classmethod
def set_hierarchical_key_val(cls, current_dict: dict, key: Text, val: Any):
"""Sets value in hierarchical dict via scoped key."""
# Ignore "private" keys starting with "_"
if key[0] == "_":
return
if "." not in key:
current_dict[key] = val
else:
top_key, *subkey_list = key.split(".")
if top_key not in current_dict:
current_dict[top_key] = dict()
subkey = ".".join(subkey_list)
cls.set_hierarchical_key_val(current_dict[top_key], subkey, val)
[docs] def to_hierarchical_dict(self) -> dict:
"""Converts internal flat dictionary to hierarchical dictionary."""
hierarch_dict = dict()
for key, val in self.key_val_dict.items():
self.set_hierarchical_key_val(hierarch_dict, key, val)
return hierarch_dict
[docs] @classmethod
def from_hierarchical_dict(cls, hierarch_dict: dict):
"""Constructs object (with flat dict) from hierarchical dictionary."""
return cls(key_val_dict=cls._make_flattened_dict(hierarch_dict))
@classmethod
def _make_flattened_dict(
cls, hierarch_dicts: dict, scope_string: Text = ""
) -> dict:
flattened_dict = dict()
for key, val in hierarch_dicts.items():
if isinstance(val, Dict):
# Dict so recurse adding node to scope string
flattened_dict.update(
cls._make_flattened_dict(
hierarch_dicts=val,
scope_string=cls._subscope_key(scope_string, key),
)
)
else:
# Leafs (non-dict)
flattened_dict[cls._subscope_key(scope_string, key)] = val
return flattened_dict
@staticmethod
def _subscope_key(scope_string: Text, key: Text) -> Text:
return key if not scope_string else f"{scope_string}.{key}"
[docs]def find_backbone_name_from_key_val_dict(key_val_dict: dict):
"""Find the backbone model name from the config dictionary."""
backbone_name = None
for key in key_val_dict:
if key.startswith("model.backbone."):
backbone_name = key.split(".")[2]
return backbone_name
[docs]def resolve_strides_from_key_val_dict(
key_val_dict: dict, backbone_name: str
) -> Tuple[int, int]:
"""Find the valid max and output strides from the config dictionary."""
max_stride = key_val_dict.get(f"model.backbone.{backbone_name}.max_stride", None)
output_stride = key_val_dict.get(
f"model.backbone.{backbone_name}.output_stride", None
)
for key in [
"model.heads.centered_instance.output_stride",
"model.heads.centroid.output_stride",
"model.heads.multi_instance.confmaps.output_stride",
"model.heads.multi_instance.pafs.output_stride",
]:
stride = key_val_dict.get(key, None)
if stride is not None:
stride = int(stride)
max_stride = (
max(int(max_stride), stride) if max_stride is not None else stride
)
output_stride = (
min(int(output_stride), stride) if output_stride is not None else stride
)
if output_stride is None:
output_stride = max_stride
return max_stride, output_stride
[docs]def make_training_config_from_key_val_dict(key_val_dict: dict) -> TrainingJobConfig:
"""
Make :py:class:`TrainingJobConfig` object from flat dictionary.
Arguments:
key_val_dict: Flat dictionary from :py:class:`TrainingEditorWidget`.
Returns:
The :py:class:`TrainingJobConfig` object.
"""
apply_cfg_transforms_to_key_val_dict(key_val_dict)
cfg_dict = ScopedKeyDict(key_val_dict).to_hierarchical_dict()
cfg = cattr.structure(cfg_dict, TrainingJobConfig)
return cfg
[docs]def make_model_config_from_key_val_dict(key_val_dict: dict) -> ModelConfig:
"""
Make :py:class:`ModelConfig` object from flat dictionary.
Arguments:
key_val_dict: Flat dictionary from :py:class:`TrainingEditorWidget`.
Returns:
The :py:class:`ModelConfig` object.
"""
apply_cfg_transforms_to_key_val_dict(key_val_dict)
cfg_dict = ScopedKeyDict(key_val_dict).to_hierarchical_dict()
if "model" in cfg_dict:
cfg_dict = cfg_dict["model"]
return cattr.structure(cfg_dict, ModelConfig)