1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/master.h"
17
18 #include <map>
19 #include <memory>
20
21 #include "grpcpp/grpcpp.h"
22
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/graph/testlib.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/notification.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/protobuf/master.pb.h"
41
42 namespace tensorflow {
43
44 class MasterTest : public ::testing::Test {
45 protected:
MasterTest()46 MasterTest() {
47 std::vector<string> targets;
48 SessionOptions options;
49 (*options.config.mutable_device_count())["CPU"] = 1;
50 (*options.config.mutable_device_count())["GPU"] = 0;
51 TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_));
52 SharedGrpcChannelPtr channel_ptr;
53 TF_CHECK_OK(NewHostPortGrpcChannel(
54 cluster_->targets()[0], &options.config.rpc_options(), &channel_ptr));
55 master_ = grpc::MasterService::NewStub(channel_ptr);
56 }
57
58 std::unique_ptr<test::TestCluster> cluster_;
59 std::unique_ptr<grpc::MasterService::Stub> master_;
60
61 // Helpers for MasterService.{CreateSession,RunStep,CloseSession}
62 // rpc calls.
63
CreateSession(const GraphDef & def,string * handle,int64 * initial_version)64 Status CreateSession(const GraphDef& def, string* handle,
65 int64* initial_version) {
66 ::grpc::ClientContext ctx;
67 CreateSessionRequest req;
68 *(req.mutable_graph_def()) = def;
69 // Invokes placement frequently.
70 req.mutable_config()->set_placement_period(1);
71 CreateSessionResponse resp;
72 const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp));
73 if (s.ok()) {
74 *handle = resp.session_handle();
75 *initial_version = resp.graph_version();
76 }
77 return s;
78 }
79
ExtendSession(const string & handle,const GraphDef & def,int64 current_version,int64 * new_version)80 Status ExtendSession(const string& handle, const GraphDef& def,
81 int64 current_version, int64* new_version) {
82 ::grpc::ClientContext ctx;
83 ExtendSessionRequest req;
84 req.set_session_handle(handle);
85 *(req.mutable_graph_def()) = def;
86 req.set_current_graph_version(current_version);
87 ExtendSessionResponse resp;
88 const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp));
89 if (s.ok()) {
90 *new_version = resp.new_graph_version();
91 }
92 return s;
93 }
94
RunStep(const string & handle,const std::vector<std::pair<string,const Tensor * >> & feed,const std::map<string,Tensor * > & fetch)95 Status RunStep(const string& handle,
96 const std::vector<std::pair<string, const Tensor*> >& feed,
97 const std::map<string, Tensor*>& fetch) {
98 ::grpc::ClientContext ctx;
99 RunStepRequest req;
100 req.set_session_handle(handle);
101 for (const auto& p : feed) {
102 const string& feed_name = p.first;
103 const Tensor* feed_tensor = p.second;
104 auto f = req.add_feed();
105 f->set_name(feed_name);
106 feed_tensor->AsProtoTensorContent(f->mutable_tensor());
107 }
108 for (const auto& p : fetch) {
109 const string& fetch_name = p.first;
110 req.add_fetch(fetch_name);
111 }
112 RunStepResponse resp;
113 const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp));
114 if (s.ok()) {
115 for (const auto& fetch_resp : resp.tensor()) {
116 auto it = fetch.find(fetch_resp.name());
117 CHECK(it != fetch.end());
118 CHECK(it->second->FromProto(fetch_resp.tensor()));
119 }
120 }
121 return s;
122 }
123
CloseSession(const string & handle)124 Status CloseSession(const string& handle) {
125 ::grpc::ClientContext ctx;
126 CloseSessionRequest req;
127 req.set_session_handle(handle);
128 CloseSessionResponse resp;
129 return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp));
130 }
131
Reset()132 Status Reset() {
133 ::grpc::ClientContext ctx;
134 ResetRequest req;
135 ResetResponse resp;
136 return FromGrpcStatus(master_->Reset(&ctx, req, &resp));
137 }
138 };
139
TEST_F(MasterTest,CreateClose)140 TEST_F(MasterTest, CreateClose) {
141 GraphDef def; // Empty.
142 string handle;
143 int64 initial_version;
144 TF_ASSERT_OK(CreateSession(def, &handle, &initial_version));
145 EXPECT_TRUE(errors::IsAborted(CloseSession("randombits")));
146 EXPECT_TRUE(CloseSession(handle).ok());
147 }
148
TEST_F(MasterTest,ListDevices)149 TEST_F(MasterTest, ListDevices) {
150 ::grpc::ClientContext ctx;
151 ListDevicesRequest req;
152 ListDevicesResponse resp;
153 const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp));
154 TF_EXPECT_OK(s);
155 EXPECT_EQ(1, resp.local_device_size());
156 EXPECT_EQ("CPU", resp.local_device(0).device_type());
157 }
158
TEST_F(MasterTest,Reset)159 TEST_F(MasterTest, Reset) {
160 GraphDef def; // Empty.
161 string s1, s2;
162 int64 initial_version1, initial_version2;
163 TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1));
164 TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2));
165 EXPECT_TRUE(Reset().ok());
166 EXPECT_TRUE(errors::IsAborted(CloseSession(s1)));
167 EXPECT_TRUE(errors::IsAborted(CloseSession(s2)));
168 }
169
TEST_F(MasterTest,Extend)170 TEST_F(MasterTest, Extend) {
171 GraphDef def_0; // Empty.
172 string handle;
173 int64 initial_version;
174 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
175
176 Tensor A_expected(DT_FLOAT, TensorShape({2, 2}));
177 test::FillValues<float>(&A_expected, {3.0, 2.0, -1.0, 0.0});
178
179 Tensor x_expected(DT_FLOAT, TensorShape({2, 1}));
180 test::FillValues<float>(&x_expected, {2.0, 2.0});
181
182 Graph graph_1(OpRegistry::Global());
183 test::graph::Constant(&graph_1, A_expected, "A");
184 GraphDef def_1;
185 test::graph::ToGraphDef(&graph_1, &def_1);
186 int64 version_1;
187 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
188 EXPECT_GT(version_1, initial_version);
189 Tensor A(DT_FLOAT, TensorShape({2, 2}));
190 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
191 test::ExpectTensorEqual<float>(A, A_expected);
192
193 Graph graph_2(OpRegistry::Global());
194 test::graph::Constant(&graph_2, x_expected, "x");
195 GraphDef def_2;
196 test::graph::ToGraphDef(&graph_2, &def_2);
197 int64 version_2;
198 EXPECT_TRUE(errors::IsAborted(
199 ExtendSession("randombits", def_2, version_1, &version_2)));
200 TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2));
201 EXPECT_GT(version_2, version_1);
202
203 Tensor x(DT_FLOAT, TensorShape({2, 1}));
204 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}}));
205 test::ExpectTensorEqual<float>(A, A_expected);
206 test::ExpectTensorEqual<float>(x, x_expected);
207
208 TF_ASSERT_OK(CloseSession(handle));
209 }
210
TEST_F(MasterTest,ExtendUpdateStatefulFails)211 TEST_F(MasterTest, ExtendUpdateStatefulFails) {
212 GraphDef def_0; // Empty.
213 string handle;
214 int64 initial_version;
215 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
216
217 Graph graph_1(OpRegistry::Global());
218 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
219 GraphDef def_1;
220 test::graph::ToGraphDef(&graph_1, &def_1);
221
222 int64 version_1, version_2;
223 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
224 EXPECT_GT(version_1, initial_version);
225 EXPECT_TRUE(errors::IsInvalidArgument(
226 ExtendSession(handle, def_1, version_1, &version_2)));
227 TF_ASSERT_OK(CloseSession(handle));
228 }
229
TEST_F(MasterTest,ExtendTwiceFails)230 TEST_F(MasterTest, ExtendTwiceFails) {
231 GraphDef def_0; // Empty.
232 string handle;
233 int64 initial_version;
234 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
235
236 Graph graph_1(OpRegistry::Global());
237 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
238 GraphDef def_1;
239 test::graph::ToGraphDef(&graph_1, &def_1);
240
241 int64 version_1;
242 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
243 EXPECT_GT(version_1, initial_version);
244 EXPECT_TRUE(errors::IsAborted(
245 ExtendSession(handle, def_1, initial_version, &version_1)));
246 TF_ASSERT_OK(CloseSession(handle));
247 }
248
TEST_F(MasterTest,ConcurrentExtendOnlyOneSucceeds)249 TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) {
250 GraphDef def_0; // Empty.
251 string handle;
252 int64 initial_version;
253 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
254
255 Graph graph_1(OpRegistry::Global());
256 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
257 GraphDef def_1;
258 test::graph::ToGraphDef(&graph_1, &def_1);
259
260 Notification n;
261 mutex mu;
262 int succeeded = 0;
263 int failed = 0;
264 auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded,
265 &failed]() {
266 n.WaitForNotification();
267 int64 new_version;
268 Status s = ExtendSession(handle, def_1, initial_version, &new_version);
269 EXPECT_TRUE(s.ok() || errors::IsAborted(s));
270 {
271 mutex_lock l(mu);
272 if (s.ok()) {
273 ++succeeded;
274 } else {
275 ++failed;
276 }
277 }
278 };
279
280 // Run 100 concurrent Extend calls and expect only one to succeed.
281 {
282 thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100);
283 for (int i = 0; i < 100; ++i) {
284 thread_pool.Schedule(extend_fn);
285 }
286 n.Notify();
287 }
288
289 EXPECT_EQ(failed, 99);
290 EXPECT_EQ(succeeded, 1);
291 TF_ASSERT_OK(CloseSession(handle));
292 }
293
TEST_F(MasterTest,ConcurrentExtendAndRun)294 TEST_F(MasterTest, ConcurrentExtendAndRun) {
295 Graph graph_0(OpRegistry::Global());
296 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
297 test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
298 test::graph::Constant(&graph_0, a_tensor, "A");
299 GraphDef def_0;
300 test::graph::ToGraphDef(&graph_0, &def_0);
301
302 string handle;
303 int64 initial_version;
304 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
305
306 Graph graph_1(OpRegistry::Global());
307 Tensor b_tensor(DT_FLOAT, TensorShape({2, 2}));
308 test::FillValues<float>(&b_tensor, {1, 0, 0, 1});
309 test::graph::Constant(&graph_1, b_tensor, "B");
310 GraphDef def_1;
311 test::graph::ToGraphDef(&graph_1, &def_1);
312
313 Notification extend_done;
314 Notification extend_can_start;
315
316 auto get_a_fn = [this, handle, &extend_done]() {
317 Tensor A(DT_FLOAT, TensorShape({2, 2}));
318 while (!extend_done.HasBeenNotified()) {
319 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
320 }
321 // Run at least once after the Extend has completed.
322 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
323 };
324
325 auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() {
326 Tensor A(DT_FLOAT, TensorShape({2, 2}));
327 Tensor B(DT_FLOAT, TensorShape({2, 2}));
328
329 // Run at least once before the Extend has completed.
330 EXPECT_TRUE(
331 errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})));
332 extend_can_start.Notify();
333
334 // Concurrent with the Extend, we will either fail (as above), or
335 // succeed (as below).
336 while (!extend_done.HasBeenNotified()) {
337 Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}});
338 EXPECT_TRUE(errors::IsNotFound(s) || s.ok());
339 }
340
341 // Run at least once after the Extend has completed.
342 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}));
343 };
344
345 auto extend_fn = [this, handle, def_1, initial_version, &extend_done,
346 &extend_can_start]() {
347 extend_can_start.WaitForNotification();
348 int64 version_1;
349 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
350 extend_done.Notify();
351 };
352
353 {
354 thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3);
355 thread_pool.Schedule(get_a_fn);
356 thread_pool.Schedule(get_a_and_b_fn);
357 thread_pool.Schedule(extend_fn);
358 }
359
360 TF_ASSERT_OK(CloseSession(handle));
361 }
362
TEST_F(MasterTest,EigenProblem)363 TEST_F(MasterTest, EigenProblem) {
364 // A = [3 2; -1 0]; x = rand(2, 1);
365 // for i=1:100; x = A * x; end
366 // We'll try to compute the largest eigenvalue for A.
367 Graph graph(OpRegistry::Global());
368 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
369 // Store rows [3, 2] and [-1, 0] in row major format.
370 test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
371 Node* a_node = test::graph::Constant(&graph, a_tensor);
372
373 // x is from the feed.
374 Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
375 test::FillValues<float>(&x_tensor, {0, 0});
376 Node* x_node = test::graph::Constant(&graph, x_tensor);
377
378 // y = A * x
379 Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false);
380
381 GraphDef def;
382 test::graph::ToGraphDef(&graph, &def);
383
384 string handle;
385 int64 initial_version;
386 TF_CHECK_OK(CreateSession(def, &handle, &initial_version));
387
388 // Temps supporting the computation of the convergence condition.
389 const Eigen::array<Eigen::DenseIndex, 1> sum_along_dim(0);
390 const Eigen::array<Eigen::DenseIndex, 2> matrix_transpose({1, 0});
391 Tensor x(DT_FLOAT, TensorShape({2, 1}));
392 Tensor y(DT_FLOAT, TensorShape({2, 1}));
393 Eigen::Tensor<float, 1, Eigen::RowMajor> y_square_sum;
394 Eigen::Tensor<float, 2, Eigen::RowMajor> y_normalized(2, 1);
395 y_normalized.setRandom();
396 Eigen::Tensor<float, 1, Eigen::RowMajor> error_square_sum;
397 float lambda;
398
399 // The computation loop.
400 bool converged = false;
401 while (!converged) {
402 // Run one step of the graph.
403 auto x_matrix = x.matrix<float>();
404 x_matrix = y_normalized;
405 TF_EXPECT_OK(
406 RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}}));
407 auto y_matrix = y.matrix<float>();
408
409 // Client code computes the convergence condition.
410 {
411 lambda = y_matrix(0, 0) / x_matrix(0, 0);
412 y_square_sum = y.matrix<float>().square().sum(sum_along_dim);
413 const float norm = static_cast<float>(sqrt(y_square_sum(0)));
414 y_normalized = y_matrix * (1 / norm);
415 error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim);
416 VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = ["
417 << y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda;
418 converged = sqrt(error_square_sum(0)) < 1e-10;
419 }
420 }
421 EXPECT_NEAR(lambda, 2.0, 0.01);
422 TF_EXPECT_OK(CloseSession(handle));
423 }
424
425 } // namespace tensorflow
426