caskade package#
Submodules#
caskade.backend module#
caskade.base module#
- class caskade.base.Node(name: str | None = None, link: Node | tuple[Node] | None = None, description: str = '')[source]#
Bases:
objectBase graph node class for
caskadeobjects.The
Nodeobject is the base class for allcaskadeobjects. It is used to construct the directed acyclic graph (DAG). The primary function of theNodeobject is to manage the parent-child relationships between nodes in the graph. There is limited functionality for theNodeobject, though it implements the base versions of theactivestate andto/update_graphmethods. Theactivestate is used to communicate through the graph that the simulator is currently running. Thetomethod is used to move and/or cast the values of the parameter. Theupdate_graphmethod is used signal all parents that the graph below them has changed.Examples
Example making some
Nodeobjects and then linking/unlinking them:n1 = Node() n2 = Node() n1.link("subnode", n2) # link n2 as a child of n1, may use any str as the key n1.unlink("subnode") # alternately n1.unlink(n2) to unlink by object
- property active: bool#
- append_state(saveto: str | File)[source]#
Append the state of the node and its children to an existing HDF5 file.
- graph_dict() dict[str, dict][source]#
Return a dictionary representation of the graph below the current node.
- graph_print(dag: dict, depth: int = 0, indent: int = 4, result: str = '') str[source]#
Print the graph dictionary in a human-readable format.
- graphviz(saveto: str | None = None) graphviz.Digraph[source]#
Return a graphviz object representing the graph below the current node in the DAG.
- Parameters:
top_down ((bool, optional)) – Whether to draw the graph top-down (current node at top) or bottom-up (current node at bottom). Defaults to True.
saveto ((Optional[str], optional)) – If provided, save the graph to this file. The file extension determines the format (e.g. ‘.pdf’, ‘.png’). Defaults to None.
- property graphviz_style#
- hierarchical_link(key: str, child: Node)[source]#
Link the current
Nodeobject to anotherNodeobject as a child in a hierarchical manner. See link for more detail on linking. A hierarchical link will allow batching internally to the simulator.- Parameters:
key ((str)) – The key to link the child node with.
child ((Node)) – The child
Nodeobject to link to.
- link(key: str | tuple | Node, child: Node | tuple | None = None)[source]#
Link the current
Nodeobject to anotherNodeobject as a child.- Parameters:
key ((Union[str, Node])) – The key to link the child node with. This will also become the attribute to access the child node. After linking you will have node.key == child
child ((Optional[Node], optional)) – The child
Nodeobject to link to. Defaults to None in which case the key is used as the child and the child.name is used as the key.
Examples
Example making some
Nodeobjects and then linking/unlinking them. demonstrating multiple ways to link/unlink:n1 = Node() n2 = Node() n1.link("subnode", n2) # may use any str as the key n1.unlink("subnode") # Alternately, link by object n1.link(n2) n1.unlink(n2)
- load_state(loadfrom: str | File, index: int = -1, **kwargs)[source]#
Load the state of the node and its children.
- property memos: set[str]#
- property name: str#
- property node_str#
- property online: bool#
- save_state(saveto: str | File, appendable: bool = False)[source]#
Save the state of the node and its children, currently only works for HDF5 file types (.h5 and .hdf5).
The “state” of a node is considered to be the value of its params, however it is also possible to save other attributes of the node by adding them to the Node.saveattrs set. Simply call Node.saveattrs.add(‘attribute’) and then Node.attribute will be saved if possible. The HDF5 file will be created with the same structure as the graph, even if there are multiple paths to the same node. For example if N1 has children N2 and N3, and both N2 and N3 have the child N4, the HDF5 file will reflect this. It will be possible to find the N4 params under both ‘N1/N2/N4’ and ‘N1/N3/N4’ if inspecting the HDF5 file manually. Specifically, if N4 has the param P1 then you could access its value like this:
If the save had been set as appendable, then the value will have an extra dimension for the number of samples, this will always be the first dimension. If appendable was false then the value will simply equal the param value.
Note
You need the optional h5py package installed to use this method.
- Parameters:
saveto ((Union[str, File])) – The file to save the state to. If a string, it should be the path to an HDF5 file (ending in ‘.h5’ or ‘.hdf5’). If a File object, it should be an open HDF5 file.
appendable ((bool, optional)) – Whether to save the state in an appendable format. If True, the values will have an extra dimension for the number of samples. Defaults to False.
- to(device=None, dtype=None)[source]#
Moves and/or casts the values of the
Nodeto a particular device and/or dtype.- Parameters:
device ((Optional[torch.device], optional)) – The device to move the values to. Defaults to None.
dtype ((Optional[torch.dtype], optional)) – The desired data type. Defaults to None.
- topological_ordering() tuple[Node][source]#
Return a topological ordering of the graph below the current node. Uses Iterative Deepening DFS (Post-Order) to resolve dependencies.
caskade.collection module#
- class caskade.collection.NodeCollection(name: str | None = None, link: Node | tuple[Node] | None = None, description: str = '')[source]#
Bases:
Node,GetSetValues- property dynamic#
- property dynamic_param_groups: tuple[int]#
- property static#
- class caskade.collection.NodeList(iterable=(), name=None)[source]#
Bases:
NodeCollection,list- property graphviz_style#
- class caskade.collection.NodeTuple(iterable=None, name=None)[source]#
Bases:
NodeCollection,tuple- property graphviz_style#
caskade.context module#
- class caskade.context.ActiveContext(module: Module)[source]#
Bases:
objectContext manager to activate a module for a simulation. Only inside an ActiveContext is it possible to fill/clear the dynamic and live parameters.
caskade.decorators module#
- class caskade.decorators.active_cache(func)[source]#
Bases:
objectCaches the first evaluated result of a Module method for the duration of a simulation.
This decorator ensures that an expensive method is executed exactly once per active simulation run. Once calculated, subsequent calls to the decorated method will return the stored value, ignoring any arguments passed to it.
- WARNING:
If the method is called multiple times with different arguments in one simulation, the cached result will still be returned, which may lead to unexpected behavior. Use with caution!
Note
If you are stacking multiple decorators on a method (such as @forward or @jax.jit), @active_cache MUST be the outermost (top) decorator.
- Example::
- class FluxModel(Module):
- def __init__(self, nodes, x, M):
super().__init__() self.nodes = nodes self.x = Param(“x”, x) self.M = Param(“M”, M)
@active_cache @jax.jit # Notice active_cache is placed at the top @forward def compute_intrinsic_sed(self, w, x, M):
print(“Computing SED…”) return jnp.interp(w, self.nodes, x * M)
@forward def compute_flux(self, wavelengths):
sed = self.compute_intrinsic_sed(wavelengths) # Cached after first call flux = jnp.sum(sed) sed = self.compute_intrinsic_sed(wavelengths) # Returns cached result, no print peak = jnp.max(sed) return flux, peak
model = FluxModel(np.linspace(400, 700, 10), x=1.0, M=np.random.rand(10))
# Compute flux only calls compute_intrinsic_sed once due to caching flux, peak = model.compute_flux(wavelengths)
- caskade.decorators.forward(method)[source]#
Decorator to define a forward method for a module.
- Parameters:
method ((Callable)) – The forward method to be decorated.
Examples
Standard usage of the forward decorator:
class ExampleSim(Module): def __init__(self, a, b, c): super().__init__("example_sim") self.a = a self.b = Param("b", b) self.c = Param("c", c) @forward def example_func(self, x, b=None): return x + self.a + b E = ExampleSim(a=1, b=None, c=3) print(E.example_func(4, params=[5])) # Output: 10
- Returns:
The decorated forward method.
- Return type:
Callable
caskade.errors module#
- exception caskade.errors.ActiveStateError[source]#
Bases:
CaskadeExceptionClass for exceptions related to the active state of a node in
caskade.
- exception caskade.errors.BackendError[source]#
Bases:
CaskadeExceptionClass for exceptions related to the backend in
caskade.
- exception caskade.errors.CaskadeException[source]#
Bases:
ExceptionBase class for all exceptions in
caskade.
- exception caskade.errors.FillParamsArrayError(name, input_params, params)[source]#
Bases:
FillParamsErrorClass for exceptions related to filling parameters with ArrayLike objects in
caskade.
- exception caskade.errors.FillParamsError[source]#
Bases:
CaskadeExceptionClass for exceptions related to filling parameters in
caskade
- exception caskade.errors.FillParamsMappingError(name, children, missing_key=None)[source]#
Bases:
FillParamsErrorClass for exceptions related to filling parameters with a mapping (dict) in
caskade.
- exception caskade.errors.FillParamsSequenceError(name, input_params, dynamic_params)[source]#
Bases:
FillParamsErrorClass for exceptions related to filling parameters with a sequence (list, tuple, etc.) in
caskade.
- exception caskade.errors.GraphError[source]#
Bases:
CaskadeExceptionClass for graph exceptions in
caskade.
- exception caskade.errors.LinkToAttributeError[source]#
Bases:
GraphErrorClass for exceptions related to linking to an attribute in
caskade.
- exception caskade.errors.NodeConfigurationError[source]#
Bases:
CaskadeExceptionClass for node configuration exceptions in
caskade.
- exception caskade.errors.ParamConfigurationError[source]#
Bases:
NodeConfigurationErrorClass for parameter configuration exceptions in
caskade.
- exception caskade.errors.ParamTypeError[source]#
Bases:
CaskadeExceptionClass for exceptions related to the type of a parameter in
caskade.
caskade.mixins module#
- class caskade.mixins.GetSetValues[source]#
Bases:
object- find_index(param: Param | tuple[Param] | Module, scheme: str = 'array') int | slice[source]#
Identify what index is associated with a param in the dynamic params array.
- Parameters:
- Returns:
param_info – A int giving the index associated with the provided Param object. If the param is multi-dimensional then the result will be a slice over all indices associated with that param.
- Return type:
Union[int, slice]
- find_param(idx: int | tuple[int], group: int | None = None, scheme: str = 'array') tuple[Param, tuple[int]][source]#
Identify which param is associated with the provided index in the dynamic params array.
- Parameters:
idx (Union[int, tuple[int]]) – The index in the params array at which we wish to find the associated param.
group (Optional[int]) – If the dynamic params have multiple group values, then this argument specifies which group to check.
scheme (str) – Whether to search the array (default) params or list version of params. dict is currently unsupported.
- Returns:
param_info – A tuple with the Param object and the index within the Param value associated with idx (empty tuple if scalar). If idx is a tuple then the result is a tuple of these results.
- Return type:
tuple[Param, tuple[int]]
- from_valid(valid_params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, param_list=None, group=None) Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping[source]#
Convert valid params to input params.
- get_values(scheme='array', dynamic=True, attribute='value', group: int | None = None) Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | list[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.']] | dict[str, dict | Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.']][source]#
- set_values(params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, dynamic=True, attribute='value')[source]#
Fill the dynamic values of the module with the input values from params.
- to_valid(params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, param_list=None, group=None) Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping[source]#
Convert input params to valid params.
- property valid_context: bool#
Return True if the module is in a valid context.
caskade.module module#
- class caskade.module.Module(name: str | None = None, **kwargs)[source]#
Bases:
Node,GetSetValuesNode to represent a simulation module in the graph.
The
Moduleobject is used to represent a simulation module in the graph. These are python objects that contain the calculations for a simulation, they also hold theParamobjects that are used in the calculations. TheModuleobject has additional functionality to manage theParamobjects below it in the graph, it keeps track of alldynamicParamobjects so that at runtime their values may be filled. TheModuleobject manages its links to other nodes through attributes of the class.Examples
Example of a nested pair of
Moduleobjects and how their@forwardmethods are called:class MySim(Module): def __init__(self, a, b=None): super().__init__() self.a = a self.b = Param("b", b) @forward def myfunc(self, x, b=None): return x * self.a.otherfun(x) + b class OtherSim(Module): def __init__(self, c=None): super().__init__() self.c = Param("c", c) @forward def otherfun(self, x, c = None): return x + c othersim = OtherSim() mysim = MySim(a=othersim) # b c params = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] result = mysim.myfunc(3.0, params=params) # result is tensor([19.0, 23.0])
- property all_params#
- clear_state()[source]#
Clear the active state _value for all params below this Module in the DAG. This should not be used by a user under normal circumstances.
- property dynamic: bool#
Return True if the module has dynamic parameters as direct children.
- fill_kwargs(keys: tuple[str]) dict[str, Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.']][source]#
Fill the kwargs for an
@forwardmethod with the values of the dynamic parameters. The requested keys are matched to names ofParamobjects owned by theModule. This should not be used by the user under normal circumstances.
- fill_params(params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, dynamic=True)[source]#
Fill the dynamic/static parameters of the module with the input values from params.
- Parameters:
params ((Union[ArrayLike, Sequence, Mapping])) – The input values to fill the dynamic parameters with. The input can be an ArrayLike, a Sequence, or a Mapping.
dynamic (bool) – Operate on dynamic parameters (True, default) or static parameters (False).
- property graphviz_style#
- property node_str: str#
Returns a string representation of the node for graph visualization.
- property static: bool#
- to_dynamic(children_only=True)[source]#
Change all parameters to dynamic parameters.
- Parameters:
children_only ((bool, optional)) – If True, only convert the children of this module to dynamic. If False, convert all parameters in the graph below this module. Defaults to True.
caskade.param module#
- class caskade.param.Param(name: str, value: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | float | int | None = None, shape: tuple[int, ...] | None = None, cyclic: bool = False, valid: tuple[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | float | int | None] | None = None, units: str | None = None, dynamic: bool | None = None, group: int = 0, batch_shape: tuple[int] | None = None, dtype: Any | None = None, device: Any | None = None, **kwargs)[source]#
Bases:
NodeNode to represent a parameter in the graph.
The
Paramobject is used to represent a parameter in the graph. During runtime this will represent a value which can be used in various calculations. TheParamobject can be set to a constant value (static);Nonemeaning the value is to be provided at runtime (dynamic); anotherParamobject meaning it will take on that value at runtime (pointer); or a function of otherParamobjects to be computed at runtime (alsopointer, see user guides). These options allow users to flexibly set the behavior of the simulator.Examples
Example making some
Paramobjects:p1 = Param("test", (1.0, 2.0)) # constant value, length 2 vector p2 =Param("p2", None, (2,2)) # dynamic 2x2 matrix value p3 = Param("p3", p1) # pointer to another parameter p4 = Param("p4", lambda p: p.children["other"].value * 2) # arbitrary function of another parameter p5 = Param("p5", valid=(0.0,2*pi), units="radians", cyclic=True) # parameter with metadata
- Parameters:
name ((str)) – The name of the parameter.
value ((Optional[Union[ArrayLike, float, int]], optional)) – The value of the parameter. Defaults to None meaning dynamic.
shape ((Optional[tuple[int, ...]], optional)) – The shape of the parameter. Defaults to () meaning scalar.
cyclic ((bool, optional)) – Whether the parameter is cyclic, imposing periodic boundary conditions. Such as a rotation from 0 to 2pi. Defaults to False.
valid ((Optional[tuple[Union[ArrayLike, float, int, None]]], optional)) – The valid range of the parameter. Defaults to None meaning all of -inf to inf is valid.
units ((Optional[str], optional)) – The units of the parameter. Defaults to None.
dynamic ((bool, optional)) – Force param to be dynamic if True. If a value is provided and param is dynamic then it has a default value at call time.
(bool (batched) – If True, the param is assumed batched and the shape may now take the form (*B, *D) where *D is the shape of the value.
optional) – If True, the param is assumed batched and the shape may now take the form (*B, *D) where *D is the shape of the value.
dtype ((Optional[Any], optional)) – The data type of the parameter. Defaults to None meaning the data type will be inferred from the value.
device ((Optional[Any], optional)) – The device of the parameter. Defaults to None meaning the device will be inferred from the value.
- property batch_shape: tuple[int, ...]#
- property batched: bool#
- property cyclic: bool#
- property device: str | None#
- property dtype: str | None#
- property dynamic: bool#
- property graphviz_style#
- property group: int#
- is_valid(value=None) bool[source]#
Check if a given value is valid given this parameters allowed (valid) range.
- property node_str: str#
Returns a string representation of the node for graph visualization.
- property node_type#
- property npvalue: ndarray#
- property pointer: bool#
- property shape: tuple[int, ...]#
- property static: bool#
- to(device=None, dtype=None) Param[source]#
Moves and/or casts the values of the parameter.
- Parameters:
device ((Optional[torch.device], optional)) – The device to move the values to. Defaults to None.
dtype ((Optional[torch.dtype], optional)) – The desired data type. Defaults to None.
- to_dynamic(value=<object object>)[source]#
Change this parameter to a dynamic parameter. If a value is provided, this will be set as the dynamic value.
- to_pointer(value, link=())[source]#
Change this parameter to a pointer parameter. If a value is provided this will be set as the pointer. Either provide a Param object to point to its value, or provide a callable function to be called at runtime. It is also possible to provide a tuple of nodes to link to while creating the pointer.
- to_static(value=<object object>)[source]#
Change this parameter to a static parameter. If a value is provided this will be set as the static value.
- property valid: tuple[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None, Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None]#
- property value: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None#
caskade.tests module#
caskade.utils module#
- caskade.utils.broadcast_cat_jax(arrays, dim=-1)[source]#
Concatenates JAX arrays with broadcasting.
Behaves like jnp.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension.
- Parameters:
arrays (sequence of jnp.ndarray) – Arrays to concatenate.
dim (int) – The dimension along which to concatenate.
- Returns:
The concatenated array.
- Return type:
jnp.ndarray
- caskade.utils.broadcast_cat_numpy(arrays, dim=-1)[source]#
Concatenates NumPy arrays with broadcasting.
Behaves like np.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension.
- Parameters:
arrays (sequence of np.ndarray) – Arrays to concatenate.
dim (int) – The dimension along which to concatenate.
- Returns:
The concatenated array.
- Return type:
np.ndarray
- caskade.utils.broadcast_cat_torch(tensors, dim=-1)[source]#
Concatenates tensors with broadcasting.
It behaves like torch.cat, but first broadcasts the tensors to match on all dimensions EXCEPT the concatenation dimension.
- Parameters:
tensors (sequence of Tensors) – Tensors to concatenate.
dim (int) – The dimension along which to concatenate. Must be a negative index to ensure consistency across tensors of different ranks (e.g., -1 for the last dimension).
- Returns:
The concatenated tensor.
- Return type:
Tensor
caskade.warnings module#
- exception caskade.warnings.InvalidValueWarning(name, value, valid)[source]#
Bases:
CaskadeWarningWarning for values which fall outside the valid range.
- exception caskade.warnings.SaveStateWarning[source]#
Bases:
CaskadeWarningWarning for when an issue occurs when a state is saved.
Module contents#
- class caskade.ActiveContext(module: Module)[source]#
Bases:
objectContext manager to activate a module for a simulation. Only inside an ActiveContext is it possible to fill/clear the dynamic and live parameters.
- exception caskade.ActiveStateError[source]#
Bases:
CaskadeExceptionClass for exceptions related to the active state of a node in
caskade.
- exception caskade.BackendError[source]#
Bases:
CaskadeExceptionClass for exceptions related to the backend in
caskade.
- exception caskade.CaskadeException[source]#
Bases:
ExceptionBase class for all exceptions in
caskade.
- exception caskade.FillParamsArrayError(name, input_params, params)[source]#
Bases:
FillParamsErrorClass for exceptions related to filling parameters with ArrayLike objects in
caskade.
- exception caskade.FillParamsError[source]#
Bases:
CaskadeExceptionClass for exceptions related to filling parameters in
caskade
- exception caskade.FillParamsMappingError(name, children, missing_key=None)[source]#
Bases:
FillParamsErrorClass for exceptions related to filling parameters with a mapping (dict) in
caskade.
- exception caskade.FillParamsSequenceError(name, input_params, dynamic_params)[source]#
Bases:
FillParamsErrorClass for exceptions related to filling parameters with a sequence (list, tuple, etc.) in
caskade.
- exception caskade.GraphError[source]#
Bases:
CaskadeExceptionClass for graph exceptions in
caskade.
- exception caskade.InvalidValueWarning(name, value, valid)[source]#
Bases:
CaskadeWarningWarning for values which fall outside the valid range.
- exception caskade.LinkToAttributeError[source]#
Bases:
GraphErrorClass for exceptions related to linking to an attribute in
caskade.
- class caskade.Memo(module: Node, memo: str)[source]#
Bases:
objectSends a “memo” (a small message) to all nodes below the current one in the graph. This can be used to communicate state changes in the graph with all lower nodes. By default, the message will skip any subgraphs (hierarchical graphs) but this can be changed to ensure all nodes hear the message.
Note that memos are stored as a python set, so duplicates will be merged. Depending on your use case, it may be wise to ensure that your memo is unique.
- Parameters:
module (Module) – The caskade Module object that will propogate the memo
memo (str) – The message to send down the graph
skip_subgraphs (bool) – If True (default) any subgraphs, otherwise known as hierarchical graphs, will not get the memo.
- class caskade.Module(name: str | None = None, **kwargs)[source]#
Bases:
Node,GetSetValuesNode to represent a simulation module in the graph.
The
Moduleobject is used to represent a simulation module in the graph. These are python objects that contain the calculations for a simulation, they also hold theParamobjects that are used in the calculations. TheModuleobject has additional functionality to manage theParamobjects below it in the graph, it keeps track of alldynamicParamobjects so that at runtime their values may be filled. TheModuleobject manages its links to other nodes through attributes of the class.Examples
Example of a nested pair of
Moduleobjects and how their@forwardmethods are called:class MySim(Module): def __init__(self, a, b=None): super().__init__() self.a = a self.b = Param("b", b) @forward def myfunc(self, x, b=None): return x * self.a.otherfun(x) + b class OtherSim(Module): def __init__(self, c=None): super().__init__() self.c = Param("c", c) @forward def otherfun(self, x, c = None): return x + c othersim = OtherSim() mysim = MySim(a=othersim) # b c params = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])] result = mysim.myfunc(3.0, params=params) # result is tensor([19.0, 23.0])
- property all_params#
- clear_state()[source]#
Clear the active state _value for all params below this Module in the DAG. This should not be used by a user under normal circumstances.
- property dynamic: bool#
Return True if the module has dynamic parameters as direct children.
- fill_kwargs(keys: tuple[str]) dict[str, Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.']][source]#
Fill the kwargs for an
@forwardmethod with the values of the dynamic parameters. The requested keys are matched to names ofParamobjects owned by theModule. This should not be used by the user under normal circumstances.
- fill_params(params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, dynamic=True)[source]#
Fill the dynamic/static parameters of the module with the input values from params.
- Parameters:
params ((Union[ArrayLike, Sequence, Mapping])) – The input values to fill the dynamic parameters with. The input can be an ArrayLike, a Sequence, or a Mapping.
dynamic (bool) – Operate on dynamic parameters (True, default) or static parameters (False).
- property graphviz_style#
- property node_str: str#
Returns a string representation of the node for graph visualization.
- property static: bool#
- to_dynamic(children_only=True)[source]#
Change all parameters to dynamic parameters.
- Parameters:
children_only ((bool, optional)) – If True, only convert the children of this module to dynamic. If False, convert all parameters in the graph below this module. Defaults to True.
- class caskade.Node(name: str | None = None, link: Node | tuple[Node] | None = None, description: str = '')[source]#
Bases:
objectBase graph node class for
caskadeobjects.The
Nodeobject is the base class for allcaskadeobjects. It is used to construct the directed acyclic graph (DAG). The primary function of theNodeobject is to manage the parent-child relationships between nodes in the graph. There is limited functionality for theNodeobject, though it implements the base versions of theactivestate andto/update_graphmethods. Theactivestate is used to communicate through the graph that the simulator is currently running. Thetomethod is used to move and/or cast the values of the parameter. Theupdate_graphmethod is used signal all parents that the graph below them has changed.Examples
Example making some
Nodeobjects and then linking/unlinking them:n1 = Node() n2 = Node() n1.link("subnode", n2) # link n2 as a child of n1, may use any str as the key n1.unlink("subnode") # alternately n1.unlink(n2) to unlink by object
- property active: bool#
- append_state(saveto: str | File)[source]#
Append the state of the node and its children to an existing HDF5 file.
- graph_dict() dict[str, dict][source]#
Return a dictionary representation of the graph below the current node.
- graph_print(dag: dict, depth: int = 0, indent: int = 4, result: str = '') str[source]#
Print the graph dictionary in a human-readable format.
- graphviz(saveto: str | None = None) graphviz.Digraph[source]#
Return a graphviz object representing the graph below the current node in the DAG.
- Parameters:
top_down ((bool, optional)) – Whether to draw the graph top-down (current node at top) or bottom-up (current node at bottom). Defaults to True.
saveto ((Optional[str], optional)) – If provided, save the graph to this file. The file extension determines the format (e.g. ‘.pdf’, ‘.png’). Defaults to None.
- property graphviz_style#
- hierarchical_link(key: str, child: Node)[source]#
Link the current
Nodeobject to anotherNodeobject as a child in a hierarchical manner. See link for more detail on linking. A hierarchical link will allow batching internally to the simulator.- Parameters:
key ((str)) – The key to link the child node with.
child ((Node)) – The child
Nodeobject to link to.
- link(key: str | tuple | Node, child: Node | tuple | None = None)[source]#
Link the current
Nodeobject to anotherNodeobject as a child.- Parameters:
key ((Union[str, Node])) – The key to link the child node with. This will also become the attribute to access the child node. After linking you will have node.key == child
child ((Optional[Node], optional)) – The child
Nodeobject to link to. Defaults to None in which case the key is used as the child and the child.name is used as the key.
Examples
Example making some
Nodeobjects and then linking/unlinking them. demonstrating multiple ways to link/unlink:n1 = Node() n2 = Node() n1.link("subnode", n2) # may use any str as the key n1.unlink("subnode") # Alternately, link by object n1.link(n2) n1.unlink(n2)
- load_state(loadfrom: str | File, index: int = -1, **kwargs)[source]#
Load the state of the node and its children.
- property memos: set[str]#
- property name: str#
- property node_str#
- property online: bool#
- save_state(saveto: str | File, appendable: bool = False)[source]#
Save the state of the node and its children, currently only works for HDF5 file types (.h5 and .hdf5).
The “state” of a node is considered to be the value of its params, however it is also possible to save other attributes of the node by adding them to the Node.saveattrs set. Simply call Node.saveattrs.add(‘attribute’) and then Node.attribute will be saved if possible. The HDF5 file will be created with the same structure as the graph, even if there are multiple paths to the same node. For example if N1 has children N2 and N3, and both N2 and N3 have the child N4, the HDF5 file will reflect this. It will be possible to find the N4 params under both ‘N1/N2/N4’ and ‘N1/N3/N4’ if inspecting the HDF5 file manually. Specifically, if N4 has the param P1 then you could access its value like this:
If the save had been set as appendable, then the value will have an extra dimension for the number of samples, this will always be the first dimension. If appendable was false then the value will simply equal the param value.
Note
You need the optional h5py package installed to use this method.
- Parameters:
saveto ((Union[str, File])) – The file to save the state to. If a string, it should be the path to an HDF5 file (ending in ‘.h5’ or ‘.hdf5’). If a File object, it should be an open HDF5 file.
appendable ((bool, optional)) – Whether to save the state in an appendable format. If True, the values will have an extra dimension for the number of samples. Defaults to False.
- to(device=None, dtype=None)[source]#
Moves and/or casts the values of the
Nodeto a particular device and/or dtype.- Parameters:
device ((Optional[torch.device], optional)) – The device to move the values to. Defaults to None.
dtype ((Optional[torch.dtype], optional)) – The desired data type. Defaults to None.
- topological_ordering() tuple[Node][source]#
Return a topological ordering of the graph below the current node. Uses Iterative Deepening DFS (Post-Order) to resolve dependencies.
- class caskade.NodeCollection(name: str | None = None, link: Node | tuple[Node] | None = None, description: str = '')[source]#
Bases:
Node,GetSetValues- property dynamic#
- property dynamic_param_groups: tuple[int]#
- property static#
- exception caskade.NodeConfigurationError[source]#
Bases:
CaskadeExceptionClass for node configuration exceptions in
caskade.
- class caskade.NodeList(iterable=(), name=None)[source]#
Bases:
NodeCollection,list- property graphviz_style#
- class caskade.NodeTuple(iterable=None, name=None)[source]#
Bases:
NodeCollection,tuple- property graphviz_style#
- class caskade.OverrideParam(param: Param, value)[source]#
Bases:
objectContext manager to override a parameter value. Only inside an OverrideParam will the parameter be set to the new value.
- class caskade.Param(name: str, value: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | float | int | None = None, shape: tuple[int, ...] | None = None, cyclic: bool = False, valid: tuple[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | float | int | None] | None = None, units: str | None = None, dynamic: bool | None = None, group: int = 0, batch_shape: tuple[int] | None = None, dtype: Any | None = None, device: Any | None = None, **kwargs)[source]#
Bases:
NodeNode to represent a parameter in the graph.
The
Paramobject is used to represent a parameter in the graph. During runtime this will represent a value which can be used in various calculations. TheParamobject can be set to a constant value (static);Nonemeaning the value is to be provided at runtime (dynamic); anotherParamobject meaning it will take on that value at runtime (pointer); or a function of otherParamobjects to be computed at runtime (alsopointer, see user guides). These options allow users to flexibly set the behavior of the simulator.Examples
Example making some
Paramobjects:p1 = Param("test", (1.0, 2.0)) # constant value, length 2 vector p2 =Param("p2", None, (2,2)) # dynamic 2x2 matrix value p3 = Param("p3", p1) # pointer to another parameter p4 = Param("p4", lambda p: p.children["other"].value * 2) # arbitrary function of another parameter p5 = Param("p5", valid=(0.0,2*pi), units="radians", cyclic=True) # parameter with metadata
- Parameters:
name ((str)) – The name of the parameter.
value ((Optional[Union[ArrayLike, float, int]], optional)) – The value of the parameter. Defaults to None meaning dynamic.
shape ((Optional[tuple[int, ...]], optional)) – The shape of the parameter. Defaults to () meaning scalar.
cyclic ((bool, optional)) – Whether the parameter is cyclic, imposing periodic boundary conditions. Such as a rotation from 0 to 2pi. Defaults to False.
valid ((Optional[tuple[Union[ArrayLike, float, int, None]]], optional)) – The valid range of the parameter. Defaults to None meaning all of -inf to inf is valid.
units ((Optional[str], optional)) – The units of the parameter. Defaults to None.
dynamic ((bool, optional)) – Force param to be dynamic if True. If a value is provided and param is dynamic then it has a default value at call time.
(bool (batched) – If True, the param is assumed batched and the shape may now take the form (*B, *D) where *D is the shape of the value.
optional) – If True, the param is assumed batched and the shape may now take the form (*B, *D) where *D is the shape of the value.
dtype ((Optional[Any], optional)) – The data type of the parameter. Defaults to None meaning the data type will be inferred from the value.
device ((Optional[Any], optional)) – The device of the parameter. Defaults to None meaning the device will be inferred from the value.
- property batch_shape: tuple[int, ...]#
- property batched: bool#
- property cyclic: bool#
- property device: str | None#
- property dtype: str | None#
- property dynamic: bool#
- property graphviz_style#
- property group: int#
- is_valid(value=None) bool[source]#
Check if a given value is valid given this parameters allowed (valid) range.
- property node_str: str#
Returns a string representation of the node for graph visualization.
- property node_type#
- property npvalue: ndarray#
- property pointer: bool#
- property shape: tuple[int, ...]#
- property static: bool#
- to(device=None, dtype=None) Param[source]#
Moves and/or casts the values of the parameter.
- Parameters:
device ((Optional[torch.device], optional)) – The device to move the values to. Defaults to None.
dtype ((Optional[torch.dtype], optional)) – The desired data type. Defaults to None.
- to_dynamic(value=<object object>)[source]#
Change this parameter to a dynamic parameter. If a value is provided, this will be set as the dynamic value.
- to_pointer(value, link=())[source]#
Change this parameter to a pointer parameter. If a value is provided this will be set as the pointer. Either provide a Param object to point to its value, or provide a callable function to be called at runtime. It is also possible to provide a tuple of nodes to link to while creating the pointer.
- to_static(value=<object object>)[source]#
Change this parameter to a static parameter. If a value is provided this will be set as the static value.
- property valid: tuple[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None, Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None]#
- property value: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None#
- exception caskade.ParamConfigurationError[source]#
Bases:
NodeConfigurationErrorClass for parameter configuration exceptions in
caskade.
- exception caskade.ParamTypeError[source]#
Bases:
CaskadeExceptionClass for exceptions related to the type of a parameter in
caskade.
- exception caskade.SaveStateWarning[source]#
Bases:
CaskadeWarningWarning for when an issue occurs when a state is saved.
- class caskade.ValidContext(module: Module)[source]#
Bases:
objectContext manager to set valid values for parameters. Only inside a ValidContext will parameters automatically be assumed valid.
- class caskade.active_cache(func)[source]#
Bases:
objectCaches the first evaluated result of a Module method for the duration of a simulation.
This decorator ensures that an expensive method is executed exactly once per active simulation run. Once calculated, subsequent calls to the decorated method will return the stored value, ignoring any arguments passed to it.
- WARNING:
If the method is called multiple times with different arguments in one simulation, the cached result will still be returned, which may lead to unexpected behavior. Use with caution!
Note
If you are stacking multiple decorators on a method (such as @forward or @jax.jit), @active_cache MUST be the outermost (top) decorator.
- Example::
- class FluxModel(Module):
- def __init__(self, nodes, x, M):
super().__init__() self.nodes = nodes self.x = Param(“x”, x) self.M = Param(“M”, M)
@active_cache @jax.jit # Notice active_cache is placed at the top @forward def compute_intrinsic_sed(self, w, x, M):
print(“Computing SED…”) return jnp.interp(w, self.nodes, x * M)
@forward def compute_flux(self, wavelengths):
sed = self.compute_intrinsic_sed(wavelengths) # Cached after first call flux = jnp.sum(sed) sed = self.compute_intrinsic_sed(wavelengths) # Returns cached result, no print peak = jnp.max(sed) return flux, peak
model = FluxModel(np.linspace(400, 700, 10), x=1.0, M=np.random.rand(10))
# Compute flux only calls compute_intrinsic_sed once due to caching flux, peak = model.compute_flux(wavelengths)
- caskade.forward(method)[source]#
Decorator to define a forward method for a module.
- Parameters:
method ((Callable)) – The forward method to be decorated.
Examples
Standard usage of the forward decorator:
class ExampleSim(Module): def __init__(self, a, b, c): super().__init__("example_sim") self.a = a self.b = Param("b", b) self.c = Param("c", c) @forward def example_func(self, x, b=None): return x + self.a + b E = ExampleSim(a=1, b=None, c=3) print(E.example_func(4, params=[5])) # Output: 10
- Returns:
The decorated forward method.
- Return type:
Callable