1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
16
17 #include "absl/strings/numbers.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/notification.h"
26
27 namespace tensorflow {
28 namespace {
DeregisterCancellation(BufRendezvous::Hook * h)29 void DeregisterCancellation(BufRendezvous::Hook* h) {
30 if (h->cancellation_manager != nullptr) {
31 h->cancellation_manager->DeregisterCallback(h->cancellation_token);
32 h->cancellation_manager = nullptr;
33 h->cancellation_token = CancellationManager::kInvalidToken;
34 }
35 }
36 } // namespace
37
~BufRendezvous()38 BufRendezvous::~BufRendezvous() {
39 mutex_lock l(mu_);
40 if (!hook_table_.empty()) {
41 PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
42 &hook_table_);
43 }
44 }
45
StartAbort(const Status & s)46 void BufRendezvous::StartAbort(const Status& s) {
47 CHECK(!s.ok());
48 HookTable dummy_table;
49 {
50 mutex_lock l(mu_);
51 // Use a "derived" status as the status for the rendezvous. Derived
52 // status messages are ignored when aggregating errors across devices: this
53 // allows us to prefer our original status message over any cancellation
54 // related errors.
55 status_.Update(StatusGroup::MakeDerived(s));
56 hook_table_.swap(dummy_table);
57 }
58 PurgeTable(s, &dummy_table);
59 }
60
PurgeTable(const Status & s,HookTable * table)61 void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
62 for (auto& it : *table) {
63 Hook* h = it.second;
64 if (h->cancellation_manager != nullptr) {
65 h->cancellation_manager->TryDeregisterCallback(h->cancellation_token);
66 }
67 if (h->cons_cb != nullptr) {
68 h->cons_cb(s, nullptr);
69 }
70 if (h->prod_cb != nullptr) {
71 h->prod_cb(s);
72 }
73 delete h;
74 }
75 table->clear();
76 }
77
DebugString() const78 string BufRendezvous::Hook::DebugString() const {
79 return absl::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
80 ", ctx:", reinterpret_cast<uint64>(prod_ctx),
81 ", val:", reinterpret_cast<uint64>(prod_value),
82 ", pcb:", reinterpret_cast<uint64>(&prod_cb),
83 ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
84 }
85
ProvideBuf(const string & key,Device * dev,DeviceContext * dev_ctx,const Tensor * v,const AllocatorAttributes & attr,const ProducerCallback & done,CancellationManager * cancellation_manager)86 void BufRendezvous::ProvideBuf(const string& key, Device* dev,
87 DeviceContext* dev_ctx, const Tensor* v,
88 const AllocatorAttributes& attr,
89 const ProducerCallback& done,
90 CancellationManager* cancellation_manager) {
91 Hook* h = nullptr;
92 Status providebuf_status;
93 do {
94 mutex_lock l(mu_);
95 if (!status_.ok()) {
96 providebuf_status = status_;
97 break;
98 } else {
99 CancellationToken cancellation_token = CancellationManager::kInvalidToken;
100 auto it = hook_table_.find(key);
101 if (it == hook_table_.end()) {
102 if (cancellation_manager != nullptr) {
103 cancellation_token = cancellation_manager->get_cancellation_token();
104 }
105 h = new Hook(cancellation_manager, cancellation_token);
106 it = hook_table_.insert(std::make_pair(key, h)).first;
107 } else {
108 if (it->second->prod_cb != nullptr) {
109 providebuf_status = errors::Internal(
110 "BufRendezvous::ProvideBuf already called for key ", key);
111 break;
112 }
113 h = it->second;
114 }
115 // Populate Hook with all of the prod values.
116 h->prod_dev = dev;
117 h->prod_ctx = dev_ctx;
118 h->prod_value = v;
119 h->prod_attr = attr;
120 h->prod_cb = done;
121 if (h->cons_cb != nullptr) {
122 // If consumer is waiting, kick off right away, removing Hook from
123 // table.
124 hook_table_.erase(it);
125 } else {
126 if (cancellation_manager != nullptr &&
127 !cancellation_manager->RegisterCallback(
128 cancellation_token, [this, key]() { CancelHook(key); })) {
129 // Register cancellation callback with CancellationManager. If it is
130 // already cancelled, call done immediately with cancelled status.
131 providebuf_status = errors::Cancelled(
132 "Operation was cancelled for BufRendezvous key ", key);
133 hook_table_.erase(it);
134 delete h;
135 }
136 h = nullptr;
137 }
138 }
139 } while (false);
140 if (h) {
141 DeregisterCancellation(h);
142 h->cons_cb(Status::OK(), h);
143 }
144 if (!providebuf_status.ok()) {
145 done(providebuf_status);
146 }
147 }
148
ConsumeBuf(const string & key,const string & device_name,const uint64 device_incarnation,const ConsumerCallback & done,CancellationManager * cancellation_manager)149 void BufRendezvous::ConsumeBuf(const string& key, const string& device_name,
150 const uint64 device_incarnation,
151 const ConsumerCallback& done,
152 CancellationManager* cancellation_manager) {
153 // Check the incarnation in the request matches the current device
154 // incarnation of the producer.
155 Device* device;
156 Status consumebuf_status = dev_mgr_->LookupDevice(device_name, &device);
157 if (consumebuf_status.ok() &&
158 device->attributes().incarnation() != device_incarnation) {
159 consumebuf_status = errors::FailedPrecondition(
160 "RecvBuf expects a different device incarnation: ", device_incarnation,
161 " vs. ", device->attributes().incarnation(),
162 ". Your worker job that contains the device (\"", device_name,
163 "\") was probably restarted. Check your "
164 "worker job for the reason why it was restarted.");
165 }
166 if (!consumebuf_status.ok()) {
167 done(consumebuf_status, nullptr);
168 return;
169 }
170
171 Hook* existing_hook = nullptr;
172 do {
173 mutex_lock l(mu_);
174 if (!status_.ok()) {
175 consumebuf_status = status_;
176 break;
177 }
178 auto it = hook_table_.find(key);
179 if (it != hook_table_.end()) {
180 // Prepare to consume immediately.
181 if (it->second->cons_cb) {
182 consumebuf_status =
183 errors::Internal("Second consumer arrived for key ", key);
184 break;
185 }
186 existing_hook = it->second;
187 hook_table_.erase(it);
188 existing_hook->cons_cb = done;
189 } else {
190 // Hang consumer callback on the Hook.
191 CancellationToken cancellation_token = CancellationManager::kInvalidToken;
192 bool already_cancelled = false;
193 if (cancellation_manager != nullptr) {
194 cancellation_token = cancellation_manager->get_cancellation_token();
195 already_cancelled = !cancellation_manager->RegisterCallback(
196 cancellation_token, [this, key]() { CancelHook(key); });
197 }
198 if (already_cancelled) {
199 consumebuf_status = errors::Cancelled(
200 "Operation was cancelled for BufRendezvous key ", key);
201 } else {
202 Hook* h = new Hook(cancellation_manager, cancellation_token);
203 h->cons_cb = done;
204 it = hook_table_.insert(std::make_pair(key, h)).first;
205 return;
206 }
207 }
208 } while (false);
209 if (existing_hook) {
210 DeregisterCancellation(existing_hook);
211 existing_hook->cons_cb(Status::OK(), existing_hook);
212 return;
213 }
214 if (!consumebuf_status.ok()) {
215 done(consumebuf_status, nullptr);
216 return;
217 }
218 }
219
CancelHook(const string & key)220 void BufRendezvous::CancelHook(const string& key) {
221 Hook* h = nullptr;
222 {
223 mutex_lock l(mu_);
224 auto it = hook_table_.find(key);
225 if (it == hook_table_.end()) return;
226 h = it->second;
227 hook_table_.erase(it);
228 }
229 if (h != nullptr) {
230 auto s = errors::Cancelled("Operation was cancelled for BufRendezvous key ",
231 key);
232 if (h->prod_cb != nullptr) {
233 h->prod_cb(s);
234 }
235 if (h->cons_cb != nullptr) {
236 h->cons_cb(s, /*Hook=*/nullptr);
237 }
238 delete h;
239 }
240 }
241
242 /*static*/
DoneWithHook(Hook * h)243 void BufRendezvous::DoneWithHook(Hook* h) {
244 h->prod_cb(Status::OK());
245 delete h;
246 }
247
LogContents()248 void BufRendezvous::LogContents() {
249 mutex_lock l(mu_);
250 LOG(INFO) << strings::StrCat("BufRendezvous ",
251 strings::Hex(reinterpret_cast<uint64>(this)),
252 " step_id=", step_id_, " current contents:");
253 for (const auto& it : hook_table_) {
254 LOG(INFO) << it.first << ":" << it.second->DebugString();
255 }
256 }
257
258 } // namespace tensorflow
259