Source code for caskade.utils
import torch
import numpy as np
[docs]
def broadcast_cat_torch(tensors, dim=-1):
"""
Concatenates tensors with broadcasting.
It behaves like torch.cat, but first broadcasts the tensors to match
on all dimensions EXCEPT the concatenation dimension.
Args:
tensors (sequence of Tensors): Tensors to concatenate.
dim (int): The dimension along which to concatenate.
Must be a negative index to ensure consistency across
tensors of different ranks (e.g., -1 for the last dimension).
Returns:
Tensor: The concatenated tensor.
"""
if not tensors:
raise ValueError("tensors argument must be a non-empty sequence")
# 1. Align Ranks
# We find the maximum rank (ndim) and pad all smaller tensors with
# leading singleton dimensions (1s) so everyone aligns to the right.
max_dim = max(t.ndim for t in tensors)
# We iterate and unsqueeze; we store them in a new list to avoid mutating inputs
aligned_tensors = []
for t in tensors:
# Prepend 1s until rank matches max_dim
pads = (1,) * (max_dim - t.ndim)
if pads:
# view as (1, 1, ..., *original_shape)
t = t.view(*pads, *t.shape)
aligned_tensors.append(t)
# 2. Normalize dim to positive index based on the new max_dim
# This allows us to index the shape tuple consistently.
if dim < 0:
dim = max_dim + dim
if dim < 0 or dim >= max_dim:
raise ValueError(
f"Dimension {dim - max_dim} is out of bounds for tensors with rank {max_dim}"
)
# 3. Determine Broadcast Shape
# We need a target shape that fits everyone.
# For the concatenation dimension, we don't care about matching sizes.
# For all other dimensions, we take the max size (standard broadcasting).
target_shape = list(aligned_tensors[0].shape)
for t in aligned_tensors[1:]:
for i in range(max_dim):
if i == dim:
continue # Skip checks for the concatenation axis
current_size = t.shape[i]
target_size = target_shape[i]
if current_size != target_size:
if current_size == 1:
continue # This tensor will expand
elif target_size == 1:
target_shape[i] = current_size # Previous tensors will expand
else:
raise RuntimeError(
f"The size of tensor {t.shape} at dimension {i} ({current_size}) "
f"does not match the target size ({target_size}) and neither is 1."
)
# 4. Expand and Concatenate
# We expand every tensor to the target_shape, BUT we must leave the
# concatenation dimension as-is for that specific tensor.
expanded_tensors = []
for t in aligned_tensors:
# Create the specific target shape for THIS tensor
# (Global target shape, but with this tensor's specific size at 'dim')
local_target_shape = list(target_shape)
local_target_shape[dim] = t.shape[dim]
# Expand efficienty (returns a view, no data copy)
expanded_tensors.append(t.expand(*local_target_shape))
return torch.cat(expanded_tensors, dim=dim)
[docs]
def broadcast_cat_jax(arrays, dim=-1):
"""
Concatenates JAX arrays with broadcasting.
Behaves like jnp.concatenate, but first broadcasts the arrays to match
on all dimensions EXCEPT the concatenation dimension.
Args:
arrays (sequence of jnp.ndarray): Arrays to concatenate.
dim (int): The dimension along which to concatenate.
Returns:
jnp.ndarray: The concatenated array.
"""
import jax.numpy as jnp
if not arrays:
raise ValueError("arrays argument must be a non-empty sequence")
# 1. Align Ranks
# Find the maximum rank and left-pad smaller arrays with shape 1
max_ndim = max(a.ndim for a in arrays)
aligned_arrays = []
for a in arrays:
diff = max_ndim - a.ndim
if diff > 0:
# Reshape to (1, 1, ..., *original_shape)
new_shape = (1,) * diff + a.shape
aligned_arrays.append(a.reshape(new_shape))
else:
aligned_arrays.append(a)
# 2. Normalize dim
# Convert negative index to positive based on max_ndim
if dim < 0:
dim += max_ndim
if dim < 0 or dim >= max_ndim:
raise ValueError(f"Dimension {dim - max_ndim} is out of bounds")
# 3. Determine Broadcast Shapes
# We split the shapes into two parts: dimensions *before* the concat axis
# and dimensions *after* the concat axis.
shapes_pre = [a.shape[:dim] for a in aligned_arrays]
shapes_post = [a.shape[dim + 1 :] for a in aligned_arrays]
try:
# Calculate the common broadcast shape for the surrounding dimensions
common_pre = jnp.broadcast_shapes(*shapes_pre)
common_post = jnp.broadcast_shapes(*shapes_post)
except ValueError as e:
raise ValueError("Shapes cannot be broadcast (excluding concatenation axis)") from e
# 4. Expand and Concatenate
expanded_arrays = []
for a in aligned_arrays:
# Construct the target shape for THIS array:
# [Broadcasting Pre] + [Original Dim Size] + [Broadcasting Post]
target_shape = common_pre + (a.shape[dim],) + common_post
# broadcast_to returns a view (no copy) where possible
expanded_arrays.append(jnp.broadcast_to(a, target_shape))
return jnp.concatenate(expanded_arrays, axis=dim)
[docs]
def broadcast_cat_numpy(arrays, dim=-1):
"""
Concatenates NumPy arrays with broadcasting.
Behaves like np.concatenate, but first broadcasts the arrays to match
on all dimensions EXCEPT the concatenation dimension.
Args:
arrays (sequence of np.ndarray): Arrays to concatenate.
dim (int): The dimension along which to concatenate.
Returns:
np.ndarray: The concatenated array.
"""
if not arrays:
raise ValueError("arrays argument must be a non-empty sequence")
# Ensure inputs are actually numpy arrays (handles lists/tuples of numbers)
arrays = [np.asarray(a) for a in arrays]
# 1. Align Ranks
# Find the maximum rank and left-pad smaller arrays with shape 1
max_ndim = max(a.ndim for a in arrays)
aligned_arrays = []
for a in arrays:
diff = max_ndim - a.ndim
if diff > 0:
# Reshape to (1, 1, ..., *original_shape)
new_shape = (1,) * diff + a.shape
aligned_arrays.append(a.reshape(new_shape))
else:
aligned_arrays.append(a)
# 2. Normalize dim
# Convert negative index to positive based on max_ndim
if dim < 0:
dim += max_ndim
if dim < 0 or dim >= max_ndim:
raise ValueError(f"Dimension {dim - max_ndim} is out of bounds")
# 3. Determine Broadcast Shapes
# We split the shapes into two parts: dimensions *before* the concat axis
# and dimensions *after* the concat axis.
shapes_pre = [a.shape[:dim] for a in aligned_arrays]
shapes_post = [a.shape[dim + 1 :] for a in aligned_arrays]
try:
# Calculate the common broadcast shape for the surrounding dimensions
# Note: np.broadcast_shapes was added in NumPy 1.20
common_pre = np.broadcast_shapes(*shapes_pre)
common_post = np.broadcast_shapes(*shapes_post)
except ValueError as e:
raise ValueError("Shapes cannot be broadcast (excluding concatenation axis)") from e
# 4. Expand and Concatenate
expanded_arrays = []
for a in aligned_arrays:
# Construct the target shape for THIS array:
# [Broadcasting Pre] + [Original Dim Size] + [Broadcasting Post]
target_shape = common_pre + (a.shape[dim],) + common_post
# broadcast_to returns a read-only view (efficient memory usage)
expanded_arrays.append(np.broadcast_to(a, target_shape))
return np.concatenate(expanded_arrays, axis=dim)