1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/cancellation.h"
17 
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/platform/logging.h"
20 
21 namespace tensorflow {
22 
23 const CancellationToken CancellationManager::kInvalidToken = -1;
24 
CancellationManager()25 CancellationManager::CancellationManager()
26     : is_cancelling_(false),
27       is_cancelled_(false),
28       next_cancellation_token_(0) {}
29 
Reset()30 void CancellationManager::Reset() {
31   mutex_lock l(mu_);
32   is_cancelling_ = false;
33   is_cancelled_.store(false);
34 }
35 
StartCancel()36 void CancellationManager::StartCancel() {
37   gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
38   {
39     mutex_lock l(mu_);
40     if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
41       return;
42     }
43     is_cancelling_ = true;
44     std::swap(callbacks_, callbacks_to_run);
45   }
46   // We call these callbacks without holding mu_, so that concurrent
47   // calls to DeregisterCallback, which can happen asynchronously, do
48   // not block. The callbacks remain valid because any concurrent call
49   // to DeregisterCallback will block until the
50   // cancelled_notification_ is notified.
51   for (auto key_and_value : callbacks_to_run) {
52     key_and_value.second();
53   }
54   {
55     mutex_lock l(mu_);
56     is_cancelling_ = false;
57     is_cancelled_.store(true, std::memory_order_release);
58   }
59   cancelled_notification_.Notify();
60 }
61 
get_cancellation_token()62 CancellationToken CancellationManager::get_cancellation_token() {
63   mutex_lock l(mu_);
64   return next_cancellation_token_++;
65 }
66 
RegisterCallback(CancellationToken token,CancelCallback callback)67 bool CancellationManager::RegisterCallback(CancellationToken token,
68                                            CancelCallback callback) {
69   mutex_lock l(mu_);
70   CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
71   bool should_register = !is_cancelled_ && !is_cancelling_;
72   if (should_register) {
73     std::swap(callbacks_[token], callback);
74   }
75   return should_register;
76 }
77 
DeregisterCallback(CancellationToken token)78 bool CancellationManager::DeregisterCallback(CancellationToken token) {
79   mu_.lock();
80   if (is_cancelled_) {
81     mu_.unlock();
82     return false;
83   } else if (is_cancelling_) {
84     mu_.unlock();
85     // Wait for all of the cancellation callbacks to be called. This
86     // wait ensures that the caller of DeregisterCallback does not
87     // return immediately and free objects that may be used in the
88     // execution of any currently pending callbacks in StartCancel.
89     cancelled_notification_.WaitForNotification();
90     return false;
91   } else {
92     callbacks_.erase(token);
93     mu_.unlock();
94     return true;
95   }
96 }
97 
TryDeregisterCallback(CancellationToken token)98 bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
99   mutex_lock lock(mu_);
100   if (is_cancelled_ || is_cancelling_) {
101     return false;
102   } else {
103     callbacks_.erase(token);
104     return true;
105   }
106 }
107 
~CancellationManager()108 CancellationManager::~CancellationManager() {
109   if (!callbacks_.empty()) {
110     StartCancel();
111   }
112 }
113 
114 }  // end namespace tensorflow
115