Utilities & Testing#

caskade.utils.broadcast_cat_jax(arrays, dim=-1)[source]#

Concatenates JAX arrays with broadcasting.

Behaves like jnp.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension.

Parameters:
  • arrays (sequence of jnp.ndarray) – Arrays to concatenate.

  • dim (int) – The dimension along which to concatenate.

Returns:

The concatenated array.

Return type:

jnp.ndarray

caskade.utils.broadcast_cat_numpy(arrays, dim=-1)[source]#

Concatenates NumPy arrays with broadcasting.

Behaves like np.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension.

Parameters:
  • arrays (sequence of np.ndarray) – Arrays to concatenate.

  • dim (int) – The dimension along which to concatenate.

Returns:

The concatenated array.

Return type:

np.ndarray

caskade.utils.broadcast_cat_torch(tensors, dim=-1)[source]#

Concatenates tensors with broadcasting.

It behaves like torch.cat, but first broadcasts the tensors to match on all dimensions EXCEPT the concatenation dimension.

Parameters:
  • tensors (sequence of Tensors) – Tensors to concatenate.

  • dim (int) – The dimension along which to concatenate. Must be a negative index to ensure consistency across tensors of different ranks (e.g., -1 for the last dimension).

Returns:

The concatenated tensor.

Return type:

Tensor

caskade.test()[source]#

Run a basic integration test to verify caskade is installed and working.

Exercises core functionality including Module and Param creation, parameter linking, and forward method execution.

Examples

import caskade
caskade.test()
# Output: Success!