1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/clusters/single_machine.h"
17 
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/resource_variable_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/framework/cost_graph.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/protobuf/queue_runner.pb.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 namespace {
36 
37 class SingleMachineTest : public ::testing::Test {
38  public:
SetUp()39   void SetUp() override {
40     // Provision a single machine with 3 cpu cores, and a short timeout of 5
41     // seconds: since there isn't much work to process a test graph that should
42     // be plenty.
43 #if TENSORFLOW_USE_ROCM
44     // ROCm takes longer to start up
45     int timeout_s = 10;
46 #else
47     int timeout_s = 5;
48 #endif
49 #ifdef THREAD_SANITIZER
50     timeout_s *= 5;
51 #endif
52     cluster_.reset(
53         new SingleMachine(timeout_s, 3 /* num_cpu_cores */, 0 /* num_gpus */));
54     TF_CHECK_OK(cluster_->EnablePeakMemoryStats());
55     TF_CHECK_OK(cluster_->Provision());
56   }
57 
TearDown()58   void TearDown() override {
59     if (cluster_) {
60       TF_CHECK_OK(cluster_->Shutdown());
61     }
62     cluster_.reset();
63   }
64 
65  protected:
66   std::unique_ptr<SingleMachine> cluster_;
67 };
68 
TEST_F(SingleMachineTest,ClusterType)69 TEST_F(SingleMachineTest, ClusterType) {
70   CHECK_EQ("single_machine", cluster_->type());
71 }
72 
TEST_F(SingleMachineTest,CostModel)73 TEST_F(SingleMachineTest, CostModel) {
74   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
75                                           cluster_->GetDeviceNames());
76   GrapplerItem item;
77   CHECK(fake_input.NextItem(&item));
78 
79   TF_CHECK_OK(cluster_->Initialize(item));
80 
81   RunMetadata metadata;
82   const int64 start_micros = Env::Default()->NowMicros();
83   TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
84   const int64 run_duration_micros = Env::Default()->NowMicros() - start_micros;
85 
86   // There should be at least 4 nodes corresponding to the 4 stages we created
87   // in the fake input.
88   EXPECT_LE(4, metadata.cost_graph().node_size());
89   for (const auto& node : metadata.cost_graph().node()) {
90     // Skip the special nodes inserted by TF: these are prefixed with an
91     // underscore.
92     if (node.name()[0] == '_' || node.name().find("/_") != string::npos) {
93       continue;
94     }
95 #ifndef INTEL_MKL
96     // The output size of MKL op is 2, and cannot filter out the MKL op
97     // with the OP name (no op name here), so just disable this check in
98     // TF_MKL build.
99     EXPECT_EQ(1, node.output_info_size());
100 #endif  // !INTEL_MKL
101     EXPECT_LE(8, node.output_info(0).size());
102     const TensorShapeProto& shape = node.output_info(0).shape();
103     EXPECT_EQ(2, shape.dim_size());
104     EXPECT_EQ(10, shape.dim(0).size());
105     EXPECT_EQ(1, shape.dim(1).size());
106     EXPECT_LE(0, node.compute_cost());
107     EXPECT_GE(run_duration_micros, node.compute_cost());
108   }
109 }
110 
TEST_F(SingleMachineTest,Queue)111 TEST_F(SingleMachineTest, Queue) {
112   TrivialTestGraphInputYielder fake_input(4, 1, 10, true,
113                                           cluster_->GetDeviceNames());
114   GrapplerItem item;
115   CHECK(fake_input.NextItem(&item));
116 
117   TF_CHECK_OK(cluster_->Initialize(item));
118   RunMetadata metadata;
119   TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
120 }
121 
TEST_F(SingleMachineTest,MultipleItems)122 TEST_F(SingleMachineTest, MultipleItems) {
123   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
124                                           cluster_->GetDeviceNames());
125 
126   for (int i = 0; i < 3; ++i) {
127     GrapplerItem item;
128     CHECK(fake_input.NextItem(&item));
129     TF_CHECK_OK(cluster_->Initialize(item));
130     RunMetadata metadata1;
131     TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata1));
132     RunMetadata metadata2;
133     TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata2));
134 
135     // There should be at least 4 nodes corresponding to the 4 stages we created
136     // in the fake input, plus 1 enqueue and 1 dequeue node.
137     EXPECT_LE(6, metadata1.cost_graph().node_size());
138     for (const auto& node : metadata1.cost_graph().node()) {
139       if (node.name()[0] == '_' || node.name().find("/_") != string::npos ||
140           node.name() == "queue") {
141         continue;
142       }
143 #ifndef INTEL_MKL
144       EXPECT_EQ(1, node.output_info_size());
145 #endif  // !INTEL_MKL
146       const TensorShapeProto& shape = node.output_info(0).shape();
147       EXPECT_EQ(2, shape.dim_size());
148       EXPECT_EQ(10, shape.dim(0).size());
149       EXPECT_EQ(1, shape.dim(1).size());
150     }
151 
152     for (int i = 0; i < metadata1.cost_graph().node_size(); ++i) {
153       metadata1.mutable_cost_graph()->mutable_node(i)->set_compute_cost(0);
154       metadata1.clear_step_stats();
155     }
156     for (int i = 0; i < metadata2.cost_graph().node_size(); ++i) {
157       metadata2.mutable_cost_graph()->mutable_node(i)->set_compute_cost(0);
158       metadata2.clear_step_stats();
159     }
160     string s1;
161     ::tensorflow::protobuf::TextFormat::PrintToString(metadata1, &s1);
162     string s2;
163     ::tensorflow::protobuf::TextFormat::PrintToString(metadata2, &s2);
164     EXPECT_EQ(s1, s2);
165   }
166 }
167 
TEST_F(SingleMachineTest,GraphOptimizations)168 TEST_F(SingleMachineTest, GraphOptimizations) {
169   // Create a graph that can be fully precomputed
170   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
171   auto zero = ops::Const(root.WithOpName("zero"), 0.0f, {2, 3});
172   auto one = ops::Const(root.WithOpName("one"), 1.0f, {2, 3});
173   auto add = ops::Add(root.WithOpName("add"), zero, one);
174   auto square = ops::Square(root.WithOpName("square"), add);
175 
176   auto new_shape = ops::Const(root.WithOpName("new_shape"), {3, -1}, {2});
177   auto reshaped = ops::Reshape(root.WithOpName("reshaped"), square, new_shape);
178   auto final_shape = ops::Shape(root.WithOpName("final_shape"), reshaped);
179 
180   auto expected_shape =
181       ops::Const(root.WithOpName("expected_shape"), {3, 2}, {2});
182   auto valid =
183       ops::Equal(root.WithOpName("valid"), final_shape, expected_shape);
184   auto all_dims = ops::Const(root.WithOpName("all_dims"), {0}, {1});
185 
186   auto all_valid = ops::All(root.WithOpName("all_valid"), valid, all_dims);
187   auto assert_valid = ops::Assert(root.WithOpName("assert_valid"), all_valid,
188                                   {final_shape.output});
189 
190   GrapplerItem item;
191   TF_CHECK_OK(root.ToGraphDef(&item.graph));
192   item.fetch.push_back("assert_valid");
193 
194   // Force the placement of all the nodes on CPU since TF attempts to use a GPU
195   // when possible event though we created the session to have a single CPU !.
196   for (auto& node : *item.graph.mutable_node()) {
197     node.set_device("/cpu:0");
198   }
199 
200   // With optimizations turned on, some nodes could have been optimized away,
201   // and the cost model could be partial. Restart the cluster with optimizations
202   // disabled and make sure we have all the information we're looking for.
203   TF_CHECK_OK(cluster_->Shutdown());
204   cluster_->DisableOptimizer(true);
205   TF_CHECK_OK(cluster_->Provision());
206 
207   RunMetadata metadata;
208   TF_CHECK_OK(cluster_->Initialize(item));
209   TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
210   std::set<string> cost_nodes;
211   for (const auto& node : metadata.cost_graph().node()) {
212 #ifdef INTEL_MKL
213     // Skip the special nodes inserted by TF (and MKL): these are either
214     // prefixed with an underscore or contain "/_".
215     if (node.name()[0] == '_' || node.name().find("/_") != string::npos) {
216       continue;
217     }
218     cost_nodes.insert(node.name());
219 #else
220     // Skip nodes added by TF internally.
221     if (node.name()[0] != '_') {
222       cost_nodes.insert(node.name());
223     }
224 #endif
225   }
226   const std::set<string> expected_cost_nodes = {
227       "zero",      "one",      "add",         "square",
228       "new_shape", "reshaped", "final_shape", "expected_shape",
229       "valid",     "all_dims", "all_valid",   "assert_valid"};
230   EXPECT_EQ(expected_cost_nodes, cost_nodes);
231 }
232 
TEST_F(SingleMachineTest,TimeOuts)233 TEST_F(SingleMachineTest, TimeOuts) {
234   // Create a graph that will block forever: Just try to dequeue data from a
235   // queue that is never fed.
236   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
237   auto q = ops::FIFOQueue(root.WithOpName("queue"), {DataType::DT_INT32});
238   auto dequeue =
239       ops::QueueDequeue(root.WithOpName("dequeue"), q, {DataType::DT_INT32});
240 
241   GrapplerItem item;
242   TF_CHECK_OK(root.ToGraphDef(&item.graph));
243   item.fetch.push_back("dequeue");
244 
245   TF_CHECK_OK(cluster_->Initialize(item));
246   RunMetadata metadata;
247   Status s1 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata);
248   EXPECT_TRUE(errors::IsDeadlineExceeded(s1));
249   Status s2 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata);
250   EXPECT_TRUE(errors::IsDeadlineExceeded(s2));
251 }
252 
RunInfiniteTFLoop()253 static void RunInfiniteTFLoop() {
254   // Create a while(true) loop
255   GrapplerItem item;
256 
257   NodeDef* shp = item.graph.add_node();
258   shp->set_name("shape");
259   shp->set_op("Const");
260   (*shp->mutable_attr())["dtype"].set_type(DT_INT32);
261   Tensor shp_tensor(DT_INT32, TensorShape({1}));
262   shp_tensor.flat<int32>()(0) = 1;
263   shp_tensor.AsProtoTensorContent(
264       (*shp->mutable_attr())["value"].mutable_tensor());
265 
266   NodeDef* r = item.graph.add_node();
267   r->set_name("random");
268   r->set_op("RandomUniform");
269   (*r->mutable_attr())["dtype"].set_type(DT_FLOAT);
270   (*r->mutable_attr())["T"].set_type(DT_INT32);
271   *r->add_input() = "shape";
272 
273   NodeDef* e = item.graph.add_node();
274   e->set_name("while/Enter");
275   e->set_op("Enter");
276   (*e->mutable_attr())["T"].set_type(DT_FLOAT);
277   (*e->mutable_attr())["frame_name"].set_s("while/while/");
278   *e->add_input() = "random";
279 
280   NodeDef* m = item.graph.add_node();
281   m->set_name("while/Merge");
282   m->set_op("Merge");
283   (*m->mutable_attr())["T"].set_type(DT_FLOAT);
284   (*m->mutable_attr())["N"].set_i(2);
285   *m->add_input() = "while/Enter";
286   *m->add_input() = "while/NextIteration";
287 
288   NodeDef* t = item.graph.add_node();
289   t->set_name("always_true");
290   t->set_op("Const");
291   (*t->mutable_attr())["dtype"].set_type(DT_BOOL);
292   *t->add_input() = "^while/Merge";
293   Tensor true_tensor(DT_BOOL, TensorShape());
294   true_tensor.flat<bool>()(0) = true;
295   true_tensor.AsProtoTensorContent(
296       (*t->mutable_attr())["value"].mutable_tensor());
297 
298   NodeDef* c = item.graph.add_node();
299   c->set_name("while/LoopCond");
300   c->set_op("LoopCond");
301   *c->add_input() = "always_true";
302 
303   NodeDef* s = item.graph.add_node();
304   s->set_name("while/Switch");
305   (*s->mutable_attr())["T"].set_type(DT_FLOAT);
306   s->set_op("Switch");
307   *s->add_input() = "while/Merge";
308   *s->add_input() = "while/LoopCond";
309 
310   NodeDef* i = item.graph.add_node();
311   i->set_name("while/Identity");
312   i->set_op("Identity");
313   (*i->mutable_attr())["T"].set_type(DT_FLOAT);
314   *i->add_input() = "while/Switch:1";
315 
316   NodeDef* n = item.graph.add_node();
317   n->set_name("while/NextIteration");
318   n->set_op("NextIteration");
319   (*n->mutable_attr())["T"].set_type(DT_FLOAT);
320   *n->add_input() = "while/Identity";
321 
322   NodeDef* x = item.graph.add_node();
323   x->set_name("while/Exit");
324   x->set_op("Exit");
325   (*x->mutable_attr())["T"].set_type(DT_FLOAT);
326   *x->add_input() = "while/Switch";
327 
328   item.fetch.push_back("while/Exit");
329 
330   // Create our own cluster to run it
331   SingleMachine cluster(5, 3, 0);
332   TF_CHECK_OK(cluster.Provision());
333   TF_CHECK_OK(cluster.Initialize(item));
334 
335   Status s1 = cluster.Run(item.graph, item.feed, item.fetch, nullptr);
336   if (!errors::IsDeadlineExceeded(s1)) {
337     LOG(ERROR) << "Expected 'deadline exceeded' error, got " << s1;
338     // Exit to break the infinite loop
339     _exit(1);
340   }
341 
342   // Attempt to shutdown the cluster and make sure we get the proper error code.
343   Status s2 = cluster.Shutdown();
344   if (!errors::IsUnavailable(s2)) {
345     LOG(ERROR) << "Expected 'unavailable' error, got " << s2;
346     // Exit to break the infinite loop
347     _exit(2);
348   }
349 
350   // The isn't much we can do at this point. Exit with the error code 0 to
351   // indicate everything went according to plan.
352   _exit(0);
353 }
354 
TEST_F(SingleMachineTest,InfiniteLoops)355 TEST_F(SingleMachineTest, InfiniteLoops) {
356 #if !(TENSORFLOW_USE_ROCM)  // fails with ROCm (investigate)
357   // The RunInfiniteTFLoop function creates its own cluster.
358   TF_CHECK_OK(cluster_->Shutdown());
359   EXPECT_EXIT(RunInfiniteTFLoop(), ::testing::ExitedWithCode(0), ".*");
360 #endif
361 }
362 
TEST_F(SingleMachineTest,InitializationMemory)363 TEST_F(SingleMachineTest, InitializationMemory) {
364   // Build a variable and its initialization graph.
365   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
366   int batch_size = 10;
367   Output x =
368       ops::RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT);
369   Output v = ops::Variable(s.WithOpName("v"), TensorShape({batch_size, 1}),
370                            DataType::DT_FLOAT);
371   Output init = ops::Assign(s.WithOpName("init"), v, x);
372 
373   GrapplerItem item;
374   TF_CHECK_OK(s.ToGraphDef(&item.graph));
375   item.init_ops.push_back(init.name());
376   item.fetch.push_back(v.name());
377 
378   TF_CHECK_OK(cluster_->Initialize(item));
379   RunMetadata metadata;
380   TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
381 
382   // Check that the initialization op is present in the cost model.
383   bool found = false;
384   for (const auto& node : metadata.cost_graph().node()) {
385     found |= (node.name() == NodeName(init.name()));
386   }
387   EXPECT_TRUE(found);
388 }
389 
390 namespace {
391 
392 template <class T>
SetNodeAttr(const string & key,const T & value,NodeDef * node)393 inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
394   AttrValue attr_value;
395   SetAttrValue(value, &attr_value);
396   auto* attr_map = node->mutable_attr();
397   (*attr_map)[key] = attr_value;
398 }
399 template <>
SetNodeAttr(const string & key,const Tensor & tensor,NodeDef * node)400 inline void SetNodeAttr(const string& key, const Tensor& tensor,
401                         NodeDef* node) {
402   TensorProto tensor_proto;
403   tensor.AsProtoTensorContent(&tensor_proto);
404   SetNodeAttr(key, tensor_proto, node);
405 }
406 
407 }  // namespace
408 
TEST_F(SingleMachineTest,PersistentMemory)409 TEST_F(SingleMachineTest, PersistentMemory) {
410   // Build a hashtable and its initialization graph.
411   GrapplerItem item;
412   const DataType key_dtype = DT_INT64;
413   const DataType data_dtype = DT_INT64;
414 
415   NodeDef* hashtable_node = item.graph.add_node();
416   hashtable_node->set_op("HashTable");
417   hashtable_node->set_name("hash_table");
418   SetNodeAttr("key_dtype", key_dtype, hashtable_node);
419   SetNodeAttr("value_dtype", data_dtype, hashtable_node);
420 
421   // Initial hashtable keys and values
422   NodeDef* keys_node = item.graph.add_node();
423   keys_node->set_op("Const");
424   keys_node->set_name("table_keys");
425   SetNodeAttr("dtype", key_dtype, keys_node);
426   Tensor keys(key_dtype, TensorShape{2});
427   keys.vec<int64>()(0) = 123;
428   keys.vec<int64>()(1) = 321;
429   SetNodeAttr("value", keys, keys_node);
430 
431   NodeDef* values_node = item.graph.add_node();
432   values_node->set_op("Const");
433   values_node->set_name("table_values");
434   SetNodeAttr("dtype", data_dtype, values_node);
435   Tensor values(data_dtype, TensorShape{2});
436   values.vec<int64>()(0) = 789;
437   values.vec<int64>()(1) = 987;
438   SetNodeAttr("value", values, values_node);
439 
440   // InitializeTable node
441   NodeDef* init_table_node = item.graph.add_node();
442   init_table_node->set_op("InitializeTable");
443   init_table_node->set_name("initialize_table");
444   SetNodeAttr("Tkey", key_dtype, init_table_node);
445   SetNodeAttr("Tval", data_dtype, init_table_node);
446   *init_table_node->add_input() = "hash_table";
447   *init_table_node->add_input() = "table_keys";
448   *init_table_node->add_input() = "table_values";
449   item.init_ops.push_back(init_table_node->name());
450 
451   // Key to lookup
452   NodeDef* query_node = item.graph.add_node();
453   query_node->set_op("Const");
454   query_node->set_name("query");
455   SetNodeAttr("dtype", key_dtype, query_node);
456   Tensor query(key_dtype, TensorShape({}));
457   query.flat<int64>()(0) = 0;
458   SetNodeAttr("value", query, query_node);
459 
460   // Default return value of hashtable lookup
461   NodeDef* default_value_node = item.graph.add_node();
462   default_value_node->set_op("Const");
463   default_value_node->set_name("default_table_value");
464   SetNodeAttr("dtype", data_dtype, default_value_node);
465   Tensor dflt(data_dtype, TensorShape({}));
466   dflt.flat<int64>()(0) = 456;
467   SetNodeAttr("value", dflt, default_value_node);
468 
469   // HashTable lookup node
470   NodeDef* lookup_node = item.graph.add_node();
471   lookup_node->set_op("LookupTableFind");
472   lookup_node->set_name("table_lookup");
473   SetNodeAttr("Tin", key_dtype, lookup_node);
474   SetNodeAttr("Tout", data_dtype, lookup_node);
475   *lookup_node->add_input() = "hash_table";
476   *lookup_node->add_input() = "query";
477   *lookup_node->add_input() = "default_table_value";
478   item.fetch.push_back(lookup_node->name());
479 
480   // Run the graph
481   TF_CHECK_OK(cluster_->Initialize(item));
482   RunMetadata metadata;
483   TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
484 
485   // Check the cost model.
486   bool found_table_init = false;
487   bool found_hashtable = false;
488   for (const auto& node : metadata.cost_graph().node()) {
489     if (node.name() == "hash_table") {
490       found_hashtable = true;
491       // Persistent memory usage should be 0 since it's recorded as part of the
492       // initialize_table op.
493       EXPECT_EQ(0, node.persistent_memory_size());
494     } else if (node.name() == "initialize_table") {
495       found_table_init = true;
496       // Persistent memory should hold 2 keys and 2 values.
497       EXPECT_LE(4 * sizeof(int64), node.persistent_memory_size());
498     }
499   }
500   EXPECT_TRUE(found_table_init);
501   EXPECT_TRUE(found_hashtable);
502 }
503 
CreateGrapplerItemWithResourceMemory()504 GrapplerItem CreateGrapplerItemWithResourceMemory() {
505   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
506 
507   // Add a variable and initializer.
508   Output a = ops::Variable(s.WithOpName("a"), TensorShape({128, 256}),
509                            DataType::DT_FLOAT);
510   Output a_init =
511       ops::RandomNormal(s.WithOpName("a/init"), {128, 256}, DataType::DT_FLOAT);
512   Output a_init_assign = ops::Assign(s.WithOpName("a/init/assign"), a, a_init);
513 
514   // Add a resource variable.
515   Output b =
516       ops::VarHandleOp(s.WithOpName("b"), DataType::DT_FLOAT, {256, 512});
517   Output b_read =
518       ops::ReadVariableOp(s.WithOpName("b/read"), b, DataType::DT_FLOAT);
519   Output b_init =
520       ops::RandomNormal(s.WithOpName("b/init"), {256, 512}, DataType::DT_FLOAT);
521   auto b_init_assign =
522       ops::AssignVariableOp(s.WithOpName("b/init/assign"), b, b_init);
523 
524   // Add a queue.
525   ops::FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_STRING});
526   Output some_string =
527       ops::Const(s.WithOpName("some_string"), string("nothing"));
528   ops::QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, {some_string});
529   ops::QueueDequeue dequeue(s.WithOpName("dequeue"), queue,
530                             {DataType::DT_STRING});
531 
532   // Add a IdentityReader.
533   ops::IdentityReader reader(s.WithOpName("identity_reader"));
534   ops::ReaderRead read(s.WithOpName("read_from_queue"), reader, queue);
535 
536   Output var_mul = ops::MatMul(s.WithOpName("var_matmul"), a, b_read);
537 
538   GrapplerItem item;
539   TF_CHECK_OK(s.ToGraphDef(&item.graph));
540 
541   QueueRunnerDef queue_runner;
542   queue_runner.set_queue_name("queue");
543   *queue_runner.add_enqueue_op_name() = "enqueue";
544   item.queue_runners.push_back(queue_runner);
545 
546   item.init_ops.push_back("a/init/assign");
547   item.init_ops.push_back("b/init/assign");
548   item.fetch.push_back("var_matmul");
549   item.fetch.push_back("dequeue");
550 
551   return item;
552 }
553 
554 #if defined(PLATFORM_GOOGLE)
TEST_F(SingleMachineTest,ReleaseMemoryAfterDestruction)555 TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
556   GrapplerItem item = CreateGrapplerItemWithResourceMemory();
557   TF_CHECK_OK(cluster_->Initialize(item));
558 
559   std::unordered_map<string, uint64> device_peak_memory_before;
560   TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory_before));
561   EXPECT_EQ(device_peak_memory_before.size(), 1);
562   // There might be a bit memory used before session's running anything.
563   EXPECT_LT(device_peak_memory_before.begin()->second, 400);
564 
565   RunMetadata metadata;
566   TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
567 
568   // Check there is memory that is not released.
569   std::unordered_map<string, uint64> device_peak_memory;
570   TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory));
571   EXPECT_EQ(device_peak_memory.size(), 1);
572   EXPECT_GT(device_peak_memory.begin()->second, 0);
573 
574   // Reprovisioning the cluster would release all memory.
575   TF_CHECK_OK(cluster_->Shutdown());
576   TF_CHECK_OK(cluster_->Provision());
577   std::unordered_map<string, uint64> device_peak_memory_after;
578   TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory_after));
579   TF_CHECK_OK(cluster_->Shutdown());
580 
581   // Check memory used by resources are released after cluster destruction.
582   EXPECT_EQ(device_peak_memory_before.size(), 1);
583   EXPECT_EQ(device_peak_memory_after.size(), 1);
584   EXPECT_LT(device_peak_memory_before.begin()->second, 400);
585   EXPECT_LT(device_peak_memory_after.begin()->second, 400);
586 }
587 
TEST_F(SingleMachineTest,PeakMemory)588 TEST_F(SingleMachineTest, PeakMemory) {
589   GrapplerItem item = CreateGrapplerItemWithResourceMemory();
590   TF_CHECK_OK(cluster_->Initialize(item));
591 
592   RunMetadata metadata;
593   TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
594 
595   std::unordered_map<string, uint64> device_peak_memory;
596   TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory));
597   ASSERT_NE(
598       device_peak_memory.find("/job:localhost/replica:0/task:0/device:CPU:0"),
599       device_peak_memory.end());
600   uint64 cpu_memory =
601       device_peak_memory["/job:localhost/replica:0/task:0/device:CPU:0"];
602   EXPECT_GT(cpu_memory, 0);
603 
604   TF_CHECK_OK(cluster_->Shutdown());
605   TF_CHECK_OK(cluster_->Provision());
606   device_peak_memory.clear();
607   TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory));
608   TF_CHECK_OK(cluster_->Shutdown());
609   ASSERT_NE(
610       device_peak_memory.find("/job:localhost/replica:0/task:0/device:CPU:0"),
611       device_peak_memory.end());
612   cpu_memory =
613       device_peak_memory["/job:localhost/replica:0/task:0/device:CPU:0"];
614   EXPECT_LT(cpu_memory, 200);
615 }
616 
TEST_F(SingleMachineTest,PeakMemoryStatsNotEnabled)617 TEST_F(SingleMachineTest, PeakMemoryStatsNotEnabled) {
618   GrapplerItem item = CreateGrapplerItemWithResourceMemory();
619 
620   TF_CHECK_OK(cluster_->Shutdown());
621   cluster_.reset();
622   SingleMachine cluster(60 /* timeout_s */, 3 /* num_cpu_cores */,
623                         0 /* num_gpus */);
624 
625   TF_CHECK_OK(cluster.Provision());
626   TF_CHECK_OK(cluster.Initialize(item));
627 
628   std::unordered_map<string, uint64> device_peak_memory;
629   Status s = cluster.GetPeakMemoryUsage(&device_peak_memory);
630   TF_CHECK_OK(cluster.Shutdown());
631   ASSERT_FALSE(s.ok());
632   EXPECT_EQ(s.code(), errors::Code::INVALID_ARGUMENT);
633 }
634 #endif
635 
636 }  // namespace
637 }  // namespace grappler
638 }  // namespace tensorflow
639