@@ -407,6 +407,7 @@ Status DirectSession::Run(const RunOptions& run_options,
407407 // Create a run state and start execution.
408408 RunState run_state (input_tensor_names, output_names);
409409 run_state.rendez = new IntraProcessRendezvous (device_mgr_.get ());
410+ CancellationManager step_cancellation_manager;
410411
411412 // Send inputs.
412413 TF_RETURN_IF_ERROR (SendInputs (inputs, executors_and_keys, run_state.rendez ));
@@ -425,7 +426,7 @@ Status DirectSession::Run(const RunOptions& run_options,
425426 Executor::Args args;
426427 args.step_id = step_id_counter_.fetch_add (1 );
427428 args.rendezvous = run_state.rendez ;
428- args.cancellation_manager = cancellation_manager_ ;
429+ args.cancellation_manager = &step_cancellation_manager ;
429430 args.runner = [this , pool](Executor::Args::Closure c) {
430431 SchedClosure (pool, std::move (c));
431432 };
@@ -464,13 +465,33 @@ Status DirectSession::Run(const RunOptions& run_options,
464465 }
465466#endif // GOOGLE_CUDA
466467
468+ // Register this step with session's cancellation manager, so that
469+ // `Session::Close()` will cancel the step.
470+ CancellationToken cancellation_token =
471+ cancellation_manager_->get_cancellation_token ();
472+ bool already_cancelled = !cancellation_manager_->RegisterCallback (
473+ cancellation_token, [&step_cancellation_manager]() {
474+ step_cancellation_manager.StartCancel ();
475+ });
476+ if (already_cancelled) {
477+ // NOTE(mrry): If we don't explicitly notify
478+ // `run_state.executors_done`, the RunState destructor would
479+ // block on this notification.
480+ run_state.executors_done .Notify ();
481+ delete barrier;
482+ return errors::Cancelled (" Run call was cancelled" );
483+ }
484+
467485 for (const auto & item : executors_and_keys->items ) {
468486 item.executor ->RunAsync (args, barrier->Get ());
469487 }
470488
471- WaitForNotification (&run_state, run_options.timeout_in_ms () > 0
472- ? run_options.timeout_in_ms ()
473- : operation_timeout_in_ms_);
489+ WaitForNotification (&run_state, &step_cancellation_manager,
490+ run_options.timeout_in_ms () > 0
491+ ? run_options.timeout_in_ms ()
492+ : operation_timeout_in_ms_);
493+
494+ cancellation_manager_->DeregisterCallback (cancellation_token);
474495
475496#if GOOGLE_CUDA
476497 if (tracer) {
@@ -687,7 +708,8 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
687708 run_state->pending_outputs .size () == 0 );
688709 }
689710 if (done) {
690- WaitForNotification (run_state, operation_timeout_in_ms_);
711+ WaitForNotification (run_state, cancellation_manager_,
712+ operation_timeout_in_ms_);
691713 partial_runs_.erase (handle);
692714 }
693715 }
@@ -1158,6 +1180,7 @@ DirectSession::RunState::~RunState() {
11581180}
11591181
11601182void DirectSession::WaitForNotification (RunState* run_state,
1183+ CancellationManager* cm,
11611184 int64 timeout_in_ms) {
11621185 if (timeout_in_ms > 0 ) {
11631186 bool notified = WaitForNotificationWithTimeout (&run_state->executors_done ,
@@ -1168,12 +1191,7 @@ void DirectSession::WaitForNotification(RunState* run_state,
11681191 run_state->status .Update (Status (error::DEADLINE_EXCEEDED,
11691192 " Timed out waiting for notification" ));
11701193 }
1171- // TODO(sherrym): This cancels all steps in the session, even ones that
1172- // have not exceeded their deadline. An alternative would be to use a
1173- // two-level cancellation manager with a Session-global one containing
1174- // several step-local ones. Probably the RunState should have its own
1175- // CancellationManager.
1176- cancellation_manager_->StartCancel ();
1194+ cm->StartCancel ();
11771195 }
11781196 } else {
11791197 run_state->executors_done .WaitForNotification ();
0 commit comments