From 0e2f038fe6b7196833c74ccb43fa453bc9afa915 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 19 Feb 2026 15:42:22 +0100 Subject: [PATCH 01/29] Zero target inputs on remove_connection to avoid stale values --- src/pathsim/simulation.py | 5 +++++ src/pathsim/subsystem.py | 5 +++++ tests/pathsim/test_simulation.py | 15 +++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index f2814cc8..7f9307db 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -446,6 +446,11 @@ def remove_connection(self, connection): self.logger.error(_msg) raise ValueError(_msg) + #zero out target input ports to avoid stale values + for target in connection.targets: + for port in target.ports: + target.block.inputs[port] = 0.0 + #remove from global connection list self.connections.discard(connection) diff --git a/src/pathsim/subsystem.py b/src/pathsim/subsystem.py index 3f0cabc4..a46415ff 100644 --- a/src/pathsim/subsystem.py +++ b/src/pathsim/subsystem.py @@ -337,6 +337,11 @@ def remove_connection(self, connection): if connection not in self.connections: raise ValueError(f"{connection} not part of subsystem") + #zero out target input ports to avoid stale values + for target in connection.targets: + for port in target.ports: + target.block.inputs[port] = 0.0 + self.connections.discard(connection) if self.graph: diff --git a/tests/pathsim/test_simulation.py b/tests/pathsim/test_simulation.py index c1ae6b8a..e17e3fb9 100644 --- a/tests/pathsim/test_simulation.py +++ b/tests/pathsim/test_simulation.py @@ -884,6 +884,21 @@ def test_remove_connection_error(self): with self.assertRaises(ValueError): self.Sim.remove_connection(C) + def test_remove_connection_zeroes_inputs(self): + """Removing a connection zeroes the target block's input ports""" + # Run a step so the connection pushes data + self.Src.function = lambda t: 5.0 + self.Sim.step(0.01) + + # Int should have received input from Src + self.assertNotEqual(self.Int.inputs[0], 0.0) + + # Remove the connection + self.Sim.remove_connection(self.C1) + + # Target input should now be zero + self.assertEqual(self.Int.inputs[0], 0.0) + def test_remove_event(self): """Adding and removing events works""" evt = Event(func_evt=lambda t: t - 1.0, func_act=lambda t: None) From 655ed3a36066adc5ca1ba033cbc581404e7aaad3 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 19 Feb 2026 15:53:09 +0100 Subject: [PATCH 02/29] Reset block inputs in _assemble_graph instead of remove_connection --- src/pathsim/simulation.py | 11 +++++------ src/pathsim/subsystem.py | 10 +++++----- tests/pathsim/test_simulation.py | 8 +++++--- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index 7f9307db..306c7e4c 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -446,11 +446,6 @@ def remove_connection(self, connection): self.logger.error(_msg) raise ValueError(_msg) - #zero out target input ports to avoid stale values - for target in connection.targets: - for port in target.ports: - target.block.inputs[port] = 0.0 - #remove from global connection list self.connections.discard(connection) @@ -504,10 +499,14 @@ def remove_event(self, event): # system assembly ------------------------------------------------------------- def _assemble_graph(self): - """Build the internal graph representation for fast system function + """Build the internal graph representation for fast system function evaluation and algebraic loop resolution. """ + #reset all block inputs to clear stale values from removed connections + for block in self.blocks: + block.inputs.reset() + #time the graph construction with Timer(verbose=False) as T: self.graph = Graph(self.blocks, self.connections) diff --git a/src/pathsim/subsystem.py b/src/pathsim/subsystem.py index a46415ff..2dd3640c 100644 --- a/src/pathsim/subsystem.py +++ b/src/pathsim/subsystem.py @@ -337,11 +337,6 @@ def remove_connection(self, connection): if connection not in self.connections: raise ValueError(f"{connection} not part of subsystem") - #zero out target input ports to avoid stale values - for target in connection.targets: - for port in target.ports: - target.block.inputs[port] = 0.0 - self.connections.discard(connection) if self.graph: @@ -386,6 +381,11 @@ def _assemble_graph(self): """Assemble internal graph of subsystem for fast algebraic evaluation during simulation. """ + + #reset all block inputs to clear stale values from removed connections + for block in self.blocks: + block.inputs.reset() + self.graph = Graph({*self.blocks, self.interface}, self.connections) self._graph_dirty = False diff --git a/tests/pathsim/test_simulation.py b/tests/pathsim/test_simulation.py index e17e3fb9..beda6aae 100644 --- a/tests/pathsim/test_simulation.py +++ b/tests/pathsim/test_simulation.py @@ -885,7 +885,7 @@ def test_remove_connection_error(self): self.Sim.remove_connection(C) def test_remove_connection_zeroes_inputs(self): - """Removing a connection zeroes the target block's input ports""" + """Removing a connection zeroes target inputs on next graph rebuild""" # Run a step so the connection pushes data self.Src.function = lambda t: 5.0 self.Sim.step(0.01) @@ -893,10 +893,12 @@ def test_remove_connection_zeroes_inputs(self): # Int should have received input from Src self.assertNotEqual(self.Int.inputs[0], 0.0) - # Remove the connection + # Remove the connection — graph is dirty but not rebuilt yet self.Sim.remove_connection(self.C1) + self.assertTrue(self.Sim._graph_dirty) - # Target input should now be zero + # Next step triggers graph rebuild which resets inputs + self.Sim.step(0.01) self.assertEqual(self.Int.inputs[0], 0.0) def test_remove_event(self): From 80a3810651f2ae927a37d9df8dfb2b44e0c392b5 Mon Sep 17 00:00:00 2001 From: kwmcbride Date: Tue, 24 Feb 2026 06:55:41 -0800 Subject: [PATCH 03/29] adding divider block --- src/pathsim/blocks/__init__.py | 1 + src/pathsim/blocks/divider.py | 215 +++++++++++++++++++ tests/pathsim/blocks/test_divider.py | 304 +++++++++++++++++++++++++++ 3 files changed, 520 insertions(+) create mode 100644 src/pathsim/blocks/divider.py create mode 100644 tests/pathsim/blocks/test_divider.py diff --git a/src/pathsim/blocks/__init__.py b/src/pathsim/blocks/__init__.py index c8580e25..a196856e 100644 --- a/src/pathsim/blocks/__init__.py +++ b/src/pathsim/blocks/__init__.py @@ -1,6 +1,7 @@ from .differentiator import * from .integrator import * from .multiplier import * +from .divider import * from .converters import * from .comparator import * from .samplehold import * diff --git a/src/pathsim/blocks/divider.py b/src/pathsim/blocks/divider.py new file mode 100644 index 00000000..ae6329a0 --- /dev/null +++ b/src/pathsim/blocks/divider.py @@ -0,0 +1,215 @@ +######################################################################################### +## +## REDUCTION BLOCKS (blocks/divider.py) +## +## This module defines static 'Divider' block +## +######################################################################################### + +# IMPORTS =============================================================================== + +import numpy as np + +from math import prod + +from ._block import Block +from ..utils.register import Register +from ..optim.operator import Operator + + +# MISO BLOCKS =========================================================================== + +_ZERO_DIV_OPTIONS = ("warn", "raise", "clamp") + + +class Divider(Block): + """Multiplies and divides input signals (MISO). + + This is the default behavior (multiply all): + + .. math:: + + y(t) = \\prod_i u_i(t) + + and this is the behavior with an operations string: + + .. math:: + + y(t) = \\frac{\\prod_{i \\in M} u_i(t)}{\\prod_{j \\in D} u_j(t)} + + where :math:`M` is the set of inputs with ``*`` and :math:`D` the set with ``/``. + + + Example + ------- + Default initialization multiplies all inputs (same as :class:`Multiplier`): + + .. code-block:: python + + D = Divider() + + Multiply the first two inputs and divide by the third: + + .. code-block:: python + + D = Divider('**/') + + Raise an error instead of producing ``inf`` when a denominator input is zero: + + .. code-block:: python + + D = Divider('**/', zero_div='raise') + + Clamp the denominator to machine epsilon so the output stays finite: + + .. code-block:: python + + D = Divider('**/', zero_div='clamp') + + + Note + ---- + This block is purely algebraic and its operation (``op_alg``) will be called + multiple times per timestep, each time when ``Simulation._update(t)`` is + called in the global simulation loop. + + + Parameters + ---------- + operations : str, optional + String of ``*`` and ``/`` characters indicating which inputs are + multiplied (``*``) or divided (``/``). Inputs beyond the length of + the string default to ``*``. ``None`` multiplies all inputs. + zero_div : str, optional + Behaviour when a denominator input is zero. One of: + + ``'warn'`` *(default)* + Propagates ``inf`` and emits a ``RuntimeWarning`` — numpy's + standard behaviour. + ``'raise'`` + Raises ``ZeroDivisionError``. + ``'clamp'`` + Clamps the denominator magnitude to machine epsilon + (``numpy.finfo(float).eps``), preserving sign, so the output + stays large-but-finite rather than ``inf``. + + + Attributes + ---------- + _ops : dict + Maps operation characters to exponent values (``+1`` or ``-1``). + _ops_array : numpy.ndarray + Exponents (+1 for ``*``, -1 for ``/``) converted to an array. + op_alg : Operator + Internal algebraic operator. + """ + + input_port_labels = None + output_port_labels = {"out": 0} + + def __init__(self, operations=None, zero_div="warn"): + super().__init__() + + # validate zero_div + if zero_div not in _ZERO_DIV_OPTIONS: + raise ValueError( + f"'zero_div' must be one of {_ZERO_DIV_OPTIONS}, got '{zero_div}'" + ) + self.zero_div = zero_div + + # allowed arithmetic operations mapped to exponents + self._ops = {"*": 1, "/": -1} + self.operations = operations + + if self.operations is None: + + # Default: multiply all inputs — identical to Multiplier + self.op_alg = Operator( + func=prod, + jac=lambda x: np.array([[ + prod(np.delete(x, i)) for i in range(len(x)) + ]]) + ) + + else: + + # input validation + if not isinstance(self.operations, str): + raise ValueError("'operations' must be a string or None") + for op in self.operations: + if op not in self._ops: + raise ValueError( + f"operation '{op}' not in {set(self._ops)}" + ) + + self._ops_array = np.array( + [self._ops[op] for op in self.operations], dtype=float + ) + + # capture for closures + _ops_array = self._ops_array + _zero_div = zero_div + _eps = np.finfo(float).eps + + def _safe_den(d): + """Apply zero_div policy to a denominator value.""" + if d == 0: + if _zero_div == "raise": + raise ZeroDivisionError( + "Divider: denominator is zero. " + "Use zero_div='warn' or 'clamp' to suppress." + ) + elif _zero_div == "clamp": + return _eps + return d + + def prod_ops(X): + n = len(X) + no = len(_ops_array) + ops = np.ones(n) + ops[:min(n, no)] = _ops_array[:min(n, no)] + num = prod(X[i] for i in range(n) if ops[i] > 0) + den = _safe_den(prod(X[i] for i in range(n) if ops[i] < 0)) + return num / den + + def jac_ops(X): + n = len(X) + no = len(_ops_array) + ops = np.ones(n) + ops[:min(n, no)] = _ops_array[:min(n, no)] + X = np.asarray(X, dtype=float) + # Apply zero_div policy to all denominator inputs up front so + # both the direct division and the rest-product stay consistent. + X_safe = X.copy() + for i in range(n): + if ops[i] < 0: + X_safe[i] = _safe_den(float(X[i])) + row = [] + for k in range(n): + rest = np.prod( + np.power(np.delete(X_safe, k), np.delete(ops, k)) + ) + if ops[k] > 0: # multiply: dy/du_k = prod of rest + row.append(rest) + else: # divide: dy/du_k = -rest / u_k^2 + row.append(-rest / X_safe[k] ** 2) + return np.array([row]) + + self.op_alg = Operator(func=prod_ops, jac=jac_ops) + + + def __len__(self): + """Purely algebraic block.""" + return 1 + + + def update(self, t): + """Update system equation. + + Parameters + ---------- + t : float + Evaluation time. + """ + u = self.inputs.to_array() + self.outputs.update_from_array(self.op_alg(u)) diff --git a/tests/pathsim/blocks/test_divider.py b/tests/pathsim/blocks/test_divider.py new file mode 100644 index 00000000..f304dbb3 --- /dev/null +++ b/tests/pathsim/blocks/test_divider.py @@ -0,0 +1,304 @@ +######################################################################################## +## +## TESTS FOR +## 'blocks.divider.py' +## +######################################################################################## + +# IMPORTS ============================================================================== + +import unittest +import numpy as np + +from pathsim.blocks.divider import Divider + +from tests.pathsim.blocks._embedding import Embedding + + +# TESTS ================================================================================ + +class TestDivider(unittest.TestCase): + """ + Test the implementation of the 'Divider' block class + """ + + def test_init(self): + + # default initialization + D = Divider() + self.assertIsNone(D.operations) + + # valid ops strings + for ops in ["*", "/", "*/", "/*", "**/", "/**"]: + D = Divider(ops) + self.assertEqual(D.operations, ops) + + # non-string types are rejected + for bad in [0.4, 3, [1, -1], True]: + with self.assertRaises(ValueError): + Divider(bad) + + # strings with invalid characters are rejected + for bad in ["+/", "*-", "a", "**0", "+-"]: + with self.assertRaises(ValueError): + Divider(bad) + + + def test_embedding(self): + """Test algebraic output against reference via Embedding.""" + + # default: multiply all (identical to Multiplier) + D = Divider() + + def src(t): return np.cos(t), np.sin(t) + 2, 3.0, t + 1 + def ref(t): return np.cos(t) * (np.sin(t) + 2) * 3.0 * (t + 1) + + E = Embedding(D, src, ref) + for t in range(10): self.assertEqual(*E.check_MIMO(t)) + + # '**/' : multiply first two, divide by third + D = Divider('**/') + + def src(t): return np.cos(t) + 2, np.sin(t) + 2, t + 1 + def ref(t): return (np.cos(t) + 2) * (np.sin(t) + 2) / (t + 1) + + E = Embedding(D, src, ref) + for t in range(10): self.assertEqual(*E.check_MIMO(t)) + + # '*/' : u0 / u1 + D = Divider('*/') + + def src(t): return t + 1, np.cos(t) + 2 + def ref(t): return (t + 1) / (np.cos(t) + 2) + + E = Embedding(D, src, ref) + for t in range(10): self.assertEqual(*E.check_MIMO(t)) + + # ops string shorter than number of inputs: extra inputs default to '*' + # '/' with 3 inputs → y = u1 * u2 / u0 + D = Divider('/') + + def src(t): return t + 1, np.cos(t) + 2, 3.0 + def ref(t): return (np.cos(t) + 2) * 3.0 / (t + 1) + + E = Embedding(D, src, ref) + for t in range(10): self.assertEqual(*E.check_MIMO(t)) + + # single input, default: passes through unchanged + D = Divider() + + def src(t): return np.cos(t) + def ref(t): return np.cos(t) + + E = Embedding(D, src, ref) + for t in range(10): self.assertEqual(*E.check_SISO(t)) + + + def test_linearization(self): + """Test linearize / delinearize round-trip.""" + + # default (multiply all) — nonlinear, so only check at linearization point + D = Divider() + + def src(t): return np.cos(t) + 2, t + 1, 3.0 + def ref(t): return (np.cos(t) + 2) * (t + 1) * 3.0 + + E = Embedding(D, src, ref) + + for t in range(10): self.assertEqual(*E.check_MIMO(t)) + + # linearize at the current operating point (inputs set to src(9) by the loop) + D.linearize(t) + a, b = E.check_MIMO(t) + self.assertAlmostEqual(np.linalg.norm(a - b), 0, 8) + + D.delinearize() + for t in range(10): self.assertEqual(*E.check_MIMO(t)) + + # with ops string + D = Divider('**/') + + def src(t): return np.cos(t) + 2, np.sin(t) + 2, t + 1 + def ref(t): return (np.cos(t) + 2) * (np.sin(t) + 2) / (t + 1) + + E = Embedding(D, src, ref) + + for t in range(10): self.assertEqual(*E.check_MIMO(t)) + + D.linearize(t) + a, b = E.check_MIMO(t) + self.assertAlmostEqual(np.linalg.norm(a - b), 0, 8) + + D.delinearize() + for t in range(10): self.assertEqual(*E.check_MIMO(t)) + + + def test_update_single(self): + + D = Divider() + + D.inputs[0] = 5.0 + D.update(None) + + self.assertEqual(D.outputs[0], 5.0) + + + def test_update_multi(self): + + D = Divider() + + D.inputs[0] = 2.0 + D.inputs[1] = 3.0 + D.inputs[2] = 4.0 + D.update(None) + + self.assertEqual(D.outputs[0], 24.0) + + + def test_update_ops(self): + + # '**/' : 2 * 3 / 4 = 1.5 + D = Divider('**/') + D.inputs[0] = 2.0 + D.inputs[1] = 3.0 + D.inputs[2] = 4.0 + D.update(None) + self.assertAlmostEqual(D.outputs[0], 1.5) + + # '*/' : 6 / 2 = 3 + D = Divider('*/') + D.inputs[0] = 6.0 + D.inputs[1] = 2.0 + D.update(None) + self.assertAlmostEqual(D.outputs[0], 3.0) + + # '/' with extra inputs: 2 * 3 / 4 = 1.5 (u0 divides, u1 u2 multiply) + D = Divider('/') + D.inputs[0] = 4.0 + D.inputs[1] = 2.0 + D.inputs[2] = 3.0 + D.update(None) + self.assertAlmostEqual(D.outputs[0], 1.5) + + # '/**' : u1 * u2 / u0 + D = Divider('/**') + D.inputs[0] = 4.0 + D.inputs[1] = 2.0 + D.inputs[2] = 3.0 + D.update(None) + self.assertAlmostEqual(D.outputs[0], 1.5) + + + def test_jacobian(self): + """Verify analytical Jacobian against central finite differences.""" + + eps = 1e-6 + + def numerical_jac(func, u): + n = len(u) + J = np.zeros((1, n)) + for k in range(n): + u_p = u.copy(); u_p[k] += eps + u_m = u.copy(); u_m[k] -= eps + J[0, k] = (func(u_p) - func(u_m)) / (2 * eps) + return J + + # default (all multiply) + D = Divider() + u = np.array([2.0, 3.0, 4.0]) + np.testing.assert_allclose( + D.op_alg.jac(u), + numerical_jac(D.op_alg._func, u), + rtol=1e-5, + ) + + # '**/' : u0 * u1 / u2 + D = Divider('**/') + u = np.array([2.0, 3.0, 4.0]) + np.testing.assert_allclose( + D.op_alg.jac(u), + numerical_jac(D.op_alg._func, u), + rtol=1e-5, + ) + + # '*/' : u0 / u1 + D = Divider('*/') + u = np.array([6.0, 2.0]) + np.testing.assert_allclose( + D.op_alg.jac(u), + numerical_jac(D.op_alg._func, u), + rtol=1e-5, + ) + + # '/**' : u1 * u2 / u0 + D = Divider('/**') + u = np.array([4.0, 2.0, 3.0]) + np.testing.assert_allclose( + D.op_alg.jac(u), + numerical_jac(D.op_alg._func, u), + rtol=1e-5, + ) + + # ops shorter than inputs: '/' with 3 inputs → u1 * u2 / u0 + D = Divider('/') + u = np.array([4.0, 2.0, 3.0]) + np.testing.assert_allclose( + D.op_alg.jac(u), + numerical_jac(D.op_alg._func, u), + rtol=1e-5, + ) + + + def test_zero_div(self): + + # 'warn' (default): produces inf, no exception + D = Divider('*/', zero_div='warn') + D.inputs[0] = 6.0 + D.inputs[1] = 0.0 + import warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + D.update(None) + self.assertTrue(np.isinf(D.outputs[0])) + + # 'raise': ZeroDivisionError on zero denominator + D = Divider('*/', zero_div='raise') + D.inputs[0] = 6.0 + D.inputs[1] = 0.0 + with self.assertRaises(ZeroDivisionError): + D.update(None) + + # 'raise': no error when denominator is nonzero + D = Divider('*/', zero_div='raise') + D.inputs[0] = 6.0 + D.inputs[1] = 2.0 + D.update(None) + self.assertAlmostEqual(D.outputs[0], 3.0) + + # 'clamp': output is large-but-finite + D = Divider('*/', zero_div='clamp') + D.inputs[0] = 1.0 + D.inputs[1] = 0.0 + D.update(None) + self.assertTrue(np.isfinite(D.outputs[0])) + self.assertGreater(abs(D.outputs[0]), 1.0) + + # 'raise' invalid zero_div value + with self.assertRaises(ValueError): + Divider('*/', zero_div='ignore') + + # Jacobian: 'raise' on zero denominator input + D = Divider('*/', zero_div='raise') + with self.assertRaises(ZeroDivisionError): + D.op_alg.jac(np.array([6.0, 0.0])) + + # Jacobian: 'clamp' stays finite + D = Divider('*/', zero_div='clamp') + J = D.op_alg.jac(np.array([1.0, 0.0])) + self.assertTrue(np.all(np.isfinite(J))) + + +# RUN TESTS LOCALLY ==================================================================== + +if __name__ == '__main__': + unittest.main(verbosity=2) From 7d152136071f6751568b7b21c255b693e8c18ab4 Mon Sep 17 00:00:00 2001 From: kwmcbride Date: Tue, 24 Feb 2026 07:11:04 -0800 Subject: [PATCH 04/29] I had this default to multiplier - changed it to */ and updated tests --- src/pathsim/blocks/divider.py | 7 ++++--- tests/pathsim/blocks/test_divider.py | 23 ++++++++++++----------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/pathsim/blocks/divider.py b/src/pathsim/blocks/divider.py index ae6329a0..a0d684cc 100644 --- a/src/pathsim/blocks/divider.py +++ b/src/pathsim/blocks/divider.py @@ -42,7 +42,7 @@ class Divider(Block): Example ------- - Default initialization multiplies all inputs (same as :class:`Multiplier`): + Default initialization multiplies the first input and divides by the second: .. code-block:: python @@ -79,7 +79,8 @@ class Divider(Block): operations : str, optional String of ``*`` and ``/`` characters indicating which inputs are multiplied (``*``) or divided (``/``). Inputs beyond the length of - the string default to ``*``. ``None`` multiplies all inputs. + the string default to ``*``. Defaults to ``'*/'`` (divide second + input by first). zero_div : str, optional Behaviour when a denominator input is zero. One of: @@ -107,7 +108,7 @@ class Divider(Block): input_port_labels = None output_port_labels = {"out": 0} - def __init__(self, operations=None, zero_div="warn"): + def __init__(self, operations="*/", zero_div="warn"): super().__init__() # validate zero_div diff --git a/tests/pathsim/blocks/test_divider.py b/tests/pathsim/blocks/test_divider.py index f304dbb3..a60b7edc 100644 --- a/tests/pathsim/blocks/test_divider.py +++ b/tests/pathsim/blocks/test_divider.py @@ -26,7 +26,7 @@ def test_init(self): # default initialization D = Divider() - self.assertIsNone(D.operations) + self.assertEqual(D.operations, "*/") # valid ops strings for ops in ["*", "/", "*/", "/*", "**/", "/**"]: @@ -47,11 +47,11 @@ def test_init(self): def test_embedding(self): """Test algebraic output against reference via Embedding.""" - # default: multiply all (identical to Multiplier) + # default: '*/' — u0 * u2 * ... / u1 D = Divider() - def src(t): return np.cos(t), np.sin(t) + 2, 3.0, t + 1 - def ref(t): return np.cos(t) * (np.sin(t) + 2) * 3.0 * (t + 1) + def src(t): return t + 1, np.cos(t) + 2, 3.0 + def ref(t): return (t + 1) * 3.0 / (np.cos(t) + 2) E = Embedding(D, src, ref) for t in range(10): self.assertEqual(*E.check_MIMO(t)) @@ -97,11 +97,11 @@ def ref(t): return np.cos(t) def test_linearization(self): """Test linearize / delinearize round-trip.""" - # default (multiply all) — nonlinear, so only check at linearization point + # default ('*/') — nonlinear, so only check at linearization point D = Divider() - def src(t): return np.cos(t) + 2, t + 1, 3.0 - def ref(t): return (np.cos(t) + 2) * (t + 1) * 3.0 + def src(t): return np.cos(t) + 2, t + 1 + def ref(t): return (np.cos(t) + 2) / (t + 1) E = Embedding(D, src, ref) @@ -145,14 +145,15 @@ def test_update_single(self): def test_update_multi(self): + # default '*/' with 3 inputs: ops=[*, /, *] → (u0 * u2) / u1 D = Divider() - D.inputs[0] = 2.0 - D.inputs[1] = 3.0 - D.inputs[2] = 4.0 + D.inputs[0] = 6.0 + D.inputs[1] = 2.0 + D.inputs[2] = 3.0 D.update(None) - self.assertEqual(D.outputs[0], 24.0) + self.assertEqual(D.outputs[0], 9.0) def test_update_ops(self): From 7bc8876c10f39229860bdf1937c64af84f1544f8 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 19 Feb 2026 21:47:52 +0100 Subject: [PATCH 05/29] Add mutable class decorator for runtime parameter reinitialization --- src/pathsim/utils/__init__.py | 1 + src/pathsim/utils/mutable.py | 171 ++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 src/pathsim/utils/mutable.py diff --git a/src/pathsim/utils/__init__.py b/src/pathsim/utils/__init__.py index 4e586b8c..ce95ea7c 100644 --- a/src/pathsim/utils/__init__.py +++ b/src/pathsim/utils/__init__.py @@ -1 +1,2 @@ from .deprecation import deprecated +from .mutable import mutable diff --git a/src/pathsim/utils/mutable.py b/src/pathsim/utils/mutable.py new file mode 100644 index 00000000..686ec00c --- /dev/null +++ b/src/pathsim/utils/mutable.py @@ -0,0 +1,171 @@ +######################################################################################### +## +## MUTABLE PARAMETER DECORATOR +## (utils/mutable.py) +## +## Class decorator that enables runtime parameter mutation with automatic +## reinitialization. When a decorated parameter is changed, the block's +## __init__ is re-run with updated values while preserving engine state. +## +######################################################################################### + +# IMPORTS =============================================================================== + +import inspect +import functools + +import numpy as np + + +# DECORATOR ============================================================================= + +def mutable(*params): + """Class decorator that makes listed parameters trigger automatic reinitialization. + + When a parameter declared as mutable is changed at runtime, the block's ``__init__`` + is re-executed with the updated parameter values. The integration engine state is + preserved across the reinitialization, ensuring continuity during simulation. + + A ``set(**kwargs)`` method is also generated for batched parameter updates that + triggers only a single reinitialization. + + Parameters + ---------- + params : str + names of the mutable parameters (must match ``__init__`` argument names) + + Example + ------- + .. code-block:: python + + @mutable("K", "T") + class PT1(StateSpace): + def __init__(self, K=1.0, T=1.0): + self.K = K + self.T = T + super().__init__( + A=np.array([[-1.0 / T]]), + B=np.array([[K / T]]), + C=np.array([[1.0]]), + D=np.array([[0.0]]) + ) + + pt1 = PT1(K=2.0, T=0.5) + pt1.K = 5.0 # auto reinitializes + pt1.set(K=5.0, T=0.3) # single reinitialization + """ + + def decorator(cls): + + original_init = cls.__init__ + + # get all __init__ parameter names for reinit + init_params = [ + name for name in inspect.signature(original_init).parameters + if name != "self" + ] + + # validate that declared mutable params exist in __init__ + for p in params: + if p not in init_params: + raise ValueError( + f"Mutable parameter '{p}' not found in " + f"{cls.__name__}.__init__ signature {init_params}" + ) + + # -- install property descriptors for mutable params --------------------------- + + for name in params: + storage = f"_p_{name}" + + def _make_property(s): + def getter(self): + return getattr(self, s) + + def setter(self, value): + setattr(self, s, value) + if getattr(self, '_param_locked', False): + _reinit(self) + + return property(getter, setter) + + setattr(cls, name, _make_property(storage)) + + # -- reinit function ----------------------------------------------------------- + + def _reinit(self): + """Re-run __init__ with current parameter values, preserving engine state.""" + + # gather current values for all init params + kwargs = {} + for name in init_params: + if hasattr(self, name): + kwargs[name] = getattr(self, name) + + # save engine state + engine = self.engine if hasattr(self, 'engine') else None + + # re-run init (unlock to prevent recursive reinit) + self._param_locked = False + original_init(self, **kwargs) + self._param_locked = True + + # restore engine + if engine is not None: + old_dim = len(engine) + new_dim = len(np.atleast_1d(self.initial_value)) if hasattr(self, 'initial_value') else 0 + + if old_dim == new_dim: + # same dimension - restore the engine directly + self.engine = engine + else: + # dimension changed - create new engine inheriting settings + self.engine = type(engine).create( + self.initial_value, + parent=engine.parent, + from_engine=None + ) + # inherit tolerances manually since from_engine=None + self.engine.tolerance_lte_abs = engine.tolerance_lte_abs + self.engine.tolerance_lte_rel = engine.tolerance_lte_rel + + # -- wrap __init__ to flip the lock after construction ------------------------- + + @functools.wraps(original_init) + def new_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self._param_locked = True + + cls.__init__ = new_init + + # -- generate batched set() method --------------------------------------------- + + def set(self, **kwargs): + """Set multiple parameters and reinitialize once. + + Parameters + ---------- + kwargs : dict + parameter names and their new values + + Example + ------- + .. code-block:: python + + block.set(K=5.0, T=0.3) + """ + self._param_locked = False + for key, value in kwargs.items(): + setattr(self, key, value) + self._param_locked = True + _reinit(self) + + cls.set = set + + # -- store metadata for introspection ------------------------------------------ + + cls._mutable_params = params + + return cls + + return decorator From bc1c62e8ddc11197b6c3b7d47b7445d36c70b331 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 19 Feb 2026 21:53:57 +0100 Subject: [PATCH 06/29] Apply mutable decorator to ctrl, lti, and filter blocks --- src/pathsim/blocks/ctrl.py | 6 + src/pathsim/blocks/filters.py | 6 + src/pathsim/blocks/lti.py | 4 + src/pathsim/utils/mutable.py | 214 +++++++++++++++++----------------- 4 files changed, 126 insertions(+), 104 deletions(-) diff --git a/src/pathsim/blocks/ctrl.py b/src/pathsim/blocks/ctrl.py index 5f65cdc6..70fa72f6 100644 --- a/src/pathsim/blocks/ctrl.py +++ b/src/pathsim/blocks/ctrl.py @@ -14,10 +14,12 @@ from .dynsys import DynamicalSystem from ..optim.operator import Operator, DynamicOperator +from ..utils.mutable import mutable # LTI CONTROL BLOCKS (StateSpace subclasses) ============================================ +@mutable class PT1(StateSpace): """First-order lag element (PT1). @@ -65,6 +67,7 @@ def __init__(self, K=1.0, T=1.0): ) +@mutable class PT2(StateSpace): """Second-order lag element (PT2). @@ -124,6 +127,7 @@ def __init__(self, K=1.0, T=1.0, d=1.0): ) +@mutable class LeadLag(StateSpace): """Lead-Lag compensator. @@ -180,6 +184,7 @@ def __init__(self, K=1.0, T1=1.0, T2=1.0): ) +@mutable class PID(StateSpace): """Proportional-Integral-Differentiation (PID) controller. @@ -253,6 +258,7 @@ def __init__(self, Kp=0, Ki=0, Kd=0, f_max=100): ) +@mutable class AntiWindupPID(PID): """Proportional-Integral-Differentiation (PID) controller with anti-windup mechanism (back-calculation). diff --git a/src/pathsim/blocks/filters.py b/src/pathsim/blocks/filters.py index ceff9cdc..6015df89 100644 --- a/src/pathsim/blocks/filters.py +++ b/src/pathsim/blocks/filters.py @@ -14,10 +14,12 @@ from .lti import StateSpace from ..utils.register import Register +from ..utils.mutable import mutable # FILTER BLOCKS ========================================================================= +@mutable class ButterworthLowpassFilter(StateSpace): """Direct implementation of a low pass butterworth filter block. @@ -52,6 +54,7 @@ def __init__(self, Fc=100, n=2): super().__init__(omega_c*A, omega_c*B, C, D) +@mutable class ButterworthHighpassFilter(StateSpace): """Direct implementation of a high pass butterworth filter block. @@ -85,6 +88,7 @@ def __init__(self, Fc=100, n=2): super().__init__(omega_c*A, omega_c*B, C, D) +@mutable class ButterworthBandpassFilter(StateSpace): """Direct implementation of a bandpass butterworth filter block. @@ -119,6 +123,7 @@ def __init__(self, Fc=[50, 100], n=2): super().__init__(*tf2ss(num, den)) +@mutable class ButterworthBandstopFilter(StateSpace): """Direct implementation of a bandstop butterworth filter block. @@ -153,6 +158,7 @@ def __init__(self, Fc=[50, 100], n=2): super().__init__(*tf2ss(num, den)) +@mutable class AllpassFilter(StateSpace): """Direct implementation of a first order allpass filter, or a cascade of n 1st order allpass filters diff --git a/src/pathsim/blocks/lti.py b/src/pathsim/blocks/lti.py index de85e99c..6a9cd23b 100644 --- a/src/pathsim/blocks/lti.py +++ b/src/pathsim/blocks/lti.py @@ -22,6 +22,7 @@ from ..utils.deprecation import deprecated from ..optim.operator import DynamicOperator +from ..utils.mutable import mutable # LTI BLOCKS ============================================================================ @@ -169,6 +170,7 @@ def step(self, t, dt): return self.engine.step(f, dt) +@mutable class TransferFunctionPRC(StateSpace): """This block defines a LTI (MIMO for pole residue) transfer function. @@ -227,6 +229,7 @@ class TransferFunction(TransferFunctionPRC): pass +@mutable class TransferFunctionZPG(StateSpace): """This block defines a LTI (SISO) transfer function. @@ -281,6 +284,7 @@ def __init__(self, Zeros=[], Poles=[-1], Gain=1.0): super().__init__(sp_SS.A, sp_SS.B, sp_SS.C, sp_SS.D) +@mutable class TransferFunctionNumDen(StateSpace): """This block defines a LTI (SISO) transfer function. diff --git a/src/pathsim/utils/mutable.py b/src/pathsim/utils/mutable.py index 686ec00c..97888908 100644 --- a/src/pathsim/utils/mutable.py +++ b/src/pathsim/utils/mutable.py @@ -17,28 +17,79 @@ import numpy as np +# REINIT HELPER ========================================================================= + +def _do_reinit(block): + """Re-run __init__ with current parameter values, preserving engine state. + + Uses ``type(block).__init__`` to always reinit from the most derived class, + ensuring that subclass overrides (e.g. operator replacements) are preserved. + + Parameters + ---------- + block : Block + the block instance to reinitialize + """ + + actual_cls = type(block) + + # gather current values for ALL init params of the actual class + sig = inspect.signature(actual_cls.__init__) + kwargs = {} + for name in sig.parameters: + if name == "self": + continue + if hasattr(block, name): + kwargs[name] = getattr(block, name) + + # save engine + engine = block.engine if hasattr(block, 'engine') else None + + # re-run init through the wrapped __init__ (handles depth counting) + block._param_locked = False + actual_cls.__init__(block, **kwargs) + # _param_locked is set to True by the outermost new_init wrapper + + # restore engine + if engine is not None: + old_dim = len(engine) + new_dim = len(np.atleast_1d(block.initial_value)) if hasattr(block, 'initial_value') else 0 + + if old_dim == new_dim: + # same dimension - restore the entire engine + block.engine = engine + else: + # dimension changed - create new engine inheriting settings + block.engine = type(engine).create( + block.initial_value, + parent=engine.parent, + ) + block.engine.tolerance_lte_abs = engine.tolerance_lte_abs + block.engine.tolerance_lte_rel = engine.tolerance_lte_rel + + # DECORATOR ============================================================================= -def mutable(*params): - """Class decorator that makes listed parameters trigger automatic reinitialization. +def mutable(cls): + """Class decorator that makes all ``__init__`` parameters trigger automatic + reinitialization when changed at runtime. - When a parameter declared as mutable is changed at runtime, the block's ``__init__`` - is re-executed with the updated parameter values. The integration engine state is - preserved across the reinitialization, ensuring continuity during simulation. + Parameters are auto-detected from the ``__init__`` signature. When any parameter + is changed at runtime, the block's ``__init__`` is re-executed with updated values. + The integration engine state is preserved across reinitialization. A ``set(**kwargs)`` method is also generated for batched parameter updates that triggers only a single reinitialization. - Parameters - ---------- - params : str - names of the mutable parameters (must match ``__init__`` argument names) + Supports inheritance: if both a parent and child class use ``@mutable``, the init + guard uses a depth counter to ensure reinitialization only triggers after the + outermost ``__init__`` completes. Example ------- .. code-block:: python - @mutable("K", "T") + @mutable class PT1(StateSpace): def __init__(self, K=1.0, T=1.0): self.K = K @@ -55,117 +106,72 @@ def __init__(self, K=1.0, T=1.0): pt1.set(K=5.0, T=0.3) # single reinitialization """ - def decorator(cls): + original_init = cls.__init__ - original_init = cls.__init__ + # auto-detect all __init__ parameters + params = [ + name for name in inspect.signature(original_init).parameters + if name != "self" + ] - # get all __init__ parameter names for reinit - init_params = [ - name for name in inspect.signature(original_init).parameters - if name != "self" - ] + # -- install property descriptors for all params ------------------------------- - # validate that declared mutable params exist in __init__ - for p in params: - if p not in init_params: - raise ValueError( - f"Mutable parameter '{p}' not found in " - f"{cls.__name__}.__init__ signature {init_params}" - ) + for name in params: + storage = f"_p_{name}" - # -- install property descriptors for mutable params --------------------------- + def _make_property(s): + def getter(self): + return getattr(self, s) - for name in params: - storage = f"_p_{name}" + def setter(self, value): + setattr(self, s, value) + if getattr(self, '_param_locked', False): + _do_reinit(self) - def _make_property(s): - def getter(self): - return getattr(self, s) + return property(getter, setter) - def setter(self, value): - setattr(self, s, value) - if getattr(self, '_param_locked', False): - _reinit(self) + setattr(cls, name, _make_property(storage)) - return property(getter, setter) + # -- wrap __init__ with depth counter ------------------------------------------ - setattr(cls, name, _make_property(storage)) - - # -- reinit function ----------------------------------------------------------- - - def _reinit(self): - """Re-run __init__ with current parameter values, preserving engine state.""" - - # gather current values for all init params - kwargs = {} - for name in init_params: - if hasattr(self, name): - kwargs[name] = getattr(self, name) - - # save engine state - engine = self.engine if hasattr(self, 'engine') else None - - # re-run init (unlock to prevent recursive reinit) - self._param_locked = False - original_init(self, **kwargs) - self._param_locked = True - - # restore engine - if engine is not None: - old_dim = len(engine) - new_dim = len(np.atleast_1d(self.initial_value)) if hasattr(self, 'initial_value') else 0 - - if old_dim == new_dim: - # same dimension - restore the engine directly - self.engine = engine - else: - # dimension changed - create new engine inheriting settings - self.engine = type(engine).create( - self.initial_value, - parent=engine.parent, - from_engine=None - ) - # inherit tolerances manually since from_engine=None - self.engine.tolerance_lte_abs = engine.tolerance_lte_abs - self.engine.tolerance_lte_rel = engine.tolerance_lte_rel - - # -- wrap __init__ to flip the lock after construction ------------------------- - - @functools.wraps(original_init) - def new_init(self, *args, **kwargs): + @functools.wraps(original_init) + def new_init(self, *args, **kwargs): + self._init_depth = getattr(self, '_init_depth', 0) + 1 + try: original_init(self, *args, **kwargs) - self._param_locked = True - - cls.__init__ = new_init + finally: + self._init_depth -= 1 + if self._init_depth == 0: + self._param_locked = True - # -- generate batched set() method --------------------------------------------- + cls.__init__ = new_init - def set(self, **kwargs): - """Set multiple parameters and reinitialize once. + # -- generate batched set() method --------------------------------------------- - Parameters - ---------- - kwargs : dict - parameter names and their new values + def set(self, **kwargs): + """Set multiple parameters and reinitialize once. - Example - ------- - .. code-block:: python + Parameters + ---------- + kwargs : dict + parameter names and their new values - block.set(K=5.0, T=0.3) - """ - self._param_locked = False - for key, value in kwargs.items(): - setattr(self, key, value) - self._param_locked = True - _reinit(self) + Example + ------- + .. code-block:: python - cls.set = set + block.set(K=5.0, T=0.3) + """ + self._param_locked = False + for key, value in kwargs.items(): + setattr(self, key, value) + _do_reinit(self) - # -- store metadata for introspection ------------------------------------------ + cls.set = set - cls._mutable_params = params + # -- store metadata for introspection ------------------------------------------ - return cls + existing = getattr(cls, '_mutable_params', ()) + cls._mutable_params = existing + tuple(p for p in params if p not in existing) - return decorator + return cls From 4a8691f651964d6d909e681c08d2e2c3a8cc2714 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 19 Feb 2026 21:56:08 +0100 Subject: [PATCH 07/29] Apply mutable decorator to sources, delay, spectrum, FIR, converters, samplehold --- src/pathsim/blocks/converters.py | 3 +++ src/pathsim/blocks/delay.py | 2 ++ src/pathsim/blocks/fir.py | 6 ++++-- src/pathsim/blocks/samplehold.py | 2 ++ src/pathsim/blocks/sources.py | 7 +++++++ src/pathsim/blocks/spectrum.py | 2 ++ 6 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/pathsim/blocks/converters.py b/src/pathsim/blocks/converters.py index c2726930..b069031c 100644 --- a/src/pathsim/blocks/converters.py +++ b/src/pathsim/blocks/converters.py @@ -12,10 +12,12 @@ from ._block import Block from ..utils.register import Register from ..events.schedule import Schedule +from ..utils.mutable import mutable # MIXED SIGNAL BLOCKS =================================================================== +@mutable class ADC(Block): """Models an ideal Analog-to-Digital Converter (ADC). @@ -104,6 +106,7 @@ def __len__(self): return 0 +@mutable class DAC(Block): """Models an ideal Digital-to-Analog Converter (DAC). diff --git a/src/pathsim/blocks/delay.py b/src/pathsim/blocks/delay.py index 88baad10..174ec8a8 100644 --- a/src/pathsim/blocks/delay.py +++ b/src/pathsim/blocks/delay.py @@ -12,10 +12,12 @@ from ._block import Block from ..utils.adaptivebuffer import AdaptiveBuffer +from ..utils.mutable import mutable # BLOCKS ================================================================================ +@mutable class Delay(Block): """Delays the input signal by a time constant 'tau' in seconds. diff --git a/src/pathsim/blocks/fir.py b/src/pathsim/blocks/fir.py index 5cdebd66..8db1a8a3 100644 --- a/src/pathsim/blocks/fir.py +++ b/src/pathsim/blocks/fir.py @@ -10,13 +10,15 @@ import numpy as np from collections import deque -from ._block import Block +from ._block import Block from ..utils.register import Register -from ..events.schedule import Schedule +from ..events.schedule import Schedule +from ..utils.mutable import mutable # FIR FILTER BLOCK ====================================================================== +@mutable class FIR(Block): """Models a discrete-time Finite-Impulse-Response (FIR) filter. diff --git a/src/pathsim/blocks/samplehold.py b/src/pathsim/blocks/samplehold.py index ae3e5b96..a292877c 100644 --- a/src/pathsim/blocks/samplehold.py +++ b/src/pathsim/blocks/samplehold.py @@ -9,10 +9,12 @@ from ._block import Block from ..events.schedule import Schedule +from ..utils.mutable import mutable # MIXED SIGNAL BLOCKS =================================================================== +@mutable class SampleHold(Block): """Samples the inputs periodically and produces them at the output. diff --git a/src/pathsim/blocks/sources.py b/src/pathsim/blocks/sources.py index 170514b9..a2abfcea 100644 --- a/src/pathsim/blocks/sources.py +++ b/src/pathsim/blocks/sources.py @@ -14,6 +14,7 @@ from ._block import Block from ..utils.register import Register from ..utils.deprecation import deprecated +from ..utils.mutable import mutable from ..events.schedule import Schedule, ScheduleList from .._constants import TOLERANCE @@ -169,6 +170,7 @@ def update(self, t): # SPECIAL CONTINUOUS SOURCE BLOCKS ====================================================== +@mutable class TriangleWaveSource(Source): """Source block that generates an analog triangle wave @@ -214,6 +216,7 @@ def _triangle_wave(self, t, f): return 2 * abs(t*f - np.floor(t*f + 0.5)) - 1 +@mutable class SinusoidalSource(Source): """Source block that generates a sinusoid wave @@ -289,6 +292,7 @@ def _gaussian(self, t, f_max): return np.exp(-(t/tau)**2) +@mutable class SinusoidalPhaseNoiseSource(Block): """Sinusoidal source with cumulative and white phase noise. @@ -703,6 +707,7 @@ class ChirpSource(ChirpPhaseNoiseSource): # SPECIAL DISCRETE SOURCE BLOCKS ======================================================== +@mutable class PulseSource(Block): """Generates a periodic pulse waveform with defined rise and fall times. @@ -909,6 +914,7 @@ class Pulse(PulseSource): pass +@mutable class ClockSource(Block): """Discrete time clock source block. @@ -970,6 +976,7 @@ class Clock(ClockSource): +@mutable class SquareWaveSource(Block): """Discrete time square wave source. diff --git a/src/pathsim/blocks/spectrum.py b/src/pathsim/blocks/spectrum.py index 24268fe9..7b3a0878 100644 --- a/src/pathsim/blocks/spectrum.py +++ b/src/pathsim/blocks/spectrum.py @@ -15,12 +15,14 @@ from ..utils.realtimeplotter import RealtimePlotter from ..utils.deprecation import deprecated +from ..utils.mutable import mutable from .._constants import COLORS_ALL # BLOCKS FOR DATA RECORDING ============================================================= +@mutable class Spectrum(Block): """Block for fourier spectrum analysis (spectrum analyzer). From 766b2a5c7ea6854e8385aa3a1852a6d63df4a34d Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 19 Feb 2026 21:59:21 +0100 Subject: [PATCH 08/29] Add tests for mutable decorator --- tests/pathsim/utils/test_mutable.py | 267 ++++++++++++++++++++++++++++ 1 file changed, 267 insertions(+) create mode 100644 tests/pathsim/utils/test_mutable.py diff --git a/tests/pathsim/utils/test_mutable.py b/tests/pathsim/utils/test_mutable.py new file mode 100644 index 00000000..e43d36fd --- /dev/null +++ b/tests/pathsim/utils/test_mutable.py @@ -0,0 +1,267 @@ +######################################################################################## +## +## TESTS FOR +## 'utils.mutable.py' +## +######################################################################################## + +# IMPORTS ============================================================================== + +import unittest +import numpy as np + +from pathsim.blocks._block import Block +from pathsim.blocks.lti import StateSpace +from pathsim.blocks.ctrl import PT1, PT2, LeadLag, PID, AntiWindupPID +from pathsim.blocks.lti import TransferFunctionNumDen, TransferFunctionZPG +from pathsim.blocks.filters import ButterworthLowpassFilter +from pathsim.blocks.sources import SinusoidalSource, ClockSource +from pathsim.blocks.delay import Delay +from pathsim.blocks.fir import FIR +from pathsim.blocks.samplehold import SampleHold + +from pathsim.utils.mutable import mutable + + +# TESTS FOR DECORATOR ================================================================== + +class TestMutableDecorator(unittest.TestCase): + """Test the @mutable decorator mechanics.""" + + def test_basic_construction(self): + """Block construction should work as before.""" + pt1 = PT1(K=2.0, T=0.5) + self.assertEqual(pt1.K, 2.0) + self.assertEqual(pt1.T, 0.5) + np.testing.assert_array_almost_equal(pt1.A, [[-2.0]]) + np.testing.assert_array_almost_equal(pt1.B, [[4.0]]) + + def test_param_mutation_triggers_reinit(self): + """Changing a mutable param should update derived state.""" + pt1 = PT1(K=1.0, T=1.0) + np.testing.assert_array_almost_equal(pt1.A, [[-1.0]]) + np.testing.assert_array_almost_equal(pt1.B, [[1.0]]) + + pt1.K = 5.0 + np.testing.assert_array_almost_equal(pt1.A, [[-1.0]]) + np.testing.assert_array_almost_equal(pt1.B, [[5.0]]) + + pt1.T = 0.5 + np.testing.assert_array_almost_equal(pt1.A, [[-2.0]]) + np.testing.assert_array_almost_equal(pt1.B, [[10.0]]) + + def test_batched_set(self): + """set() should update multiple params with a single reinit.""" + pt1 = PT1(K=1.0, T=1.0) + pt1.set(K=3.0, T=0.2) + self.assertEqual(pt1.K, 3.0) + self.assertEqual(pt1.T, 0.2) + np.testing.assert_array_almost_equal(pt1.A, [[-5.0]]) + np.testing.assert_array_almost_equal(pt1.B, [[15.0]]) + + def test_mutable_params_introspection(self): + """_mutable_params should list all init params.""" + self.assertEqual(PT1._mutable_params, ("K", "T")) + self.assertEqual(PT2._mutable_params, ("K", "T", "d")) + self.assertEqual(PID._mutable_params, ("Kp", "Ki", "Kd", "f_max")) + + def test_mutable_params_inherited(self): + """AntiWindupPID should accumulate parent and own params.""" + self.assertIn("Kp", AntiWindupPID._mutable_params) + self.assertIn("Ks", AntiWindupPID._mutable_params) + self.assertIn("limits", AntiWindupPID._mutable_params) + # no duplicates + self.assertEqual( + len(AntiWindupPID._mutable_params), + len(set(AntiWindupPID._mutable_params)) + ) + + def test_no_reinit_during_construction(self): + """Properties should not trigger reinit during __init__.""" + # If this doesn't hang or error, the init guard works + pt1 = PT1(K=2.0, T=0.5) + self.assertTrue(pt1._param_locked) + + +# TESTS FOR ENGINE PRESERVATION ========================================================= + +class TestEnginePreservation(unittest.TestCase): + """Test that engine state is preserved across reinit.""" + + def test_engine_preserved_same_dimension(self): + """Engine should be preserved when state dimension doesn't change.""" + from pathsim.solvers.euler import EUF + + pt1 = PT1(K=1.0, T=1.0) + pt1.set_solver(EUF, None) + pt1.engine.state = np.array([42.0]) + + # Mutate parameter + pt1.K = 5.0 + + # Engine should be preserved with same state + self.assertIsNotNone(pt1.engine) + np.testing.assert_array_equal(pt1.engine.state, [42.0]) + + def test_engine_recreated_on_dimension_change(self): + """Engine should be recreated when state dimension changes.""" + from pathsim.solvers.euler import EUF + + filt = ButterworthLowpassFilter(Fc=100, n=2) + filt.set_solver(EUF, None) + + old_state_dim = len(filt.engine) + self.assertEqual(old_state_dim, 2) + + # Change filter order -> dimension change + filt.n = 4 + + # Engine should exist but with new dimension + self.assertIsNotNone(filt.engine) + self.assertEqual(len(filt.engine), 4) + + +# TESTS FOR INHERITANCE ================================================================= + +class TestInheritance(unittest.TestCase): + """Test that @mutable works with class hierarchies.""" + + def test_antiwinduppid_construction(self): + """AntiWindupPID should construct correctly with both decorators.""" + awpid = AntiWindupPID(Kp=2, Ki=0.5, Kd=0.1, f_max=1e3, Ks=10, limits=[-5, 5]) + self.assertEqual(awpid.Kp, 2) + self.assertEqual(awpid.Ks, 10) + + def test_antiwinduppid_parent_param_mutation(self): + """Mutating inherited param should reinit from most derived class.""" + awpid = AntiWindupPID(Kp=2, Ki=0.5, Kd=0.1, f_max=1e3, Ks=10, limits=[-5, 5]) + + # Mutate inherited param + awpid.Kp = 5.0 + + # op_dyn should still be the antiwindup version (not plain PID) + x = np.array([0.0, 0.0]) + u = np.array([1.0]) + result = awpid.op_dyn(x, u, 0) + # For AntiWindupPID with these params, dx1 = f_max*(u-x1), dx2 = u - w + self.assertEqual(len(result), 2) + + def test_antiwinduppid_own_param_mutation(self): + """Mutating AntiWindupPID's own param should work.""" + awpid = AntiWindupPID(Kp=2, Ki=0.5, Kd=0.1, f_max=1e3, Ks=10, limits=[-5, 5]) + awpid.Ks = 20 + self.assertEqual(awpid.Ks, 20) + + +# TESTS FOR SPECIFIC BLOCKS ============================================================= + +class TestSpecificBlocks(unittest.TestCase): + """Test @mutable on various block types.""" + + def test_pt2(self): + pt2 = PT2(K=1.0, T=1.0, d=0.5) + A_before = pt2.A.copy() + pt2.d = 0.7 + # A matrix should have changed + self.assertFalse(np.allclose(pt2.A, A_before)) + + def test_leadlag(self): + ll = LeadLag(K=1.0, T1=0.5, T2=0.1) + ll.K = 2.0 + self.assertEqual(ll.K, 2.0) + # C and D should reflect new K + np.testing.assert_array_almost_equal(ll.D, [[2.0 * 0.5 / 0.1]]) + + def test_transfer_function_numden(self): + tf = TransferFunctionNumDen(Num=[1], Den=[1, 1]) + np.testing.assert_array_almost_equal(tf.A, [[-1.0]]) + tf.Den = [1, 2] + np.testing.assert_array_almost_equal(tf.A, [[-2.0]]) + + def test_transfer_function_dimension_change(self): + """Changing denominator order should change state dimension.""" + tf = TransferFunctionNumDen(Num=[1], Den=[1, 1]) + self.assertEqual(tf.A.shape, (1, 1)) + tf.Den = [1, 3, 2] # second order + self.assertEqual(tf.A.shape, (2, 2)) + + def test_sinusoidal_source(self): + s = SinusoidalSource(frequency=10, amplitude=2, phase=0.5) + self.assertAlmostEqual(s._omega, 2*np.pi*10) + s.frequency = 20 + self.assertAlmostEqual(s._omega, 2*np.pi*20) + + def test_delay(self): + d = Delay(tau=0.01) + self.assertEqual(d._buffer.delay, 0.01) + d.tau = 0.05 + self.assertEqual(d._buffer.delay, 0.05) + + def test_clock_source(self): + c = ClockSource(T=1.0, tau=0.0) + self.assertEqual(c.events[0].t_period, 1.0) + c.T = 2.0 + self.assertEqual(c.events[0].t_period, 2.0) + + def test_fir(self): + f = FIR(coeffs=[0.5, 0.5], T=0.1) + self.assertEqual(f.T, 0.1) + f.T = 0.2 + self.assertEqual(f.T, 0.2) + self.assertEqual(f.events[0].t_period, 0.2) + + def test_samplehold(self): + sh = SampleHold(T=0.5, tau=0.0) + sh.T = 1.0 + self.assertEqual(sh.T, 1.0) + + def test_butterworth_filter_mutation(self): + filt = ButterworthLowpassFilter(Fc=100, n=2) + A_before = filt.A.copy() + filt.Fc = 200 + # Matrices should change + self.assertFalse(np.allclose(filt.A, A_before)) + + def test_butterworth_filter_order_change(self): + filt = ButterworthLowpassFilter(Fc=100, n=2) + self.assertEqual(filt.A.shape, (2, 2)) + filt.n = 4 + self.assertEqual(filt.A.shape, (4, 4)) + + +# INTEGRATION TEST ====================================================================== + +class TestMutableInSimulation(unittest.TestCase): + """Test parameter mutation in an actual simulation context.""" + + def test_pt1_mutation_mid_simulation(self): + """Mutating PT1 gain mid-simulation should affect output.""" + from pathsim import Simulation, Connection + from pathsim.blocks.sources import Constant + + src = Constant(value=1.0) + pt1 = PT1(K=1.0, T=0.1) + + sim = Simulation( + blocks=[src, pt1], + connections=[Connection(src, pt1)], + dt=0.01 + ) + + # Run for a bit + sim.run(duration=1.0) + output_before = pt1.outputs[0] + + # Change gain + pt1.K = 5.0 + + # Run more + sim.run(duration=1.0) + output_after = pt1.outputs[0] + + # With K=5 and enough settling time, output should approach 5.0 + self.assertGreater(output_after, output_before) + + +if __name__ == "__main__": + unittest.main() From 314b616c16c136c6786d0ae52dfc950705f05ea3 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 26 Feb 2026 10:45:56 +0100 Subject: [PATCH 09/29] Apply mutable decorator to Divider block --- src/pathsim/blocks/divider.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pathsim/blocks/divider.py b/src/pathsim/blocks/divider.py index a0d684cc..4e623170 100644 --- a/src/pathsim/blocks/divider.py +++ b/src/pathsim/blocks/divider.py @@ -15,6 +15,7 @@ from ._block import Block from ..utils.register import Register from ..optim.operator import Operator +from ..utils.mutable import mutable # MISO BLOCKS =========================================================================== @@ -22,6 +23,7 @@ _ZERO_DIV_OPTIONS = ("warn", "raise", "clamp") +@mutable class Divider(Block): """Multiplies and divides input signals (MISO). From d08ab8563a2393156e0a2e1db5913d6536b18a21 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 26 Feb 2026 10:35:36 +0100 Subject: [PATCH 10/29] Add missing XPRT IR blocks: logic ops, Atan2, MapLin, Alias, discrete Delay --- src/pathsim/blocks/__init__.py | 1 + src/pathsim/blocks/delay.py | 125 ++++++++++++---- src/pathsim/blocks/logic.py | 229 +++++++++++++++++++++++++++++ src/pathsim/blocks/math.py | 149 +++++++++++++++++++ tests/pathsim/blocks/test_delay.py | 91 ++++++++++++ tests/pathsim/blocks/test_logic.py | 217 +++++++++++++++++++++++++++ tests/pathsim/blocks/test_math.py | 126 ++++++++++++++++ 7 files changed, 908 insertions(+), 30 deletions(-) create mode 100644 src/pathsim/blocks/logic.py create mode 100644 tests/pathsim/blocks/test_logic.py diff --git a/src/pathsim/blocks/__init__.py b/src/pathsim/blocks/__init__.py index a196856e..dd67152b 100644 --- a/src/pathsim/blocks/__init__.py +++ b/src/pathsim/blocks/__init__.py @@ -21,6 +21,7 @@ from .noise import * from .table import * from .relay import * +from .logic import * from .math import * from .ctrl import * from .lti import * diff --git a/src/pathsim/blocks/delay.py b/src/pathsim/blocks/delay.py index 174ec8a8..6bcbed8b 100644 --- a/src/pathsim/blocks/delay.py +++ b/src/pathsim/blocks/delay.py @@ -1,6 +1,6 @@ ######################################################################################### ## -## TIME DOMAIN DELAY BLOCK +## TIME DOMAIN DELAY BLOCK ## (blocks/delay.py) ## ######################################################################################### @@ -9,9 +9,12 @@ import numpy as np +from collections import deque + from ._block import Block from ..utils.adaptivebuffer import AdaptiveBuffer +from ..events.schedule import Schedule from ..utils.mutable import mutable @@ -19,60 +22,107 @@ @mutable class Delay(Block): - """Delays the input signal by a time constant 'tau' in seconds. + """Delays the input signal by a time constant 'tau' in seconds. + + Supports two modes of operation: - Mathematically this block creates a time delay of the input signal like this: + **Continuous mode** (default, ``sampling_period=None``): + Uses an adaptive interpolating buffer for continuous-time delay. .. math:: - - y(t) = + + y(t) = \\begin{cases} x(t - \\tau) & , t \\geq \\tau \\\\ 0 & , t < \\tau \\end{cases} + **Discrete mode** (``sampling_period`` provided): + Uses a ring buffer with scheduled sampling events for N-sample delay, + where ``N = round(tau / sampling_period)``. + + .. math:: + + y[k] = x[k - N] + Note ---- - The internal adaptive buffer uses interpolation for the evaluation. This is - required to be compatible with variable step solvers. It has a drawback however. - The order of the ode solver used will degrade when this block is used, due to - the interpolation. + In continuous mode, the internal adaptive buffer uses interpolation for + the evaluation. This is required to be compatible with variable step solvers. + It has a drawback however. The order of the ode solver used will degrade + when this block is used, due to the interpolation. + - Note ---- - This block supports vector input, meaning we can have multiple parallel + This block supports vector input, meaning we can have multiple parallel delay paths through this block. Example ------- - The block is initialized like this: + Continuous-time delay: .. code-block:: python - + #5 time units delay D = Delay(tau=5) - + + Discrete-time N-sample delay (10 samples): + + .. code-block:: python + + D = Delay(tau=0.01, sampling_period=0.001) + Parameters ---------- tau : float - delay time constant + delay time constant in seconds + sampling_period : float, None + sampling period for discrete mode, default is continuous mode Attributes ---------- _buffer : AdaptiveBuffer - internal interpolatable adaptive rolling buffer + internal interpolatable adaptive rolling buffer (continuous mode) + _ring : deque + internal ring buffer for N-sample delay (discrete mode) """ - def __init__(self, tau=1e-3): + def __init__(self, tau=1e-3, sampling_period=None): super().__init__() - #time delay in seconds + #time delay in seconds self.tau = tau - #create adaptive buffer - self._buffer = AdaptiveBuffer(self.tau) + #params for sampling + self.sampling_period = sampling_period + + if sampling_period is None: + + #continuous mode: adaptive buffer with interpolation + self._buffer = AdaptiveBuffer(self.tau) + + else: + + #discrete mode: ring buffer with N-sample delay + self._n = max(1, round(self.tau / self.sampling_period)) + self._ring = deque([0.0] * self._n, maxlen=self._n + 1) + + #flag to indicate this is a timestep to sample + self._sample_next_timestep = False + + #internal scheduled event for periodic sampling + def _sample(t): + self._sample_next_timestep = True + + self.events = [ + Schedule( + t_start=0, + t_period=sampling_period, + func_act=_sample + ) + ] def __len__(self): @@ -83,13 +133,18 @@ def __len__(self): def reset(self): super().reset() - #clear the buffer - self._buffer.clear() + if self.sampling_period is None: + #clear the adaptive buffer + self._buffer.clear() + else: + #clear the ring buffer + self._ring.clear() + self._ring.extend([0.0] * self._n) def update(self, t): - """Evaluation of the buffer at different times - via interpolation. + """Evaluation of the buffer at different times + via interpolation (continuous) or ring buffer lookup (discrete). Parameters ---------- @@ -97,13 +152,17 @@ def update(self, t): evaluation time """ - #retrieve value from buffer - y = self._buffer.get(t) - self.outputs.update_from_array(y) + if self.sampling_period is None: + #continuous mode: retrieve value from buffer + y = self._buffer.get(t) + self.outputs.update_from_array(y) + else: + #discrete mode: output the oldest value in the ring buffer + self.outputs[0] = self._ring[0] def sample(self, t, dt): - """Sample input values and time of sampling + """Sample input values and time of sampling and add them to the buffer. Parameters @@ -114,5 +173,11 @@ def sample(self, t, dt): integration timestep """ - #add new value to buffer - self._buffer.add(t, self.inputs.to_array()) \ No newline at end of file + if self.sampling_period is None: + #continuous mode: add new value to buffer + self._buffer.add(t, self.inputs.to_array()) + else: + #discrete mode: only sample on scheduled events + if self._sample_next_timestep: + self._ring.append(self.inputs[0]) + self._sample_next_timestep = False \ No newline at end of file diff --git a/src/pathsim/blocks/logic.py b/src/pathsim/blocks/logic.py new file mode 100644 index 00000000..31315b6f --- /dev/null +++ b/src/pathsim/blocks/logic.py @@ -0,0 +1,229 @@ +######################################################################################### +## +## COMPARISON AND LOGIC BLOCKS +## (blocks/logic.py) +## +## definitions of comparison and boolean logic blocks +## +######################################################################################### + +# IMPORTS =============================================================================== + +import numpy as np + +from ._block import Block + +from ..optim.operator import Operator + + +# BASE LOGIC BLOCK ====================================================================== + +class Logic(Block): + """Base logic block. + + Note + ---- + This block doesnt implement any functionality itself. + Its intended to be used as a base for the comparison and logic blocks. + Its **not** intended to be used directly! + + """ + + def __len__(self): + """Purely algebraic block""" + return 1 + + + def update(self, t): + """update algebraic component of system equation + + Parameters + ---------- + t : float + evaluation time + """ + u = self.inputs.to_array() + y = self.op_alg(u) + self.outputs.update_from_array(y) + + +# COMPARISON BLOCKS ===================================================================== + +class GreaterThan(Logic): + """Greater-than comparison block. + + Compares two inputs and outputs 1.0 if a > b, else 0.0. + + .. math:: + + y = + \\begin{cases} + 1 & , a > b \\\\ + 0 & , a \\leq b + \\end{cases} + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + input_port_labels = {"a":0, "b":1} + output_port_labels = {"y":0} + + def __init__(self): + super().__init__() + + self.op_alg = Operator( + func=lambda x: float(x[0] > x[1]), + jac=lambda x: np.zeros((1, 2)) + ) + + +class LessThan(Logic): + """Less-than comparison block. + + Compares two inputs and outputs 1.0 if a < b, else 0.0. + + .. math:: + + y = + \\begin{cases} + 1 & , a < b \\\\ + 0 & , a \\geq b + \\end{cases} + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + input_port_labels = {"a":0, "b":1} + output_port_labels = {"y":0} + + def __init__(self): + super().__init__() + + self.op_alg = Operator( + func=lambda x: float(x[0] < x[1]), + jac=lambda x: np.zeros((1, 2)) + ) + + +class Equal(Logic): + """Equality comparison block. + + Compares two inputs and outputs 1.0 if |a - b| <= tolerance, else 0.0. + + .. math:: + + y = + \\begin{cases} + 1 & , |a - b| \\leq \\epsilon \\\\ + 0 & , |a - b| > \\epsilon + \\end{cases} + + Parameters + ---------- + tolerance : float + comparison tolerance for floating point equality + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + input_port_labels = {"a":0, "b":1} + output_port_labels = {"y":0} + + def __init__(self, tolerance=1e-12): + super().__init__() + + self.tolerance = tolerance + + self.op_alg = Operator( + func=lambda x: float(abs(x[0] - x[1]) <= self.tolerance), + jac=lambda x: np.zeros((1, 2)) + ) + + +# BOOLEAN LOGIC BLOCKS ================================================================== + +class LogicAnd(Logic): + """Logical AND block. + + Outputs 1.0 if both inputs are nonzero, else 0.0. + + .. math:: + + y = a \\land b + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + input_port_labels = {"a":0, "b":1} + output_port_labels = {"y":0} + + def __init__(self): + super().__init__() + + self.op_alg = Operator( + func=lambda x: float(bool(x[0]) and bool(x[1])), + jac=lambda x: np.zeros((1, 2)) + ) + + +class LogicOr(Logic): + """Logical OR block. + + Outputs 1.0 if either input is nonzero, else 0.0. + + .. math:: + + y = a \\lor b + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + input_port_labels = {"a":0, "b":1} + output_port_labels = {"y":0} + + def __init__(self): + super().__init__() + + self.op_alg = Operator( + func=lambda x: float(bool(x[0]) or bool(x[1])), + jac=lambda x: np.zeros((1, 2)) + ) + + +class LogicNot(Logic): + """Logical NOT block. + + Outputs 1.0 if input is zero, else 0.0. + + .. math:: + + y = \\lnot x + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + def __init__(self): + super().__init__() + + self.op_alg = Operator( + func=lambda x: float(not bool(x[0])), + jac=lambda x: np.zeros((1, 1)) + ) diff --git a/src/pathsim/blocks/math.py b/src/pathsim/blocks/math.py index 6a08f927..a3d89ef2 100644 --- a/src/pathsim/blocks/math.py +++ b/src/pathsim/blocks/math.py @@ -574,4 +574,153 @@ def __init__(self, A=np.eye(1)): self.op_alg = Operator( func=lambda u: np.dot(self.A, u), jac=lambda u: self.A + ) + + +class Atan2(Block): + """Two-argument arctangent block. + + Computes the four-quadrant arctangent of two inputs: + + .. math:: + + y = \\mathrm{atan2}(a, b) + + Note + ---- + This block takes exactly two inputs (a, b) and produces one output. + The first input is the y-coordinate, the second is the x-coordinate, + matching the convention of ``numpy.arctan2(y, x)``. + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + input_port_labels = {"a":0, "b":1} + output_port_labels = {"y":0} + + def __init__(self): + super().__init__() + + def _atan2_jac(x): + a, b = x[0], x[1] + denom = a**2 + b**2 + if denom == 0: + return np.zeros((1, 2)) + return np.array([[b / denom, -a / denom]]) + + self.op_alg = Operator( + func=lambda x: np.arctan2(x[0], x[1]), + jac=_atan2_jac + ) + + + def __len__(self): + """Purely algebraic block""" + return 1 + + + def update(self, t): + """update algebraic component of system equation + + Parameters + ---------- + t : float + evaluation time + """ + u = self.inputs.to_array() + y = self.op_alg(u) + self.outputs.update_from_array(y) + + +class MapLin(Math): + """Linear mapping / interpolation block. + + Maps the input linearly from range ``[i0, i1]`` to range ``[o0, o1]``. + Optionally saturates the output to ``[o0, o1]``. + + .. math:: + + y = o_0 + \\frac{(x - i_0) \\cdot (o_1 - o_0)}{i_1 - i_0} + + This block supports vector inputs. + + Parameters + ---------- + i0 : float + input range lower bound + i1 : float + input range upper bound + o0 : float + output range lower bound + o1 : float + output range upper bound + saturate : bool + if True, clamp output to [min(o0,o1), max(o0,o1)] + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + def __init__(self, i0=0.0, i1=1.0, o0=0.0, o1=1.0, saturate=False): + super().__init__() + + self.i0 = i0 + self.i1 = i1 + self.o0 = o0 + self.o1 = o1 + self.saturate = saturate + + #precompute gain + self._gain = (o1 - o0) / (i1 - i0) + + def _maplin(x): + y = self.o0 + (x - self.i0) * self._gain + if self.saturate: + lo, hi = min(self.o0, self.o1), max(self.o0, self.o1) + y = np.clip(y, lo, hi) + return y + + def _maplin_jac(x): + if self.saturate: + lo, hi = min(self.o0, self.o1), max(self.o0, self.o1) + y = self.o0 + (x - self.i0) * self._gain + mask = (y >= lo) & (y <= hi) + return np.diag(mask.astype(float) * self._gain) + return np.diag(np.full_like(x, self._gain)) + + self.op_alg = Operator( + func=_maplin, + jac=_maplin_jac + ) + + +class Alias(Math): + """Signal alias / pass-through block. + + Passes the input directly to the output without modification. + This is useful for signal renaming in model composition. + + .. math:: + + y = x + + This block supports vector inputs. + + Attributes + ---------- + op_alg : Operator + internal algebraic operator + """ + + def __init__(self): + super().__init__() + + self.op_alg = Operator( + func=lambda x: x, + jac=lambda x: np.eye(len(x)) ) \ No newline at end of file diff --git a/tests/pathsim/blocks/test_delay.py b/tests/pathsim/blocks/test_delay.py index 5a98236a..f4a86cec 100644 --- a/tests/pathsim/blocks/test_delay.py +++ b/tests/pathsim/blocks/test_delay.py @@ -105,6 +105,97 @@ def test_update(self): self.assertEqual(D.outputs[0], max(0, t-10.5)) +class TestDelayDiscrete(unittest.TestCase): + """ + Test the discrete-time (sampling_period) mode of the 'Delay' block class + """ + + def test_init_discrete(self): + + D = Delay(tau=0.01, sampling_period=0.001) + + self.assertEqual(D._n, 10) + self.assertEqual(len(D._ring), 10) + self.assertTrue(hasattr(D, 'events')) + self.assertEqual(len(D.events), 1) + + + def test_n_computation(self): + + #exact multiple + D = Delay(tau=0.05, sampling_period=0.01) + self.assertEqual(D._n, 5) + + #rounding + D = Delay(tau=0.015, sampling_period=0.01) + self.assertEqual(D._n, 2) + + #minimum of 1 + D = Delay(tau=0.001, sampling_period=0.01) + self.assertEqual(D._n, 1) + + + def test_len(self): + + D = Delay(tau=0.01, sampling_period=0.001) + + #no passthrough + self.assertEqual(len(D), 0) + + + def test_reset(self): + + D = Delay(tau=0.01, sampling_period=0.001) + + #push some values + D._sample_next_timestep = True + D.inputs[0] = 42.0 + D.sample(0, 0.001) + + D.reset() + + #ring buffer should be all zeros + self.assertTrue(all(v == 0.0 for v in D._ring)) + self.assertEqual(len(D._ring), D._n) + + + def test_discrete_delay(self): + + n = 3 + D = Delay(tau=0.003, sampling_period=0.001) + + self.assertEqual(D._n, n) + + #push values through the ring buffer + outputs = [] + for k in range(10): + D.inputs[0] = float(k) + D._sample_next_timestep = True + D.sample(k * 0.001, 0.001) + D.update(k * 0.001) + outputs.append(D.outputs[0]) + + #first n outputs should be 0 (initial buffer fill) + for k in range(n): + self.assertEqual(outputs[k], 0.0, f"output[{k}] should be 0.0") + + #after that, output should be delayed by n samples + for k in range(n, 10): + self.assertEqual(outputs[k], float(k - n), f"output[{k}] should be {k-n}") + + + def test_no_sample_without_flag(self): + + D = Delay(tau=0.003, sampling_period=0.001) + + #push a value without the flag set + D.inputs[0] = 42.0 + D.sample(0, 0.001) + + #ring buffer should be unchanged (all zeros) + self.assertTrue(all(v == 0.0 for v in D._ring)) + + # RUN TESTS LOCALLY ==================================================================== if __name__ == '__main__': diff --git a/tests/pathsim/blocks/test_logic.py b/tests/pathsim/blocks/test_logic.py new file mode 100644 index 00000000..e684205c --- /dev/null +++ b/tests/pathsim/blocks/test_logic.py @@ -0,0 +1,217 @@ +######################################################################################## +## +## TESTS FOR +## 'blocks.logic.py' +## +######################################################################################## + +# IMPORTS ============================================================================== + +import unittest +import numpy as np + +from pathsim.blocks.logic import ( + GreaterThan, + LessThan, + Equal, + LogicAnd, + LogicOr, + LogicNot, +) + +from tests.pathsim.blocks._embedding import Embedding + + +# TESTS ================================================================================ + +class TestGreaterThan(unittest.TestCase): + """ + Test the implementation of the 'GreaterThan' block class + """ + + def test_embedding(self): + """test algebraic components via embedding""" + + B = GreaterThan() + + #test a > b + def src(t): return t, 5.0 + def ref(t): return float(t > 5.0) + E = Embedding(B, src, ref) + + for t in range(10): self.assertTrue(np.allclose(*E.check_MIMO(t))) + + def test_equal_values(self): + """test that equal values return 0""" + + B = GreaterThan() + B.inputs[0] = 5.0 + B.inputs[1] = 5.0 + B.update(0) + self.assertEqual(B.outputs[0], 0.0) + + def test_less_than(self): + """test that a < b returns 0""" + + B = GreaterThan() + B.inputs[0] = 3.0 + B.inputs[1] = 5.0 + B.update(0) + self.assertEqual(B.outputs[0], 0.0) + + +class TestLessThan(unittest.TestCase): + """ + Test the implementation of the 'LessThan' block class + """ + + def test_embedding(self): + """test algebraic components via embedding""" + + B = LessThan() + + #test a < b + def src(t): return t, 5.0 + def ref(t): return float(t < 5.0) + E = Embedding(B, src, ref) + + for t in range(10): self.assertTrue(np.allclose(*E.check_MIMO(t))) + + def test_equal_values(self): + """test that equal values return 0""" + + B = LessThan() + B.inputs[0] = 5.0 + B.inputs[1] = 5.0 + B.update(0) + self.assertEqual(B.outputs[0], 0.0) + + +class TestEqual(unittest.TestCase): + """ + Test the implementation of the 'Equal' block class + """ + + def test_equal(self): + """test that equal values return 1""" + + B = Equal() + B.inputs[0] = 5.0 + B.inputs[1] = 5.0 + B.update(0) + self.assertEqual(B.outputs[0], 1.0) + + def test_not_equal(self): + """test that different values return 0""" + + B = Equal() + B.inputs[0] = 5.0 + B.inputs[1] = 6.0 + B.update(0) + self.assertEqual(B.outputs[0], 0.0) + + def test_tolerance(self): + """test tolerance parameter""" + + B = Equal(tolerance=0.1) + B.inputs[0] = 5.0 + B.inputs[1] = 5.05 + B.update(0) + self.assertEqual(B.outputs[0], 1.0) + + B.inputs[1] = 5.2 + B.update(0) + self.assertEqual(B.outputs[0], 0.0) + + +class TestLogicAnd(unittest.TestCase): + """ + Test the implementation of the 'LogicAnd' block class + """ + + def test_truth_table(self): + """test all combinations of boolean inputs""" + + B = LogicAnd() + + cases = [ + (0.0, 0.0, 0.0), + (0.0, 1.0, 0.0), + (1.0, 0.0, 0.0), + (1.0, 1.0, 1.0), + ] + + for a, b, expected in cases: + B.inputs[0] = a + B.inputs[1] = b + B.update(0) + self.assertEqual(B.outputs[0], expected, f"AND({a}, {b}) should be {expected}") + + def test_nonzero_is_true(self): + """test that nonzero values are treated as true""" + + B = LogicAnd() + B.inputs[0] = 5.0 + B.inputs[1] = -3.0 + B.update(0) + self.assertEqual(B.outputs[0], 1.0) + + +class TestLogicOr(unittest.TestCase): + """ + Test the implementation of the 'LogicOr' block class + """ + + def test_truth_table(self): + """test all combinations of boolean inputs""" + + B = LogicOr() + + cases = [ + (0.0, 0.0, 0.0), + (0.0, 1.0, 1.0), + (1.0, 0.0, 1.0), + (1.0, 1.0, 1.0), + ] + + for a, b, expected in cases: + B.inputs[0] = a + B.inputs[1] = b + B.update(0) + self.assertEqual(B.outputs[0], expected, f"OR({a}, {b}) should be {expected}") + + +class TestLogicNot(unittest.TestCase): + """ + Test the implementation of the 'LogicNot' block class + """ + + def test_true_to_false(self): + """test that nonzero input gives 0""" + + B = LogicNot() + B.inputs[0] = 1.0 + B.update(0) + self.assertEqual(B.outputs[0], 0.0) + + def test_false_to_true(self): + """test that zero input gives 1""" + + B = LogicNot() + B.inputs[0] = 0.0 + B.update(0) + self.assertEqual(B.outputs[0], 1.0) + + def test_nonzero_is_true(self): + """test that arbitrary nonzero values are treated as true""" + + B = LogicNot() + B.inputs[0] = 42.0 + B.update(0) + self.assertEqual(B.outputs[0], 0.0) + + +# RUN TESTS LOCALLY ==================================================================== + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/tests/pathsim/blocks/test_math.py b/tests/pathsim/blocks/test_math.py index a9c950b6..37cb8427 100644 --- a/tests/pathsim/blocks/test_math.py +++ b/tests/pathsim/blocks/test_math.py @@ -544,6 +544,132 @@ def ref(t): +class TestAtan2(unittest.TestCase): + """ + Test the implementation of the 'Atan2' block class + """ + + def test_embedding(self): + """test algebraic components via embedding""" + + B = Atan2() + + def src(t): return np.sin(t + 0.1), np.cos(t + 0.1) + def ref(t): return np.arctan2(np.sin(t + 0.1), np.cos(t + 0.1)) + E = Embedding(B, src, ref) + + for t in range(10): self.assertTrue(np.allclose(*E.check_MIMO(t))) + + def test_quadrants(self): + """test all four quadrants""" + + B = Atan2() + + cases = [ + ( 1.0, 1.0, np.arctan2(1.0, 1.0)), + ( 1.0, -1.0, np.arctan2(1.0, -1.0)), + (-1.0, -1.0, np.arctan2(-1.0, -1.0)), + (-1.0, 1.0, np.arctan2(-1.0, 1.0)), + ] + + for a, b, expected in cases: + B.inputs[0] = a + B.inputs[1] = b + B.update(0) + self.assertAlmostEqual(B.outputs[0], expected) + + +class TestMapLin(unittest.TestCase): + """ + Test the implementation of the 'MapLin' block class + """ + + def test_default_identity(self): + """test default mapping [0,1] -> [0,1] is identity""" + + B = MapLin() + + def src(t): return t * 0.1 + def ref(t): return t * 0.1 + E = Embedding(B, src, ref) + + for t in range(10): self.assertEqual(*E.check_SISO(t)) + + def test_custom_mapping(self): + """test custom linear mapping""" + + B = MapLin(i0=0.0, i1=10.0, o0=0.0, o1=100.0) + + def src(t): return float(t) + def ref(t): return float(t) * 10.0 + E = Embedding(B, src, ref) + + for t in range(10): self.assertAlmostEqual(*E.check_SISO(t)) + + def test_saturate(self): + """test saturation clamping""" + + B = MapLin(i0=0.0, i1=1.0, o0=0.0, o1=10.0, saturate=True) + + #input beyond range + B.inputs[0] = 2.0 + B.update(0) + self.assertEqual(B.outputs[0], 10.0) + + #input below range + B.inputs[0] = -1.0 + B.update(0) + self.assertEqual(B.outputs[0], 0.0) + + def test_no_saturate(self): + """test that without saturation, output can exceed range""" + + B = MapLin(i0=0.0, i1=1.0, o0=0.0, o1=10.0, saturate=False) + + B.inputs[0] = 2.0 + B.update(0) + self.assertEqual(B.outputs[0], 20.0) + + def test_vector_input(self): + """test with vector inputs""" + + B = MapLin(i0=0.0, i1=10.0, o0=-1.0, o1=1.0) + + def src(t): return float(t), float(t) * 2 + def ref(t): return -1.0 + float(t) * 0.2, -1.0 + float(t) * 2 * 0.2 + E = Embedding(B, src, ref) + + for t in range(5): self.assertTrue(np.allclose(*E.check_MIMO(t))) + + +class TestAlias(unittest.TestCase): + """ + Test the implementation of the 'Alias' block class + """ + + def test_passthrough_siso(self): + """test that input passes through unchanged""" + + B = Alias() + + def src(t): return float(t) + def ref(t): return float(t) + E = Embedding(B, src, ref) + + for t in range(10): self.assertEqual(*E.check_SISO(t)) + + def test_passthrough_mimo(self): + """test that vector input passes through unchanged""" + + B = Alias() + + def src(t): return float(t), float(t) * 2 + def ref(t): return float(t), float(t) * 2 + E = Embedding(B, src, ref) + + for t in range(10): self.assertTrue(np.allclose(*E.check_MIMO(t))) + + # RUN TESTS LOCALLY ==================================================================== if __name__ == '__main__': unittest.main(verbosity=2) \ No newline at end of file From 250e699a30b24ebeeed8961a3b2e390e4423c659 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 26 Feb 2026 10:40:31 +0100 Subject: [PATCH 11/29] Rename MapLin to Rescale for consistency with pathsim naming conventions --- src/pathsim/blocks/math.py | 6 ++++-- tests/pathsim/blocks/test_math.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/pathsim/blocks/math.py b/src/pathsim/blocks/math.py index a3d89ef2..aa134a28 100644 --- a/src/pathsim/blocks/math.py +++ b/src/pathsim/blocks/math.py @@ -15,6 +15,7 @@ from ..utils.register import Register from ..optim.operator import Operator +from ..utils.mutable import mutable # BASE MATH BLOCK ======================================================================= @@ -635,8 +636,9 @@ def update(self, t): self.outputs.update_from_array(y) -class MapLin(Math): - """Linear mapping / interpolation block. +@mutable +class Rescale(Math): + """Linear rescaling / mapping block. Maps the input linearly from range ``[i0, i1]`` to range ``[o0, o1]``. Optionally saturates the output to ``[o0, o1]``. diff --git a/tests/pathsim/blocks/test_math.py b/tests/pathsim/blocks/test_math.py index 37cb8427..075e831f 100644 --- a/tests/pathsim/blocks/test_math.py +++ b/tests/pathsim/blocks/test_math.py @@ -579,15 +579,15 @@ def test_quadrants(self): self.assertAlmostEqual(B.outputs[0], expected) -class TestMapLin(unittest.TestCase): +class TestRescale(unittest.TestCase): """ - Test the implementation of the 'MapLin' block class + Test the implementation of the 'Rescale' block class """ def test_default_identity(self): """test default mapping [0,1] -> [0,1] is identity""" - B = MapLin() + B = Rescale() def src(t): return t * 0.1 def ref(t): return t * 0.1 @@ -598,7 +598,7 @@ def ref(t): return t * 0.1 def test_custom_mapping(self): """test custom linear mapping""" - B = MapLin(i0=0.0, i1=10.0, o0=0.0, o1=100.0) + B = Rescale(i0=0.0, i1=10.0, o0=0.0, o1=100.0) def src(t): return float(t) def ref(t): return float(t) * 10.0 @@ -609,7 +609,7 @@ def ref(t): return float(t) * 10.0 def test_saturate(self): """test saturation clamping""" - B = MapLin(i0=0.0, i1=1.0, o0=0.0, o1=10.0, saturate=True) + B = Rescale(i0=0.0, i1=1.0, o0=0.0, o1=10.0, saturate=True) #input beyond range B.inputs[0] = 2.0 @@ -624,7 +624,7 @@ def test_saturate(self): def test_no_saturate(self): """test that without saturation, output can exceed range""" - B = MapLin(i0=0.0, i1=1.0, o0=0.0, o1=10.0, saturate=False) + B = Rescale(i0=0.0, i1=1.0, o0=0.0, o1=10.0, saturate=False) B.inputs[0] = 2.0 B.update(0) @@ -633,7 +633,7 @@ def test_no_saturate(self): def test_vector_input(self): """test with vector inputs""" - B = MapLin(i0=0.0, i1=10.0, o0=-1.0, o1=1.0) + B = Rescale(i0=0.0, i1=10.0, o0=-1.0, o1=1.0) def src(t): return float(t), float(t) * 2 def ref(t): return -1.0 + float(t) * 0.2, -1.0 + float(t) * 2 * 0.2 From f02e3cb5a6de809d3a9dcc81ac3836e7641cc231 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 26 Feb 2026 11:09:33 +0100 Subject: [PATCH 12/29] Add eval tests for logic, Rescale, Atan2, Alias, and discrete Delay --- tests/evals/test_logic_system.py | 265 ++++++++++++++++++++++ tests/evals/test_rescale_delay_system.py | 277 +++++++++++++++++++++++ 2 files changed, 542 insertions(+) create mode 100644 tests/evals/test_logic_system.py create mode 100644 tests/evals/test_rescale_delay_system.py diff --git a/tests/evals/test_logic_system.py b/tests/evals/test_logic_system.py new file mode 100644 index 00000000..9d6f56cf --- /dev/null +++ b/tests/evals/test_logic_system.py @@ -0,0 +1,265 @@ +######################################################################################## +## +## Testing logic and comparison block systems +## +## Verifies comparison (GreaterThan, LessThan, Equal) and boolean logic +## (LogicAnd, LogicOr, LogicNot) blocks in full simulation context. +## +######################################################################################## + +# IMPORTS ============================================================================== + +import unittest +import numpy as np + +from pathsim import Simulation, Connection +from pathsim.blocks import ( + Source, + Constant, + Scope, + ) + +from pathsim.blocks.logic import ( + GreaterThan, + LessThan, + Equal, + LogicAnd, + LogicOr, + LogicNot, + ) + + +# TESTCASE ============================================================================= + +class TestComparisonSystem(unittest.TestCase): + """ + Test comparison blocks in a simulation that compares a sine wave + against a constant threshold. + + System: Source(sin(t)) → GT/LT/EQ → Scope + Constant(0) ↗ + + Verify: GT outputs 1 when sin(t) > 0, LT outputs 1 when sin(t) < 0 + """ + + def setUp(self): + + Src = Source(lambda t: np.sin(2 * np.pi * t)) + Thr = Constant(0.0) + + self.GT = GreaterThan() + self.LT = LessThan() + + self.Sco = Scope(labels=["signal", "gt_zero", "lt_zero"]) + + blocks = [Src, Thr, self.GT, self.LT, self.Sco] + + connections = [ + Connection(Src, self.GT["a"], self.LT["a"], self.Sco[0]), + Connection(Thr, self.GT["b"], self.LT["b"]), + Connection(self.GT, self.Sco[1]), + Connection(self.LT, self.Sco[2]), + ] + + self.Sim = Simulation( + blocks, + connections, + dt=0.01, + log=False + ) + + + def test_gt_lt_complementary(self): + """GT and LT should be complementary (sum to 1) away from zero crossings""" + + self.Sim.run(duration=3.0, reset=True) + + time, [sig, gt, lt] = self.Sco.read() + + #away from zero crossings, GT + LT should be 1 (exactly one is true) + mask = np.abs(sig) > 0.1 + result = gt[mask] + lt[mask] + + self.assertTrue(np.allclose(result, 1.0), + "GT and LT should be complementary away from zero crossings") + + + def test_gt_matches_positive(self): + """GT output should be 1 when signal is clearly positive""" + + self.Sim.run(duration=3.0, reset=True) + + time, [sig, gt, lt] = self.Sco.read() + + mask_pos = sig > 0.2 + self.assertTrue(np.all(gt[mask_pos] == 1.0), + "GT should be 1 when signal is positive") + + mask_neg = sig < -0.2 + self.assertTrue(np.all(gt[mask_neg] == 0.0), + "GT should be 0 when signal is negative") + + +class TestLogicGateSystem(unittest.TestCase): + """ + Test logic gates combining two comparison outputs. + + System: Two sine waves at different frequencies compared against 0, + then combined with AND/OR/NOT. + + Verify: Logic truth tables hold across the simulation. + """ + + def setUp(self): + + #two signals with different frequencies so they go in and out of phase + Src1 = Source(lambda t: np.sin(2 * np.pi * 1.0 * t)) + Src2 = Source(lambda t: np.sin(2 * np.pi * 1.5 * t)) + Zero = Constant(0.0) + + GT1 = GreaterThan() + GT2 = GreaterThan() + + self.AND = LogicAnd() + self.OR = LogicOr() + self.NOT = LogicNot() + + self.Sco = Scope(labels=["gt1", "gt2", "and", "or", "not1"]) + + blocks = [Src1, Src2, Zero, GT1, GT2, + self.AND, self.OR, self.NOT, self.Sco] + + connections = [ + Connection(Src1, GT1["a"]), + Connection(Src2, GT2["a"]), + Connection(Zero, GT1["b"], GT2["b"]), + Connection(GT1, self.AND["a"], self.OR["a"], self.NOT, self.Sco[0]), + Connection(GT2, self.AND["b"], self.OR["b"], self.Sco[1]), + Connection(self.AND, self.Sco[2]), + Connection(self.OR, self.Sco[3]), + Connection(self.NOT, self.Sco[4]), + ] + + self.Sim = Simulation( + blocks, + connections, + dt=0.01, + log=False + ) + + + def test_and_gate(self): + """AND should only be 1 when both inputs are 1""" + + self.Sim.run(duration=5.0, reset=True) + + time, [gt1, gt2, and_out, or_out, not_out] = self.Sco.read() + + #where both are 1, AND should be 1 + both_true = (gt1 == 1.0) & (gt2 == 1.0) + if np.any(both_true): + self.assertTrue(np.all(and_out[both_true] == 1.0)) + + #where either is 0, AND should be 0 + either_false = (gt1 == 0.0) | (gt2 == 0.0) + if np.any(either_false): + self.assertTrue(np.all(and_out[either_false] == 0.0)) + + + def test_or_gate(self): + """OR should be 1 when either input is 1""" + + self.Sim.run(duration=5.0, reset=True) + + time, [gt1, gt2, and_out, or_out, not_out] = self.Sco.read() + + #where both are 0, OR should be 0 + both_false = (gt1 == 0.0) & (gt2 == 0.0) + if np.any(both_false): + self.assertTrue(np.all(or_out[both_false] == 0.0)) + + #where either is 1, OR should be 1 + either_true = (gt1 == 1.0) | (gt2 == 1.0) + if np.any(either_true): + self.assertTrue(np.all(or_out[either_true] == 1.0)) + + + def test_not_gate(self): + """NOT should invert its input""" + + self.Sim.run(duration=5.0, reset=True) + + time, [gt1, gt2, and_out, or_out, not_out] = self.Sco.read() + + #NOT should be inverse of GT1 + self.assertTrue(np.allclose(not_out + gt1, 1.0), + "NOT should invert its input") + + +class TestEqualSystem(unittest.TestCase): + """ + Test Equal block detecting when two signals are close. + + System: Source(sin(t)) → Equal ← Source(sin(t + small_offset)) + """ + + def test_equal_detects_match(self): + """Equal should output 1 when signals match within tolerance""" + + Src1 = Constant(3.14) + Src2 = Constant(3.14) + + Eq = Equal(tolerance=0.01) + Sco = Scope() + + Sim = Simulation( + blocks=[Src1, Src2, Eq, Sco], + connections=[ + Connection(Src1, Eq["a"]), + Connection(Src2, Eq["b"]), + Connection(Eq, Sco), + ], + dt=0.1, + log=False + ) + + Sim.run(duration=1.0, reset=True) + + time, [eq_out] = Sco.read() + + self.assertTrue(np.all(eq_out == 1.0), + "Equal should output 1 for identical signals") + + + def test_equal_detects_mismatch(self): + """Equal should output 0 when signals differ""" + + Src1 = Constant(1.0) + Src2 = Constant(2.0) + + Eq = Equal(tolerance=0.01) + Sco = Scope() + + Sim = Simulation( + blocks=[Src1, Src2, Eq, Sco], + connections=[ + Connection(Src1, Eq["a"]), + Connection(Src2, Eq["b"]), + Connection(Eq, Sco), + ], + dt=0.1, + log=False + ) + + Sim.run(duration=1.0, reset=True) + + time, [eq_out] = Sco.read() + + self.assertTrue(np.all(eq_out == 0.0), + "Equal should output 0 for different signals") + + +# RUN TESTS LOCALLY ==================================================================== + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/tests/evals/test_rescale_delay_system.py b/tests/evals/test_rescale_delay_system.py new file mode 100644 index 00000000..60c97be4 --- /dev/null +++ b/tests/evals/test_rescale_delay_system.py @@ -0,0 +1,277 @@ +######################################################################################## +## +## Testing Rescale, Atan2, Alias, and discrete Delay systems +## +## Verifies new math blocks and discrete delay mode in full simulation context. +## +######################################################################################## + +# IMPORTS ============================================================================== + +import unittest +import numpy as np + +from pathsim import Simulation, Connection +from pathsim.blocks import ( + Source, + SinusoidalSource, + Constant, + Delay, + Scope, + ) + +from pathsim.blocks.math import Atan2, Rescale, Alias + + +# TESTCASE ============================================================================= + +class TestRescaleSystem(unittest.TestCase): + """ + Test Rescale block mapping a sine wave from [-1, 1] to [0, 10]. + + System: Source(sin(t)) → Rescale → Scope + Verify: output is linearly mapped to target range + """ + + def setUp(self): + + Src = SinusoidalSource(amplitude=1.0, frequency=1.0) + + self.Rsc = Rescale(i0=-1.0, i1=1.0, o0=0.0, o1=10.0) + self.Sco = Scope(labels=["input", "rescaled"]) + + blocks = [Src, self.Rsc, self.Sco] + + connections = [ + Connection(Src, self.Rsc, self.Sco[0]), + Connection(self.Rsc, self.Sco[1]), + ] + + self.Sim = Simulation( + blocks, + connections, + dt=0.01, + log=False + ) + + + def test_rescale_range(self): + """Output should be in [0, 10] for input in [-1, 1]""" + + self.Sim.run(duration=3.0, reset=True) + + time, [inp, rsc] = self.Sco.read() + + #check output stays within target range (with small tolerance) + self.assertTrue(np.all(rsc >= -0.1), "Rescaled output below lower bound") + self.assertTrue(np.all(rsc <= 10.1), "Rescaled output above upper bound") + + + def test_rescale_linearity(self): + """Output should be linear mapping of input""" + + self.Sim.run(duration=3.0, reset=True) + + time, [inp, rsc] = self.Sco.read() + + #expected: 5 + 5 * sin(t) + expected = 5.0 + 5.0 * inp + error = np.max(np.abs(rsc - expected)) + + self.assertLess(error, 0.01, f"Rescale linearity error: {error:.4f}") + + +class TestRescaleSaturationSystem(unittest.TestCase): + """ + Test Rescale with saturation enabled. + + System: Source(ramp) → Rescale(saturate=True) → Scope + Verify: output is clamped to target range + """ + + def test_saturation_clamps_output(self): + + #ramp from -2 to 2 over 4 seconds, mapped [0,1] -> [0,10] + Src = Source(lambda t: t - 2.0) + Rsc = Rescale(i0=0.0, i1=1.0, o0=0.0, o1=10.0, saturate=True) + Sco = Scope(labels=["input", "rescaled"]) + + Sim = Simulation( + blocks=[Src, Rsc, Sco], + connections=[ + Connection(Src, Rsc, Sco[0]), + Connection(Rsc, Sco[1]), + ], + dt=0.01, + log=False + ) + + Sim.run(duration=4.0, reset=True) + + time, [inp, rsc] = Sco.read() + + #output should never exceed [0, 10] + self.assertTrue(np.all(rsc >= -0.01), "Saturated output below 0") + self.assertTrue(np.all(rsc <= 10.01), "Saturated output above 10") + + #input in valid range [0, 1] should map normally + mask_valid = (inp >= 0.0) & (inp <= 1.0) + if np.any(mask_valid): + expected = 10.0 * inp[mask_valid] + error = np.max(np.abs(rsc[mask_valid] - expected)) + self.assertLess(error, 0.1) + + +class TestAtan2System(unittest.TestCase): + """ + Test Atan2 block computing the angle of a rotating vector. + + System: Source(sin(t)) → Atan2 ← Source(cos(t)) + Verify: output recovers the angle t (mod 2pi) + """ + + def setUp(self): + + self.SrcY = Source(lambda t: np.sin(t)) + self.SrcX = Source(lambda t: np.cos(t)) + + self.At2 = Atan2() + self.Sco = Scope(labels=["angle"]) + + blocks = [self.SrcY, self.SrcX, self.At2, self.Sco] + + connections = [ + Connection(self.SrcY, self.At2["a"]), + Connection(self.SrcX, self.At2["b"]), + Connection(self.At2, self.Sco), + ] + + self.Sim = Simulation( + blocks, + connections, + dt=0.01, + log=False + ) + + + def test_atan2_recovers_angle(self): + """atan2(sin(t), cos(t)) should equal t for t in [0, pi)""" + + self.Sim.run(duration=3.0, reset=True) + + time, [angle] = self.Sco.read() + + #check in first half period where atan2 is monotonic + mask = time < np.pi - 0.1 + expected = time[mask] + actual = angle[mask] + + error = np.max(np.abs(actual - expected)) + self.assertLess(error, 0.02, + f"Atan2 angle recovery error: {error:.4f}") + + +class TestAliasSystem(unittest.TestCase): + """ + Test Alias block as a transparent pass-through. + + System: Source(sin(t)) → Alias → Scope + Verify: output is identical to input + """ + + def test_alias_transparent(self): + + Src = SinusoidalSource(amplitude=1.0, frequency=2.0) + Als = Alias() + Sco = Scope(labels=["input", "alias"]) + + Sim = Simulation( + blocks=[Src, Als, Sco], + connections=[ + Connection(Src, Als, Sco[0]), + Connection(Als, Sco[1]), + ], + dt=0.01, + log=False + ) + + Sim.run(duration=2.0, reset=True) + + time, [inp, als] = Sco.read() + + self.assertTrue(np.allclose(inp, als), + "Alias output should be identical to input") + + +class TestDiscreteDelaySystem(unittest.TestCase): + """ + Test discrete-time delay using sampling_period parameter. + + System: Source(ramp) → Delay(tau, sampling_period) → Scope + Verify: output is a staircase-delayed version of input + """ + + def setUp(self): + + self.tau = 0.1 + self.T = 0.01 + + Src = Source(lambda t: t) + self.Dly = Delay(tau=self.tau, sampling_period=self.T) + self.Sco = Scope(labels=["input", "delayed"]) + + blocks = [Src, self.Dly, self.Sco] + + connections = [ + Connection(Src, self.Dly, self.Sco[0]), + Connection(self.Dly, self.Sco[1]), + ] + + self.Sim = Simulation( + blocks, + connections, + dt=0.001, + log=False + ) + + + def test_discrete_delay_offset(self): + """Delayed signal should trail input by approximately tau""" + + self.Sim.run(duration=1.0, reset=True) + + time, [inp, delayed] = self.Sco.read() + + #after initial fill (t > tau + settling), check delay offset + mask = time > self.tau + 0.2 + t_check = time[mask] + delayed_check = delayed[mask] + + #the delayed ramp should be approximately (t - tau) + #with staircase quantization from sampling + expected = t_check - self.tau + error = np.mean(np.abs(delayed_check - expected)) + + self.assertLess(error, self.T + 0.01, + f"Discrete delay mean error: {error:.4f}") + + + def test_discrete_delay_zero_initial(self): + """Output should be zero during initial fill period""" + + self.Sim.run(duration=0.5, reset=True) + + time, [inp, delayed] = self.Sco.read() + + #during first tau seconds, output should be 0 + mask = time < self.tau * 0.5 + early_output = delayed[mask] + + self.assertTrue(np.all(early_output == 0.0), + "Discrete delay output should be zero before buffer fills") + + +# RUN TESTS LOCALLY ==================================================================== + +if __name__ == '__main__': + unittest.main(verbosity=2) From bbd0f7755de160c310689059e85d656a9fc86ca6 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Wed, 4 Mar 2026 09:18:20 +0100 Subject: [PATCH 13/29] Deprecate RFNetwork in favour of pathsim-rf package --- src/pathsim/blocks/rf.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/pathsim/blocks/rf.py b/src/pathsim/blocks/rf.py index 51ac9e00..4a82c5db 100644 --- a/src/pathsim/blocks/rf.py +++ b/src/pathsim/blocks/rf.py @@ -8,13 +8,8 @@ ## ######################################################################################### -# TODO LIST -# class RFAmplifier Model amplifier in RF systems -# class Resistor/Capacitor/Inductor -# class RFMixer for mixer in RF systems? - - # IMPORTS =============================================================================== + from __future__ import annotations import numpy as np @@ -38,10 +33,17 @@ from .lti import StateSpace +from ..utils.deprecation import deprecated + # BLOCK DEFINITIONS ===================================================================== +@deprecated( + version="1.0.0", + replacement="pathsim_rf.RFNetwork", + reason="This block has moved to the pathsim-rf package: pip install pathsim-rf", +) class RFNetwork(StateSpace): """N-port RF network linear time invariant (LTI) multi input multi output (MIMO) state-space model. @@ -78,7 +80,7 @@ def __init__(self, ntwk: NetworkType | str | Path, auto_fit: bool = True, **kwar _msg = "The scikit-rf package is required to use this block -> 'pip install scikit-rf'" raise ImportError(_msg) - if isinstance(ntwk, Path) or isinstance(ntwk, str): + if isinstance(ntwk, (Path, str)): ntwk = rf.Network(ntwk) # Select the vector fitting function from scikit-rf From 0d52b5421d1edc964fec2912b09894832a48808a Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Sun, 15 Mar 2026 20:59:57 +0100 Subject: [PATCH 14/29] Add checkpoint save/load system with JSON+NPZ format --- src/pathsim/blocks/_block.py | 91 ++++++++++ src/pathsim/blocks/delay.py | 35 ++++ src/pathsim/blocks/scope.py | 42 ++++- src/pathsim/blocks/spectrum.py | 18 ++ src/pathsim/blocks/switch.py | 9 + src/pathsim/events/_event.py | 63 ++++++- src/pathsim/simulation.py | 134 +++++++++++++++ src/pathsim/solvers/_solver.py | 64 +++++++ src/pathsim/solvers/gear.py | 44 +++++ src/pathsim/utils/adaptivebuffer.py | 45 ++++- tests/pathsim/test_checkpoint.py | 256 ++++++++++++++++++++++++++++ 11 files changed, 796 insertions(+), 5 deletions(-) create mode 100644 tests/pathsim/test_checkpoint.py diff --git a/src/pathsim/blocks/_block.py b/src/pathsim/blocks/_block.py index 4b4275fb..597195a4 100644 --- a/src/pathsim/blocks/_block.py +++ b/src/pathsim/blocks/_block.py @@ -11,6 +11,7 @@ # IMPORTS =============================================================================== import inspect +from uuid import uuid4 from functools import lru_cache from ..utils.deprecation import deprecated @@ -84,6 +85,9 @@ class definition for other blocks to be inherited. def __init__(self): + #unique identifier for checkpointing and diagnostics + self.id = uuid4().hex + #registers to hold input and output values self.inputs = Register( mapping=self.input_port_labels and self.input_port_labels.copy() @@ -524,6 +528,93 @@ def state(self, val): self.engine.state = val + # checkpoint methods ---------------------------------------------------------------- + + def to_checkpoint(self, recordings=False): + """Serialize block state for checkpointing. + + Parameters + ---------- + recordings : bool + include recording data (for Scope blocks) + + Returns + ------- + json_data : dict + JSON-serializable metadata + npz_data : dict + numpy arrays keyed by path + """ + prefix = self.id + + json_data = { + "id": self.id, + "type": self.__class__.__name__, + "active": self._active, + } + + npz_data = { + f"{prefix}/inputs": self.inputs.to_array(), + f"{prefix}/outputs": self.outputs.to_array(), + } + + #solver state + if self.engine: + e_json, e_npz = self.engine.to_checkpoint(f"{prefix}/engine") + json_data["engine"] = e_json + npz_data.update(e_npz) + + #internal events + if self.events: + evt_jsons = [] + for event in self.events: + e_json, e_npz = event.to_checkpoint() + evt_jsons.append(e_json) + npz_data.update(e_npz) + json_data["events"] = evt_jsons + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz): + """Restore block state from checkpoint. + + Parameters + ---------- + json_data : dict + block metadata from checkpoint JSON + npz : dict-like + numpy arrays from checkpoint NPZ + """ + prefix = json_data["id"] + + #verify type + if json_data["type"] != self.__class__.__name__: + raise ValueError( + f"Checkpoint type mismatch: expected '{self.__class__.__name__}', " + f"got '{json_data['type']}'" + ) + + self._active = json_data["active"] + + #restore registers + inp_key = f"{prefix}/inputs" + out_key = f"{prefix}/outputs" + if inp_key in npz: + self.inputs.update_from_array(npz[inp_key]) + if out_key in npz: + self.outputs.update_from_array(npz[out_key]) + + #restore solver state + if self.engine and "engine" in json_data: + self.engine.load_checkpoint(json_data["engine"], npz, f"{prefix}/engine") + + #restore internal events + if self.events and "events" in json_data: + for event, evt_data in zip(self.events, json_data["events"]): + event.load_checkpoint(evt_data, npz) + + # methods for block output and state updates ---------------------------------------- def update(self, t): diff --git a/src/pathsim/blocks/delay.py b/src/pathsim/blocks/delay.py index 6bcbed8b..6b42614c 100644 --- a/src/pathsim/blocks/delay.py +++ b/src/pathsim/blocks/delay.py @@ -142,6 +142,41 @@ def reset(self): self._ring.extend([0.0] * self._n) + def to_checkpoint(self, recordings=False): + """Serialize Delay state including buffer data.""" + json_data, npz_data = super().to_checkpoint(recordings=recordings) + prefix = self.id + + json_data["sampling_period"] = self.sampling_period + + if self.sampling_period is None: + #continuous mode: adaptive buffer + npz_data.update(self._buffer.to_checkpoint(f"{prefix}/buffer")) + else: + #discrete mode: ring buffer + npz_data[f"{prefix}/ring"] = np.array(list(self._ring)) + json_data["_sample_next_timestep"] = self._sample_next_timestep + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz): + """Restore Delay state including buffer data.""" + super().load_checkpoint(json_data, npz) + prefix = json_data["id"] + + if self.sampling_period is None: + #continuous mode + self._buffer.load_checkpoint(npz, f"{prefix}/buffer") + else: + #discrete mode + ring_key = f"{prefix}/ring" + if ring_key in npz: + self._ring.clear() + self._ring.extend(npz[ring_key].tolist()) + self._sample_next_timestep = json_data.get("_sample_next_timestep", False) + + def update(self, t): """Evaluation of the buffer at different times via interpolation (continuous) or ring buffer lookup (discrete). diff --git a/src/pathsim/blocks/scope.py b/src/pathsim/blocks/scope.py index 4997f772..57854526 100644 --- a/src/pathsim/blocks/scope.py +++ b/src/pathsim/blocks/scope.py @@ -448,13 +448,49 @@ def save(self, path="scope.csv"): wrt.writerow(sample) + def to_checkpoint(self, recordings=False): + """Serialize Scope state including optional recording data.""" + json_data, npz_data = super().to_checkpoint(recordings=recordings) + prefix = self.id + + json_data["_incremental_idx"] = self._incremental_idx + if hasattr(self, '_sample_next_timestep'): + json_data["_sample_next_timestep"] = self._sample_next_timestep + + if recordings and self.recording_time: + npz_data[f"{prefix}/recording_time"] = np.array(self.recording_time) + npz_data[f"{prefix}/recording_data"] = np.array(self.recording_data) + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz): + """Restore Scope state including optional recording data.""" + super().load_checkpoint(json_data, npz) + prefix = json_data["id"] + + self._incremental_idx = json_data.get("_incremental_idx", 0) + if hasattr(self, '_sample_next_timestep'): + self._sample_next_timestep = json_data.get("_sample_next_timestep", False) + + #restore recordings if present + rt_key = f"{prefix}/recording_time" + rd_key = f"{prefix}/recording_data" + if rt_key in npz and rd_key in npz: + self.recording_time = npz[rt_key].tolist() + self.recording_data = [row for row in npz[rd_key]] + else: + self.recording_time = [] + self.recording_data = [] + + def update(self, t): - """update system equation for fixed point loop, + """update system equation for fixed point loop, here just setting the outputs - + Note ---- - Scope has no passthrough, so the 'update' method + Scope has no passthrough, so the 'update' method is optimized for this case (does nothing) Parameters diff --git a/src/pathsim/blocks/spectrum.py b/src/pathsim/blocks/spectrum.py index 7b3a0878..b4d37fed 100644 --- a/src/pathsim/blocks/spectrum.py +++ b/src/pathsim/blocks/spectrum.py @@ -283,6 +283,24 @@ def step(self, t, dt): return True, 0.0, None + def to_checkpoint(self, recordings=False): + """Serialize Spectrum state including integration time.""" + json_data, npz_data = super().to_checkpoint(recordings=recordings) + + json_data["time"] = self.time + json_data["t_sample"] = self.t_sample + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz): + """Restore Spectrum state including integration time.""" + super().load_checkpoint(json_data, npz) + + self.time = json_data.get("time", 0.0) + self.t_sample = json_data.get("t_sample", 0.0) + + def sample(self, t, dt): """sample time of successfull timestep for waiting period diff --git a/src/pathsim/blocks/switch.py b/src/pathsim/blocks/switch.py index dc60d887..3ce04b07 100644 --- a/src/pathsim/blocks/switch.py +++ b/src/pathsim/blocks/switch.py @@ -82,6 +82,15 @@ def select(self, switch_state=0): self.switch_state = switch_state + def to_checkpoint(self, recordings=False): + json_data, npz_data = super().to_checkpoint(recordings=recordings) + json_data["switch_state"] = self.switch_state + return json_data, npz_data + + def load_checkpoint(self, json_data, npz): + super().load_checkpoint(json_data, npz) + self.switch_state = json_data.get("switch_state", None) + def update(self, t): """Update switch output depending on inputs and switch state. diff --git a/src/pathsim/events/_event.py b/src/pathsim/events/_event.py index fd911a5b..124c99d1 100644 --- a/src/pathsim/events/_event.py +++ b/src/pathsim/events/_event.py @@ -11,6 +11,8 @@ import numpy as np +from uuid import uuid4 + from .. _constants import EVT_TOLERANCE @@ -64,6 +66,9 @@ def __init__( tolerance=EVT_TOLERANCE ): + #unique identifier for checkpointing and diagnostics + self.id = uuid4().hex + #event detection function self.func_evt = func_evt @@ -201,4 +206,60 @@ def resolve(self, t): #action function for event resolution if self.func_act is not None: - self.func_act(t) \ No newline at end of file + self.func_act(t) + + + # checkpoint methods ---------------------------------------------------------------- + + def to_checkpoint(self): + """Serialize event state for checkpointing. + + Returns + ------- + json_data : dict + JSON-serializable metadata + npz_data : dict + numpy arrays keyed by path + """ + prefix = self.id + + #extract history eval value + hist_eval, hist_time = self._history + if hist_eval is not None and hasattr(hist_eval, 'item'): + hist_eval = float(hist_eval) + + json_data = { + "id": self.id, + "type": self.__class__.__name__, + "active": self._active, + "history_eval": hist_eval, + "history_time": hist_time, + } + + npz_data = {} + if self._times: + npz_data[f"{prefix}/times"] = np.array(self._times) + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz): + """Restore event state from checkpoint. + + Parameters + ---------- + json_data : dict + event metadata from checkpoint JSON + npz : dict-like + numpy arrays from checkpoint NPZ + """ + prefix = json_data["id"] + + self._active = json_data["active"] + self._history = json_data["history_eval"], json_data["history_time"] + + times_key = f"{prefix}/times" + if times_key in npz: + self._times = npz[times_key].tolist() + else: + self._times = [] \ No newline at end of file diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index 306c7e4c..ed6edbfb 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -331,6 +331,140 @@ def plot(self, *args, **kwargs): if block: block.plot(*args, **kwargs) + # checkpoint methods ---------------------------------------------------------- + + def save_checkpoint(self, path, recordings=False): + """Save simulation state to checkpoint files (JSON + NPZ). + + Creates two files: {path}.json (structure/metadata) and + {path}.npz (numerical data). + + Parameters + ---------- + path : str + base path without extension + recordings : bool + include scope/spectrum recording data (default: False) + """ + import json + + #strip extension if provided + if path.endswith('.json') or path.endswith('.npz'): + path = path.rsplit('.', 1)[0] + + #simulation metadata + checkpoint = { + "version": "1.0.0", + "pathsim_version": __version__, + "created": datetime.datetime.now(datetime.timezone.utc).isoformat(), + "simulation": { + "time": self.time, + "dt": self.dt, + "dt_min": self.dt_min, + "dt_max": self.dt_max, + "solver": self.Solver.__name__, + "tolerance_fpi": self.tolerance_fpi, + "iterations_max": self.iterations_max, + }, + "blocks": {}, + "events": {}, + } + + npz_data = {} + + #checkpoint all blocks (keyed by UUID) + for block in self.blocks: + b_json, b_npz = block.to_checkpoint(recordings=recordings) + checkpoint["blocks"][block.id] = b_json + npz_data.update(b_npz) + + #checkpoint external events (keyed by UUID) + for event in self.events: + e_json, e_npz = event.to_checkpoint() + checkpoint["events"][event.id] = e_json + npz_data.update(e_npz) + + #write files + with open(f"{path}.json", "w", encoding="utf-8") as f: + json.dump(checkpoint, f, indent=2, ensure_ascii=False) + + np.savez(f"{path}.npz", **npz_data) + + + def load_checkpoint(self, path): + """Load simulation state from checkpoint files (JSON + NPZ). + + Restores simulation time and all block/event states from a + previously saved checkpoint. The simulation must have the same + blocks and events as when the checkpoint was saved. + + Parameters + ---------- + path : str + base path without extension + """ + import json + import warnings + + #strip extension if provided + if path.endswith('.json') or path.endswith('.npz'): + path = path.rsplit('.', 1)[0] + + #read files + with open(f"{path}.json", "r", encoding="utf-8") as f: + checkpoint = json.load(f) + + npz = np.load(f"{path}.npz", allow_pickle=False) + + try: + #version check + cp_version = checkpoint.get("pathsim_version", "unknown") + if cp_version != __version__: + warnings.warn( + f"Checkpoint was saved with PathSim {cp_version}, " + f"current version is {__version__}" + ) + + #restore simulation state + sim_data = checkpoint["simulation"] + self.time = sim_data["time"] + self.dt = sim_data["dt"] + self.dt_min = sim_data["dt_min"] + self.dt_max = sim_data["dt_max"] + + #solver type check + if sim_data["solver"] != self.Solver.__name__: + warnings.warn( + f"Checkpoint solver '{sim_data['solver']}' differs from " + f"current solver '{self.Solver.__name__}'" + ) + + #restore blocks + block_data = checkpoint.get("blocks", {}) + for block in self.blocks: + if block.id in block_data: + block.load_checkpoint(block_data[block.id], npz) + else: + warnings.warn( + f"Block {block.__class__.__name__} (id={block.id[:8]}...) " + f"not found in checkpoint" + ) + + #restore external events + event_data = checkpoint.get("events", {}) + for event in self.events: + if event.id in event_data: + event.load_checkpoint(event_data[event.id], npz) + else: + warnings.warn( + f"Event {event.__class__.__name__} (id={event.id[:8]}...) " + f"not found in checkpoint" + ) + + finally: + npz.close() + + # adding system components ---------------------------------------------------- def add_block(self, block): diff --git a/src/pathsim/solvers/_solver.py b/src/pathsim/solvers/_solver.py index 9cf00de9..d235856e 100644 --- a/src/pathsim/solvers/_solver.py +++ b/src/pathsim/solvers/_solver.py @@ -353,6 +353,70 @@ def create(cls, initial_value, parent=None, from_engine=None, **solver_kwargs): return cls(initial_value, parent, **solver_kwargs) + # checkpoint methods --------------------------------------------------------------- + + def to_checkpoint(self, prefix): + """Serialize solver state for checkpointing. + + Parameters + ---------- + prefix : str + NPZ key prefix for this solver's arrays + + Returns + ------- + json_data : dict + JSON-serializable metadata + npz_data : dict + numpy arrays keyed by path + """ + json_data = { + "type": self.__class__.__name__, + "is_adaptive": self.is_adaptive, + "n": self.n, + "history_len": len(self.history), + "history_maxlen": self.history.maxlen, + } + + npz_data = { + f"{prefix}/x": np.atleast_1d(self.x), + f"{prefix}/initial_value": np.atleast_1d(self.initial_value), + } + + for i, h in enumerate(self.history): + npz_data[f"{prefix}/history_{i}"] = np.atleast_1d(h) + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz, prefix): + """Restore solver state from checkpoint. + + Parameters + ---------- + json_data : dict + solver metadata from checkpoint JSON + npz : dict-like + numpy arrays from checkpoint NPZ + prefix : str + NPZ key prefix for this solver's arrays + """ + self.x = npz[f"{prefix}/x"].copy() + self.initial_value = npz[f"{prefix}/initial_value"].copy() + + #restore scalar format if needed + if self._scalar_initial and self.initial_value.size == 1: + self.initial_value = self.initial_value.item() + + #restore history + maxlen = json_data.get("history_maxlen", self.history.maxlen) + self.history = deque([], maxlen=maxlen) + for i in range(json_data.get("history_len", 0)): + key = f"{prefix}/history_{i}" + if key in npz: + self.history.append(npz[key].copy()) + + # methods for adaptive timestep solvers -------------------------------------------- def error_controller(self): diff --git a/src/pathsim/solvers/gear.py b/src/pathsim/solvers/gear.py index e8cf269f..6f745371 100644 --- a/src/pathsim/solvers/gear.py +++ b/src/pathsim/solvers/gear.py @@ -210,6 +210,50 @@ def create(cls, initial_value, parent=None, from_engine=None, **solver_kwargs): return cls(initial_value, parent, **solver_kwargs) + def to_checkpoint(self, prefix): + """Serialize GEAR solver state including startup solver and timestep history.""" + json_data, npz_data = super().to_checkpoint(prefix) + + json_data["_needs_startup"] = self._needs_startup + json_data["history_dt_len"] = len(self.history_dt) + + #timestep history + for i, dt in enumerate(self.history_dt): + npz_data[f"{prefix}/history_dt_{i}"] = np.atleast_1d(dt) + + #startup solver state + if self.startup: + s_json, s_npz = self.startup.to_checkpoint(f"{prefix}/startup") + json_data["startup"] = s_json + npz_data.update(s_npz) + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz, prefix): + """Restore GEAR solver state including startup solver and timestep history.""" + super().load_checkpoint(json_data, npz, prefix) + + self._needs_startup = json_data.get("_needs_startup", True) + + #restore timestep history + self.history_dt.clear() + for i in range(json_data.get("history_dt_len", 0)): + key = f"{prefix}/history_dt_{i}" + if key in npz: + self.history_dt.append(npz[key].item()) + + #restore startup solver + if self.startup and "startup" in json_data: + self.startup.load_checkpoint(json_data["startup"], npz, f"{prefix}/startup") + + #recompute BDF coefficients from restored history + if not self._needs_startup and len(self.history_dt) > 0: + self.F, self.K = {}, {} + for n, _ in enumerate(self.history_dt, 1): + self.F[n], self.K[n] = compute_bdf_coefficients(n, np.array(self.history_dt)) + + def stages(self, t, dt): """Generator that yields the intermediate evaluation time during the timestep 't + ratio * dt'. diff --git a/src/pathsim/utils/adaptivebuffer.py b/src/pathsim/utils/adaptivebuffer.py index b24e2e5a..05dd82fa 100644 --- a/src/pathsim/utils/adaptivebuffer.py +++ b/src/pathsim/utils/adaptivebuffer.py @@ -10,6 +10,8 @@ # IMPORTS ============================================================================== +import numpy as np + from collections import deque from bisect import bisect_left @@ -120,4 +122,45 @@ def get(self, t): def clear(self): """clear the buffer, reset everything""" self.buffer_t.clear() - self.buffer_v.clear() \ No newline at end of file + self.buffer_v.clear() + + + def to_checkpoint(self, prefix): + """Serialize buffer state for checkpointing. + + Parameters + ---------- + prefix : str + NPZ key prefix + + Returns + ------- + npz_data : dict + numpy arrays keyed by path + """ + npz_data = {} + if self.buffer_t: + npz_data[f"{prefix}/buffer_t"] = np.array(list(self.buffer_t)) + npz_data[f"{prefix}/buffer_v"] = np.array(list(self.buffer_v)) + return npz_data + + + def load_checkpoint(self, npz, prefix): + """Restore buffer state from checkpoint. + + Parameters + ---------- + npz : dict-like + numpy arrays from checkpoint NPZ + prefix : str + NPZ key prefix + """ + self.clear() + t_key = f"{prefix}/buffer_t" + v_key = f"{prefix}/buffer_v" + if t_key in npz and v_key in npz: + times = npz[t_key] + values = npz[v_key] + for t, v in zip(times, values): + self.buffer_t.append(float(t)) + self.buffer_v.append(v) \ No newline at end of file diff --git a/tests/pathsim/test_checkpoint.py b/tests/pathsim/test_checkpoint.py new file mode 100644 index 00000000..b0bcc470 --- /dev/null +++ b/tests/pathsim/test_checkpoint.py @@ -0,0 +1,256 @@ +"""Tests for checkpoint save/load functionality.""" + +import os +import json +import tempfile + +import numpy as np +import pytest + +from pathsim import Simulation, Connection +from pathsim.blocks import ( + Source, Integrator, Amplifier, Scope, Constant +) +from pathsim.blocks.delay import Delay +from pathsim.blocks.switch import Switch + + +class TestBlockCheckpoint: + """Test block-level checkpoint methods.""" + + def test_basic_block_to_checkpoint(self): + """Block produces valid checkpoint data.""" + b = Integrator(1.0) + b.inputs[0] = 3.14 + json_data, npz_data = b.to_checkpoint() + + assert json_data["type"] == "Integrator" + assert json_data["id"] == b.id + assert json_data["active"] is True + assert f"{b.id}/inputs" in npz_data + assert f"{b.id}/outputs" in npz_data + + def test_block_has_uuid(self): + """Each block gets a unique UUID.""" + b1 = Integrator() + b2 = Integrator() + assert b1.id != b2.id + assert len(b1.id) == 32 # hex UUID without dashes + + def test_block_checkpoint_roundtrip(self): + """Block state survives save/load cycle.""" + b = Integrator(2.5) + b.inputs[0] = 1.0 + b.outputs[0] = 2.5 + + json_data, npz_data = b.to_checkpoint() + + #reset block + b.reset() + assert b.inputs[0] == 0.0 + + #restore + b.load_checkpoint(json_data, npz_data) + assert np.isclose(b.inputs[0], 1.0) + assert np.isclose(b.outputs[0], 2.5) + + def test_block_type_mismatch_raises(self): + """Loading checkpoint with wrong type raises ValueError.""" + b = Integrator() + json_data, npz_data = b.to_checkpoint() + + b2 = Amplifier(1.0) + with pytest.raises(ValueError, match="type mismatch"): + b2.load_checkpoint(json_data, npz_data) + + +class TestEventCheckpoint: + """Test event-level checkpoint methods.""" + + def test_event_has_uuid(self): + from pathsim.events import ZeroCrossing + e = ZeroCrossing(func_evt=lambda t: t - 1.0) + assert len(e.id) == 32 + + def test_event_checkpoint_roundtrip(self): + from pathsim.events import ZeroCrossing + e = ZeroCrossing(func_evt=lambda t: t - 1.0) + e._history = (0.5, 0.99) + e._times = [1.0, 2.0, 3.0] + e._active = False + + json_data, npz_data = e.to_checkpoint() + + e.reset() + assert e._active is True + assert len(e._times) == 0 + + e.load_checkpoint(json_data, npz_data) + assert e._active is False + assert e._times == [1.0, 2.0, 3.0] + assert e._history == (0.5, 0.99) + + +class TestSwitchCheckpoint: + """Test Switch block checkpoint.""" + + def test_switch_state_preserved(self): + s = Switch(switch_state=2) + json_data, npz_data = s.to_checkpoint() + + s.select(None) + assert s.switch_state is None + + s.load_checkpoint(json_data, npz_data) + assert s.switch_state == 2 + + +class TestSimulationCheckpoint: + """Test simulation-level checkpoint save/load.""" + + def test_save_load_simple(self): + """Simple simulation checkpoint round-trip.""" + src = Source(lambda t: np.sin(2 * np.pi * t)) + integ = Integrator() + scope = Scope() + + sim = Simulation( + blocks=[src, integ, scope], + connections=[ + Connection(src, integ, scope), + ], + dt=0.01 + ) + + #run for 1 second + sim.run(1.0) + time_after_run = sim.time + state_after_run = integ.state.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "checkpoint") + sim.save_checkpoint(path) + + #verify files exist + assert os.path.exists(f"{path}.json") + assert os.path.exists(f"{path}.npz") + + #verify JSON structure + with open(f"{path}.json") as f: + data = json.load(f) + assert data["version"] == "1.0.0" + assert data["simulation"]["time"] == time_after_run + assert integ.id in data["blocks"] + + #reset and reload + sim.time = 0.0 + integ.state = np.array([0.0]) + + sim.load_checkpoint(path) + assert sim.time == time_after_run + assert np.allclose(integ.state, state_after_run) + + def test_continue_after_load(self): + """Simulation continues correctly after checkpoint load.""" + #run continuously for 2 seconds + src1 = Source(lambda t: 1.0) + integ1 = Integrator() + sim1 = Simulation( + blocks=[src1, integ1], + connections=[Connection(src1, integ1)], + dt=0.01 + ) + sim1.run(2.0) + reference_state = integ1.state.copy() + + #run for 1 second, save, load, run 1 more second + src2 = Source(lambda t: 1.0) + integ2 = Integrator() + sim2 = Simulation( + blocks=[src2, integ2], + connections=[Connection(src2, integ2)], + dt=0.01 + ) + sim2.run(1.0) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim2.save_checkpoint(path) + sim2.load_checkpoint(path) + sim2.run(1.0) # run 1 more second (t=1 -> t=2) + + #compare results + assert np.allclose(integ2.state, reference_state, rtol=1e-6) + + def test_scope_recordings(self): + """Scope recordings are saved when recordings=True.""" + src = Source(lambda t: t) + scope = Scope() + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.1 + ) + sim.run(1.0) + + with tempfile.TemporaryDirectory() as tmpdir: + #without recordings + path1 = os.path.join(tmpdir, "no_rec") + sim.save_checkpoint(path1, recordings=False) + npz1 = np.load(f"{path1}.npz") + assert f"{scope.id}/recording_time" not in npz1 + npz1.close() + + #with recordings + path2 = os.path.join(tmpdir, "with_rec") + sim.save_checkpoint(path2, recordings=True) + npz2 = np.load(f"{path2}.npz") + assert f"{scope.id}/recording_time" in npz2 + npz2.close() + + def test_delay_continuous_checkpoint(self): + """Continuous delay block preserves buffer.""" + src = Source(lambda t: np.sin(t)) + delay = Delay(tau=0.1) + scope = Scope() + sim = Simulation( + blocks=[src, delay, scope], + connections=[ + Connection(src, delay, scope), + ], + dt=0.01 + ) + sim.run(0.5) + + #capture delay output + delay_output = delay.outputs[0] + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + #reset delay buffer + delay._buffer.clear() + + sim.load_checkpoint(path) + assert np.isclose(delay.outputs[0], delay_output) + + def test_delay_discrete_checkpoint(self): + """Discrete delay block preserves ring buffer.""" + src = Source(lambda t: float(t > 0)) + delay = Delay(tau=0.05, sampling_period=0.01) + sim = Simulation( + blocks=[src, delay], + connections=[Connection(src, delay)], + dt=0.01 + ) + sim.run(0.1) + + ring_before = list(delay._ring) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + delay._ring.clear() + sim.load_checkpoint(path) + assert list(delay._ring) == ring_before From 93c065cdfefa44732a63096c10855f3ab51dfb51 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 09:13:55 +0100 Subject: [PATCH 15/29] Replace set with ordered list + shadow set for blocks, connections, events --- src/pathsim/simulation.py | 92 +++++++++++++++++++------------- src/pathsim/subsystem.py | 49 +++++++++-------- tests/pathsim/test_simulation.py | 16 +++--- 3 files changed, 91 insertions(+), 66 deletions(-) diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index ed6edbfb..924b4817 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -153,10 +153,10 @@ class Simulation: get attributes and access to intermediate evaluation stages logger : logging.Logger global simulation logger - _blocks_dyn : set[Block] - blocks with internal ´Solver´ instances (stateful) - _blocks_evt : set[Block] - blocks with internal events (discrete time, eventful) + _blocks_dyn : list[Block] + blocks with internal ´Solver´ instances (stateful) + _blocks_evt : list[Block] + blocks with internal events (discrete time, eventful) _active : bool flag for setting the simulation as active, used for interrupts """ @@ -176,10 +176,13 @@ def __init__( **solver_kwargs ): - #system definition - self.blocks = set() - self.connections = set() - self.events = set() + #system definition (ordered lists with shadow sets for O(1) lookup) + self.blocks = [] + self._block_set = set() + self.connections = [] + self._conn_set = set() + self.events = [] + self._event_set = set() #simulation timestep and bounds self.dt = dt @@ -215,10 +218,12 @@ def __init__( self.time = 0.0 #collection of blocks with internal ODE solvers - self._blocks_dyn = set() + self._blocks_dyn = [] + self._blocks_dyn_set = set() #collection of blocks with internal events - self._blocks_evt = set() + self._blocks_evt = [] + self._blocks_evt_set = set() #flag for setting the simulation active self._active = True @@ -269,9 +274,9 @@ def __contains__(self, other): bool """ return ( - other in self.blocks or - other in self.connections or - other in self.events + other in self._block_set or + other in self._conn_set or + other in self._event_set ) @@ -480,7 +485,7 @@ def add_block(self, block): """ #check if block already in block list - if block in self.blocks: + if block in self._block_set: _msg = f"block {block} already part of simulation" self.logger.error(_msg) raise ValueError(_msg) @@ -490,14 +495,17 @@ def add_block(self, block): #add to dynamic list if solver was initialized if block.engine: - self._blocks_dyn.add(block) + self._blocks_dyn.append(block) + self._blocks_dyn_set.add(block) #add to eventful list if internal events if block.events: - self._blocks_evt.add(block) + self._blocks_evt.append(block) + self._blocks_evt_set.add(block) #add block to global blocklist - self.blocks.add(block) + self.blocks.append(block) + self._block_set.add(block) #mark graph for rebuild if self.graph: @@ -517,19 +525,24 @@ def remove_block(self, block): """ #check if block is in block list - if block not in self.blocks: + if block not in self._block_set: _msg = f"block {block} not part of simulation" self.logger.error(_msg) raise ValueError(_msg) #remove from global blocklist - self.blocks.discard(block) + self.blocks.remove(block) + self._block_set.discard(block) #remove from dynamic list - self._blocks_dyn.discard(block) + if block in self._blocks_dyn_set: + self._blocks_dyn.remove(block) + self._blocks_dyn_set.discard(block) #remove from eventful list - self._blocks_evt.discard(block) + if block in self._blocks_evt_set: + self._blocks_evt.remove(block) + self._blocks_evt_set.discard(block) #mark graph for rebuild if self.graph: @@ -549,13 +562,14 @@ def add_connection(self, connection): """ #check if connection already in connection list - if connection in self.connections: + if connection in self._conn_set: _msg = f"{connection} already part of simulation" self.logger.error(_msg) raise ValueError(_msg) #add connection to global connection list - self.connections.add(connection) + self.connections.append(connection) + self._conn_set.add(connection) #mark graph for rebuild if self.graph: @@ -575,13 +589,14 @@ def remove_connection(self, connection): """ #check if connection is in connection list - if connection not in self.connections: + if connection not in self._conn_set: _msg = f"{connection} not part of simulation" self.logger.error(_msg) raise ValueError(_msg) #remove from global connection list - self.connections.discard(connection) + self.connections.remove(connection) + self._conn_set.discard(connection) #mark graph for rebuild if self.graph: @@ -600,13 +615,14 @@ def add_event(self, event): """ #check if event already in event list - if event in self.events: + if event in self._event_set: _msg = f"{event} already part of simulation" self.logger.error(_msg) raise ValueError(_msg) #add event to global event list - self.events.add(event) + self.events.append(event) + self._event_set.add(event) def remove_event(self, event): @@ -621,13 +637,14 @@ def remove_event(self, event): """ #check if event is in event list - if event not in self.events: + if event not in self._event_set: _msg = f"{event} not part of simulation" self.logger.error(_msg) raise ValueError(_msg) #remove from global event list - self.events.discard(event) + self.events.remove(event) + self._event_set.discard(event) # system assembly ------------------------------------------------------------- @@ -685,10 +702,11 @@ def _check_blocks_are_managed(self): conn_blocks.update(conn.get_blocks()) # Check subset actively managed - if not conn_blocks.issubset(self.blocks): - self.logger.warning( - f"{blk} in 'connections' but not in 'blocks'!" - ) + for blk in conn_blocks: + if blk not in self._block_set: + self.logger.warning( + f"{blk} in 'connections' but not in 'blocks'!" + ) # solver management ----------------------------------------------------------- @@ -719,13 +737,15 @@ def _set_solver(self, Solver=None, **solver_kwargs): self.engine = self.Solver() #iterate all blocks and set integration engines with tolerances - self._blocks_dyn = set() + self._blocks_dyn = [] + self._blocks_dyn_set = set() for block in self.blocks: block.set_solver(self.Solver, self.engine, **self.solver_kwargs) - + #add dynamic blocks to list if block.engine: - self._blocks_dyn.add(block) + self._blocks_dyn.append(block) + self._blocks_dyn_set.add(block) #logging message self.logger.info( diff --git a/src/pathsim/subsystem.py b/src/pathsim/subsystem.py index 2dd3640c..eedecbfe 100644 --- a/src/pathsim/subsystem.py +++ b/src/pathsim/subsystem.py @@ -181,27 +181,28 @@ def __init__(self, #internal algebraic loop solvers -> initialized later self.boosters = None - #internal connecions - self.connections = set() - if connections: - self.connections.update(connections) - + #internal connecions (ordered list with shadow set for O(1) lookup) + self.connections = list(connections) if connections else [] + self._conn_set = set(self.connections) + #collect and organize internal blocks - self.blocks = set() - self.interface = None + self.blocks = [] + self._block_set = set() + self.interface = None if blocks: for block in blocks: - if isinstance(block, Interface): - + if isinstance(block, Interface): + if self.interface is not None: #interface block is already defined raise ValueError("Subsystem can only have one 'Interface' block!") - + self.interface = block - else: + else: #regular blocks - self.blocks.add(block) + self.blocks.append(block) + self._block_set.add(block) #check if interface is defined if self.interface is None: @@ -252,7 +253,7 @@ def __contains__(self, other): ------- bool """ - return other in self.blocks or other in self.connections + return other in self._block_set or other in self._conn_set # adding and removing system components --------------------------------------------------- @@ -267,7 +268,7 @@ def add_block(self, block): block : Block block to add to the subsystem """ - if block in self.blocks: + if block in self._block_set: raise ValueError(f"block {block} already part of subsystem") #initialize solver if available @@ -276,7 +277,8 @@ def add_block(self, block): if block.engine: self._blocks_dyn.append(block) - self.blocks.add(block) + self.blocks.append(block) + self._block_set.add(block) if self.graph: self._graph_dirty = True @@ -292,10 +294,11 @@ def remove_block(self, block): block : Block block to remove from the subsystem """ - if block not in self.blocks: + if block not in self._block_set: raise ValueError(f"block {block} not part of subsystem") - self.blocks.discard(block) + self.blocks.remove(block) + self._block_set.discard(block) #remove from dynamic list if hasattr(self, '_blocks_dyn') and block in self._blocks_dyn: @@ -315,10 +318,11 @@ def add_connection(self, connection): connection : Connection connection to add to the subsystem """ - if connection in self.connections: + if connection in self._conn_set: raise ValueError(f"{connection} already part of subsystem") - self.connections.add(connection) + self.connections.append(connection) + self._conn_set.add(connection) if self.graph: self._graph_dirty = True @@ -334,10 +338,11 @@ def remove_connection(self, connection): connection : Connection connection to remove from the subsystem """ - if connection not in self.connections: + if connection not in self._conn_set: raise ValueError(f"{connection} not part of subsystem") - self.connections.discard(connection) + self.connections.remove(connection) + self._conn_set.discard(connection) if self.graph: self._graph_dirty = True @@ -386,7 +391,7 @@ def _assemble_graph(self): for block in self.blocks: block.inputs.reset() - self.graph = Graph({*self.blocks, self.interface}, self.connections) + self.graph = Graph([*self.blocks, self.interface], self.connections) self._graph_dirty = False #create boosters for loop closing connections diff --git a/tests/pathsim/test_simulation.py b/tests/pathsim/test_simulation.py index beda6aae..d370b473 100644 --- a/tests/pathsim/test_simulation.py +++ b/tests/pathsim/test_simulation.py @@ -52,9 +52,9 @@ def test_init_default(self): #test default initialization Sim = Simulation(log=False) - self.assertEqual(Sim.blocks, set()) - self.assertEqual(Sim.connections, set()) - self.assertEqual(Sim.events, set()) + self.assertEqual(Sim.blocks, []) + self.assertEqual(Sim.connections, []) + self.assertEqual(Sim.events, []) self.assertEqual(Sim.dt, SIM_TIMESTEP) self.assertEqual(Sim.dt_min, SIM_TIMESTEP_MIN) self.assertEqual(Sim.dt_max, SIM_TIMESTEP_MAX) @@ -130,12 +130,12 @@ def test_add_block(self): Sim = Simulation(log=False) - self.assertEqual(Sim.blocks, set()) + self.assertEqual(Sim.blocks, []) #test adding a block B1 = Block() Sim.add_block(B1) - self.assertEqual(Sim.blocks, {B1}) + self.assertEqual(Sim.blocks, [B1]) #test adding the same block again with self.assertRaises(ValueError): @@ -153,17 +153,17 @@ def test_add_connection(self): log=False ) - self.assertEqual(Sim.connections, {C1}) + self.assertEqual(Sim.connections, [C1]) #test adding a connection C2 = Connection(B2, B3) Sim.add_connection(C2) - self.assertEqual(Sim.connections, {C1, C2}) + self.assertEqual(Sim.connections, [C1, C2]) #test adding the same connection again with self.assertRaises(ValueError): Sim.add_connection(C2) - self.assertEqual(Sim.connections, {C1, C2}) + self.assertEqual(Sim.connections, [C1, C2]) def test_set_solver(self): From 0f6a970f43da2b3e4e6a13f50ed0215466f401ec Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 09:20:36 +0100 Subject: [PATCH 16/29] Use type+index matching for checkpoints instead of UUIDs --- src/pathsim/blocks/_block.py | 22 ++++----- src/pathsim/blocks/delay.py | 10 ++-- src/pathsim/blocks/scope.py | 13 ++--- src/pathsim/blocks/spectrum.py | 8 ++-- src/pathsim/blocks/switch.py | 8 ++-- src/pathsim/events/_event.py | 16 ++++--- src/pathsim/simulation.py | 82 +++++++++++++++++++++++--------- tests/pathsim/test_checkpoint.py | 46 ++++++++---------- 8 files changed, 114 insertions(+), 91 deletions(-) diff --git a/src/pathsim/blocks/_block.py b/src/pathsim/blocks/_block.py index 597195a4..347be870 100644 --- a/src/pathsim/blocks/_block.py +++ b/src/pathsim/blocks/_block.py @@ -530,11 +530,13 @@ def state(self, val): # checkpoint methods ---------------------------------------------------------------- - def to_checkpoint(self, recordings=False): + def to_checkpoint(self, prefix, recordings=False): """Serialize block state for checkpointing. Parameters ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) recordings : bool include recording data (for Scope blocks) @@ -545,10 +547,7 @@ def to_checkpoint(self, recordings=False): npz_data : dict numpy arrays keyed by path """ - prefix = self.id - json_data = { - "id": self.id, "type": self.__class__.__name__, "active": self._active, } @@ -567,8 +566,9 @@ def to_checkpoint(self, recordings=False): #internal events if self.events: evt_jsons = [] - for event in self.events: - e_json, e_npz = event.to_checkpoint() + for i, event in enumerate(self.events): + evt_prefix = f"{prefix}/evt_{i}" + e_json, e_npz = event.to_checkpoint(evt_prefix) evt_jsons.append(e_json) npz_data.update(e_npz) json_data["events"] = evt_jsons @@ -576,18 +576,18 @@ def to_checkpoint(self, recordings=False): return json_data, npz_data - def load_checkpoint(self, json_data, npz): + def load_checkpoint(self, prefix, json_data, npz): """Restore block state from checkpoint. Parameters ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) json_data : dict block metadata from checkpoint JSON npz : dict-like numpy arrays from checkpoint NPZ """ - prefix = json_data["id"] - #verify type if json_data["type"] != self.__class__.__name__: raise ValueError( @@ -611,8 +611,8 @@ def load_checkpoint(self, json_data, npz): #restore internal events if self.events and "events" in json_data: - for event, evt_data in zip(self.events, json_data["events"]): - event.load_checkpoint(evt_data, npz) + for i, (event, evt_data) in enumerate(zip(self.events, json_data["events"])): + event.load_checkpoint(f"{prefix}/evt_{i}", evt_data, npz) # methods for block output and state updates ---------------------------------------- diff --git a/src/pathsim/blocks/delay.py b/src/pathsim/blocks/delay.py index 6b42614c..4e6d0a4f 100644 --- a/src/pathsim/blocks/delay.py +++ b/src/pathsim/blocks/delay.py @@ -142,10 +142,9 @@ def reset(self): self._ring.extend([0.0] * self._n) - def to_checkpoint(self, recordings=False): + def to_checkpoint(self, prefix, recordings=False): """Serialize Delay state including buffer data.""" - json_data, npz_data = super().to_checkpoint(recordings=recordings) - prefix = self.id + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) json_data["sampling_period"] = self.sampling_period @@ -160,10 +159,9 @@ def to_checkpoint(self, recordings=False): return json_data, npz_data - def load_checkpoint(self, json_data, npz): + def load_checkpoint(self, prefix, json_data, npz): """Restore Delay state including buffer data.""" - super().load_checkpoint(json_data, npz) - prefix = json_data["id"] + super().load_checkpoint(prefix, json_data, npz) if self.sampling_period is None: #continuous mode diff --git a/src/pathsim/blocks/scope.py b/src/pathsim/blocks/scope.py index 57854526..ec980785 100644 --- a/src/pathsim/blocks/scope.py +++ b/src/pathsim/blocks/scope.py @@ -448,10 +448,9 @@ def save(self, path="scope.csv"): wrt.writerow(sample) - def to_checkpoint(self, recordings=False): + def to_checkpoint(self, prefix, recordings=False): """Serialize Scope state including optional recording data.""" - json_data, npz_data = super().to_checkpoint(recordings=recordings) - prefix = self.id + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) json_data["_incremental_idx"] = self._incremental_idx if hasattr(self, '_sample_next_timestep'): @@ -464,10 +463,9 @@ def to_checkpoint(self, recordings=False): return json_data, npz_data - def load_checkpoint(self, json_data, npz): + def load_checkpoint(self, prefix, json_data, npz): """Restore Scope state including optional recording data.""" - super().load_checkpoint(json_data, npz) - prefix = json_data["id"] + super().load_checkpoint(prefix, json_data, npz) self._incremental_idx = json_data.get("_incremental_idx", 0) if hasattr(self, '_sample_next_timestep'): @@ -479,9 +477,6 @@ def load_checkpoint(self, json_data, npz): if rt_key in npz and rd_key in npz: self.recording_time = npz[rt_key].tolist() self.recording_data = [row for row in npz[rd_key]] - else: - self.recording_time = [] - self.recording_data = [] def update(self, t): diff --git a/src/pathsim/blocks/spectrum.py b/src/pathsim/blocks/spectrum.py index b4d37fed..0dec61fe 100644 --- a/src/pathsim/blocks/spectrum.py +++ b/src/pathsim/blocks/spectrum.py @@ -283,9 +283,9 @@ def step(self, t, dt): return True, 0.0, None - def to_checkpoint(self, recordings=False): + def to_checkpoint(self, prefix, recordings=False): """Serialize Spectrum state including integration time.""" - json_data, npz_data = super().to_checkpoint(recordings=recordings) + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) json_data["time"] = self.time json_data["t_sample"] = self.t_sample @@ -293,9 +293,9 @@ def to_checkpoint(self, recordings=False): return json_data, npz_data - def load_checkpoint(self, json_data, npz): + def load_checkpoint(self, prefix, json_data, npz): """Restore Spectrum state including integration time.""" - super().load_checkpoint(json_data, npz) + super().load_checkpoint(prefix, json_data, npz) self.time = json_data.get("time", 0.0) self.t_sample = json_data.get("t_sample", 0.0) diff --git a/src/pathsim/blocks/switch.py b/src/pathsim/blocks/switch.py index 3ce04b07..8ee707be 100644 --- a/src/pathsim/blocks/switch.py +++ b/src/pathsim/blocks/switch.py @@ -82,13 +82,13 @@ def select(self, switch_state=0): self.switch_state = switch_state - def to_checkpoint(self, recordings=False): - json_data, npz_data = super().to_checkpoint(recordings=recordings) + def to_checkpoint(self, prefix, recordings=False): + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) json_data["switch_state"] = self.switch_state return json_data, npz_data - def load_checkpoint(self, json_data, npz): - super().load_checkpoint(json_data, npz) + def load_checkpoint(self, prefix, json_data, npz): + super().load_checkpoint(prefix, json_data, npz) self.switch_state = json_data.get("switch_state", None) def update(self, t): diff --git a/src/pathsim/events/_event.py b/src/pathsim/events/_event.py index 124c99d1..85a14625 100644 --- a/src/pathsim/events/_event.py +++ b/src/pathsim/events/_event.py @@ -211,9 +211,14 @@ def resolve(self, t): # checkpoint methods ---------------------------------------------------------------- - def to_checkpoint(self): + def to_checkpoint(self, prefix): """Serialize event state for checkpointing. + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + Returns ------- json_data : dict @@ -221,15 +226,12 @@ def to_checkpoint(self): npz_data : dict numpy arrays keyed by path """ - prefix = self.id - #extract history eval value hist_eval, hist_time = self._history if hist_eval is not None and hasattr(hist_eval, 'item'): hist_eval = float(hist_eval) json_data = { - "id": self.id, "type": self.__class__.__name__, "active": self._active, "history_eval": hist_eval, @@ -243,18 +245,18 @@ def to_checkpoint(self): return json_data, npz_data - def load_checkpoint(self, json_data, npz): + def load_checkpoint(self, prefix, json_data, npz): """Restore event state from checkpoint. Parameters ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) json_data : dict event metadata from checkpoint JSON npz : dict-like numpy arrays from checkpoint NPZ """ - prefix = json_data["id"] - self._active = json_data["active"] self._history = json_data["history_eval"], json_data["history_time"] diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index 924b4817..279820ae 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -338,11 +338,34 @@ def plot(self, *args, **kwargs): # checkpoint methods ---------------------------------------------------------- + @staticmethod + def _checkpoint_key(type_name, type_counts): + """Generate a deterministic checkpoint key from block/event type + and occurrence index (e.g. 'Integrator_0', 'Scope_1'). + + Parameters + ---------- + type_name : str + class name of the block or event + type_counts : dict + running counter per type name, mutated in place + + Returns + ------- + key : str + deterministic checkpoint key + """ + idx = type_counts.get(type_name, 0) + type_counts[type_name] = idx + 1 + return f"{type_name}_{idx}" + + def save_checkpoint(self, path, recordings=False): """Save simulation state to checkpoint files (JSON + NPZ). Creates two files: {path}.json (structure/metadata) and - {path}.npz (numerical data). + {path}.npz (numerical data). Blocks and events are keyed by + type and insertion order for deterministic cross-instance matching. Parameters ---------- @@ -371,22 +394,28 @@ def save_checkpoint(self, path, recordings=False): "tolerance_fpi": self.tolerance_fpi, "iterations_max": self.iterations_max, }, - "blocks": {}, - "events": {}, + "blocks": [], + "events": [], } npz_data = {} - #checkpoint all blocks (keyed by UUID) + #checkpoint all blocks (keyed by type + insertion index) + type_counts = {} for block in self.blocks: - b_json, b_npz = block.to_checkpoint(recordings=recordings) - checkpoint["blocks"][block.id] = b_json + key = self._checkpoint_key(block.__class__.__name__, type_counts) + b_json, b_npz = block.to_checkpoint(key, recordings=recordings) + b_json["_key"] = key + checkpoint["blocks"].append(b_json) npz_data.update(b_npz) - #checkpoint external events (keyed by UUID) + #checkpoint external events (keyed by type + insertion index) + type_counts = {} for event in self.events: - e_json, e_npz = event.to_checkpoint() - checkpoint["events"][event.id] = e_json + key = self._checkpoint_key(event.__class__.__name__, type_counts) + e_json, e_npz = event.to_checkpoint(key) + e_json["_key"] = key + checkpoint["events"].append(e_json) npz_data.update(e_npz) #write files @@ -400,8 +429,9 @@ def load_checkpoint(self, path): """Load simulation state from checkpoint files (JSON + NPZ). Restores simulation time and all block/event states from a - previously saved checkpoint. The simulation must have the same - blocks and events as when the checkpoint was saved. + previously saved checkpoint. Matching is based on block/event + type and insertion order, so the simulation must be constructed + with the same block types in the same order. Parameters ---------- @@ -444,26 +474,32 @@ def load_checkpoint(self, path): f"current solver '{self.Solver.__name__}'" ) - #restore blocks - block_data = checkpoint.get("blocks", {}) + #index checkpoint blocks by key + block_data = {b["_key"]: b for b in checkpoint.get("blocks", [])} + + #restore blocks by type + insertion order + type_counts = {} for block in self.blocks: - if block.id in block_data: - block.load_checkpoint(block_data[block.id], npz) + key = self._checkpoint_key(block.__class__.__name__, type_counts) + if key in block_data: + block.load_checkpoint(key, block_data[key], npz) else: warnings.warn( - f"Block {block.__class__.__name__} (id={block.id[:8]}...) " - f"not found in checkpoint" + f"Block '{key}' not found in checkpoint" ) - #restore external events - event_data = checkpoint.get("events", {}) + #index checkpoint events by key + event_data = {e["_key"]: e for e in checkpoint.get("events", [])} + + #restore external events by type + insertion order + type_counts = {} for event in self.events: - if event.id in event_data: - event.load_checkpoint(event_data[event.id], npz) + key = self._checkpoint_key(event.__class__.__name__, type_counts) + if key in event_data: + event.load_checkpoint(key, event_data[key], npz) else: warnings.warn( - f"Event {event.__class__.__name__} (id={event.id[:8]}...) " - f"not found in checkpoint" + f"Event '{key}' not found in checkpoint" ) finally: diff --git a/tests/pathsim/test_checkpoint.py b/tests/pathsim/test_checkpoint.py index b0bcc470..b8af6b0d 100644 --- a/tests/pathsim/test_checkpoint.py +++ b/tests/pathsim/test_checkpoint.py @@ -22,70 +22,61 @@ def test_basic_block_to_checkpoint(self): """Block produces valid checkpoint data.""" b = Integrator(1.0) b.inputs[0] = 3.14 - json_data, npz_data = b.to_checkpoint() + prefix = "Integrator_0" + json_data, npz_data = b.to_checkpoint(prefix) assert json_data["type"] == "Integrator" - assert json_data["id"] == b.id assert json_data["active"] is True - assert f"{b.id}/inputs" in npz_data - assert f"{b.id}/outputs" in npz_data - - def test_block_has_uuid(self): - """Each block gets a unique UUID.""" - b1 = Integrator() - b2 = Integrator() - assert b1.id != b2.id - assert len(b1.id) == 32 # hex UUID without dashes + assert f"{prefix}/inputs" in npz_data + assert f"{prefix}/outputs" in npz_data def test_block_checkpoint_roundtrip(self): """Block state survives save/load cycle.""" b = Integrator(2.5) b.inputs[0] = 1.0 b.outputs[0] = 2.5 + prefix = "Integrator_0" - json_data, npz_data = b.to_checkpoint() + json_data, npz_data = b.to_checkpoint(prefix) #reset block b.reset() assert b.inputs[0] == 0.0 #restore - b.load_checkpoint(json_data, npz_data) + b.load_checkpoint(prefix, json_data, npz_data) assert np.isclose(b.inputs[0], 1.0) assert np.isclose(b.outputs[0], 2.5) def test_block_type_mismatch_raises(self): """Loading checkpoint with wrong type raises ValueError.""" b = Integrator() - json_data, npz_data = b.to_checkpoint() + prefix = "Integrator_0" + json_data, npz_data = b.to_checkpoint(prefix) b2 = Amplifier(1.0) with pytest.raises(ValueError, match="type mismatch"): - b2.load_checkpoint(json_data, npz_data) + b2.load_checkpoint(prefix, json_data, npz_data) class TestEventCheckpoint: """Test event-level checkpoint methods.""" - def test_event_has_uuid(self): - from pathsim.events import ZeroCrossing - e = ZeroCrossing(func_evt=lambda t: t - 1.0) - assert len(e.id) == 32 - def test_event_checkpoint_roundtrip(self): from pathsim.events import ZeroCrossing e = ZeroCrossing(func_evt=lambda t: t - 1.0) e._history = (0.5, 0.99) e._times = [1.0, 2.0, 3.0] e._active = False + prefix = "ZeroCrossing_0" - json_data, npz_data = e.to_checkpoint() + json_data, npz_data = e.to_checkpoint(prefix) e.reset() assert e._active is True assert len(e._times) == 0 - e.load_checkpoint(json_data, npz_data) + e.load_checkpoint(prefix, json_data, npz_data) assert e._active is False assert e._times == [1.0, 2.0, 3.0] assert e._history == (0.5, 0.99) @@ -96,12 +87,13 @@ class TestSwitchCheckpoint: def test_switch_state_preserved(self): s = Switch(switch_state=2) - json_data, npz_data = s.to_checkpoint() + prefix = "Switch_0" + json_data, npz_data = s.to_checkpoint(prefix) s.select(None) assert s.switch_state is None - s.load_checkpoint(json_data, npz_data) + s.load_checkpoint(prefix, json_data, npz_data) assert s.switch_state == 2 @@ -140,7 +132,7 @@ def test_save_load_simple(self): data = json.load(f) assert data["version"] == "1.0.0" assert data["simulation"]["time"] == time_after_run - assert integ.id in data["blocks"] + assert any(b["_key"] == "Integrator_0" for b in data["blocks"]) #reset and reload sim.time = 0.0 @@ -198,14 +190,14 @@ def test_scope_recordings(self): path1 = os.path.join(tmpdir, "no_rec") sim.save_checkpoint(path1, recordings=False) npz1 = np.load(f"{path1}.npz") - assert f"{scope.id}/recording_time" not in npz1 + assert "Scope_0/recording_time" not in npz1 npz1.close() #with recordings path2 = os.path.join(tmpdir, "with_rec") sim.save_checkpoint(path2, recordings=True) npz2 = np.load(f"{path2}.npz") - assert f"{scope.id}/recording_time" in npz2 + assert "Scope_0/recording_time" in npz2 npz2.close() def test_delay_continuous_checkpoint(self): From 11697a068f56c09812ec8db64a82447288020586 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 09:22:49 +0100 Subject: [PATCH 17/29] Move json/warnings imports to top-level, fix missing trailing newlines --- src/pathsim/events/_event.py | 2 +- src/pathsim/simulation.py | 8 +++----- src/pathsim/utils/adaptivebuffer.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/pathsim/events/_event.py b/src/pathsim/events/_event.py index 85a14625..58f657b5 100644 --- a/src/pathsim/events/_event.py +++ b/src/pathsim/events/_event.py @@ -264,4 +264,4 @@ def load_checkpoint(self, prefix, json_data, npz): if times_key in npz: self._times = npz[times_key].tolist() else: - self._times = [] \ No newline at end of file + self._times = [] diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index 279820ae..64402bd7 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -10,6 +10,9 @@ # IMPORTS =============================================================================== +import json +import warnings + import numpy as np import time @@ -374,8 +377,6 @@ def save_checkpoint(self, path, recordings=False): recordings : bool include scope/spectrum recording data (default: False) """ - import json - #strip extension if provided if path.endswith('.json') or path.endswith('.npz'): path = path.rsplit('.', 1)[0] @@ -438,9 +439,6 @@ def load_checkpoint(self, path): path : str base path without extension """ - import json - import warnings - #strip extension if provided if path.endswith('.json') or path.endswith('.npz'): path = path.rsplit('.', 1)[0] diff --git a/src/pathsim/utils/adaptivebuffer.py b/src/pathsim/utils/adaptivebuffer.py index 05dd82fa..5b37fa05 100644 --- a/src/pathsim/utils/adaptivebuffer.py +++ b/src/pathsim/utils/adaptivebuffer.py @@ -163,4 +163,4 @@ def load_checkpoint(self, npz, prefix): values = npz[v_key] for t, v in zip(times, values): self.buffer_t.append(float(t)) - self.buffer_v.append(v) \ No newline at end of file + self.buffer_v.append(v) From 13189ff62023fb1470b3d01fecb7c91003fdc2bd Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 09:30:51 +0100 Subject: [PATCH 18/29] Add checkpoint overrides for FIR, KalmanFilter, noise blocks, RNG, Subsystem --- src/pathsim/blocks/fir.py | 15 ++++++ src/pathsim/blocks/kalman.py | 16 ++++++ src/pathsim/blocks/noise.py | 27 ++++++++++ src/pathsim/blocks/rng.py | 14 ++++++ src/pathsim/solvers/gear.py | 4 +- src/pathsim/subsystem.py | 96 ++++++++++++++++++++++++++++++++++++ 6 files changed, 170 insertions(+), 2 deletions(-) diff --git a/src/pathsim/blocks/fir.py b/src/pathsim/blocks/fir.py index 8db1a8a3..c2766bc1 100644 --- a/src/pathsim/blocks/fir.py +++ b/src/pathsim/blocks/fir.py @@ -114,6 +114,21 @@ def reset(self): self._buffer = deque([0.0]*n, maxlen=n) + def to_checkpoint(self, prefix, recordings=False): + """Serialize FIR state including input buffer.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + npz_data[f"{prefix}/fir_buffer"] = np.array(list(self._buffer)) + return json_data, npz_data + + def load_checkpoint(self, prefix, json_data, npz): + """Restore FIR state including input buffer.""" + super().load_checkpoint(prefix, json_data, npz) + key = f"{prefix}/fir_buffer" + if key in npz: + self._buffer.clear() + self._buffer.extend(npz[key].tolist()) + + def __len__(self): """This block has no direct passthrough""" return 0 \ No newline at end of file diff --git a/src/pathsim/blocks/kalman.py b/src/pathsim/blocks/kalman.py index 783ae537..98374a38 100644 --- a/src/pathsim/blocks/kalman.py +++ b/src/pathsim/blocks/kalman.py @@ -143,6 +143,22 @@ def __len__(self): return 0 + def to_checkpoint(self, prefix, recordings=False): + """Serialize Kalman filter state estimate and covariance.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + npz_data[f"{prefix}/kf_x"] = self.x + npz_data[f"{prefix}/kf_P"] = self.P + return json_data, npz_data + + def load_checkpoint(self, prefix, json_data, npz): + """Restore Kalman filter state estimate and covariance.""" + super().load_checkpoint(prefix, json_data, npz) + if f"{prefix}/kf_x" in npz: + self.x = npz[f"{prefix}/kf_x"] + if f"{prefix}/kf_P" in npz: + self.P = npz[f"{prefix}/kf_P"] + + def _kf_update(self): """Perform one Kalman filter update step.""" diff --git a/src/pathsim/blocks/noise.py b/src/pathsim/blocks/noise.py index 101828ea..555eb5be 100644 --- a/src/pathsim/blocks/noise.py +++ b/src/pathsim/blocks/noise.py @@ -44,6 +44,17 @@ class WhiteNoise(Block): random seed for reproducibility """ + def to_checkpoint(self, prefix, recordings=False): + """Serialize WhiteNoise state including current sample.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + json_data["_current_sample"] = float(self._current_sample) + return json_data, npz_data + + def load_checkpoint(self, prefix, json_data, npz): + """Restore WhiteNoise state including current sample.""" + super().load_checkpoint(prefix, json_data, npz) + self._current_sample = json_data.get("_current_sample", 0.0) + input_port_labels = {} output_port_labels = {"out": 0} @@ -156,6 +167,22 @@ class PinkNoise(Block): random seed for reproducibility """ + def to_checkpoint(self, prefix, recordings=False): + """Serialize PinkNoise state including algorithm state.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + json_data["n_samples"] = self.n_samples + json_data["_current_sample"] = float(self._current_sample) + npz_data[f"{prefix}/octave_values"] = self.octave_values + return json_data, npz_data + + def load_checkpoint(self, prefix, json_data, npz): + """Restore PinkNoise state including algorithm state.""" + super().load_checkpoint(prefix, json_data, npz) + self.n_samples = json_data.get("n_samples", 0) + self._current_sample = json_data.get("_current_sample", 0.0) + if f"{prefix}/octave_values" in npz: + self.octave_values = npz[f"{prefix}/octave_values"] + input_port_labels = {} output_port_labels = {"out": 0} diff --git a/src/pathsim/blocks/rng.py b/src/pathsim/blocks/rng.py index 5841b5a5..974e181e 100644 --- a/src/pathsim/blocks/rng.py +++ b/src/pathsim/blocks/rng.py @@ -96,6 +96,20 @@ def sample(self, t, dt): self._sample = np.random.rand() + def to_checkpoint(self, prefix, recordings=False): + """Serialize RNG state including current sample.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + if self.sampling_period is None: + json_data["_sample"] = float(self._sample) + return json_data, npz_data + + def load_checkpoint(self, prefix, json_data, npz): + """Restore RNG state including current sample.""" + super().load_checkpoint(prefix, json_data, npz) + if self.sampling_period is None: + self._sample = json_data.get("_sample", 0.0) + + def __len__(self): """Essentially a source-like block without passthrough""" return 0 diff --git a/src/pathsim/solvers/gear.py b/src/pathsim/solvers/gear.py index 6f745371..22968194 100644 --- a/src/pathsim/solvers/gear.py +++ b/src/pathsim/solvers/gear.py @@ -250,8 +250,8 @@ def load_checkpoint(self, json_data, npz, prefix): #recompute BDF coefficients from restored history if not self._needs_startup and len(self.history_dt) > 0: self.F, self.K = {}, {} - for n, _ in enumerate(self.history_dt, 1): - self.F[n], self.K[n] = compute_bdf_coefficients(n, np.array(self.history_dt)) + for k, _ in enumerate(self.history_dt, 1): + self.F[k], self.K[k] = compute_bdf_coefficients(k, np.array(self.history_dt)) def stages(self, t, dt): diff --git a/src/pathsim/subsystem.py b/src/pathsim/subsystem.py index eedecbfe..b515bd34 100644 --- a/src/pathsim/subsystem.py +++ b/src/pathsim/subsystem.py @@ -462,6 +462,102 @@ def reset(self): block.reset() + def to_checkpoint(self, prefix, recordings=False): + """Serialize subsystem state by recursively checkpointing internal blocks. + + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + recordings : bool + include recording data (for Scope blocks) + + Returns + ------- + json_data : dict + JSON-serializable metadata + npz_data : dict + numpy arrays keyed by path + """ + json_data = { + "type": self.__class__.__name__, + "active": self._active, + "blocks": [], + } + npz_data = {} + + #checkpoint interface block + if_json, if_npz = self.interface.to_checkpoint(f"{prefix}/interface", recordings=recordings) + json_data["interface"] = if_json + npz_data.update(if_npz) + + #checkpoint internal blocks by type + insertion order + type_counts = {} + for block in self.blocks: + type_name = block.__class__.__name__ + idx = type_counts.get(type_name, 0) + type_counts[type_name] = idx + 1 + key = f"{prefix}/{type_name}_{idx}" + b_json, b_npz = block.to_checkpoint(key, recordings=recordings) + b_json["_key"] = key + json_data["blocks"].append(b_json) + npz_data.update(b_npz) + + #checkpoint subsystem-level events + if self._events: + evt_jsons = [] + for i, event in enumerate(self._events): + evt_prefix = f"{prefix}/evt_{i}" + e_json, e_npz = event.to_checkpoint(evt_prefix) + evt_jsons.append(e_json) + npz_data.update(e_npz) + json_data["events"] = evt_jsons + + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore subsystem state by recursively loading internal blocks. + + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + json_data : dict + subsystem metadata from checkpoint JSON + npz : dict-like + numpy arrays from checkpoint NPZ + """ + #verify type + if json_data["type"] != self.__class__.__name__: + raise ValueError( + f"Checkpoint type mismatch: expected '{self.__class__.__name__}', " + f"got '{json_data['type']}'" + ) + + self._active = json_data["active"] + + #restore interface block + if "interface" in json_data: + self.interface.load_checkpoint(f"{prefix}/interface", json_data["interface"], npz) + + #restore internal blocks by type + insertion order + block_data = {b["_key"]: b for b in json_data.get("blocks", [])} + type_counts = {} + for block in self.blocks: + type_name = block.__class__.__name__ + idx = type_counts.get(type_name, 0) + type_counts[type_name] = idx + 1 + key = f"{prefix}/{type_name}_{idx}" + if key in block_data: + block.load_checkpoint(key, block_data[key], npz) + + #restore subsystem-level events + if self._events and "events" in json_data: + for i, (event, evt_data) in enumerate(zip(self._events, json_data["events"])): + event.load_checkpoint(f"{prefix}/evt_{i}", evt_data, npz) + + def on(self): """Activate the subsystem and all internal blocks, sets the boolean evaluation flag to 'True'. From 00e4f30c9b3d34330b5bbc36597c434b0f64d04e Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 09:33:02 +0100 Subject: [PATCH 19/29] Add comprehensive checkpoint tests: cross-instance, FIR, Kalman, noise, Subsystem --- tests/pathsim/test_checkpoint.py | 280 ++++++++++++++++++++++++++++++- 1 file changed, 278 insertions(+), 2 deletions(-) diff --git a/tests/pathsim/test_checkpoint.py b/tests/pathsim/test_checkpoint.py index b8af6b0d..ac4df056 100644 --- a/tests/pathsim/test_checkpoint.py +++ b/tests/pathsim/test_checkpoint.py @@ -7,12 +7,16 @@ import numpy as np import pytest -from pathsim import Simulation, Connection +from pathsim import Simulation, Connection, Subsystem, Interface from pathsim.blocks import ( - Source, Integrator, Amplifier, Scope, Constant + Source, Integrator, Amplifier, Scope, Constant, Function ) from pathsim.blocks.delay import Delay from pathsim.blocks.switch import Switch +from pathsim.blocks.fir import FIR +from pathsim.blocks.kalman import KalmanFilter +from pathsim.blocks.noise import WhiteNoise, PinkNoise +from pathsim.blocks.rng import RandomNumberGenerator class TestBlockCheckpoint: @@ -246,3 +250,275 @@ def test_delay_discrete_checkpoint(self): delay._ring.clear() sim.load_checkpoint(path) assert list(delay._ring) == ring_before + + def test_cross_instance_load(self): + """Checkpoint loads into a freshly constructed simulation (different UUIDs).""" + src1 = Source(lambda t: 1.0) + integ1 = Integrator() + sim1 = Simulation( + blocks=[src1, integ1], + connections=[Connection(src1, integ1)], + dt=0.01 + ) + sim1.run(1.0) + saved_time = sim1.time + saved_state = integ1.state.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim1.save_checkpoint(path) + + #create entirely new simulation (new block objects, new UUIDs) + src2 = Source(lambda t: 1.0) + integ2 = Integrator() + sim2 = Simulation( + blocks=[src2, integ2], + connections=[Connection(src2, integ2)], + dt=0.01 + ) + + #UUIDs differ + assert src1.id != src2.id + assert integ1.id != integ2.id + + sim2.load_checkpoint(path) + assert sim2.time == saved_time + assert np.allclose(integ2.state, saved_state) + + def test_scope_recordings_preserved_without_flag(self): + """Loading without recordings flag does not erase existing recordings.""" + src = Source(lambda t: t) + scope = Scope() + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.1 + ) + sim.run(1.0) + + #scope has recordings + assert len(scope.recording_time) > 0 + rec_len = len(scope.recording_time) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path, recordings=False) + sim.load_checkpoint(path) + + #recordings should still be intact + assert len(scope.recording_time) == rec_len + + def test_multiple_same_type_blocks(self): + """Multiple blocks of the same type are matched by insertion order.""" + src = Source(lambda t: 1.0) + i1 = Integrator(1.0) + i2 = Integrator(2.0) + sim = Simulation( + blocks=[src, i1, i2], + connections=[Connection(src, i1), Connection(src, i2)], + dt=0.01 + ) + sim.run(0.5) + + state1 = i1.state.copy() + state2 = i2.state.copy() + assert not np.allclose(state1, state2) # different initial conditions + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + i1.state = np.array([0.0]) + i2.state = np.array([0.0]) + + sim.load_checkpoint(path) + assert np.allclose(i1.state, state1) + assert np.allclose(i2.state, state2) + + +class TestFIRCheckpoint: + """Test FIR block checkpoint.""" + + def test_fir_buffer_preserved(self): + """FIR filter buffer survives checkpoint round-trip.""" + fir = FIR(coeffs=[0.25, 0.5, 0.25], T=0.01) + prefix = "FIR_0" + + #simulate some input to fill the buffer + fir._buffer.appendleft(1.0) + fir._buffer.appendleft(2.0) + buffer_before = list(fir._buffer) + + json_data, npz_data = fir.to_checkpoint(prefix) + + fir._buffer.clear() + fir._buffer.extend([0.0] * 3) + + fir.load_checkpoint(prefix, json_data, npz_data) + assert list(fir._buffer) == buffer_before + + +class TestKalmanFilterCheckpoint: + """Test KalmanFilter block checkpoint.""" + + def test_kalman_state_preserved(self): + """Kalman filter state and covariance survive checkpoint.""" + F = np.array([[1.0, 0.1], [0.0, 1.0]]) + H = np.array([[1.0, 0.0]]) + Q = np.eye(2) * 0.01 + R = np.array([[0.1]]) + + kf = KalmanFilter(F, H, Q, R) + prefix = "KalmanFilter_0" + + #set some state + kf.x = np.array([3.14, -1.0]) + kf.P = np.array([[0.5, 0.1], [0.1, 0.3]]) + + json_data, npz_data = kf.to_checkpoint(prefix) + + kf.x = np.zeros(2) + kf.P = np.eye(2) + + kf.load_checkpoint(prefix, json_data, npz_data) + assert np.allclose(kf.x, [3.14, -1.0]) + assert np.allclose(kf.P, [[0.5, 0.1], [0.1, 0.3]]) + + +class TestNoiseCheckpoint: + """Test noise block checkpoints.""" + + def test_white_noise_sample_preserved(self): + """WhiteNoise current sample survives checkpoint.""" + wn = WhiteNoise(standard_deviation=2.0) + wn._current_sample = 1.234 + prefix = "WhiteNoise_0" + + json_data, npz_data = wn.to_checkpoint(prefix) + wn._current_sample = 0.0 + + wn.load_checkpoint(prefix, json_data, npz_data) + assert wn._current_sample == pytest.approx(1.234) + + def test_pink_noise_state_preserved(self): + """PinkNoise algorithm state survives checkpoint.""" + pn = PinkNoise(num_octaves=8, seed=42) + prefix = "PinkNoise_0" + + #advance the noise state + for _ in range(10): + pn._generate_sample(0.01) + + n_samples_before = pn.n_samples + octaves_before = pn.octave_values.copy() + sample_before = pn._current_sample + + json_data, npz_data = pn.to_checkpoint(prefix) + + pn.reset() + assert pn.n_samples == 0 + + pn.load_checkpoint(prefix, json_data, npz_data) + assert pn.n_samples == n_samples_before + assert np.allclose(pn.octave_values, octaves_before) + + +class TestRNGCheckpoint: + """Test RandomNumberGenerator checkpoint.""" + + def test_rng_sample_preserved(self): + """RNG current sample survives checkpoint (continuous mode).""" + rng = RandomNumberGenerator(sampling_period=None) + prefix = "RandomNumberGenerator_0" + sample_before = rng._sample + + json_data, npz_data = rng.to_checkpoint(prefix) + rng._sample = 0.0 + + rng.load_checkpoint(prefix, json_data, npz_data) + assert rng._sample == pytest.approx(sample_before) + + +class TestSubsystemCheckpoint: + """Test Subsystem checkpoint.""" + + def test_subsystem_roundtrip(self): + """Subsystem with internal blocks survives checkpoint round-trip.""" + #build a simple subsystem: two integrators in series + If = Interface() + I1 = Integrator(1.0) + I2 = Integrator(0.0) + + sub = Subsystem( + blocks=[If, I1, I2], + connections=[ + Connection(If, I1), + Connection(I1, I2), + Connection(I2, If), + ] + ) + + #embed in a simulation + src = Source(lambda t: 1.0) + scope = Scope() + sim = Simulation( + blocks=[src, sub, scope], + connections=[ + Connection(src, sub), + Connection(sub, scope), + ], + dt=0.01 + ) + + sim.run(0.5) + state_I1 = I1.state.copy() + state_I2 = I2.state.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + #zero out states + I1.state = np.array([0.0]) + I2.state = np.array([0.0]) + + sim.load_checkpoint(path) + assert np.allclose(I1.state, state_I1) + assert np.allclose(I2.state, state_I2) + + def test_subsystem_cross_instance(self): + """Subsystem checkpoint loads into a fresh simulation instance.""" + If1 = Interface() + I1 = Integrator(1.0) + sub1 = Subsystem( + blocks=[If1, I1], + connections=[Connection(If1, I1), Connection(I1, If1)] + ) + src1 = Source(lambda t: 1.0) + sim1 = Simulation( + blocks=[src1, sub1], + connections=[Connection(src1, sub1)], + dt=0.01 + ) + sim1.run(0.5) + state_before = I1.state.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim1.save_checkpoint(path) + + #new instance + If2 = Interface() + I2 = Integrator(1.0) + sub2 = Subsystem( + blocks=[If2, I2], + connections=[Connection(If2, I2), Connection(I2, If2)] + ) + src2 = Source(lambda t: 1.0) + sim2 = Simulation( + blocks=[src2, sub2], + connections=[Connection(src2, sub2)], + dt=0.01 + ) + sim2.load_checkpoint(path) + assert np.allclose(I2.state, state_before) From 528b26ed6b352d2bd73990b60e3d9f565da57e65 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 09:42:02 +0100 Subject: [PATCH 20/29] Clean up: remove dead UUID code, fix method placement and spacing, deduplicate checkpoint key logic --- src/pathsim/blocks/_block.py | 4 --- src/pathsim/blocks/fir.py | 3 +- src/pathsim/blocks/kalman.py | 1 + src/pathsim/blocks/noise.py | 60 +++++++++++++++++--------------- src/pathsim/blocks/rng.py | 1 + src/pathsim/blocks/switch.py | 1 + src/pathsim/events/_event.py | 5 --- src/pathsim/subsystem.py | 20 ++++++----- tests/pathsim/test_checkpoint.py | 4 --- 9 files changed, 49 insertions(+), 50 deletions(-) diff --git a/src/pathsim/blocks/_block.py b/src/pathsim/blocks/_block.py index 347be870..6a3c3e3c 100644 --- a/src/pathsim/blocks/_block.py +++ b/src/pathsim/blocks/_block.py @@ -11,7 +11,6 @@ # IMPORTS =============================================================================== import inspect -from uuid import uuid4 from functools import lru_cache from ..utils.deprecation import deprecated @@ -85,9 +84,6 @@ class definition for other blocks to be inherited. def __init__(self): - #unique identifier for checkpointing and diagnostics - self.id = uuid4().hex - #registers to hold input and output values self.inputs = Register( mapping=self.input_port_labels and self.input_port_labels.copy() diff --git a/src/pathsim/blocks/fir.py b/src/pathsim/blocks/fir.py index c2766bc1..c9dc8ff7 100644 --- a/src/pathsim/blocks/fir.py +++ b/src/pathsim/blocks/fir.py @@ -120,6 +120,7 @@ def to_checkpoint(self, prefix, recordings=False): npz_data[f"{prefix}/fir_buffer"] = np.array(list(self._buffer)) return json_data, npz_data + def load_checkpoint(self, prefix, json_data, npz): """Restore FIR state including input buffer.""" super().load_checkpoint(prefix, json_data, npz) @@ -131,4 +132,4 @@ def load_checkpoint(self, prefix, json_data, npz): def __len__(self): """This block has no direct passthrough""" - return 0 \ No newline at end of file + return 0 diff --git a/src/pathsim/blocks/kalman.py b/src/pathsim/blocks/kalman.py index 98374a38..c835a7cc 100644 --- a/src/pathsim/blocks/kalman.py +++ b/src/pathsim/blocks/kalman.py @@ -150,6 +150,7 @@ def to_checkpoint(self, prefix, recordings=False): npz_data[f"{prefix}/kf_P"] = self.P return json_data, npz_data + def load_checkpoint(self, prefix, json_data, npz): """Restore Kalman filter state estimate and covariance.""" super().load_checkpoint(prefix, json_data, npz) diff --git a/src/pathsim/blocks/noise.py b/src/pathsim/blocks/noise.py index 555eb5be..0206a1b6 100644 --- a/src/pathsim/blocks/noise.py +++ b/src/pathsim/blocks/noise.py @@ -44,17 +44,6 @@ class WhiteNoise(Block): random seed for reproducibility """ - def to_checkpoint(self, prefix, recordings=False): - """Serialize WhiteNoise state including current sample.""" - json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) - json_data["_current_sample"] = float(self._current_sample) - return json_data, npz_data - - def load_checkpoint(self, prefix, json_data, npz): - """Restore WhiteNoise state including current sample.""" - super().load_checkpoint(prefix, json_data, npz) - self._current_sample = json_data.get("_current_sample", 0.0) - input_port_labels = {} output_port_labels = {"out": 0} @@ -135,6 +124,19 @@ def update(self, t): pass + def to_checkpoint(self, prefix, recordings=False): + """Serialize WhiteNoise state including current sample.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + json_data["_current_sample"] = float(self._current_sample) + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore WhiteNoise state including current sample.""" + super().load_checkpoint(prefix, json_data, npz) + self._current_sample = json_data.get("_current_sample", 0.0) + + class PinkNoise(Block): """Pink noise (1/f noise) source using the Voss-McCartney algorithm. @@ -167,22 +169,6 @@ class PinkNoise(Block): random seed for reproducibility """ - def to_checkpoint(self, prefix, recordings=False): - """Serialize PinkNoise state including algorithm state.""" - json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) - json_data["n_samples"] = self.n_samples - json_data["_current_sample"] = float(self._current_sample) - npz_data[f"{prefix}/octave_values"] = self.octave_values - return json_data, npz_data - - def load_checkpoint(self, prefix, json_data, npz): - """Restore PinkNoise state including algorithm state.""" - super().load_checkpoint(prefix, json_data, npz) - self.n_samples = json_data.get("n_samples", 0) - self._current_sample = json_data.get("_current_sample", 0.0) - if f"{prefix}/octave_values" in npz: - self.octave_values = npz[f"{prefix}/octave_values"] - input_port_labels = {} output_port_labels = {"out": 0} @@ -295,4 +281,22 @@ def sample(self, t, dt): def update(self, t): - pass \ No newline at end of file + pass + + + def to_checkpoint(self, prefix, recordings=False): + """Serialize PinkNoise state including algorithm state.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + json_data["n_samples"] = self.n_samples + json_data["_current_sample"] = float(self._current_sample) + npz_data[f"{prefix}/octave_values"] = self.octave_values + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore PinkNoise state including algorithm state.""" + super().load_checkpoint(prefix, json_data, npz) + self.n_samples = json_data.get("n_samples", 0) + self._current_sample = json_data.get("_current_sample", 0.0) + if f"{prefix}/octave_values" in npz: + self.octave_values = npz[f"{prefix}/octave_values"] \ No newline at end of file diff --git a/src/pathsim/blocks/rng.py b/src/pathsim/blocks/rng.py index 974e181e..72824107 100644 --- a/src/pathsim/blocks/rng.py +++ b/src/pathsim/blocks/rng.py @@ -103,6 +103,7 @@ def to_checkpoint(self, prefix, recordings=False): json_data["_sample"] = float(self._sample) return json_data, npz_data + def load_checkpoint(self, prefix, json_data, npz): """Restore RNG state including current sample.""" super().load_checkpoint(prefix, json_data, npz) diff --git a/src/pathsim/blocks/switch.py b/src/pathsim/blocks/switch.py index 8ee707be..f89f28ca 100644 --- a/src/pathsim/blocks/switch.py +++ b/src/pathsim/blocks/switch.py @@ -87,6 +87,7 @@ def to_checkpoint(self, prefix, recordings=False): json_data["switch_state"] = self.switch_state return json_data, npz_data + def load_checkpoint(self, prefix, json_data, npz): super().load_checkpoint(prefix, json_data, npz) self.switch_state = json_data.get("switch_state", None) diff --git a/src/pathsim/events/_event.py b/src/pathsim/events/_event.py index 58f657b5..1ebd1a9e 100644 --- a/src/pathsim/events/_event.py +++ b/src/pathsim/events/_event.py @@ -11,8 +11,6 @@ import numpy as np -from uuid import uuid4 - from .. _constants import EVT_TOLERANCE @@ -66,9 +64,6 @@ def __init__( tolerance=EVT_TOLERANCE ): - #unique identifier for checkpointing and diagnostics - self.id = uuid4().hex - #event detection function self.func_evt = func_evt diff --git a/src/pathsim/subsystem.py b/src/pathsim/subsystem.py index b515bd34..cea9740c 100644 --- a/src/pathsim/subsystem.py +++ b/src/pathsim/subsystem.py @@ -462,6 +462,16 @@ def reset(self): block.reset() + @staticmethod + def _checkpoint_key(type_name, type_counts): + """Generate a deterministic checkpoint key from block/event type + and occurrence index (e.g. 'Integrator_0', 'Scope_1'). + """ + idx = type_counts.get(type_name, 0) + type_counts[type_name] = idx + 1 + return f"{type_name}_{idx}" + + def to_checkpoint(self, prefix, recordings=False): """Serialize subsystem state by recursively checkpointing internal blocks. @@ -494,10 +504,7 @@ def to_checkpoint(self, prefix, recordings=False): #checkpoint internal blocks by type + insertion order type_counts = {} for block in self.blocks: - type_name = block.__class__.__name__ - idx = type_counts.get(type_name, 0) - type_counts[type_name] = idx + 1 - key = f"{prefix}/{type_name}_{idx}" + key = f"{prefix}/{self._checkpoint_key(block.__class__.__name__, type_counts)}" b_json, b_npz = block.to_checkpoint(key, recordings=recordings) b_json["_key"] = key json_data["blocks"].append(b_json) @@ -545,10 +552,7 @@ def load_checkpoint(self, prefix, json_data, npz): block_data = {b["_key"]: b for b in json_data.get("blocks", [])} type_counts = {} for block in self.blocks: - type_name = block.__class__.__name__ - idx = type_counts.get(type_name, 0) - type_counts[type_name] = idx + 1 - key = f"{prefix}/{type_name}_{idx}" + key = f"{prefix}/{self._checkpoint_key(block.__class__.__name__, type_counts)}" if key in block_data: block.load_checkpoint(key, block_data[key], npz) diff --git a/tests/pathsim/test_checkpoint.py b/tests/pathsim/test_checkpoint.py index ac4df056..6a930198 100644 --- a/tests/pathsim/test_checkpoint.py +++ b/tests/pathsim/test_checkpoint.py @@ -277,10 +277,6 @@ def test_cross_instance_load(self): dt=0.01 ) - #UUIDs differ - assert src1.id != src2.id - assert integ1.id != integ2.id - sim2.load_checkpoint(path) assert sim2.time == saved_time assert np.allclose(integ2.state, saved_state) From 18e6ef10343f4d17c82612bf55787dd80252edea Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 10:21:46 +0100 Subject: [PATCH 21/29] Fix solver n restoration, add GEAR/Spectrum/Scope/event coverage tests --- src/pathsim/solvers/_solver.py | 1 + tests/pathsim/test_checkpoint.py | 223 +++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+) diff --git a/src/pathsim/solvers/_solver.py b/src/pathsim/solvers/_solver.py index d235856e..d10bf16e 100644 --- a/src/pathsim/solvers/_solver.py +++ b/src/pathsim/solvers/_solver.py @@ -403,6 +403,7 @@ def load_checkpoint(self, json_data, npz, prefix): """ self.x = npz[f"{prefix}/x"].copy() self.initial_value = npz[f"{prefix}/initial_value"].copy() + self.n = json_data.get("n", self.n) #restore scalar format if needed if self._scalar_initial and self.initial_value.size == 1: diff --git a/tests/pathsim/test_checkpoint.py b/tests/pathsim/test_checkpoint.py index 6a930198..db480d44 100644 --- a/tests/pathsim/test_checkpoint.py +++ b/tests/pathsim/test_checkpoint.py @@ -518,3 +518,226 @@ def test_subsystem_cross_instance(self): ) sim2.load_checkpoint(path) assert np.allclose(I2.state, state_before) + + +class TestGEARCheckpoint: + """Test GEAR solver checkpoint round-trip.""" + + def test_gear_solver_roundtrip(self): + """GEAR solver state survives checkpoint including BDF coefficients.""" + from pathsim.solvers import GEAR32 + + src = Source(lambda t: np.sin(2 * np.pi * t)) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + Solver=GEAR32 + ) + + #run long enough for GEAR to exit startup phase + sim.run(0.5) + state_after = integ.state.copy() + time_after = sim.time + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + #reset state + integ.state = np.array([0.0]) + sim.time = 0.0 + + sim.load_checkpoint(path) + assert sim.time == time_after + assert np.allclose(integ.state, state_after) + + def test_gear_continue_after_load(self): + """GEAR simulation continues correctly after checkpoint load.""" + from pathsim.solvers import GEAR32 + + #reference: run 2s continuously + src1 = Source(lambda t: 1.0) + integ1 = Integrator() + sim1 = Simulation( + blocks=[src1, integ1], + connections=[Connection(src1, integ1)], + dt=0.01, + Solver=GEAR32 + ) + sim1.run(2.0) + reference = integ1.state.copy() + + #split: run 1s, save, load, run 1s more + src2 = Source(lambda t: 1.0) + integ2 = Integrator() + sim2 = Simulation( + blocks=[src2, integ2], + connections=[Connection(src2, integ2)], + dt=0.01, + Solver=GEAR32 + ) + sim2.run(1.0) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim2.save_checkpoint(path) + sim2.load_checkpoint(path) + sim2.run(1.0) + + assert np.allclose(integ2.state, reference, rtol=1e-6) + + +class TestSpectrumCheckpoint: + """Test Spectrum block checkpoint.""" + + def test_spectrum_roundtrip(self): + """Spectrum block state survives checkpoint round-trip.""" + from pathsim.blocks.spectrum import Spectrum + + src = Source(lambda t: np.sin(2 * np.pi * 10 * t)) + spec = Spectrum(freq=[5, 10, 15], t_wait=0.0) + sim = Simulation( + blocks=[src, spec], + connections=[Connection(src, spec)], + dt=0.001 + ) + sim.run(0.1) + + time_before = spec.time + t_sample_before = spec.t_sample + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + spec.time = 0.0 + spec.t_sample = 0.0 + + sim.load_checkpoint(path) + assert spec.time == pytest.approx(time_before) + assert spec.t_sample == pytest.approx(t_sample_before) + + +class TestScopeCheckpointExtended: + """Extended Scope checkpoint tests for coverage.""" + + def test_scope_with_sampling_period(self): + """Scope with sampling_period preserves _sample_next_timestep.""" + src = Source(lambda t: t) + scope = Scope(sampling_period=0.1) + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.01 + ) + sim.run(0.5) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + sim.load_checkpoint(path) + + #verify scope still works after load + sim.run(0.1) + assert len(scope.recording_time) > 0 + + def test_scope_recordings_roundtrip(self): + """Scope recording data round-trips with recordings=True.""" + src = Source(lambda t: t) + scope = Scope() + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.1 + ) + sim.run(1.0) + + rec_time = scope.recording_time.copy() + rec_data = [row.copy() for row in scope.recording_data] + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path, recordings=True) + + #clear recordings + scope.recording_time = [] + scope.recording_data = [] + + sim.load_checkpoint(path) + assert len(scope.recording_time) == len(rec_time) + assert np.allclose(scope.recording_time, rec_time) + + +class TestSimulationCheckpointExtended: + """Extended simulation checkpoint tests for coverage.""" + + def test_save_load_with_extension(self): + """Path with .json extension is handled correctly.""" + src = Source(lambda t: 1.0) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01 + ) + sim.run(0.1) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp.json") + sim.save_checkpoint(path) + + assert os.path.exists(os.path.join(tmpdir, "cp.json")) + assert os.path.exists(os.path.join(tmpdir, "cp.npz")) + + sim.load_checkpoint(path) + assert sim.time == pytest.approx(0.1, abs=0.01) + + def test_checkpoint_with_events(self): + """Simulation with external events checkpoints correctly.""" + from pathsim.events import Schedule + + src = Source(lambda t: 1.0) + integ = Integrator() + + event_fired = [False] + def on_event(t): + event_fired[0] = True + + evt = Schedule(t_start=0.5, t_period=1.0, func_act=on_event) + + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + events=[evt], + dt=0.01 + ) + sim.run(1.0) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + #verify events in JSON + with open(f"{path}.json") as f: + data = json.load(f) + assert len(data["events"]) == 1 + assert data["events"][0]["type"] == "Schedule" + + sim.load_checkpoint(path) + + def test_event_numpy_history(self): + """Event with numpy scalar in history serializes correctly.""" + from pathsim.events import ZeroCrossing + + e = ZeroCrossing(func_evt=lambda t: t - 1.0) + e._history = (np.float64(0.5), 0.99) + prefix = "ZeroCrossing_0" + + json_data, npz_data = e.to_checkpoint(prefix) + assert isinstance(json_data["history_eval"], float) + + e.reset() + e.load_checkpoint(prefix, json_data, npz_data) + assert e._history[0] == pytest.approx(0.5) From 1cb68470aacac59cc7130ef7b7852b086f63e803 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 10:43:22 +0100 Subject: [PATCH 22/29] Add checkpoint example notebook for docs --- docs/source/examples/checkpoints.ipynb | 211 +++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 docs/source/examples/checkpoints.ipynb diff --git a/docs/source/examples/checkpoints.ipynb b/docs/source/examples/checkpoints.ipynb new file mode 100644 index 00000000..e77eb6b8 --- /dev/null +++ b/docs/source/examples/checkpoints.ipynb @@ -0,0 +1,211 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Checkpoints\n", + "\n", + "PathSim supports saving and loading simulation state via checkpoints. This allows you to pause a simulation, save its complete state to disk, and resume it later from exactly where you left off.\n", + "\n", + "Checkpoints use a split format: a JSON file for metadata and structure, and an NPZ file for numerical data (block states, solver histories, etc.)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "We'll use a damped harmonic oscillator as our test system. First, let's run it continuously for 25 seconds as our reference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from pathsim import Simulation, Connection\n", + "from pathsim.blocks import Integrator, Amplifier, Adder, Scope" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# System parameters\n", + "x0, v0 = 2, 5\n", + "m, c, k = 0.8, 0.2, 1.5\n", + "\n", + "def make_system():\n", + " \"\"\"Helper to create a fresh harmonic oscillator simulation.\"\"\"\n", + " I1 = Integrator(v0)\n", + " I2 = Integrator(x0)\n", + " A1 = Amplifier(c)\n", + " A2 = Amplifier(k)\n", + " A3 = Amplifier(-1/m)\n", + " P1 = Adder()\n", + " Sc = Scope(labels=[\"velocity\", \"position\"])\n", + "\n", + " blocks = [I1, I2, A1, A2, A3, P1, Sc]\n", + " connections = [\n", + " Connection(I1, I2, A1, Sc), \n", + " Connection(I2, A2, Sc[1]),\n", + " Connection(A1, P1), \n", + " Connection(A2, P1[1]), \n", + " Connection(P1, A3),\n", + " Connection(A3, I1)\n", + " ]\n", + "\n", + " sim = Simulation(blocks, connections, dt=0.01)\n", + " return sim, Sc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reference Run\n", + "\n", + "Run the full simulation continuously for 25 seconds. This is our ground truth." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sim_ref, scope_ref = make_system()\n", + "sim_ref.run(25)\n", + "\n", + "time_ref, data_ref = scope_ref.read()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save Checkpoint\n", + "\n", + "Now let's run a second simulation, but only for the first 10 seconds. Then we save a checkpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sim_a, scope_a = make_system()\n", + "sim_a.run(10)\n", + "\n", + "# Save checkpoint (creates checkpoint.json and checkpoint.npz)\n", + "sim_a.save_checkpoint(\"checkpoint\")\n", + "print(f\"Saved checkpoint at t = {sim_a.time:.1f}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Checkpoint and Resume\n", + "\n", + "Create an entirely new simulation with fresh block objects, load the checkpoint, and continue for the remaining 15 seconds. The key point is that the new simulation has completely different Python objects, yet the checkpoint restores the exact state by matching blocks by type and insertion order." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sim_b, scope_b = make_system()\n", + "\n", + "# Load checkpoint into the fresh simulation\n", + "sim_b.load_checkpoint(\"checkpoint\")\n", + "print(f\"Resumed from t = {sim_b.time:.1f}s\")\n", + "\n", + "# Continue the simulation for the remaining 15 seconds\n", + "sim_b.run(15)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare Results\n", + "\n", + "Now let's overlay the reference (continuous run) with the checkpointed run (first 10s + resumed 15s). If checkpointing works correctly, they should be identical." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Read data from both scopes\ntime_a, data_a = scope_a.read()\ntime_b, data_b = scope_b.read()\n\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5), sharex=True)\n\n# Position (channel 1)\nax1.plot(time_ref, data_ref[1], \"k-\", alpha=0.3, lw=3, label=\"reference (continuous)\")\nax1.plot(time_a, data_a[1], \"C0-\", label=\"first half (0-10s)\")\nax1.plot(time_b, data_b[1], \"C1--\", label=\"resumed (10-25s)\")\nax1.axvline(10, color=\"gray\", ls=\":\", alpha=0.5, label=\"checkpoint\")\nax1.set_ylabel(\"position\")\nax1.legend(loc=\"upper right\", fontsize=8)\n\n# Velocity (channel 0)\nax2.plot(time_ref, data_ref[0], \"k-\", alpha=0.3, lw=3, label=\"reference (continuous)\")\nax2.plot(time_a, data_a[0], \"C0-\", label=\"first half (0-10s)\")\nax2.plot(time_b, data_b[0], \"C1--\", label=\"resumed (10-25s)\")\nax2.axvline(10, color=\"gray\", ls=\":\", alpha=0.5)\nax2.set_ylabel(\"velocity\")\nax2.set_xlabel(\"time [s]\")\n\nfig.suptitle(\"Checkpoint Save / Load\")\nfig.tight_layout()\nplt.show()" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The resumed simulation (dashed) seamlessly continues the reference (gray), confirming that the complete simulation state was correctly saved and restored across different Python objects." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Checkpoint File Contents\n", + "\n", + "The JSON file contains human-readable metadata about the simulation state. Let's inspect it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "with open(\"checkpoint.json\") as f:\n", + " cp = json.load(f)\n", + "\n", + "print(f\"PathSim version: {cp['pathsim_version']}\")\n", + "print(f\"Simulation time: {cp['simulation']['time']:.1f}s\")\n", + "print(f\"Solver: {cp['simulation']['solver']}\")\n", + "print(f\"Blocks saved: {len(cp['blocks'])}\")\n", + "for b in cp[\"blocks\"]:\n", + " print(f\" {b['_key']} ({b['type']})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Blocks are matched by type and insertion order (`Integrator_0`, `Integrator_1`, etc.), which means the checkpoint can be loaded into any simulation with the same block structure, regardless of the specific Python objects." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file From 456e7cef3c14cb7a2294d15c48a23f44e5a49a44 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 10:55:55 +0100 Subject: [PATCH 23/29] Rework checkpoint notebook: driven oscillator with rollback scenarios --- docs/source/examples/checkpoints.ipynb | 100 +++---------------------- 1 file changed, 12 insertions(+), 88 deletions(-) diff --git a/docs/source/examples/checkpoints.ipynb b/docs/source/examples/checkpoints.ipynb index e77eb6b8..aaa86177 100644 --- a/docs/source/examples/checkpoints.ipynb +++ b/docs/source/examples/checkpoints.ipynb @@ -3,22 +3,12 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "# Checkpoints\n", - "\n", - "PathSim supports saving and loading simulation state via checkpoints. This allows you to pause a simulation, save its complete state to disk, and resume it later from exactly where you left off.\n", - "\n", - "Checkpoints use a split format: a JSON file for metadata and structure, and an NPZ file for numerical data (block states, solver histories, etc.)." - ] + "source": "# Checkpoints\n\nPathSim supports saving and loading simulation state via checkpoints. This allows you to pause a simulation, save its complete state to disk, and resume it later from exactly where you left off. \n\nCheckpoints also enable **rollback**, where you return to a saved state and explore different what-if scenarios by changing parameters.\n\nCheckpoints use a split format: a JSON file for metadata and structure, and an NPZ file for numerical data (block states, solver histories, etc.)." }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "## Setup\n", - "\n", - "We'll use a damped harmonic oscillator as our test system. First, let's run it continuously for 25 seconds as our reference." - ] + "source": "## Setup\n\nWe'll simulate a driven harmonic oscillator — a mass-spring system excited by an external sinusoidal force. The system produces a sustained periodic response, making it easy to visually verify that checkpoints preserve continuity." }, { "cell_type": "code", @@ -38,126 +28,60 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# System parameters\n", - "x0, v0 = 2, 5\n", - "m, c, k = 0.8, 0.2, 1.5\n", - "\n", - "def make_system():\n", - " \"\"\"Helper to create a fresh harmonic oscillator simulation.\"\"\"\n", - " I1 = Integrator(v0)\n", - " I2 = Integrator(x0)\n", - " A1 = Amplifier(c)\n", - " A2 = Amplifier(k)\n", - " A3 = Amplifier(-1/m)\n", - " P1 = Adder()\n", - " Sc = Scope(labels=[\"velocity\", \"position\"])\n", - "\n", - " blocks = [I1, I2, A1, A2, A3, P1, Sc]\n", - " connections = [\n", - " Connection(I1, I2, A1, Sc), \n", - " Connection(I2, A2, Sc[1]),\n", - " Connection(A1, P1), \n", - " Connection(A2, P1[1]), \n", - " Connection(P1, A3),\n", - " Connection(A3, I1)\n", - " ]\n", - "\n", - " sim = Simulation(blocks, connections, dt=0.01)\n", - " return sim, Sc" - ] + "source": "import numpy as np\n\n# System parameters\nm = 1.0 # mass\nc = 0.1 # light damping\nk = 4.0 # spring stiffness\nF0 = 1.0 # forcing amplitude\nw = 1.8 # forcing frequency (near resonance for k/m=4 -> w0=2)\n\ndef make_system(damping=c, stiffness=k):\n \"\"\"Create a driven harmonic oscillator with configurable parameters.\"\"\"\n from pathsim.blocks import Source, Integrator, Amplifier, Adder, Scope\n\n Src = Source(lambda t: F0/m * np.sin(w * t)) # external acceleration\n I1 = Integrator(0.0) # velocity\n I2 = Integrator(0.5) # position (start displaced)\n Ac = Amplifier(-damping/m)\n Ak = Amplifier(-stiffness/m)\n P1 = Adder()\n Sc = Scope(labels=[\"position\"])\n\n blocks = [Src, I1, I2, Ac, Ak, P1, Sc]\n connections = [\n Connection(I1, I2, Ac), # velocity -> position integrator, damper\n Connection(I2, Ak, Sc), # position -> spring, scope\n Connection(Ac, P1), # -c/m * v -> adder\n Connection(Ak, P1[1]), # -k/m * x -> adder\n Connection(Src, P1[2]), # F/m -> adder\n Connection(P1, I1), # acceleration -> velocity integrator\n ]\n\n sim = Simulation(blocks, connections, dt=0.01)\n return sim, Sc" }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "## Reference Run\n", - "\n", - "Run the full simulation continuously for 25 seconds. This is our ground truth." - ] + "source": "## Save Checkpoint\n\nRun the simulation for 20 seconds, then save a checkpoint. The system will be in a sustained oscillation by this point." }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "sim_ref, scope_ref = make_system()\n", - "sim_ref.run(25)\n", - "\n", - "time_ref, data_ref = scope_ref.read()" - ] + "source": "sim, scope = make_system()\nsim.run(20)\n\n# Save checkpoint\nsim.save_checkpoint(\"checkpoint\")\nprint(f\"Saved checkpoint at t = {sim.time:.1f}s\")" }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "## Save Checkpoint\n", - "\n", - "Now let's run a second simulation, but only for the first 10 seconds. Then we save a checkpoint." - ] + "source": "## Resume from Checkpoint\n\nLoad the checkpoint into a fresh simulation and continue for another 20 seconds. The new simulation has completely different Python objects, yet the checkpoint restores the exact state by matching blocks by type and insertion order." }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "sim_a, scope_a = make_system()\n", - "sim_a.run(10)\n", - "\n", - "# Save checkpoint (creates checkpoint.json and checkpoint.npz)\n", - "sim_a.save_checkpoint(\"checkpoint\")\n", - "print(f\"Saved checkpoint at t = {sim_a.time:.1f}s\")" - ] + "source": "sim_resumed, scope_resumed = make_system()\nsim_resumed.load_checkpoint(\"checkpoint\")\nprint(f\"Resumed from t = {sim_resumed.time:.1f}s\")\n\nsim_resumed.run(20)" }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "## Load Checkpoint and Resume\n", - "\n", - "Create an entirely new simulation with fresh block objects, load the checkpoint, and continue for the remaining 15 seconds. The key point is that the new simulation has completely different Python objects, yet the checkpoint restores the exact state by matching blocks by type and insertion order." - ] + "source": "## Rollback: What-If Scenarios\n\nThis is where checkpoints really shine. We reload the same checkpoint but with **different parameters** — increasing the damping significantly. Both branches start from the exact same state at t=20, but evolve differently." }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "sim_b, scope_b = make_system()\n", - "\n", - "# Load checkpoint into the fresh simulation\n", - "sim_b.load_checkpoint(\"checkpoint\")\n", - "print(f\"Resumed from t = {sim_b.time:.1f}s\")\n", - "\n", - "# Continue the simulation for the remaining 15 seconds\n", - "sim_b.run(15)" - ] + "source": "# Scenario A: same parameters (continuation)\nsim_a, scope_a = make_system(damping=0.1)\nsim_a.load_checkpoint(\"checkpoint\")\nsim_a.run(20)\n\n# Scenario B: increased damping (what-if)\nsim_b, scope_b = make_system(damping=1.5)\nsim_b.load_checkpoint(\"checkpoint\")\nsim_b.run(20)\n\n# Scenario C: stiffer spring (what-if)\nsim_c, scope_c = make_system(stiffness=9.0)\nsim_c.load_checkpoint(\"checkpoint\")\nsim_c.run(20)" }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "## Compare Results\n", - "\n", - "Now let's overlay the reference (continuous run) with the checkpointed run (first 10s + resumed 15s). If checkpointing works correctly, they should be identical." - ] + "source": "## Compare Results\n\nThe plot shows the original simulation (0–20s), followed by three different futures branching from the same checkpoint." }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "# Read data from both scopes\ntime_a, data_a = scope_a.read()\ntime_b, data_b = scope_b.read()\n\nfig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 5), sharex=True)\n\n# Position (channel 1)\nax1.plot(time_ref, data_ref[1], \"k-\", alpha=0.3, lw=3, label=\"reference (continuous)\")\nax1.plot(time_a, data_a[1], \"C0-\", label=\"first half (0-10s)\")\nax1.plot(time_b, data_b[1], \"C1--\", label=\"resumed (10-25s)\")\nax1.axvline(10, color=\"gray\", ls=\":\", alpha=0.5, label=\"checkpoint\")\nax1.set_ylabel(\"position\")\nax1.legend(loc=\"upper right\", fontsize=8)\n\n# Velocity (channel 0)\nax2.plot(time_ref, data_ref[0], \"k-\", alpha=0.3, lw=3, label=\"reference (continuous)\")\nax2.plot(time_a, data_a[0], \"C0-\", label=\"first half (0-10s)\")\nax2.plot(time_b, data_b[0], \"C1--\", label=\"resumed (10-25s)\")\nax2.axvline(10, color=\"gray\", ls=\":\", alpha=0.5)\nax2.set_ylabel(\"velocity\")\nax2.set_xlabel(\"time [s]\")\n\nfig.suptitle(\"Checkpoint Save / Load\")\nfig.tight_layout()\nplt.show()" + "source": "time_orig, data_orig = scope.read()\ntime_a, data_a = scope_a.read()\ntime_b, data_b = scope_b.read()\ntime_c, data_c = scope_c.read()\n\nfig, ax = plt.subplots(figsize=(10, 4))\n\n# Original run (0-20s)\nax.plot(time_orig, data_orig[0], \"k-\", lw=1.5, label=\"original (c=0.1, k=4)\")\n\n# Three futures from checkpoint\nax.plot(time_a, data_a[0], \"C0-\", alpha=0.8, label=\"resumed (c=0.1, k=4)\")\nax.plot(time_b, data_b[0], \"C1-\", alpha=0.8, label=\"what-if: heavy damping (c=1.5)\")\nax.plot(time_c, data_c[0], \"C2-\", alpha=0.8, label=\"what-if: stiffer spring (k=9)\")\n\nax.axvline(20, color=\"gray\", ls=\":\", alpha=0.5, lw=2, label=\"checkpoint (t=20s)\")\nax.set_xlabel(\"time [s]\")\nax.set_ylabel(\"position\")\nax.set_title(\"Checkpoint Rollback: Three Futures from the Same State\")\nax.legend(loc=\"upper right\", fontsize=8)\nfig.tight_layout()\nplt.show()" }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "The resumed simulation (dashed) seamlessly continues the reference (gray), confirming that the complete simulation state was correctly saved and restored across different Python objects." - ] + "source": "All three scenarios start from the exact same state at t=20s. The blue continuation matches the original trajectory perfectly, while the heavy damping scenario (orange) decays rapidly and the stiffer spring scenario (green) shifts to a higher natural frequency." }, { "cell_type": "markdown", From c88fc29ef9e391e1baa2d8c3acc1eebcf1f255c3 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 10:59:27 +0100 Subject: [PATCH 24/29] Rewrite checkpoint notebook: coupled oscillators, flat style, rollback demo --- docs/source/examples/checkpoints.ipynb | 229 ++++++++++++++++++++++--- 1 file changed, 201 insertions(+), 28 deletions(-) diff --git a/docs/source/examples/checkpoints.ipynb b/docs/source/examples/checkpoints.ipynb index aaa86177..7d7f4eb7 100644 --- a/docs/source/examples/checkpoints.ipynb +++ b/docs/source/examples/checkpoints.ipynb @@ -3,12 +3,31 @@ { "cell_type": "markdown", "metadata": {}, - "source": "# Checkpoints\n\nPathSim supports saving and loading simulation state via checkpoints. This allows you to pause a simulation, save its complete state to disk, and resume it later from exactly where you left off. \n\nCheckpoints also enable **rollback**, where you return to a saved state and explore different what-if scenarios by changing parameters.\n\nCheckpoints use a split format: a JSON file for metadata and structure, and an NPZ file for numerical data (block states, solver histories, etc.)." + "source": [ + "# Checkpoints\n", + "\n", + "PathSim supports saving and loading simulation state via checkpoints. This allows you to pause a simulation, save its complete state to disk, and resume it later from exactly where you left off.\n", + "\n", + "Checkpoints also enable **rollback** — returning to a saved state and exploring different what-if scenarios by changing parameters.\n", + "\n", + "Checkpoints use a split format: a JSON file for metadata and structure, and an NPZ file for numerical data (block states, solver histories, etc.)." + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## Setup\n\nWe'll simulate a driven harmonic oscillator — a mass-spring system excited by an external sinusoidal force. The system produces a sustained periodic response, making it easy to visually verify that checkpoints preserve continuity." + "source": [ + "## Building the System\n", + "\n", + "We'll use the coupled oscillators system to demonstrate checkpoints. The energy exchange between the two oscillators produces a sustained, non-trivial response that makes it easy to visually verify checkpoint continuity." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "First let's import the :class:`.Simulation` and :class:`.Connection` classes and the required blocks:" + ] }, { "cell_type": "code", @@ -20,7 +39,14 @@ "import matplotlib.pyplot as plt\n", "\n", "from pathsim import Simulation, Connection\n", - "from pathsim.blocks import Integrator, Amplifier, Adder, Scope" + "from pathsim.blocks import ODE, Function, Scope" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define the system parameters:" ] }, { @@ -28,68 +54,207 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "import numpy as np\n\n# System parameters\nm = 1.0 # mass\nc = 0.1 # light damping\nk = 4.0 # spring stiffness\nF0 = 1.0 # forcing amplitude\nw = 1.8 # forcing frequency (near resonance for k/m=4 -> w0=2)\n\ndef make_system(damping=c, stiffness=k):\n \"\"\"Create a driven harmonic oscillator with configurable parameters.\"\"\"\n from pathsim.blocks import Source, Integrator, Amplifier, Adder, Scope\n\n Src = Source(lambda t: F0/m * np.sin(w * t)) # external acceleration\n I1 = Integrator(0.0) # velocity\n I2 = Integrator(0.5) # position (start displaced)\n Ac = Amplifier(-damping/m)\n Ak = Amplifier(-stiffness/m)\n P1 = Adder()\n Sc = Scope(labels=[\"position\"])\n\n blocks = [Src, I1, I2, Ac, Ak, P1, Sc]\n connections = [\n Connection(I1, I2, Ac), # velocity -> position integrator, damper\n Connection(I2, Ak, Sc), # position -> spring, scope\n Connection(Ac, P1), # -c/m * v -> adder\n Connection(Ak, P1[1]), # -k/m * x -> adder\n Connection(Src, P1[2]), # F/m -> adder\n Connection(P1, I1), # acceleration -> velocity integrator\n ]\n\n sim = Simulation(blocks, connections, dt=0.01)\n return sim, Sc" + "source": [ + "# Mass parameters\n", + "m1 = 1.0\n", + "m2 = 1.5\n", + "\n", + "# Spring constants\n", + "k1 = 2.0\n", + "k2 = 3.0\n", + "k12 = 0.5 # coupling spring\n", + "\n", + "# Damping coefficients\n", + "c1 = 0.02\n", + "c2 = 0.03\n", + "\n", + "# Initial conditions [position, velocity]\n", + "x1_0 = np.array([2.0, 0.0]) # oscillator 1 displaced\n", + "x2_0 = np.array([0.0, 0.0]) # oscillator 2 at rest" + ] }, { - "cell_type": "markdown", + "cell_type": "raw", "metadata": {}, - "source": "## Save Checkpoint\n\nRun the simulation for 20 seconds, then save a checkpoint. The system will be in a sustained oscillation by this point." + "source": [ + "Define the differential equations for each oscillator using :class:`.ODE` blocks and the coupling force using a :class:`.Function` block:" + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "sim, scope = make_system()\nsim.run(20)\n\n# Save checkpoint\nsim.save_checkpoint(\"checkpoint\")\nprint(f\"Saved checkpoint at t = {sim.time:.1f}s\")" + "source": [ + "# Oscillator 1: m1*x1'' = -k1*x1 - c1*x1' - k12*(x1 - x2)\n", + "def osc1_func(x1, u, t):\n", + " f_e = u[0]\n", + " return np.array([x1[1], (-k1*x1[0] - c1*x1[1] - f_e) / m1])\n", + "\n", + "# Oscillator 2: m2*x2'' = -k2*x2 - c2*x2' + k12*(x1 - x2)\n", + "def osc2_func(x2, u, t):\n", + " f_e = u[0]\n", + " return np.array([x2[1], (-k2*x2[0] - c2*x2[1] - f_e) / m2])\n", + "\n", + "# Coupling force\n", + "def coupling_func(x1, x2):\n", + " f = k12 * (x1 - x2)\n", + " return f, -f\n", + "\n", + "# Blocks\n", + "osc1 = ODE(osc1_func, x1_0)\n", + "osc2 = ODE(osc2_func, x2_0)\n", + "fn = Function(coupling_func)\n", + "sc = Scope(labels=[r\"$x_1(t)$ - Oscillator 1\", r\"$x_2(t)$ - Oscillator 2\"])\n", + "\n", + "blocks = [osc1, osc2, fn, sc]\n", + "\n", + "# Connections\n", + "connections = [\n", + " Connection(osc1[0], fn[0], sc[0]),\n", + " Connection(osc2[0], fn[1], sc[1]),\n", + " Connection(fn[0], osc1[0]),\n", + " Connection(fn[1], osc2[0]),\n", + "]" + ] }, { - "cell_type": "markdown", + "cell_type": "raw", "metadata": {}, - "source": "## Resume from Checkpoint\n\nLoad the checkpoint into a fresh simulation and continue for another 20 seconds. The new simulation has completely different Python objects, yet the checkpoint restores the exact state by matching blocks by type and insertion order." + "source": [ + "Create the :class:`.Simulation` and run for 60 seconds:" + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "sim_resumed, scope_resumed = make_system()\nsim_resumed.load_checkpoint(\"checkpoint\")\nprint(f\"Resumed from t = {sim_resumed.time:.1f}s\")\n\nsim_resumed.run(20)" + "source": [ + "sim = Simulation(blocks, connections, dt=0.01)\n", + "\n", + "sim.run(60)\n", + "\n", + "fig, ax = sc.plot()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The two oscillators exchange energy through the coupling spring, producing a characteristic beat pattern." + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## Rollback: What-If Scenarios\n\nThis is where checkpoints really shine. We reload the same checkpoint but with **different parameters** — increasing the damping significantly. Both branches start from the exact same state at t=20, but evolve differently." + "source": [ + "## Saving a Checkpoint\n", + "\n", + "Now let's save the simulation state at t=60s. This creates two files: `coupled.json` (metadata) and `coupled.npz` (numerical data)." + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "# Scenario A: same parameters (continuation)\nsim_a, scope_a = make_system(damping=0.1)\nsim_a.load_checkpoint(\"checkpoint\")\nsim_a.run(20)\n\n# Scenario B: increased damping (what-if)\nsim_b, scope_b = make_system(damping=1.5)\nsim_b.load_checkpoint(\"checkpoint\")\nsim_b.run(20)\n\n# Scenario C: stiffer spring (what-if)\nsim_c, scope_c = make_system(stiffness=9.0)\nsim_c.load_checkpoint(\"checkpoint\")\nsim_c.run(20)" + "source": [ + "sim.save_checkpoint(\"coupled\")\n", + "print(f\"Checkpoint saved at t = {sim.time:.1f}s\")" + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## Compare Results\n\nThe plot shows the original simulation (0–20s), followed by three different futures branching from the same checkpoint." + "source": [ + "We can inspect the JSON file to see what was saved:" + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "time_orig, data_orig = scope.read()\ntime_a, data_a = scope_a.read()\ntime_b, data_b = scope_b.read()\ntime_c, data_c = scope_c.read()\n\nfig, ax = plt.subplots(figsize=(10, 4))\n\n# Original run (0-20s)\nax.plot(time_orig, data_orig[0], \"k-\", lw=1.5, label=\"original (c=0.1, k=4)\")\n\n# Three futures from checkpoint\nax.plot(time_a, data_a[0], \"C0-\", alpha=0.8, label=\"resumed (c=0.1, k=4)\")\nax.plot(time_b, data_b[0], \"C1-\", alpha=0.8, label=\"what-if: heavy damping (c=1.5)\")\nax.plot(time_c, data_c[0], \"C2-\", alpha=0.8, label=\"what-if: stiffer spring (k=9)\")\n\nax.axvline(20, color=\"gray\", ls=\":\", alpha=0.5, lw=2, label=\"checkpoint (t=20s)\")\nax.set_xlabel(\"time [s]\")\nax.set_ylabel(\"position\")\nax.set_title(\"Checkpoint Rollback: Three Futures from the Same State\")\nax.legend(loc=\"upper right\", fontsize=8)\nfig.tight_layout()\nplt.show()" + "source": [ + "import json\n", + "\n", + "with open(\"coupled.json\") as f:\n", + " cp = json.load(f)\n", + "\n", + "print(f\"PathSim version: {cp['pathsim_version']}\")\n", + "print(f\"Simulation time: {cp['simulation']['time']:.1f}s\")\n", + "print(f\"Solver: {cp['simulation']['solver']}\")\n", + "print(f\"Blocks saved:\")\n", + "for b in cp[\"blocks\"]:\n", + " print(f\" {b['_key']} ({b['type']})\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Blocks are identified by type and insertion order (``ODE_0``, ``ODE_1``, etc.), so the checkpoint can be loaded into any simulation with the same block structure, regardless of the specific Python objects." + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "All three scenarios start from the exact same state at t=20s. The blue continuation matches the original trajectory perfectly, while the heavy damping scenario (orange) decays rapidly and the stiffer spring scenario (green) shifts to a higher natural frequency." + "source": [ + "## Rollback: What-If Scenarios\n", + "\n", + "This is where checkpoints really shine. We'll load the same checkpoint three times with different coupling strengths to explore how the system evolves from the exact same state.\n", + "\n", + "Since the checkpoint restores all block states by type and insertion order, we just need to rebuild the simulation with the same block structure but different parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_scenario(k12_new, duration=60):\n", + " \"\"\"Load checkpoint and continue with a different coupling constant.\"\"\"\n", + " def coupling_new(x1, x2):\n", + " f = k12_new * (x1 - x2)\n", + " return f, -f\n", + "\n", + " o1 = ODE(osc1_func, x1_0)\n", + " o2 = ODE(osc2_func, x2_0)\n", + " f = Function(coupling_new)\n", + " s = Scope()\n", + "\n", + " sim = Simulation(\n", + " [o1, o2, f, s],\n", + " [Connection(o1[0], f[0], s[0]),\n", + " Connection(o2[0], f[1], s[1]),\n", + " Connection(f[0], o1[0]),\n", + " Connection(f[1], o2[0])],\n", + " dt=0.01\n", + " )\n", + " sim.load_checkpoint(\"coupled\")\n", + " sim.run(duration)\n", + " return s.read()\n", + "\n", + "# Original coupling (continuation)\n", + "t_a, d_a = run_scenario(k12_new=0.5)\n", + "\n", + "# Stronger coupling\n", + "t_b, d_b = run_scenario(k12_new=2.0)\n", + "\n", + "# Decoupled\n", + "t_c, d_c = run_scenario(k12_new=0.0)" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Checkpoint File Contents\n", + "## Comparing the Scenarios\n", "\n", - "The JSON file contains human-readable metadata about the simulation state. Let's inspect it." + "The plot shows the original run (0-60s) followed by three different futures branching from the checkpoint at t=60s. We show oscillator 1 for clarity." ] }, { @@ -98,24 +263,32 @@ "metadata": {}, "outputs": [], "source": [ - "import json\n", + "time_orig, data_orig = sc.read()\n", "\n", - "with open(\"checkpoint.json\") as f:\n", - " cp = json.load(f)\n", + "fig, ax = plt.subplots(figsize=(10, 4))\n", "\n", - "print(f\"PathSim version: {cp['pathsim_version']}\")\n", - "print(f\"Simulation time: {cp['simulation']['time']:.1f}s\")\n", - "print(f\"Solver: {cp['simulation']['solver']}\")\n", - "print(f\"Blocks saved: {len(cp['blocks'])}\")\n", - "for b in cp[\"blocks\"]:\n", - " print(f\" {b['_key']} ({b['type']})\")" + "# Original run\n", + "ax.plot(time_orig, data_orig[0], \"k-\", lw=1.5, label=r\"original ($k_{12}=0.5$)\")\n", + "\n", + "# Three futures from checkpoint\n", + "ax.plot(t_a, d_a[0], \"C0-\", alpha=0.8, label=r\"continued ($k_{12}=0.5$)\")\n", + "ax.plot(t_b, d_b[0], \"C1-\", alpha=0.8, label=r\"stronger coupling ($k_{12}=2.0$)\")\n", + "ax.plot(t_c, d_c[0], \"C2-\", alpha=0.8, label=r\"decoupled ($k_{12}=0$)\")\n", + "\n", + "ax.axvline(60, color=\"gray\", ls=\":\", lw=2, alpha=0.5, label=\"checkpoint\")\n", + "ax.set_xlabel(\"time [s]\")\n", + "ax.set_ylabel(r\"$x_1(t)$\")\n", + "ax.set_title(\"Checkpoint Rollback: Three Futures from the Same State\")\n", + "ax.legend(loc=\"upper right\", fontsize=8)\n", + "fig.tight_layout()\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Blocks are matched by type and insertion order (`Integrator_0`, `Integrator_1`, etc.), which means the checkpoint can be loaded into any simulation with the same block structure, regardless of the specific Python objects." + "All three scenarios start from the exact same state at t=60s. The blue continuation matches the original trajectory perfectly, confirming checkpoint fidelity. The stronger coupling (orange) produces faster energy exchange, while the decoupled system (green) oscillates independently at its natural frequency." ] } ], @@ -132,4 +305,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} From 81f2cff62196a3b9ae0372fc9b88b575da41c2d5 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 11:14:34 +0100 Subject: [PATCH 25/29] Include scope recordings in checkpoints by default --- src/pathsim/simulation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index 64402bd7..ed50a5ae 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -363,7 +363,7 @@ class name of the block or event return f"{type_name}_{idx}" - def save_checkpoint(self, path, recordings=False): + def save_checkpoint(self, path, recordings=True): """Save simulation state to checkpoint files (JSON + NPZ). Creates two files: {path}.json (structure/metadata) and @@ -375,7 +375,7 @@ def save_checkpoint(self, path, recordings=False): path : str base path without extension recordings : bool - include scope/spectrum recording data (default: False) + include scope/spectrum recording data (default: True) """ #strip extension if provided if path.endswith('.json') or path.endswith('.npz'): From 64efb9168ae5eacd34bf0890ee515467b22c82fc Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Tue, 17 Mar 2026 11:18:35 +0100 Subject: [PATCH 26/29] Add test verifying recordings are included by default --- tests/pathsim/test_checkpoint.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/pathsim/test_checkpoint.py b/tests/pathsim/test_checkpoint.py index db480d44..94fed3d6 100644 --- a/tests/pathsim/test_checkpoint.py +++ b/tests/pathsim/test_checkpoint.py @@ -669,6 +669,32 @@ def test_scope_recordings_roundtrip(self): assert len(scope.recording_time) == len(rec_time) assert np.allclose(scope.recording_time, rec_time) + def test_scope_recordings_included_by_default(self): + """Default save_checkpoint includes recordings.""" + src = Source(lambda t: t) + scope = Scope() + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.1 + ) + sim.run(1.0) + + rec_time = scope.recording_time.copy() + assert len(rec_time) > 0 + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) # no recordings kwarg — default + + #clear recordings + scope.recording_time = [] + scope.recording_data = [] + + sim.load_checkpoint(path) + assert len(scope.recording_time) == len(rec_time) + assert np.allclose(scope.recording_time, rec_time) + class TestSimulationCheckpointExtended: """Extended simulation checkpoint tests for coverage.""" From 9fe1d2617a21a8016d7a64ca5ad571caf17ae0e1 Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Thu, 19 Mar 2026 11:36:49 +0100 Subject: [PATCH 27/29] Drop shadow sets, use plain lists for blocks/connections/events --- src/pathsim/simulation.py | 43 ++++++++++++--------------------------- src/pathsim/subsystem.py | 23 ++++++++------------- 2 files changed, 21 insertions(+), 45 deletions(-) diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index ed50a5ae..6f6dc9ed 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -179,13 +179,10 @@ def __init__( **solver_kwargs ): - #system definition (ordered lists with shadow sets for O(1) lookup) + #system definition self.blocks = [] - self._block_set = set() self.connections = [] - self._conn_set = set() self.events = [] - self._event_set = set() #simulation timestep and bounds self.dt = dt @@ -222,11 +219,9 @@ def __init__( #collection of blocks with internal ODE solvers self._blocks_dyn = [] - self._blocks_dyn_set = set() #collection of blocks with internal events self._blocks_evt = [] - self._blocks_evt_set = set() #flag for setting the simulation active self._active = True @@ -277,9 +272,9 @@ def __contains__(self, other): bool """ return ( - other in self._block_set or - other in self._conn_set or - other in self._event_set + other in self.blocks or + other in self.connections or + other in self.events ) @@ -519,7 +514,7 @@ def add_block(self, block): """ #check if block already in block list - if block in self._block_set: + if block in self.blocks: _msg = f"block {block} already part of simulation" self.logger.error(_msg) raise ValueError(_msg) @@ -530,16 +525,13 @@ def add_block(self, block): #add to dynamic list if solver was initialized if block.engine: self._blocks_dyn.append(block) - self._blocks_dyn_set.add(block) #add to eventful list if internal events if block.events: self._blocks_evt.append(block) - self._blocks_evt_set.add(block) #add block to global blocklist self.blocks.append(block) - self._block_set.add(block) #mark graph for rebuild if self.graph: @@ -559,24 +551,21 @@ def remove_block(self, block): """ #check if block is in block list - if block not in self._block_set: + if block not in self.blocks: _msg = f"block {block} not part of simulation" self.logger.error(_msg) raise ValueError(_msg) #remove from global blocklist self.blocks.remove(block) - self._block_set.discard(block) #remove from dynamic list - if block in self._blocks_dyn_set: + if block in self._blocks_dyn: self._blocks_dyn.remove(block) - self._blocks_dyn_set.discard(block) #remove from eventful list - if block in self._blocks_evt_set: + if block in self._blocks_evt: self._blocks_evt.remove(block) - self._blocks_evt_set.discard(block) #mark graph for rebuild if self.graph: @@ -596,14 +585,13 @@ def add_connection(self, connection): """ #check if connection already in connection list - if connection in self._conn_set: + if connection in self.connections: _msg = f"{connection} already part of simulation" self.logger.error(_msg) raise ValueError(_msg) #add connection to global connection list self.connections.append(connection) - self._conn_set.add(connection) #mark graph for rebuild if self.graph: @@ -623,14 +611,13 @@ def remove_connection(self, connection): """ #check if connection is in connection list - if connection not in self._conn_set: + if connection not in self.connections: _msg = f"{connection} not part of simulation" self.logger.error(_msg) raise ValueError(_msg) #remove from global connection list self.connections.remove(connection) - self._conn_set.discard(connection) #mark graph for rebuild if self.graph: @@ -649,14 +636,13 @@ def add_event(self, event): """ #check if event already in event list - if event in self._event_set: + if event in self.events: _msg = f"{event} already part of simulation" self.logger.error(_msg) raise ValueError(_msg) #add event to global event list self.events.append(event) - self._event_set.add(event) def remove_event(self, event): @@ -671,14 +657,13 @@ def remove_event(self, event): """ #check if event is in event list - if event not in self._event_set: + if event not in self.events: _msg = f"{event} not part of simulation" self.logger.error(_msg) raise ValueError(_msg) #remove from global event list self.events.remove(event) - self._event_set.discard(event) # system assembly ------------------------------------------------------------- @@ -737,7 +722,7 @@ def _check_blocks_are_managed(self): # Check subset actively managed for blk in conn_blocks: - if blk not in self._block_set: + if blk not in self.blocks: self.logger.warning( f"{blk} in 'connections' but not in 'blocks'!" ) @@ -772,14 +757,12 @@ def _set_solver(self, Solver=None, **solver_kwargs): #iterate all blocks and set integration engines with tolerances self._blocks_dyn = [] - self._blocks_dyn_set = set() for block in self.blocks: block.set_solver(self.Solver, self.engine, **self.solver_kwargs) #add dynamic blocks to list if block.engine: self._blocks_dyn.append(block) - self._blocks_dyn_set.add(block) #logging message self.logger.info( diff --git a/src/pathsim/subsystem.py b/src/pathsim/subsystem.py index cea9740c..6dec9691 100644 --- a/src/pathsim/subsystem.py +++ b/src/pathsim/subsystem.py @@ -181,14 +181,12 @@ def __init__(self, #internal algebraic loop solvers -> initialized later self.boosters = None - #internal connecions (ordered list with shadow set for O(1) lookup) + #internal connecions self.connections = list(connections) if connections else [] - self._conn_set = set(self.connections) #collect and organize internal blocks - self.blocks = [] - self._block_set = set() - self.interface = None + self.blocks = [] + self.interface = None if blocks: for block in blocks: @@ -202,7 +200,6 @@ def __init__(self, else: #regular blocks self.blocks.append(block) - self._block_set.add(block) #check if interface is defined if self.interface is None: @@ -253,7 +250,7 @@ def __contains__(self, other): ------- bool """ - return other in self._block_set or other in self._conn_set + return other in self.blocks or other in self.connections # adding and removing system components --------------------------------------------------- @@ -268,7 +265,7 @@ def add_block(self, block): block : Block block to add to the subsystem """ - if block in self._block_set: + if block in self.blocks: raise ValueError(f"block {block} already part of subsystem") #initialize solver if available @@ -278,7 +275,6 @@ def add_block(self, block): self._blocks_dyn.append(block) self.blocks.append(block) - self._block_set.add(block) if self.graph: self._graph_dirty = True @@ -294,11 +290,10 @@ def remove_block(self, block): block : Block block to remove from the subsystem """ - if block not in self._block_set: + if block not in self.blocks: raise ValueError(f"block {block} not part of subsystem") self.blocks.remove(block) - self._block_set.discard(block) #remove from dynamic list if hasattr(self, '_blocks_dyn') and block in self._blocks_dyn: @@ -318,11 +313,10 @@ def add_connection(self, connection): connection : Connection connection to add to the subsystem """ - if connection in self._conn_set: + if connection in self.connections: raise ValueError(f"{connection} already part of subsystem") self.connections.append(connection) - self._conn_set.add(connection) if self.graph: self._graph_dirty = True @@ -338,11 +332,10 @@ def remove_connection(self, connection): connection : Connection connection to remove from the subsystem """ - if connection not in self._conn_set: + if connection not in self.connections: raise ValueError(f"{connection} not part of subsystem") self.connections.remove(connection) - self._conn_set.discard(connection) if self.graph: self._graph_dirty = True From e198b21b2d9ef5a8072704a7ef7a0e3eb2a75f7d Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Wed, 25 Mar 2026 09:56:21 +0100 Subject: [PATCH 28/29] Add auto-redirect from RTD to docs.pathsim.org --- docs/source/_static/redirect.js | 8 ++++++++ docs/source/conf.py | 2 ++ docs/source/index.rst | 6 +++--- 3 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 docs/source/_static/redirect.js diff --git a/docs/source/_static/redirect.js b/docs/source/_static/redirect.js new file mode 100644 index 00000000..93887caa --- /dev/null +++ b/docs/source/_static/redirect.js @@ -0,0 +1,8 @@ +// Redirect visitors from RTD to docs.pathsim.org after a brief delay. +// The banner is shown immediately; redirect fires after 3 seconds +// so users understand what's happening. Click the link to go immediately. +(function () { + if (window.location.hostname.indexOf('readthedocs') === -1) return; + var target = 'https://docs.pathsim.org'; + setTimeout(function () { window.location.replace(target); }, 3000); +})(); diff --git a/docs/source/conf.py b/docs/source/conf.py index 49d68f97..1fc03c44 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -40,6 +40,8 @@ html_favicon = "logos/pathsim_icon.png" html_title = "PathSim Documentation" html_css_files = ['custom.css'] # Add custom CSS for link previews and styling +html_js_files = ['redirect.js'] # Auto-redirect RTD visitors to docs.pathsim.org +html_baseurl = 'https://docs.pathsim.org/' # Canonical URL for SEO html_theme_options = { "light_css_variables": { diff --git a/docs/source/index.rst b/docs/source/index.rst index 129aa25c..bb32c065 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,11 +8,11 @@ PathSim

- 📢 We've moved to a new documentation site! + 📢 Redirecting to the new documentation site...

- This legacy documentation will remain available but is no longer updated. - Visit docs.pathsim.org for the latest docs with interactive examples, and pathsim.org for the new homepage. + This legacy documentation is no longer updated. You will be redirected to + docs.pathsim.org in a few seconds.

From 3db0ca6e7a1fdca4274b1c629276dba1b25bea9a Mon Sep 17 00:00:00 2001 From: Milan Rother Date: Sat, 21 Mar 2026 09:25:08 +0100 Subject: [PATCH 29/29] Add ConvergenceTracker/StepTracker classes with optional diagnostics snapshots --- src/pathsim/simulation.py | 141 +++++++----- src/pathsim/utils/diagnostics.py | 244 +++++++++++++++++++++ tests/pathsim/test_diagnostics.py | 353 ++++++++++++++++++++++++++++++ 3 files changed, 679 insertions(+), 59 deletions(-) create mode 100644 src/pathsim/utils/diagnostics.py create mode 100644 tests/pathsim/test_diagnostics.py diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index 6f6dc9ed..11f08aff 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -37,6 +37,7 @@ from .utils.deprecation import deprecated from .utils.portreference import PortReference from .utils.progresstracker import ProgressTracker +from .utils.diagnostics import Diagnostics, ConvergenceTracker, StepTracker from .utils.logger import LoggerManager from .solvers import SSPRK22, SteadyState @@ -165,17 +166,18 @@ class Simulation: """ def __init__( - self, - blocks=None, - connections=None, + self, + blocks=None, + connections=None, events=None, - dt=SIM_TIMESTEP, - dt_min=SIM_TIMESTEP_MIN, - dt_max=SIM_TIMESTEP_MAX, - Solver=SSPRK22, - tolerance_fpi=SIM_TOLERANCE_FPI, - iterations_max=SIM_ITERATIONS_MAX, + dt=SIM_TIMESTEP, + dt_min=SIM_TIMESTEP_MIN, + dt_max=SIM_TIMESTEP_MAX, + Solver=SSPRK22, + tolerance_fpi=SIM_TOLERANCE_FPI, + iterations_max=SIM_ITERATIONS_MAX, log=LOG_ENABLE, + diagnostics=False, **solver_kwargs ): @@ -226,6 +228,17 @@ def __init__( #flag for setting the simulation active self._active = True + #convergence trackers for the three solver loops + self._loop_tracker = ConvergenceTracker() + self._solve_tracker = ConvergenceTracker() + self._step_tracker = StepTracker() + + #diagnostics snapshot (None when disabled) + self.diagnostics = Diagnostics() if diagnostics else None + + #diagnostics history (list of snapshots per timestep) + self._diagnostics_history = [] if diagnostics == "history" else None + #initialize logging logger_mgr = LoggerManager( enabled=bool(self.log), @@ -815,6 +828,15 @@ def reset(self, time=0.0): for event in self.events: event.reset() + #reset convergence trackers and diagnostics + self._loop_tracker.reset() + self._solve_tracker.reset() + self._step_tracker.reset() + if self.diagnostics is not None: + self.diagnostics = Diagnostics() + if self._diagnostics_history is not None: + self._diagnostics_history.clear() + #evaluate system function self._update(self.time) @@ -1026,19 +1048,20 @@ def _loops(self, t): if connection: connection.update() #step boosters of loop closing connections - max_err = 0.0 + self._loop_tracker.begin_iteration() for con_booster in self.boosters: - err = con_booster.update() - if err > max_err: - max_err = err - + self._loop_tracker.record(con_booster, con_booster.update()) + #check convergence - if max_err <= self.tolerance_fpi: + if self._loop_tracker.converged(self.tolerance_fpi): + self._loop_tracker.iterations = iteration return - #not converged -> error - _msg = "algebraic loop not converged (iters: {}, err: {})".format( - self.iterations_max, max_err + #not converged -> error with per-connection details + self._loop_tracker.iterations = self.iterations_max + details = self._loop_tracker.details(lambda b: str(b.connection)) + _msg = "algebraic loop not converged (iters: {}, err: {:.2e})\n{}".format( + self.iterations_max, self._loop_tracker.max_error, "\n".join(details) ) self.logger.error(_msg) raise RuntimeError(_msg) @@ -1080,26 +1103,21 @@ def _solve(self, t, dt): #evaluate system equation (this is a fixed point loop) self._update(t) - total_evals += 1 + total_evals += 1 #advance solution of implicit solver - max_error = 0.0 + self._solve_tracker.begin_iteration() for block in self._blocks_dyn: - - #skip inactive blocks - if not block: + if not block: continue - - #advance solution (internal optimizer) - error = block.solve(t, dt) - if error > max_error: - max_error = error + self._solve_tracker.record(block, block.solve(t, dt)) - #check for convergence (only error) - if max_error <= self.tolerance_fpi: - return True, total_evals, it+1 + #check for convergence + if self._solve_tracker.converged(self.tolerance_fpi): + self._solve_tracker.iterations = it + 1 + return True, total_evals, it + 1 - #not converged in 'self.iterations_max' steps + self._solve_tracker.iterations = self.iterations_max return False, total_evals, self.iterations_max @@ -1156,8 +1174,9 @@ def steadystate(self, reset=False): #catch non convergence if not success: - _msg = "STEADYSTATE -> FINISHED (success: {}, evals: {}, iters: {}, runtime: {})".format( - success, evals, iters, T) + details = self._solve_tracker.details(lambda b: b.__class__.__name__) + _msg = "STEADYSTATE -> FAILED (evals: {}, iters: {}, runtime: {})\n{}".format( + evals, iters, T, "\n".join(details)) self.logger.error(_msg) raise RuntimeError(_msg) @@ -1276,32 +1295,14 @@ def _step(self, t, dt): rescale factor for timestep """ - #initial timestep rescale and error estimate - success, max_error_norm, min_scale = True, 0.0, None + self._step_tracker.reset() - #step blocks and get error estimates if available for block in self._blocks_dyn: - - #skip inactive blocks if not block: continue - - #step the block suc, err_norm, scl = block.step(t, dt) + self._step_tracker.record(block, suc, err_norm, scl) - #check solver stepping success - if not suc: - success = False - - #update error tracking - if err_norm > max_error_norm: - max_error_norm = err_norm - - #track minimum relevant scale directly (avoids list allocation) - if scl is not None: - if min_scale is None or scl < min_scale: - min_scale = scl - - return success, max_error_norm, min_scale if min_scale is not None else 1.0 + return self._step_tracker.success, self._step_tracker.max_error, self._step_tracker.scale # timestepping ---------------------------------------------------------------- @@ -1469,10 +1470,19 @@ def timestep(self, dt=None, adaptive=True): total_evals += evals total_solver_its += solver_its - #adaptive implicit: revert if solver didn't converge - if not success and is_adaptive: - self._revert(self.time) - return False, 0.0, 0.5, total_evals + 1, total_solver_its + #implicit solver didn't converge + if not success: + details = self._solve_tracker.details(lambda b: b.__class__.__name__) + if is_adaptive: + self.logger.warning( + "implicit solver not converged, reverting step at t={:.6f}\n{}".format( + time_stage, "\n".join(details))) + self._revert(self.time) + return False, 0.0, 0.5, total_evals + 1, total_solver_its + else: + self.logger.warning( + "implicit solver not converged at t={:.6f} (iters: {})\n{}".format( + time_stage, solver_its, "\n".join(details))) else: #explicit: evaluate system equation self._update(time_stage) @@ -1511,6 +1521,19 @@ def timestep(self, dt=None, adaptive=True): self._update(time_dt) total_evals += 1 + #update diagnostics snapshot for this timestep + if self.diagnostics is not None: + self.diagnostics = Diagnostics( + time=time_dt, + loop_residuals=dict(self._loop_tracker.errors), + loop_iterations=self._loop_tracker.iterations, + solve_residuals=dict(self._solve_tracker.errors), + solve_iterations=self._solve_tracker.iterations, + step_errors=dict(self._step_tracker.errors), + ) + if self._diagnostics_history is not None: + self._diagnostics_history.append(self.diagnostics) + #sample data after successful timestep self._sample(time_dt, dt) diff --git a/src/pathsim/utils/diagnostics.py b/src/pathsim/utils/diagnostics.py new file mode 100644 index 00000000..b58c86da --- /dev/null +++ b/src/pathsim/utils/diagnostics.py @@ -0,0 +1,244 @@ +######################################################################################### +## +## CONVERGENCE TRACKING AND DIAGNOSTICS +## (utils/diagnostics.py) +## +## Convergence tracker classes for the simulation solver loops +## and optional per-timestep diagnostics snapshot. +## +######################################################################################### + +# IMPORTS =============================================================================== + +from dataclasses import dataclass, field + + +# CONVERGENCE TRACKER =================================================================== + +class ConvergenceTracker: + """Tracks per-object scalar errors and convergence for fixed-point loops. + + Used by the algebraic loop solver (keyed by ConnectionBooster) and + the implicit ODE solver (keyed by Block). + + Attributes + ---------- + errors : dict + object -> float, per-object error from most recent iteration + max_error : float + maximum error across all objects in current iteration + iterations : int + number of iterations taken + """ + + __slots__ = ('errors', 'max_error', 'iterations') + + def __init__(self): + self.errors = {} + self.max_error = 0.0 + self.iterations = 0 + + + def reset(self): + """Clear all state.""" + self.errors.clear() + self.max_error = 0.0 + self.iterations = 0 + + + def begin_iteration(self): + """Reset per-iteration state before sweeping objects.""" + self.errors.clear() + self.max_error = 0.0 + + + def record(self, obj, error): + """Record a single object's error and update the running max.""" + self.errors[obj] = error + if error > self.max_error: + self.max_error = error + + + def converged(self, tolerance): + """Check if max error is within tolerance.""" + return self.max_error <= tolerance + + + def details(self, label_fn): + """Format per-object error breakdown for error messages. + + Parameters + ---------- + label_fn : callable + obj -> str, produces a human-readable label + + Returns + ------- + list[str] + formatted lines like " Integrator: 1.23e-04" + """ + return [f" {label_fn(obj)}: {err:.2e}" for obj, err in self.errors.items()] + + +# STEP TRACKER ========================================================================== + +class StepTracker: + """Tracks per-block adaptive step results. + + Used by the adaptive error control loop. Each block produces a tuple + (success, err_norm, scale) and this tracker aggregates them. + + Attributes + ---------- + errors : dict + block -> (success, err_norm, scale) from most recent step + success : bool + AND of all block successes + max_error : float + maximum error norm across all blocks + min_scale : float | None + minimum scale factor (None if no block provides one) + """ + + __slots__ = ('errors', 'success', 'max_error', 'min_scale') + + def __init__(self): + self.errors = {} + self.success = True + self.max_error = 0.0 + self.min_scale = None + + + def reset(self): + """Clear state for a new step.""" + self.errors.clear() + self.success = True + self.max_error = 0.0 + self.min_scale = None + + + def record(self, block, success, err_norm, scale): + """Record a single block's step result.""" + self.errors[block] = (success, err_norm, scale) + if not success: + self.success = False + if err_norm > self.max_error: + self.max_error = err_norm + if scale is not None: + if self.min_scale is None or scale < self.min_scale: + self.min_scale = scale + + + @property + def scale(self): + """Effective scale factor (1.0 when no block provides one).""" + return self.min_scale if self.min_scale is not None else 1.0 + + +# DIAGNOSTICS SNAPSHOT ================================================================== + +@dataclass +class Diagnostics: + """Per-timestep convergence diagnostics snapshot. + + Populated by the simulation engine after each successful timestep + from the three convergence trackers. Provides read-only accessors + for the worst offending block or connection. + + Attributes + ---------- + time : float + simulation time + loop_residuals : dict + per-booster algebraic loop residuals (booster -> residual) + loop_iterations : int + number of algebraic loop iterations taken + solve_residuals : dict + per-block implicit solver residuals (block -> residual) + solve_iterations : int + number of implicit solver iterations taken + step_errors : dict + per-block adaptive step data (block -> (success, err_norm, scale)) + """ + time: float = 0.0 + loop_residuals: dict = field(default_factory=dict) + loop_iterations: int = 0 + solve_residuals: dict = field(default_factory=dict) + solve_iterations: int = 0 + step_errors: dict = field(default_factory=dict) + + + @staticmethod + def _label(obj): + """Human-readable label for a block or booster.""" + if hasattr(obj, 'connection'): + return str(obj.connection) + return obj.__class__.__name__ + + + def worst_block(self): + """Block with the highest residual across solve and step errors. + + Returns + ------- + tuple[str, float] or None + (label, error) or None if no data + """ + worst, worst_err = None, -1.0 + + for obj, err in self.solve_residuals.items(): + if err > worst_err: + worst, worst_err = obj, err + + for obj, (_, err_norm, _) in self.step_errors.items(): + if err_norm > worst_err: + worst, worst_err = obj, err_norm + + if worst is None: + return None + return self._label(worst), worst_err + + + def worst_booster(self): + """Connection booster with the highest algebraic loop residual. + + Returns + ------- + tuple[str, float] or None + (label, residual) or None if no data + """ + if not self.loop_residuals: + return None + + worst = max(self.loop_residuals, key=self.loop_residuals.get) + return self._label(worst), self.loop_residuals[worst] + + + def summary(self): + """Formatted summary of this diagnostics snapshot. + + Returns + ------- + str + human-readable diagnostics summary + """ + lines = [f"Diagnostics at t = {self.time:.6f}"] + + if self.step_errors: + lines.append(f"\n Adaptive step errors:") + for obj, (suc, err, scl) in self.step_errors.items(): + status = "OK" if suc else "FAIL" + scl_str = f"{scl:.3f}" if scl is not None else "N/A" + lines.append(f" {status} {self._label(obj)}: err={err:.2e}, scale={scl_str}") + + if self.solve_residuals: + lines.append(f"\n Implicit solver residuals ({self.solve_iterations} iterations):") + for obj, err in self.solve_residuals.items(): + lines.append(f" {self._label(obj)}: {err:.2e}") + + if self.loop_residuals: + lines.append(f"\n Algebraic loop residuals ({self.loop_iterations} iterations):") + for obj, err in self.loop_residuals.items(): + lines.append(f" {self._label(obj)}: {err:.2e}") + + return "\n".join(lines) diff --git a/tests/pathsim/test_diagnostics.py b/tests/pathsim/test_diagnostics.py new file mode 100644 index 00000000..9f9514d2 --- /dev/null +++ b/tests/pathsim/test_diagnostics.py @@ -0,0 +1,353 @@ +"""Tests for simulation diagnostics.""" + +import unittest +import numpy as np + +from pathsim import Simulation, Connection +from pathsim.blocks import Source, Integrator, Amplifier, Adder, Scope +from pathsim.utils.diagnostics import Diagnostics, ConvergenceTracker, StepTracker + + +class TestDiagnosticsOff(unittest.TestCase): + """Verify diagnostics=False (default) has no side effects.""" + + def test_diagnostics_none_by_default(self): + src = Source(lambda t: 1.0) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01 + ) + self.assertIsNone(sim.diagnostics) + sim.run(0.1) + self.assertIsNone(sim.diagnostics) + + +class TestDiagnosticsExplicitSolver(unittest.TestCase): + """Diagnostics with an explicit solver (step errors only).""" + + def test_snapshot_after_run(self): + src = Source(lambda t: np.sin(t)) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + diagnostics=True + ) + sim.run(0.1) + + diag = sim.diagnostics + self.assertIsInstance(diag, Diagnostics) + self.assertAlmostEqual(diag.time, sim.time, places=6) + + #explicit solver: step errors should be populated + self.assertGreater(len(diag.step_errors), 0) + first_key = list(diag.step_errors.keys())[0] + self.assertEqual(first_key.__class__.__name__, "Integrator") + + #no implicit solver or algebraic loops + self.assertEqual(len(diag.solve_residuals), 0) + self.assertEqual(len(diag.loop_residuals), 0) + + def test_worst_block(self): + src = Source(lambda t: 1.0) + i1 = Integrator() + i2 = Integrator() + sim = Simulation( + blocks=[src, i1, i2], + connections=[Connection(src, i1), Connection(i1, i2)], + dt=0.01, + diagnostics=True + ) + sim.run(0.1) + + result = sim.diagnostics.worst_block() + self.assertIsNotNone(result) + label, err = result + self.assertIn("Integrator", label) + + def test_summary_string(self): + src = Source(lambda t: 1.0) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + diagnostics=True + ) + sim.run(0.1) + + summary = sim.diagnostics.summary() + self.assertIn("Diagnostics at t", summary) + self.assertIn("Integrator", summary) + + def test_reset_clears_diagnostics(self): + src = Source(lambda t: 1.0) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + diagnostics=True + ) + sim.run(0.1) + self.assertGreater(sim.diagnostics.time, 0) + + sim.reset() + self.assertEqual(sim.diagnostics.time, 0.0) + + +class TestDiagnosticsAdaptiveSolver(unittest.TestCase): + """Diagnostics with an adaptive solver.""" + + def test_adaptive_step_errors(self): + from pathsim.solvers import RKCK54 + + src = Source(lambda t: np.sin(10 * t)) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.1, + Solver=RKCK54, + tolerance_lte_abs=1e-6, + tolerance_lte_rel=1e-4, + diagnostics=True + ) + sim.run(1.0) + + diag = sim.diagnostics + self.assertIsInstance(diag, Diagnostics) + self.assertGreater(len(diag.step_errors), 0) + + +class TestDiagnosticsImplicitSolver(unittest.TestCase): + """Diagnostics with an implicit solver (solve residuals).""" + + def test_implicit_solve_residuals(self): + from pathsim.solvers import ESDIRK32 + + src = Source(lambda t: np.sin(t)) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + Solver=ESDIRK32, + diagnostics=True + ) + sim.run(0.1) + + diag = sim.diagnostics + self.assertIsInstance(diag, Diagnostics) + self.assertGreater(len(diag.solve_residuals), 0) + self.assertGreater(diag.solve_iterations, 0) + + #worst_block should find the block from solve residuals + result = diag.worst_block() + self.assertIsNotNone(result) + + #summary should include implicit solver section + summary = diag.summary() + self.assertIn("Implicit solver residuals", summary) + + +class TestDiagnosticsAlgebraicLoop(unittest.TestCase): + """Diagnostics with algebraic loops (loop residuals).""" + + def test_algebraic_loop_residuals(self): + src = Source(lambda t: 1.0) + P1 = Adder() + A1 = Amplifier(0.5) + sco = Scope() + + sim = Simulation( + blocks=[src, P1, A1, sco], + connections=[ + Connection(src, P1), + Connection(P1, A1, sco), + Connection(A1, P1[1]), + ], + dt=0.01, + diagnostics=True + ) + + self.assertTrue(sim.graph.has_loops) + sim.run(0.05) + + diag = sim.diagnostics + self.assertGreater(len(diag.loop_residuals), 0) + + result = diag.worst_booster() + self.assertIsNotNone(result) + + #summary should include algebraic loop section + summary = diag.summary() + self.assertIn("Algebraic loop residuals", summary) + + +class TestDiagnosticsHistory(unittest.TestCase): + """Diagnostics history recording.""" + + def test_no_history_by_default(self): + src = Source(lambda t: 1.0) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + diagnostics=True + ) + sim.run(0.1) + + self.assertIsNone(sim._diagnostics_history) + self.assertIsInstance(sim.diagnostics, Diagnostics) + + def test_history_enabled(self): + src = Source(lambda t: 1.0) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + diagnostics="history" + ) + sim.run(0.1) + + #should have ~10 snapshots (0.1s / 0.01 dt) + self.assertGreater(len(sim._diagnostics_history), 5) + + #each snapshot should have a time + times = [s.time for s in sim._diagnostics_history] + self.assertEqual(times, sorted(times)) + + def test_history_reset(self): + src = Source(lambda t: 1.0) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + diagnostics="history" + ) + sim.run(0.05) + self.assertGreater(len(sim._diagnostics_history), 0) + + sim.reset() + self.assertEqual(len(sim._diagnostics_history), 0) + + +class TestDiagnosticsUnit(unittest.TestCase): + """Unit tests for the Diagnostics dataclass.""" + + def test_defaults(self): + d = Diagnostics() + self.assertEqual(d.time, 0.0) + self.assertIsNone(d.worst_block()) + self.assertIsNone(d.worst_booster()) + + def test_worst_block_from_step_errors(self): + + class FakeBlock: + pass + + b1, b2 = FakeBlock(), FakeBlock() + d = Diagnostics(step_errors={b1: (True, 1e-3, 0.9), b2: (True, 5e-3, 0.7)}) + + label, err = d.worst_block() + self.assertAlmostEqual(err, 5e-3) + + def test_worst_block_from_solve_residuals(self): + + class FakeBlock: + pass + + b1, b2 = FakeBlock(), FakeBlock() + d = Diagnostics(solve_residuals={b1: 1e-4, b2: 3e-3}) + + label, err = d.worst_block() + self.assertAlmostEqual(err, 3e-3) + + def test_summary_with_all_data(self): + + class FakeBlock: + pass + + class FakeBooster: + class connection: + def __str__(self): + return "A -> B" + connection = connection() + + b = FakeBlock() + bst = FakeBooster() + d = Diagnostics( + time=1.0, + step_errors={b: (True, 1e-4, 0.9)}, + solve_residuals={b: 1e-8}, + solve_iterations=3, + loop_residuals={bst: 1e-12}, + loop_iterations=2, + ) + + summary = d.summary() + self.assertIn("Diagnostics at t", summary) + self.assertIn("Adaptive step errors", summary) + self.assertIn("Implicit solver residuals", summary) + self.assertIn("Algebraic loop residuals", summary) + + +class TestConvergenceTrackerUnit(unittest.TestCase): + """Unit tests for ConvergenceTracker.""" + + def test_record_and_converge(self): + t = ConvergenceTracker() + t.record("a", 1e-5) + t.record("b", 1e-8) + self.assertAlmostEqual(t.max_error, 1e-5) + self.assertTrue(t.converged(1e-4)) + self.assertFalse(t.converged(1e-6)) + + def test_begin_iteration_clears(self): + t = ConvergenceTracker() + t.record("a", 1.0) + t.begin_iteration() + self.assertEqual(len(t.errors), 0) + self.assertEqual(t.max_error, 0.0) + + def test_details(self): + t = ConvergenceTracker() + t.record("block_a", 1e-3) + t.record("block_b", 2e-4) + lines = t.details(lambda obj: f"name:{obj}") + self.assertEqual(len(lines), 2) + self.assertIn("name:block_a", lines[0]) + + +class TestStepTrackerUnit(unittest.TestCase): + """Unit tests for StepTracker.""" + + def test_record_aggregation(self): + t = StepTracker() + t.record("a", True, 1e-4, 0.9) + t.record("b", False, 2e-3, 0.5) + t.record("c", True, 1e-5, None) + + self.assertFalse(t.success) + self.assertAlmostEqual(t.max_error, 2e-3) + self.assertAlmostEqual(t.scale, 0.5) + + def test_scale_default(self): + t = StepTracker() + t.record("a", True, 0.0, None) + self.assertEqual(t.scale, 1.0) + + def test_reset(self): + t = StepTracker() + t.record("a", False, 1.0, 0.1) + t.reset() + self.assertTrue(t.success) + self.assertEqual(t.max_error, 0.0) + self.assertEqual(len(t.errors), 0)