1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/clusters/single_machine.h"
17
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/resource_variable_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/framework/cost_graph.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/protobuf/queue_runner.pb.h"
32
33 namespace tensorflow {
34 namespace grappler {
35 namespace {
36
37 class SingleMachineTest : public ::testing::Test {
38 public:
SetUp()39 void SetUp() override {
40 // Provision a single machine with 3 cpu cores, and a short timeout of 5
41 // seconds: since there isn't much work to process a test graph that should
42 // be plenty.
43 #if TENSORFLOW_USE_ROCM
44 // ROCm takes longer to start up
45 int timeout_s = 10;
46 #else
47 int timeout_s = 5;
48 #endif
49 #ifdef THREAD_SANITIZER
50 timeout_s *= 5;
51 #endif
52 cluster_.reset(
53 new SingleMachine(timeout_s, 3 /* num_cpu_cores */, 0 /* num_gpus */));
54 TF_CHECK_OK(cluster_->EnablePeakMemoryStats());
55 TF_CHECK_OK(cluster_->Provision());
56 }
57
TearDown()58 void TearDown() override {
59 if (cluster_) {
60 TF_CHECK_OK(cluster_->Shutdown());
61 }
62 cluster_.reset();
63 }
64
65 protected:
66 std::unique_ptr<SingleMachine> cluster_;
67 };
68
TEST_F(SingleMachineTest,ClusterType)69 TEST_F(SingleMachineTest, ClusterType) {
70 CHECK_EQ("single_machine", cluster_->type());
71 }
72
TEST_F(SingleMachineTest,CostModel)73 TEST_F(SingleMachineTest, CostModel) {
74 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
75 cluster_->GetDeviceNames());
76 GrapplerItem item;
77 CHECK(fake_input.NextItem(&item));
78
79 TF_CHECK_OK(cluster_->Initialize(item));
80
81 RunMetadata metadata;
82 const int64 start_micros = Env::Default()->NowMicros();
83 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
84 const int64 run_duration_micros = Env::Default()->NowMicros() - start_micros;
85
86 // There should be at least 4 nodes corresponding to the 4 stages we created
87 // in the fake input.
88 EXPECT_LE(4, metadata.cost_graph().node_size());
89 for (const auto& node : metadata.cost_graph().node()) {
90 // Skip the special nodes inserted by TF: these are prefixed with an
91 // underscore.
92 if (node.name()[0] == '_' || node.name().find("/_") != string::npos) {
93 continue;
94 }
95 #ifndef INTEL_MKL
96 // The output size of MKL op is 2, and cannot filter out the MKL op
97 // with the OP name (no op name here), so just disable this check in
98 // TF_MKL build.
99 EXPECT_EQ(1, node.output_info_size());
100 #endif // !INTEL_MKL
101 EXPECT_LE(8, node.output_info(0).size());
102 const TensorShapeProto& shape = node.output_info(0).shape();
103 EXPECT_EQ(2, shape.dim_size());
104 EXPECT_EQ(10, shape.dim(0).size());
105 EXPECT_EQ(1, shape.dim(1).size());
106 EXPECT_LE(0, node.compute_cost());
107 EXPECT_GE(run_duration_micros, node.compute_cost());
108 }
109 }
110
TEST_F(SingleMachineTest,Queue)111 TEST_F(SingleMachineTest, Queue) {
112 TrivialTestGraphInputYielder fake_input(4, 1, 10, true,
113 cluster_->GetDeviceNames());
114 GrapplerItem item;
115 CHECK(fake_input.NextItem(&item));
116
117 TF_CHECK_OK(cluster_->Initialize(item));
118 RunMetadata metadata;
119 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
120 }
121
TEST_F(SingleMachineTest,MultipleItems)122 TEST_F(SingleMachineTest, MultipleItems) {
123 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
124 cluster_->GetDeviceNames());
125
126 for (int i = 0; i < 3; ++i) {
127 GrapplerItem item;
128 CHECK(fake_input.NextItem(&item));
129 TF_CHECK_OK(cluster_->Initialize(item));
130 RunMetadata metadata1;
131 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata1));
132 RunMetadata metadata2;
133 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata2));
134
135 // There should be at least 4 nodes corresponding to the 4 stages we created
136 // in the fake input, plus 1 enqueue and 1 dequeue node.
137 EXPECT_LE(6, metadata1.cost_graph().node_size());
138 for (const auto& node : metadata1.cost_graph().node()) {
139 if (node.name()[0] == '_' || node.name().find("/_") != string::npos ||
140 node.name() == "queue") {
141 continue;
142 }
143 #ifndef INTEL_MKL
144 EXPECT_EQ(1, node.output_info_size());
145 #endif // !INTEL_MKL
146 const TensorShapeProto& shape = node.output_info(0).shape();
147 EXPECT_EQ(2, shape.dim_size());
148 EXPECT_EQ(10, shape.dim(0).size());
149 EXPECT_EQ(1, shape.dim(1).size());
150 }
151
152 for (int i = 0; i < metadata1.cost_graph().node_size(); ++i) {
153 metadata1.mutable_cost_graph()->mutable_node(i)->set_compute_cost(0);
154 metadata1.clear_step_stats();
155 }
156 for (int i = 0; i < metadata2.cost_graph().node_size(); ++i) {
157 metadata2.mutable_cost_graph()->mutable_node(i)->set_compute_cost(0);
158 metadata2.clear_step_stats();
159 }
160 string s1;
161 ::tensorflow::protobuf::TextFormat::PrintToString(metadata1, &s1);
162 string s2;
163 ::tensorflow::protobuf::TextFormat::PrintToString(metadata2, &s2);
164 EXPECT_EQ(s1, s2);
165 }
166 }
167
TEST_F(SingleMachineTest,GraphOptimizations)168 TEST_F(SingleMachineTest, GraphOptimizations) {
169 // Create a graph that can be fully precomputed
170 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
171 auto zero = ops::Const(root.WithOpName("zero"), 0.0f, {2, 3});
172 auto one = ops::Const(root.WithOpName("one"), 1.0f, {2, 3});
173 auto add = ops::Add(root.WithOpName("add"), zero, one);
174 auto square = ops::Square(root.WithOpName("square"), add);
175
176 auto new_shape = ops::Const(root.WithOpName("new_shape"), {3, -1}, {2});
177 auto reshaped = ops::Reshape(root.WithOpName("reshaped"), square, new_shape);
178 auto final_shape = ops::Shape(root.WithOpName("final_shape"), reshaped);
179
180 auto expected_shape =
181 ops::Const(root.WithOpName("expected_shape"), {3, 2}, {2});
182 auto valid =
183 ops::Equal(root.WithOpName("valid"), final_shape, expected_shape);
184 auto all_dims = ops::Const(root.WithOpName("all_dims"), {0}, {1});
185
186 auto all_valid = ops::All(root.WithOpName("all_valid"), valid, all_dims);
187 auto assert_valid = ops::Assert(root.WithOpName("assert_valid"), all_valid,
188 {final_shape.output});
189
190 GrapplerItem item;
191 TF_CHECK_OK(root.ToGraphDef(&item.graph));
192 item.fetch.push_back("assert_valid");
193
194 // Force the placement of all the nodes on CPU since TF attempts to use a GPU
195 // when possible event though we created the session to have a single CPU !.
196 for (auto& node : *item.graph.mutable_node()) {
197 node.set_device("/cpu:0");
198 }
199
200 // With optimizations turned on, some nodes could have been optimized away,
201 // and the cost model could be partial. Restart the cluster with optimizations
202 // disabled and make sure we have all the information we're looking for.
203 TF_CHECK_OK(cluster_->Shutdown());
204 cluster_->DisableOptimizer(true);
205 TF_CHECK_OK(cluster_->Provision());
206
207 RunMetadata metadata;
208 TF_CHECK_OK(cluster_->Initialize(item));
209 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
210 std::set<string> cost_nodes;
211 for (const auto& node : metadata.cost_graph().node()) {
212 #ifdef INTEL_MKL
213 // Skip the special nodes inserted by TF (and MKL): these are either
214 // prefixed with an underscore or contain "/_".
215 if (node.name()[0] == '_' || node.name().find("/_") != string::npos) {
216 continue;
217 }
218 cost_nodes.insert(node.name());
219 #else
220 // Skip nodes added by TF internally.
221 if (node.name()[0] != '_') {
222 cost_nodes.insert(node.name());
223 }
224 #endif
225 }
226 const std::set<string> expected_cost_nodes = {
227 "zero", "one", "add", "square",
228 "new_shape", "reshaped", "final_shape", "expected_shape",
229 "valid", "all_dims", "all_valid", "assert_valid"};
230 EXPECT_EQ(expected_cost_nodes, cost_nodes);
231 }
232
TEST_F(SingleMachineTest,TimeOuts)233 TEST_F(SingleMachineTest, TimeOuts) {
234 // Create a graph that will block forever: Just try to dequeue data from a
235 // queue that is never fed.
236 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
237 auto q = ops::FIFOQueue(root.WithOpName("queue"), {DataType::DT_INT32});
238 auto dequeue =
239 ops::QueueDequeue(root.WithOpName("dequeue"), q, {DataType::DT_INT32});
240
241 GrapplerItem item;
242 TF_CHECK_OK(root.ToGraphDef(&item.graph));
243 item.fetch.push_back("dequeue");
244
245 TF_CHECK_OK(cluster_->Initialize(item));
246 RunMetadata metadata;
247 Status s1 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata);
248 EXPECT_TRUE(errors::IsDeadlineExceeded(s1));
249 Status s2 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata);
250 EXPECT_TRUE(errors::IsDeadlineExceeded(s2));
251 }
252
RunInfiniteTFLoop()253 static void RunInfiniteTFLoop() {
254 // Create a while(true) loop
255 GrapplerItem item;
256
257 NodeDef* shp = item.graph.add_node();
258 shp->set_name("shape");
259 shp->set_op("Const");
260 (*shp->mutable_attr())["dtype"].set_type(DT_INT32);
261 Tensor shp_tensor(DT_INT32, TensorShape({1}));
262 shp_tensor.flat<int32>()(0) = 1;
263 shp_tensor.AsProtoTensorContent(
264 (*shp->mutable_attr())["value"].mutable_tensor());
265
266 NodeDef* r = item.graph.add_node();
267 r->set_name("random");
268 r->set_op("RandomUniform");
269 (*r->mutable_attr())["dtype"].set_type(DT_FLOAT);
270 (*r->mutable_attr())["T"].set_type(DT_INT32);
271 *r->add_input() = "shape";
272
273 NodeDef* e = item.graph.add_node();
274 e->set_name("while/Enter");
275 e->set_op("Enter");
276 (*e->mutable_attr())["T"].set_type(DT_FLOAT);
277 (*e->mutable_attr())["frame_name"].set_s("while/while/");
278 *e->add_input() = "random";
279
280 NodeDef* m = item.graph.add_node();
281 m->set_name("while/Merge");
282 m->set_op("Merge");
283 (*m->mutable_attr())["T"].set_type(DT_FLOAT);
284 (*m->mutable_attr())["N"].set_i(2);
285 *m->add_input() = "while/Enter";
286 *m->add_input() = "while/NextIteration";
287
288 NodeDef* t = item.graph.add_node();
289 t->set_name("always_true");
290 t->set_op("Const");
291 (*t->mutable_attr())["dtype"].set_type(DT_BOOL);
292 *t->add_input() = "^while/Merge";
293 Tensor true_tensor(DT_BOOL, TensorShape());
294 true_tensor.flat<bool>()(0) = true;
295 true_tensor.AsProtoTensorContent(
296 (*t->mutable_attr())["value"].mutable_tensor());
297
298 NodeDef* c = item.graph.add_node();
299 c->set_name("while/LoopCond");
300 c->set_op("LoopCond");
301 *c->add_input() = "always_true";
302
303 NodeDef* s = item.graph.add_node();
304 s->set_name("while/Switch");
305 (*s->mutable_attr())["T"].set_type(DT_FLOAT);
306 s->set_op("Switch");
307 *s->add_input() = "while/Merge";
308 *s->add_input() = "while/LoopCond";
309
310 NodeDef* i = item.graph.add_node();
311 i->set_name("while/Identity");
312 i->set_op("Identity");
313 (*i->mutable_attr())["T"].set_type(DT_FLOAT);
314 *i->add_input() = "while/Switch:1";
315
316 NodeDef* n = item.graph.add_node();
317 n->set_name("while/NextIteration");
318 n->set_op("NextIteration");
319 (*n->mutable_attr())["T"].set_type(DT_FLOAT);
320 *n->add_input() = "while/Identity";
321
322 NodeDef* x = item.graph.add_node();
323 x->set_name("while/Exit");
324 x->set_op("Exit");
325 (*x->mutable_attr())["T"].set_type(DT_FLOAT);
326 *x->add_input() = "while/Switch";
327
328 item.fetch.push_back("while/Exit");
329
330 // Create our own cluster to run it
331 SingleMachine cluster(5, 3, 0);
332 TF_CHECK_OK(cluster.Provision());
333 TF_CHECK_OK(cluster.Initialize(item));
334
335 Status s1 = cluster.Run(item.graph, item.feed, item.fetch, nullptr);
336 if (!errors::IsDeadlineExceeded(s1)) {
337 LOG(ERROR) << "Expected 'deadline exceeded' error, got " << s1;
338 // Exit to break the infinite loop
339 _exit(1);
340 }
341
342 // Attempt to shutdown the cluster and make sure we get the proper error code.
343 Status s2 = cluster.Shutdown();
344 if (!errors::IsUnavailable(s2)) {
345 LOG(ERROR) << "Expected 'unavailable' error, got " << s2;
346 // Exit to break the infinite loop
347 _exit(2);
348 }
349
350 // The isn't much we can do at this point. Exit with the error code 0 to
351 // indicate everything went according to plan.
352 _exit(0);
353 }
354
TEST_F(SingleMachineTest,InfiniteLoops)355 TEST_F(SingleMachineTest, InfiniteLoops) {
356 #if !(TENSORFLOW_USE_ROCM) // fails with ROCm (investigate)
357 // The RunInfiniteTFLoop function creates its own cluster.
358 TF_CHECK_OK(cluster_->Shutdown());
359 EXPECT_EXIT(RunInfiniteTFLoop(), ::testing::ExitedWithCode(0), ".*");
360 #endif
361 }
362
TEST_F(SingleMachineTest,InitializationMemory)363 TEST_F(SingleMachineTest, InitializationMemory) {
364 // Build a variable and its initialization graph.
365 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
366 int batch_size = 10;
367 Output x =
368 ops::RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT);
369 Output v = ops::Variable(s.WithOpName("v"), TensorShape({batch_size, 1}),
370 DataType::DT_FLOAT);
371 Output init = ops::Assign(s.WithOpName("init"), v, x);
372
373 GrapplerItem item;
374 TF_CHECK_OK(s.ToGraphDef(&item.graph));
375 item.init_ops.push_back(init.name());
376 item.fetch.push_back(v.name());
377
378 TF_CHECK_OK(cluster_->Initialize(item));
379 RunMetadata metadata;
380 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
381
382 // Check that the initialization op is present in the cost model.
383 bool found = false;
384 for (const auto& node : metadata.cost_graph().node()) {
385 found |= (node.name() == NodeName(init.name()));
386 }
387 EXPECT_TRUE(found);
388 }
389
390 namespace {
391
392 template <class T>
SetNodeAttr(const string & key,const T & value,NodeDef * node)393 inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
394 AttrValue attr_value;
395 SetAttrValue(value, &attr_value);
396 auto* attr_map = node->mutable_attr();
397 (*attr_map)[key] = attr_value;
398 }
399 template <>
SetNodeAttr(const string & key,const Tensor & tensor,NodeDef * node)400 inline void SetNodeAttr(const string& key, const Tensor& tensor,
401 NodeDef* node) {
402 TensorProto tensor_proto;
403 tensor.AsProtoTensorContent(&tensor_proto);
404 SetNodeAttr(key, tensor_proto, node);
405 }
406
407 } // namespace
408
TEST_F(SingleMachineTest,PersistentMemory)409 TEST_F(SingleMachineTest, PersistentMemory) {
410 // Build a hashtable and its initialization graph.
411 GrapplerItem item;
412 const DataType key_dtype = DT_INT64;
413 const DataType data_dtype = DT_INT64;
414
415 NodeDef* hashtable_node = item.graph.add_node();
416 hashtable_node->set_op("HashTable");
417 hashtable_node->set_name("hash_table");
418 SetNodeAttr("key_dtype", key_dtype, hashtable_node);
419 SetNodeAttr("value_dtype", data_dtype, hashtable_node);
420
421 // Initial hashtable keys and values
422 NodeDef* keys_node = item.graph.add_node();
423 keys_node->set_op("Const");
424 keys_node->set_name("table_keys");
425 SetNodeAttr("dtype", key_dtype, keys_node);
426 Tensor keys(key_dtype, TensorShape{2});
427 keys.vec<int64>()(0) = 123;
428 keys.vec<int64>()(1) = 321;
429 SetNodeAttr("value", keys, keys_node);
430
431 NodeDef* values_node = item.graph.add_node();
432 values_node->set_op("Const");
433 values_node->set_name("table_values");
434 SetNodeAttr("dtype", data_dtype, values_node);
435 Tensor values(data_dtype, TensorShape{2});
436 values.vec<int64>()(0) = 789;
437 values.vec<int64>()(1) = 987;
438 SetNodeAttr("value", values, values_node);
439
440 // InitializeTable node
441 NodeDef* init_table_node = item.graph.add_node();
442 init_table_node->set_op("InitializeTable");
443 init_table_node->set_name("initialize_table");
444 SetNodeAttr("Tkey", key_dtype, init_table_node);
445 SetNodeAttr("Tval", data_dtype, init_table_node);
446 *init_table_node->add_input() = "hash_table";
447 *init_table_node->add_input() = "table_keys";
448 *init_table_node->add_input() = "table_values";
449 item.init_ops.push_back(init_table_node->name());
450
451 // Key to lookup
452 NodeDef* query_node = item.graph.add_node();
453 query_node->set_op("Const");
454 query_node->set_name("query");
455 SetNodeAttr("dtype", key_dtype, query_node);
456 Tensor query(key_dtype, TensorShape({}));
457 query.flat<int64>()(0) = 0;
458 SetNodeAttr("value", query, query_node);
459
460 // Default return value of hashtable lookup
461 NodeDef* default_value_node = item.graph.add_node();
462 default_value_node->set_op("Const");
463 default_value_node->set_name("default_table_value");
464 SetNodeAttr("dtype", data_dtype, default_value_node);
465 Tensor dflt(data_dtype, TensorShape({}));
466 dflt.flat<int64>()(0) = 456;
467 SetNodeAttr("value", dflt, default_value_node);
468
469 // HashTable lookup node
470 NodeDef* lookup_node = item.graph.add_node();
471 lookup_node->set_op("LookupTableFind");
472 lookup_node->set_name("table_lookup");
473 SetNodeAttr("Tin", key_dtype, lookup_node);
474 SetNodeAttr("Tout", data_dtype, lookup_node);
475 *lookup_node->add_input() = "hash_table";
476 *lookup_node->add_input() = "query";
477 *lookup_node->add_input() = "default_table_value";
478 item.fetch.push_back(lookup_node->name());
479
480 // Run the graph
481 TF_CHECK_OK(cluster_->Initialize(item));
482 RunMetadata metadata;
483 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
484
485 // Check the cost model.
486 bool found_table_init = false;
487 bool found_hashtable = false;
488 for (const auto& node : metadata.cost_graph().node()) {
489 if (node.name() == "hash_table") {
490 found_hashtable = true;
491 // Persistent memory usage should be 0 since it's recorded as part of the
492 // initialize_table op.
493 EXPECT_EQ(0, node.persistent_memory_size());
494 } else if (node.name() == "initialize_table") {
495 found_table_init = true;
496 // Persistent memory should hold 2 keys and 2 values.
497 EXPECT_LE(4 * sizeof(int64), node.persistent_memory_size());
498 }
499 }
500 EXPECT_TRUE(found_table_init);
501 EXPECT_TRUE(found_hashtable);
502 }
503
CreateGrapplerItemWithResourceMemory()504 GrapplerItem CreateGrapplerItemWithResourceMemory() {
505 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
506
507 // Add a variable and initializer.
508 Output a = ops::Variable(s.WithOpName("a"), TensorShape({128, 256}),
509 DataType::DT_FLOAT);
510 Output a_init =
511 ops::RandomNormal(s.WithOpName("a/init"), {128, 256}, DataType::DT_FLOAT);
512 Output a_init_assign = ops::Assign(s.WithOpName("a/init/assign"), a, a_init);
513
514 // Add a resource variable.
515 Output b =
516 ops::VarHandleOp(s.WithOpName("b"), DataType::DT_FLOAT, {256, 512});
517 Output b_read =
518 ops::ReadVariableOp(s.WithOpName("b/read"), b, DataType::DT_FLOAT);
519 Output b_init =
520 ops::RandomNormal(s.WithOpName("b/init"), {256, 512}, DataType::DT_FLOAT);
521 auto b_init_assign =
522 ops::AssignVariableOp(s.WithOpName("b/init/assign"), b, b_init);
523
524 // Add a queue.
525 ops::FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_STRING});
526 Output some_string =
527 ops::Const(s.WithOpName("some_string"), string("nothing"));
528 ops::QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, {some_string});
529 ops::QueueDequeue dequeue(s.WithOpName("dequeue"), queue,
530 {DataType::DT_STRING});
531
532 // Add a IdentityReader.
533 ops::IdentityReader reader(s.WithOpName("identity_reader"));
534 ops::ReaderRead read(s.WithOpName("read_from_queue"), reader, queue);
535
536 Output var_mul = ops::MatMul(s.WithOpName("var_matmul"), a, b_read);
537
538 GrapplerItem item;
539 TF_CHECK_OK(s.ToGraphDef(&item.graph));
540
541 QueueRunnerDef queue_runner;
542 queue_runner.set_queue_name("queue");
543 *queue_runner.add_enqueue_op_name() = "enqueue";
544 item.queue_runners.push_back(queue_runner);
545
546 item.init_ops.push_back("a/init/assign");
547 item.init_ops.push_back("b/init/assign");
548 item.fetch.push_back("var_matmul");
549 item.fetch.push_back("dequeue");
550
551 return item;
552 }
553
554 #if defined(PLATFORM_GOOGLE)
TEST_F(SingleMachineTest,ReleaseMemoryAfterDestruction)555 TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
556 GrapplerItem item = CreateGrapplerItemWithResourceMemory();
557 TF_CHECK_OK(cluster_->Initialize(item));
558
559 std::unordered_map<string, uint64> device_peak_memory_before;
560 TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory_before));
561 EXPECT_EQ(device_peak_memory_before.size(), 1);
562 // There might be a bit memory used before session's running anything.
563 EXPECT_LT(device_peak_memory_before.begin()->second, 400);
564
565 RunMetadata metadata;
566 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
567
568 // Check there is memory that is not released.
569 std::unordered_map<string, uint64> device_peak_memory;
570 TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory));
571 EXPECT_EQ(device_peak_memory.size(), 1);
572 EXPECT_GT(device_peak_memory.begin()->second, 0);
573
574 // Reprovisioning the cluster would release all memory.
575 TF_CHECK_OK(cluster_->Shutdown());
576 TF_CHECK_OK(cluster_->Provision());
577 std::unordered_map<string, uint64> device_peak_memory_after;
578 TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory_after));
579 TF_CHECK_OK(cluster_->Shutdown());
580
581 // Check memory used by resources are released after cluster destruction.
582 EXPECT_EQ(device_peak_memory_before.size(), 1);
583 EXPECT_EQ(device_peak_memory_after.size(), 1);
584 EXPECT_LT(device_peak_memory_before.begin()->second, 400);
585 EXPECT_LT(device_peak_memory_after.begin()->second, 400);
586 }
587
TEST_F(SingleMachineTest,PeakMemory)588 TEST_F(SingleMachineTest, PeakMemory) {
589 GrapplerItem item = CreateGrapplerItemWithResourceMemory();
590 TF_CHECK_OK(cluster_->Initialize(item));
591
592 RunMetadata metadata;
593 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
594
595 std::unordered_map<string, uint64> device_peak_memory;
596 TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory));
597 ASSERT_NE(
598 device_peak_memory.find("/job:localhost/replica:0/task:0/device:CPU:0"),
599 device_peak_memory.end());
600 uint64 cpu_memory =
601 device_peak_memory["/job:localhost/replica:0/task:0/device:CPU:0"];
602 EXPECT_GT(cpu_memory, 0);
603
604 TF_CHECK_OK(cluster_->Shutdown());
605 TF_CHECK_OK(cluster_->Provision());
606 device_peak_memory.clear();
607 TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory));
608 TF_CHECK_OK(cluster_->Shutdown());
609 ASSERT_NE(
610 device_peak_memory.find("/job:localhost/replica:0/task:0/device:CPU:0"),
611 device_peak_memory.end());
612 cpu_memory =
613 device_peak_memory["/job:localhost/replica:0/task:0/device:CPU:0"];
614 EXPECT_LT(cpu_memory, 200);
615 }
616
TEST_F(SingleMachineTest,PeakMemoryStatsNotEnabled)617 TEST_F(SingleMachineTest, PeakMemoryStatsNotEnabled) {
618 GrapplerItem item = CreateGrapplerItemWithResourceMemory();
619
620 TF_CHECK_OK(cluster_->Shutdown());
621 cluster_.reset();
622 SingleMachine cluster(60 /* timeout_s */, 3 /* num_cpu_cores */,
623 0 /* num_gpus */);
624
625 TF_CHECK_OK(cluster.Provision());
626 TF_CHECK_OK(cluster.Initialize(item));
627
628 std::unordered_map<string, uint64> device_peak_memory;
629 Status s = cluster.GetPeakMemoryUsage(&device_peak_memory);
630 TF_CHECK_OK(cluster.Shutdown());
631 ASSERT_FALSE(s.ok());
632 EXPECT_EQ(s.code(), errors::Code::INVALID_ARGUMENT);
633 }
634 #endif
635
636 } // namespace
637 } // namespace grappler
638 } // namespace tensorflow
639