Source code for caskade.collection

from .base import Node
from .param import Param
from .mixins import GetSetValues


[docs] class NodeCollection(Node, GetSetValues):
[docs] def to_dynamic(self, children_only=True): """Change all parameters to dynamic parameters. Parameters ---------- children_only: (bool, optional) If True, only convert the children of this module to dynamic. If False, convert all parameters in the graph below this module. Defaults to True. """ node_list = self.children.values() if children_only else self.topological_ordering() for node in node_list: if isinstance(node, Param) and not node.pointer: node.to_dynamic()
[docs] def to_static(self, children_only=True): """Change all parameters to static parameters. Parameters ---------- children_only: (bool, optional) If True, only convert children of this module. If False, convert all parameters in the graph below this module. Defaults to True. """ node_list = self.children.values() if children_only else self.topological_ordering() for node in node_list: if isinstance(node, Param) and not node.pointer: node.to_static()
@property def dynamic_params(self) -> tuple[Param]: T = self.topological_ordering() return tuple(filter(lambda n: isinstance(n, Param) and n.dynamic, T)) @property def dynamic_param_groups(self) -> tuple[int]: return tuple(sorted(set(p.group for p in self.dynamic_params))) @property def static_params(self) -> tuple[Param]: T = self.topological_ordering() return tuple(filter(lambda n: isinstance(n, Param) and n.static, T)) @property def pointer_params(self) -> tuple[Param]: T = self.topological_ordering() return tuple(filter(lambda n: isinstance(n, Param) and n.pointer, T))
[docs] def copy(self): raise NotImplementedError
[docs] def deepcopy(self): raise NotImplementedError
@property def dynamic(self): return any(node.dynamic for node in self) @property def static(self): return not self.dynamic def __mul__(self, other): raise NotImplementedError def __eq__(self, other): return Node.__eq__(self, other) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name})[{len(self)}]" def __hash__(self): return Node.__hash__(self)
[docs] class NodeTuple(NodeCollection, tuple): def __init__(self, iterable=None, name=None): tuple.__init__(iterable) Node.__init__(self, name=name) self.node_type = "ntuple" for node in self: if not isinstance(node, Node): raise TypeError(f"NodeTuple elements must be Node objects, not {type(node)}") self.link(node) @property def graphviz_style(self): return {"style": "solid", "color": "black", "shape": "tab"} def __getitem__(self, key): if isinstance(key, str): return Node.__getitem__(self, key) return tuple.__getitem__(self, key) def __add__(self, other): res = super().__add__(other) return NodeTuple(res)
[docs] class NodeList(NodeCollection, list): def __init__(self, iterable=(), name=None): list.__init__(self, iterable) Node.__init__(self, name) self.node_type = "nlist" self._link_nodes() @property def graphviz_style(self): return {"style": "solid", "color": "black", "shape": "folder"} def _unlink_nodes(self): for node in self: self.unlink(node) def _link_nodes(self): for node in self: if not isinstance(node, Node): raise TypeError(f"NodeList elements must be Node objects, not {type(node)}") self.link(node)
[docs] def append(self, node): self._unlink_nodes() super().append(node) self._link_nodes()
[docs] def insert(self, index, node): self._unlink_nodes() super().insert(index, node) self._link_nodes()
[docs] def extend(self, iterable): self._unlink_nodes() super().extend(iterable) self._link_nodes()
[docs] def clear(self): self._unlink_nodes() super().clear() self._link_nodes()
[docs] def pop(self, index=-1): self._unlink_nodes() node = super().pop(index) self._link_nodes() return node
[docs] def remove(self, value): self._unlink_nodes() super().remove(value) self._link_nodes()
def __getitem__(self, key): if isinstance(key, str): return Node.__getitem__(self, key) if isinstance(key, slice): return NodeList(list.__getitem__(self, key), name=self.name) return list.__getitem__(self, key) def __setitem__(self, key, value): self._unlink_nodes() super().__setitem__(key, value) self._link_nodes() def __delitem__(self, key): self._unlink_nodes() super().__delitem__(key) self._link_nodes() def __add__(self, other): res = super().__add__(other) return NodeList(res, name=self.name) def __iadd__(self, other): self._unlink_nodes() ret = super().__iadd__(other) self._link_nodes() return ret def __imul__(self, other): raise NotImplementedError