Decorators

Decorators#

caskade.forward(method)[source]#

Decorator to define a forward method for a module.

Manages parameter passing and activation for the decorated method. When called, it automatically fills keyword arguments from the module’s Param children and handles parameter overrides and active context.

Parameters:

method ((Callable)) – The forward method to be decorated.

Returns:

The decorated forward method.

Return type:

Callable

Examples

Standard usage of the forward decorator:

class ExampleSim(Module):
    def __init__(self, a, b, c):
        super().__init__("example_sim")
        self.a = a
        self.b = Param("b", b)
        self.c = Param("c", c)

    @forward
    def example_func(self, x, b=None):
        return x + self.a + b

E = ExampleSim(a=1, b=None, c=3)
print(E.example_func(4, params=[5]))
# Output: 10
class caskade.active_cache(func)[source]#

Bases: object

Caches the first evaluated result of a Module method for the duration of a simulation.

This decorator ensures that an expensive method is executed exactly once per active simulation run. Once calculated, subsequent calls to the decorated method will return the stored value, ignoring any arguments passed to it.

Warning

If the method is called multiple times with different arguments in one simulation, the cached result will still be returned, which may lead to unexpected behavior. Use with caution!

Notes

If you are stacking multiple decorators on a method (such as @forward or @jax.jit), @active_cache MUST be the outermost (top) decorator.

Examples

class FluxModel(Module):
    def __init__(self, nodes, x, M):
        super().__init__()
        self.nodes = nodes
        self.x = Param("x", x)
        self.M = Param("M", M)

    @active_cache
    @jax.jit  # Notice active_cache is placed at the top
    @forward
    def compute_intrinsic_sed(self, w, x, M):
        print("Computing SED...")
        return jnp.interp(w, self.nodes, x * M)

    @forward
    def compute_flux(self, wavelengths):
        sed = self.compute_intrinsic_sed(wavelengths)  # Cached after first call
        flux = jnp.sum(sed)
        sed = self.compute_intrinsic_sed(wavelengths)  # Returns cached result, no print
        peak = jnp.max(sed)
        return flux, peak

model = FluxModel(np.linspace(400, 700, 10), x=1.0, M=np.random.rand(10))

# Compute flux only calls compute_intrinsic_sed once due to caching
flux, peak = model.compute_flux(wavelengths)