1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/rendezvous.h"
17 
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/notification.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/random/simple_philox.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace tensorflow {
37 namespace {
38 
TEST(RendezvousTest,Key)39 TEST(RendezvousTest, Key) {
40   const string key = Rendezvous::CreateKey(
41       "/job:mnist/replica:1/task:2/CPU:0", 7890,
42       "/job:mnist/replica:1/task:2/device:GPU:0", "var0", FrameAndIter(0, 0));
43   EXPECT_EQ(key,
44             "/job:mnist/replica:1/task:2/CPU:0;"
45             "0000000000001ed2;"  // 7890 = 0x1ed2
46             "/job:mnist/replica:1/task:2/device:GPU:0;"
47             "var0;"
48             "0:0");
49   Rendezvous::ParsedKey parsed;
50   TF_EXPECT_OK(Rendezvous::ParseKey(key, &parsed));
51   EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0");
52   EXPECT_EQ(parsed.src_incarnation, 7890);
53   EXPECT_EQ(parsed.src.type, "CPU");
54   EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/device:GPU:0");
55   EXPECT_EQ(parsed.dst.type, "GPU");
56 
57   EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok());
58   EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;"
59                                     "/job:mnist/replica:1/task:2/device:GPU:0;",
60                                     &parsed)
61                    .ok());
62   EXPECT_FALSE(
63       Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok());
64 }
65 
66 class LocalRendezvousTest : public ::testing::Test {
67  public:
LocalRendezvousTest()68   LocalRendezvousTest() : threads_(Env::Default(), "test", 16) {
69     rendez_ = NewLocalRendezvous();
70   }
71 
~LocalRendezvousTest()72   ~LocalRendezvousTest() override { rendez_->Unref(); }
73 
SchedClosure(std::function<void ()> fn)74   void SchedClosure(std::function<void()> fn) {
75     threads_.Schedule(std::move(fn));
76   }
77 
78   Rendezvous* rendez_;
79 
80  private:
81   thread::ThreadPool threads_;
82 };
83 
84 // string -> Tensor<string>
V(const string & content)85 Tensor V(const string& content) {
86   Tensor tensor(DT_STRING, TensorShape({}));
87   tensor.scalar<string>()() = content;
88   return tensor;
89 }
90 
91 // Tensor<string> -> string
V(const Tensor & tensor)92 string V(const Tensor& tensor) {
93   CHECK_EQ(tensor.dtype(), DT_STRING);
94   CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
95   return tensor.scalar<string>()();
96 }
97 
MakeKey(const string & name)98 Rendezvous::ParsedKey MakeKey(const string& name) {
99   string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
100                                    "/job:mnist/replica:1/task:2/device:GPU:0",
101                                    name, FrameAndIter(0, 0));
102   Rendezvous::ParsedKey k;
103   TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
104   return k;
105 }
106 
KeyFoo()107 const Rendezvous::ParsedKey& KeyFoo() {
108   static auto key = MakeKey("foo");
109   return key;
110 }
111 
KeyBar()112 const Rendezvous::ParsedKey& KeyBar() {
113   static auto key = MakeKey("bar");
114   return key;
115 }
116 
TEST_F(LocalRendezvousTest,SendRecv)117 TEST_F(LocalRendezvousTest, SendRecv) {
118   Rendezvous::Args args;
119   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
120   Tensor val(DT_STRING);
121   bool is_dead = false;
122   TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
123   EXPECT_EQ("hello", V(val));
124 }
125 
TEST_F(LocalRendezvousTest,RecvSend)126 TEST_F(LocalRendezvousTest, RecvSend) {
127   SchedClosure([this]() {
128     Env::Default()->SleepForMicroseconds(10000);
129     Rendezvous::Args args;
130     TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
131   });
132   Tensor val(DT_STRING);
133   bool is_dead = false;
134   Rendezvous::Args args;
135   TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
136   EXPECT_EQ("hello", V(val));
137 }
138 
TEST_F(LocalRendezvousTest,PingPong)139 TEST_F(LocalRendezvousTest, PingPong) {
140   SchedClosure([this]() {
141     Tensor t(DT_STRING);
142     bool is_dead = false;
143     Rendezvous::Args args;
144     TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
145     TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
146   });
147   Env::Default()->SleepForMicroseconds(1000000);
148   Tensor val(DT_STRING);
149   bool val_dead = false;
150   Rendezvous::Args args;
151   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
152   TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
153   EXPECT_EQ("secret msg", V(val));
154 }
155 
156 // A simple structure that behaves a bit like a blocking counter.  The
157 // user that decrements counter to 0 does done.Notify(), and the main
158 // thread waits for done to be notified.
159 struct BlockingState {
160   mutex lock;
161   int counter = 0;
162   Notification done;
163 };
164 
TEST_F(LocalRendezvousTest,RandomSendRecv)165 TEST_F(LocalRendezvousTest, RandomSendRecv) {
166   // We are scheduling 2*N closures in the this->threads_, which is
167   // configured with only 16 threads. Furthermore, because the
168   // threadpool may execute the closures in an arbitrary order, we
169   // must use RecvAsync below. Otherwise, blocking Recv() may run
170   // before all all the Send() and deadlock.
171   static const int N = 100;
172   random::PhiloxRandom philox(testing::RandomSeed(), 17);
173   random::SimplePhilox rnd(&philox);
174   BlockingState state;
175   state.counter = N;
176   for (int i = 0; i < N; ++i) {
177     int micros = 100 + rnd.Uniform(1000);
178     SchedClosure([this, i, micros]() {
179       Env::Default()->SleepForMicroseconds(micros);
180       Rendezvous::Args args;
181       TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args,
182                                  V(strings::StrCat(i)), false));
183     });
184     auto recv_done = [this, &state, i](const Status& status,
185                                        const Rendezvous::Args& sender_args,
186                                        const Rendezvous::Args& recver_args,
187                                        const Tensor& val, const bool val_dead) {
188       EXPECT_EQ(strings::StrCat(i), V(val));
189       bool done = false;
190       {
191         mutex_lock l(state.lock);
192         state.counter--;
193         if (state.counter == 0) {
194           done = true;
195         }
196       }
197       if (done) {
198         state.done.Notify();
199       }
200     };
201     micros = 100 + rnd.Uniform(1000);
202     SchedClosure([this, i, micros, recv_done]() {
203       Env::Default()->SleepForMicroseconds(micros);
204       rendez_->RecvAsync(MakeKey(strings::StrCat(i)), Rendezvous::Args(),
205                          recv_done);
206     });
207   }
208 
209   state.done.WaitForNotification();
210 }
211 
RandomSleep()212 void RandomSleep() {
213   if (std::rand() % 10 == 0) {
214     Env::Default()->SleepForMicroseconds(1000);
215   }
216 }
217 
TEST_F(LocalRendezvousTest,MultiSends)218 TEST_F(LocalRendezvousTest, MultiSends) {
219   static const int N = 100;
220   const auto& key_foo = KeyFoo();
221   Rendezvous::Args args;
222   SchedClosure([=]() {
223     for (int i = 0; i < N; ++i) {
224       TF_ASSERT_OK(rendez_->Send(key_foo, args, V(strings::StrCat(i)), false));
225       RandomSleep();
226     }
227   });
228   Tensor val;
229   bool val_dead;
230   for (int i = 0; i < N; ++i) {
231     TF_ASSERT_OK(rendez_->Recv(key_foo, args, &val, &val_dead));
232     RandomSleep();
233   }
234 }
235 
TEST_F(LocalRendezvousTest,RecvAbort)236 TEST_F(LocalRendezvousTest, RecvAbort) {
237   rendez_->Ref();
238   SchedClosure([this]() {
239     rendez_->StartAbort(errors::Aborted(""));  // abort
240     rendez_->Unref();
241   });
242   Tensor val(DT_STRING);
243   bool val_dead = false;
244   Rendezvous::Args args;
245   Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
246   EXPECT_TRUE(errors::IsAborted(status));
247 }
248 
249 // Similar to RecvAbort. But this test case ensures the main thread
250 // Recv() call happens after StartAbort().
TEST_F(LocalRendezvousTest,RecvSleepAbort)251 TEST_F(LocalRendezvousTest, RecvSleepAbort) {
252   rendez_->Ref();
253   SchedClosure([this]() {
254     Env::Default()->SleepForMicroseconds(1000000);
255     rendez_->StartAbort(errors::Aborted(""));  // abort
256     rendez_->Unref();
257   });
258   Tensor val(DT_STRING);
259   bool val_dead = false;
260   Rendezvous::Args args;
261   Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
262   EXPECT_TRUE(errors::IsAborted(status));
263 }
264 
TEST_F(LocalRendezvousTest,AbortThenRecvOrSend)265 TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
266   rendez_->StartAbort(errors::Aborted(""));
267   Tensor val(DT_STRING);
268   bool val_dead = false;
269   Rendezvous::Args args;
270   EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead)));
271   EXPECT_TRUE(
272       errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
273 }
274 
275 class DummyDeviceContext : public DeviceContext {
276  public:
DummyDeviceContext(int stream_id)277   explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
~DummyDeviceContext()278   ~DummyDeviceContext() override {}
stream_id() const279   int stream_id() const { return stream_id_; }
280 
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const281   void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
282                               Tensor* output_tensor,
283                               StatusCallback done) const override {
284     done(Status::OK());
285   }
286 
287  private:
288   const int stream_id_;
289 };
290 
TEST_F(LocalRendezvousTest,TransferDummyDeviceContext)291 TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
292   Rendezvous::Args args;
293   args.device_context = new DummyDeviceContext(123);
294 
295   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
296 
297   Notification n;
298   Rendezvous::Args args1;
299   args1.device_context = new DummyDeviceContext(1);
300   rendez_->RecvAsync(
301       KeyFoo(), args1,
302       [&n](const Status& s, const Rendezvous::Args& send_args,
303            const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
304         CHECK_EQ(123, dynamic_cast<const DummyDeviceContext*>(
305                           send_args.device_context)
306                           ->stream_id());
307         n.Notify();
308       });
309 
310   n.WaitForNotification();
311   args.device_context->Unref();
312   args1.device_context->Unref();
313 }
314 
BM_SendRecv(int iters)315 void BM_SendRecv(int iters) {
316   Rendezvous* rendez = NewLocalRendezvous();
317   Tensor orig = V("val");
318   Tensor val(DT_STRING, TensorShape({}));
319   bool is_dead = false;
320   Rendezvous::Args args;
321   if (iters > 0) {
322     while (iters--) {
323       TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
324       TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
325     }
326     CHECK_EQ(V(val), V(orig));
327   }
328   rendez->Unref();
329 }
330 BENCHMARK(BM_SendRecv);
331 
BM_PingPong(int iters)332 void BM_PingPong(int iters) {
333   CHECK_GT(iters, 0);
334   thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
335 
336   // The main thread sends "foo" for iters times and receives "bar"
337   // for iters times.  The other thread sends "bar" for iters times
338   // and receives "foo" for iters times.
339   Rendezvous* rendez = NewLocalRendezvous();
340   pool->Schedule([rendez, iters]() {
341     Tensor bar = V("bar");
342     Tensor foo(DT_STRING, TensorShape({}));
343     bool is_dead = false;
344     Rendezvous::Args args;
345     for (int i = 0; i < iters; ++i) {
346       TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
347       TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
348     }
349     CHECK_EQ("foo", V(foo));
350   });
351   Tensor foo = V("foo");
352   Tensor bar(DT_STRING, TensorShape({}));
353   bool is_dead = false;
354   Rendezvous::Args args;
355   for (int i = 0; i < iters; ++i) {
356     TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
357     TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
358   }
359   CHECK_EQ("bar", V(bar));
360   delete pool;
361 }
362 BENCHMARK(BM_PingPong);
363 
364 }  // namespace
365 }  // namespace tensorflow
366