Source code for caskade.backend

import os
import importlib
from typing import Annotated

from torch import Tensor
import numpy as np
from . import utils

ArrayLike = Annotated[
    Tensor,
    "One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.",
]


[docs] class Backend: def __init__(self, backend=None): self.backend = backend @property def backend(self): return self._backend @backend.setter def backend(self, backend): if backend is None: backend = os.getenv("CASKADE_BACKEND", "torch") self.module = self._load_backend(backend) self._backend = backend def _load_backend(self, backend): if backend == "torch": self.setup_torch() return importlib.import_module("torch") elif backend == "jax": self.setup_jax() return importlib.import_module("jax.numpy") elif backend == "numpy": self.setup_numpy() return importlib.import_module("numpy") else: raise ValueError(f"Unsupported backend: {backend}")
[docs] def setup_torch(self): self.make_array = self._make_array_torch self._array_type = self._array_type_torch self.concatenate = self._concatenate_torch self.broadcast_cat = utils.broadcast_cat_torch self.tolist = self._tolist_torch self.view = self._view_torch self.detach = self._detach_torch self.as_array = self._as_array_torch self.to = self._to_torch self.to_numpy = self._to_numpy_torch self.logit = self._logit_torch self.sigmoid = self._sigmoid_torch
[docs] def setup_jax(self): self.jax = importlib.import_module("jax") self.make_array = self._make_array_jax self._array_type = self._array_type_jax self.concatenate = self._concatenate_jax self.broadcast_cat = utils.broadcast_cat_jax self.tolist = self._tolist_jax self.view = self._view_jax self.detach = self._detach_jax self.as_array = self._as_array_jax self.to = self._to_jax self.to_numpy = self._to_numpy_jax self.logit = self._logit_jax self.sigmoid = self._sigmoid_jax
[docs] def setup_numpy(self): self.make_array = self._make_array_numpy self._array_type = self._array_type_numpy self.concatenate = self._concatenate_numpy self.broadcast_cat = utils.broadcast_cat_numpy self.tolist = self._tolist_numpy self.view = self._view_numpy self.detach = self._detach_numpy self.as_array = self._as_array_numpy self.to = self._to_numpy self.to_numpy = self._to_numpy_numpy self.logit = self._logit_numpy self.sigmoid = self._sigmoid_numpy
@property def array_type(self): return self._array_type() def _make_array_torch(self, array, dtype=None, device=None): return self.module.tensor(array, dtype=dtype, device=device) def _make_array_jax(self, array, dtype=None, **kwargs): return self.module.array(array, dtype=dtype) def _make_array_numpy(self, array, dtype=None, **kwargs): return self.module.array(array, dtype=dtype) def _array_type_torch(self): return self.module.Tensor def _array_type_jax(self): return self.module.ndarray def _array_type_numpy(self): return self.module.ndarray def _concatenate_torch(self, arrays, axis=0): return self.module.cat(arrays, dim=axis) def _concatenate_jax(self, arrays, axis=0): return self.module.concatenate(arrays, axis=axis) def _concatenate_numpy(self, arrays, axis=0): return self.module.concatenate(arrays, axis=axis) def _detach_torch(self, array): return array.detach() def _detach_jax(self, array): return array def _detach_numpy(self, array): return array def _tolist_torch(self, array): return array.detach().cpu().tolist() def _tolist_jax(self, array): return array.block_until_ready().tolist() def _tolist_numpy(self, array): return array.tolist() def _view_torch(self, array, shape): return array.reshape(shape) def _view_jax(self, array, shape): return array.reshape(shape) def _view_numpy(self, array, shape): return array.reshape(shape) def _as_array_torch(self, array, dtype=None, device=None): return self.module.as_tensor(array, dtype=dtype, device=device) def _as_array_jax(self, array, dtype=None, **kwargs): return self.module.asarray(array, dtype=dtype) def _as_array_numpy(self, array, dtype=None, **kwargs): return self.module.asarray(array, dtype=dtype) def _to_torch(self, array, dtype=None, device=None): return array.to(dtype=dtype, device=device) def _to_jax(self, array, dtype=None, device=None): return self.jax.device_put(array.astype(dtype), device=device) def _to_numpy(self, array, dtype=None, **kwargs): return array.astype(dtype) def _to_numpy_torch(self, array): return array.detach().cpu().numpy() def _to_numpy_jax(self, array): return np.array(array.block_until_ready()) def _to_numpy_numpy(self, array): return array
[docs] def any(self, array): return self.module.any(array)
[docs] def all(self, array): return self.module.all(array)
[docs] def log(self, array): return self.module.log(array)
[docs] def exp(self, array): return self.module.exp(array)
[docs] def sum(self, array, axis=None): return self.module.sum(array, axis=axis)
def _sigmoid_torch(self, array): return self.module.sigmoid(array) def _sigmoid_jax(self, array): return self.jax.nn.sigmoid(array) def _sigmoid_numpy(self, array): return 1 / (1 + self.module.exp(-array)) def _logit_torch(self, array): return self.module.logit(array) def _logit_jax(self, array): return self.jax.scipy.special.logit(array) def _logit_numpy(self, array): return np.log(array / (1 - array))
backend = Backend()