Source code for caskade.tests
from caskade import Module, Param, forward, backend
__all__ = ("test",)
def _test_full_integration():
class TestSim(Module):
def __init__(self, a, b, c, c_shape, m1):
super().__init__("test_sim")
self.a = a
self.b = Param("b", b)
self.c = Param("c", c, c_shape)
self.m1 = m1
@forward
def testfun(self, x, b=None):
y = self.m1()
return x + self.a + b + y
class TestSubSim(Module):
def __init__(self, d, e, f):
super().__init__("test_sub_sim")
self.d = Param("d", d)
self.e = Param("e", e)
self.f = Param("f", f)
@forward
def __call__(self, d=None, e=None, f=None):
return d + e + f
sub1 = TestSubSim(d=1.0, e=lambda s: s.children["flink"].value, f=None)
sub1.e.link("flink", sub1.f)
main1 = TestSim(a=2.0, b=None, c=None, c_shape=(), m1=sub1)
main1.c = main1.b
sub1.f = main1.c
b_value = backend.make_array(3.0)
res = main1.testfun(1.0, params=[b_value])
assert res.item() == 13.0
[docs]
def test():
"""Basic integration test of caskade to ensure that the library is functioning correctly."""
_test_full_integration()
print("Success!")