1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/master.h"
17
18 #include <map>
19 #include <memory>
20
21 #include "grpcpp/grpcpp.h"
22
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/graph/testlib.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/notification.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/protobuf/master.pb.h"
41
42 namespace tensorflow {
43
44 class MasterTest : public ::testing::Test {
45 protected:
MasterTest()46 MasterTest() {
47 std::vector<string> targets;
48 SessionOptions options;
49 (*options.config.mutable_device_count())["CPU"] = 1;
50 (*options.config.mutable_device_count())["GPU"] = 0;
51 TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_));
52 SharedGrpcChannelPtr channel_ptr;
53 TF_CHECK_OK(NewHostPortGrpcChannel(cluster_->targets()[0], &channel_ptr));
54 master_ = grpc::MasterService::NewStub(channel_ptr);
55 }
56
57 std::unique_ptr<test::TestCluster> cluster_;
58 std::unique_ptr<grpc::MasterService::Stub> master_;
59
60 // Helpers for MasterService.{CreateSession,RunStep,CloseSession}
61 // rpc calls.
62
CreateSession(const GraphDef & def,string * handle,int64 * initial_version)63 Status CreateSession(const GraphDef& def, string* handle,
64 int64* initial_version) {
65 ::grpc::ClientContext ctx;
66 CreateSessionRequest req;
67 *(req.mutable_graph_def()) = def;
68 // Invokes placement frequently.
69 req.mutable_config()->set_placement_period(1);
70 CreateSessionResponse resp;
71 const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp));
72 if (s.ok()) {
73 *handle = resp.session_handle();
74 *initial_version = resp.graph_version();
75 }
76 return s;
77 }
78
ExtendSession(const string & handle,const GraphDef & def,int64 current_version,int64 * new_version)79 Status ExtendSession(const string& handle, const GraphDef& def,
80 int64 current_version, int64* new_version) {
81 ::grpc::ClientContext ctx;
82 ExtendSessionRequest req;
83 req.set_session_handle(handle);
84 *(req.mutable_graph_def()) = def;
85 req.set_current_graph_version(current_version);
86 ExtendSessionResponse resp;
87 const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp));
88 if (s.ok()) {
89 *new_version = resp.new_graph_version();
90 }
91 return s;
92 }
93
RunStep(const string & handle,const std::vector<std::pair<string,const Tensor * >> & feed,const std::map<string,Tensor * > & fetch)94 Status RunStep(const string& handle,
95 const std::vector<std::pair<string, const Tensor*> >& feed,
96 const std::map<string, Tensor*>& fetch) {
97 ::grpc::ClientContext ctx;
98 RunStepRequest req;
99 req.set_session_handle(handle);
100 for (const auto& p : feed) {
101 const string& feed_name = p.first;
102 const Tensor* feed_tensor = p.second;
103 auto f = req.add_feed();
104 f->set_name(feed_name);
105 feed_tensor->AsProtoTensorContent(f->mutable_tensor());
106 }
107 for (const auto& p : fetch) {
108 const string& fetch_name = p.first;
109 req.add_fetch(fetch_name);
110 }
111 RunStepResponse resp;
112 const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp));
113 if (s.ok()) {
114 for (const auto& fetch_resp : resp.tensor()) {
115 auto it = fetch.find(fetch_resp.name());
116 CHECK(it != fetch.end());
117 CHECK(it->second->FromProto(fetch_resp.tensor()));
118 }
119 }
120 return s;
121 }
122
CloseSession(const string & handle)123 Status CloseSession(const string& handle) {
124 ::grpc::ClientContext ctx;
125 CloseSessionRequest req;
126 req.set_session_handle(handle);
127 CloseSessionResponse resp;
128 return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp));
129 }
130
Reset()131 Status Reset() {
132 ::grpc::ClientContext ctx;
133 ResetRequest req;
134 ResetResponse resp;
135 return FromGrpcStatus(master_->Reset(&ctx, req, &resp));
136 }
137 };
138
TEST_F(MasterTest,CreateClose)139 TEST_F(MasterTest, CreateClose) {
140 GraphDef def; // Empty.
141 string handle;
142 int64 initial_version;
143 TF_ASSERT_OK(CreateSession(def, &handle, &initial_version));
144 EXPECT_TRUE(errors::IsAborted(CloseSession("randombits")));
145 EXPECT_TRUE(CloseSession(handle).ok());
146 }
147
TEST_F(MasterTest,ListDevices)148 TEST_F(MasterTest, ListDevices) {
149 ::grpc::ClientContext ctx;
150 ListDevicesRequest req;
151 ListDevicesResponse resp;
152 const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp));
153 TF_EXPECT_OK(s);
154 EXPECT_EQ(1, resp.local_device_size());
155 EXPECT_EQ("CPU", resp.local_device(0).device_type());
156 }
157
TEST_F(MasterTest,Reset)158 TEST_F(MasterTest, Reset) {
159 GraphDef def; // Empty.
160 string s1, s2;
161 int64 initial_version1, initial_version2;
162 TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1));
163 TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2));
164 EXPECT_TRUE(Reset().ok());
165 EXPECT_TRUE(errors::IsAborted(CloseSession(s1)));
166 EXPECT_TRUE(errors::IsAborted(CloseSession(s2)));
167 }
168
TEST_F(MasterTest,Extend)169 TEST_F(MasterTest, Extend) {
170 GraphDef def_0; // Empty.
171 string handle;
172 int64 initial_version;
173 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
174
175 Tensor A_expected(DT_FLOAT, TensorShape({2, 2}));
176 test::FillValues<float>(&A_expected, {3.0, 2.0, -1.0, 0.0});
177
178 Tensor x_expected(DT_FLOAT, TensorShape({2, 1}));
179 test::FillValues<float>(&x_expected, {2.0, 2.0});
180
181 Graph graph_1(OpRegistry::Global());
182 test::graph::Constant(&graph_1, A_expected, "A");
183 GraphDef def_1;
184 test::graph::ToGraphDef(&graph_1, &def_1);
185 int64 version_1;
186 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
187 EXPECT_GT(version_1, initial_version);
188 Tensor A(DT_FLOAT, TensorShape({2, 2}));
189 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
190 test::ExpectTensorEqual<float>(A, A_expected);
191
192 Graph graph_2(OpRegistry::Global());
193 test::graph::Constant(&graph_2, x_expected, "x");
194 GraphDef def_2;
195 test::graph::ToGraphDef(&graph_2, &def_2);
196 int64 version_2;
197 EXPECT_TRUE(errors::IsAborted(
198 ExtendSession("randombits", def_2, version_1, &version_2)));
199 TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2));
200 EXPECT_GT(version_2, version_1);
201
202 Tensor x(DT_FLOAT, TensorShape({2, 1}));
203 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}}));
204 test::ExpectTensorEqual<float>(A, A_expected);
205 test::ExpectTensorEqual<float>(x, x_expected);
206
207 TF_ASSERT_OK(CloseSession(handle));
208 }
209
TEST_F(MasterTest,ExtendUpdateStatefulFails)210 TEST_F(MasterTest, ExtendUpdateStatefulFails) {
211 GraphDef def_0; // Empty.
212 string handle;
213 int64 initial_version;
214 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
215
216 Graph graph_1(OpRegistry::Global());
217 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
218 GraphDef def_1;
219 test::graph::ToGraphDef(&graph_1, &def_1);
220
221 int64 version_1, version_2;
222 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
223 EXPECT_GT(version_1, initial_version);
224 EXPECT_TRUE(errors::IsInvalidArgument(
225 ExtendSession(handle, def_1, version_1, &version_2)));
226 TF_ASSERT_OK(CloseSession(handle));
227 }
228
TEST_F(MasterTest,ExtendTwiceFails)229 TEST_F(MasterTest, ExtendTwiceFails) {
230 GraphDef def_0; // Empty.
231 string handle;
232 int64 initial_version;
233 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
234
235 Graph graph_1(OpRegistry::Global());
236 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
237 GraphDef def_1;
238 test::graph::ToGraphDef(&graph_1, &def_1);
239
240 int64 version_1;
241 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
242 EXPECT_GT(version_1, initial_version);
243 EXPECT_TRUE(errors::IsAborted(
244 ExtendSession(handle, def_1, initial_version, &version_1)));
245 TF_ASSERT_OK(CloseSession(handle));
246 }
247
TEST_F(MasterTest,ConcurrentExtendOnlyOneSucceeds)248 TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) {
249 GraphDef def_0; // Empty.
250 string handle;
251 int64 initial_version;
252 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
253
254 Graph graph_1(OpRegistry::Global());
255 test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
256 GraphDef def_1;
257 test::graph::ToGraphDef(&graph_1, &def_1);
258
259 Notification n;
260 mutex mu;
261 int succeeded = 0;
262 int failed = 0;
263 auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded,
264 &failed]() {
265 n.WaitForNotification();
266 int64 new_version;
267 Status s = ExtendSession(handle, def_1, initial_version, &new_version);
268 EXPECT_TRUE(s.ok() || errors::IsAborted(s));
269 {
270 mutex_lock l(mu);
271 if (s.ok()) {
272 ++succeeded;
273 } else {
274 ++failed;
275 }
276 }
277 };
278
279 // Run 100 concurrent Extend calls and expect only one to succeed.
280 {
281 thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100);
282 for (int i = 0; i < 100; ++i) {
283 thread_pool.Schedule(extend_fn);
284 }
285 n.Notify();
286 }
287
288 EXPECT_EQ(failed, 99);
289 EXPECT_EQ(succeeded, 1);
290 TF_ASSERT_OK(CloseSession(handle));
291 }
292
TEST_F(MasterTest,ConcurrentExtendAndRun)293 TEST_F(MasterTest, ConcurrentExtendAndRun) {
294 Graph graph_0(OpRegistry::Global());
295 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
296 test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
297 test::graph::Constant(&graph_0, a_tensor, "A");
298 GraphDef def_0;
299 test::graph::ToGraphDef(&graph_0, &def_0);
300
301 string handle;
302 int64 initial_version;
303 TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
304
305 Graph graph_1(OpRegistry::Global());
306 Tensor b_tensor(DT_FLOAT, TensorShape({2, 2}));
307 test::FillValues<float>(&b_tensor, {1, 0, 0, 1});
308 test::graph::Constant(&graph_1, b_tensor, "B");
309 GraphDef def_1;
310 test::graph::ToGraphDef(&graph_1, &def_1);
311
312 Notification extend_done;
313 Notification extend_can_start;
314
315 auto get_a_fn = [this, handle, &extend_done]() {
316 Tensor A(DT_FLOAT, TensorShape({2, 2}));
317 while (!extend_done.HasBeenNotified()) {
318 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
319 }
320 // Run at least once after the Extend has completed.
321 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
322 };
323
324 auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() {
325 Tensor A(DT_FLOAT, TensorShape({2, 2}));
326 Tensor B(DT_FLOAT, TensorShape({2, 2}));
327
328 // Run at least once before the Extend has completed.
329 EXPECT_TRUE(
330 errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})));
331 extend_can_start.Notify();
332
333 // Concurrent with the Extend, we will either fail (as above), or
334 // succeed (as below).
335 while (!extend_done.HasBeenNotified()) {
336 Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}});
337 EXPECT_TRUE(errors::IsNotFound(s) || s.ok());
338 }
339
340 // Run at least once after the Extend has completed.
341 TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}));
342 };
343
344 auto extend_fn = [this, handle, def_1, initial_version, &extend_done,
345 &extend_can_start]() {
346 extend_can_start.WaitForNotification();
347 int64 version_1;
348 TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
349 extend_done.Notify();
350 };
351
352 {
353 thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3);
354 thread_pool.Schedule(get_a_fn);
355 thread_pool.Schedule(get_a_and_b_fn);
356 thread_pool.Schedule(extend_fn);
357 }
358
359 TF_ASSERT_OK(CloseSession(handle));
360 }
361
TEST_F(MasterTest,EigenProblem)362 TEST_F(MasterTest, EigenProblem) {
363 // A = [3 2; -1 0]; x = rand(2, 1);
364 // for i=1:100; x = A * x; end
365 // We'll try to compute the largest eigenvalue for A.
366 Graph graph(OpRegistry::Global());
367 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
368 // Store rows [3, 2] and [-1, 0] in row major format.
369 test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
370 Node* a_node = test::graph::Constant(&graph, a_tensor);
371
372 // x is from the feed.
373 Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
374 test::FillValues<float>(&x_tensor, {0, 0});
375 Node* x_node = test::graph::Constant(&graph, x_tensor);
376
377 // y = A * x
378 Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false);
379
380 GraphDef def;
381 test::graph::ToGraphDef(&graph, &def);
382
383 string handle;
384 int64 initial_version;
385 TF_CHECK_OK(CreateSession(def, &handle, &initial_version));
386
387 // Temps supporting the computation of the convergence condition.
388 const Eigen::array<Eigen::DenseIndex, 1> sum_along_dim(0);
389 const Eigen::array<Eigen::DenseIndex, 2> matrix_transpose({1, 0});
390 Tensor x(DT_FLOAT, TensorShape({2, 1}));
391 Tensor y(DT_FLOAT, TensorShape({2, 1}));
392 Eigen::Tensor<float, 1, Eigen::RowMajor> y_square_sum;
393 Eigen::Tensor<float, 2, Eigen::RowMajor> y_normalized(2, 1);
394 y_normalized.setRandom();
395 Eigen::Tensor<float, 1, Eigen::RowMajor> error_square_sum;
396 float lambda;
397
398 // The computation loop.
399 bool converged = false;
400 while (!converged) {
401 // Run one step of the graph.
402 auto x_matrix = x.matrix<float>();
403 x_matrix = y_normalized;
404 TF_EXPECT_OK(
405 RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}}));
406 auto y_matrix = y.matrix<float>();
407
408 // Client code computes the convergence condition.
409 {
410 lambda = y_matrix(0, 0) / x_matrix(0, 0);
411 y_square_sum = y.matrix<float>().square().sum(sum_along_dim);
412 const float norm = static_cast<float>(sqrt(y_square_sum(0)));
413 y_normalized = y_matrix * (1 / norm);
414 error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim);
415 VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = ["
416 << y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda;
417 converged = sqrt(error_square_sum(0)) < 1e-10;
418 }
419 }
420 EXPECT_NEAR(lambda, 2.0, 0.01);
421 TF_EXPECT_OK(CloseSession(handle));
422 }
423
424 } // namespace tensorflow
425