from typing import Optional, Union, Callable, Any, Iterable
from warnings import warn
from math import prod
from numpy import ndarray
import numpy as np
from .backend import backend, ArrayLike
from .base import Node
from .errors import ParamConfigurationError, ParamTypeError, ActiveStateError
from .warnings import InvalidValueWarning
[docs]
def valid_shape(batch_shape, shape, value_shape):
# No shape to compare
if shape is None:
return True
# Determine what to compare
if batch_shape is None:
value_shape = value_shape[len(value_shape) - len(shape) :]
else:
shape = batch_shape + shape
# Definitely dont match, wrong lengths
if len(value_shape) != len(shape):
return False
# Check for match or None
return all(s is None or v == s for v, s in zip(value_shape, shape))
NULL = object()
[docs]
class Param(Node):
"""
Node to represent a parameter in the graph.
The ``Param`` object is used to represent a parameter in the graph. During
runtime this will represent a value which can be used in various
calculations. The ``Param`` object can be set to a constant value
(``static``); ``None`` meaning the value is to be provided at runtime
(``dynamic``); another ``Param`` object meaning it will take on that value
at runtime (``pointer``); or a function of other ``Param`` objects to be
computed at runtime (also ``pointer``, see user guides). These options allow
users to flexibly set the behavior of the simulator.
Examples
--------
Example making some ``Param`` objects::
p1 = Param("test", (1.0, 2.0)) # constant value, length 2 vector
p2 =Param("p2", None, (2,2)) # dynamic 2x2 matrix value
p3 = Param("p3", p1) # pointer to another parameter
p4 = Param("p4", lambda p: p.children["other"].value * 2) # arbitrary function of another parameter
p5 = Param("p5", valid=(0.0,2*pi), units="radians", cyclic=True) # parameter with metadata
Parameters
----------
name: (str)
The name of the parameter.
value: (Optional[Union[ArrayLike, float, int]], optional)
The value of the parameter. Defaults to None meaning dynamic.
shape: (Optional[tuple[int, ...]], optional)
The shape of the parameter. Defaults to () meaning scalar.
cyclic: (bool, optional)
Whether the parameter is cyclic, imposing periodic boundary conditions.
Such as a rotation from 0 to 2pi. Defaults to False.
valid: (Optional[tuple[Union[ArrayLike, float, int, None]]], optional)
The valid range of the parameter. Defaults to None meaning all of -inf
to inf is valid.
units: (Optional[str], optional)
The units of the parameter. Defaults to None.
dynamic: (bool, optional)
Force param to be dynamic if True. If a value is provided and param is dynamic
then it has a default value at call time.
batched (bool, optional):
If True, the param is assumed batched and the shape may now take the form
(*B, *D) where *D is the shape of the value.
dtype: (Optional[Any], optional)
The data type of the parameter. Defaults to None meaning the data type
will be inferred from the value.
device: (Optional[Any], optional)
The device of the parameter. Defaults to None meaning the device will be
inferred from the value.
"""
def __init__(
self,
name: str,
value: Optional[Union[ArrayLike, float, int]] = None,
shape: Optional[tuple[int, ...]] = None,
cyclic: bool = False,
valid: Optional[tuple[Union[ArrayLike, float, int, None]]] = None,
units: Optional[str] = None,
dynamic: Optional[bool] = None,
group: int = 0,
batch_shape: Optional[tuple[int]] = None,
dtype: Optional[Any] = None,
device: Optional[Any] = None,
**kwargs,
):
self._node_type = "node"
super().__init__(name=name, **kwargs)
self._shape = None
self._batch_shape = None
self._value = None
self.__value = None
self._valid = (None, None)
self._group = 0
self._dtype = dtype
self._device = device
self._cyclic = cyclic
self.shape = shape
if dynamic or (dynamic is None and value is None):
self.to_dynamic()
else:
self.to_static()
self.value = value
self.group = group
self.valid = valid
self.units = units
if batch_shape is not None:
self.batch_shape = batch_shape
@property
def dynamic(self) -> bool:
return "dynamic" in self.node_type
@property
def pointer(self) -> bool:
return "pointer" in self.node_type
@property
def static(self) -> bool:
return "static" in self.node_type
@property
def graphviz_style(self):
if self.pointer:
return {
"style": "filled",
"color": "lightgrey",
"fillcolor": "lightgrey",
"shape": "rarrow",
}
elif self.dynamic:
return {
"style": "solid",
"color": "black",
"fillcolor": "white",
"shape": "box",
}
else:
if self.__value is None:
return {
"style": "filled",
"color": "black",
"fillcolor": "grey90",
"shape": "box",
}
return {
"style": "filled",
"color": "lightgrey",
"fillcolor": "lightgrey",
"shape": "box",
}
@property
def node_type(self):
return self._node_type
@node_type.setter
def node_type(self, value):
pre_type = self.node_type
self._node_type = value
if pre_type != self.node_type:
self.update_graph()
[docs]
def to_dynamic(self, value=NULL):
"""Change this parameter to a dynamic parameter. If a value is provided,
this will be set as the dynamic value."""
# While active no value can be set
if self.active:
raise ActiveStateError(f"Cannot set parameter {self.name} dynamic value while active.")
# Catch cases where input is invalid
if isinstance(value, Param) or callable(value):
raise ParamTypeError(f"Cannot set dynamic value to pointer ({self.name}).")
if value is NULL:
if self.pointer:
try:
self.__value = self.__value(self)
except:
self.__value = None
self.node_type = "dynamic"
return
if value is not None:
value = backend.as_array(value, dtype=self._dtype, device=self._device)
if not valid_shape(self._batch_shape, self._shape, tuple(value.shape)):
if self.batched:
shape = f"{self._shape} with batch dims {self._batch_shape}"
else:
shape = str(self._shape)
raise ParamConfigurationError(
f"Value shape {value.shape} does not match param shape {shape}! Cannot update value. ({self.name})"
)
self.__value = value
self.node_type = "dynamic"
self.is_valid()
[docs]
def to_static(self, value=NULL):
"""Change this parameter to a static parameter. If a value is provided
this will be set as the static value."""
# While active no value can be set
if self.active:
raise ActiveStateError(f"Cannot set parameter {self.name} static value while active.")
# Catch cases where input is invalid
if isinstance(value, Param) or callable(value):
raise ParamTypeError(f"Cannot set static value to pointer ({self.name}).")
if value is NULL:
if self.pointer:
try:
self.__value = self.__value(self)
except:
self.__value = None
self.node_type = "static"
return
if value is not None:
value = backend.as_array(value, dtype=self._dtype, device=self._device)
if not valid_shape(self._batch_shape, self._shape, tuple(value.shape)):
if self.batched:
shape = f"{self._shape} with batch dims {self._batch_shape}"
else:
shape = str(self._shape)
raise ParamConfigurationError(
f"Value shape {value.shape} does not match param shape {shape}! Cannot update value. ({self.name})"
)
self.__value = value
self.is_valid()
self.node_type = "static"
[docs]
def to_pointer(self, value, link=()):
"""Change this parameter to a pointer parameter. If a value is provided
this will be set as the pointer. Either provide a Param object to point
to its value, or provide a callable function to be called at runtime. It
is also possible to provide a tuple of nodes to link to while creating
the pointer."""
# While active no value can be set
if self.active:
raise ActiveStateError(f"Cannot set parameter {self.name} to pointer while active")
if isinstance(value, Param):
self.link(value)
p_name = value.name
value = lambda p: p[p_name].value
elif value is not None and not callable(value):
raise ParamTypeError(f"Pointer function must be a Param or callable ({self.name})")
elif hasattr(value, "params"):
self.link(value.params)
self.link(link)
self.__value = value
self._shape = None
self.node_type = "pointer"
@property
def shape(self) -> tuple[int, ...]:
value = self.value
# 1. Handle cases where no shape template is defined
if self._shape is None:
return tuple(value.shape) if value is not None else ()
# 2. If value is missing, return the template as-is
if value is None:
return self._shape
# 3. Fill wildcards (None) in _shape using the trailing dimensions of value
# Negative indexing handles the alignment automatically
n = len(self._shape)
return tuple(v if s is None else s for s, v in zip(self._shape, value.shape[-n:]))
@shape.setter
def shape(self, shape: Optional[Iterable]):
if self.pointer:
raise ParamTypeError(
f"Cannot set shape of parameter {self.name} with node type 'pointer'"
)
if shape is None:
self._shape = None
return
value = self.value
try:
shape = tuple(shape)
except TypeError:
raise ParamConfigurationError(
f"Param shape must be iterable of ints/None, not: {type(shape)}. ({self.name})"
)
if value is None or valid_shape(self._batch_shape, shape, tuple(value.shape)):
self._shape = shape
return
raise ValueError(
f"Shape {shape} does not match the shape of the value {value.shape}! Unable to set shape. ({self.name})"
)
@property
def batched(self) -> bool:
return len(self.batch_shape) > 0
@property
def batch_shape(self) -> tuple[int, ...]:
if self._batch_shape is not None:
return self._batch_shape
try:
value = self.value
except:
value = None
if value is None:
return ()
return tuple(value.shape[: len(value.shape) - len(self.shape)])
@batch_shape.setter
def batch_shape(self, batch_shape: tuple[int]):
if self.pointer:
raise ParamTypeError(
f"Cannot set batch_shape of parameter {self.name} with node type 'pointer'"
)
self._batch_shape = batch_shape
@property
def group(self) -> int:
return self._group
@group.setter
def group(self, group: int):
assert isinstance(group, int), f"Group must be an integer ({self.name})"
pregroup = self._group
self._group = group
if pregroup != self._group:
self.update_graph()
@property
def dtype(self) -> Optional[str]:
if self._dtype is None:
try:
return self.value.dtype
except AttributeError:
pass
return self._dtype
@property
def device(self) -> Optional[str]:
if self._device is None:
try:
return self.value.device
except AttributeError:
pass
return self._device
@property
def value(self) -> Union[ArrayLike, None]:
if self._value is not None:
return self._value
if self.pointer:
value = self.__value(self)
if self.active:
self._value = value
return value
return self.__value
@value.setter
def value(self, value):
# While active no value can be set
if self.active:
if self.static and self.__value is None: # static value set live in sim
self._value = value
return
raise ActiveStateError(f"Cannot set value of parameter {self.name} while active")
if isinstance(value, Param) or callable(value):
self.to_pointer(value)
elif self.dynamic:
self.to_dynamic(value)
else:
self.to_static(value)
@property
def npvalue(self) -> ndarray:
return backend.to_numpy(self.value)
[docs]
def to(self, device=None, dtype=None) -> "Param":
"""
Moves and/or casts the values of the parameter.
Parameters
----------
device: (Optional[torch.device], optional)
The device to move the values to. Defaults to None.
dtype: (Optional[torch.dtype], optional)
The desired data type. Defaults to None.
"""
if device is not None:
self._device = device
else:
device = self.device
if dtype is not None:
self._dtype = dtype
else:
dtype = self.dtype
super().to(device=device, dtype=dtype)
if not self.pointer and self.__value is not None:
self.__value = backend.to(self.__value, device=device, dtype=dtype)
valid = self.valid
if valid[0] is not None:
valid = (backend.to(valid[0], device=device, dtype=dtype), valid[1])
if valid[1] is not None:
valid = (valid[0], backend.to(valid[1], device=device, dtype=dtype))
self.valid = valid
return self
@property
def cyclic(self) -> bool:
return self._cyclic
@cyclic.setter
def cyclic(self, cyclic: bool):
self._cyclic = cyclic
self.valid = self.valid
def _save_state_hdf5(self, h5group, appendable: bool = False, _done_save: set = None):
super()._save_state_hdf5(h5group, appendable=appendable, _done_save=_done_save)
if "value" not in self._h5group:
try:
value = self.value
except:
value = None
if value is None:
value = "None"
elif appendable:
value = backend.to_numpy(value.reshape(1, *value.shape))
else:
value = backend.to_numpy(value)
if appendable:
self._h5group.create_dataset(
"value",
data=value,
chunks=False if isinstance(value, str) else True,
maxshape=None if isinstance(value, str) else (None,) + self.shape,
compression=None if isinstance(value, str) else "gzip",
)
else:
self._h5group.create_dataset(
"value",
data=value,
)
self._h5group["value"].attrs["node_type"] = self.node_type
self._h5group["value"].attrs["appendable"] = appendable
self._h5group["value"].attrs["cyclic"] = self.cyclic
if self.valid[0] is not None:
self._h5group["value"].attrs["valid_left"] = backend.to_numpy(self.valid[0])
if self.valid[1] is not None:
self._h5group["value"].attrs["valid_right"] = backend.to_numpy(self.valid[1])
self._h5group["value"].attrs["units"] = self.units if self.units is not None else "None"
def _check_append_state_hdf5(self, h5group):
super()._check_append_state_hdf5(h5group)
if not h5group["value"].attrs["appendable"]:
raise IOError(
f"{self.name} is not appendable. Need to save the HDF5 file with `appendable=True`."
)
def _append_state_cleanup(self):
super()._append_state_cleanup()
del self.appended
def _append_state_hdf5(self, h5group):
super()._append_state_hdf5(h5group)
if not hasattr(self, "appended"):
self.appended = True
try:
value = self.value
except:
value = None
if value is not None:
h5group["value"].resize((h5group["value"].shape[0] + 1,) + self.shape)
h5group["value"][-1] = self.value
def _load_state_hdf5(self, h5group, index: int = -1, _done_load: set = None):
super()._load_state_hdf5(h5group, index=index, _done_load=_done_load)
self.cyclic = False
self.valid = None
if not self.pointer:
if isinstance(h5group["value"][()], bytes):
assert h5group["value"][()] == b"None"
value = None
elif h5group["value"].attrs["appendable"]:
value = h5group["value"][index]
else:
value = h5group["value"][()]
if "static" in h5group["value"].attrs["node_type"]:
self.to_static(value)
elif "dynamic" in h5group["value"].attrs["node_type"]:
self.to_dynamic(value)
self.units = h5group["value"].attrs["units"]
if "valid_left" in h5group["value"].attrs:
self.valid = (
h5group["value"].attrs["valid_left"],
self.valid[1],
)
if "valid_right" in h5group["value"].attrs:
self.valid = (
self.valid[0],
h5group["value"].attrs["valid_right"],
)
self.cyclic = h5group["value"].attrs["cyclic"]
@property
def valid(self) -> tuple[Optional[ArrayLike], Optional[ArrayLike]]:
return self._valid
@valid.setter
def valid(self, valid: tuple[Union[ArrayLike, float, int, None]]):
if valid is None:
valid = (None, None)
if not isinstance(valid, tuple):
raise ParamConfigurationError(f"Valid must be a tuple ({self.name})")
if len(valid) != 2:
raise ParamConfigurationError(f"Valid must be a tuple of length 2 ({self.name})")
if self.cyclic and (valid[0] is None or valid[1] is None):
raise ParamConfigurationError(f"valid must be set for cyclic parameter ({self.name})")
if valid[0] is None and valid[1] is None:
self.to_valid = self._to_valid_base
self.from_valid = self._from_valid_base
elif valid[0] is None:
self.to_valid = self._to_valid_rightvalid
self.from_valid = self._from_valid_rightvalid
valid = (None, backend.as_array(valid[1], dtype=self.dtype, device=self.device))
elif valid[1] is None:
self.to_valid = self._to_valid_leftvalid
self.from_valid = self._from_valid_leftvalid
valid = (backend.as_array(valid[0], dtype=self.dtype, device=self.device), None)
else:
if self.cyclic:
self.to_valid = self._to_valid_cyclic
self.from_valid = self._from_valid_cyclic
else:
self.to_valid = self._to_valid_fullvalid
self.from_valid = self._from_valid_fullvalid
valid = (
backend.as_array(valid[0], dtype=self.dtype, device=self.device),
backend.as_array(valid[1], dtype=self.dtype, device=self.device),
)
if backend.any(valid[0] >= valid[1]):
raise ParamConfigurationError(
f"Valid range (valid[1] - valid[0]) must be strictly positive ({self.name})"
)
self._valid = valid
self.is_valid()
[docs]
def is_valid(self, value=None) -> bool:
"""Check if a given value is valid given this parameters allowed (valid) range."""
if self.cyclic or self.pointer:
return True
if value is None:
value = self.value
if value is None:
return True
if self.valid[0] is not None and backend.any(value < self.valid[0]):
warn(InvalidValueWarning(self.name, value, self.valid))
return False
elif self.valid[1] is not None and backend.any(value > self.valid[1]):
warn(InvalidValueWarning(self.name, value, self.valid))
return False
return True
def _to_valid_base(self, value: ArrayLike) -> ArrayLike:
return value
def _to_valid_fullvalid(self, value: ArrayLike) -> ArrayLike:
value = (
backend.logit((value - self.valid[0]) / (self.valid[1] - self.valid[0])) + self.valid[0]
)
return value
def _to_valid_cyclic(self, value: ArrayLike) -> ArrayLike:
return ((value - self.valid[0]) % (self.valid[1] - self.valid[0])) + self.valid[0]
def _to_valid_leftvalid(self, value: ArrayLike) -> ArrayLike:
return backend.log(value - self.valid[0])
def _to_valid_rightvalid(self, value: ArrayLike) -> ArrayLike:
return backend.log(self.valid[1] - value)
def _from_valid_base(self, value: ArrayLike) -> ArrayLike:
return value
def _from_valid_fullvalid(self, value: ArrayLike) -> ArrayLike:
value = (
backend.sigmoid(value - self.valid[0]) * (self.valid[1] - self.valid[0]) + self.valid[0]
)
return value
def _from_valid_cyclic(self, value: ArrayLike) -> ArrayLike:
value = ((value - self.valid[0]) % (self.valid[1] - self.valid[0])) + self.valid[0]
return value
def _from_valid_leftvalid(self, value: ArrayLike) -> ArrayLike:
value = backend.exp(value) + self.valid[0]
return value
def _from_valid_rightvalid(self, value: ArrayLike) -> ArrayLike:
value = self.valid[1] - backend.exp(value)
return value
@property
def node_str(self) -> str:
"""
Returns a string representation of the node for graph visualization.
"""
try:
value = self.value
except:
value = None
if value is not None:
value = backend.to_numpy(value)
if max(1, prod(value.shape)) == 1:
return f"{self.name}|{self.node_type}: {value.item():.3g}"
elif prod(value.shape) <= 4:
value = str(np.char.mod("%.3g", value).tolist()).replace("'", "")
return f"{self.name}|{self.node_type}: {value}"
else:
return f"{self.name}|{self.node_type}: {self.shape}"
elif self.static:
return f"{self.name}|{self.node_type}: live"
return f"{self.name}|{self.node_type}"
def __repr__(self) -> str:
return self.name