from typing import Optional, Mapping, Sequence, Union
from math import prod
import numpy as np
from .param import Param
from .errors import (
FillParamsArrayError,
FillParamsMappingError,
FillParamsSequenceError,
ActiveStateError,
ParamConfigurationError,
)
from .backend import backend, ArrayLike
from .base import Node, Memo
[docs]
class GetSetValues:
@property
def valid_context(self) -> bool:
"""Return True if the module is in a valid context."""
try:
return self._valid_context
except AttributeError:
return False
@valid_context.setter
def valid_context(self, value: bool):
"""Set the valid context of the module."""
self._valid_context = value
for node in self.topological_ordering():
if isinstance(node, GetSetValues):
node._valid_context = value
# Set Values
#################################################################
def _set_values_dict(self, node, params, params_list, attribute="_value"):
for key in params:
if key in node.children and isinstance(params[key], dict):
self._set_values_dict(node[key], params[key], params_list, attribute=attribute)
elif key in node.children and isinstance(node[key], Param):
setattr(node[key], attribute, params[key])
elif key in node.children:
sublist = tuple(p for p in params_list if p in node[key].all_params)
node[key]._set_values(params[key], sublist, attribute=attribute)
else:
raise FillParamsMappingError(self.name, self.children, missing_key=key)
def _set_values(
self,
params: Union[ArrayLike, Sequence, Mapping],
param_list: tuple[Param],
attribute="_value",
):
"""
Fill the dynamic parameters of the module with the input values from
params.
Parameters
----------
params: (Union[ArrayLike, Sequence, Mapping])
The input values to fill the dynamic parameters with. The input can
be an ArrayLike, a Sequence, or a Mapping. If the input is
array-like, the values are filled in order of the dynamic
parameters. `params` should be a flattened array-like object with
all parameters concatenated in the order of the dynamic parameters.
If `len(params.shape)>1` then all dimensions but the last one are
considered batch dimensions. If the input is a Sequence, the values
are filled in order of the dynamic parameters. If the input is a
Mapping, the values are filled by matching the keys of the Mapping
to the names of the dynamic parameters. Note that the system does
not check for missing keys in the dictionary, but you will get an
error eventually if a value is missing.
"""
if isinstance(params, backend.array_type):
if params.shape[-1] == 0:
return # No parameters to fill
# check for batch dimension
batch = len(params.shape) > 1
B = tuple(params.shape[:-1]) if batch else ()
pos = 0
for param in param_list:
if param.online:
shape = param.shape
else:
depth = max(memo.count("|") for memo in param.memos)
shape = param.batch_shape[-depth:] + param.shape
# Handle scalar parameters
size = max(1, prod(shape))
try:
val = backend.view(params[..., pos : pos + size], B + shape)
setattr(param, attribute, val)
except (RuntimeError, IndexError, ValueError, TypeError):
raise FillParamsArrayError(self.name, params, param_list)
pos += size
if pos != params.shape[-1]:
raise FillParamsArrayError(self.name, params, param_list)
elif isinstance(params, Sequence):
if len(params) == 0:
return
elif len(params) == len(param_list):
for param, value in zip(param_list, params):
setattr(param, attribute, value)
else:
raise FillParamsSequenceError(self.name, params, param_list)
elif isinstance(params, Mapping):
self._set_values_dict(self, params, param_list, attribute=attribute)
else:
raise TypeError(
f"Input params type {type(params)} not supported. Should be {backend.array_type.__name__}, Sequence, or Mapping."
)
[docs]
def set_values(
self, params: Union[ArrayLike, Sequence, Mapping], dynamic=True, attribute="value"
):
"""Fill the dynamic values of the module with the input values from params."""
if self.active:
raise ActiveStateError(f"Cannot fill dynamic values when Module {self.name} is active")
param_list = self.dynamic_params if dynamic else self.static_params
with Memo(self, self.name + ":semi_set_active"):
if len(self.dynamic_param_groups) > 1:
for group, params_g in zip(self.dynamic_param_groups, params):
param_list_g = tuple(p for p in param_list if p.group == group)
if self.valid_context:
params_g = self.from_valid(params_g, param_list_g, group=group)
self._set_values(params_g, param_list_g, attribute=attribute)
else:
if self.valid_context:
params = self.from_valid(params, param_list)
self._set_values(params, param_list, attribute=attribute)
# Get Values
#################################################################
def _check_values(self, param_list, scheme):
"""Check if all dynamic values are set."""
bad_params = []
for param in param_list:
if param.value is None:
bad_params.append(param.name)
if len(bad_params) > 0:
raise ParamConfigurationError(
f"{self.name} Param(s) {bad_params} have no value, so the params {scheme} cannot be built. Set their value to use this feature."
)
[docs]
def get_values(
self, scheme="array", dynamic=True, attribute="value", group: Optional[int] = None
) -> Union[ArrayLike, list[ArrayLike], dict[str, Union[dict, ArrayLike]]]:
if len(self.dynamic_param_groups) > 1 and group is None:
values = []
for g in self.dynamic_param_groups:
values.append(
self.get_values(scheme=scheme, dynamic=dynamic, attribute=attribute, group=g)
)
return values
param_list = self.dynamic_params if dynamic else self.static_params
param_list = tuple(p for p in param_list if (group is None or p.group == group))
self._check_values(param_list, scheme)
x = []
if scheme.lower() in ["array", "tensor"]:
with Memo(self, self.name + ":semi_get_active"):
for param in param_list:
if param.online:
B = param.batch_shape
else:
depth = max(memo.count("|") for memo in param.memos)
B = param.batch_shape[:-depth]
x.append(getattr(param, attribute).reshape(B + (-1,)))
if len(x) == 0:
return backend.make_array([])
x = backend.detach(backend.broadcast_cat(x, dim=-1))
elif scheme.lower() == "list":
for param in param_list:
x.append(getattr(param, attribute))
elif scheme.lower() == "dict":
unique_params = set()
x = self._recursive_build_params_dict(
self, unique_params=unique_params, param_list=param_list, attribute=attribute
)
if self.valid_context:
x = self.to_valid(x, group=group)
return x
def _recursive_build_params_dict(
self, node: Node, unique_params: set, param_list, attribute="value"
):
params = {}
for link, child in node.children.items():
if isinstance(child, Param) and child in param_list and child not in unique_params:
unique_params.add(child)
params[link] = getattr(child, attribute)
else:
params[link] = self._recursive_build_params_dict(
child, unique_params=unique_params, param_list=param_list, attribute=attribute
)
if len(params[link]) == 0:
del params[link]
return params
def _array_inspection(self, group: Optional[int] = None):
param_list = self.dynamic_params
param_list = tuple(p for p in param_list if (group is None or p.group == group))
self._check_values(param_list, "array")
x = []
with Memo(self, self.name + ":semi_findidx_active"):
for param in param_list:
if param.online:
shape = param.shape
else:
depth = max(memo.count("|") for memo in param.memos)
shape = param.batch_shape[-depth:] + param.shape
if shape == ():
x.append((param, ()))
else:
for i in range(prod(shape)):
x.append((param, tuple(itm.item() for itm in np.unravel_index(i, shape))))
return x
# Finders
#################################################################
[docs]
def find_param(
self, idx: Union[int, tuple[int]], group: Optional[int] = None, scheme: str = "array"
) -> tuple[Param, tuple[int]]:
"""
Identify which param is associated with the provided index in the
dynamic params array.
Parameters
----------
idx: Union[int, tuple[int]]
The index in the params array at which we wish to find the
associated param.
group: Optional[int]
If the dynamic params have multiple group values, then this argument
specifies which group to check.
scheme: str
Whether to search the array (default) params or list version of
params. dict is currently unsupported.
Returns
-------
param_info: tuple[Param, tuple[int]]
A tuple with the Param object and the index within the Param value
associated with idx (empty tuple if scalar). If idx is a tuple then
the result is a tuple of these results.
"""
if not isinstance(idx, int):
return tuple(self.find_param(i, group, scheme) for i in idx)
if scheme == "array":
x = self._array_inspection(group)
return x[idx]
elif scheme == "list":
param_list = tuple(p for p in self.dynamic_params if group is None or p.group == group)
return param_list[idx]
elif scheme == "dict":
raise NotImplementedError(
"find_param is not implemented for the dict scheme. The dict has the same structure as the graph and so may be inspected in a variety of other ways."
)
else:
raise ValueError(f"unrecognized scheme: {scheme}")
[docs]
def find_index(
self, param: Union[Param, tuple[Param], "Module"], scheme: str = "array"
) -> Union[int, slice]:
"""
Identify what index is associated with a param in the dynamic params
array.
Parameters
----------
param: Union[Param, tuple[Param], Module]
The param for which to find the associated index.
scheme: str
Whether to search the array (default) params or list version of
params. dict is currently unsupported.
Returns
-------
param_info: Union[int, slice]
A int giving the index associated with the provided Param object. If
the param is multi-dimensional then the result will be a slice over
all indices associated with that param.
"""
# 1. Handle recursive structures
if isinstance(param, (list, tuple)):
return tuple(self.find_index(p, scheme) for p in param)
if isinstance(param, GetSetValues):
return tuple(
self.find_index(c, scheme)
for c in param.children.values()
if isinstance(c, Param) and c.dynamic
)
groups = self.dynamic_param_groups if len(self.dynamic_param_groups) > 1 else [None]
for group in groups:
if scheme in ["array", "tensor"]:
inspection = self._array_inspection(group)
matches = [i for i, item in enumerate(inspection) if item[0] is param]
if not matches:
continue
idx = matches[0] if len(matches) == 1 else slice(min(matches), max(matches) + 1)
elif scheme == "list":
param_list = [p for p in self.dynamic_params if group is None or p.group == group]
if param not in param_list:
continue
idx = param_list.index(param)
elif scheme == "dict":
raise NotImplementedError("find_index is not implemented for the dict scheme.")
else:
raise ValueError(f"unrecognized scheme: {scheme}")
# Return with group prefix if we are in multi-group mode
return (group, idx) if len(self.dynamic_param_groups) > 1 else idx
raise ValueError(f"Param {param.name} could not be found in dynamic params.")
# To/From Valid
#################################################################
def _transform_params(self, node, init_params, param_list, transform_attr):
if isinstance(init_params, backend.array_type):
trans_params = []
batch = len(init_params.shape) > 1
B = tuple(init_params.shape[:-1]) if batch else ()
pos = 0
with Memo(self, self.name + ":semi_trans_active"):
for param in param_list:
if param.online:
shape = param.shape
else:
depth = max(memo.count("|") for memo in param.memos)
shape = param.batch_shape[-depth:] + param.shape
size = max(1, prod(shape)) # Handle scalar parameters
return_shape = (*B, size)
val = getattr(param, transform_attr)(
backend.view(init_params[..., pos : pos + size], B + shape)
)
trans_params.append(backend.view(val, return_shape))
pos += size
trans_params = backend.concatenate(trans_params, axis=-1)
elif isinstance(init_params, Sequence):
trans_params = []
if len(init_params) == len(param_list):
for param, value in zip(param_list, init_params):
trans_params.append(getattr(param, transform_attr)(value))
else:
raise FillParamsSequenceError(self.name, init_params, param_list)
elif isinstance(init_params, Mapping):
trans_params = {}
for key in init_params:
if key in node.children and isinstance(node[key], Param):
trans_params[key] = getattr(node[key], transform_attr)(init_params[key])
elif key in node.children:
sublist = tuple(p for p in param_list if p in node[key].children.values())
trans_params[key] = self._transform_params(
node[key], init_params[key], sublist, transform_attr
)
else:
raise FillParamsMappingError(self.name, self.children, missing_key=key)
else:
raise TypeError(
f"Input params type {type(init_params)} not supported. Should be {backend.array_type.__name__}, Sequence, or Mapping."
)
return trans_params
[docs]
def to_valid(
self, params: Union[ArrayLike, Sequence, Mapping], param_list=None, group=None
) -> Union[ArrayLike, Sequence, Mapping]:
"""Convert input params to valid params."""
if param_list is None:
param_list = self.dynamic_params
if len(self.dynamic_param_groups) > 1:
if group is None:
valid_params = []
for g, params_g in zip(self.dynamic_param_groups, params):
param_list_g = tuple(p for p in param_list if p.group == g)
valid_params.append(
self._transform_params(self, params_g, param_list_g, "to_valid")
)
return valid_params
else:
param_list = tuple(p for p in param_list if p.group == group)
return self._transform_params(self, params, param_list, "to_valid")
[docs]
def from_valid(
self, valid_params: Union[ArrayLike, Sequence, Mapping], param_list=None, group=None
) -> Union[ArrayLike, Sequence, Mapping]:
"""Convert valid params to input params."""
if param_list is None:
param_list = self.dynamic_params
if len(self.dynamic_param_groups) > 1:
if group is None:
params = []
for g, valid_params_g in zip(self.dynamic_param_groups, valid_params):
param_list_g = tuple(p for p in param_list if p.group == g)
params.append(
self._transform_params(self, valid_params_g, param_list_g, "from_valid")
)
return params
else:
param_list = tuple(p for p in param_list if p.group == group)
return self._transform_params(self, valid_params, param_list, "from_valid")