1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/rendezvous.h"
17
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/notification.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/random/simple_philox.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace tensorflow {
37 namespace {
38
TEST(RendezvousTest,Key)39 TEST(RendezvousTest, Key) {
40 const string key = Rendezvous::CreateKey(
41 "/job:mnist/replica:1/task:2/CPU:0", 7890,
42 "/job:mnist/replica:1/task:2/device:GPU:0", "var0", FrameAndIter(0, 0));
43 EXPECT_EQ(key,
44 "/job:mnist/replica:1/task:2/CPU:0;"
45 "0000000000001ed2;" // 7890 = 0x1ed2
46 "/job:mnist/replica:1/task:2/device:GPU:0;"
47 "var0;"
48 "0:0");
49 Rendezvous::ParsedKey parsed;
50 TF_EXPECT_OK(Rendezvous::ParseKey(key, &parsed));
51 EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0");
52 EXPECT_EQ(parsed.src_incarnation, 7890);
53 EXPECT_EQ(parsed.src.type, "CPU");
54 EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/device:GPU:0");
55 EXPECT_EQ(parsed.dst.type, "GPU");
56
57 EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok());
58 EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;"
59 "/job:mnist/replica:1/task:2/device:GPU:0;",
60 &parsed)
61 .ok());
62 EXPECT_FALSE(
63 Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok());
64 }
65
66 class LocalRendezvousTest : public ::testing::Test {
67 public:
LocalRendezvousTest()68 LocalRendezvousTest() : threads_(Env::Default(), "test", 16) {
69 rendez_ = NewLocalRendezvous();
70 }
71
~LocalRendezvousTest()72 ~LocalRendezvousTest() override { rendez_->Unref(); }
73
SchedClosure(std::function<void ()> fn)74 void SchedClosure(std::function<void()> fn) {
75 threads_.Schedule(std::move(fn));
76 }
77
78 Rendezvous* rendez_;
79
80 private:
81 thread::ThreadPool threads_;
82 };
83
84 // string -> Tensor<string>
V(const string & content)85 Tensor V(const string& content) {
86 Tensor tensor(DT_STRING, TensorShape({}));
87 tensor.scalar<string>()() = content;
88 return tensor;
89 }
90
91 // Tensor<string> -> string
V(const Tensor & tensor)92 string V(const Tensor& tensor) {
93 CHECK_EQ(tensor.dtype(), DT_STRING);
94 CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
95 return tensor.scalar<string>()();
96 }
97
MakeKey(const string & name)98 Rendezvous::ParsedKey MakeKey(const string& name) {
99 string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
100 "/job:mnist/replica:1/task:2/device:GPU:0",
101 name, FrameAndIter(0, 0));
102 Rendezvous::ParsedKey k;
103 TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
104 return k;
105 }
106
KeyFoo()107 const Rendezvous::ParsedKey& KeyFoo() {
108 static auto key = MakeKey("foo");
109 return key;
110 }
111
KeyBar()112 const Rendezvous::ParsedKey& KeyBar() {
113 static auto key = MakeKey("bar");
114 return key;
115 }
116
TEST_F(LocalRendezvousTest,SendRecv)117 TEST_F(LocalRendezvousTest, SendRecv) {
118 Rendezvous::Args args;
119 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
120 Tensor val(DT_STRING);
121 bool is_dead = false;
122 TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
123 EXPECT_EQ("hello", V(val));
124 }
125
TEST_F(LocalRendezvousTest,RecvSend)126 TEST_F(LocalRendezvousTest, RecvSend) {
127 SchedClosure([this]() {
128 Env::Default()->SleepForMicroseconds(10000);
129 Rendezvous::Args args;
130 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
131 });
132 Tensor val(DT_STRING);
133 bool is_dead = false;
134 Rendezvous::Args args;
135 TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
136 EXPECT_EQ("hello", V(val));
137 }
138
TEST_F(LocalRendezvousTest,PingPong)139 TEST_F(LocalRendezvousTest, PingPong) {
140 SchedClosure([this]() {
141 Tensor t(DT_STRING);
142 bool is_dead = false;
143 Rendezvous::Args args;
144 TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
145 TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
146 });
147 Env::Default()->SleepForMicroseconds(1000000);
148 Tensor val(DT_STRING);
149 bool val_dead = false;
150 Rendezvous::Args args;
151 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
152 TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
153 EXPECT_EQ("secret msg", V(val));
154 }
155
156 // A simple structure that behaves a bit like a blocking counter. The
157 // user that decrements counter to 0 does done.Notify(), and the main
158 // thread waits for done to be notified.
159 struct BlockingState {
160 mutex lock;
161 int counter = 0;
162 Notification done;
163 };
164
TEST_F(LocalRendezvousTest,RandomSendRecv)165 TEST_F(LocalRendezvousTest, RandomSendRecv) {
166 // We are scheduling 2*N closures in the this->threads_, which is
167 // configured with only 16 threads. Furthermore, because the
168 // threadpool may execute the closures in an arbitrary order, we
169 // must use RecvAsync below. Otherwise, blocking Recv() may run
170 // before all all the Send() and deadlock.
171 static const int N = 100;
172 random::PhiloxRandom philox(testing::RandomSeed(), 17);
173 random::SimplePhilox rnd(&philox);
174 BlockingState state;
175 state.counter = N;
176 for (int i = 0; i < N; ++i) {
177 int micros = 100 + rnd.Uniform(1000);
178 SchedClosure([this, i, micros]() {
179 Env::Default()->SleepForMicroseconds(micros);
180 Rendezvous::Args args;
181 TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args,
182 V(strings::StrCat(i)), false));
183 });
184 auto recv_done = [this, &state, i](const Status& status,
185 const Rendezvous::Args& sender_args,
186 const Rendezvous::Args& recver_args,
187 const Tensor& val, const bool val_dead) {
188 EXPECT_EQ(strings::StrCat(i), V(val));
189 bool done = false;
190 {
191 mutex_lock l(state.lock);
192 state.counter--;
193 if (state.counter == 0) {
194 done = true;
195 }
196 }
197 if (done) {
198 state.done.Notify();
199 }
200 };
201 micros = 100 + rnd.Uniform(1000);
202 SchedClosure([this, i, micros, recv_done]() {
203 Env::Default()->SleepForMicroseconds(micros);
204 rendez_->RecvAsync(MakeKey(strings::StrCat(i)), Rendezvous::Args(),
205 recv_done);
206 });
207 }
208
209 state.done.WaitForNotification();
210 }
211
RandomSleep()212 void RandomSleep() {
213 if (std::rand() % 10 == 0) {
214 Env::Default()->SleepForMicroseconds(1000);
215 }
216 }
217
TEST_F(LocalRendezvousTest,MultiSends)218 TEST_F(LocalRendezvousTest, MultiSends) {
219 static const int N = 100;
220 const auto& key_foo = KeyFoo();
221 Rendezvous::Args args;
222 SchedClosure([=]() {
223 for (int i = 0; i < N; ++i) {
224 TF_ASSERT_OK(rendez_->Send(key_foo, args, V(strings::StrCat(i)), false));
225 RandomSleep();
226 }
227 });
228 Tensor val;
229 bool val_dead;
230 for (int i = 0; i < N; ++i) {
231 TF_ASSERT_OK(rendez_->Recv(key_foo, args, &val, &val_dead));
232 RandomSleep();
233 }
234 }
235
TEST_F(LocalRendezvousTest,RecvAbort)236 TEST_F(LocalRendezvousTest, RecvAbort) {
237 rendez_->Ref();
238 SchedClosure([this]() {
239 rendez_->StartAbort(errors::Aborted("")); // abort
240 rendez_->Unref();
241 });
242 Tensor val(DT_STRING);
243 bool val_dead = false;
244 Rendezvous::Args args;
245 Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
246 EXPECT_TRUE(errors::IsAborted(status));
247 }
248
249 // Similar to RecvAbort. But this test case ensures the main thread
250 // Recv() call happens after StartAbort().
TEST_F(LocalRendezvousTest,RecvSleepAbort)251 TEST_F(LocalRendezvousTest, RecvSleepAbort) {
252 rendez_->Ref();
253 SchedClosure([this]() {
254 Env::Default()->SleepForMicroseconds(1000000);
255 rendez_->StartAbort(errors::Aborted("")); // abort
256 rendez_->Unref();
257 });
258 Tensor val(DT_STRING);
259 bool val_dead = false;
260 Rendezvous::Args args;
261 Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
262 EXPECT_TRUE(errors::IsAborted(status));
263 }
264
TEST_F(LocalRendezvousTest,AbortThenRecvOrSend)265 TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
266 rendez_->StartAbort(errors::Aborted(""));
267 Tensor val(DT_STRING);
268 bool val_dead = false;
269 Rendezvous::Args args;
270 EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead)));
271 EXPECT_TRUE(
272 errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
273 }
274
275 class DummyDeviceContext : public DeviceContext {
276 public:
DummyDeviceContext(int stream_id)277 explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
~DummyDeviceContext()278 ~DummyDeviceContext() override {}
stream_id() const279 int stream_id() const { return stream_id_; }
280
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const281 void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
282 Tensor* output_tensor,
283 StatusCallback done) const override {
284 done(Status::OK());
285 }
286
287 private:
288 const int stream_id_;
289 };
290
TEST_F(LocalRendezvousTest,TransferDummyDeviceContext)291 TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
292 Rendezvous::Args args;
293 args.device_context = new DummyDeviceContext(123);
294
295 TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
296
297 Notification n;
298 Rendezvous::Args args1;
299 args1.device_context = new DummyDeviceContext(1);
300 rendez_->RecvAsync(
301 KeyFoo(), args1,
302 [&n](const Status& s, const Rendezvous::Args& send_args,
303 const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
304 CHECK_EQ(123, dynamic_cast<const DummyDeviceContext*>(
305 send_args.device_context)
306 ->stream_id());
307 n.Notify();
308 });
309
310 n.WaitForNotification();
311 args.device_context->Unref();
312 args1.device_context->Unref();
313 }
314
BM_SendRecv(int iters)315 void BM_SendRecv(int iters) {
316 Rendezvous* rendez = NewLocalRendezvous();
317 Tensor orig = V("val");
318 Tensor val(DT_STRING, TensorShape({}));
319 bool is_dead = false;
320 Rendezvous::Args args;
321 if (iters > 0) {
322 while (iters--) {
323 TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
324 TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
325 }
326 CHECK_EQ(V(val), V(orig));
327 }
328 rendez->Unref();
329 }
330 BENCHMARK(BM_SendRecv);
331
BM_PingPong(int iters)332 void BM_PingPong(int iters) {
333 CHECK_GT(iters, 0);
334 thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
335
336 // The main thread sends "foo" for iters times and receives "bar"
337 // for iters times. The other thread sends "bar" for iters times
338 // and receives "foo" for iters times.
339 Rendezvous* rendez = NewLocalRendezvous();
340 pool->Schedule([rendez, iters]() {
341 Tensor bar = V("bar");
342 Tensor foo(DT_STRING, TensorShape({}));
343 bool is_dead = false;
344 Rendezvous::Args args;
345 for (int i = 0; i < iters; ++i) {
346 TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
347 TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
348 }
349 CHECK_EQ("foo", V(foo));
350 });
351 Tensor foo = V("foo");
352 Tensor bar(DT_STRING, TensorShape({}));
353 bool is_dead = false;
354 Rendezvous::Args args;
355 for (int i = 0; i < iters; ++i) {
356 TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
357 TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
358 }
359 CHECK_EQ("bar", V(bar));
360 delete pool;
361 }
362 BENCHMARK(BM_PingPong);
363
364 } // namespace
365 } // namespace tensorflow
366