Source code for sleap.skeleton

"""
Implementation of skeleton data structure and API.

This module implements and API for creating animal skeletons. The goal
is to provide a common interface for defining the parts of the animal,
their connection to each other, and needed meta-data.
"""

import attr
import cattr
import numpy as np
import jsonpickle
import json
import h5py
import copy

from enum import Enum
from itertools import count
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Text

import networkx as nx
from networkx.readwrite import json_graph
from scipy.io import loadmat

NodeRef = Union[str, "Node"]
H5FileRef = Union[str, h5py.File]


[docs]class EdgeType(Enum): """ The skeleton graph can store different types of edges to represent different things. All edges must specify one or more of the following types: * BODY - these edges represent connections between parts or landmarks. * SYMMETRY - these edges represent symmetrical relationships between parts (e.g. left and right arms) """ BODY = 1 SYMMETRY = 2
[docs]@attr.s(auto_attribs=True, slots=True, eq=False, order=False) class Node: """ The class :class:`Node` represents a potential skeleton node. (But note that nodes can exist without being part of a skeleton.) """ name: str weight: float = 1.0
[docs] @staticmethod def from_names(name_list: str) -> List["Node"]: """Convert list of node names to list of nodes objects.""" nodes = [] for name in name_list: nodes.append(Node(name)) return nodes
[docs] @classmethod def as_node(cls, node: NodeRef) -> "Node": """Convert given `node` to `Node` object (if not already).""" return node if isinstance(node, cls) else cls(node)
[docs] def matches(self, other: "Node") -> bool: """ Check whether all attributes match between two nodes. Args: other: The `Node` to compare to this one. Returns: True if all attributes match, False otherwise. """ return other.name == self.name and other.weight == self.weight
[docs]class Skeleton: """ The main object for representing animal skeletons. The skeleton represents the constituent parts of the animal whose pose is being estimated. An index variable used to give skeletons a default name that should be unique across all skeletons. """ _skeleton_idx = count(0) def __init__(self, name: str = None): """ Initialize an empty skeleton object. Skeleton objects, once created, can be modified by adding nodes and edges. Args: name: A name for this skeleton. """ # If no skeleton was create, try to create a unique name for this Skeleton. if name is None or not isinstance(name, str) or not name: name = "Skeleton-" + str(next(self._skeleton_idx)) # Since networkx does not keep edges in the order we insert them we need # to keep track of how many edges have been inserted so we can number them # as they are inserted and sort them by this numbering when the edge list # is returned. self._graph: nx.MultiDiGraph = nx.MultiDiGraph(name=name, num_edges_inserted=0)
[docs] def matches(self, other: "Skeleton") -> bool: """ Compare this `Skeleton` to another, ignoring skeleton name and the identities of the `Node` objects in each graph. Args: other: The other skeleton. Returns: True if match, False otherwise. """ def dict_match(dict1, dict2): return dict1 == dict2 # Check if the graphs are iso-morphic is_isomorphic = nx.is_isomorphic( self._graph, other._graph, node_match=dict_match ) if not is_isomorphic: return False # Now check that the nodes have the same labels and order. They can have # different weights I guess?! for node1, node2 in zip(self._graph.nodes, other._graph.nodes): if node1.name != node2.name: return False # Check if the two graphs are equal return True
@property def is_arborescence(self) -> bool: return nx.algorithms.tree.recognition.is_arborescence(self._graph) @property def in_degree_over_one(self) -> List[Node]: return [node for node, in_degree in self._graph.in_degree if in_degree > 1] @property def root_nodes(self) -> List[Node]: return [node for node, in_degree in self._graph.in_degree if in_degree == 0] @property def cycles(self) -> List[List[Node]]: return list(nx.algorithms.simple_cycles(self._graph)) @property def graph(self): """Returns subgraph of BODY edges for skeleton.""" edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.BODY ] # TODO: properly induce subgraph for MultiDiGraph # Currently, NetworkX will just return the nodes in the subgraph. # See: https://stackoverflow.com/questions/16150557/networkxcreating-a-subgraph-induced-from-edges return self._graph.edge_subgraph(edges) @property def graph_symmetry(self): """Returns subgraph of symmetric edges for skeleton.""" edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.SYMMETRY ] return self._graph.edge_subgraph(edges)
[docs] @staticmethod def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: """ Find all unique nodes from a list of skeletons. Args: skeletons: The list of skeletons. Returns: A list of unique `Node` objects. """ return list({node for skeleton in skeletons for node in skeleton.nodes})
[docs] @staticmethod def make_cattr(idx_to_node: Dict[int, Node] = None) -> cattr.Converter: """ Make cattr.Convert() for `Skeleton`. Make a cattr.Converter() that registers structure/unstructure hooks for Skeleton objects to handle serialization of skeletons. Args: idx_to_node: A dict that maps node index to Node objects. Returns: A cattr.Converter() instance for skeleton serialization and deserialization. """ node_to_idx = ( {node: idx for idx, node in idx_to_node.items()} if idx_to_node is not None else None ) _cattr = cattr.Converter() _cattr.register_unstructure_hook( Skeleton, lambda x: Skeleton.to_dict(x, node_to_idx) ) _cattr.register_structure_hook( Skeleton, lambda x, cls: Skeleton.from_dict(x, idx_to_node) ) return _cattr
@property def name(self) -> str: """Get the name of the skeleton. Returns: A string representing the name of the skeleton. """ return self._graph.name @name.setter def name(self, name: str): """ A skeleton object cannot change its name. This property is immutable because it is used to hash skeletons. If you want to rename a Skeleton you must use the class method :code:`rename_skeleton`: >>> new_skeleton = Skeleton.rename_skeleton( >>> skeleton=old_skeleton, name="New Name") Args: name: The name of the Skeleton. Raises: NotImplementedError: Error is always raised. """ raise NotImplementedError( "Cannot change Skeleton name, it is immutable since " "it is used for hashing. Create a copy of the skeleton " "with new name using " f"new_skeleton = Skeleton.rename(skeleton, '{name}'))" )
[docs] @classmethod def rename_skeleton(cls, skeleton: "Skeleton", name: str) -> "Skeleton": """ Make copy of skeleton with new name. This property is immutable because it is used to hash skeletons. If you want to rename a Skeleton you must use this class method. >>> new_skeleton = Skeleton.rename_skeleton( >>> skeleton=old_skeleton, name="New Name") Args: skeleton: The skeleton to copy. name: The name of the new skeleton. Returns: A new deep copied skeleton with the changed name. """ new_skeleton = cls(name) new_skeleton._graph = copy.deepcopy(skeleton._graph) new_skeleton._graph.name = name return new_skeleton
@property def nodes(self) -> List[Node]: """Get a list of :class:`Node`s. Returns: A list of :class:`Node`s """ return list(self._graph.nodes) @property def node_names(self) -> List[str]: """Get a list of node names. Returns: A list of node names. """ return [node.name for node in self.nodes] @property def edges(self) -> List[Tuple[Node, Node]]: """Get a list of edge tuples. Returns: list of (src_node, dst_node) """ edge_list = [ (d["edge_insert_idx"], src, dst) for src, dst, key, d in self._graph.edges(keys=True, data=True) if d["type"] == EdgeType.BODY ] # We don't want to return the edge list in the order it is stored. We # want to use the insertion order. Sort by the insertion index for each # edge then drop it from the edge list. edge_list = [(src, dst) for _, src, dst in sorted(edge_list)] return edge_list @property def edge_names(self) -> List[Tuple[str, str]]: """Get a list of edge name tuples. Returns: list of (src_node.name, dst_node.name) """ edge_list = [ (d["edge_insert_idx"], src.name, dst.name) for src, dst, key, d in self._graph.edges(keys=True, data=True) if d["type"] == EdgeType.BODY ] # We don't want to return the edge list in the order it is stored. We # want to use the insertion order. Sort by the insertion index for each # edge then drop it from the edge list. edge_list = [(src, dst) for _, src, dst in sorted(edge_list)] return [(src.name, dst.name) for src, dst in self.edges] @property def edge_inds(self) -> List[Tuple[int, int]]: """Get a list of edges as node indices. Returns: A list of (src_node_ind, dst_node_ind), where indices are subscripts into the Skeleton.nodes list. """ return [ (self.nodes.index(src_node), self.nodes.index(dst_node)) for src_node, dst_node in self.edges ] @property def edges_full(self) -> List[Tuple[Node, Node, Any, Any]]: """Get a list of edge tuples with keys and attributes. Returns: list of (src_node, dst_node, key, attributes) """ return [ (src, dst, key, attr) for src, dst, key, attr in self._graph.edges(keys=True, data=True) if attr["type"] == EdgeType.BODY ] @property def symmetries(self) -> List[Tuple[Node, Node]]: """Get a list of all symmetries without duplicates. Returns: list of (node1, node2) """ # Find all symmetric edges symmetries = [ (src, dst) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") if edge_type == EdgeType.SYMMETRY ] # Get rid of duplicates symmetries = list(set([tuple(set(e)) for e in symmetries])) return symmetries @property def symmetries_full(self) -> List[Tuple[Node, Node, Any, Any]]: """Get a list of all symmetries with keys and attributes. Note: The returned list will contain duplicates (node1, node2) and (node2, node1). Returns: list of (node1, node2, key, attr) """ # Find all symmetric edges return [ (src, dst, key, attr) for src, dst, key, attr in self._graph.edges(keys=True, data=True) if attr["type"] == EdgeType.SYMMETRY ]
[docs] def node_to_index(self, node: NodeRef) -> int: """ Return the index of the node, accepts either `Node` or name. Args: node: The name of the node or the Node object. Raises: ValueError if node cannot be found in skeleton. Returns: The index of the node in the graph. """ node_list = list(self._graph.nodes) try: return node_list.index(node) except ValueError: return node_list.index(self.find_node(node))
[docs] def edge_to_index(self, source: NodeRef, destination: NodeRef): """Returns the index of edge from source to destination.""" source = self.find_node(source) destination = self.find_node(destination) edge = (source, destination) if edge in self.edges: return self.edges.index(edge) return -1
[docs] def add_node(self, name: str): """Add a node representing an animal part to the skeleton. Args: name: The name of the node to add to the skeleton. This name must be unique within the skeleton. Raises: ValueError: If name is not unique. Returns: None """ if not isinstance(name, str): raise TypeError("Cannot add nodes to the skeleton that are not str") if name in self.node_names: raise ValueError("Skeleton already has a node named ({})".format(name)) self._graph.add_node(Node(name))
[docs] def add_nodes(self, name_list: List[str]): """ Add a list of nodes representing animal parts to the skeleton. Args: name_list: List of strings representing the nodes. Returns: None """ for node in name_list: self.add_node(node)
[docs] def delete_node(self, name: str): """Remove a node from the skeleton. The method removes a node from the skeleton and any edge that is connected to it. Args: name: The name of the node to remove Raises: ValueError: If node cannot be found. Returns: None """ try: node = self.find_node(name) self._graph.remove_node(node) except nx.NetworkXError: raise ValueError( "The node named ({}) does not exist, cannot remove it.".format(name) )
[docs] def find_node(self, name: NodeRef) -> Node: """Find node in skeleton by name of node. Args: name: The name of the :class:`Node` (or a :class:`Node`) Returns: `Node`, or None if no match found """ if isinstance(name, Node): name = name.name nodes = [node for node in self.nodes if node.name == name] if len(nodes) == 1: return nodes[0] if len(nodes) > 1: raise ValueError("Found multiple nodes named ({}).".format(name)) return None
[docs] def add_edge(self, source: str, destination: str): """Add an edge between two nodes. Args: source: The name of the source node. destination: The name of the destination node. Raises: ValueError: If source or destination nodes cannot be found, or if edge already exists between those nodes. Returns: None. """ if isinstance(source, Node): source_node = source source = source_node.name else: source_node = self.find_node(source) if isinstance(destination, Node): destination_node = destination destination = destination_node.name else: destination_node = self.find_node(destination) if source_node is None: raise ValueError( "Skeleton does not have source node named ({})".format(source) ) if destination_node is None: raise ValueError( "Skeleton does not have destination node named ({})".format(destination) ) if self._graph.has_edge(source_node, destination_node): raise ValueError( "Skeleton already has an edge between ({}) and ({}).".format( source, destination ) ) self._graph.add_edge( source_node, destination_node, type=EdgeType.BODY, edge_insert_idx=self._graph.graph["num_edges_inserted"], ) self._graph.graph["num_edges_inserted"] = ( self._graph.graph["num_edges_inserted"] + 1 )
[docs] def delete_edge(self, source: str, destination: str): """Delete an edge between two nodes. Args: source: The name of the source node. destination: The name of the destination node. Raises: ValueError: If skeleton does not have either source node, destination node, or edge between them. Returns: None """ if isinstance(source, Node): source_node = source source = source_node.name else: source_node = self.find_node(source) if isinstance(destination, Node): destination_node = destination destination = destination_node.name else: destination_node = self.find_node(destination) if source_node is None: raise ValueError( "Skeleton does not have source node named ({})".format(source) ) if destination_node is None: raise ValueError( "Skeleton does not have destination node named ({})".format(destination) ) if not self._graph.has_edge(source_node, destination_node): raise ValueError( "Skeleton has no edge between ({}) and ({}).".format( source, destination ) ) self._graph.remove_edge(source_node, destination_node)
[docs] def clear_edges(self): """Deletes all edges in skeleton.""" for src, dst in self.edges: self.delete_edge(src, dst)
[docs] def add_symmetry(self, node1: str, node2: str): """Specify that two parts (nodes) in skeleton are symmetrical. Certain parts of an animal body can be related as symmetrical parts in a pair. For example, left and right hands of a person. Args: node1: The name of the first part in the symmetric pair node2: The name of the second part in the symmetric pair Raises: ValueError: If node1 and node2 match, or if there is already a symmetry between them. Returns: None """ node1_node, node2_node = self.find_node(node1), self.find_node(node2) # We will represent symmetric pairs in the skeleton via additional edges in the _graph # These edges will have a special attribute signifying they are not part of the skeleton itself if node1 == node2: raise ValueError("Cannot add symmetry to the same node.") if self.get_symmetry(node1) is not None: raise ValueError( f"{node1} is already symmetric with {self.get_symmetry(node1)}." ) if self.get_symmetry(node2) is not None: raise ValueError( f"{node2} is already symmetric with {self.get_symmetry(node2)}." ) self._graph.add_edge(node1_node, node2_node, type=EdgeType.SYMMETRY) self._graph.add_edge(node2_node, node1_node, type=EdgeType.SYMMETRY)
[docs] def delete_symmetry(self, node1: NodeRef, node2: NodeRef): """ Deletes a previously established symmetry between two nodes. Args: node1: One node (by `Node` object or name) in symmetric pair. node2: Other node in symmetric pair. Raises: ValueError: If there's no symmetry between node1 and node2. Returns: None """ node1_node = self.find_node(node1) node2_node = self.find_node(node2) if ( self.get_symmetry(node1) != node2_node or self.get_symmetry(node2) != node1_node ): raise ValueError(f"Nodes {node1}, {node2} are not symmetric.") edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges( [node1_node, node2_node], keys=True, data="type" ) if edge_type == EdgeType.SYMMETRY ] self._graph.remove_edges_from(edges)
[docs] def get_symmetry(self, node: NodeRef) -> Optional[Node]: """ Returns the node symmetric with the specified node. Args: node: Node (by `Node` object or name) to query. Raises: ValueError: If node has more than one symmetry. Returns: The symmetric :class:`Node`, None if no symmetry. """ node_node = self.find_node(node) symmetry = [ dst for src, dst, edge_type in self._graph.edges(node_node, data="type") if edge_type == EdgeType.SYMMETRY ] if len(symmetry) == 0: return None elif len(symmetry) == 1: return symmetry[0] else: raise ValueError(f"{node} has more than one symmetry.")
[docs] def get_symmetry_name(self, node: NodeRef) -> Optional[str]: """ Returns the name of the node symmetric with the specified node. Args: node: Node (by `Node` object or name) to query. Returns: Name of symmetric node, None if no symmetry. """ symmetric_node = self.get_symmetry(node) return None if symmetric_node is None else symmetric_node.name
def __getitem__(self, node_name: str) -> dict: """ Retrieves the node data associated with skeleton node. Args: node_name: The name from which to retrieve data. Raises: ValueError: If node cannot be found. Returns: A dictionary of data associated with this node. """ node = self.find_node(node_name) if node is None: raise ValueError(f"Skeleton does not have node named '{node_name}'.") return self._graph.nodes.data()[node] def __contains__(self, node_name: str) -> bool: """ Checks if specified node exists in skeleton. Args: node_name: the node name to query Returns: True if node is in the skeleton. """ return self.has_node(node_name)
[docs] def relabel_node(self, old_name: str, new_name: str): """ Relabel a single node to a new name. Args: old_name: The old name of the node. new_name: The new name of the node. Returns: None """ self.relabel_nodes({old_name: new_name})
[docs] def relabel_nodes(self, mapping: Dict[str, str]): """ Relabel the nodes of the skeleton. Args: mapping: A dictionary with the old labels as keys and new labels as values. A partial mapping is allowed. Raises: ValueError: If node already present with one of the new names. Returns: None """ existing_nodes = self.nodes for old_name, new_name in mapping.items(): if self.has_node(new_name): raise ValueError("Cannot relabel a node to an existing name.") node = self.find_node(old_name) if node is not None: node.name = new_name
# self._graph = nx.relabel_nodes(G=self._graph, mapping=mapping)
[docs] def has_node(self, name: str) -> bool: """ Check whether the skeleton has a node. Args: name: The name of the node to check for. Returns: True for yes, False for no. """ return name in self.node_names
[docs] def has_nodes(self, names: Iterable[str]) -> bool: """ Check whether the skeleton has a list of nodes. Args: names: The list names of the nodes to check for. Returns: True for yes, False for no. """ current_node_names = self.node_names for name in names: if name not in current_node_names: return False return True
[docs] def has_edge(self, source_name: str, dest_name: str) -> bool: """ Check whether the skeleton has an edge. Args: source_name: The name of the source node for the edge. dest_name: The name of the destination node for the edge. Returns: True is yes, False if no. """ source_node, destination_node = ( self.find_node(source_name), self.find_node(dest_name), ) return self._graph.has_edge(source_node, destination_node)
[docs] @staticmethod def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> Dict: """ Convert skeleton to dict; used for saving as JSON. Args: obj: the :object:`Skeleton` to convert node_to_idx: optional dict which maps :class:`Node`sto index in some list. This is used when saving :class:`Labels`where we want to serialize the :class:`Nodes` outside the :class:`Skeleton` object. If given, then we replace each :class:`Node` with specified index before converting :class:`Skeleton`. Otherwise, we convert :class:`Node` objects with the rest of the :class:`Skeleton`. Returns: dict with data from skeleton """ # This is a weird hack to serialize the whole _graph into a dict. # I use the underlying to_json and parse it. return json.loads(obj.to_json(node_to_idx))
[docs] @classmethod def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": """ Create skeleton from dict; used for loading from JSON. Args: d: the dict from which to deserialize node_to_idx: optional dict which maps :class:`Node`sto index in some list. This is used when saving :class:`Labels`where we want to serialize the :class:`Nodes` outside the :class:`Skeleton` object. If given, then we replace each :class:`Node` with specified index before converting :class:`Skeleton`. Otherwise, we convert :class:`Node` objects with the rest of the :class:`Skeleton`. Returns: :class:`Skeleton`. """ return Skeleton.from_json(json.dumps(d), node_to_idx)
[docs] @classmethod def from_names_and_edge_inds( cls, node_names: List[Text], edge_inds: List[Tuple[int, int]] = None ) -> "Skeleton": """Create skeleton from a list of node names and edge indices. Args: node_names: List of strings defining the nodes. edge_inds: List of tuples in the form (src_node_ind, dst_node_ind). If not specified, the resulting skeleton will have no edges. Returns: The instantiated skeleton. """ skeleton = cls() skeleton.add_nodes(node_names) if edge_inds is not None: for src, dst in edge_inds: skeleton.add_edge(node_names[src], node_names[dst]) return skeleton
[docs] def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: """ Convert the :class:`Skeleton` to a JSON representation. Args: node_to_idx: optional dict which maps :class:`Node`sto index in some list. This is used when saving :class:`Labels`where we want to serialize the :class:`Nodes` outside the :class:`Skeleton` object. If given, then we replace each :class:`Node` with specified index before converting :class:`Skeleton`. Otherwise, we convert :class:`Node` objects with the rest of the :class:`Skeleton`. Returns: A string containing the JSON representation of the skeleton. """ jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) if node_to_idx is not None: indexed_node_graph = nx.relabel_nodes( G=self._graph, mapping=node_to_idx ) # map nodes to int else: indexed_node_graph = self._graph # Encode to JSON json_str = jsonpickle.encode(json_graph.node_link_data(indexed_node_graph)) return json_str
[docs] def save_json(self, filename: str, node_to_idx: Optional[Dict[Node, int]] = None): """ Save the :class:`Skeleton` as JSON file. Output the complete skeleton to a file in JSON format. Args: filename: The filename to save the JSON to. node_to_idx: optional dict which maps :class:`Node`sto index in some list. This is used when saving :class:`Labels`where we want to serialize the :class:`Nodes` outside the :class:`Skeleton` object. If given, then we replace each :class:`Node` with specified index before converting :class:`Skeleton`. Otherwise, we convert :class:`Node` objects with the rest of the :class:`Skeleton`. Returns: None """ json_str = self.to_json(node_to_idx) with open(filename, "w") as file: file.write(json_str)
[docs] @classmethod def from_json( cls, json_str: str, idx_to_node: Dict[int, Node] = None ) -> "Skeleton": """ Instantiate :class:`Skeleton` from JSON string. Args: json_str: The JSON encoded Skeleton. idx_to_node: optional dict which maps an int (indexing a list of :class:`Node` objects) to the already deserialized :class:`Node`. This should invert `node_to_idx` we used when saving. If not given, then we'll assume each :class:`Node` was left in the :class:`Skeleton` when it was saved. Returns: An instance of the `Skeleton` object decoded from the JSON. """ graph = json_graph.node_link_graph(jsonpickle.decode(json_str)) # Replace graph node indices with corresponding nodes from node_map if idx_to_node is not None: graph = nx.relabel_nodes(G=graph, mapping=idx_to_node) skeleton = Skeleton() skeleton._graph = graph return skeleton
[docs] @classmethod def load_json( cls, filename: str, idx_to_node: Dict[int, Node] = None ) -> "Skeleton": """ Load a skeleton from a JSON file. This method will load the Skeleton from JSON file saved with; :meth:`~Skeleton.save_json` Args: filename: The file that contains the JSON. idx_to_node: optional dict which maps an int (indexing a list of :class:`Node` objects) to the already deserialized :class:`Node`. This should invert `node_to_idx` we used when saving. If not given, then we'll assume each :class:`Node` was left in the :class:`Skeleton` when it was saved. Returns: The `Skeleton` object stored in the JSON filename. """ with open(filename, "r") as file: skeleton = cls.from_json(file.read(), idx_to_node) return skeleton
[docs] @classmethod def load_hdf5(cls, file: H5FileRef, name: str) -> List["Skeleton"]: """ Load a specific skeleton (by name) from the HDF5 file. Args: file: The file name or open h5py.File name: The name of the skeleton. Returns: The specified `Skeleton` instance stored in the HDF5 file. """ if isinstance(file, str): with h5py.File(file, "r") as _file: skeletons = cls._load_hdf5(_file) # Load all skeletons else: skeletons = cls._load_hdf5(file) return skeletons[name]
[docs] @classmethod def load_all_hdf5( cls, file: H5FileRef, return_dict: bool = False ) -> Union[List["Skeleton"], Dict[str, "Skeleton"]]: """ Load all skeletons found in the HDF5 file. Args: file: The file name or open h5py.File return_dict: Whether the the return value should be a dict where the keys are skeleton names and values the corresponding skeleton. If False, then method will return just a list of the skeletons. Returns: The skeleton instances stored in the HDF5 file. Either in List or Dict form. """ if isinstance(file, str): with h5py.File(file, "r") as _file: skeletons = cls._load_hdf5(_file) # Load all skeletons else: skeletons = cls._load_hdf5(file) if return_dict: return skeletons return list(skeletons.values())
@classmethod def _load_hdf5(cls, file: h5py.File): skeletons = {} for name, json_str in file["skeleton"].attrs.items(): skeletons[name] = cls.from_json(json_str) return skeletons
[docs] @classmethod def save_all_hdf5(self, file: H5FileRef, skeletons: List["Skeleton"]): """ Convenience method to save a list of skeletons to HDF5 file. Skeletons are saved as attributes of a /skeleton group in the file. Args: file: The filename or the open h5py.File object. skeletons: The list of skeletons to save. Raises: ValueError: If multiple skeletons have the same name. Returns: None """ # Make sure no skeleton has the same name unique_names = {s.name for s in skeletons} if len(unique_names) != len(skeletons): raise ValueError("Cannot save multiple Skeleton's with the same name.") for skeleton in skeletons: skeleton.save_hdf5(file)
[docs] def save_hdf5(self, file: H5FileRef): """ Wrapper for HDF5 saving which takes either filename or h5py.File. Args: file: can be filename (string) or `h5py.File` object Returns: None """ if isinstance(file, str): with h5py.File(file, "a") as _file: self._save_hdf5(_file) else: self._save_hdf5(file)
def _save_hdf5(self, file: h5py.File): """ Actual implementation of HDF5 saving. Args: file: The open h5py.File to write the skeleton data to. Returns: None """ # All skeleton will be put as sub-groups in the skeleton group if "skeleton" not in file: all_sk_group = file.create_group("skeleton", track_order=True) else: all_sk_group = file.require_group("skeleton") # Write the dataset to JSON string, then store it in a string # attribute all_sk_group.attrs[self.name] = np.string_(self.to_json())
[docs] @classmethod def load_mat(cls, filename: str) -> "Skeleton": """ Load the skeleton from a Matlab MAT file. This is to support backwards compatibility with old LEAP MATLAB code and datasets. Args: filename: The name of the skeleton file Returns: An instance of the skeleton. """ # Lets create a skeleton object, use the filename for the name since old LEAP # skeletons did not have names. skeleton = cls(name=filename) skel_mat = loadmat(filename) skel_mat["nodes"] = skel_mat["nodes"][0][0] # convert to scalar skel_mat["edges"] = skel_mat["edges"] - 1 # convert to 0-based indexing node_names = skel_mat["nodeNames"] node_names = [str(n[0][0]) for n in node_names] skeleton.add_nodes(node_names) for k in range(len(skel_mat["edges"])): edge = skel_mat["edges"][k] skeleton.add_edge( source=node_names[edge[0]], destination=node_names[edge[1]] ) return skeleton
def __str__(self): return "%s(name=%r)" % (self.__class__.__name__, self.name) def __hash__(self): """ Construct a hash from skeleton id. """ return id(self)
cattr.register_unstructure_hook(Skeleton, lambda skeleton: Skeleton.to_dict(skeleton)) cattr.register_structure_hook(Skeleton, lambda dicts, cls: Skeleton.from_dict(dicts))