Decorators#
- caskade.forward(method)[source]#
Decorator to define a forward method for a module.
Manages parameter passing and activation for the decorated method. When called, it automatically fills keyword arguments from the module’s
Paramchildren and handles parameter overrides and active context.- Parameters:
method ((Callable)) – The forward method to be decorated.
- Returns:
The decorated forward method.
- Return type:
Callable
Examples
Standard usage of the forward decorator:
class ExampleSim(Module): def __init__(self, a, b, c): super().__init__("example_sim") self.a = a self.b = Param("b", b) self.c = Param("c", c) @forward def example_func(self, x, b=None): return x + self.a + b E = ExampleSim(a=1, b=None, c=3) print(E.example_func(4, params=[5])) # Output: 10
- class caskade.active_cache(func)[source]#
Bases:
objectCaches the first evaluated result of a Module method for the duration of a simulation.
This decorator ensures that an expensive method is executed exactly once per active simulation run. Once calculated, subsequent calls to the decorated method will return the stored value, ignoring any arguments passed to it.
Warning
If the method is called multiple times with different arguments in one simulation, the cached result will still be returned, which may lead to unexpected behavior. Use with caution!
Notes
If you are stacking multiple decorators on a method (such as
@forwardor@jax.jit),@active_cacheMUST be the outermost (top) decorator.Examples
class FluxModel(Module): def __init__(self, nodes, x, M): super().__init__() self.nodes = nodes self.x = Param("x", x) self.M = Param("M", M) @active_cache @jax.jit # Notice active_cache is placed at the top @forward def compute_intrinsic_sed(self, w, x, M): print("Computing SED...") return jnp.interp(w, self.nodes, x * M) @forward def compute_flux(self, wavelengths): sed = self.compute_intrinsic_sed(wavelengths) # Cached after first call flux = jnp.sum(sed) sed = self.compute_intrinsic_sed(wavelengths) # Returns cached result, no print peak = jnp.max(sed) return flux, peak model = FluxModel(np.linspace(400, 700, 10), x=1.0, M=np.random.rand(10)) # Compute flux only calls compute_intrinsic_sed once due to caching flux, peak = model.compute_flux(wavelengths)