Skip to content

Commit f46fe64

Browse files
mrrytensorflower-gardener
authored andcommitted
Enable reuse of a DirectSession after a run times out.
Fixes tensorflow#5115. Change: 138676008
1 parent a771598 commit f46fe64

3 files changed

Lines changed: 50 additions & 12 deletions

File tree

tensorflow/core/common_runtime/direct_session.cc

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ Status DirectSession::Run(const RunOptions& run_options,
407407
// Create a run state and start execution.
408408
RunState run_state(input_tensor_names, output_names);
409409
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
410+
CancellationManager step_cancellation_manager;
410411

411412
// Send inputs.
412413
TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
@@ -425,7 +426,7 @@ Status DirectSession::Run(const RunOptions& run_options,
425426
Executor::Args args;
426427
args.step_id = step_id_counter_.fetch_add(1);
427428
args.rendezvous = run_state.rendez;
428-
args.cancellation_manager = cancellation_manager_;
429+
args.cancellation_manager = &step_cancellation_manager;
429430
args.runner = [this, pool](Executor::Args::Closure c) {
430431
SchedClosure(pool, std::move(c));
431432
};
@@ -464,13 +465,33 @@ Status DirectSession::Run(const RunOptions& run_options,
464465
}
465466
#endif // GOOGLE_CUDA
466467

468+
// Register this step with session's cancellation manager, so that
469+
// `Session::Close()` will cancel the step.
470+
CancellationToken cancellation_token =
471+
cancellation_manager_->get_cancellation_token();
472+
bool already_cancelled = !cancellation_manager_->RegisterCallback(
473+
cancellation_token, [&step_cancellation_manager]() {
474+
step_cancellation_manager.StartCancel();
475+
});
476+
if (already_cancelled) {
477+
// NOTE(mrry): If we don't explicitly notify
478+
// `run_state.executors_done`, the RunState destructor would
479+
// block on this notification.
480+
run_state.executors_done.Notify();
481+
delete barrier;
482+
return errors::Cancelled("Run call was cancelled");
483+
}
484+
467485
for (const auto& item : executors_and_keys->items) {
468486
item.executor->RunAsync(args, barrier->Get());
469487
}
470488

471-
WaitForNotification(&run_state, run_options.timeout_in_ms() > 0
472-
? run_options.timeout_in_ms()
473-
: operation_timeout_in_ms_);
489+
WaitForNotification(&run_state, &step_cancellation_manager,
490+
run_options.timeout_in_ms() > 0
491+
? run_options.timeout_in_ms()
492+
: operation_timeout_in_ms_);
493+
494+
cancellation_manager_->DeregisterCallback(cancellation_token);
474495

475496
#if GOOGLE_CUDA
476497
if (tracer) {
@@ -687,7 +708,8 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
687708
run_state->pending_outputs.size() == 0);
688709
}
689710
if (done) {
690-
WaitForNotification(run_state, operation_timeout_in_ms_);
711+
WaitForNotification(run_state, cancellation_manager_,
712+
operation_timeout_in_ms_);
691713
partial_runs_.erase(handle);
692714
}
693715
}
@@ -1158,6 +1180,7 @@ DirectSession::RunState::~RunState() {
11581180
}
11591181

11601182
void DirectSession::WaitForNotification(RunState* run_state,
1183+
CancellationManager* cm,
11611184
int64 timeout_in_ms) {
11621185
if (timeout_in_ms > 0) {
11631186
bool notified = WaitForNotificationWithTimeout(&run_state->executors_done,
@@ -1168,12 +1191,7 @@ void DirectSession::WaitForNotification(RunState* run_state,
11681191
run_state->status.Update(Status(error::DEADLINE_EXCEEDED,
11691192
"Timed out waiting for notification"));
11701193
}
1171-
// TODO(sherrym): This cancels all steps in the session, even ones that
1172-
// have not exceeded their deadline. An alternative would be to use a
1173-
// two-level cancellation manager with a Session-global one containing
1174-
// several step-local ones. Probably the RunState should have its own
1175-
// CancellationManager.
1176-
cancellation_manager_->StartCancel();
1194+
cm->StartCancel();
11771195
}
11781196
} else {
11791197
run_state->executors_done.WaitForNotification();

tensorflow/core/common_runtime/direct_session.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ class DirectSession : public Session {
209209

210210
// Use the appropriate WaitForNotification function based on whether
211211
// operation_timeout_in_ms is greater than 0.
212-
void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
212+
//
213+
// If the timeout expires, the `cm->StartCancel()` will be called.
214+
void WaitForNotification(RunState* run_state, CancellationManager* cm,
215+
int64 timeout_in_ms);
213216

214217
::tensorflow::Status CheckNotClosed() {
215218
mutex_lock l(closed_lock_);

tensorflow/python/kernel_tests/fifo_queue_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,23 @@ def testDequeueWithTimeout(self):
15121512
"Timed out waiting for notification"):
15131513
sess.run(dequeued_t)
15141514

1515+
def testReusableAfterTimeout(self):
1516+
with self.test_session() as sess:
1517+
q = tf.FIFOQueue(10, tf.float32)
1518+
dequeued_t = q.dequeue()
1519+
enqueue_op = q.enqueue(37)
1520+
1521+
with self.assertRaisesRegexp(tf.errors.DeadlineExceededError,
1522+
"Timed out waiting for notification"):
1523+
sess.run(dequeued_t, options=tf.RunOptions(timeout_in_ms=10))
1524+
1525+
with self.assertRaisesRegexp(tf.errors.DeadlineExceededError,
1526+
"Timed out waiting for notification"):
1527+
sess.run(dequeued_t, options=tf.RunOptions(timeout_in_ms=10))
1528+
1529+
sess.run(enqueue_op)
1530+
self.assertEqual(37, sess.run(dequeued_t))
1531+
15151532

15161533
class QueueContainerTest(tf.test.TestCase):
15171534

0 commit comments

Comments
 (0)