caskade package

Contents

caskade package#

Submodules#

caskade.backend module#

class caskade.backend.Backend(backend=None)[source]#

Bases: object

all(array)[source]#
any(array)[source]#
property array_type#
property backend#
exp(array)[source]#
log(array)[source]#
setup_jax()[source]#
setup_numpy()[source]#
setup_torch()[source]#
sum(array, axis=None)[source]#

caskade.base module#

class caskade.base.Node(name: str | None = None, link: Node | tuple[Node] | None = None, description: str = '')[source]#

Bases: object

Base graph node class for caskade objects.

The Node object is the base class for all caskade objects. It is used to construct the directed acyclic graph (DAG). The primary function of the Node object is to manage the parent-child relationships between nodes in the graph. There is limited functionality for the Node object, though it implements the base versions of the active state and to / update_graph methods. The active state is used to communicate through the graph that the simulator is currently running. The to method is used to move and/or cast the values of the parameter. The update_graph method is used signal all parents that the graph below them has changed.

Examples

Example making some Node objects and then linking/unlinking them:

n1 = Node()
n2 = Node()
n1.link("subnode", n2) # link n2 as a child of n1, may use any str as the key
n1.unlink("subnode") # alternately n1.unlink(n2) to unlink by object
property active: bool#
add_memo(memo)[source]#
append_state(saveto: str | File)[source]#

Append the state of the node and its children to an existing HDF5 file.

property children: dict[str, Node]#
graph_dict() dict[str, dict][source]#

Return a dictionary representation of the graph below the current node.

graph_print(dag: dict, depth: int = 0, indent: int = 4, result: str = '') str[source]#

Print the graph dictionary in a human-readable format.

graphviz(saveto: str | None = None) graphviz.Digraph[source]#

Return a graphviz object representing the graph below the current node in the DAG.

Parameters:
  • top_down ((bool, optional)) – Whether to draw the graph top-down (current node at top) or bottom-up (current node at bottom). Defaults to True.

  • saveto ((Optional[str], optional)) – If provided, save the graph to this file. The file extension determines the format (e.g. ‘.pdf’, ‘.png’). Defaults to None.

property graphviz_style#

Link the current Node object to another Node object as a child in a hierarchical manner. See link for more detail on linking. A hierarchical link will allow batching internally to the simulator.

Parameters:
  • key ((str)) – The key to link the child node with.

  • child ((Node)) – The child Node object to link to.

Link the current Node object to another Node object as a child.

Parameters:
  • key ((Union[str, Node])) – The key to link the child node with. This will also become the attribute to access the child node. After linking you will have node.key == child

  • child ((Optional[Node], optional)) – The child Node object to link to. Defaults to None in which case the key is used as the child and the child.name is used as the key.

Examples

Example making some Node objects and then linking/unlinking them. demonstrating multiple ways to link/unlink:

n1 = Node() n2 = Node()

n1.link("subnode", n2) # may use any str as the key
n1.unlink("subnode")

# Alternately, link by object n1.link(n2) n1.unlink(n2)
load_state(loadfrom: str | File, index: int = -1, **kwargs)[source]#

Load the state of the node and its children.

property memos: set[str]#
property name: str#
property node_str#
property online: bool#
property parents: set[Node]#
remove_memo(memo)[source]#
save_state(saveto: str | File, appendable: bool = False)[source]#

Save the state of the node and its children, currently only works for HDF5 file types (.h5 and .hdf5).

The “state” of a node is considered to be the value of its params, however it is also possible to save other attributes of the node by adding them to the Node.saveattrs set. Simply call Node.saveattrs.add(‘attribute’) and then Node.attribute will be saved if possible. The HDF5 file will be created with the same structure as the graph, even if there are multiple paths to the same node. For example if N1 has children N2 and N3, and both N2 and N3 have the child N4, the HDF5 file will reflect this. It will be possible to find the N4 params under both ‘N1/N2/N4’ and ‘N1/N3/N4’ if inspecting the HDF5 file manually. Specifically, if N4 has the param P1 then you could access its value like this:

If the save had been set as appendable, then the value will have an extra dimension for the number of samples, this will always be the first dimension. If appendable was false then the value will simply equal the param value.

Note

You need the optional h5py package installed to use this method.

Parameters:
  • saveto ((Union[str, File])) – The file to save the state to. If a string, it should be the path to an HDF5 file (ending in ‘.h5’ or ‘.hdf5’). If a File object, it should be an open HDF5 file.

  • appendable ((bool, optional)) – Whether to save the state in an appendable format. If True, the values will have an extra dimension for the number of samples. Defaults to False.

property subgraphs: set[Node]#
to(device=None, dtype=None)[source]#

Moves and/or casts the values of the Node to a particular device and/or dtype.

Parameters:
  • device ((Optional[torch.device], optional)) – The device to move the values to. Defaults to None.

  • dtype ((Optional[torch.dtype], optional)) – The desired data type. Defaults to None.

topological_ordering() tuple[Node][source]#

Return a topological ordering of the graph below the current node. Uses Iterative Deepening DFS (Post-Order) to resolve dependencies.

Unlink the current Node object from another Node object which is a child.

update_graph()[source]#

Triggers a call to all parents that the graph below them has been updated. The base Node object does nothing with this information, but other node types may use this to update internal state.

caskade.collection module#

class caskade.collection.NodeCollection(name: str | None = None, link: Node | tuple[Node] | None = None, description: str = '')[source]#

Bases: Node, GetSetValues

copy()[source]#
deepcopy()[source]#
property dynamic#
property dynamic_param_groups: tuple[int]#
property dynamic_params: tuple[Param]#
property pointer_params: tuple[Param]#
property static#
property static_params: tuple[Param]#
to_dynamic(children_only=True)[source]#

Change all parameters to dynamic parameters.

Parameters:

children_only ((bool, optional)) – If True, only convert the children of this module to dynamic. If False, convert all parameters in the graph below this module. Defaults to True.

to_static(children_only=True)[source]#

Change all parameters to static parameters.

Parameters:

children_only ((bool, optional)) – If True, only convert children of this module. If False, convert all parameters in the graph below this module. Defaults to True.

class caskade.collection.NodeList(iterable=(), name=None)[source]#

Bases: NodeCollection, list

append(node)[source]#

Append object to the end of the list.

clear()[source]#

Remove all items from list.

extend(iterable)[source]#

Extend list by appending elements from the iterable.

property graphviz_style#
insert(index, node)[source]#

Insert object before index.

pop(index=-1)[source]#

Remove and return item at index (default last).

Raises IndexError if list is empty or index is out of range.

remove(value)[source]#

Remove first occurrence of value.

Raises ValueError if the value is not present.

class caskade.collection.NodeTuple(iterable=None, name=None)[source]#

Bases: NodeCollection, tuple

property graphviz_style#

caskade.context module#

class caskade.context.ActiveContext(module: Module)[source]#

Bases: object

Context manager to activate a module for a simulation. Only inside an ActiveContext is it possible to fill/clear the dynamic and live parameters.

class caskade.context.OverrideParam(param: Param, value)[source]#

Bases: object

Context manager to override a parameter value. Only inside an OverrideParam will the parameter be set to the new value.

class caskade.context.ValidContext(module: Module)[source]#

Bases: object

Context manager to set valid values for parameters. Only inside a ValidContext will parameters automatically be assumed valid.

caskade.decorators module#

class caskade.decorators.active_cache(func)[source]#

Bases: object

Caches the first evaluated result of a Module method for the duration of a simulation.

This decorator ensures that an expensive method is executed exactly once per active simulation run. Once calculated, subsequent calls to the decorated method will return the stored value, ignoring any arguments passed to it.

WARNING:

If the method is called multiple times with different arguments in one simulation, the cached result will still be returned, which may lead to unexpected behavior. Use with caution!

Note

If you are stacking multiple decorators on a method (such as @forward or @jax.jit), @active_cache MUST be the outermost (top) decorator.

Example::
class FluxModel(Module):
def __init__(self, nodes, x, M):

super().__init__() self.nodes = nodes self.x = Param(“x”, x) self.M = Param(“M”, M)

@active_cache @jax.jit # Notice active_cache is placed at the top @forward def compute_intrinsic_sed(self, w, x, M):

print(“Computing SED…”) return jnp.interp(w, self.nodes, x * M)

@forward def compute_flux(self, wavelengths):

sed = self.compute_intrinsic_sed(wavelengths) # Cached after first call flux = jnp.sum(sed) sed = self.compute_intrinsic_sed(wavelengths) # Returns cached result, no print peak = jnp.max(sed) return flux, peak

model = FluxModel(np.linspace(400, 700, 10), x=1.0, M=np.random.rand(10))

# Compute flux only calls compute_intrinsic_sed once due to caching flux, peak = model.compute_flux(wavelengths)

caskade.decorators.forward(method)[source]#

Decorator to define a forward method for a module.

Parameters:

method ((Callable)) – The forward method to be decorated.

Examples

Standard usage of the forward decorator:

class ExampleSim(Module):
    def __init__(self, a, b, c):
        super().__init__("example_sim")
        self.a = a
        self.b = Param("b", b)
        self.c = Param("c", c)

    @forward
    def example_func(self, x, b=None):
        return x + self.a + b

E = ExampleSim(a=1, b=None, c=3)
print(E.example_func(4, params=[5]))
# Output: 10
Returns:

The decorated forward method.

Return type:

Callable

caskade.errors module#

exception caskade.errors.ActiveStateError[source]#

Bases: CaskadeException

Class for exceptions related to the active state of a node in caskade.

exception caskade.errors.BackendError[source]#

Bases: CaskadeException

Class for exceptions related to the backend in caskade.

exception caskade.errors.CaskadeException[source]#

Bases: Exception

Base class for all exceptions in caskade.

exception caskade.errors.FillParamsArrayError(name, input_params, params)[source]#

Bases: FillParamsError

Class for exceptions related to filling parameters with ArrayLike objects in caskade.

exception caskade.errors.FillParamsError[source]#

Bases: CaskadeException

Class for exceptions related to filling parameters in caskade

exception caskade.errors.FillParamsMappingError(name, children, missing_key=None)[source]#

Bases: FillParamsError

Class for exceptions related to filling parameters with a mapping (dict) in caskade.

exception caskade.errors.FillParamsSequenceError(name, input_params, dynamic_params)[source]#

Bases: FillParamsError

Class for exceptions related to filling parameters with a sequence (list, tuple, etc.) in caskade.

exception caskade.errors.GraphError[source]#

Bases: CaskadeException

Class for graph exceptions in caskade.

exception caskade.errors.LinkToAttributeError[source]#

Bases: GraphError

Class for exceptions related to linking to an attribute in caskade.

exception caskade.errors.NodeConfigurationError[source]#

Bases: CaskadeException

Class for node configuration exceptions in caskade.

exception caskade.errors.ParamConfigurationError[source]#

Bases: NodeConfigurationError

Class for parameter configuration exceptions in caskade.

exception caskade.errors.ParamTypeError[source]#

Bases: CaskadeException

Class for exceptions related to the type of a parameter in caskade.

caskade.mixins module#

class caskade.mixins.GetSetValues[source]#

Bases: object

find_index(param: Param | tuple[Param] | Module, scheme: str = 'array') int | slice[source]#

Identify what index is associated with a param in the dynamic params array.

Parameters:
  • param (Union[Param, tuple[Param], Module]) – The param for which to find the associated index.

  • scheme (str) – Whether to search the array (default) params or list version of params. dict is currently unsupported.

Returns:

param_info – A int giving the index associated with the provided Param object. If the param is multi-dimensional then the result will be a slice over all indices associated with that param.

Return type:

Union[int, slice]

find_param(idx: int | tuple[int], group: int | None = None, scheme: str = 'array') tuple[Param, tuple[int]][source]#

Identify which param is associated with the provided index in the dynamic params array.

Parameters:
  • idx (Union[int, tuple[int]]) – The index in the params array at which we wish to find the associated param.

  • group (Optional[int]) – If the dynamic params have multiple group values, then this argument specifies which group to check.

  • scheme (str) – Whether to search the array (default) params or list version of params. dict is currently unsupported.

Returns:

param_info – A tuple with the Param object and the index within the Param value associated with idx (empty tuple if scalar). If idx is a tuple then the result is a tuple of these results.

Return type:

tuple[Param, tuple[int]]

from_valid(valid_params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, param_list=None, group=None) Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping[source]#

Convert valid params to input params.

get_values(scheme='array', dynamic=True, attribute='value', group: int | None = None) Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | list[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.']] | dict[str, dict | Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.']][source]#
set_values(params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, dynamic=True, attribute='value')[source]#

Fill the dynamic values of the module with the input values from params.

to_valid(params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, param_list=None, group=None) Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping[source]#

Convert input params to valid params.

property valid_context: bool#

Return True if the module is in a valid context.

caskade.module module#

class caskade.module.Module(name: str | None = None, **kwargs)[source]#

Bases: Node, GetSetValues

Node to represent a simulation module in the graph.

The Module object is used to represent a simulation module in the graph. These are python objects that contain the calculations for a simulation, they also hold the Param objects that are used in the calculations. The Module object has additional functionality to manage the Param objects below it in the graph, it keeps track of all dynamic Param objects so that at runtime their values may be filled. The Module object manages its links to other nodes through attributes of the class.

Examples

Example of a nested pair of Module objects and how their @forward methods are called:

class MySim(Module):
    def __init__(self, a, b=None):
        super().__init__()
        self.a = a
        self.b = Param("b", b)

    @forward
    def myfunc(self, x, b=None):
        return x * self.a.otherfun(x) + b

class OtherSim(Module):
    def __init__(self, c=None):
        super().__init__()
        self.c = Param("c", c)

    @forward
    def otherfun(self, x, c = None):
        return x + c

othersim = OtherSim()
mysim = MySim(a=othersim)
#                       b                         c
params = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]
result = mysim.myfunc(3.0, params=params)
# result is tensor([19.0, 23.0])
property all_params#
clear_state()[source]#

Clear the active state _value for all params below this Module in the DAG. This should not be used by a user under normal circumstances.

property dynamic: bool#

Return True if the module has dynamic parameters as direct children.

fill_kwargs(keys: tuple[str]) dict[str, Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.']][source]#

Fill the kwargs for an @forward method with the values of the dynamic parameters. The requested keys are matched to names of Param objects owned by the Module. This should not be used by the user under normal circumstances.

fill_params(params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, dynamic=True)[source]#

Fill the dynamic/static parameters of the module with the input values from params.

Parameters:
  • params ((Union[ArrayLike, Sequence, Mapping])) – The input values to fill the dynamic parameters with. The input can be an ArrayLike, a Sequence, or a Mapping.

  • dynamic (bool) – Operate on dynamic parameters (True, default) or static parameters (False).

property graphviz_style#
property node_str: str#

Returns a string representation of the node for graph visualization.

param_order()[source]#
remove_memo(memo)[source]#
property static: bool#
to_dynamic(children_only=True)[source]#

Change all parameters to dynamic parameters.

Parameters:

children_only ((bool, optional)) – If True, only convert the children of this module to dynamic. If False, convert all parameters in the graph below this module. Defaults to True.

to_static(children_only=True)[source]#

Change all parameters to static parameters.

Parameters:

children_only ((bool, optional)) – If True, only convert children of this module to static. If False, convert all parameters in the graph below this module. Defaults to True.

update_graph()[source]#

Maintain a tuple of dynamic, static, and pointer parameters at all points lower in the DAG.

caskade.param module#

class caskade.param.Param(name: str, value: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | float | int | None = None, shape: tuple[int, ...] | None = None, cyclic: bool = False, valid: tuple[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | float | int | None] | None = None, units: str | None = None, dynamic: bool | None = None, group: int = 0, batch_shape: tuple[int] | None = None, dtype: Any | None = None, device: Any | None = None, **kwargs)[source]#

Bases: Node

Node to represent a parameter in the graph.

The Param object is used to represent a parameter in the graph. During runtime this will represent a value which can be used in various calculations. The Param object can be set to a constant value (static); None meaning the value is to be provided at runtime (dynamic); another Param object meaning it will take on that value at runtime (pointer); or a function of other Param objects to be computed at runtime (also pointer, see user guides). These options allow users to flexibly set the behavior of the simulator.

Examples

Example making some Param objects:

p1 = Param("test", (1.0, 2.0)) # constant value, length 2 vector
p2 =Param("p2", None, (2,2)) # dynamic 2x2 matrix value
p3 = Param("p3", p1) # pointer to another parameter
p4 = Param("p4", lambda p: p.children["other"].value * 2) # arbitrary function of another parameter
p5 = Param("p5", valid=(0.0,2*pi), units="radians", cyclic=True) # parameter with metadata
Parameters:
  • name ((str)) – The name of the parameter.

  • value ((Optional[Union[ArrayLike, float, int]], optional)) – The value of the parameter. Defaults to None meaning dynamic.

  • shape ((Optional[tuple[int, ...]], optional)) – The shape of the parameter. Defaults to () meaning scalar.

  • cyclic ((bool, optional)) – Whether the parameter is cyclic, imposing periodic boundary conditions. Such as a rotation from 0 to 2pi. Defaults to False.

  • valid ((Optional[tuple[Union[ArrayLike, float, int, None]]], optional)) – The valid range of the parameter. Defaults to None meaning all of -inf to inf is valid.

  • units ((Optional[str], optional)) – The units of the parameter. Defaults to None.

  • dynamic ((bool, optional)) – Force param to be dynamic if True. If a value is provided and param is dynamic then it has a default value at call time.

  • (bool (batched) – If True, the param is assumed batched and the shape may now take the form (*B, *D) where *D is the shape of the value.

  • optional) – If True, the param is assumed batched and the shape may now take the form (*B, *D) where *D is the shape of the value.

  • dtype ((Optional[Any], optional)) – The data type of the parameter. Defaults to None meaning the data type will be inferred from the value.

  • device ((Optional[Any], optional)) – The device of the parameter. Defaults to None meaning the device will be inferred from the value.

property batch_shape: tuple[int, ...]#
property batched: bool#
property cyclic: bool#
property device: str | None#
property dtype: str | None#
property dynamic: bool#
property graphviz_style#
property group: int#
is_valid(value=None) bool[source]#

Check if a given value is valid given this parameters allowed (valid) range.

property node_str: str#

Returns a string representation of the node for graph visualization.

property node_type#
property npvalue: ndarray#
property pointer: bool#
property shape: tuple[int, ...]#
property static: bool#
to(device=None, dtype=None) Param[source]#

Moves and/or casts the values of the parameter.

Parameters:
  • device ((Optional[torch.device], optional)) – The device to move the values to. Defaults to None.

  • dtype ((Optional[torch.dtype], optional)) – The desired data type. Defaults to None.

to_dynamic(value=<object object>)[source]#

Change this parameter to a dynamic parameter. If a value is provided, this will be set as the dynamic value.

to_pointer(value, link=())[source]#

Change this parameter to a pointer parameter. If a value is provided this will be set as the pointer. Either provide a Param object to point to its value, or provide a callable function to be called at runtime. It is also possible to provide a tuple of nodes to link to while creating the pointer.

to_static(value=<object object>)[source]#

Change this parameter to a static parameter. If a value is provided this will be set as the static value.

property valid: tuple[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None, Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None]#
property value: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None#
caskade.param.valid_shape(batch_shape, shape, value_shape)[source]#

caskade.tests module#

caskade.tests.test()[source]#

Basic integration test of caskade to ensure that the library is functioning correctly.

caskade.utils module#

caskade.utils.broadcast_cat_jax(arrays, dim=-1)[source]#

Concatenates JAX arrays with broadcasting.

Behaves like jnp.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension.

Parameters:
  • arrays (sequence of jnp.ndarray) – Arrays to concatenate.

  • dim (int) – The dimension along which to concatenate.

Returns:

The concatenated array.

Return type:

jnp.ndarray

caskade.utils.broadcast_cat_numpy(arrays, dim=-1)[source]#

Concatenates NumPy arrays with broadcasting.

Behaves like np.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension.

Parameters:
  • arrays (sequence of np.ndarray) – Arrays to concatenate.

  • dim (int) – The dimension along which to concatenate.

Returns:

The concatenated array.

Return type:

np.ndarray

caskade.utils.broadcast_cat_torch(tensors, dim=-1)[source]#

Concatenates tensors with broadcasting.

It behaves like torch.cat, but first broadcasts the tensors to match on all dimensions EXCEPT the concatenation dimension.

Parameters:
  • tensors (sequence of Tensors) – Tensors to concatenate.

  • dim (int) – The dimension along which to concatenate. Must be a negative index to ensure consistency across tensors of different ranks (e.g., -1 for the last dimension).

Returns:

The concatenated tensor.

Return type:

Tensor

caskade.warnings module#

exception caskade.warnings.CaskadeWarning[source]#

Bases: Warning

Base warning for caskade.

exception caskade.warnings.InvalidValueWarning(name, value, valid)[source]#

Bases: CaskadeWarning

Warning for values which fall outside the valid range.

exception caskade.warnings.SaveStateWarning[source]#

Bases: CaskadeWarning

Warning for when an issue occurs when a state is saved.

Module contents#

class caskade.ActiveContext(module: Module)[source]#

Bases: object

Context manager to activate a module for a simulation. Only inside an ActiveContext is it possible to fill/clear the dynamic and live parameters.

exception caskade.ActiveStateError[source]#

Bases: CaskadeException

Class for exceptions related to the active state of a node in caskade.

exception caskade.BackendError[source]#

Bases: CaskadeException

Class for exceptions related to the backend in caskade.

exception caskade.CaskadeException[source]#

Bases: Exception

Base class for all exceptions in caskade.

exception caskade.CaskadeWarning[source]#

Bases: Warning

Base warning for caskade.

exception caskade.FillParamsArrayError(name, input_params, params)[source]#

Bases: FillParamsError

Class for exceptions related to filling parameters with ArrayLike objects in caskade.

exception caskade.FillParamsError[source]#

Bases: CaskadeException

Class for exceptions related to filling parameters in caskade

exception caskade.FillParamsMappingError(name, children, missing_key=None)[source]#

Bases: FillParamsError

Class for exceptions related to filling parameters with a mapping (dict) in caskade.

exception caskade.FillParamsSequenceError(name, input_params, dynamic_params)[source]#

Bases: FillParamsError

Class for exceptions related to filling parameters with a sequence (list, tuple, etc.) in caskade.

exception caskade.GraphError[source]#

Bases: CaskadeException

Class for graph exceptions in caskade.

exception caskade.InvalidValueWarning(name, value, valid)[source]#

Bases: CaskadeWarning

Warning for values which fall outside the valid range.

exception caskade.LinkToAttributeError[source]#

Bases: GraphError

Class for exceptions related to linking to an attribute in caskade.

class caskade.Memo(module: Node, memo: str)[source]#

Bases: object

Sends a “memo” (a small message) to all nodes below the current one in the graph. This can be used to communicate state changes in the graph with all lower nodes. By default, the message will skip any subgraphs (hierarchical graphs) but this can be changed to ensure all nodes hear the message.

Note that memos are stored as a python set, so duplicates will be merged. Depending on your use case, it may be wise to ensure that your memo is unique.

Parameters:
  • module (Module) – The caskade Module object that will propogate the memo

  • memo (str) – The message to send down the graph

  • skip_subgraphs (bool) – If True (default) any subgraphs, otherwise known as hierarchical graphs, will not get the memo.

class caskade.Module(name: str | None = None, **kwargs)[source]#

Bases: Node, GetSetValues

Node to represent a simulation module in the graph.

The Module object is used to represent a simulation module in the graph. These are python objects that contain the calculations for a simulation, they also hold the Param objects that are used in the calculations. The Module object has additional functionality to manage the Param objects below it in the graph, it keeps track of all dynamic Param objects so that at runtime their values may be filled. The Module object manages its links to other nodes through attributes of the class.

Examples

Example of a nested pair of Module objects and how their @forward methods are called:

class MySim(Module):
    def __init__(self, a, b=None):
        super().__init__()
        self.a = a
        self.b = Param("b", b)

    @forward
    def myfunc(self, x, b=None):
        return x * self.a.otherfun(x) + b

class OtherSim(Module):
    def __init__(self, c=None):
        super().__init__()
        self.c = Param("c", c)

    @forward
    def otherfun(self, x, c = None):
        return x + c

othersim = OtherSim()
mysim = MySim(a=othersim)
#                       b                         c
params = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]
result = mysim.myfunc(3.0, params=params)
# result is tensor([19.0, 23.0])
property all_params#
clear_state()[source]#

Clear the active state _value for all params below this Module in the DAG. This should not be used by a user under normal circumstances.

property dynamic: bool#

Return True if the module has dynamic parameters as direct children.

fill_kwargs(keys: tuple[str]) dict[str, Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.']][source]#

Fill the kwargs for an @forward method with the values of the dynamic parameters. The requested keys are matched to names of Param objects owned by the Module. This should not be used by the user under normal circumstances.

fill_params(params: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | Sequence | Mapping, dynamic=True)[source]#

Fill the dynamic/static parameters of the module with the input values from params.

Parameters:
  • params ((Union[ArrayLike, Sequence, Mapping])) – The input values to fill the dynamic parameters with. The input can be an ArrayLike, a Sequence, or a Mapping.

  • dynamic (bool) – Operate on dynamic parameters (True, default) or static parameters (False).

property graphviz_style#
property node_str: str#

Returns a string representation of the node for graph visualization.

param_order()[source]#
remove_memo(memo)[source]#
property static: bool#
to_dynamic(children_only=True)[source]#

Change all parameters to dynamic parameters.

Parameters:

children_only ((bool, optional)) – If True, only convert the children of this module to dynamic. If False, convert all parameters in the graph below this module. Defaults to True.

to_static(children_only=True)[source]#

Change all parameters to static parameters.

Parameters:

children_only ((bool, optional)) – If True, only convert children of this module to static. If False, convert all parameters in the graph below this module. Defaults to True.

update_graph()[source]#

Maintain a tuple of dynamic, static, and pointer parameters at all points lower in the DAG.

class caskade.Node(name: str | None = None, link: Node | tuple[Node] | None = None, description: str = '')[source]#

Bases: object

Base graph node class for caskade objects.

The Node object is the base class for all caskade objects. It is used to construct the directed acyclic graph (DAG). The primary function of the Node object is to manage the parent-child relationships between nodes in the graph. There is limited functionality for the Node object, though it implements the base versions of the active state and to / update_graph methods. The active state is used to communicate through the graph that the simulator is currently running. The to method is used to move and/or cast the values of the parameter. The update_graph method is used signal all parents that the graph below them has changed.

Examples

Example making some Node objects and then linking/unlinking them:

n1 = Node()
n2 = Node()
n1.link("subnode", n2) # link n2 as a child of n1, may use any str as the key
n1.unlink("subnode") # alternately n1.unlink(n2) to unlink by object
property active: bool#
add_memo(memo)[source]#
append_state(saveto: str | File)[source]#

Append the state of the node and its children to an existing HDF5 file.

property children: dict[str, Node]#
graph_dict() dict[str, dict][source]#

Return a dictionary representation of the graph below the current node.

graph_print(dag: dict, depth: int = 0, indent: int = 4, result: str = '') str[source]#

Print the graph dictionary in a human-readable format.

graphviz(saveto: str | None = None) graphviz.Digraph[source]#

Return a graphviz object representing the graph below the current node in the DAG.

Parameters:
  • top_down ((bool, optional)) – Whether to draw the graph top-down (current node at top) or bottom-up (current node at bottom). Defaults to True.

  • saveto ((Optional[str], optional)) – If provided, save the graph to this file. The file extension determines the format (e.g. ‘.pdf’, ‘.png’). Defaults to None.

property graphviz_style#

Link the current Node object to another Node object as a child in a hierarchical manner. See link for more detail on linking. A hierarchical link will allow batching internally to the simulator.

Parameters:
  • key ((str)) – The key to link the child node with.

  • child ((Node)) – The child Node object to link to.

Link the current Node object to another Node object as a child.

Parameters:
  • key ((Union[str, Node])) – The key to link the child node with. This will also become the attribute to access the child node. After linking you will have node.key == child

  • child ((Optional[Node], optional)) – The child Node object to link to. Defaults to None in which case the key is used as the child and the child.name is used as the key.

Examples

Example making some Node objects and then linking/unlinking them. demonstrating multiple ways to link/unlink:

n1 = Node() n2 = Node()

n1.link("subnode", n2) # may use any str as the key
n1.unlink("subnode")

# Alternately, link by object n1.link(n2) n1.unlink(n2)
load_state(loadfrom: str | File, index: int = -1, **kwargs)[source]#

Load the state of the node and its children.

property memos: set[str]#
property name: str#
property node_str#
property online: bool#
property parents: set[Node]#
remove_memo(memo)[source]#
save_state(saveto: str | File, appendable: bool = False)[source]#

Save the state of the node and its children, currently only works for HDF5 file types (.h5 and .hdf5).

The “state” of a node is considered to be the value of its params, however it is also possible to save other attributes of the node by adding them to the Node.saveattrs set. Simply call Node.saveattrs.add(‘attribute’) and then Node.attribute will be saved if possible. The HDF5 file will be created with the same structure as the graph, even if there are multiple paths to the same node. For example if N1 has children N2 and N3, and both N2 and N3 have the child N4, the HDF5 file will reflect this. It will be possible to find the N4 params under both ‘N1/N2/N4’ and ‘N1/N3/N4’ if inspecting the HDF5 file manually. Specifically, if N4 has the param P1 then you could access its value like this:

If the save had been set as appendable, then the value will have an extra dimension for the number of samples, this will always be the first dimension. If appendable was false then the value will simply equal the param value.

Note

You need the optional h5py package installed to use this method.

Parameters:
  • saveto ((Union[str, File])) – The file to save the state to. If a string, it should be the path to an HDF5 file (ending in ‘.h5’ or ‘.hdf5’). If a File object, it should be an open HDF5 file.

  • appendable ((bool, optional)) – Whether to save the state in an appendable format. If True, the values will have an extra dimension for the number of samples. Defaults to False.

property subgraphs: set[Node]#
to(device=None, dtype=None)[source]#

Moves and/or casts the values of the Node to a particular device and/or dtype.

Parameters:
  • device ((Optional[torch.device], optional)) – The device to move the values to. Defaults to None.

  • dtype ((Optional[torch.dtype], optional)) – The desired data type. Defaults to None.

topological_ordering() tuple[Node][source]#

Return a topological ordering of the graph below the current node. Uses Iterative Deepening DFS (Post-Order) to resolve dependencies.

Unlink the current Node object from another Node object which is a child.

update_graph()[source]#

Triggers a call to all parents that the graph below them has been updated. The base Node object does nothing with this information, but other node types may use this to update internal state.

class caskade.NodeCollection(name: str | None = None, link: Node | tuple[Node] | None = None, description: str = '')[source]#

Bases: Node, GetSetValues

copy()[source]#
deepcopy()[source]#
property dynamic#
property dynamic_param_groups: tuple[int]#
property dynamic_params: tuple[Param]#
property pointer_params: tuple[Param]#
property static#
property static_params: tuple[Param]#
to_dynamic(children_only=True)[source]#

Change all parameters to dynamic parameters.

Parameters:

children_only ((bool, optional)) – If True, only convert the children of this module to dynamic. If False, convert all parameters in the graph below this module. Defaults to True.

to_static(children_only=True)[source]#

Change all parameters to static parameters.

Parameters:

children_only ((bool, optional)) – If True, only convert children of this module. If False, convert all parameters in the graph below this module. Defaults to True.

exception caskade.NodeConfigurationError[source]#

Bases: CaskadeException

Class for node configuration exceptions in caskade.

class caskade.NodeList(iterable=(), name=None)[source]#

Bases: NodeCollection, list

append(node)[source]#

Append object to the end of the list.

clear()[source]#

Remove all items from list.

extend(iterable)[source]#

Extend list by appending elements from the iterable.

property graphviz_style#
insert(index, node)[source]#

Insert object before index.

pop(index=-1)[source]#

Remove and return item at index (default last).

Raises IndexError if list is empty or index is out of range.

remove(value)[source]#

Remove first occurrence of value.

Raises ValueError if the value is not present.

class caskade.NodeTuple(iterable=None, name=None)[source]#

Bases: NodeCollection, tuple

property graphviz_style#
class caskade.OverrideParam(param: Param, value)[source]#

Bases: object

Context manager to override a parameter value. Only inside an OverrideParam will the parameter be set to the new value.

class caskade.Param(name: str, value: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | float | int | None = None, shape: tuple[int, ...] | None = None, cyclic: bool = False, valid: tuple[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | float | int | None] | None = None, units: str | None = None, dynamic: bool | None = None, group: int = 0, batch_shape: tuple[int] | None = None, dtype: Any | None = None, device: Any | None = None, **kwargs)[source]#

Bases: Node

Node to represent a parameter in the graph.

The Param object is used to represent a parameter in the graph. During runtime this will represent a value which can be used in various calculations. The Param object can be set to a constant value (static); None meaning the value is to be provided at runtime (dynamic); another Param object meaning it will take on that value at runtime (pointer); or a function of other Param objects to be computed at runtime (also pointer, see user guides). These options allow users to flexibly set the behavior of the simulator.

Examples

Example making some Param objects:

p1 = Param("test", (1.0, 2.0)) # constant value, length 2 vector
p2 =Param("p2", None, (2,2)) # dynamic 2x2 matrix value
p3 = Param("p3", p1) # pointer to another parameter
p4 = Param("p4", lambda p: p.children["other"].value * 2) # arbitrary function of another parameter
p5 = Param("p5", valid=(0.0,2*pi), units="radians", cyclic=True) # parameter with metadata
Parameters:
  • name ((str)) – The name of the parameter.

  • value ((Optional[Union[ArrayLike, float, int]], optional)) – The value of the parameter. Defaults to None meaning dynamic.

  • shape ((Optional[tuple[int, ...]], optional)) – The shape of the parameter. Defaults to () meaning scalar.

  • cyclic ((bool, optional)) – Whether the parameter is cyclic, imposing periodic boundary conditions. Such as a rotation from 0 to 2pi. Defaults to False.

  • valid ((Optional[tuple[Union[ArrayLike, float, int, None]]], optional)) – The valid range of the parameter. Defaults to None meaning all of -inf to inf is valid.

  • units ((Optional[str], optional)) – The units of the parameter. Defaults to None.

  • dynamic ((bool, optional)) – Force param to be dynamic if True. If a value is provided and param is dynamic then it has a default value at call time.

  • (bool (batched) – If True, the param is assumed batched and the shape may now take the form (*B, *D) where *D is the shape of the value.

  • optional) – If True, the param is assumed batched and the shape may now take the form (*B, *D) where *D is the shape of the value.

  • dtype ((Optional[Any], optional)) – The data type of the parameter. Defaults to None meaning the data type will be inferred from the value.

  • device ((Optional[Any], optional)) – The device of the parameter. Defaults to None meaning the device will be inferred from the value.

property batch_shape: tuple[int, ...]#
property batched: bool#
property cyclic: bool#
property device: str | None#
property dtype: str | None#
property dynamic: bool#
property graphviz_style#
property group: int#
is_valid(value=None) bool[source]#

Check if a given value is valid given this parameters allowed (valid) range.

property node_str: str#

Returns a string representation of the node for graph visualization.

property node_type#
property npvalue: ndarray#
property pointer: bool#
property shape: tuple[int, ...]#
property static: bool#
to(device=None, dtype=None) Param[source]#

Moves and/or casts the values of the parameter.

Parameters:
  • device ((Optional[torch.device], optional)) – The device to move the values to. Defaults to None.

  • dtype ((Optional[torch.dtype], optional)) – The desired data type. Defaults to None.

to_dynamic(value=<object object>)[source]#

Change this parameter to a dynamic parameter. If a value is provided, this will be set as the dynamic value.

to_pointer(value, link=())[source]#

Change this parameter to a pointer parameter. If a value is provided this will be set as the pointer. Either provide a Param object to point to its value, or provide a callable function to be called at runtime. It is also possible to provide a tuple of nodes to link to while creating the pointer.

to_static(value=<object object>)[source]#

Change this parameter to a static parameter. If a value is provided this will be set as the static value.

property valid: tuple[Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None, Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None]#
property value: Annotated[Tensor, 'One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.'] | None#
exception caskade.ParamConfigurationError[source]#

Bases: NodeConfigurationError

Class for parameter configuration exceptions in caskade.

exception caskade.ParamTypeError[source]#

Bases: CaskadeException

Class for exceptions related to the type of a parameter in caskade.

exception caskade.SaveStateWarning[source]#

Bases: CaskadeWarning

Warning for when an issue occurs when a state is saved.

class caskade.ValidContext(module: Module)[source]#

Bases: object

Context manager to set valid values for parameters. Only inside a ValidContext will parameters automatically be assumed valid.

class caskade.active_cache(func)[source]#

Bases: object

Caches the first evaluated result of a Module method for the duration of a simulation.

This decorator ensures that an expensive method is executed exactly once per active simulation run. Once calculated, subsequent calls to the decorated method will return the stored value, ignoring any arguments passed to it.

WARNING:

If the method is called multiple times with different arguments in one simulation, the cached result will still be returned, which may lead to unexpected behavior. Use with caution!

Note

If you are stacking multiple decorators on a method (such as @forward or @jax.jit), @active_cache MUST be the outermost (top) decorator.

Example::
class FluxModel(Module):
def __init__(self, nodes, x, M):

super().__init__() self.nodes = nodes self.x = Param(“x”, x) self.M = Param(“M”, M)

@active_cache @jax.jit # Notice active_cache is placed at the top @forward def compute_intrinsic_sed(self, w, x, M):

print(“Computing SED…”) return jnp.interp(w, self.nodes, x * M)

@forward def compute_flux(self, wavelengths):

sed = self.compute_intrinsic_sed(wavelengths) # Cached after first call flux = jnp.sum(sed) sed = self.compute_intrinsic_sed(wavelengths) # Returns cached result, no print peak = jnp.max(sed) return flux, peak

model = FluxModel(np.linspace(400, 700, 10), x=1.0, M=np.random.rand(10))

# Compute flux only calls compute_intrinsic_sed once due to caching flux, peak = model.compute_flux(wavelengths)

caskade.forward(method)[source]#

Decorator to define a forward method for a module.

Parameters:

method ((Callable)) – The forward method to be decorated.

Examples

Standard usage of the forward decorator:

class ExampleSim(Module):
    def __init__(self, a, b, c):
        super().__init__("example_sim")
        self.a = a
        self.b = Param("b", b)
        self.c = Param("c", c)

    @forward
    def example_func(self, x, b=None):
        return x + self.a + b

E = ExampleSim(a=1, b=None, c=3)
print(E.example_func(4, params=[5]))
# Output: 10
Returns:

The decorated forward method.

Return type:

Callable

caskade.test()[source]#

Basic integration test of caskade to ensure that the library is functioning correctly.