Source code for caskade.decorators

import inspect
import functools
from contextlib import ExitStack

from .context import ActiveContext, OverrideParam
from .param import Param

__all__ = ("forward", "active_cache")


[docs] def forward(method): """ Decorator to define a forward method for a module. Parameters ---------- method: (Callable) The forward method to be decorated. Examples -------- Standard usage of the forward decorator:: class ExampleSim(Module): def __init__(self, a, b, c): super().__init__("example_sim") self.a = a self.b = Param("b", b) self.c = Param("c", c) @forward def example_func(self, x, b=None): return x + self.a + b E = ExampleSim(a=1, b=None, c=3) print(E.example_func(4, params=[5])) # Output: 10 Returns ------- Callable The decorated forward method. """ # Get arguments from function signature sig = inspect.signature(method) method_params = tuple(sig.parameters.keys()) @functools.wraps(method) def wrapped(self, *args, **kwargs): if self.online: with ExitStack() as stack: # User override of parameters for single function call for kwarg, kval in list(kwargs.items()): if kwarg in self.children and isinstance(self.children[kwarg], Param): stack.enter_context(OverrideParam(self.children[kwarg], kval)) del kwargs[kwarg] kwargs = {**self.fill_kwargs(method_params), **kwargs} return method(self, *args, **kwargs) # Extract params from the arguments if "params" in kwargs: params = kwargs.pop("params") with ActiveContext(self): self.fill_params(params) kwargs = {**self.fill_kwargs(method_params), **kwargs} return method(self, *args, **kwargs) elif len(self.dynamic_params) == 0: with ActiveContext(self): kwargs = {**self.fill_kwargs(method_params), **kwargs} try: sig.bind(self, *args, **kwargs) empty_params = False except TypeError: # user supplied empty params as last arg empty_params = True if empty_params: return method(self, *args[:-1], **kwargs) else: return method(self, *args, **kwargs) elif args: params = args[-1] args = args[:-1] with ActiveContext(self): self.fill_params(params) kwargs = {**self.fill_kwargs(method_params), **kwargs} return method(self, *args, **kwargs) else: with ActiveContext(self): kwargs = {**self.fill_kwargs(method_params), **kwargs} return method(self, *args, **kwargs) return wrapped
[docs] class active_cache: """ Caches the first evaluated result of a Module method for the duration of a simulation. This decorator ensures that an expensive method is executed exactly once per active simulation run. Once calculated, subsequent calls to the decorated method will return the stored value, ignoring any arguments passed to it. **WARNING**: If the method is called multiple times with different arguments in one simulation, the cached result will still be returned, which may lead to unexpected behavior. Use with caution! Note: If you are stacking multiple decorators on a method (such as `@forward` or `@jax.jit`), `@active_cache` MUST be the outermost (top) decorator. Example:: class FluxModel(Module): def __init__(self, nodes, x, M): super().__init__() self.nodes = nodes self.x = Param("x", x) self.M = Param("M", M) @active_cache @jax.jit # Notice active_cache is placed at the top @forward def compute_intrinsic_sed(self, w, x, M): print("Computing SED...") return jnp.interp(w, self.nodes, x * M) @forward def compute_flux(self, wavelengths): sed = self.compute_intrinsic_sed(wavelengths) # Cached after first call flux = jnp.sum(sed) sed = self.compute_intrinsic_sed(wavelengths) # Returns cached result, no print peak = jnp.max(sed) return flux, peak model = FluxModel(np.linspace(400, 700, 10), x=1.0, M=np.random.rand(10)) # Compute flux only calls compute_intrinsic_sed once due to caching flux, peak = model.compute_flux(wavelengths) """ def __init__(self, func): self.func = func # Unique attribute name to store the single result self.cache_attr = f"_active_cache_{func.__name__}" # Update wrapper to preserve function metadata functools.update_wrapper(self, func) def __set_name__(self, owner, name): """Injects the reset function when the class is created.""" if "_cache_attrs" not in owner.__dict__: # Start with a copy of any inherited cache attributes from parent classes inherited_attrs = set() for base in owner.__bases__: if hasattr(base, "_cache_attrs"): inherited_attrs.update(base._cache_attrs) # Assign the independent set to this specific subclass owner._cache_attrs = inherited_attrs owner._cache_attrs.add(self.cache_attr) if "reset_active_cache" not in owner.__dict__: def reset_active_cache(instance): """Deletes the cached attributes to force a recalculation.""" for attr in instance.__class__._cache_attrs: if hasattr(instance, attr): delattr(instance, attr) owner.reset_active_cache = reset_active_cache def __get__(self, instance, owner): if instance is None: return self @functools.wraps(self.func) def wrapper(*args, **kwargs): # If not in simulation, just call the function without caching if not instance.active: return self.func(instance, *args, **kwargs) # If we already have the attribute, return it immediately if hasattr(instance, self.cache_attr): return getattr(instance, self.cache_attr) # Run the function, save the output, and return it result = self.func(instance, *args, **kwargs) setattr(instance, self.cache_attr, result) return result return wrapper