1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
16 
17 #include "absl/strings/numbers.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26 
27 namespace tensorflow {
28 namespace {
DeregisterCancellation(BufRendezvous::Hook * h)29 void DeregisterCancellation(BufRendezvous::Hook* h) {
30   if (h->cancellation_manager != nullptr) {
31     h->cancellation_manager->DeregisterCallback(h->cancellation_token);
32     h->cancellation_manager = nullptr;
33     h->cancellation_token = CancellationManager::kInvalidToken;
34   }
35 }
36 }  // namespace
37 
~BufRendezvous()38 BufRendezvous::~BufRendezvous() {
39   mutex_lock l(mu_);
40   if (!hook_table_.empty()) {
41     PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
42                &hook_table_);
43   }
44 }
45 
StartAbort(const Status & s)46 void BufRendezvous::StartAbort(const Status& s) {
47   CHECK(!s.ok());
48   HookTable dummy_table;
49   {
50     mutex_lock l(mu_);
51     // Use a "derived" status as the status for the rendezvous. Derived
52     // status messages are ignored when aggregating errors across devices: this
53     // allows us to prefer our original status message over any cancellation
54     // related errors.
55     status_.Update(StatusGroup::MakeDerived(s));
56     hook_table_.swap(dummy_table);
57   }
58   PurgeTable(s, &dummy_table);
59 }
60 
PurgeTable(const Status & s,HookTable * table)61 void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
62   for (auto& it : *table) {
63     Hook* h = it.second;
64     if (h->cancellation_manager != nullptr) {
65       h->cancellation_manager->TryDeregisterCallback(h->cancellation_token);
66     }
67     if (h->cons_cb != nullptr) {
68       h->cons_cb(s, nullptr);
69     }
70     if (h->prod_cb != nullptr) {
71       h->prod_cb(s);
72     }
73     delete h;
74   }
75   table->clear();
76 }
77 
DebugString() const78 string BufRendezvous::Hook::DebugString() const {
79   return absl::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
80                       ", ctx:", reinterpret_cast<uint64>(prod_ctx),
81                       ", val:", reinterpret_cast<uint64>(prod_value),
82                       ", pcb:", reinterpret_cast<uint64>(&prod_cb),
83                       ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
84 }
85 
ProvideBuf(const string & key,Device * dev,DeviceContext * dev_ctx,const Tensor * v,const AllocatorAttributes & attr,const ProducerCallback & done,CancellationManager * cancellation_manager)86 void BufRendezvous::ProvideBuf(const string& key, Device* dev,
87                                DeviceContext* dev_ctx, const Tensor* v,
88                                const AllocatorAttributes& attr,
89                                const ProducerCallback& done,
90                                CancellationManager* cancellation_manager) {
91   Hook* h = nullptr;
92   Status providebuf_status;
93   do {
94     mutex_lock l(mu_);
95     if (!status_.ok()) {
96       providebuf_status = status_;
97       break;
98     } else {
99       CancellationToken cancellation_token = CancellationManager::kInvalidToken;
100       auto it = hook_table_.find(key);
101       if (it == hook_table_.end()) {
102         if (cancellation_manager != nullptr) {
103           cancellation_token = cancellation_manager->get_cancellation_token();
104         }
105         h = new Hook(cancellation_manager, cancellation_token);
106         it = hook_table_.insert(std::make_pair(key, h)).first;
107       } else {
108         if (it->second->prod_cb != nullptr) {
109           providebuf_status = errors::Internal(
110               "BufRendezvous::ProvideBuf already called for key ", key);
111           break;
112         }
113         h = it->second;
114       }
115       // Populate Hook with all of the prod values.
116       h->prod_dev = dev;
117       h->prod_ctx = dev_ctx;
118       h->prod_value = v;
119       h->prod_attr = attr;
120       h->prod_cb = done;
121       if (h->cons_cb != nullptr) {
122         // If consumer is waiting, kick off right away, removing Hook from
123         // table.
124         hook_table_.erase(it);
125       } else {
126         if (cancellation_manager != nullptr &&
127             !cancellation_manager->RegisterCallback(
128                 cancellation_token, [this, key]() { CancelHook(key); })) {
129           // Register cancellation callback with CancellationManager.  If it is
130           // already cancelled, call done immediately with cancelled status.
131           providebuf_status = errors::Cancelled(
132               "Operation was cancelled for BufRendezvous key ", key);
133           hook_table_.erase(it);
134           delete h;
135         }
136         h = nullptr;
137       }
138     }
139   } while (false);
140   if (h) {
141     DeregisterCancellation(h);
142     h->cons_cb(Status::OK(), h);
143   }
144   if (!providebuf_status.ok()) {
145     done(providebuf_status);
146   }
147 }
148 
ConsumeBuf(const string & key,const string & device_name,const uint64 device_incarnation,const ConsumerCallback & done,CancellationManager * cancellation_manager)149 void BufRendezvous::ConsumeBuf(const string& key, const string& device_name,
150                                const uint64 device_incarnation,
151                                const ConsumerCallback& done,
152                                CancellationManager* cancellation_manager) {
153   // Check the incarnation in the request matches the current device
154   // incarnation of the producer.
155   Device* device;
156   Status consumebuf_status = dev_mgr_->LookupDevice(device_name, &device);
157   if (consumebuf_status.ok() &&
158       device->attributes().incarnation() != device_incarnation) {
159     consumebuf_status = errors::FailedPrecondition(
160         "RecvBuf expects a different device incarnation: ", device_incarnation,
161         " vs. ", device->attributes().incarnation(),
162         ". Your worker job that contains the device (\"", device_name,
163         "\") was probably restarted. Check your "
164         "worker job for the reason why it was restarted.");
165   }
166   if (!consumebuf_status.ok()) {
167     done(consumebuf_status, nullptr);
168     return;
169   }
170 
171   Hook* existing_hook = nullptr;
172   do {
173     mutex_lock l(mu_);
174     if (!status_.ok()) {
175       consumebuf_status = status_;
176       break;
177     }
178     auto it = hook_table_.find(key);
179     if (it != hook_table_.end()) {
180       // Prepare to consume immediately.
181       if (it->second->cons_cb) {
182         consumebuf_status =
183             errors::Internal("Second consumer arrived for key ", key);
184         break;
185       }
186       existing_hook = it->second;
187       hook_table_.erase(it);
188       existing_hook->cons_cb = done;
189     } else {
190       // Hang consumer callback on the Hook.
191       CancellationToken cancellation_token = CancellationManager::kInvalidToken;
192       bool already_cancelled = false;
193       if (cancellation_manager != nullptr) {
194         cancellation_token = cancellation_manager->get_cancellation_token();
195         already_cancelled = !cancellation_manager->RegisterCallback(
196             cancellation_token, [this, key]() { CancelHook(key); });
197       }
198       if (already_cancelled) {
199         consumebuf_status = errors::Cancelled(
200             "Operation was cancelled for BufRendezvous key ", key);
201       } else {
202         Hook* h = new Hook(cancellation_manager, cancellation_token);
203         h->cons_cb = done;
204         it = hook_table_.insert(std::make_pair(key, h)).first;
205         return;
206       }
207     }
208   } while (false);
209   if (existing_hook) {
210     DeregisterCancellation(existing_hook);
211     existing_hook->cons_cb(Status::OK(), existing_hook);
212     return;
213   }
214   if (!consumebuf_status.ok()) {
215     done(consumebuf_status, nullptr);
216     return;
217   }
218 }
219 
CancelHook(const string & key)220 void BufRendezvous::CancelHook(const string& key) {
221   Hook* h = nullptr;
222   {
223     mutex_lock l(mu_);
224     auto it = hook_table_.find(key);
225     if (it == hook_table_.end()) return;
226     h = it->second;
227     hook_table_.erase(it);
228   }
229   if (h != nullptr) {
230     auto s = errors::Cancelled("Operation was cancelled for BufRendezvous key ",
231                                key);
232     if (h->prod_cb != nullptr) {
233       h->prod_cb(s);
234     }
235     if (h->cons_cb != nullptr) {
236       h->cons_cb(s, /*Hook=*/nullptr);
237     }
238     delete h;
239   }
240 }
241 
242 /*static*/
DoneWithHook(Hook * h)243 void BufRendezvous::DoneWithHook(Hook* h) {
244   h->prod_cb(Status::OK());
245   delete h;
246 }
247 
LogContents()248 void BufRendezvous::LogContents() {
249   mutex_lock l(mu_);
250   LOG(INFO) << strings::StrCat("BufRendezvous ",
251                                strings::Hex(reinterpret_cast<uint64>(this)),
252                                " step_id=", step_id_, " current contents:");
253   for (const auto& it : hook_table_) {
254     LOG(INFO) << it.first << ":" << it.second->DebugString();
255   }
256 }
257 
258 }  // namespace tensorflow
259