Skip to content

Commit 371e893

Browse files
committed
Add ConvergenceTracker/StepTracker classes with optional diagnostics snapshots
1 parent d2a2f91 commit 371e893

File tree

3 files changed

+679
-59
lines changed

3 files changed

+679
-59
lines changed

src/pathsim/simulation.py

Lines changed: 82 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .utils.deprecation import deprecated
3838
from .utils.portreference import PortReference
3939
from .utils.progresstracker import ProgressTracker
40+
from .utils.diagnostics import Diagnostics, ConvergenceTracker, StepTracker
4041
from .utils.logger import LoggerManager
4142

4243
from .solvers import SSPRK22, SteadyState
@@ -165,17 +166,18 @@ class Simulation:
165166
"""
166167

167168
def __init__(
168-
self,
169-
blocks=None,
170-
connections=None,
169+
self,
170+
blocks=None,
171+
connections=None,
171172
events=None,
172-
dt=SIM_TIMESTEP,
173-
dt_min=SIM_TIMESTEP_MIN,
174-
dt_max=SIM_TIMESTEP_MAX,
175-
Solver=SSPRK22,
176-
tolerance_fpi=SIM_TOLERANCE_FPI,
177-
iterations_max=SIM_ITERATIONS_MAX,
173+
dt=SIM_TIMESTEP,
174+
dt_min=SIM_TIMESTEP_MIN,
175+
dt_max=SIM_TIMESTEP_MAX,
176+
Solver=SSPRK22,
177+
tolerance_fpi=SIM_TOLERANCE_FPI,
178+
iterations_max=SIM_ITERATIONS_MAX,
178179
log=LOG_ENABLE,
180+
diagnostics=False,
179181
**solver_kwargs
180182
):
181183

@@ -226,6 +228,17 @@ def __init__(
226228
#flag for setting the simulation active
227229
self._active = True
228230

231+
#convergence trackers for the three solver loops
232+
self._loop_tracker = ConvergenceTracker()
233+
self._solve_tracker = ConvergenceTracker()
234+
self._step_tracker = StepTracker()
235+
236+
#diagnostics snapshot (None when disabled)
237+
self.diagnostics = Diagnostics() if diagnostics else None
238+
239+
#diagnostics history (list of snapshots per timestep)
240+
self._diagnostics_history = [] if diagnostics == "history" else None
241+
229242
#initialize logging
230243
logger_mgr = LoggerManager(
231244
enabled=bool(self.log),
@@ -815,6 +828,15 @@ def reset(self, time=0.0):
815828
for event in self.events:
816829
event.reset()
817830

831+
#reset convergence trackers and diagnostics
832+
self._loop_tracker.reset()
833+
self._solve_tracker.reset()
834+
self._step_tracker.reset()
835+
if self.diagnostics is not None:
836+
self.diagnostics = Diagnostics()
837+
if self._diagnostics_history is not None:
838+
self._diagnostics_history.clear()
839+
818840
#evaluate system function
819841
self._update(self.time)
820842

@@ -1026,19 +1048,20 @@ def _loops(self, t):
10261048
if connection: connection.update()
10271049

10281050
#step boosters of loop closing connections
1029-
max_err = 0.0
1051+
self._loop_tracker.begin_iteration()
10301052
for con_booster in self.boosters:
1031-
err = con_booster.update()
1032-
if err > max_err:
1033-
max_err = err
1034-
1053+
self._loop_tracker.record(con_booster, con_booster.update())
1054+
10351055
#check convergence
1036-
if max_err <= self.tolerance_fpi:
1056+
if self._loop_tracker.converged(self.tolerance_fpi):
1057+
self._loop_tracker.iterations = iteration
10371058
return
10381059

1039-
#not converged -> error
1040-
_msg = "algebraic loop not converged (iters: {}, err: {})".format(
1041-
self.iterations_max, max_err
1060+
#not converged -> error with per-connection details
1061+
self._loop_tracker.iterations = self.iterations_max
1062+
details = self._loop_tracker.details(lambda b: str(b.connection))
1063+
_msg = "algebraic loop not converged (iters: {}, err: {:.2e})\n{}".format(
1064+
self.iterations_max, self._loop_tracker.max_error, "\n".join(details)
10421065
)
10431066
self.logger.error(_msg)
10441067
raise RuntimeError(_msg)
@@ -1080,26 +1103,21 @@ def _solve(self, t, dt):
10801103

10811104
#evaluate system equation (this is a fixed point loop)
10821105
self._update(t)
1083-
total_evals += 1
1106+
total_evals += 1
10841107

10851108
#advance solution of implicit solver
1086-
max_error = 0.0
1109+
self._solve_tracker.begin_iteration()
10871110
for block in self._blocks_dyn:
1088-
1089-
#skip inactive blocks
1090-
if not block:
1111+
if not block:
10911112
continue
1092-
1093-
#advance solution (internal optimizer)
1094-
error = block.solve(t, dt)
1095-
if error > max_error:
1096-
max_error = error
1113+
self._solve_tracker.record(block, block.solve(t, dt))
10971114

1098-
#check for convergence (only error)
1099-
if max_error <= self.tolerance_fpi:
1100-
return True, total_evals, it+1
1115+
#check for convergence
1116+
if self._solve_tracker.converged(self.tolerance_fpi):
1117+
self._solve_tracker.iterations = it + 1
1118+
return True, total_evals, it + 1
11011119

1102-
#not converged in 'self.iterations_max' steps
1120+
self._solve_tracker.iterations = self.iterations_max
11031121
return False, total_evals, self.iterations_max
11041122

11051123

@@ -1156,8 +1174,9 @@ def steadystate(self, reset=False):
11561174

11571175
#catch non convergence
11581176
if not success:
1159-
_msg = "STEADYSTATE -> FINISHED (success: {}, evals: {}, iters: {}, runtime: {})".format(
1160-
success, evals, iters, T)
1177+
details = self._solve_tracker.details(lambda b: b.__class__.__name__)
1178+
_msg = "STEADYSTATE -> FAILED (evals: {}, iters: {}, runtime: {})\n{}".format(
1179+
evals, iters, T, "\n".join(details))
11611180
self.logger.error(_msg)
11621181
raise RuntimeError(_msg)
11631182

@@ -1276,32 +1295,14 @@ def _step(self, t, dt):
12761295
rescale factor for timestep
12771296
"""
12781297

1279-
#initial timestep rescale and error estimate
1280-
success, max_error_norm, min_scale = True, 0.0, None
1298+
self._step_tracker.reset()
12811299

1282-
#step blocks and get error estimates if available
12831300
for block in self._blocks_dyn:
1284-
1285-
#skip inactive blocks
12861301
if not block: continue
1287-
1288-
#step the block
12891302
suc, err_norm, scl = block.step(t, dt)
1303+
self._step_tracker.record(block, suc, err_norm, scl)
12901304

1291-
#check solver stepping success
1292-
if not suc:
1293-
success = False
1294-
1295-
#update error tracking
1296-
if err_norm > max_error_norm:
1297-
max_error_norm = err_norm
1298-
1299-
#track minimum relevant scale directly (avoids list allocation)
1300-
if scl is not None:
1301-
if min_scale is None or scl < min_scale:
1302-
min_scale = scl
1303-
1304-
return success, max_error_norm, min_scale if min_scale is not None else 1.0
1305+
return self._step_tracker.success, self._step_tracker.max_error, self._step_tracker.scale
13051306

13061307

13071308
# timestepping ----------------------------------------------------------------
@@ -1469,10 +1470,19 @@ def timestep(self, dt=None, adaptive=True):
14691470
total_evals += evals
14701471
total_solver_its += solver_its
14711472

1472-
#adaptive implicit: revert if solver didn't converge
1473-
if not success and is_adaptive:
1474-
self._revert(self.time)
1475-
return False, 0.0, 0.5, total_evals + 1, total_solver_its
1473+
#implicit solver didn't converge
1474+
if not success:
1475+
details = self._solve_tracker.details(lambda b: b.__class__.__name__)
1476+
if is_adaptive:
1477+
self.logger.warning(
1478+
"implicit solver not converged, reverting step at t={:.6f}\n{}".format(
1479+
time_stage, "\n".join(details)))
1480+
self._revert(self.time)
1481+
return False, 0.0, 0.5, total_evals + 1, total_solver_its
1482+
else:
1483+
self.logger.warning(
1484+
"implicit solver not converged at t={:.6f} (iters: {})\n{}".format(
1485+
time_stage, solver_its, "\n".join(details)))
14761486
else:
14771487
#explicit: evaluate system equation
14781488
self._update(time_stage)
@@ -1511,6 +1521,19 @@ def timestep(self, dt=None, adaptive=True):
15111521
self._update(time_dt)
15121522
total_evals += 1
15131523

1524+
#update diagnostics snapshot for this timestep
1525+
if self.diagnostics is not None:
1526+
self.diagnostics = Diagnostics(
1527+
time=time_dt,
1528+
loop_residuals=dict(self._loop_tracker.errors),
1529+
loop_iterations=self._loop_tracker.iterations,
1530+
solve_residuals=dict(self._solve_tracker.errors),
1531+
solve_iterations=self._solve_tracker.iterations,
1532+
step_errors=dict(self._step_tracker.errors),
1533+
)
1534+
if self._diagnostics_history is not None:
1535+
self._diagnostics_history.append(self.diagnostics)
1536+
15141537
#sample data after successful timestep
15151538
self._sample(time_dt, dt)
15161539

0 commit comments

Comments
 (0)