diff --git a/Lib/bdb.py b/Lib/bdb.py index f256b56daaa..79da4bab9c9 100644 --- a/Lib/bdb.py +++ b/Lib/bdb.py @@ -2,7 +2,9 @@ import fnmatch import sys +import threading import os +import weakref from contextlib import contextmanager from inspect import CO_GENERATOR, CO_COROUTINE, CO_ASYNC_GENERATOR @@ -15,6 +17,166 @@ class BdbQuit(Exception): """Exception to give up completely.""" +E = sys.monitoring.events + +class _MonitoringTracer: + EVENT_CALLBACK_MAP = { + E.PY_START: 'call', + E.PY_RESUME: 'call', + E.PY_THROW: 'call', + E.LINE: 'line', + E.JUMP: 'jump', + E.PY_RETURN: 'return', + E.PY_YIELD: 'return', + E.PY_UNWIND: 'unwind', + E.RAISE: 'exception', + E.STOP_ITERATION: 'exception', + E.INSTRUCTION: 'opcode', + } + + GLOBAL_EVENTS = E.PY_START | E.PY_RESUME | E.PY_THROW | E.PY_UNWIND | E.RAISE + LOCAL_EVENTS = E.LINE | E.JUMP | E.PY_RETURN | E.PY_YIELD | E.STOP_ITERATION + + def __init__(self): + self._tool_id = sys.monitoring.DEBUGGER_ID + self._name = 'bdbtracer' + self._tracefunc = None + self._disable_current_event = False + self._tracing_thread = None + self._enabled = False + + def start_trace(self, tracefunc): + self._tracefunc = tracefunc + self._tracing_thread = threading.current_thread() + curr_tool = sys.monitoring.get_tool(self._tool_id) + if curr_tool is None: + sys.monitoring.use_tool_id(self._tool_id, self._name) + elif curr_tool == self._name: + sys.monitoring.clear_tool_id(self._tool_id) + else: + raise ValueError('Another debugger is using the monitoring tool') + E = sys.monitoring.events + all_events = 0 + for event, cb_name in self.EVENT_CALLBACK_MAP.items(): + callback = self.callback_wrapper(getattr(self, f'{cb_name}_callback'), event) + sys.monitoring.register_callback(self._tool_id, event, callback) + if event != E.INSTRUCTION: + all_events |= event + self.update_local_events() + sys.monitoring.set_events(self._tool_id, self.GLOBAL_EVENTS) + self._enabled = True + + def stop_trace(self): + self._enabled = False + self._tracing_thread = None + curr_tool = sys.monitoring.get_tool(self._tool_id) + if curr_tool != self._name: + return + sys.monitoring.clear_tool_id(self._tool_id) + sys.monitoring.free_tool_id(self._tool_id) + + def disable_current_event(self): + self._disable_current_event = True + + def restart_events(self): + if sys.monitoring.get_tool(self._tool_id) == self._name: + sys.monitoring.restart_events() + + def callback_wrapper(self, func, event): + import functools + + @functools.wraps(func) + def wrapper(*args): + if self._tracing_thread != threading.current_thread(): + return + try: + frame = sys._getframe().f_back + ret = func(frame, *args) + if self._enabled and frame.f_trace: + self.update_local_events() + if ( + self._disable_current_event + and event not in (E.PY_THROW, E.PY_UNWIND, E.RAISE) + ): + return sys.monitoring.DISABLE + else: + return ret + except BaseException: + self.stop_trace() + sys._getframe().f_back.f_trace = None + raise + finally: + self._disable_current_event = False + + return wrapper + + def call_callback(self, frame, code, *args): + local_tracefunc = self._tracefunc(frame, 'call', None) + if local_tracefunc is not None: + frame.f_trace = local_tracefunc + if self._enabled: + sys.monitoring.set_local_events(self._tool_id, code, self.LOCAL_EVENTS) + + def return_callback(self, frame, code, offset, retval): + if frame.f_trace: + frame.f_trace(frame, 'return', retval) + + def unwind_callback(self, frame, code, *args): + if frame.f_trace: + frame.f_trace(frame, 'return', None) + + def line_callback(self, frame, code, *args): + if frame.f_trace and frame.f_trace_lines: + frame.f_trace(frame, 'line', None) + + def jump_callback(self, frame, code, inst_offset, dest_offset): + if dest_offset > inst_offset: + return sys.monitoring.DISABLE + inst_lineno = self._get_lineno(code, inst_offset) + dest_lineno = self._get_lineno(code, dest_offset) + if inst_lineno != dest_lineno: + return sys.monitoring.DISABLE + if frame.f_trace and frame.f_trace_lines: + frame.f_trace(frame, 'line', None) + + def exception_callback(self, frame, code, offset, exc): + if frame.f_trace: + if exc.__traceback__ and hasattr(exc.__traceback__, 'tb_frame'): + tb = exc.__traceback__ + while tb: + if tb.tb_frame.f_locals.get('self') is self: + return + tb = tb.tb_next + frame.f_trace(frame, 'exception', (type(exc), exc, exc.__traceback__)) + + def opcode_callback(self, frame, code, offset): + if frame.f_trace and frame.f_trace_opcodes: + frame.f_trace(frame, 'opcode', None) + + def update_local_events(self, frame=None): + if sys.monitoring.get_tool(self._tool_id) != self._name: + return + if frame is None: + frame = sys._getframe().f_back + while frame is not None: + if frame.f_trace is not None: + if frame.f_trace_opcodes: + events = self.LOCAL_EVENTS | E.INSTRUCTION + else: + events = self.LOCAL_EVENTS + sys.monitoring.set_local_events(self._tool_id, frame.f_code, events) + frame = frame.f_back + + def _get_lineno(self, code, offset): + import dis + last_lineno = None + for start, lineno in dis.findlinestarts(code): + if offset < start: + return last_lineno + last_lineno = lineno + return last_lineno + + class Bdb: """Generic Python debugger base class. @@ -29,7 +191,7 @@ class Bdb: is determined by the __name__ in the frame globals. """ - def __init__(self, skip=None): + def __init__(self, skip=None, backend='settrace'): self.skip = set(skip) if skip else None self.breaks = {} self.fncache = {} @@ -39,6 +201,14 @@ def __init__(self, skip=None): self.enterframe = None self.cmdframe = None self.cmdlineno = None + self.code_linenos = weakref.WeakKeyDictionary() + self.backend = backend + if backend == 'monitoring': + self.monitoring_tracer = _MonitoringTracer() + elif backend == 'settrace': + self.monitoring_tracer = None + else: + raise ValueError(f"Invalid backend '{backend}'") self._load_breaks() @@ -59,6 +229,18 @@ def canonic(self, filename): self.fncache[filename] = canonic return canonic + def start_trace(self): + if self.monitoring_tracer: + self.monitoring_tracer.start_trace(self.trace_dispatch) + else: + sys.settrace(self.trace_dispatch) + + def stop_trace(self): + if self.monitoring_tracer: + self.monitoring_tracer.stop_trace() + else: + sys.settrace(None) + def reset(self): """Set values of attributes as ready to start debugging.""" import linecache @@ -133,7 +315,10 @@ def dispatch_line(self, frame): self.cmdframe == frame and self.cmdlineno == frame.f_lineno ): self.user_line(frame) + self.restart_events() if self.quitting: raise BdbQuit + elif not self.get_break(frame.f_code.co_filename, frame.f_lineno): + self.disable_current_event() return self.trace_dispatch def dispatch_call(self, frame, arg): @@ -149,12 +334,18 @@ def dispatch_call(self, frame, arg): self.botframe = frame.f_back # (CT) Note that this may also be None! return self.trace_dispatch if not (self.stop_here(frame) or self.break_anywhere(frame)): - # No need to trace this function + # We already know there's no breakpoint in this function + # If it's a next/until/return command, we don't need any CALL event + # and we don't need to set the f_trace on any new frame. + # If it's a step command, it must either hit stop_here, or skip the + # whole module. Either way, we don't need the CALL event here. + self.disable_current_event() return # None # Ignore call events in generator except when stepping. if self.stopframe and frame.f_code.co_flags & GENERATOR_AND_COROUTINE_FLAGS: return self.trace_dispatch self.user_call(frame, arg) + self.restart_events() if self.quitting: raise BdbQuit return self.trace_dispatch @@ -168,10 +359,14 @@ def dispatch_return(self, frame, arg): if self.stop_here(frame) or frame == self.returnframe: # Ignore return events in generator except when stepping. if self.stopframe and frame.f_code.co_flags & GENERATOR_AND_COROUTINE_FLAGS: + # It's possible to trigger a StopIteration exception in + # the caller so we must set the trace function in the caller + self._set_caller_tracefunc(frame) return self.trace_dispatch try: self.frame_returning = frame self.user_return(frame, arg) + self.restart_events() finally: self.frame_returning = None if self.quitting: raise BdbQuit @@ -199,6 +394,7 @@ def dispatch_exception(self, frame, arg): if not (frame.f_code.co_flags & GENERATOR_AND_COROUTINE_FLAGS and arg[0] is StopIteration and arg[2] is None): self.user_exception(frame, arg) + self.restart_events() if self.quitting: raise BdbQuit # Stop at the StopIteration or GeneratorExit exception when the user # has set stopframe in a generator by issuing a return command, or a @@ -208,6 +404,7 @@ def dispatch_exception(self, frame, arg): and self.stopframe.f_code.co_flags & GENERATOR_AND_COROUTINE_FLAGS and arg[0] in (StopIteration, GeneratorExit)): self.user_exception(frame, arg) + self.restart_events() if self.quitting: raise BdbQuit return self.trace_dispatch @@ -217,10 +414,14 @@ def dispatch_opcode(self, frame, arg): If the debugger stops on the current opcode, invoke self.user_opcode(). Raise BdbQuit if self.quitting is set. Return self.trace_dispatch to continue tracing in this scope. + + Opcode event will always trigger the user callback. For now the only + opcode event is from an inline set_trace() and we want to stop there + unconditionally. """ - if self.stop_here(frame) or self.break_here(frame): - self.user_opcode(frame) - if self.quitting: raise BdbQuit + self.user_opcode(frame) + self.restart_events() + if self.quitting: raise BdbQuit return self.trace_dispatch # Normally derived classes don't override the following @@ -286,9 +487,25 @@ def do_clear(self, arg): raise NotImplementedError("subclass of bdb must implement do_clear()") def break_anywhere(self, frame): - """Return True if there is any breakpoint for frame's filename. + """Return True if there is any breakpoint in that frame + """ + filename = self.canonic(frame.f_code.co_filename) + if filename not in self.breaks: + return False + for lineno in self.breaks[filename]: + if self._lineno_in_frame(lineno, frame): + return True + return False + + def _lineno_in_frame(self, lineno, frame): + """Return True if the line number is in the frame's code object. """ - return self.canonic(frame.f_code.co_filename) in self.breaks + code = frame.f_code + if lineno < code.co_firstlineno: + return False + if code not in self.code_linenos: + self.code_linenos[code] = set(lineno for _, _, lineno in code.co_lines()) + return lineno in self.code_linenos[code] # Derived classes should override the user_* methods # to gain control. @@ -322,6 +539,8 @@ def _set_trace_opcodes(self, trace_opcodes): if frame is self.botframe: break frame = frame.f_back + if self.monitoring_tracer: + self.monitoring_tracer.update_local_events() def _set_stopinfo(self, stopframe, returnframe, stoplineno=0, opcode=False, cmdframe=None, cmdlineno=None): @@ -381,7 +600,7 @@ def set_next(self, frame): def set_return(self, frame): """Stop when returning from the given frame.""" if frame.f_code.co_flags & GENERATOR_AND_COROUTINE_FLAGS: - self._set_stopinfo(frame, None, -1) + self._set_stopinfo(frame, frame, -1) else: self._set_stopinfo(frame.f_back, frame) @@ -390,6 +609,7 @@ def set_trace(self, frame=None): If frame is not specified, debugging starts from caller's frame. """ + self.stop_trace() if frame is None: frame = sys._getframe().f_back self.reset() @@ -402,7 +622,8 @@ def set_trace(self, frame=None): frame.f_trace_lines = True frame = frame.f_back self.set_stepinstr() - sys.settrace(self.trace_dispatch) + self.enterframe = None + self.start_trace() def set_continue(self): """Stop only at breakpoints or when finished. @@ -413,13 +634,15 @@ def set_continue(self): self._set_stopinfo(self.botframe, None, -1) if not self.breaks: # no breakpoints; run without debugger overhead - sys.settrace(None) + self.stop_trace() frame = sys._getframe().f_back while frame and frame is not self.botframe: del frame.f_trace frame = frame.f_back for frame, (trace_lines, trace_opcodes) in self.frame_trace_lines_opcodes.items(): frame.f_trace_lines, frame.f_trace_opcodes = trace_lines, trace_opcodes + if self.backend == 'monitoring': + self.monitoring_tracer.update_local_events() self.frame_trace_lines_opcodes = {} def set_quit(self): @@ -430,7 +653,7 @@ def set_quit(self): self.stopframe = self.botframe self.returnframe = None self.quitting = True - sys.settrace(None) + self.stop_trace() # Derived classes and clients can call the following methods # to manipulate breakpoints. These methods return an @@ -658,6 +881,16 @@ def format_stack_entry(self, frame_lineno, lprefix=': '): s += f'{lprefix}Warning: lineno is None' return s + def disable_current_event(self): + """Disable the current event.""" + if self.backend == 'monitoring': + self.monitoring_tracer.disable_current_event() + + def restart_events(self): + """Restart all events.""" + if self.backend == 'monitoring': + self.monitoring_tracer.restart_events() + # The following methods can be called by clients to use # a debugger to debug a statement or an expression. # Both can be given as a string, or a code object. @@ -675,14 +908,14 @@ def run(self, cmd, globals=None, locals=None): self.reset() if isinstance(cmd, str): cmd = compile(cmd, "", "exec") - sys.settrace(self.trace_dispatch) + self.start_trace() try: exec(cmd, globals, locals) except BdbQuit: pass finally: self.quitting = True - sys.settrace(None) + self.stop_trace() def runeval(self, expr, globals=None, locals=None): """Debug an expression executed via the eval() function. @@ -695,14 +928,14 @@ def runeval(self, expr, globals=None, locals=None): if locals is None: locals = globals self.reset() - sys.settrace(self.trace_dispatch) + self.start_trace() try: return eval(expr, globals, locals) except BdbQuit: pass finally: self.quitting = True - sys.settrace(None) + self.stop_trace() def runctx(self, cmd, globals, locals): """For backwards-compatibility. Defers to run().""" @@ -717,7 +950,7 @@ def runcall(self, func, /, *args, **kwds): Return the result of the function call. """ self.reset() - sys.settrace(self.trace_dispatch) + self.start_trace() res = None try: res = func(*args, **kwds) @@ -725,7 +958,7 @@ def runcall(self, func, /, *args, **kwds): pass finally: self.quitting = True - sys.settrace(None) + self.stop_trace() return res diff --git a/Lib/test/test_bdb.py b/Lib/test/test_bdb.py index eb1a7710c5a..f1077d91fdd 100644 --- a/Lib/test/test_bdb.py +++ b/Lib/test/test_bdb.py @@ -614,7 +614,7 @@ def test_step_next_on_last_statement(self): with TracerRun(self) as tracer: tracer.runcall(tfunc_main) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: All paired tuples have not been processed, the last one was number 1 [('next',), ('quit',)] + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: All paired tuples have not been processed, the last one was number 1 [('next',), ('quit',)] def test_stepinstr(self): self.expect_set = [ ('line', 2, 'tfunc_main'), ('stepinstr', ), @@ -762,7 +762,6 @@ def test_skip_with_no_name_module(self): bdb = Bdb(skip=['anything*']) self.assertIs(bdb.is_skipped_module(None), False) - @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs def test_down(self): # Check that set_down() raises BdbError at the newest frame. self.expect_set = [ @@ -784,7 +783,6 @@ def test_up(self): class BreakpointTestCase(BaseTestCase): """Test the breakpoint set method.""" - @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs def test_bp_on_non_existent_module(self): self.expect_set = [ ('line', 2, 'tfunc_import'), ('break', ('/non/existent/module.py', 1)) @@ -792,7 +790,6 @@ def test_bp_on_non_existent_module(self): with TracerRun(self) as tracer: self.assertRaises(BdbError, tracer.runcall, tfunc_import) - @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs def test_bp_after_last_statement(self): code = """ def main(): @@ -969,7 +966,6 @@ def main(): with TracerRun(self) as tracer: tracer.runcall(tfunc_import) - @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs def test_clear_at_no_bp(self): self.expect_set = [ ('line', 2, 'tfunc_import'), ('clear', (__file__, 1)) @@ -1051,8 +1047,9 @@ def main(): ('return', 1, ''), ('quit', ), ] import test_module_for_bdb + ns = {'test_module_for_bdb': test_module_for_bdb} with TracerRun(self) as tracer: - tracer.runeval('test_module_for_bdb.main()', globals(), locals()) + tracer.runeval('test_module_for_bdb.main()', ns, ns) class IssuesTestCase(BaseTestCase): """Test fixed bdb issues.""" @@ -1087,7 +1084,7 @@ def func(): with TracerRun(self) as tracer: tracer.runcall(tfunc_import) - @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs + @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs def test_next_until_return_in_generator(self): # Issue #16596. # Check that set_next(), set_until() and set_return() do not treat the @@ -1129,7 +1126,7 @@ def main(): with TracerRun(self) as tracer: tracer.runcall(tfunc_import) - @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs + @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs def test_next_command_in_generator_for_loop(self): # Issue #16596. code = """ @@ -1161,7 +1158,7 @@ def main(): with TracerRun(self) as tracer: tracer.runcall(tfunc_import) - @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs + @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs def test_next_command_in_generator_with_subiterator(self): # Issue #16596. code = """ @@ -1193,7 +1190,7 @@ def main(): with TracerRun(self) as tracer: tracer.runcall(tfunc_import) - @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs + @unittest.expectedFailure # TODO: RUSTPYTHON; Error in atexit._run_exitfuncs def test_return_command_in_generator_with_subiterator(self): # Issue #16596. code = """ diff --git a/Lib/test/test_sys_settrace.py b/Lib/test/test_sys_settrace.py index 65193d4f8c8..a98b4d22760 100644 --- a/Lib/test/test_sys_settrace.py +++ b/Lib/test/test_sys_settrace.py @@ -1218,8 +1218,6 @@ def test_return(self): def test_exception(self): self.run_test_for_event('exception') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_trash_stack(self): def f(): for i in range(5): @@ -1785,15 +1783,11 @@ async def test_jump_over_async_for_block_before_else(output): # The second set of 'jump' tests are for things that are not allowed: - # TODO: RUSTPYTHON - @unittest.expectedFailure @jump_test(2, 3, [1], (ValueError, 'after')) def test_no_jump_too_far_forwards(output): output.append(1) output.append(2) - # TODO: RUSTPYTHON - @unittest.expectedFailure @jump_test(2, -2, [1], (ValueError, 'before')) def test_no_jump_too_far_backwards(output): output.append(1) @@ -1840,8 +1834,6 @@ def test_no_jump_to_except_4(output): output.append(4) raise e - # TODO: RUSTPYTHON - @unittest.expectedFailure @jump_test(1, 3, [], (ValueError, 'into')) def test_no_jump_forwards_into_for_block(output): output.append(1) @@ -1857,8 +1849,6 @@ async def test_no_jump_forwards_into_async_for_block(output): output.append(3) pass - # TODO: RUSTPYTHON - @unittest.expectedFailure @jump_test(3, 2, [2, 2], (ValueError, 'into')) def test_no_jump_backwards_into_for_block(output): for i in 1, 2: @@ -2020,8 +2010,6 @@ def test_no_jump_into_bare_except_block_from_try_block(output): raise output.append(8) - # TODO: RUSTPYTHON - @unittest.expectedFailure @jump_test(3, 6, [2], (ValueError, "into an 'except'")) def test_no_jump_into_qualified_except_block_from_try_block(output): try: @@ -2087,8 +2075,6 @@ def test_no_jump_over_return_out_of_finally_block(output): return output.append(7) - # TODO: RUSTPYTHON - @unittest.expectedFailure @jump_test(7, 4, [1, 6], (ValueError, 'into')) def test_no_jump_into_for_block_before_else(output): output.append(1) diff --git a/crates/vm/src/protocol/callable.rs b/crates/vm/src/protocol/callable.rs index 9a621dee4f8..316ed36dd19 100644 --- a/crates/vm/src/protocol/callable.rs +++ b/crates/vm/src/protocol/callable.rs @@ -96,37 +96,67 @@ impl core::fmt::Display for TraceEvent { impl VirtualMachine { /// Call registered trace function. + /// + /// Returns the trace function's return value: + /// - `Some(obj)` if the trace function returned a non-None value + /// - `None` if it returned Python None or no trace function was active + /// + /// In CPython's trace protocol: + /// - For 'call' events: the return value determines the per-frame `f_trace` + /// - For 'line'/'return' events: the return value can update `f_trace` #[inline] - pub(crate) fn trace_event(&self, event: TraceEvent, arg: Option) -> PyResult<()> { + pub(crate) fn trace_event( + &self, + event: TraceEvent, + arg: Option, + ) -> PyResult> { if self.use_tracing.get() { self._trace_event_inner(event, arg) } else { - Ok(()) + Ok(None) } } - fn _trace_event_inner(&self, event: TraceEvent, arg: Option) -> PyResult<()> { + fn _trace_event_inner( + &self, + event: TraceEvent, + arg: Option, + ) -> PyResult> { let trace_func = self.trace_func.borrow().to_owned(); let profile_func = self.profile_func.borrow().to_owned(); if self.is_none(&trace_func) && self.is_none(&profile_func) { - return Ok(()); + return Ok(None); } let Some(frame_ref) = self.current_frame() else { - return Ok(()); + return Ok(None); }; let frame: PyObjectRef = frame_ref.into(); let event = self.ctx.new_str(event.to_string()).into(); let args = vec![frame, event, arg.unwrap_or_else(|| self.ctx.none())]; + let mut trace_result = None; + // temporarily disable tracing, during the call to the // tracing function itself. if !self.is_none(&trace_func) { self.use_tracing.set(false); let res = trace_func.call(args.clone(), self); self.use_tracing.set(true); - if res.is_err() { - *self.trace_func.borrow_mut() = self.ctx.none(); + match res { + Ok(result) => { + if !self.is_none(&result) { + trace_result = Some(result); + } + } + Err(e) => { + // trace_trampoline behavior: clear per-frame f_trace + // and propagate the error. + if let Some(frame_ref) = self.current_frame() { + *frame_ref.trace.lock() = self.ctx.none(); + } + return Err(e); + } } } @@ -138,6 +168,6 @@ impl VirtualMachine { *self.profile_func.borrow_mut() = self.ctx.none(); } } - Ok(()) + Ok(trace_result) } } diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index ead41170229..3ee421e04ff 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -1084,19 +1084,25 @@ impl VirtualMachine { // Fire 'call' trace event after pushing frame // (current_frame() now returns the callee's frame) + // + // trace_dispatch protocol (matching CPython's trace_trampoline): + // - For 'call' events, the global trace function is called. + // If it returns non-None, set f_trace to that value (trace this frame). + // If it returns None, leave f_trace unset (skip tracing this frame). + // - For 'return' events, fire if this frame has f_trace set OR if + // a profile function is active (profiling is independent of f_trace). match self.trace_event(TraceEvent::Call, None) { - Ok(()) => { - // Set per-frame trace function so line events fire for this frame. - // Frames entered before sys.settrace() keep trace=None and skip line events. - if self.use_tracing.get() { - let trace_func = self.trace_func.borrow().clone(); - if !self.is_none(&trace_func) { - *frame.trace.lock() = trace_func; - } + Ok(trace_result) => { + if let Some(local_trace) = trace_result { + *frame.trace.lock() = local_trace; } let result = f(frame.clone()); - // Fire 'return' trace event on success - if result.is_ok() { + // Fire 'return' event if frame is being traced or profiled + if result.is_ok() + && self.use_tracing.get() + && (!self.is_none(&frame.trace.lock()) + || !self.is_none(&self.profile_func.borrow())) + { let _ = self.trace_event(TraceEvent::Return, None); } result @@ -1155,9 +1161,18 @@ impl VirtualMachine { use crate::protocol::TraceEvent; match self.trace_event(TraceEvent::Call, None) { - Ok(()) => { + Ok(trace_result) => { + // Update per-frame trace if trace function returned a new local trace + if let Some(local_trace) = trace_result { + *frame.trace.lock() = local_trace; + } let result = f(frame); - if result.is_ok() { + // Fire 'return' event if frame is being traced or profiled + if result.is_ok() + && self.use_tracing.get() + && (!self.is_none(&frame.trace.lock()) + || !self.is_none(&self.profile_func.borrow())) + { let _ = self.trace_event(TraceEvent::Return, None); } result