Utilities & Testing#
- caskade.utils.broadcast_cat_jax(arrays, dim=-1)[source]#
Concatenates JAX arrays with broadcasting.
Behaves like jnp.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension.
- Parameters:
arrays (sequence of jnp.ndarray) – Arrays to concatenate.
dim (int) – The dimension along which to concatenate.
- Returns:
The concatenated array.
- Return type:
jnp.ndarray
- caskade.utils.broadcast_cat_numpy(arrays, dim=-1)[source]#
Concatenates NumPy arrays with broadcasting.
Behaves like np.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension.
- Parameters:
arrays (sequence of np.ndarray) – Arrays to concatenate.
dim (int) – The dimension along which to concatenate.
- Returns:
The concatenated array.
- Return type:
np.ndarray
- caskade.utils.broadcast_cat_torch(tensors, dim=-1)[source]#
Concatenates tensors with broadcasting.
It behaves like torch.cat, but first broadcasts the tensors to match on all dimensions EXCEPT the concatenation dimension.
- Parameters:
tensors (sequence of Tensors) – Tensors to concatenate.
dim (int) – The dimension along which to concatenate. Must be a negative index to ensure consistency across tensors of different ranks (e.g., -1 for the last dimension).
- Returns:
The concatenated tensor.
- Return type:
Tensor