1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/rendezvous.h"
17
18 #include <deque>
19 #include <functional>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/notification.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/lib/hash/hash.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33
34 namespace tensorflow {
35
operator =(const ParsedKey & b)36 Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) {
37 const char* b_base = b.buf_.data();
38 buf_ = b.buf_;
39 src_device = StringPiece(buf_.data() + (b.src_device.data() - b_base),
40 b.src_device.size());
41 src = b.src;
42 src_incarnation = b.src_incarnation;
43 dst_device = StringPiece(buf_.data() + (b.dst_device.data() - b_base),
44 b.dst_device.size());
45 dst = b.dst;
46 edge_name = StringPiece(buf_.data() + (b.edge_name.data() - b_base),
47 b.edge_name.size());
48 return *this;
49 }
50
51 /* static */
CreateKey(const string & src_device,uint64 src_incarnation,const string & dst_device,const string & name,const FrameAndIter & frame_iter)52 string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
53 const string& dst_device, const string& name,
54 const FrameAndIter& frame_iter) {
55 // NOTE: ';' is not used in the device name's job name.
56 //
57 // We include both sender and receiver in the key to facilitate
58 // debugging. For correctness, we only need to encode the receiver.
59 //
60 // "src_incarnation" is used to distinguish a worker when it
61 // restarts.
62 char buf[strings::kFastToBufferSize];
63 return strings::StrCat(
64 src_device, ";", strings::Uint64ToHexString(src_incarnation, buf), ";",
65 dst_device, ";", name, ";", frame_iter.frame_id, ":", frame_iter.iter_id);
66 }
67
68 // Return the prefix of "*s" up to the next occurrence of "delim", or
69 // the whole remaining string if "delim" is not found. "*s" is advanced
70 // past the string returned plus the delimiter (if found).
ConsumeNextPart(StringPiece * s,char delim)71 static StringPiece ConsumeNextPart(StringPiece* s, char delim) {
72 for (size_t offset = 0; offset < s->size(); offset++) {
73 if ((*s)[offset] == delim) {
74 StringPiece result(s->data(), offset);
75 s->remove_prefix(offset + 1); // +1: remove delim, as well
76 return result;
77 }
78 }
79 // No delimiter found: return rest of string
80 StringPiece result(s->data(), s->size());
81 s->remove_prefix(s->size());
82 return result;
83 }
84
85 /* static */
ParseKey(StringPiece key,ParsedKey * out)86 Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) {
87 if (key.data() == out->buf_.data()) {
88 // Caller used our buf_ string directly, so we don't need to copy. (The
89 // SendOp and RecvOp implementations do this, for example).
90 DCHECK_EQ(key.size(), out->buf_.size());
91 } else {
92 // Make a copy that our StringPieces can point at a copy that will persist
93 // for the lifetime of the ParsedKey object.
94 out->buf_.assign(key.data(), key.size());
95 }
96 StringPiece s(out->buf_);
97 StringPiece parts[5];
98 for (int i = 0; i < 5; i++) {
99 parts[i] = ConsumeNextPart(&s, ';');
100 }
101 if (s.empty() && // Consumed the whole string
102 !parts[4].empty() && // Exactly five parts
103 DeviceNameUtils::ParseFullName(parts[0], &out->src) &&
104 strings::HexStringToUint64(parts[1], &out->src_incarnation) &&
105 DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
106 !parts[3].empty()) {
107 out->src_device = StringPiece(parts[0].data(), parts[0].size());
108 out->dst_device = StringPiece(parts[2].data(), parts[2].size());
109 out->edge_name = StringPiece(parts[3].data(), parts[3].size());
110 return Status::OK();
111 }
112 return errors::InvalidArgument("Invalid rendezvous key: ", key);
113 }
114
~Rendezvous()115 Rendezvous::~Rendezvous() {}
116
Recv(const ParsedKey & key,const Args & recv_args,Tensor * val,bool * is_dead,int64 timeout_ms)117 Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
118 Tensor* val, bool* is_dead, int64 timeout_ms) {
119 Status ret;
120 Notification n;
121 RecvAsync(key, recv_args,
122 [&ret, &n, val, is_dead](const Status& s, const Args& send_args,
123 const Args& recv_args, const Tensor& v,
124 const bool dead) {
125 ret = s;
126 *val = v;
127 *is_dead = dead;
128 n.Notify();
129 });
130 if (timeout_ms > 0) {
131 int64 timeout_us = timeout_ms * 1000;
132 bool notified = WaitForNotificationWithTimeout(&n, timeout_us);
133 if (!notified) {
134 return Status(error::DEADLINE_EXCEEDED,
135 "Timed out waiting for notification");
136 }
137 } else {
138 n.WaitForNotification();
139 }
140 return ret;
141 }
142
Recv(const ParsedKey & key,const Args & args,Tensor * val,bool * is_dead)143 Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val,
144 bool* is_dead) {
145 const int64 no_timeout = 0;
146 return Recv(key, args, val, is_dead, no_timeout);
147 }
148
149 class LocalRendezvousImpl : public Rendezvous {
150 public:
LocalRendezvousImpl()151 explicit LocalRendezvousImpl() {}
152
Send(const ParsedKey & key,const Args & send_args,const Tensor & val,const bool is_dead)153 Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
154 const bool is_dead) override {
155 uint64 key_hash = KeyHash(key.FullKey());
156 VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
157
158 mu_.lock();
159 if (!status_.ok()) {
160 // Rendezvous has been aborted.
161 Status s = status_;
162 mu_.unlock();
163 return s;
164 }
165
166 ItemQueue* queue = &table_[key_hash];
167 if (queue->empty() || queue->front()->IsSendValue()) {
168 // There is no waiter for this message. Append the message
169 // into the queue. The waiter will pick it up when arrives.
170 // Only send-related fields need to be filled.
171 Item* item = new Item;
172 item->value = val;
173 item->is_dead = is_dead;
174 item->send_args = send_args;
175 if (item->send_args.device_context) {
176 item->send_args.device_context->Ref();
177 }
178 queue->push_back(item);
179 mu_.unlock();
180 return Status::OK();
181 }
182
183 // There is an earliest waiter to consume this message.
184 Item* item = queue->front();
185 queue->pop_front();
186 mu_.unlock();
187
188 // Notify the waiter by invoking its done closure, outside the
189 // lock.
190 DCHECK(!item->IsSendValue());
191 item->waiter(Status::OK(), send_args, item->recv_args, val, is_dead);
192 delete item;
193 return Status::OK();
194 }
195
RecvAsync(const ParsedKey & key,const Args & recv_args,DoneCallback done)196 void RecvAsync(const ParsedKey& key, const Args& recv_args,
197 DoneCallback done) override {
198 uint64 key_hash = KeyHash(key.FullKey());
199 VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
200
201 mu_.lock();
202 if (!status_.ok()) {
203 // Rendezvous has been aborted.
204 Status s = status_;
205 mu_.unlock();
206 done(s, Args(), recv_args, Tensor(), false);
207 return;
208 }
209
210 ItemQueue* queue = &table_[key_hash];
211 if (queue->empty() || !queue->front()->IsSendValue()) {
212 // There is no message to pick up.
213 // Only recv-related fields need to be filled.
214 Item* item = new Item;
215 item->waiter = std::move(done);
216 item->recv_args = recv_args;
217 if (item->recv_args.device_context) {
218 item->recv_args.device_context->Ref();
219 }
220 queue->push_back(item);
221 mu_.unlock();
222 return;
223 }
224
225 // A message has already arrived and is queued in the table under
226 // this key. Consumes the message and invokes the done closure.
227 Item* item = queue->front();
228 queue->pop_front();
229 mu_.unlock();
230
231 // Invokes the done() by invoking its done closure, outside scope
232 // of the table lock.
233 DCHECK(item->IsSendValue());
234 done(Status::OK(), item->send_args, recv_args, item->value, item->is_dead);
235 delete item;
236 }
237
StartAbort(const Status & status)238 void StartAbort(const Status& status) override {
239 CHECK(!status.ok());
240 Table table;
241 {
242 mutex_lock l(mu_);
243 status_.Update(status);
244 table_.swap(table);
245 }
246 for (auto& p : table) {
247 for (Item* item : p.second) {
248 if (!item->IsSendValue()) {
249 item->waiter(status, Args(), Args(), Tensor(), false);
250 }
251 delete item;
252 }
253 }
254 }
255
256 private:
257 typedef LocalRendezvousImpl ME;
258
259 struct Item {
260 DoneCallback waiter = nullptr;
261 Tensor value;
262 bool is_dead = false;
263 Args send_args;
264 Args recv_args;
265
~Itemtensorflow::LocalRendezvousImpl::Item266 ~Item() {
267 if (send_args.device_context) {
268 send_args.device_context->Unref();
269 }
270 if (recv_args.device_context) {
271 recv_args.device_context->Unref();
272 }
273 }
274
275 // Returns true iff this item represents a value being sent.
IsSendValuetensorflow::LocalRendezvousImpl::Item276 bool IsSendValue() const { return this->waiter == nullptr; }
277 };
278
279 // We key the hash table by KeyHash of the Rendezvous::CreateKey string
KeyHash(const StringPiece & k)280 static uint64 KeyHash(const StringPiece& k) {
281 return Hash64(k.data(), k.size());
282 }
283
284 // By invariant, the item queue under each key is of the form
285 // [item.IsSendValue()]* meaning each item is a sent message.
286 // or
287 // [!item.IsSendValue()]* meaning each item is a waiter.
288 //
289 // TODO(zhifengc): consider a better queue impl than std::deque.
290 typedef std::deque<Item*> ItemQueue;
291 typedef gtl::FlatMap<uint64, ItemQueue> Table;
292
293 // TODO(zhifengc): shard table_.
294 mutex mu_;
295 Table table_ GUARDED_BY(mu_);
296 Status status_ GUARDED_BY(mu_);
297
~LocalRendezvousImpl()298 ~LocalRendezvousImpl() override {
299 if (!table_.empty()) {
300 StartAbort(errors::Cancelled("LocalRendezvousImpl deleted"));
301 }
302 }
303
304 TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl);
305 };
306
NewLocalRendezvous()307 Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); }
308
309 } // end namespace tensorflow
310