1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/rendezvous.h"
17 
18 #include <deque>
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/notification.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/lib/hash/hash.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 
operator =(const ParsedKey & b)36 Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) {
37   const char* b_base = b.buf_.data();
38   buf_ = b.buf_;
39   src_device = StringPiece(buf_.data() + (b.src_device.data() - b_base),
40                            b.src_device.size());
41   src = b.src;
42   src_incarnation = b.src_incarnation;
43   dst_device = StringPiece(buf_.data() + (b.dst_device.data() - b_base),
44                            b.dst_device.size());
45   dst = b.dst;
46   edge_name = StringPiece(buf_.data() + (b.edge_name.data() - b_base),
47                           b.edge_name.size());
48   return *this;
49 }
50 
51 /*  static */
CreateKey(const string & src_device,uint64 src_incarnation,const string & dst_device,const string & name,const FrameAndIter & frame_iter)52 string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
53                              const string& dst_device, const string& name,
54                              const FrameAndIter& frame_iter) {
55   // NOTE: ';' is not used in the device name's job name.
56   //
57   // We include both sender and receiver in the key to facilitate
58   // debugging. For correctness, we only need to encode the receiver.
59   //
60   // "src_incarnation" is used to distinguish a worker when it
61   // restarts.
62   char buf[strings::kFastToBufferSize];
63   return strings::StrCat(
64       src_device, ";", strings::Uint64ToHexString(src_incarnation, buf), ";",
65       dst_device, ";", name, ";", frame_iter.frame_id, ":", frame_iter.iter_id);
66 }
67 
68 // Return the prefix of "*s" up to the next occurrence of "delim", or
69 // the whole remaining string if "delim" is not found.  "*s" is advanced
70 // past the string returned plus the delimiter (if found).
ConsumeNextPart(StringPiece * s,char delim)71 static StringPiece ConsumeNextPart(StringPiece* s, char delim) {
72   for (size_t offset = 0; offset < s->size(); offset++) {
73     if ((*s)[offset] == delim) {
74       StringPiece result(s->data(), offset);
75       s->remove_prefix(offset + 1);  // +1: remove delim, as well
76       return result;
77     }
78   }
79   // No delimiter found: return rest of string
80   StringPiece result(s->data(), s->size());
81   s->remove_prefix(s->size());
82   return result;
83 }
84 
85 /* static */
ParseKey(StringPiece key,ParsedKey * out)86 Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) {
87   if (key.data() == out->buf_.data()) {
88     // Caller used our buf_ string directly, so we don't need to copy.  (The
89     // SendOp and RecvOp implementations do this, for example).
90     DCHECK_EQ(key.size(), out->buf_.size());
91   } else {
92     // Make a copy that our StringPieces can point at a copy that will persist
93     // for the lifetime of the ParsedKey object.
94     out->buf_.assign(key.data(), key.size());
95   }
96   StringPiece s(out->buf_);
97   StringPiece parts[5];
98   for (int i = 0; i < 5; i++) {
99     parts[i] = ConsumeNextPart(&s, ';');
100   }
101   if (s.empty() &&          // Consumed the whole string
102       !parts[4].empty() &&  // Exactly five parts
103       DeviceNameUtils::ParseFullName(parts[0], &out->src) &&
104       strings::HexStringToUint64(parts[1], &out->src_incarnation) &&
105       DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
106       !parts[3].empty()) {
107     out->src_device = StringPiece(parts[0].data(), parts[0].size());
108     out->dst_device = StringPiece(parts[2].data(), parts[2].size());
109     out->edge_name = StringPiece(parts[3].data(), parts[3].size());
110     return Status::OK();
111   }
112   return errors::InvalidArgument("Invalid  rendezvous key: ", key);
113 }
114 
~Rendezvous()115 Rendezvous::~Rendezvous() {}
116 
Recv(const ParsedKey & key,const Args & recv_args,Tensor * val,bool * is_dead,int64 timeout_ms)117 Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
118                         Tensor* val, bool* is_dead, int64 timeout_ms) {
119   Status ret;
120   Notification n;
121   RecvAsync(key, recv_args,
122             [&ret, &n, val, is_dead](const Status& s, const Args& send_args,
123                                      const Args& recv_args, const Tensor& v,
124                                      const bool dead) {
125               ret = s;
126               *val = v;
127               *is_dead = dead;
128               n.Notify();
129             });
130   if (timeout_ms > 0) {
131     int64 timeout_us = timeout_ms * 1000;
132     bool notified = WaitForNotificationWithTimeout(&n, timeout_us);
133     if (!notified) {
134       return Status(error::DEADLINE_EXCEEDED,
135                     "Timed out waiting for notification");
136     }
137   } else {
138     n.WaitForNotification();
139   }
140   return ret;
141 }
142 
Recv(const ParsedKey & key,const Args & args,Tensor * val,bool * is_dead)143 Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val,
144                         bool* is_dead) {
145   const int64 no_timeout = 0;
146   return Recv(key, args, val, is_dead, no_timeout);
147 }
148 
149 class LocalRendezvousImpl : public Rendezvous {
150  public:
LocalRendezvousImpl()151   explicit LocalRendezvousImpl() {}
152 
Send(const ParsedKey & key,const Args & send_args,const Tensor & val,const bool is_dead)153   Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
154               const bool is_dead) override {
155     uint64 key_hash = KeyHash(key.FullKey());
156     VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
157 
158     mu_.lock();
159     if (!status_.ok()) {
160       // Rendezvous has been aborted.
161       Status s = status_;
162       mu_.unlock();
163       return s;
164     }
165 
166     ItemQueue* queue = &table_[key_hash];
167     if (queue->empty() || queue->front()->IsSendValue()) {
168       // There is no waiter for this message. Append the message
169       // into the queue. The waiter will pick it up when arrives.
170       // Only send-related fields need to be filled.
171       Item* item = new Item;
172       item->value = val;
173       item->is_dead = is_dead;
174       item->send_args = send_args;
175       if (item->send_args.device_context) {
176         item->send_args.device_context->Ref();
177       }
178       queue->push_back(item);
179       mu_.unlock();
180       return Status::OK();
181     }
182 
183     // There is an earliest waiter to consume this message.
184     Item* item = queue->front();
185     queue->pop_front();
186     mu_.unlock();
187 
188     // Notify the waiter by invoking its done closure, outside the
189     // lock.
190     DCHECK(!item->IsSendValue());
191     item->waiter(Status::OK(), send_args, item->recv_args, val, is_dead);
192     delete item;
193     return Status::OK();
194   }
195 
RecvAsync(const ParsedKey & key,const Args & recv_args,DoneCallback done)196   void RecvAsync(const ParsedKey& key, const Args& recv_args,
197                  DoneCallback done) override {
198     uint64 key_hash = KeyHash(key.FullKey());
199     VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
200 
201     mu_.lock();
202     if (!status_.ok()) {
203       // Rendezvous has been aborted.
204       Status s = status_;
205       mu_.unlock();
206       done(s, Args(), recv_args, Tensor(), false);
207       return;
208     }
209 
210     ItemQueue* queue = &table_[key_hash];
211     if (queue->empty() || !queue->front()->IsSendValue()) {
212       // There is no message to pick up.
213       // Only recv-related fields need to be filled.
214       Item* item = new Item;
215       item->waiter = std::move(done);
216       item->recv_args = recv_args;
217       if (item->recv_args.device_context) {
218         item->recv_args.device_context->Ref();
219       }
220       queue->push_back(item);
221       mu_.unlock();
222       return;
223     }
224 
225     // A message has already arrived and is queued in the table under
226     // this key.  Consumes the message and invokes the done closure.
227     Item* item = queue->front();
228     queue->pop_front();
229     mu_.unlock();
230 
231     // Invokes the done() by invoking its done closure, outside scope
232     // of the table lock.
233     DCHECK(item->IsSendValue());
234     done(Status::OK(), item->send_args, recv_args, item->value, item->is_dead);
235     delete item;
236   }
237 
StartAbort(const Status & status)238   void StartAbort(const Status& status) override {
239     CHECK(!status.ok());
240     Table table;
241     {
242       mutex_lock l(mu_);
243       status_.Update(status);
244       table_.swap(table);
245     }
246     for (auto& p : table) {
247       for (Item* item : p.second) {
248         if (!item->IsSendValue()) {
249           item->waiter(status, Args(), Args(), Tensor(), false);
250         }
251         delete item;
252       }
253     }
254   }
255 
256  private:
257   typedef LocalRendezvousImpl ME;
258 
259   struct Item {
260     DoneCallback waiter = nullptr;
261     Tensor value;
262     bool is_dead = false;
263     Args send_args;
264     Args recv_args;
265 
~Itemtensorflow::LocalRendezvousImpl::Item266     ~Item() {
267       if (send_args.device_context) {
268         send_args.device_context->Unref();
269       }
270       if (recv_args.device_context) {
271         recv_args.device_context->Unref();
272       }
273     }
274 
275     // Returns true iff this item represents a value being sent.
IsSendValuetensorflow::LocalRendezvousImpl::Item276     bool IsSendValue() const { return this->waiter == nullptr; }
277   };
278 
279   // We key the hash table by KeyHash of the Rendezvous::CreateKey string
KeyHash(const StringPiece & k)280   static uint64 KeyHash(const StringPiece& k) {
281     return Hash64(k.data(), k.size());
282   }
283 
284   // By invariant, the item queue under each key is of the form
285   //   [item.IsSendValue()]* meaning each item is a sent message.
286   // or
287   //   [!item.IsSendValue()]* meaning each item is a waiter.
288   //
289   // TODO(zhifengc): consider a better queue impl than std::deque.
290   typedef std::deque<Item*> ItemQueue;
291   typedef gtl::FlatMap<uint64, ItemQueue> Table;
292 
293   // TODO(zhifengc): shard table_.
294   mutex mu_;
295   Table table_ GUARDED_BY(mu_);
296   Status status_ GUARDED_BY(mu_);
297 
~LocalRendezvousImpl()298   ~LocalRendezvousImpl() override {
299     if (!table_.empty()) {
300       StartAbort(errors::Cancelled("LocalRendezvousImpl deleted"));
301     }
302   }
303 
304   TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl);
305 };
306 
NewLocalRendezvous()307 Rendezvous* NewLocalRendezvous() { return new LocalRendezvousImpl(); }
308 
309 }  // end namespace tensorflow
310