|
37 | 37 | from .utils.deprecation import deprecated |
38 | 38 | from .utils.portreference import PortReference |
39 | 39 | from .utils.progresstracker import ProgressTracker |
| 40 | +from .utils.diagnostics import Diagnostics, ConvergenceTracker, StepTracker |
40 | 41 | from .utils.logger import LoggerManager |
41 | 42 |
|
42 | 43 | from .solvers import SSPRK22, SteadyState |
@@ -165,17 +166,18 @@ class Simulation: |
165 | 166 | """ |
166 | 167 |
|
167 | 168 | def __init__( |
168 | | - self, |
169 | | - blocks=None, |
170 | | - connections=None, |
| 169 | + self, |
| 170 | + blocks=None, |
| 171 | + connections=None, |
171 | 172 | events=None, |
172 | | - dt=SIM_TIMESTEP, |
173 | | - dt_min=SIM_TIMESTEP_MIN, |
174 | | - dt_max=SIM_TIMESTEP_MAX, |
175 | | - Solver=SSPRK22, |
176 | | - tolerance_fpi=SIM_TOLERANCE_FPI, |
177 | | - iterations_max=SIM_ITERATIONS_MAX, |
| 173 | + dt=SIM_TIMESTEP, |
| 174 | + dt_min=SIM_TIMESTEP_MIN, |
| 175 | + dt_max=SIM_TIMESTEP_MAX, |
| 176 | + Solver=SSPRK22, |
| 177 | + tolerance_fpi=SIM_TOLERANCE_FPI, |
| 178 | + iterations_max=SIM_ITERATIONS_MAX, |
178 | 179 | log=LOG_ENABLE, |
| 180 | + diagnostics=False, |
179 | 181 | **solver_kwargs |
180 | 182 | ): |
181 | 183 |
|
@@ -226,6 +228,17 @@ def __init__( |
226 | 228 | #flag for setting the simulation active |
227 | 229 | self._active = True |
228 | 230 |
|
| 231 | + #convergence trackers for the three solver loops |
| 232 | + self._loop_tracker = ConvergenceTracker() |
| 233 | + self._solve_tracker = ConvergenceTracker() |
| 234 | + self._step_tracker = StepTracker() |
| 235 | + |
| 236 | + #diagnostics snapshot (None when disabled) |
| 237 | + self.diagnostics = Diagnostics() if diagnostics else None |
| 238 | + |
| 239 | + #diagnostics history (list of snapshots per timestep) |
| 240 | + self._diagnostics_history = [] if diagnostics == "history" else None |
| 241 | + |
229 | 242 | #initialize logging |
230 | 243 | logger_mgr = LoggerManager( |
231 | 244 | enabled=bool(self.log), |
@@ -815,6 +828,15 @@ def reset(self, time=0.0): |
815 | 828 | for event in self.events: |
816 | 829 | event.reset() |
817 | 830 |
|
| 831 | + #reset convergence trackers and diagnostics |
| 832 | + self._loop_tracker.reset() |
| 833 | + self._solve_tracker.reset() |
| 834 | + self._step_tracker.reset() |
| 835 | + if self.diagnostics is not None: |
| 836 | + self.diagnostics = Diagnostics() |
| 837 | + if self._diagnostics_history is not None: |
| 838 | + self._diagnostics_history.clear() |
| 839 | + |
818 | 840 | #evaluate system function |
819 | 841 | self._update(self.time) |
820 | 842 |
|
@@ -1026,19 +1048,20 @@ def _loops(self, t): |
1026 | 1048 | if connection: connection.update() |
1027 | 1049 |
|
1028 | 1050 | #step boosters of loop closing connections |
1029 | | - max_err = 0.0 |
| 1051 | + self._loop_tracker.begin_iteration() |
1030 | 1052 | for con_booster in self.boosters: |
1031 | | - err = con_booster.update() |
1032 | | - if err > max_err: |
1033 | | - max_err = err |
1034 | | - |
| 1053 | + self._loop_tracker.record(con_booster, con_booster.update()) |
| 1054 | + |
1035 | 1055 | #check convergence |
1036 | | - if max_err <= self.tolerance_fpi: |
| 1056 | + if self._loop_tracker.converged(self.tolerance_fpi): |
| 1057 | + self._loop_tracker.iterations = iteration |
1037 | 1058 | return |
1038 | 1059 |
|
1039 | | - #not converged -> error |
1040 | | - _msg = "algebraic loop not converged (iters: {}, err: {})".format( |
1041 | | - self.iterations_max, max_err |
| 1060 | + #not converged -> error with per-connection details |
| 1061 | + self._loop_tracker.iterations = self.iterations_max |
| 1062 | + details = self._loop_tracker.details(lambda b: str(b.connection)) |
| 1063 | + _msg = "algebraic loop not converged (iters: {}, err: {:.2e})\n{}".format( |
| 1064 | + self.iterations_max, self._loop_tracker.max_error, "\n".join(details) |
1042 | 1065 | ) |
1043 | 1066 | self.logger.error(_msg) |
1044 | 1067 | raise RuntimeError(_msg) |
@@ -1080,26 +1103,21 @@ def _solve(self, t, dt): |
1080 | 1103 |
|
1081 | 1104 | #evaluate system equation (this is a fixed point loop) |
1082 | 1105 | self._update(t) |
1083 | | - total_evals += 1 |
| 1106 | + total_evals += 1 |
1084 | 1107 |
|
1085 | 1108 | #advance solution of implicit solver |
1086 | | - max_error = 0.0 |
| 1109 | + self._solve_tracker.begin_iteration() |
1087 | 1110 | for block in self._blocks_dyn: |
1088 | | - |
1089 | | - #skip inactive blocks |
1090 | | - if not block: |
| 1111 | + if not block: |
1091 | 1112 | continue |
1092 | | - |
1093 | | - #advance solution (internal optimizer) |
1094 | | - error = block.solve(t, dt) |
1095 | | - if error > max_error: |
1096 | | - max_error = error |
| 1113 | + self._solve_tracker.record(block, block.solve(t, dt)) |
1097 | 1114 |
|
1098 | | - #check for convergence (only error) |
1099 | | - if max_error <= self.tolerance_fpi: |
1100 | | - return True, total_evals, it+1 |
| 1115 | + #check for convergence |
| 1116 | + if self._solve_tracker.converged(self.tolerance_fpi): |
| 1117 | + self._solve_tracker.iterations = it + 1 |
| 1118 | + return True, total_evals, it + 1 |
1101 | 1119 |
|
1102 | | - #not converged in 'self.iterations_max' steps |
| 1120 | + self._solve_tracker.iterations = self.iterations_max |
1103 | 1121 | return False, total_evals, self.iterations_max |
1104 | 1122 |
|
1105 | 1123 |
|
@@ -1156,8 +1174,9 @@ def steadystate(self, reset=False): |
1156 | 1174 |
|
1157 | 1175 | #catch non convergence |
1158 | 1176 | if not success: |
1159 | | - _msg = "STEADYSTATE -> FINISHED (success: {}, evals: {}, iters: {}, runtime: {})".format( |
1160 | | - success, evals, iters, T) |
| 1177 | + details = self._solve_tracker.details(lambda b: b.__class__.__name__) |
| 1178 | + _msg = "STEADYSTATE -> FAILED (evals: {}, iters: {}, runtime: {})\n{}".format( |
| 1179 | + evals, iters, T, "\n".join(details)) |
1161 | 1180 | self.logger.error(_msg) |
1162 | 1181 | raise RuntimeError(_msg) |
1163 | 1182 |
|
@@ -1276,32 +1295,14 @@ def _step(self, t, dt): |
1276 | 1295 | rescale factor for timestep |
1277 | 1296 | """ |
1278 | 1297 |
|
1279 | | - #initial timestep rescale and error estimate |
1280 | | - success, max_error_norm, min_scale = True, 0.0, None |
| 1298 | + self._step_tracker.reset() |
1281 | 1299 |
|
1282 | | - #step blocks and get error estimates if available |
1283 | 1300 | for block in self._blocks_dyn: |
1284 | | - |
1285 | | - #skip inactive blocks |
1286 | 1301 | if not block: continue |
1287 | | - |
1288 | | - #step the block |
1289 | 1302 | suc, err_norm, scl = block.step(t, dt) |
| 1303 | + self._step_tracker.record(block, suc, err_norm, scl) |
1290 | 1304 |
|
1291 | | - #check solver stepping success |
1292 | | - if not suc: |
1293 | | - success = False |
1294 | | - |
1295 | | - #update error tracking |
1296 | | - if err_norm > max_error_norm: |
1297 | | - max_error_norm = err_norm |
1298 | | - |
1299 | | - #track minimum relevant scale directly (avoids list allocation) |
1300 | | - if scl is not None: |
1301 | | - if min_scale is None or scl < min_scale: |
1302 | | - min_scale = scl |
1303 | | - |
1304 | | - return success, max_error_norm, min_scale if min_scale is not None else 1.0 |
| 1305 | + return self._step_tracker.success, self._step_tracker.max_error, self._step_tracker.scale |
1305 | 1306 |
|
1306 | 1307 |
|
1307 | 1308 | # timestepping ---------------------------------------------------------------- |
@@ -1469,10 +1470,19 @@ def timestep(self, dt=None, adaptive=True): |
1469 | 1470 | total_evals += evals |
1470 | 1471 | total_solver_its += solver_its |
1471 | 1472 |
|
1472 | | - #adaptive implicit: revert if solver didn't converge |
1473 | | - if not success and is_adaptive: |
1474 | | - self._revert(self.time) |
1475 | | - return False, 0.0, 0.5, total_evals + 1, total_solver_its |
| 1473 | + #implicit solver didn't converge |
| 1474 | + if not success: |
| 1475 | + details = self._solve_tracker.details(lambda b: b.__class__.__name__) |
| 1476 | + if is_adaptive: |
| 1477 | + self.logger.warning( |
| 1478 | + "implicit solver not converged, reverting step at t={:.6f}\n{}".format( |
| 1479 | + time_stage, "\n".join(details))) |
| 1480 | + self._revert(self.time) |
| 1481 | + return False, 0.0, 0.5, total_evals + 1, total_solver_its |
| 1482 | + else: |
| 1483 | + self.logger.warning( |
| 1484 | + "implicit solver not converged at t={:.6f} (iters: {})\n{}".format( |
| 1485 | + time_stage, solver_its, "\n".join(details))) |
1476 | 1486 | else: |
1477 | 1487 | #explicit: evaluate system equation |
1478 | 1488 | self._update(time_stage) |
@@ -1511,6 +1521,19 @@ def timestep(self, dt=None, adaptive=True): |
1511 | 1521 | self._update(time_dt) |
1512 | 1522 | total_evals += 1 |
1513 | 1523 |
|
| 1524 | + #update diagnostics snapshot for this timestep |
| 1525 | + if self.diagnostics is not None: |
| 1526 | + self.diagnostics = Diagnostics( |
| 1527 | + time=time_dt, |
| 1528 | + loop_residuals=dict(self._loop_tracker.errors), |
| 1529 | + loop_iterations=self._loop_tracker.iterations, |
| 1530 | + solve_residuals=dict(self._solve_tracker.errors), |
| 1531 | + solve_iterations=self._solve_tracker.iterations, |
| 1532 | + step_errors=dict(self._step_tracker.errors), |
| 1533 | + ) |
| 1534 | + if self._diagnostics_history is not None: |
| 1535 | + self._diagnostics_history.append(self.diagnostics) |
| 1536 | + |
1514 | 1537 | #sample data after successful timestep |
1515 | 1538 | self._sample(time_dt, dt) |
1516 | 1539 |
|
|
0 commit comments