Skip to content

Commit b547c2f

Browse files
committed
Add fine-grained per-block diagnostics for convergence debugging
1 parent d2a2f91 commit b547c2f

File tree

3 files changed

+447
-19
lines changed

3 files changed

+447
-19
lines changed

src/pathsim/simulation.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .utils.deprecation import deprecated
3838
from .utils.portreference import PortReference
3939
from .utils.progresstracker import ProgressTracker
40+
from .utils.diagnostics import Diagnostics
4041
from .utils.logger import LoggerManager
4142

4243
from .solvers import SSPRK22, SteadyState
@@ -165,17 +166,18 @@ class Simulation:
165166
"""
166167

167168
def __init__(
168-
self,
169-
blocks=None,
170-
connections=None,
169+
self,
170+
blocks=None,
171+
connections=None,
171172
events=None,
172-
dt=SIM_TIMESTEP,
173-
dt_min=SIM_TIMESTEP_MIN,
174-
dt_max=SIM_TIMESTEP_MAX,
175-
Solver=SSPRK22,
176-
tolerance_fpi=SIM_TOLERANCE_FPI,
177-
iterations_max=SIM_ITERATIONS_MAX,
173+
dt=SIM_TIMESTEP,
174+
dt_min=SIM_TIMESTEP_MIN,
175+
dt_max=SIM_TIMESTEP_MAX,
176+
Solver=SSPRK22,
177+
tolerance_fpi=SIM_TOLERANCE_FPI,
178+
iterations_max=SIM_ITERATIONS_MAX,
178179
log=LOG_ENABLE,
180+
diagnostics=False,
179181
**solver_kwargs
180182
):
181183

@@ -226,6 +228,12 @@ def __init__(
226228
#flag for setting the simulation active
227229
self._active = True
228230

231+
#diagnostics snapshot (None when disabled)
232+
self.diagnostics = Diagnostics() if diagnostics else None
233+
234+
#diagnostics history (list of snapshots per timestep)
235+
self._diagnostics_history = [] if diagnostics == "history" else None
236+
229237
#initialize logging
230238
logger_mgr = LoggerManager(
231239
enabled=bool(self.log),
@@ -815,6 +823,12 @@ def reset(self, time=0.0):
815823
for event in self.events:
816824
event.reset()
817825

826+
#reset diagnostics
827+
if self.diagnostics is not None:
828+
self.diagnostics = Diagnostics()
829+
if self._diagnostics_history is not None:
830+
self._diagnostics_history.clear()
831+
818832
#evaluate system function
819833
self._update(self.time)
820834

@@ -1027,18 +1041,24 @@ def _loops(self, t):
10271041

10281042
#step boosters of loop closing connections
10291043
max_err = 0.0
1044+
self._loop_errors = {}
10301045
for con_booster in self.boosters:
10311046
err = con_booster.update()
1047+
self._loop_errors[con_booster] = err
10321048
if err > max_err:
10331049
max_err = err
1034-
1050+
10351051
#check convergence
10361052
if max_err <= self.tolerance_fpi:
1053+
self._loop_iterations = iteration
10371054
return
10381055

1039-
#not converged -> error
1040-
_msg = "algebraic loop not converged (iters: {}, err: {})".format(
1041-
self.iterations_max, max_err
1056+
self._loop_iterations = self.iterations_max
1057+
1058+
#not converged -> error with per-connection details
1059+
details = [f" {b.connection}: {e:.2e}" for b, e in self._loop_errors.items()]
1060+
_msg = "algebraic loop not converged (iters: {}, err: {:.2e})\n{}".format(
1061+
self.iterations_max, max_err, "\n".join(details)
10421062
)
10431063
self.logger.error(_msg)
10441064
raise RuntimeError(_msg)
@@ -1075,31 +1095,36 @@ def _solve(self, t, dt):
10751095
#total evaluations of system equation
10761096
total_evals = 0
10771097

1098+
#per-block residuals (overwritten each iteration, only final kept)
1099+
self._solve_errors = {}
1100+
10781101
#perform fixed-point iterations to solve implicit update equation
10791102
for it in range(self.iterations_max):
10801103

10811104
#evaluate system equation (this is a fixed point loop)
10821105
self._update(t)
1083-
total_evals += 1
1106+
total_evals += 1
10841107

10851108
#advance solution of implicit solver
10861109
max_error = 0.0
10871110
for block in self._blocks_dyn:
10881111

10891112
#skip inactive blocks
1090-
if not block:
1113+
if not block:
10911114
continue
1092-
1115+
10931116
#advance solution (internal optimizer)
10941117
error = block.solve(t, dt)
1118+
self._solve_errors[block] = error
10951119
if error > max_error:
10961120
max_error = error
10971121

10981122
#check for convergence (only error)
10991123
if max_error <= self.tolerance_fpi:
1124+
self._solve_iterations = it + 1
11001125
return True, total_evals, it+1
11011126

1102-
#not converged in 'self.iterations_max' steps
1127+
self._solve_iterations = self.iterations_max
11031128
return False, total_evals, self.iterations_max
11041129

11051130

@@ -1156,8 +1181,10 @@ def steadystate(self, reset=False):
11561181

11571182
#catch non convergence
11581183
if not success:
1159-
_msg = "STEADYSTATE -> FINISHED (success: {}, evals: {}, iters: {}, runtime: {})".format(
1160-
success, evals, iters, T)
1184+
details = [f" {b.__class__.__name__}: {e:.2e}"
1185+
for b, e in self._solve_errors.items()]
1186+
_msg = "STEADYSTATE -> FAILED (evals: {}, iters: {}, runtime: {})\n{}".format(
1187+
evals, iters, T, "\n".join(details))
11611188
self.logger.error(_msg)
11621189
raise RuntimeError(_msg)
11631190

@@ -1278,6 +1305,7 @@ def _step(self, t, dt):
12781305

12791306
#initial timestep rescale and error estimate
12801307
success, max_error_norm, min_scale = True, 0.0, None
1308+
self._step_errors = {}
12811309

12821310
#step blocks and get error estimates if available
12831311
for block in self._blocks_dyn:
@@ -1287,6 +1315,7 @@ def _step(self, t, dt):
12871315

12881316
#step the block
12891317
suc, err_norm, scl = block.step(t, dt)
1318+
self._step_errors[block] = (suc, err_norm, scl)
12901319

12911320
#check solver stepping success
12921321
if not suc:
@@ -1471,8 +1500,23 @@ def timestep(self, dt=None, adaptive=True):
14711500

14721501
#adaptive implicit: revert if solver didn't converge
14731502
if not success and is_adaptive:
1503+
details = [f" {b.__class__.__name__}: {e:.2e}"
1504+
for b, e in self._solve_errors.items()]
1505+
self.logger.warning(
1506+
"implicit solver not converged, reverting step at t={:.6f}\n{}".format(
1507+
time_stage, "\n".join(details))
1508+
)
14741509
self._revert(self.time)
14751510
return False, 0.0, 0.5, total_evals + 1, total_solver_its
1511+
1512+
#fixed implicit: warn if solver didn't converge
1513+
if not success and not is_adaptive:
1514+
details = [f" {b.__class__.__name__}: {e:.2e}"
1515+
for b, e in self._solve_errors.items()]
1516+
self.logger.warning(
1517+
"implicit solver not converged at t={:.6f} (iters: {})\n{}".format(
1518+
time_stage, solver_its, "\n".join(details))
1519+
)
14761520
else:
14771521
#explicit: evaluate system equation
14781522
self._update(time_stage)
@@ -1511,6 +1555,19 @@ def timestep(self, dt=None, adaptive=True):
15111555
self._update(time_dt)
15121556
total_evals += 1
15131557

1558+
#update diagnostics snapshot for this timestep
1559+
if self.diagnostics is not None:
1560+
self.diagnostics = Diagnostics(
1561+
time=time_dt,
1562+
loop_residuals=dict(getattr(self, '_loop_errors', {})),
1563+
loop_iterations=getattr(self, '_loop_iterations', 0),
1564+
solve_residuals=dict(getattr(self, '_solve_errors', {})),
1565+
solve_iterations=getattr(self, '_solve_iterations', 0),
1566+
step_errors=dict(getattr(self, '_step_errors', {})),
1567+
)
1568+
if self._diagnostics_history is not None:
1569+
self._diagnostics_history.append(self.diagnostics)
1570+
15141571
#sample data after successful timestep
15151572
self._sample(time_dt, dt)
15161573

src/pathsim/utils/diagnostics.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#########################################################################################
2+
##
3+
## DIAGNOSTICS FOR SIMULATION
4+
## (utils/diagnostics.py)
5+
##
6+
## Fine-grained per-block and per-connection convergence metrics
7+
## for debugging non-converging simulations.
8+
##
9+
#########################################################################################
10+
11+
# IMPORTS ===============================================================================
12+
13+
from dataclasses import dataclass, field
14+
15+
16+
# DIAGNOSTICS ===========================================================================
17+
18+
@dataclass
19+
class Diagnostics:
20+
"""Per-timestep convergence diagnostics snapshot.
21+
22+
Populated by the simulation engine after each successful timestep.
23+
Stores per-block and per-connection residuals from the three
24+
convergence loops: algebraic loop solver, implicit ODE solver,
25+
and adaptive error control.
26+
27+
Attributes
28+
----------
29+
time : float
30+
simulation time
31+
loop_residuals : dict
32+
per-booster algebraic loop residuals (booster -> residual)
33+
loop_iterations : int
34+
number of algebraic loop iterations taken
35+
solve_residuals : dict
36+
per-block implicit solver residuals (block -> residual)
37+
solve_iterations : int
38+
number of implicit solver iterations taken
39+
step_errors : dict
40+
per-block adaptive step data (block -> (success, err_norm, scale))
41+
"""
42+
time: float = 0.0
43+
loop_residuals: dict = field(default_factory=dict)
44+
loop_iterations: int = 0
45+
solve_residuals: dict = field(default_factory=dict)
46+
solve_iterations: int = 0
47+
step_errors: dict = field(default_factory=dict)
48+
49+
50+
@staticmethod
51+
def _label(obj):
52+
"""Human-readable label for a block or booster."""
53+
if hasattr(obj, 'connection'):
54+
return str(obj.connection)
55+
return obj.__class__.__name__
56+
57+
58+
def worst_block(self):
59+
"""Block with the highest residual across solve and step errors.
60+
61+
Returns
62+
-------
63+
tuple[str, float] or None
64+
(label, error) or None if no data
65+
"""
66+
worst, worst_err = None, -1.0
67+
68+
for obj, err in self.solve_residuals.items():
69+
if err > worst_err:
70+
worst, worst_err = obj, err
71+
72+
for obj, (_, err_norm, _) in self.step_errors.items():
73+
if err_norm > worst_err:
74+
worst, worst_err = obj, err_norm
75+
76+
if worst is None:
77+
return None
78+
return self._label(worst), worst_err
79+
80+
81+
def worst_booster(self):
82+
"""Connection booster with the highest algebraic loop residual.
83+
84+
Returns
85+
-------
86+
tuple[str, float] or None
87+
(label, residual) or None if no data
88+
"""
89+
if not self.loop_residuals:
90+
return None
91+
92+
worst = max(self.loop_residuals, key=self.loop_residuals.get)
93+
return self._label(worst), self.loop_residuals[worst]
94+
95+
96+
def summary(self):
97+
"""Formatted summary of this diagnostics snapshot.
98+
99+
Returns
100+
-------
101+
str
102+
human-readable diagnostics summary
103+
"""
104+
lines = [f"Diagnostics at t = {self.time:.6f}"]
105+
106+
if self.step_errors:
107+
lines.append(f"\n Adaptive step errors:")
108+
for obj, (suc, err, scl) in self.step_errors.items():
109+
status = "OK" if suc else "FAIL"
110+
scl_str = f"{scl:.3f}" if scl is not None else "N/A"
111+
lines.append(f" {status} {self._label(obj)}: err={err:.2e}, scale={scl_str}")
112+
113+
if self.solve_residuals:
114+
lines.append(f"\n Implicit solver residuals ({self.solve_iterations} iterations):")
115+
for obj, err in self.solve_residuals.items():
116+
lines.append(f" {self._label(obj)}: {err:.2e}")
117+
118+
if self.loop_residuals:
119+
lines.append(f"\n Algebraic loop residuals ({self.loop_iterations} iterations):")
120+
for obj, err in self.loop_residuals.items():
121+
lines.append(f" {self._label(obj)}: {err:.2e}")
122+
123+
return "\n".join(lines)

0 commit comments

Comments
 (0)