1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/cancellation.h"
17 
18 #include <forward_list>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace tensorflow {
25 
26 const CancellationToken CancellationManager::kInvalidToken = -1;
27 
CancellationManager()28 CancellationManager::CancellationManager()
29     : is_cancelling_(false),
30       is_cancelled_(false),
31       next_cancellation_token_(0) {}
32 
CancellationManager(CancellationManager * parent)33 CancellationManager::CancellationManager(CancellationManager* parent)
34     : is_cancelling_(false), next_cancellation_token_(0), parent_(parent) {
35   is_cancelled_ = parent->RegisterChild(this);
36 }
37 
StartCancel()38 void CancellationManager::StartCancel() {
39   gtl::FlatMap<CancellationToken, CancelCallback> callbacks_to_run;
40   std::forward_list<CancellationManager*> children_to_cancel;
41   Notification* cancelled_notification = nullptr;
42   {
43     mutex_lock l(mu_);
44     if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
45       return;
46     }
47     is_cancelling_ = true;
48     if (state_) {
49       std::swap(state_->callbacks, callbacks_to_run);
50 
51       // Remove all children from the list of children.
52       CancellationManager* child = state_->first_child;
53       while (child != nullptr) {
54         children_to_cancel.push_front(child);
55         child->is_removed_from_parent_ = true;
56         child = child->next_sibling_;
57       }
58       state_->first_child = nullptr;
59 
60       cancelled_notification = &state_->cancelled_notification;
61     }
62   }
63   // We call these callbacks without holding mu_, so that concurrent
64   // calls to DeregisterCallback, which can happen asynchronously, do
65   // not block. The callbacks remain valid because any concurrent call
66   // to DeregisterCallback will block until the
67   // cancelled_notification_ is notified.
68   for (auto key_and_value : callbacks_to_run) {
69     key_and_value.second();
70   }
71   for (CancellationManager* child : children_to_cancel) {
72     child->StartCancel();
73   }
74   {
75     mutex_lock l(mu_);
76     is_cancelling_ = false;
77     is_cancelled_.store(true, std::memory_order_release);
78   }
79   if (cancelled_notification) {
80     cancelled_notification->Notify();
81   }
82 }
83 
RegisterCallback(CancellationToken token,CancelCallback callback)84 bool CancellationManager::RegisterCallback(CancellationToken token,
85                                            CancelCallback callback) {
86   DCHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
87   mutex_lock l(mu_);
88   bool should_register = !is_cancelled_ && !is_cancelling_;
89   if (should_register) {
90     if (!state_) {
91       state_ = absl::make_unique<State>();
92     }
93     std::swap(state_->callbacks[token], callback);
94   }
95   return should_register;
96 }
97 
DeregisterCallback(CancellationToken token)98 bool CancellationManager::DeregisterCallback(CancellationToken token) {
99   mu_.lock();
100   if (is_cancelled_) {
101     mu_.unlock();
102     return false;
103   } else if (is_cancelling_) {
104     Notification* cancelled_notification =
105         state_ ? &state_->cancelled_notification : nullptr;
106     mu_.unlock();
107     // Wait for all of the cancellation callbacks to be called. This
108     // wait ensures that the caller of DeregisterCallback does not
109     // return immediately and free objects that may be used in the
110     // execution of any currently pending callbacks in StartCancel.
111     if (cancelled_notification) {
112       cancelled_notification->WaitForNotification();
113     }
114     return false;
115   } else {
116     if (state_) {
117       state_->callbacks.erase(token);
118     }
119     mu_.unlock();
120     return true;
121   }
122 }
123 
RegisterChild(CancellationManager * child)124 bool CancellationManager::RegisterChild(CancellationManager* child) {
125   mutex_lock l(mu_);
126   if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
127     child->is_removed_from_parent_ = true;
128     return true;
129   }
130 
131   if (!state_) {
132     state_ = absl::make_unique<State>();
133   }
134 
135   // Push `child` onto the front of the list of children.
136   CancellationManager* current_head = state_->first_child;
137   state_->first_child = child;
138   child->prev_sibling_ = nullptr;
139   child->next_sibling_ = current_head;
140   if (current_head) {
141     current_head->prev_sibling_ = child;
142   }
143 
144   return false;
145 }
146 
DeregisterChild(CancellationManager * child)147 void CancellationManager::DeregisterChild(CancellationManager* child) {
148   DCHECK_EQ(child->parent_, this);
149   Notification* cancelled_notification = nullptr;
150   {
151     mutex_lock l(mu_);
152     if (!child->is_removed_from_parent_) {
153       // Remove the child from this manager's list of children.
154       DCHECK(state_);
155 
156       if (child->prev_sibling_ == nullptr) {
157         // The child was at the head of the list.
158         DCHECK_EQ(state_->first_child, child);
159         state_->first_child = child->next_sibling_;
160       } else {
161         child->prev_sibling_->next_sibling_ = child->next_sibling_;
162       }
163 
164       if (child->next_sibling_ != nullptr) {
165         child->next_sibling_->prev_sibling_ = child->prev_sibling_;
166       }
167 
168       child->is_removed_from_parent_ = true;
169     }
170     if (is_cancelling_) {
171       cancelled_notification = &state_->cancelled_notification;
172     }
173   }
174 
175   // Wait for an ongoing call to StartCancel() to finish. This wait ensures that
176   // the caller of DeregisterChild does not return immediately and free a child
177   // that may currently be being cancelled by StartCancel().
178   if (cancelled_notification) {
179     cancelled_notification->WaitForNotification();
180   }
181 }
182 
TryDeregisterCallback(CancellationToken token)183 bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
184   mutex_lock lock(mu_);
185   if (is_cancelled_ || is_cancelling_) {
186     return false;
187   } else {
188     if (state_) {
189       state_->callbacks.erase(token);
190     }
191     return true;
192   }
193 }
194 
~CancellationManager()195 CancellationManager::~CancellationManager() {
196   if (parent_) {
197     parent_->DeregisterChild(this);
198   }
199   if (state_) {
200     StartCancel();
201   }
202 }
203 
IsCancelling()204 bool CancellationManager::IsCancelling() {
205   mutex_lock lock(mu_);
206   return is_cancelling_;
207 }
208 
RegisterCancellationCallback(CancellationManager * cancellation_manager,std::function<void ()> callback,std::function<void ()> * deregister_fn)209 Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
210                                     std::function<void()> callback,
211                                     std::function<void()>* deregister_fn) {
212   if (cancellation_manager) {
213     CancellationToken token = cancellation_manager->get_cancellation_token();
214     if (!cancellation_manager->RegisterCallback(token, std::move(callback))) {
215       return errors::Cancelled("Operation was cancelled");
216     }
217     *deregister_fn = [cancellation_manager, token]() {
218       cancellation_manager->DeregisterCallback(token);
219     };
220   } else {
221     VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
222                "not be registered.";
223     *deregister_fn = []() {};
224   }
225   return Status::OK();
226 }
227 
228 }  // end namespace tensorflow
229