diff --git a/docs/source/examples/checkpoints.ipynb b/docs/source/examples/checkpoints.ipynb new file mode 100644 index 00000000..7d7f4eb7 --- /dev/null +++ b/docs/source/examples/checkpoints.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Checkpoints\n", + "\n", + "PathSim supports saving and loading simulation state via checkpoints. This allows you to pause a simulation, save its complete state to disk, and resume it later from exactly where you left off.\n", + "\n", + "Checkpoints also enable **rollback** — returning to a saved state and exploring different what-if scenarios by changing parameters.\n", + "\n", + "Checkpoints use a split format: a JSON file for metadata and structure, and an NPZ file for numerical data (block states, solver histories, etc.)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Building the System\n", + "\n", + "We'll use the coupled oscillators system to demonstrate checkpoints. The energy exchange between the two oscillators produces a sustained, non-trivial response that makes it easy to visually verify checkpoint continuity." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "First let's import the :class:`.Simulation` and :class:`.Connection` classes and the required blocks:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from pathsim import Simulation, Connection\n", + "from pathsim.blocks import ODE, Function, Scope" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define the system parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mass parameters\n", + "m1 = 1.0\n", + "m2 = 1.5\n", + "\n", + "# Spring constants\n", + "k1 = 2.0\n", + "k2 = 3.0\n", + "k12 = 0.5 # coupling spring\n", + "\n", + "# Damping coefficients\n", + "c1 = 0.02\n", + "c2 = 0.03\n", + "\n", + "# Initial conditions [position, velocity]\n", + "x1_0 = np.array([2.0, 0.0]) # oscillator 1 displaced\n", + "x2_0 = np.array([0.0, 0.0]) # oscillator 2 at rest" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Define the differential equations for each oscillator using :class:`.ODE` blocks and the coupling force using a :class:`.Function` block:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Oscillator 1: m1*x1'' = -k1*x1 - c1*x1' - k12*(x1 - x2)\n", + "def osc1_func(x1, u, t):\n", + " f_e = u[0]\n", + " return np.array([x1[1], (-k1*x1[0] - c1*x1[1] - f_e) / m1])\n", + "\n", + "# Oscillator 2: m2*x2'' = -k2*x2 - c2*x2' + k12*(x1 - x2)\n", + "def osc2_func(x2, u, t):\n", + " f_e = u[0]\n", + " return np.array([x2[1], (-k2*x2[0] - c2*x2[1] - f_e) / m2])\n", + "\n", + "# Coupling force\n", + "def coupling_func(x1, x2):\n", + " f = k12 * (x1 - x2)\n", + " return f, -f\n", + "\n", + "# Blocks\n", + "osc1 = ODE(osc1_func, x1_0)\n", + "osc2 = ODE(osc2_func, x2_0)\n", + "fn = Function(coupling_func)\n", + "sc = Scope(labels=[r\"$x_1(t)$ - Oscillator 1\", r\"$x_2(t)$ - Oscillator 2\"])\n", + "\n", + "blocks = [osc1, osc2, fn, sc]\n", + "\n", + "# Connections\n", + "connections = [\n", + " Connection(osc1[0], fn[0], sc[0]),\n", + " Connection(osc2[0], fn[1], sc[1]),\n", + " Connection(fn[0], osc1[0]),\n", + " Connection(fn[1], osc2[0]),\n", + "]" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Create the :class:`.Simulation` and run for 60 seconds:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sim = Simulation(blocks, connections, dt=0.01)\n", + "\n", + "sim.run(60)\n", + "\n", + "fig, ax = sc.plot()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The two oscillators exchange energy through the coupling spring, producing a characteristic beat pattern." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving a Checkpoint\n", + "\n", + "Now let's save the simulation state at t=60s. This creates two files: `coupled.json` (metadata) and `coupled.npz` (numerical data)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sim.save_checkpoint(\"coupled\")\n", + "print(f\"Checkpoint saved at t = {sim.time:.1f}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can inspect the JSON file to see what was saved:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "with open(\"coupled.json\") as f:\n", + " cp = json.load(f)\n", + "\n", + "print(f\"PathSim version: {cp['pathsim_version']}\")\n", + "print(f\"Simulation time: {cp['simulation']['time']:.1f}s\")\n", + "print(f\"Solver: {cp['simulation']['solver']}\")\n", + "print(f\"Blocks saved:\")\n", + "for b in cp[\"blocks\"]:\n", + " print(f\" {b['_key']} ({b['type']})\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Blocks are identified by type and insertion order (``ODE_0``, ``ODE_1``, etc.), so the checkpoint can be loaded into any simulation with the same block structure, regardless of the specific Python objects." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rollback: What-If Scenarios\n", + "\n", + "This is where checkpoints really shine. We'll load the same checkpoint three times with different coupling strengths to explore how the system evolves from the exact same state.\n", + "\n", + "Since the checkpoint restores all block states by type and insertion order, we just need to rebuild the simulation with the same block structure but different parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_scenario(k12_new, duration=60):\n", + " \"\"\"Load checkpoint and continue with a different coupling constant.\"\"\"\n", + " def coupling_new(x1, x2):\n", + " f = k12_new * (x1 - x2)\n", + " return f, -f\n", + "\n", + " o1 = ODE(osc1_func, x1_0)\n", + " o2 = ODE(osc2_func, x2_0)\n", + " f = Function(coupling_new)\n", + " s = Scope()\n", + "\n", + " sim = Simulation(\n", + " [o1, o2, f, s],\n", + " [Connection(o1[0], f[0], s[0]),\n", + " Connection(o2[0], f[1], s[1]),\n", + " Connection(f[0], o1[0]),\n", + " Connection(f[1], o2[0])],\n", + " dt=0.01\n", + " )\n", + " sim.load_checkpoint(\"coupled\")\n", + " sim.run(duration)\n", + " return s.read()\n", + "\n", + "# Original coupling (continuation)\n", + "t_a, d_a = run_scenario(k12_new=0.5)\n", + "\n", + "# Stronger coupling\n", + "t_b, d_b = run_scenario(k12_new=2.0)\n", + "\n", + "# Decoupled\n", + "t_c, d_c = run_scenario(k12_new=0.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing the Scenarios\n", + "\n", + "The plot shows the original run (0-60s) followed by three different futures branching from the checkpoint at t=60s. We show oscillator 1 for clarity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "time_orig, data_orig = sc.read()\n", + "\n", + "fig, ax = plt.subplots(figsize=(10, 4))\n", + "\n", + "# Original run\n", + "ax.plot(time_orig, data_orig[0], \"k-\", lw=1.5, label=r\"original ($k_{12}=0.5$)\")\n", + "\n", + "# Three futures from checkpoint\n", + "ax.plot(t_a, d_a[0], \"C0-\", alpha=0.8, label=r\"continued ($k_{12}=0.5$)\")\n", + "ax.plot(t_b, d_b[0], \"C1-\", alpha=0.8, label=r\"stronger coupling ($k_{12}=2.0$)\")\n", + "ax.plot(t_c, d_c[0], \"C2-\", alpha=0.8, label=r\"decoupled ($k_{12}=0$)\")\n", + "\n", + "ax.axvline(60, color=\"gray\", ls=\":\", lw=2, alpha=0.5, label=\"checkpoint\")\n", + "ax.set_xlabel(\"time [s]\")\n", + "ax.set_ylabel(r\"$x_1(t)$\")\n", + "ax.set_title(\"Checkpoint Rollback: Three Futures from the Same State\")\n", + "ax.legend(loc=\"upper right\", fontsize=8)\n", + "fig.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "All three scenarios start from the exact same state at t=60s. The blue continuation matches the original trajectory perfectly, confirming checkpoint fidelity. The stronger coupling (orange) produces faster energy exchange, while the decoupled system (green) oscillates independently at its natural frequency." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/pathsim/blocks/_block.py b/src/pathsim/blocks/_block.py index 4b4275fb..6a3c3e3c 100644 --- a/src/pathsim/blocks/_block.py +++ b/src/pathsim/blocks/_block.py @@ -524,6 +524,93 @@ def state(self, val): self.engine.state = val + # checkpoint methods ---------------------------------------------------------------- + + def to_checkpoint(self, prefix, recordings=False): + """Serialize block state for checkpointing. + + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + recordings : bool + include recording data (for Scope blocks) + + Returns + ------- + json_data : dict + JSON-serializable metadata + npz_data : dict + numpy arrays keyed by path + """ + json_data = { + "type": self.__class__.__name__, + "active": self._active, + } + + npz_data = { + f"{prefix}/inputs": self.inputs.to_array(), + f"{prefix}/outputs": self.outputs.to_array(), + } + + #solver state + if self.engine: + e_json, e_npz = self.engine.to_checkpoint(f"{prefix}/engine") + json_data["engine"] = e_json + npz_data.update(e_npz) + + #internal events + if self.events: + evt_jsons = [] + for i, event in enumerate(self.events): + evt_prefix = f"{prefix}/evt_{i}" + e_json, e_npz = event.to_checkpoint(evt_prefix) + evt_jsons.append(e_json) + npz_data.update(e_npz) + json_data["events"] = evt_jsons + + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore block state from checkpoint. + + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + json_data : dict + block metadata from checkpoint JSON + npz : dict-like + numpy arrays from checkpoint NPZ + """ + #verify type + if json_data["type"] != self.__class__.__name__: + raise ValueError( + f"Checkpoint type mismatch: expected '{self.__class__.__name__}', " + f"got '{json_data['type']}'" + ) + + self._active = json_data["active"] + + #restore registers + inp_key = f"{prefix}/inputs" + out_key = f"{prefix}/outputs" + if inp_key in npz: + self.inputs.update_from_array(npz[inp_key]) + if out_key in npz: + self.outputs.update_from_array(npz[out_key]) + + #restore solver state + if self.engine and "engine" in json_data: + self.engine.load_checkpoint(json_data["engine"], npz, f"{prefix}/engine") + + #restore internal events + if self.events and "events" in json_data: + for i, (event, evt_data) in enumerate(zip(self.events, json_data["events"])): + event.load_checkpoint(f"{prefix}/evt_{i}", evt_data, npz) + + # methods for block output and state updates ---------------------------------------- def update(self, t): diff --git a/src/pathsim/blocks/delay.py b/src/pathsim/blocks/delay.py index 6bcbed8b..4e6d0a4f 100644 --- a/src/pathsim/blocks/delay.py +++ b/src/pathsim/blocks/delay.py @@ -142,6 +142,39 @@ def reset(self): self._ring.extend([0.0] * self._n) + def to_checkpoint(self, prefix, recordings=False): + """Serialize Delay state including buffer data.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + + json_data["sampling_period"] = self.sampling_period + + if self.sampling_period is None: + #continuous mode: adaptive buffer + npz_data.update(self._buffer.to_checkpoint(f"{prefix}/buffer")) + else: + #discrete mode: ring buffer + npz_data[f"{prefix}/ring"] = np.array(list(self._ring)) + json_data["_sample_next_timestep"] = self._sample_next_timestep + + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore Delay state including buffer data.""" + super().load_checkpoint(prefix, json_data, npz) + + if self.sampling_period is None: + #continuous mode + self._buffer.load_checkpoint(npz, f"{prefix}/buffer") + else: + #discrete mode + ring_key = f"{prefix}/ring" + if ring_key in npz: + self._ring.clear() + self._ring.extend(npz[ring_key].tolist()) + self._sample_next_timestep = json_data.get("_sample_next_timestep", False) + + def update(self, t): """Evaluation of the buffer at different times via interpolation (continuous) or ring buffer lookup (discrete). diff --git a/src/pathsim/blocks/fir.py b/src/pathsim/blocks/fir.py index 8db1a8a3..c9dc8ff7 100644 --- a/src/pathsim/blocks/fir.py +++ b/src/pathsim/blocks/fir.py @@ -114,6 +114,22 @@ def reset(self): self._buffer = deque([0.0]*n, maxlen=n) + def to_checkpoint(self, prefix, recordings=False): + """Serialize FIR state including input buffer.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + npz_data[f"{prefix}/fir_buffer"] = np.array(list(self._buffer)) + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore FIR state including input buffer.""" + super().load_checkpoint(prefix, json_data, npz) + key = f"{prefix}/fir_buffer" + if key in npz: + self._buffer.clear() + self._buffer.extend(npz[key].tolist()) + + def __len__(self): """This block has no direct passthrough""" - return 0 \ No newline at end of file + return 0 diff --git a/src/pathsim/blocks/kalman.py b/src/pathsim/blocks/kalman.py index 783ae537..c835a7cc 100644 --- a/src/pathsim/blocks/kalman.py +++ b/src/pathsim/blocks/kalman.py @@ -143,6 +143,23 @@ def __len__(self): return 0 + def to_checkpoint(self, prefix, recordings=False): + """Serialize Kalman filter state estimate and covariance.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + npz_data[f"{prefix}/kf_x"] = self.x + npz_data[f"{prefix}/kf_P"] = self.P + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore Kalman filter state estimate and covariance.""" + super().load_checkpoint(prefix, json_data, npz) + if f"{prefix}/kf_x" in npz: + self.x = npz[f"{prefix}/kf_x"] + if f"{prefix}/kf_P" in npz: + self.P = npz[f"{prefix}/kf_P"] + + def _kf_update(self): """Perform one Kalman filter update step.""" diff --git a/src/pathsim/blocks/noise.py b/src/pathsim/blocks/noise.py index 101828ea..0206a1b6 100644 --- a/src/pathsim/blocks/noise.py +++ b/src/pathsim/blocks/noise.py @@ -124,6 +124,19 @@ def update(self, t): pass + def to_checkpoint(self, prefix, recordings=False): + """Serialize WhiteNoise state including current sample.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + json_data["_current_sample"] = float(self._current_sample) + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore WhiteNoise state including current sample.""" + super().load_checkpoint(prefix, json_data, npz) + self._current_sample = json_data.get("_current_sample", 0.0) + + class PinkNoise(Block): """Pink noise (1/f noise) source using the Voss-McCartney algorithm. @@ -268,4 +281,22 @@ def sample(self, t, dt): def update(self, t): - pass \ No newline at end of file + pass + + + def to_checkpoint(self, prefix, recordings=False): + """Serialize PinkNoise state including algorithm state.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + json_data["n_samples"] = self.n_samples + json_data["_current_sample"] = float(self._current_sample) + npz_data[f"{prefix}/octave_values"] = self.octave_values + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore PinkNoise state including algorithm state.""" + super().load_checkpoint(prefix, json_data, npz) + self.n_samples = json_data.get("n_samples", 0) + self._current_sample = json_data.get("_current_sample", 0.0) + if f"{prefix}/octave_values" in npz: + self.octave_values = npz[f"{prefix}/octave_values"] \ No newline at end of file diff --git a/src/pathsim/blocks/rf.py b/src/pathsim/blocks/rf.py index 51ac9e00..4a82c5db 100644 --- a/src/pathsim/blocks/rf.py +++ b/src/pathsim/blocks/rf.py @@ -8,13 +8,8 @@ ## ######################################################################################### -# TODO LIST -# class RFAmplifier Model amplifier in RF systems -# class Resistor/Capacitor/Inductor -# class RFMixer for mixer in RF systems? - - # IMPORTS =============================================================================== + from __future__ import annotations import numpy as np @@ -38,10 +33,17 @@ from .lti import StateSpace +from ..utils.deprecation import deprecated + # BLOCK DEFINITIONS ===================================================================== +@deprecated( + version="1.0.0", + replacement="pathsim_rf.RFNetwork", + reason="This block has moved to the pathsim-rf package: pip install pathsim-rf", +) class RFNetwork(StateSpace): """N-port RF network linear time invariant (LTI) multi input multi output (MIMO) state-space model. @@ -78,7 +80,7 @@ def __init__(self, ntwk: NetworkType | str | Path, auto_fit: bool = True, **kwar _msg = "The scikit-rf package is required to use this block -> 'pip install scikit-rf'" raise ImportError(_msg) - if isinstance(ntwk, Path) or isinstance(ntwk, str): + if isinstance(ntwk, (Path, str)): ntwk = rf.Network(ntwk) # Select the vector fitting function from scikit-rf diff --git a/src/pathsim/blocks/rng.py b/src/pathsim/blocks/rng.py index 5841b5a5..72824107 100644 --- a/src/pathsim/blocks/rng.py +++ b/src/pathsim/blocks/rng.py @@ -96,6 +96,21 @@ def sample(self, t, dt): self._sample = np.random.rand() + def to_checkpoint(self, prefix, recordings=False): + """Serialize RNG state including current sample.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + if self.sampling_period is None: + json_data["_sample"] = float(self._sample) + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore RNG state including current sample.""" + super().load_checkpoint(prefix, json_data, npz) + if self.sampling_period is None: + self._sample = json_data.get("_sample", 0.0) + + def __len__(self): """Essentially a source-like block without passthrough""" return 0 diff --git a/src/pathsim/blocks/scope.py b/src/pathsim/blocks/scope.py index 4997f772..ec980785 100644 --- a/src/pathsim/blocks/scope.py +++ b/src/pathsim/blocks/scope.py @@ -448,13 +448,44 @@ def save(self, path="scope.csv"): wrt.writerow(sample) + def to_checkpoint(self, prefix, recordings=False): + """Serialize Scope state including optional recording data.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + + json_data["_incremental_idx"] = self._incremental_idx + if hasattr(self, '_sample_next_timestep'): + json_data["_sample_next_timestep"] = self._sample_next_timestep + + if recordings and self.recording_time: + npz_data[f"{prefix}/recording_time"] = np.array(self.recording_time) + npz_data[f"{prefix}/recording_data"] = np.array(self.recording_data) + + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore Scope state including optional recording data.""" + super().load_checkpoint(prefix, json_data, npz) + + self._incremental_idx = json_data.get("_incremental_idx", 0) + if hasattr(self, '_sample_next_timestep'): + self._sample_next_timestep = json_data.get("_sample_next_timestep", False) + + #restore recordings if present + rt_key = f"{prefix}/recording_time" + rd_key = f"{prefix}/recording_data" + if rt_key in npz and rd_key in npz: + self.recording_time = npz[rt_key].tolist() + self.recording_data = [row for row in npz[rd_key]] + + def update(self, t): - """update system equation for fixed point loop, + """update system equation for fixed point loop, here just setting the outputs - + Note ---- - Scope has no passthrough, so the 'update' method + Scope has no passthrough, so the 'update' method is optimized for this case (does nothing) Parameters diff --git a/src/pathsim/blocks/spectrum.py b/src/pathsim/blocks/spectrum.py index 7b3a0878..0dec61fe 100644 --- a/src/pathsim/blocks/spectrum.py +++ b/src/pathsim/blocks/spectrum.py @@ -283,6 +283,24 @@ def step(self, t, dt): return True, 0.0, None + def to_checkpoint(self, prefix, recordings=False): + """Serialize Spectrum state including integration time.""" + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + + json_data["time"] = self.time + json_data["t_sample"] = self.t_sample + + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore Spectrum state including integration time.""" + super().load_checkpoint(prefix, json_data, npz) + + self.time = json_data.get("time", 0.0) + self.t_sample = json_data.get("t_sample", 0.0) + + def sample(self, t, dt): """sample time of successfull timestep for waiting period diff --git a/src/pathsim/blocks/switch.py b/src/pathsim/blocks/switch.py index dc60d887..f89f28ca 100644 --- a/src/pathsim/blocks/switch.py +++ b/src/pathsim/blocks/switch.py @@ -82,6 +82,16 @@ def select(self, switch_state=0): self.switch_state = switch_state + def to_checkpoint(self, prefix, recordings=False): + json_data, npz_data = super().to_checkpoint(prefix, recordings=recordings) + json_data["switch_state"] = self.switch_state + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + super().load_checkpoint(prefix, json_data, npz) + self.switch_state = json_data.get("switch_state", None) + def update(self, t): """Update switch output depending on inputs and switch state. diff --git a/src/pathsim/events/_event.py b/src/pathsim/events/_event.py index fd911a5b..1ebd1a9e 100644 --- a/src/pathsim/events/_event.py +++ b/src/pathsim/events/_event.py @@ -201,4 +201,62 @@ def resolve(self, t): #action function for event resolution if self.func_act is not None: - self.func_act(t) \ No newline at end of file + self.func_act(t) + + + # checkpoint methods ---------------------------------------------------------------- + + def to_checkpoint(self, prefix): + """Serialize event state for checkpointing. + + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + + Returns + ------- + json_data : dict + JSON-serializable metadata + npz_data : dict + numpy arrays keyed by path + """ + #extract history eval value + hist_eval, hist_time = self._history + if hist_eval is not None and hasattr(hist_eval, 'item'): + hist_eval = float(hist_eval) + + json_data = { + "type": self.__class__.__name__, + "active": self._active, + "history_eval": hist_eval, + "history_time": hist_time, + } + + npz_data = {} + if self._times: + npz_data[f"{prefix}/times"] = np.array(self._times) + + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore event state from checkpoint. + + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + json_data : dict + event metadata from checkpoint JSON + npz : dict-like + numpy arrays from checkpoint NPZ + """ + self._active = json_data["active"] + self._history = json_data["history_eval"], json_data["history_time"] + + times_key = f"{prefix}/times" + if times_key in npz: + self._times = npz[times_key].tolist() + else: + self._times = [] diff --git a/src/pathsim/simulation.py b/src/pathsim/simulation.py index 306c7e4c..6f6dc9ed 100644 --- a/src/pathsim/simulation.py +++ b/src/pathsim/simulation.py @@ -10,6 +10,9 @@ # IMPORTS =============================================================================== +import json +import warnings + import numpy as np import time @@ -153,10 +156,10 @@ class Simulation: get attributes and access to intermediate evaluation stages logger : logging.Logger global simulation logger - _blocks_dyn : set[Block] - blocks with internal ´Solver´ instances (stateful) - _blocks_evt : set[Block] - blocks with internal events (discrete time, eventful) + _blocks_dyn : list[Block] + blocks with internal ´Solver´ instances (stateful) + _blocks_evt : list[Block] + blocks with internal events (discrete time, eventful) _active : bool flag for setting the simulation as active, used for interrupts """ @@ -177,9 +180,9 @@ def __init__( ): #system definition - self.blocks = set() - self.connections = set() - self.events = set() + self.blocks = [] + self.connections = [] + self.events = [] #simulation timestep and bounds self.dt = dt @@ -215,10 +218,10 @@ def __init__( self.time = 0.0 #collection of blocks with internal ODE solvers - self._blocks_dyn = set() + self._blocks_dyn = [] #collection of blocks with internal events - self._blocks_evt = set() + self._blocks_evt = [] #flag for setting the simulation active self._active = True @@ -269,8 +272,8 @@ def __contains__(self, other): bool """ return ( - other in self.blocks or - other in self.connections or + other in self.blocks or + other in self.connections or other in self.events ) @@ -331,6 +334,171 @@ def plot(self, *args, **kwargs): if block: block.plot(*args, **kwargs) + # checkpoint methods ---------------------------------------------------------- + + @staticmethod + def _checkpoint_key(type_name, type_counts): + """Generate a deterministic checkpoint key from block/event type + and occurrence index (e.g. 'Integrator_0', 'Scope_1'). + + Parameters + ---------- + type_name : str + class name of the block or event + type_counts : dict + running counter per type name, mutated in place + + Returns + ------- + key : str + deterministic checkpoint key + """ + idx = type_counts.get(type_name, 0) + type_counts[type_name] = idx + 1 + return f"{type_name}_{idx}" + + + def save_checkpoint(self, path, recordings=True): + """Save simulation state to checkpoint files (JSON + NPZ). + + Creates two files: {path}.json (structure/metadata) and + {path}.npz (numerical data). Blocks and events are keyed by + type and insertion order for deterministic cross-instance matching. + + Parameters + ---------- + path : str + base path without extension + recordings : bool + include scope/spectrum recording data (default: True) + """ + #strip extension if provided + if path.endswith('.json') or path.endswith('.npz'): + path = path.rsplit('.', 1)[0] + + #simulation metadata + checkpoint = { + "version": "1.0.0", + "pathsim_version": __version__, + "created": datetime.datetime.now(datetime.timezone.utc).isoformat(), + "simulation": { + "time": self.time, + "dt": self.dt, + "dt_min": self.dt_min, + "dt_max": self.dt_max, + "solver": self.Solver.__name__, + "tolerance_fpi": self.tolerance_fpi, + "iterations_max": self.iterations_max, + }, + "blocks": [], + "events": [], + } + + npz_data = {} + + #checkpoint all blocks (keyed by type + insertion index) + type_counts = {} + for block in self.blocks: + key = self._checkpoint_key(block.__class__.__name__, type_counts) + b_json, b_npz = block.to_checkpoint(key, recordings=recordings) + b_json["_key"] = key + checkpoint["blocks"].append(b_json) + npz_data.update(b_npz) + + #checkpoint external events (keyed by type + insertion index) + type_counts = {} + for event in self.events: + key = self._checkpoint_key(event.__class__.__name__, type_counts) + e_json, e_npz = event.to_checkpoint(key) + e_json["_key"] = key + checkpoint["events"].append(e_json) + npz_data.update(e_npz) + + #write files + with open(f"{path}.json", "w", encoding="utf-8") as f: + json.dump(checkpoint, f, indent=2, ensure_ascii=False) + + np.savez(f"{path}.npz", **npz_data) + + + def load_checkpoint(self, path): + """Load simulation state from checkpoint files (JSON + NPZ). + + Restores simulation time and all block/event states from a + previously saved checkpoint. Matching is based on block/event + type and insertion order, so the simulation must be constructed + with the same block types in the same order. + + Parameters + ---------- + path : str + base path without extension + """ + #strip extension if provided + if path.endswith('.json') or path.endswith('.npz'): + path = path.rsplit('.', 1)[0] + + #read files + with open(f"{path}.json", "r", encoding="utf-8") as f: + checkpoint = json.load(f) + + npz = np.load(f"{path}.npz", allow_pickle=False) + + try: + #version check + cp_version = checkpoint.get("pathsim_version", "unknown") + if cp_version != __version__: + warnings.warn( + f"Checkpoint was saved with PathSim {cp_version}, " + f"current version is {__version__}" + ) + + #restore simulation state + sim_data = checkpoint["simulation"] + self.time = sim_data["time"] + self.dt = sim_data["dt"] + self.dt_min = sim_data["dt_min"] + self.dt_max = sim_data["dt_max"] + + #solver type check + if sim_data["solver"] != self.Solver.__name__: + warnings.warn( + f"Checkpoint solver '{sim_data['solver']}' differs from " + f"current solver '{self.Solver.__name__}'" + ) + + #index checkpoint blocks by key + block_data = {b["_key"]: b for b in checkpoint.get("blocks", [])} + + #restore blocks by type + insertion order + type_counts = {} + for block in self.blocks: + key = self._checkpoint_key(block.__class__.__name__, type_counts) + if key in block_data: + block.load_checkpoint(key, block_data[key], npz) + else: + warnings.warn( + f"Block '{key}' not found in checkpoint" + ) + + #index checkpoint events by key + event_data = {e["_key"]: e for e in checkpoint.get("events", [])} + + #restore external events by type + insertion order + type_counts = {} + for event in self.events: + key = self._checkpoint_key(event.__class__.__name__, type_counts) + if key in event_data: + event.load_checkpoint(key, event_data[key], npz) + else: + warnings.warn( + f"Event '{key}' not found in checkpoint" + ) + + finally: + npz.close() + + # adding system components ---------------------------------------------------- def add_block(self, block): @@ -356,14 +524,14 @@ def add_block(self, block): #add to dynamic list if solver was initialized if block.engine: - self._blocks_dyn.add(block) + self._blocks_dyn.append(block) #add to eventful list if internal events if block.events: - self._blocks_evt.add(block) + self._blocks_evt.append(block) #add block to global blocklist - self.blocks.add(block) + self.blocks.append(block) #mark graph for rebuild if self.graph: @@ -389,13 +557,15 @@ def remove_block(self, block): raise ValueError(_msg) #remove from global blocklist - self.blocks.discard(block) + self.blocks.remove(block) #remove from dynamic list - self._blocks_dyn.discard(block) + if block in self._blocks_dyn: + self._blocks_dyn.remove(block) #remove from eventful list - self._blocks_evt.discard(block) + if block in self._blocks_evt: + self._blocks_evt.remove(block) #mark graph for rebuild if self.graph: @@ -421,7 +591,7 @@ def add_connection(self, connection): raise ValueError(_msg) #add connection to global connection list - self.connections.add(connection) + self.connections.append(connection) #mark graph for rebuild if self.graph: @@ -447,7 +617,7 @@ def remove_connection(self, connection): raise ValueError(_msg) #remove from global connection list - self.connections.discard(connection) + self.connections.remove(connection) #mark graph for rebuild if self.graph: @@ -472,7 +642,7 @@ def add_event(self, event): raise ValueError(_msg) #add event to global event list - self.events.add(event) + self.events.append(event) def remove_event(self, event): @@ -493,7 +663,7 @@ def remove_event(self, event): raise ValueError(_msg) #remove from global event list - self.events.discard(event) + self.events.remove(event) # system assembly ------------------------------------------------------------- @@ -551,10 +721,11 @@ def _check_blocks_are_managed(self): conn_blocks.update(conn.get_blocks()) # Check subset actively managed - if not conn_blocks.issubset(self.blocks): - self.logger.warning( - f"{blk} in 'connections' but not in 'blocks'!" - ) + for blk in conn_blocks: + if blk not in self.blocks: + self.logger.warning( + f"{blk} in 'connections' but not in 'blocks'!" + ) # solver management ----------------------------------------------------------- @@ -585,13 +756,13 @@ def _set_solver(self, Solver=None, **solver_kwargs): self.engine = self.Solver() #iterate all blocks and set integration engines with tolerances - self._blocks_dyn = set() + self._blocks_dyn = [] for block in self.blocks: block.set_solver(self.Solver, self.engine, **self.solver_kwargs) - + #add dynamic blocks to list if block.engine: - self._blocks_dyn.add(block) + self._blocks_dyn.append(block) #logging message self.logger.info( diff --git a/src/pathsim/solvers/_solver.py b/src/pathsim/solvers/_solver.py index 9cf00de9..d10bf16e 100644 --- a/src/pathsim/solvers/_solver.py +++ b/src/pathsim/solvers/_solver.py @@ -353,6 +353,71 @@ def create(cls, initial_value, parent=None, from_engine=None, **solver_kwargs): return cls(initial_value, parent, **solver_kwargs) + # checkpoint methods --------------------------------------------------------------- + + def to_checkpoint(self, prefix): + """Serialize solver state for checkpointing. + + Parameters + ---------- + prefix : str + NPZ key prefix for this solver's arrays + + Returns + ------- + json_data : dict + JSON-serializable metadata + npz_data : dict + numpy arrays keyed by path + """ + json_data = { + "type": self.__class__.__name__, + "is_adaptive": self.is_adaptive, + "n": self.n, + "history_len": len(self.history), + "history_maxlen": self.history.maxlen, + } + + npz_data = { + f"{prefix}/x": np.atleast_1d(self.x), + f"{prefix}/initial_value": np.atleast_1d(self.initial_value), + } + + for i, h in enumerate(self.history): + npz_data[f"{prefix}/history_{i}"] = np.atleast_1d(h) + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz, prefix): + """Restore solver state from checkpoint. + + Parameters + ---------- + json_data : dict + solver metadata from checkpoint JSON + npz : dict-like + numpy arrays from checkpoint NPZ + prefix : str + NPZ key prefix for this solver's arrays + """ + self.x = npz[f"{prefix}/x"].copy() + self.initial_value = npz[f"{prefix}/initial_value"].copy() + self.n = json_data.get("n", self.n) + + #restore scalar format if needed + if self._scalar_initial and self.initial_value.size == 1: + self.initial_value = self.initial_value.item() + + #restore history + maxlen = json_data.get("history_maxlen", self.history.maxlen) + self.history = deque([], maxlen=maxlen) + for i in range(json_data.get("history_len", 0)): + key = f"{prefix}/history_{i}" + if key in npz: + self.history.append(npz[key].copy()) + + # methods for adaptive timestep solvers -------------------------------------------- def error_controller(self): diff --git a/src/pathsim/solvers/gear.py b/src/pathsim/solvers/gear.py index e8cf269f..22968194 100644 --- a/src/pathsim/solvers/gear.py +++ b/src/pathsim/solvers/gear.py @@ -210,6 +210,50 @@ def create(cls, initial_value, parent=None, from_engine=None, **solver_kwargs): return cls(initial_value, parent, **solver_kwargs) + def to_checkpoint(self, prefix): + """Serialize GEAR solver state including startup solver and timestep history.""" + json_data, npz_data = super().to_checkpoint(prefix) + + json_data["_needs_startup"] = self._needs_startup + json_data["history_dt_len"] = len(self.history_dt) + + #timestep history + for i, dt in enumerate(self.history_dt): + npz_data[f"{prefix}/history_dt_{i}"] = np.atleast_1d(dt) + + #startup solver state + if self.startup: + s_json, s_npz = self.startup.to_checkpoint(f"{prefix}/startup") + json_data["startup"] = s_json + npz_data.update(s_npz) + + return json_data, npz_data + + + def load_checkpoint(self, json_data, npz, prefix): + """Restore GEAR solver state including startup solver and timestep history.""" + super().load_checkpoint(json_data, npz, prefix) + + self._needs_startup = json_data.get("_needs_startup", True) + + #restore timestep history + self.history_dt.clear() + for i in range(json_data.get("history_dt_len", 0)): + key = f"{prefix}/history_dt_{i}" + if key in npz: + self.history_dt.append(npz[key].item()) + + #restore startup solver + if self.startup and "startup" in json_data: + self.startup.load_checkpoint(json_data["startup"], npz, f"{prefix}/startup") + + #recompute BDF coefficients from restored history + if not self._needs_startup and len(self.history_dt) > 0: + self.F, self.K = {}, {} + for k, _ in enumerate(self.history_dt, 1): + self.F[k], self.K[k] = compute_bdf_coefficients(k, np.array(self.history_dt)) + + def stages(self, t, dt): """Generator that yields the intermediate evaluation time during the timestep 't + ratio * dt'. diff --git a/src/pathsim/subsystem.py b/src/pathsim/subsystem.py index 2dd3640c..6dec9691 100644 --- a/src/pathsim/subsystem.py +++ b/src/pathsim/subsystem.py @@ -182,26 +182,24 @@ def __init__(self, self.boosters = None #internal connecions - self.connections = set() - if connections: - self.connections.update(connections) - + self.connections = list(connections) if connections else [] + #collect and organize internal blocks - self.blocks = set() + self.blocks = [] self.interface = None if blocks: for block in blocks: - if isinstance(block, Interface): - + if isinstance(block, Interface): + if self.interface is not None: #interface block is already defined raise ValueError("Subsystem can only have one 'Interface' block!") - + self.interface = block - else: + else: #regular blocks - self.blocks.add(block) + self.blocks.append(block) #check if interface is defined if self.interface is None: @@ -276,7 +274,7 @@ def add_block(self, block): if block.engine: self._blocks_dyn.append(block) - self.blocks.add(block) + self.blocks.append(block) if self.graph: self._graph_dirty = True @@ -295,7 +293,7 @@ def remove_block(self, block): if block not in self.blocks: raise ValueError(f"block {block} not part of subsystem") - self.blocks.discard(block) + self.blocks.remove(block) #remove from dynamic list if hasattr(self, '_blocks_dyn') and block in self._blocks_dyn: @@ -318,7 +316,7 @@ def add_connection(self, connection): if connection in self.connections: raise ValueError(f"{connection} already part of subsystem") - self.connections.add(connection) + self.connections.append(connection) if self.graph: self._graph_dirty = True @@ -337,7 +335,7 @@ def remove_connection(self, connection): if connection not in self.connections: raise ValueError(f"{connection} not part of subsystem") - self.connections.discard(connection) + self.connections.remove(connection) if self.graph: self._graph_dirty = True @@ -386,7 +384,7 @@ def _assemble_graph(self): for block in self.blocks: block.inputs.reset() - self.graph = Graph({*self.blocks, self.interface}, self.connections) + self.graph = Graph([*self.blocks, self.interface], self.connections) self._graph_dirty = False #create boosters for loop closing connections @@ -457,6 +455,106 @@ def reset(self): block.reset() + @staticmethod + def _checkpoint_key(type_name, type_counts): + """Generate a deterministic checkpoint key from block/event type + and occurrence index (e.g. 'Integrator_0', 'Scope_1'). + """ + idx = type_counts.get(type_name, 0) + type_counts[type_name] = idx + 1 + return f"{type_name}_{idx}" + + + def to_checkpoint(self, prefix, recordings=False): + """Serialize subsystem state by recursively checkpointing internal blocks. + + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + recordings : bool + include recording data (for Scope blocks) + + Returns + ------- + json_data : dict + JSON-serializable metadata + npz_data : dict + numpy arrays keyed by path + """ + json_data = { + "type": self.__class__.__name__, + "active": self._active, + "blocks": [], + } + npz_data = {} + + #checkpoint interface block + if_json, if_npz = self.interface.to_checkpoint(f"{prefix}/interface", recordings=recordings) + json_data["interface"] = if_json + npz_data.update(if_npz) + + #checkpoint internal blocks by type + insertion order + type_counts = {} + for block in self.blocks: + key = f"{prefix}/{self._checkpoint_key(block.__class__.__name__, type_counts)}" + b_json, b_npz = block.to_checkpoint(key, recordings=recordings) + b_json["_key"] = key + json_data["blocks"].append(b_json) + npz_data.update(b_npz) + + #checkpoint subsystem-level events + if self._events: + evt_jsons = [] + for i, event in enumerate(self._events): + evt_prefix = f"{prefix}/evt_{i}" + e_json, e_npz = event.to_checkpoint(evt_prefix) + evt_jsons.append(e_json) + npz_data.update(e_npz) + json_data["events"] = evt_jsons + + return json_data, npz_data + + + def load_checkpoint(self, prefix, json_data, npz): + """Restore subsystem state by recursively loading internal blocks. + + Parameters + ---------- + prefix : str + key prefix for NPZ arrays (assigned by simulation) + json_data : dict + subsystem metadata from checkpoint JSON + npz : dict-like + numpy arrays from checkpoint NPZ + """ + #verify type + if json_data["type"] != self.__class__.__name__: + raise ValueError( + f"Checkpoint type mismatch: expected '{self.__class__.__name__}', " + f"got '{json_data['type']}'" + ) + + self._active = json_data["active"] + + #restore interface block + if "interface" in json_data: + self.interface.load_checkpoint(f"{prefix}/interface", json_data["interface"], npz) + + #restore internal blocks by type + insertion order + block_data = {b["_key"]: b for b in json_data.get("blocks", [])} + type_counts = {} + for block in self.blocks: + key = f"{prefix}/{self._checkpoint_key(block.__class__.__name__, type_counts)}" + if key in block_data: + block.load_checkpoint(key, block_data[key], npz) + + #restore subsystem-level events + if self._events and "events" in json_data: + for i, (event, evt_data) in enumerate(zip(self._events, json_data["events"])): + event.load_checkpoint(f"{prefix}/evt_{i}", evt_data, npz) + + def on(self): """Activate the subsystem and all internal blocks, sets the boolean evaluation flag to 'True'. diff --git a/src/pathsim/utils/adaptivebuffer.py b/src/pathsim/utils/adaptivebuffer.py index b24e2e5a..5b37fa05 100644 --- a/src/pathsim/utils/adaptivebuffer.py +++ b/src/pathsim/utils/adaptivebuffer.py @@ -10,6 +10,8 @@ # IMPORTS ============================================================================== +import numpy as np + from collections import deque from bisect import bisect_left @@ -120,4 +122,45 @@ def get(self, t): def clear(self): """clear the buffer, reset everything""" self.buffer_t.clear() - self.buffer_v.clear() \ No newline at end of file + self.buffer_v.clear() + + + def to_checkpoint(self, prefix): + """Serialize buffer state for checkpointing. + + Parameters + ---------- + prefix : str + NPZ key prefix + + Returns + ------- + npz_data : dict + numpy arrays keyed by path + """ + npz_data = {} + if self.buffer_t: + npz_data[f"{prefix}/buffer_t"] = np.array(list(self.buffer_t)) + npz_data[f"{prefix}/buffer_v"] = np.array(list(self.buffer_v)) + return npz_data + + + def load_checkpoint(self, npz, prefix): + """Restore buffer state from checkpoint. + + Parameters + ---------- + npz : dict-like + numpy arrays from checkpoint NPZ + prefix : str + NPZ key prefix + """ + self.clear() + t_key = f"{prefix}/buffer_t" + v_key = f"{prefix}/buffer_v" + if t_key in npz and v_key in npz: + times = npz[t_key] + values = npz[v_key] + for t, v in zip(times, values): + self.buffer_t.append(float(t)) + self.buffer_v.append(v) diff --git a/tests/pathsim/test_checkpoint.py b/tests/pathsim/test_checkpoint.py new file mode 100644 index 00000000..94fed3d6 --- /dev/null +++ b/tests/pathsim/test_checkpoint.py @@ -0,0 +1,769 @@ +"""Tests for checkpoint save/load functionality.""" + +import os +import json +import tempfile + +import numpy as np +import pytest + +from pathsim import Simulation, Connection, Subsystem, Interface +from pathsim.blocks import ( + Source, Integrator, Amplifier, Scope, Constant, Function +) +from pathsim.blocks.delay import Delay +from pathsim.blocks.switch import Switch +from pathsim.blocks.fir import FIR +from pathsim.blocks.kalman import KalmanFilter +from pathsim.blocks.noise import WhiteNoise, PinkNoise +from pathsim.blocks.rng import RandomNumberGenerator + + +class TestBlockCheckpoint: + """Test block-level checkpoint methods.""" + + def test_basic_block_to_checkpoint(self): + """Block produces valid checkpoint data.""" + b = Integrator(1.0) + b.inputs[0] = 3.14 + prefix = "Integrator_0" + json_data, npz_data = b.to_checkpoint(prefix) + + assert json_data["type"] == "Integrator" + assert json_data["active"] is True + assert f"{prefix}/inputs" in npz_data + assert f"{prefix}/outputs" in npz_data + + def test_block_checkpoint_roundtrip(self): + """Block state survives save/load cycle.""" + b = Integrator(2.5) + b.inputs[0] = 1.0 + b.outputs[0] = 2.5 + prefix = "Integrator_0" + + json_data, npz_data = b.to_checkpoint(prefix) + + #reset block + b.reset() + assert b.inputs[0] == 0.0 + + #restore + b.load_checkpoint(prefix, json_data, npz_data) + assert np.isclose(b.inputs[0], 1.0) + assert np.isclose(b.outputs[0], 2.5) + + def test_block_type_mismatch_raises(self): + """Loading checkpoint with wrong type raises ValueError.""" + b = Integrator() + prefix = "Integrator_0" + json_data, npz_data = b.to_checkpoint(prefix) + + b2 = Amplifier(1.0) + with pytest.raises(ValueError, match="type mismatch"): + b2.load_checkpoint(prefix, json_data, npz_data) + + +class TestEventCheckpoint: + """Test event-level checkpoint methods.""" + + def test_event_checkpoint_roundtrip(self): + from pathsim.events import ZeroCrossing + e = ZeroCrossing(func_evt=lambda t: t - 1.0) + e._history = (0.5, 0.99) + e._times = [1.0, 2.0, 3.0] + e._active = False + prefix = "ZeroCrossing_0" + + json_data, npz_data = e.to_checkpoint(prefix) + + e.reset() + assert e._active is True + assert len(e._times) == 0 + + e.load_checkpoint(prefix, json_data, npz_data) + assert e._active is False + assert e._times == [1.0, 2.0, 3.0] + assert e._history == (0.5, 0.99) + + +class TestSwitchCheckpoint: + """Test Switch block checkpoint.""" + + def test_switch_state_preserved(self): + s = Switch(switch_state=2) + prefix = "Switch_0" + json_data, npz_data = s.to_checkpoint(prefix) + + s.select(None) + assert s.switch_state is None + + s.load_checkpoint(prefix, json_data, npz_data) + assert s.switch_state == 2 + + +class TestSimulationCheckpoint: + """Test simulation-level checkpoint save/load.""" + + def test_save_load_simple(self): + """Simple simulation checkpoint round-trip.""" + src = Source(lambda t: np.sin(2 * np.pi * t)) + integ = Integrator() + scope = Scope() + + sim = Simulation( + blocks=[src, integ, scope], + connections=[ + Connection(src, integ, scope), + ], + dt=0.01 + ) + + #run for 1 second + sim.run(1.0) + time_after_run = sim.time + state_after_run = integ.state.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "checkpoint") + sim.save_checkpoint(path) + + #verify files exist + assert os.path.exists(f"{path}.json") + assert os.path.exists(f"{path}.npz") + + #verify JSON structure + with open(f"{path}.json") as f: + data = json.load(f) + assert data["version"] == "1.0.0" + assert data["simulation"]["time"] == time_after_run + assert any(b["_key"] == "Integrator_0" for b in data["blocks"]) + + #reset and reload + sim.time = 0.0 + integ.state = np.array([0.0]) + + sim.load_checkpoint(path) + assert sim.time == time_after_run + assert np.allclose(integ.state, state_after_run) + + def test_continue_after_load(self): + """Simulation continues correctly after checkpoint load.""" + #run continuously for 2 seconds + src1 = Source(lambda t: 1.0) + integ1 = Integrator() + sim1 = Simulation( + blocks=[src1, integ1], + connections=[Connection(src1, integ1)], + dt=0.01 + ) + sim1.run(2.0) + reference_state = integ1.state.copy() + + #run for 1 second, save, load, run 1 more second + src2 = Source(lambda t: 1.0) + integ2 = Integrator() + sim2 = Simulation( + blocks=[src2, integ2], + connections=[Connection(src2, integ2)], + dt=0.01 + ) + sim2.run(1.0) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim2.save_checkpoint(path) + sim2.load_checkpoint(path) + sim2.run(1.0) # run 1 more second (t=1 -> t=2) + + #compare results + assert np.allclose(integ2.state, reference_state, rtol=1e-6) + + def test_scope_recordings(self): + """Scope recordings are saved when recordings=True.""" + src = Source(lambda t: t) + scope = Scope() + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.1 + ) + sim.run(1.0) + + with tempfile.TemporaryDirectory() as tmpdir: + #without recordings + path1 = os.path.join(tmpdir, "no_rec") + sim.save_checkpoint(path1, recordings=False) + npz1 = np.load(f"{path1}.npz") + assert "Scope_0/recording_time" not in npz1 + npz1.close() + + #with recordings + path2 = os.path.join(tmpdir, "with_rec") + sim.save_checkpoint(path2, recordings=True) + npz2 = np.load(f"{path2}.npz") + assert "Scope_0/recording_time" in npz2 + npz2.close() + + def test_delay_continuous_checkpoint(self): + """Continuous delay block preserves buffer.""" + src = Source(lambda t: np.sin(t)) + delay = Delay(tau=0.1) + scope = Scope() + sim = Simulation( + blocks=[src, delay, scope], + connections=[ + Connection(src, delay, scope), + ], + dt=0.01 + ) + sim.run(0.5) + + #capture delay output + delay_output = delay.outputs[0] + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + #reset delay buffer + delay._buffer.clear() + + sim.load_checkpoint(path) + assert np.isclose(delay.outputs[0], delay_output) + + def test_delay_discrete_checkpoint(self): + """Discrete delay block preserves ring buffer.""" + src = Source(lambda t: float(t > 0)) + delay = Delay(tau=0.05, sampling_period=0.01) + sim = Simulation( + blocks=[src, delay], + connections=[Connection(src, delay)], + dt=0.01 + ) + sim.run(0.1) + + ring_before = list(delay._ring) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + delay._ring.clear() + sim.load_checkpoint(path) + assert list(delay._ring) == ring_before + + def test_cross_instance_load(self): + """Checkpoint loads into a freshly constructed simulation (different UUIDs).""" + src1 = Source(lambda t: 1.0) + integ1 = Integrator() + sim1 = Simulation( + blocks=[src1, integ1], + connections=[Connection(src1, integ1)], + dt=0.01 + ) + sim1.run(1.0) + saved_time = sim1.time + saved_state = integ1.state.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim1.save_checkpoint(path) + + #create entirely new simulation (new block objects, new UUIDs) + src2 = Source(lambda t: 1.0) + integ2 = Integrator() + sim2 = Simulation( + blocks=[src2, integ2], + connections=[Connection(src2, integ2)], + dt=0.01 + ) + + sim2.load_checkpoint(path) + assert sim2.time == saved_time + assert np.allclose(integ2.state, saved_state) + + def test_scope_recordings_preserved_without_flag(self): + """Loading without recordings flag does not erase existing recordings.""" + src = Source(lambda t: t) + scope = Scope() + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.1 + ) + sim.run(1.0) + + #scope has recordings + assert len(scope.recording_time) > 0 + rec_len = len(scope.recording_time) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path, recordings=False) + sim.load_checkpoint(path) + + #recordings should still be intact + assert len(scope.recording_time) == rec_len + + def test_multiple_same_type_blocks(self): + """Multiple blocks of the same type are matched by insertion order.""" + src = Source(lambda t: 1.0) + i1 = Integrator(1.0) + i2 = Integrator(2.0) + sim = Simulation( + blocks=[src, i1, i2], + connections=[Connection(src, i1), Connection(src, i2)], + dt=0.01 + ) + sim.run(0.5) + + state1 = i1.state.copy() + state2 = i2.state.copy() + assert not np.allclose(state1, state2) # different initial conditions + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + i1.state = np.array([0.0]) + i2.state = np.array([0.0]) + + sim.load_checkpoint(path) + assert np.allclose(i1.state, state1) + assert np.allclose(i2.state, state2) + + +class TestFIRCheckpoint: + """Test FIR block checkpoint.""" + + def test_fir_buffer_preserved(self): + """FIR filter buffer survives checkpoint round-trip.""" + fir = FIR(coeffs=[0.25, 0.5, 0.25], T=0.01) + prefix = "FIR_0" + + #simulate some input to fill the buffer + fir._buffer.appendleft(1.0) + fir._buffer.appendleft(2.0) + buffer_before = list(fir._buffer) + + json_data, npz_data = fir.to_checkpoint(prefix) + + fir._buffer.clear() + fir._buffer.extend([0.0] * 3) + + fir.load_checkpoint(prefix, json_data, npz_data) + assert list(fir._buffer) == buffer_before + + +class TestKalmanFilterCheckpoint: + """Test KalmanFilter block checkpoint.""" + + def test_kalman_state_preserved(self): + """Kalman filter state and covariance survive checkpoint.""" + F = np.array([[1.0, 0.1], [0.0, 1.0]]) + H = np.array([[1.0, 0.0]]) + Q = np.eye(2) * 0.01 + R = np.array([[0.1]]) + + kf = KalmanFilter(F, H, Q, R) + prefix = "KalmanFilter_0" + + #set some state + kf.x = np.array([3.14, -1.0]) + kf.P = np.array([[0.5, 0.1], [0.1, 0.3]]) + + json_data, npz_data = kf.to_checkpoint(prefix) + + kf.x = np.zeros(2) + kf.P = np.eye(2) + + kf.load_checkpoint(prefix, json_data, npz_data) + assert np.allclose(kf.x, [3.14, -1.0]) + assert np.allclose(kf.P, [[0.5, 0.1], [0.1, 0.3]]) + + +class TestNoiseCheckpoint: + """Test noise block checkpoints.""" + + def test_white_noise_sample_preserved(self): + """WhiteNoise current sample survives checkpoint.""" + wn = WhiteNoise(standard_deviation=2.0) + wn._current_sample = 1.234 + prefix = "WhiteNoise_0" + + json_data, npz_data = wn.to_checkpoint(prefix) + wn._current_sample = 0.0 + + wn.load_checkpoint(prefix, json_data, npz_data) + assert wn._current_sample == pytest.approx(1.234) + + def test_pink_noise_state_preserved(self): + """PinkNoise algorithm state survives checkpoint.""" + pn = PinkNoise(num_octaves=8, seed=42) + prefix = "PinkNoise_0" + + #advance the noise state + for _ in range(10): + pn._generate_sample(0.01) + + n_samples_before = pn.n_samples + octaves_before = pn.octave_values.copy() + sample_before = pn._current_sample + + json_data, npz_data = pn.to_checkpoint(prefix) + + pn.reset() + assert pn.n_samples == 0 + + pn.load_checkpoint(prefix, json_data, npz_data) + assert pn.n_samples == n_samples_before + assert np.allclose(pn.octave_values, octaves_before) + + +class TestRNGCheckpoint: + """Test RandomNumberGenerator checkpoint.""" + + def test_rng_sample_preserved(self): + """RNG current sample survives checkpoint (continuous mode).""" + rng = RandomNumberGenerator(sampling_period=None) + prefix = "RandomNumberGenerator_0" + sample_before = rng._sample + + json_data, npz_data = rng.to_checkpoint(prefix) + rng._sample = 0.0 + + rng.load_checkpoint(prefix, json_data, npz_data) + assert rng._sample == pytest.approx(sample_before) + + +class TestSubsystemCheckpoint: + """Test Subsystem checkpoint.""" + + def test_subsystem_roundtrip(self): + """Subsystem with internal blocks survives checkpoint round-trip.""" + #build a simple subsystem: two integrators in series + If = Interface() + I1 = Integrator(1.0) + I2 = Integrator(0.0) + + sub = Subsystem( + blocks=[If, I1, I2], + connections=[ + Connection(If, I1), + Connection(I1, I2), + Connection(I2, If), + ] + ) + + #embed in a simulation + src = Source(lambda t: 1.0) + scope = Scope() + sim = Simulation( + blocks=[src, sub, scope], + connections=[ + Connection(src, sub), + Connection(sub, scope), + ], + dt=0.01 + ) + + sim.run(0.5) + state_I1 = I1.state.copy() + state_I2 = I2.state.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + #zero out states + I1.state = np.array([0.0]) + I2.state = np.array([0.0]) + + sim.load_checkpoint(path) + assert np.allclose(I1.state, state_I1) + assert np.allclose(I2.state, state_I2) + + def test_subsystem_cross_instance(self): + """Subsystem checkpoint loads into a fresh simulation instance.""" + If1 = Interface() + I1 = Integrator(1.0) + sub1 = Subsystem( + blocks=[If1, I1], + connections=[Connection(If1, I1), Connection(I1, If1)] + ) + src1 = Source(lambda t: 1.0) + sim1 = Simulation( + blocks=[src1, sub1], + connections=[Connection(src1, sub1)], + dt=0.01 + ) + sim1.run(0.5) + state_before = I1.state.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim1.save_checkpoint(path) + + #new instance + If2 = Interface() + I2 = Integrator(1.0) + sub2 = Subsystem( + blocks=[If2, I2], + connections=[Connection(If2, I2), Connection(I2, If2)] + ) + src2 = Source(lambda t: 1.0) + sim2 = Simulation( + blocks=[src2, sub2], + connections=[Connection(src2, sub2)], + dt=0.01 + ) + sim2.load_checkpoint(path) + assert np.allclose(I2.state, state_before) + + +class TestGEARCheckpoint: + """Test GEAR solver checkpoint round-trip.""" + + def test_gear_solver_roundtrip(self): + """GEAR solver state survives checkpoint including BDF coefficients.""" + from pathsim.solvers import GEAR32 + + src = Source(lambda t: np.sin(2 * np.pi * t)) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01, + Solver=GEAR32 + ) + + #run long enough for GEAR to exit startup phase + sim.run(0.5) + state_after = integ.state.copy() + time_after = sim.time + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + #reset state + integ.state = np.array([0.0]) + sim.time = 0.0 + + sim.load_checkpoint(path) + assert sim.time == time_after + assert np.allclose(integ.state, state_after) + + def test_gear_continue_after_load(self): + """GEAR simulation continues correctly after checkpoint load.""" + from pathsim.solvers import GEAR32 + + #reference: run 2s continuously + src1 = Source(lambda t: 1.0) + integ1 = Integrator() + sim1 = Simulation( + blocks=[src1, integ1], + connections=[Connection(src1, integ1)], + dt=0.01, + Solver=GEAR32 + ) + sim1.run(2.0) + reference = integ1.state.copy() + + #split: run 1s, save, load, run 1s more + src2 = Source(lambda t: 1.0) + integ2 = Integrator() + sim2 = Simulation( + blocks=[src2, integ2], + connections=[Connection(src2, integ2)], + dt=0.01, + Solver=GEAR32 + ) + sim2.run(1.0) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim2.save_checkpoint(path) + sim2.load_checkpoint(path) + sim2.run(1.0) + + assert np.allclose(integ2.state, reference, rtol=1e-6) + + +class TestSpectrumCheckpoint: + """Test Spectrum block checkpoint.""" + + def test_spectrum_roundtrip(self): + """Spectrum block state survives checkpoint round-trip.""" + from pathsim.blocks.spectrum import Spectrum + + src = Source(lambda t: np.sin(2 * np.pi * 10 * t)) + spec = Spectrum(freq=[5, 10, 15], t_wait=0.0) + sim = Simulation( + blocks=[src, spec], + connections=[Connection(src, spec)], + dt=0.001 + ) + sim.run(0.1) + + time_before = spec.time + t_sample_before = spec.t_sample + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + spec.time = 0.0 + spec.t_sample = 0.0 + + sim.load_checkpoint(path) + assert spec.time == pytest.approx(time_before) + assert spec.t_sample == pytest.approx(t_sample_before) + + +class TestScopeCheckpointExtended: + """Extended Scope checkpoint tests for coverage.""" + + def test_scope_with_sampling_period(self): + """Scope with sampling_period preserves _sample_next_timestep.""" + src = Source(lambda t: t) + scope = Scope(sampling_period=0.1) + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.01 + ) + sim.run(0.5) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + sim.load_checkpoint(path) + + #verify scope still works after load + sim.run(0.1) + assert len(scope.recording_time) > 0 + + def test_scope_recordings_roundtrip(self): + """Scope recording data round-trips with recordings=True.""" + src = Source(lambda t: t) + scope = Scope() + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.1 + ) + sim.run(1.0) + + rec_time = scope.recording_time.copy() + rec_data = [row.copy() for row in scope.recording_data] + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path, recordings=True) + + #clear recordings + scope.recording_time = [] + scope.recording_data = [] + + sim.load_checkpoint(path) + assert len(scope.recording_time) == len(rec_time) + assert np.allclose(scope.recording_time, rec_time) + + def test_scope_recordings_included_by_default(self): + """Default save_checkpoint includes recordings.""" + src = Source(lambda t: t) + scope = Scope() + sim = Simulation( + blocks=[src, scope], + connections=[Connection(src, scope)], + dt=0.1 + ) + sim.run(1.0) + + rec_time = scope.recording_time.copy() + assert len(rec_time) > 0 + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) # no recordings kwarg — default + + #clear recordings + scope.recording_time = [] + scope.recording_data = [] + + sim.load_checkpoint(path) + assert len(scope.recording_time) == len(rec_time) + assert np.allclose(scope.recording_time, rec_time) + + +class TestSimulationCheckpointExtended: + """Extended simulation checkpoint tests for coverage.""" + + def test_save_load_with_extension(self): + """Path with .json extension is handled correctly.""" + src = Source(lambda t: 1.0) + integ = Integrator() + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + dt=0.01 + ) + sim.run(0.1) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp.json") + sim.save_checkpoint(path) + + assert os.path.exists(os.path.join(tmpdir, "cp.json")) + assert os.path.exists(os.path.join(tmpdir, "cp.npz")) + + sim.load_checkpoint(path) + assert sim.time == pytest.approx(0.1, abs=0.01) + + def test_checkpoint_with_events(self): + """Simulation with external events checkpoints correctly.""" + from pathsim.events import Schedule + + src = Source(lambda t: 1.0) + integ = Integrator() + + event_fired = [False] + def on_event(t): + event_fired[0] = True + + evt = Schedule(t_start=0.5, t_period=1.0, func_act=on_event) + + sim = Simulation( + blocks=[src, integ], + connections=[Connection(src, integ)], + events=[evt], + dt=0.01 + ) + sim.run(1.0) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cp") + sim.save_checkpoint(path) + + #verify events in JSON + with open(f"{path}.json") as f: + data = json.load(f) + assert len(data["events"]) == 1 + assert data["events"][0]["type"] == "Schedule" + + sim.load_checkpoint(path) + + def test_event_numpy_history(self): + """Event with numpy scalar in history serializes correctly.""" + from pathsim.events import ZeroCrossing + + e = ZeroCrossing(func_evt=lambda t: t - 1.0) + e._history = (np.float64(0.5), 0.99) + prefix = "ZeroCrossing_0" + + json_data, npz_data = e.to_checkpoint(prefix) + assert isinstance(json_data["history_eval"], float) + + e.reset() + e.load_checkpoint(prefix, json_data, npz_data) + assert e._history[0] == pytest.approx(0.5) diff --git a/tests/pathsim/test_simulation.py b/tests/pathsim/test_simulation.py index beda6aae..d370b473 100644 --- a/tests/pathsim/test_simulation.py +++ b/tests/pathsim/test_simulation.py @@ -52,9 +52,9 @@ def test_init_default(self): #test default initialization Sim = Simulation(log=False) - self.assertEqual(Sim.blocks, set()) - self.assertEqual(Sim.connections, set()) - self.assertEqual(Sim.events, set()) + self.assertEqual(Sim.blocks, []) + self.assertEqual(Sim.connections, []) + self.assertEqual(Sim.events, []) self.assertEqual(Sim.dt, SIM_TIMESTEP) self.assertEqual(Sim.dt_min, SIM_TIMESTEP_MIN) self.assertEqual(Sim.dt_max, SIM_TIMESTEP_MAX) @@ -130,12 +130,12 @@ def test_add_block(self): Sim = Simulation(log=False) - self.assertEqual(Sim.blocks, set()) + self.assertEqual(Sim.blocks, []) #test adding a block B1 = Block() Sim.add_block(B1) - self.assertEqual(Sim.blocks, {B1}) + self.assertEqual(Sim.blocks, [B1]) #test adding the same block again with self.assertRaises(ValueError): @@ -153,17 +153,17 @@ def test_add_connection(self): log=False ) - self.assertEqual(Sim.connections, {C1}) + self.assertEqual(Sim.connections, [C1]) #test adding a connection C2 = Connection(B2, B3) Sim.add_connection(C2) - self.assertEqual(Sim.connections, {C1, C2}) + self.assertEqual(Sim.connections, [C1, C2]) #test adding the same connection again with self.assertRaises(ValueError): Sim.add_connection(C2) - self.assertEqual(Sim.connections, {C1, C2}) + self.assertEqual(Sim.connections, [C1, C2]) def test_set_solver(self):