• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/master.h"
17 
18 #include <map>
19 #include <memory>
20 
21 #include "grpcpp/grpcpp.h"
22 
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/graph/testlib.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/notification.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/test.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/protobuf/master.pb.h"
41 
42 namespace tensorflow {
43 
44 class MasterTest : public ::testing::Test {
45  protected:
MasterTest()46   MasterTest() {
47     std::vector<string> targets;
48     SessionOptions options;
49     (*options.config.mutable_device_count())["CPU"] = 1;
50     (*options.config.mutable_device_count())["GPU"] = 0;
51     TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_));
52     SharedGrpcChannelPtr channel_ptr;
53     TF_CHECK_OK(NewHostPortGrpcChannel(
54         cluster_->targets()[0], &options.config.rpc_options(), &channel_ptr));
55     master_ = grpc::MasterService::NewStub(channel_ptr);
56   }
57 
58   std::unique_ptr<test::TestCluster> cluster_;
59   std::unique_ptr<grpc::MasterService::Stub> master_;
60 
61   // Helpers for MasterService.{CreateSession,RunStep,CloseSession}
62   // rpc calls.
63 
CreateSession(const GraphDef & def,string * handle,int64 * initial_version)64   Status CreateSession(const GraphDef& def, string* handle,
65                        int64* initial_version) {
66     ::grpc::ClientContext ctx;
67     CreateSessionRequest req;
68     *(req.mutable_graph_def()) = def;
69     // Invokes placement frequently.
70     req.mutable_config()->set_placement_period(1);
71     CreateSessionResponse resp;
72     const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp));
73     if (s.ok()) {
74       *handle = resp.session_handle();
75       *initial_version = resp.graph_version();
76     }
77     return s;
78   }
79 
ExtendSession(const string & handle,const GraphDef & def,int64 current_version,int64 * new_version)80   Status ExtendSession(const string& handle, const GraphDef& def,
81                        int64 current_version, int64* new_version) {
82     ::grpc::ClientContext ctx;
83     ExtendSessionRequest req;
84     req.set_session_handle(handle);
85     *(req.mutable_graph_def()) = def;
86     req.set_current_graph_version(current_version);
87     ExtendSessionResponse resp;
88     const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp));
89     if (s.ok()) {
90       *new_version = resp.new_graph_version();
91     }
92     return s;
93   }
94 
RunStep(const string & handle,const std::vector<std::pair<string,const Tensor * >> & feed,const std::map<string,Tensor * > & fetch)95   Status RunStep(const string& handle,
96                  const std::vector<std::pair<string, const Tensor*> >& feed,
97                  const std::map<string, Tensor*>& fetch) {
98     ::grpc::ClientContext ctx;
99     RunStepRequest req;
100     req.set_session_handle(handle);
101     for (const auto& p : feed) {
102       const string& feed_name = p.first;
103       const Tensor* feed_tensor = p.second;
104       auto f = req.add_feed();
105       f->set_name(feed_name);
106       feed_tensor->AsProtoTensorContent(f->mutable_tensor());
107     }
108     for (const auto& p : fetch) {
109       const string& fetch_name = p.first;
110       req.add_fetch(fetch_name);
111     }
112     RunStepResponse resp;
113     const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp));
114     if (s.ok()) {
115       for (const auto& fetch_resp : resp.tensor()) {
116         auto it = fetch.find(fetch_resp.name());
117         CHECK(it != fetch.end());
118         CHECK(it->second->FromProto(fetch_resp.tensor()));
119       }
120     }
121     return s;
122   }
123 
CloseSession(const string & handle)124   Status CloseSession(const string& handle) {
125     ::grpc::ClientContext ctx;
126     CloseSessionRequest req;
127     req.set_session_handle(handle);
128     CloseSessionResponse resp;
129     return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp));
130   }
131 
Reset()132   Status Reset() {
133     ::grpc::ClientContext ctx;
134     ResetRequest req;
135     ResetResponse resp;
136     return FromGrpcStatus(master_->Reset(&ctx, req, &resp));
137   }
138 };
139 
TEST_F(MasterTest,CreateClose)140 TEST_F(MasterTest, CreateClose) {
141   GraphDef def;  // Empty.
142   string handle;
143   int64 initial_version;
144   TF_ASSERT_OK(CreateSession(def, &handle, &initial_version));
145   EXPECT_TRUE(errors::IsAborted(CloseSession("randombits")));
146   EXPECT_TRUE(CloseSession(handle).ok());
147 }
148 
TEST_F(MasterTest,ListDevices)149 TEST_F(MasterTest, ListDevices) {
150   ::grpc::ClientContext ctx;
151   ListDevicesRequest req;
152   ListDevicesResponse resp;
153   const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp));
154   TF_EXPECT_OK(s);
155   EXPECT_EQ(1, resp.local_device_size());
156   EXPECT_EQ("CPU", resp.local_device(0).device_type());
157 }
158 
TEST_F(MasterTest,Reset)159 TEST_F(MasterTest, Reset) {
160   GraphDef def;  // Empty.
161   string s1, s2;
162   int64 initial_version1, initial_version2;
163   TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1));
164   TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2));
165   EXPECT_TRUE(Reset().ok());
166   EXPECT_TRUE(errors::IsAborted(CloseSession(s1)));
167   EXPECT_TRUE(errors::IsAborted(CloseSession(s2)));
168 }
169 
TEST_F(MasterTest,Extend)170 TEST_F(MasterTest, Extend) {
171   GraphDef def_0;  // Empty.
172   string handle;
173   int64 initial_version;
174   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
175 
176   Tensor A_expected(DT_FLOAT, TensorShape({2, 2}));
177   test::FillValues<float>(&A_expected, {3.0, 2.0, -1.0, 0.0});
178 
179   Tensor x_expected(DT_FLOAT, TensorShape({2, 1}));
180   test::FillValues<float>(&x_expected, {2.0, 2.0});
181 
182   Graph graph_1(OpRegistry::Global());
183   test::graph::Constant(&graph_1, A_expected, "A");
184   GraphDef def_1;
185   test::graph::ToGraphDef(&graph_1, &def_1);
186   int64 version_1;
187   TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
188   EXPECT_GT(version_1, initial_version);
189   Tensor A(DT_FLOAT, TensorShape({2, 2}));
190   TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
191   test::ExpectTensorEqual<float>(A, A_expected);
192 
193   Graph graph_2(OpRegistry::Global());
194   test::graph::Constant(&graph_2, x_expected, "x");
195   GraphDef def_2;
196   test::graph::ToGraphDef(&graph_2, &def_2);
197   int64 version_2;
198   EXPECT_TRUE(errors::IsAborted(
199       ExtendSession("randombits", def_2, version_1, &version_2)));
200   TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2));
201   EXPECT_GT(version_2, version_1);
202 
203   Tensor x(DT_FLOAT, TensorShape({2, 1}));
204   TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}}));
205   test::ExpectTensorEqual<float>(A, A_expected);
206   test::ExpectTensorEqual<float>(x, x_expected);
207 
208   TF_ASSERT_OK(CloseSession(handle));
209 }
210 
TEST_F(MasterTest,ExtendUpdateStatefulFails)211 TEST_F(MasterTest, ExtendUpdateStatefulFails) {
212   GraphDef def_0;  // Empty.
213   string handle;
214   int64 initial_version;
215   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
216 
217   Graph graph_1(OpRegistry::Global());
218   test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
219   GraphDef def_1;
220   test::graph::ToGraphDef(&graph_1, &def_1);
221 
222   int64 version_1, version_2;
223   TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
224   EXPECT_GT(version_1, initial_version);
225   EXPECT_TRUE(errors::IsInvalidArgument(
226       ExtendSession(handle, def_1, version_1, &version_2)));
227   TF_ASSERT_OK(CloseSession(handle));
228 }
229 
TEST_F(MasterTest,ExtendTwiceFails)230 TEST_F(MasterTest, ExtendTwiceFails) {
231   GraphDef def_0;  // Empty.
232   string handle;
233   int64 initial_version;
234   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
235 
236   Graph graph_1(OpRegistry::Global());
237   test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
238   GraphDef def_1;
239   test::graph::ToGraphDef(&graph_1, &def_1);
240 
241   int64 version_1;
242   TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
243   EXPECT_GT(version_1, initial_version);
244   EXPECT_TRUE(errors::IsAborted(
245       ExtendSession(handle, def_1, initial_version, &version_1)));
246   TF_ASSERT_OK(CloseSession(handle));
247 }
248 
TEST_F(MasterTest,ConcurrentExtendOnlyOneSucceeds)249 TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) {
250   GraphDef def_0;  // Empty.
251   string handle;
252   int64 initial_version;
253   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
254 
255   Graph graph_1(OpRegistry::Global());
256   test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
257   GraphDef def_1;
258   test::graph::ToGraphDef(&graph_1, &def_1);
259 
260   Notification n;
261   mutex mu;
262   int succeeded = 0;
263   int failed = 0;
264   auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded,
265                     &failed]() {
266     n.WaitForNotification();
267     int64 new_version;
268     Status s = ExtendSession(handle, def_1, initial_version, &new_version);
269     EXPECT_TRUE(s.ok() || errors::IsAborted(s));
270     {
271       mutex_lock l(mu);
272       if (s.ok()) {
273         ++succeeded;
274       } else {
275         ++failed;
276       }
277     }
278   };
279 
280   // Run 100 concurrent Extend calls and expect only one to succeed.
281   {
282     thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100);
283     for (int i = 0; i < 100; ++i) {
284       thread_pool.Schedule(extend_fn);
285     }
286     n.Notify();
287   }
288 
289   EXPECT_EQ(failed, 99);
290   EXPECT_EQ(succeeded, 1);
291   TF_ASSERT_OK(CloseSession(handle));
292 }
293 
TEST_F(MasterTest,ConcurrentExtendAndRun)294 TEST_F(MasterTest, ConcurrentExtendAndRun) {
295   Graph graph_0(OpRegistry::Global());
296   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
297   test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
298   test::graph::Constant(&graph_0, a_tensor, "A");
299   GraphDef def_0;
300   test::graph::ToGraphDef(&graph_0, &def_0);
301 
302   string handle;
303   int64 initial_version;
304   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
305 
306   Graph graph_1(OpRegistry::Global());
307   Tensor b_tensor(DT_FLOAT, TensorShape({2, 2}));
308   test::FillValues<float>(&b_tensor, {1, 0, 0, 1});
309   test::graph::Constant(&graph_1, b_tensor, "B");
310   GraphDef def_1;
311   test::graph::ToGraphDef(&graph_1, &def_1);
312 
313   Notification extend_done;
314   Notification extend_can_start;
315 
316   auto get_a_fn = [this, handle, &extend_done]() {
317     Tensor A(DT_FLOAT, TensorShape({2, 2}));
318     while (!extend_done.HasBeenNotified()) {
319       TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
320     }
321     // Run at least once after the Extend has completed.
322     TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
323   };
324 
325   auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() {
326     Tensor A(DT_FLOAT, TensorShape({2, 2}));
327     Tensor B(DT_FLOAT, TensorShape({2, 2}));
328 
329     // Run at least once before the Extend has completed.
330     EXPECT_TRUE(
331         errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})));
332     extend_can_start.Notify();
333 
334     // Concurrent with the Extend, we will either fail (as above), or
335     // succeed (as below).
336     while (!extend_done.HasBeenNotified()) {
337       Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}});
338       EXPECT_TRUE(errors::IsNotFound(s) || s.ok());
339     }
340 
341     // Run at least once after the Extend has completed.
342     TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}));
343   };
344 
345   auto extend_fn = [this, handle, def_1, initial_version, &extend_done,
346                     &extend_can_start]() {
347     extend_can_start.WaitForNotification();
348     int64 version_1;
349     TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
350     extend_done.Notify();
351   };
352 
353   {
354     thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3);
355     thread_pool.Schedule(get_a_fn);
356     thread_pool.Schedule(get_a_and_b_fn);
357     thread_pool.Schedule(extend_fn);
358   }
359 
360   TF_ASSERT_OK(CloseSession(handle));
361 }
362 
TEST_F(MasterTest,EigenProblem)363 TEST_F(MasterTest, EigenProblem) {
364   // A = [3 2; -1 0]; x = rand(2, 1);
365   // for i=1:100; x = A * x; end
366   // We'll try to compute the largest eigenvalue for A.
367   Graph graph(OpRegistry::Global());
368   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
369   // Store rows [3, 2] and [-1, 0] in row major format.
370   test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
371   Node* a_node = test::graph::Constant(&graph, a_tensor);
372 
373   // x is from the feed.
374   Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
375   test::FillValues<float>(&x_tensor, {0, 0});
376   Node* x_node = test::graph::Constant(&graph, x_tensor);
377 
378   // y = A * x
379   Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false);
380 
381   GraphDef def;
382   test::graph::ToGraphDef(&graph, &def);
383 
384   string handle;
385   int64 initial_version;
386   TF_CHECK_OK(CreateSession(def, &handle, &initial_version));
387 
388   // Temps supporting the computation of the convergence condition.
389   const Eigen::array<Eigen::DenseIndex, 1> sum_along_dim(0);
390   const Eigen::array<Eigen::DenseIndex, 2> matrix_transpose({1, 0});
391   Tensor x(DT_FLOAT, TensorShape({2, 1}));
392   Tensor y(DT_FLOAT, TensorShape({2, 1}));
393   Eigen::Tensor<float, 1, Eigen::RowMajor> y_square_sum;
394   Eigen::Tensor<float, 2, Eigen::RowMajor> y_normalized(2, 1);
395   y_normalized.setRandom();
396   Eigen::Tensor<float, 1, Eigen::RowMajor> error_square_sum;
397   float lambda;
398 
399   // The computation loop.
400   bool converged = false;
401   while (!converged) {
402     // Run one step of the graph.
403     auto x_matrix = x.matrix<float>();
404     x_matrix = y_normalized;
405     TF_EXPECT_OK(
406         RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}}));
407     auto y_matrix = y.matrix<float>();
408 
409     // Client code computes the convergence condition.
410     {
411       lambda = y_matrix(0, 0) / x_matrix(0, 0);
412       y_square_sum = y.matrix<float>().square().sum(sum_along_dim);
413       const float norm = static_cast<float>(sqrt(y_square_sum(0)));
414       y_normalized = y_matrix * (1 / norm);
415       error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim);
416       VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = ["
417               << y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda;
418       converged = sqrt(error_square_sum(0)) < 1e-10;
419     }
420   }
421   EXPECT_NEAR(lambda, 2.0, 0.01);
422   TF_EXPECT_OK(CloseSession(handle));
423 }
424 
425 }  // namespace tensorflow
426