Source code for sleap.nn.config.utils

"""Utilities for config building and validation."""

[docs]def oneof(attrs_cls, must_be_set: bool = False): """Ensure that the decorated attrs class only has a single attribute set. This decorator is inspired by the `oneof` protobuffer field behavior. Args: attrs_cls: An attrs decorated class. must_be_set: If True, raise an error if none of the attributes are set. If not, error will only be raised if more than one attribute is set. Returns: The `attrs_cls` with an `__init__` method that checks for the number of attributes that are set. """ # Check if the class is an attrs class at all. if not hasattr(attrs_cls, "__attrs_attrs__"): raise ValueError("Classes decorated with oneof must also be attr.s decorated.") # Pull out attrs generated class attributes. attribs = attrs_cls.__attrs_attrs__ init_fn = attrs_cls.__init__ # Define a new __init__ function that wraps the attrs generated one. def new_init_fn(self, *args, **kwargs): # Execute the standard attrs-generated __init__. init_fn(self, *args, **kwargs) # Check for attribs with set values. attribs_with_value = [ attrib for attrib in attribs if getattr(self, is not None ] if len(attribs_with_value) > 1: # Raise error if more than one attribute is set. raise ValueError("Only one attribute of this class can be set (not None).") if len(attribs_with_value) == 0 and must_be_set: # Raise error if none are set. raise ValueError("At least one attribute of this class must be set.") # Replace with wrapped __init__. attrs_cls.__init__ = new_init_fn # Define convenience method for getting the set attribute. def which_oneof_attrib_name(self): attribs_with_value = [ attrib for attrib in attribs if getattr(self, is not None ] if len(attribs_with_value) > 1: # Raise error if more than one attribute is set. raise ValueError("Only one attribute of this class can be set (not None).") if len(attribs_with_value) == 0: if must_be_set: # Raise error if none are set. raise ValueError("At least one attribute of this class must be set.") else: return None return attribs_with_value[0].name def which_oneof(self): attrib_name = self.which_oneof_attrib_name() if attrib_name is None: return None return getattr(self, attrib_name) attrs_cls.which_oneof_attrib_name = which_oneof_attrib_name attrs_cls.which_oneof = which_oneof return attrs_cls