1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/tensor_description.pb.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
21 #include "tensorflow/core/grappler/costs/virtual_placer.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24
25 namespace tensorflow {
26 namespace grappler {
27 // Class for testing virtual scheduler.
28 class TestVirtualScheduler : public VirtualScheduler {
29 public:
TestVirtualScheduler(const GrapplerItem * grappler_item,const bool use_static_shapes,Cluster * cluster)30 TestVirtualScheduler(const GrapplerItem* grappler_item,
31 const bool use_static_shapes, Cluster* cluster)
32 : VirtualScheduler(grappler_item, use_static_shapes, cluster,
33 &ready_node_manager_) {}
34
35 FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
36 FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
37 FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
38 FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
39 FRIEND_TEST(VirtualSchedulerTest, Variable);
40 FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
41
42 protected:
43 FirstReadyManager ready_node_manager_;
44 };
45
46 class VirtualSchedulerTest : public ::testing::Test {
47 protected:
48 NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
49 std::unordered_map<const NodeDef*, NodeState> node_states_;
50
51 // Device names:
52 const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
53 const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1";
54 const string kChannelFrom0To1 = "Channel from CPU0 to CPU1";
55 const string kChannelFrom1To0 = "Channel from CPU1 to CPU0";
56 // Op names:
57 const string kSend = "_Send";
58 const string kRecv = "_Recv";
59 const string kConv2D = "Conv2D";
60
GetDummyCPUDevice()61 DeviceProperties GetDummyCPUDevice() {
62 // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
63 // - 8 Gflops
64 // - 2 GB/s
65 DeviceProperties cpu_device;
66 cpu_device.set_type("CPU");
67 cpu_device.set_frequency(4000);
68 cpu_device.set_num_cores(2);
69 cpu_device.set_bandwidth(2000000);
70 return cpu_device;
71 }
72
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)73 void NodeSetUp(const string& name, const string& op_name,
74 const string& device_name, const uint64 time_ready,
75 NodeDef* node) {
76 node->set_name(name);
77 node->set_op(op_name);
78 node->set_device(device_name);
79
80 node_states_[node] = NodeState();
81 node_states_[node].time_ready = time_ready;
82 node_states_[node].device_name = device_name;
83 }
84
SetUp()85 void SetUp() override {
86 // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
87 NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
88 NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
89 NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
90 NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
91 NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
92 NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
93
94 // Initializes cluster_ and placer_.
95 std::unordered_map<string, DeviceProperties> devices;
96
97 // Set some dummy CPU properties
98 DeviceProperties cpu_device = GetDummyCPUDevice();
99
100 // IMPORTANT: Device is not actually ever used in the test case since
101 // force_cpu_type is defaulted to "Haswell"
102 devices[kCPU0] = cpu_device;
103 devices[kCPU1] = cpu_device;
104 cluster_.reset(new VirtualCluster(devices));
105 placer_.reset(new VirtualPlacer(cluster_.get()));
106 }
107
108 // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()109 void CreateGrapplerItemWithConv2Ds() {
110 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
111 auto x = ops::RandomUniform(
112 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
113 auto y = ops::RandomUniform(
114 s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
115 auto z = ops::RandomUniform(
116 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
117 auto f = ops::RandomUniform(
118 s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
119 std::vector<int> strides = {1, 1, 1, 1};
120 auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
121 auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
122 auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
123 GraphDef def;
124 TF_CHECK_OK(s.ToGraphDef(&def));
125
126 grappler_item_.reset(new GrapplerItem);
127 grappler_item_->id = "test_conv2d_graph";
128 grappler_item_->graph = def;
129 grappler_item_->fetch = {"c0", "c1"};
130
131 dependency_["c0"] = {"x", "f"};
132 dependency_["c1"] = {"y", "f"};
133 }
134
135 // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()136 void CreateGrapplerItemWithConv2DAndVariable() {
137 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
138 auto x = ops::RandomUniform(
139 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
140 auto f = ops::Variable(s.WithOpName("f"),
141 {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
142 std::vector<int> strides = {1, 1, 1, 1};
143 auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
144 GraphDef def;
145 TF_CHECK_OK(s.ToGraphDef(&def));
146
147 grappler_item_.reset(new GrapplerItem);
148 grappler_item_->id = "test_conv2d_var_graph";
149 grappler_item_->graph = def;
150 grappler_item_->fetch = {"y"};
151
152 dependency_["y"] = {"x", "f"};
153 }
154
CreateGrapplerItemWithMatmulChain()155 void CreateGrapplerItemWithMatmulChain() {
156 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
157 // Add control dependencies to ensure tests do not rely on specific
158 // manager and the order remains consistent for the test.
159 auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
160 auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
161 {3200, 3200}, DT_FLOAT);
162 auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
163 {3200, 3200}, DT_FLOAT);
164 auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
165 {3200, 3200}, DT_FLOAT);
166 auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
167 {3200, 3200}, DT_FLOAT);
168
169 auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
170 auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
171 auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
172 auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
173
174 GraphDef def;
175 TF_CHECK_OK(s.ToGraphDef(&def));
176
177 grappler_item_.reset(new GrapplerItem);
178 grappler_item_->id = "test_matmul_sequence_graph";
179 grappler_item_->graph = def;
180 grappler_item_->fetch = {"abcde"};
181
182 dependency_["ab"] = {"a", "b"};
183 dependency_["abc"] = {"ab", "c"};
184 dependency_["abcd"] = {"abc", "d"};
185 dependency_["abcde"] = {"abcd", "e"};
186 }
187
188 // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()189 void CreateGrapplerItemWithAddN() {
190 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
191 auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
192 auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
193 auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
194 auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
195 OutputList input_tensors = {x, y, z, w};
196 auto out = ops::AddN(s.WithOpName("out"), input_tensors);
197 GraphDef def;
198 TF_CHECK_OK(s.ToGraphDef(&def));
199
200 grappler_item_.reset(new GrapplerItem);
201 grappler_item_->id = "test_addn_graph";
202 grappler_item_->graph = def;
203 grappler_item_->fetch = {"out"};
204
205 dependency_["out"] = {"x", "y", "z", "w"};
206 }
207
208 // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()209 void CreateGrapplerItemWithControlDependency() {
210 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
211 std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
212 std::vector<Operation> input_tensors;
213 for (const auto& input : input_noop_names) {
214 auto x = ops::NoOp(s.WithOpName(input));
215 input_tensors.push_back(x.operation);
216 }
217 auto out =
218 ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
219 GraphDef def;
220 TF_CHECK_OK(s.ToGraphDef(&def));
221
222 grappler_item_.reset(new GrapplerItem);
223 grappler_item_->id = "test_control_dependency_graph";
224 grappler_item_->graph = def;
225 grappler_item_->fetch = {"out"};
226
227 dependency_["out"] = input_noop_names;
228 }
229
230 // FusedBN [an op with multiple outputs] with multiple consumers (including
231 // control dependency).
CreateGrapplerItemWithBatchNorm()232 void CreateGrapplerItemWithBatchNorm() {
233 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
234 auto x = ops::RandomUniform(
235 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
236 auto scale =
237 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
238 auto offset =
239 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
240 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
241 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
242
243 auto batch_norm = ops::FusedBatchNorm(
244 s.WithOpName("bn"), x, scale, offset, mean, var,
245 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
246 auto y = batch_norm.y;
247 auto batch_mean = batch_norm.batch_mean;
248 auto batch_var = batch_norm.batch_variance;
249
250 auto z1 = ops::Add(s.WithOpName("z1"), x, y);
251 auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
252 auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
253 std::vector<Operation> input_tensors = {
254 batch_mean.op(),
255 z1.z.op(),
256 z2.z.op(),
257 z3.z.op(),
258 };
259 auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
260
261 GraphDef def;
262 TF_CHECK_OK(s.ToGraphDef(&def));
263
264 grappler_item_.reset(new GrapplerItem);
265 grappler_item_->id = "test_complex_dependency_graph";
266 grappler_item_->graph = def;
267 grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
268
269 dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
270 dependency_["z1"] = {"x", "bn"};
271 dependency_["z2"] = {"bn"};
272 dependency_["z3"] = {"bn"};
273 dependency_["z4"] = {"bn"};
274 }
275
CreateGrapplerItemWithSendRecv()276 void CreateGrapplerItemWithSendRecv() {
277 const string gdef_ascii = R"EOF(
278 node {
279 name: "Const"
280 op: "Const"
281 device: "/job:localhost/replica:0/task:0/device:CPU:0"
282 attr {
283 key: "dtype"
284 value {
285 type: DT_FLOAT
286 }
287 }
288 attr {
289 key: "value"
290 value {
291 tensor {
292 dtype: DT_FLOAT
293 tensor_shape {
294 }
295 float_val: 3.1415
296 }
297 }
298 }
299 }
300 node {
301 name: "Send"
302 op: "_Send"
303 input: "Const"
304 device: "/job:localhost/replica:0/task:0/device:CPU:0"
305 attr {
306 key: "T"
307 value {
308 type: DT_FLOAT
309 }
310 }
311 attr {
312 key: "client_terminated"
313 value {
314 b: false
315 }
316 }
317 attr {
318 key: "recv_device"
319 value {
320 s: "/job:localhost/replica:0/task:0/device:CPU:0"
321 }
322 }
323 attr {
324 key: "send_device"
325 value {
326 s: "/job:localhost/replica:0/task:0/device:CPU:0"
327 }
328 }
329 attr {
330 key: "send_device_incarnation"
331 value {
332 i: 0
333 }
334 }
335 attr {
336 key: "tensor_name"
337 value {
338 s: "test"
339 }
340 }
341 }
342 node {
343 name: "Recv"
344 op: "_Recv"
345 device: "/job:localhost/replica:0/task:0/device:CPU:0"
346 attr {
347 key: "client_terminated"
348 value {
349 b: false
350 }
351 }
352 attr {
353 key: "recv_device"
354 value {
355 s: "/job:localhost/replica:0/task:0/device:CPU:0"
356 }
357 }
358 attr {
359 key: "send_device"
360 value {
361 s: "/job:localhost/replica:0/task:0/device:CPU:0"
362 }
363 }
364 attr {
365 key: "send_device_incarnation"
366 value {
367 i: 0
368 }
369 }
370 attr {
371 key: "tensor_name"
372 value {
373 s: "test"
374 }
375 }
376 attr {
377 key: "tensor_type"
378 value {
379 type: DT_FLOAT
380 }
381 }
382 }
383 library {
384 }
385 versions {
386 producer: 24
387 }
388 )EOF";
389
390 grappler_item_.reset(new GrapplerItem);
391 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
392 &grappler_item_->graph));
393 grappler_item_->id = "test_graph";
394 grappler_item_->fetch = {"Recv"};
395 }
396
397 // A simple while loop
CreateGrapplerItemWithLoop()398 void CreateGrapplerItemWithLoop() {
399 // Test graph produced in python using:
400 /*
401 with tf.Graph().as_default():
402 i0 = tf.constant(0)
403 m0 = tf.ones([2, 2])
404 c = lambda i, m: i < 10
405 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
406 r = tf.while_loop(
407 c, b, loop_vars=[i0, m0],
408 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
409 with open('/tmp/graph.pbtxt', 'w') as f:
410 f.write(str(tf.get_default_graph().as_graph_def()))
411 */
412 const string gdef_ascii = R"EOF(
413 node {
414 name: "Const"
415 op: "Const"
416 attr {
417 key: "dtype"
418 value {
419 type: DT_INT32
420 }
421 }
422 attr {
423 key: "value"
424 value {
425 tensor {
426 dtype: DT_INT32
427 tensor_shape {
428 }
429 int_val: 0
430 }
431 }
432 }
433 }
434 node {
435 name: "ones"
436 op: "Const"
437 attr {
438 key: "dtype"
439 value {
440 type: DT_FLOAT
441 }
442 }
443 attr {
444 key: "value"
445 value {
446 tensor {
447 dtype: DT_FLOAT
448 tensor_shape {
449 dim {
450 size: 2
451 }
452 dim {
453 size: 2
454 }
455 }
456 float_val: 1.0
457 }
458 }
459 }
460 }
461 node {
462 name: "while/Enter"
463 op: "Enter"
464 input: "Const"
465 attr {
466 key: "T"
467 value {
468 type: DT_INT32
469 }
470 }
471 attr {
472 key: "frame_name"
473 value {
474 s: "while/while/"
475 }
476 }
477 attr {
478 key: "is_constant"
479 value {
480 b: false
481 }
482 }
483 attr {
484 key: "parallel_iterations"
485 value {
486 i: 10
487 }
488 }
489 }
490 node {
491 name: "while/Enter_1"
492 op: "Enter"
493 input: "ones"
494 attr {
495 key: "T"
496 value {
497 type: DT_FLOAT
498 }
499 }
500 attr {
501 key: "frame_name"
502 value {
503 s: "while/while/"
504 }
505 }
506 attr {
507 key: "is_constant"
508 value {
509 b: false
510 }
511 }
512 attr {
513 key: "parallel_iterations"
514 value {
515 i: 10
516 }
517 }
518 }
519 node {
520 name: "while/Merge"
521 op: "Merge"
522 input: "while/Enter"
523 input: "while/NextIteration"
524 attr {
525 key: "N"
526 value {
527 i: 2
528 }
529 }
530 attr {
531 key: "T"
532 value {
533 type: DT_INT32
534 }
535 }
536 }
537 node {
538 name: "while/Merge_1"
539 op: "Merge"
540 input: "while/Enter_1"
541 input: "while/NextIteration_1"
542 attr {
543 key: "N"
544 value {
545 i: 2
546 }
547 }
548 attr {
549 key: "T"
550 value {
551 type: DT_FLOAT
552 }
553 }
554 }
555 node {
556 name: "while/Less/y"
557 op: "Const"
558 input: "^while/Merge"
559 attr {
560 key: "dtype"
561 value {
562 type: DT_INT32
563 }
564 }
565 attr {
566 key: "value"
567 value {
568 tensor {
569 dtype: DT_INT32
570 tensor_shape {
571 }
572 int_val: 10
573 }
574 }
575 }
576 }
577 node {
578 name: "while/Less"
579 op: "Less"
580 input: "while/Merge"
581 input: "while/Less/y"
582 attr {
583 key: "T"
584 value {
585 type: DT_INT32
586 }
587 }
588 }
589 node {
590 name: "while/LoopCond"
591 op: "LoopCond"
592 input: "while/Less"
593 }
594 node {
595 name: "while/Switch"
596 op: "Switch"
597 input: "while/Merge"
598 input: "while/LoopCond"
599 attr {
600 key: "T"
601 value {
602 type: DT_INT32
603 }
604 }
605 attr {
606 key: "_class"
607 value {
608 list {
609 s: "loc:@while/Merge"
610 }
611 }
612 }
613 }
614 node {
615 name: "while/Switch_1"
616 op: "Switch"
617 input: "while/Merge_1"
618 input: "while/LoopCond"
619 attr {
620 key: "T"
621 value {
622 type: DT_FLOAT
623 }
624 }
625 attr {
626 key: "_class"
627 value {
628 list {
629 s: "loc:@while/Merge_1"
630 }
631 }
632 }
633 }
634 node {
635 name: "while/Identity"
636 op: "Identity"
637 input: "while/Switch:1"
638 attr {
639 key: "T"
640 value {
641 type: DT_INT32
642 }
643 }
644 }
645 node {
646 name: "while/Identity_1"
647 op: "Identity"
648 input: "while/Switch_1:1"
649 attr {
650 key: "T"
651 value {
652 type: DT_FLOAT
653 }
654 }
655 }
656 node {
657 name: "while/add/y"
658 op: "Const"
659 input: "^while/Identity"
660 attr {
661 key: "dtype"
662 value {
663 type: DT_INT32
664 }
665 }
666 attr {
667 key: "value"
668 value {
669 tensor {
670 dtype: DT_INT32
671 tensor_shape {
672 }
673 int_val: 1
674 }
675 }
676 }
677 }
678 node {
679 name: "while/add"
680 op: "Add"
681 input: "while/Identity"
682 input: "while/add/y"
683 attr {
684 key: "T"
685 value {
686 type: DT_INT32
687 }
688 }
689 }
690 node {
691 name: "while/concat/axis"
692 op: "Const"
693 input: "^while/Identity"
694 attr {
695 key: "dtype"
696 value {
697 type: DT_INT32
698 }
699 }
700 attr {
701 key: "value"
702 value {
703 tensor {
704 dtype: DT_INT32
705 tensor_shape {
706 }
707 int_val: 0
708 }
709 }
710 }
711 }
712 node {
713 name: "while/concat"
714 op: "ConcatV2"
715 input: "while/Identity_1"
716 input: "while/Identity_1"
717 input: "while/concat/axis"
718 attr {
719 key: "N"
720 value {
721 i: 2
722 }
723 }
724 attr {
725 key: "T"
726 value {
727 type: DT_FLOAT
728 }
729 }
730 attr {
731 key: "Tidx"
732 value {
733 type: DT_INT32
734 }
735 }
736 }
737 node {
738 name: "while/NextIteration"
739 op: "NextIteration"
740 input: "while/add"
741 attr {
742 key: "T"
743 value {
744 type: DT_INT32
745 }
746 }
747 }
748 node {
749 name: "while/NextIteration_1"
750 op: "NextIteration"
751 input: "while/concat"
752 attr {
753 key: "T"
754 value {
755 type: DT_FLOAT
756 }
757 }
758 }
759 node {
760 name: "while/Exit"
761 op: "Exit"
762 input: "while/Switch"
763 attr {
764 key: "T"
765 value {
766 type: DT_INT32
767 }
768 }
769 }
770 node {
771 name: "while/Exit_1"
772 op: "Exit"
773 input: "while/Switch_1"
774 attr {
775 key: "T"
776 value {
777 type: DT_FLOAT
778 }
779 }
780 }
781 versions {
782 producer: 21
783 }
784 )EOF";
785
786 grappler_item_.reset(new GrapplerItem);
787 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
788 &grappler_item_->graph));
789 grappler_item_->id = "test_graph";
790 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
791 }
792
CreateGrapplerItemWithInterDeviceTransfers()793 void CreateGrapplerItemWithInterDeviceTransfers() {
794 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
795
796 // Create a FusedBatchNorm op that has multiple output ports.
797 auto x = ops::RandomUniform(
798 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
799 auto scale =
800 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
801 auto offset =
802 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
803 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
804 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
805
806 auto batch_norm = ops::FusedBatchNorm(
807 s.WithOpName("bn"), x, scale, offset, mean, var,
808 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
809 auto y = batch_norm.y;
810 auto batch_mean = batch_norm.batch_mean;
811 auto batch_var = batch_norm.batch_variance;
812 // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
813 auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
814 auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
815 // batch_mean1 and batch_var1 take different output ports, so each will
816 // initiate Send/Recv.
817 auto batch_mean1 = ops::Identity(
818 s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
819 auto batch_var1 =
820 ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
821 // This is control dependency.
822 auto control_dep = ops::NoOp(s.WithOpName("control_dep")
823 .WithControlDependencies(y)
824 .WithDevice(kCPU1));
825
826 GraphDef def;
827 TF_CHECK_OK(s.ToGraphDef(&def));
828
829 grappler_item_.reset(new GrapplerItem);
830 grappler_item_->id = "test_conv2d_graph";
831 grappler_item_->graph = def;
832 grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
833 "control_dep"};
834
835 dependency_["bn"] = {"x", "mean", "var"};
836 dependency_["y1"] = {"bn"};
837 dependency_["y2"] = {"bn"};
838 dependency_["batch_mean1"] = {"bn"};
839 dependency_["batch_var1"] = {"bn"};
840 dependency_["control_dep"] = {"bn"};
841 }
842
843 // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()844 void InitScheduler() {
845 scheduler_.reset(new TestVirtualScheduler(
846 grappler_item_.get(), true /* use_static_shapes */, cluster_.get()));
847 TF_CHECK_OK(scheduler_->Init());
848 }
849
850 // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const851 Costs SimplePredictCosts(const OpContext& op_context) const {
852 Costs c;
853 int64 exec_cost = 0;
854 if (op_context.op_info.op() == "MatMul") {
855 exec_cost = 2000000000;
856 } else if (op_context.op_info.op() == "RandomUniform") {
857 exec_cost = 1000000000;
858 } else {
859 exec_cost = 1000;
860 }
861 c.execution_time = Costs::NanoSeconds(exec_cost);
862 return c;
863 }
864
865 // Call this after init scheduler_. Scheduler stops after executing
866 // target_node.
RunScheduler(const string & target_node)867 std::unordered_map<string, OpContext> RunScheduler(
868 const string& target_node) {
869 Costs zero_costs = Costs::ZeroCosts();
870 std::unordered_map<string, OpContext> ops_executed;
871 bool more_nodes = true;
872 do {
873 OpContext op_context = scheduler_->GetCurrNode();
874 ops_executed[op_context.name] = op_context;
875 std::cout << op_context.name << std::endl;
876
877 Costs node_costs = SimplePredictCosts(op_context);
878
879 // Check scheduling order.
880 auto it = dependency_.find(op_context.name);
881 if (it != dependency_.end()) {
882 for (const auto& preceding_node : it->second) {
883 EXPECT_GT(ops_executed.count(preceding_node), 0);
884 }
885 }
886 more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
887
888 if (op_context.name == target_node) {
889 // Scheduler has the state after executing the target node.
890 break;
891 }
892 } while (more_nodes);
893 return ops_executed;
894 }
895
896 // Helper method for validating a vector.
897 template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)898 void ExpectVectorEq(const std::vector<T>& expected,
899 const std::vector<T>& test_elements) {
900 // Set of expected elements for an easy comparison.
901 std::set<T> expected_set(expected.begin(), expected.end());
902 for (const auto& element : test_elements) {
903 EXPECT_GT(expected_set.count(element), 0);
904 }
905 EXPECT_EQ(expected.size(), test_elements.size());
906 }
907
908 // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)909 void ValidateNodeDefs(const std::vector<string>& expected,
910 const std::vector<const NodeDef*>& node_defs) {
911 std::vector<string> node_names;
912 std::transform(node_defs.begin(), node_defs.end(),
913 std::back_inserter(node_names),
914 [](const NodeDef* node) { return node->name(); });
915 ExpectVectorEq(expected, node_names);
916 }
917
918 // Helper method for validating a set.
919 template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)920 void ExpectSetEq(const std::set<T>& expected,
921 const std::set<T>& test_elements) {
922 for (const auto& element : test_elements) {
923 EXPECT_GT(expected.count(element), 0);
924 }
925 EXPECT_EQ(expected.size(), test_elements.size());
926 }
927
928 // Helper method tthat checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)929 void ValidateMemoryUsageSnapshot(
930 const std::vector<string>& expected_names, const int port_num_expected,
931 const std::unordered_set<std::pair<const NodeDef*, int>,
932 DeviceState::NodePairHash>& mem_usage_snapshot) {
933 std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
934 std::transform(
935 mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
936 std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
937 [](const std::pair<const NodeDef*, int>& node_port) {
938 return std::make_pair(node_port.first->name(), node_port.second);
939 });
940 std::set<std::pair<string, int>> expected;
941 std::transform(expected_names.begin(), expected_names.end(),
942 std::inserter(expected, expected.begin()),
943 [port_num_expected](const string& name) {
944 return std::make_pair(name, port_num_expected);
945 });
946 ExpectSetEq(expected, nodes_at_peak_mem_usage);
947 }
948
949 // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64> & start_times,const std::vector<string> & nodes_in_dependency_order)950 void ValidateDependencyChain(
951 const std::unordered_map<string, int64>& start_times,
952 const std::vector<string>& nodes_in_dependency_order) {
953 int64 prev_node_time = -1;
954 for (const auto& node : nodes_in_dependency_order) {
955 int64 curr_node_time = start_times.at(node);
956 EXPECT_GE(curr_node_time, prev_node_time);
957 prev_node_time = curr_node_time;
958 }
959 }
960
961 // Helper method for converting shape vector to TensorProperty.
ShapeToTensorProperty(const std::vector<int> shape,const DataType & data_type) const962 OpInfo::TensorProperties ShapeToTensorProperty(
963 const std::vector<int> shape, const DataType& data_type) const {
964 OpInfo::TensorProperties tensor_property;
965 tensor_property.set_dtype(data_type);
966 for (const auto& x : shape) {
967 tensor_property.mutable_shape()->add_dim()->set_size(x);
968 }
969 return tensor_property;
970 }
971
972 // SetUp() inits cluster_ and placer_.
973 std::unique_ptr<VirtualCluster> cluster_;
974 std::unique_ptr<VirtualPlacer> placer_;
975
976 // grappler_item_ and scheduler_ will be initialized differently for each test
977 // case.
978 std::unique_ptr<GrapplerItem> grappler_item_;
979 std::unique_ptr<TestVirtualScheduler> scheduler_;
980 // Node name -> its preceding nodes map for testing scheduling order.
981 std::unordered_map<string, std::vector<string>> dependency_;
982
983 // Shared params for Conv2D related graphs:
984 const int batch_size_ = 4;
985 const int width_ = 10;
986 const int height_ = 10;
987 const int depth_in_ = 8;
988 const int kernel_ = 3;
989 const int depth_out_ = 16;
990 };
991
992 // Test that FIFOManager correctly returns the current node with only 1 node.
TEST_F(VirtualSchedulerTest,GetSingleNodeFIFOManager)993 TEST_F(VirtualSchedulerTest, GetSingleNodeFIFOManager) {
994 // Init.
995 FIFOManager manager = FIFOManager();
996
997 // Add the node to FIFOManager.
998 manager.AddNode(&node1_);
999 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1000 }
1001
1002 // Test that FIFOManager removes the only node contained within.
TEST_F(VirtualSchedulerTest,RemoveSingleNodeFIFOManager)1003 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFIFOManager) {
1004 // Init.
1005 FIFOManager manager = FIFOManager();
1006
1007 // Add the node to FIFOManager.
1008 manager.AddNode(&node1_);
1009
1010 // Remove the only node in FIFOManager.
1011 manager.RemoveCurrNode();
1012 EXPECT_TRUE(manager.Empty());
1013 }
1014
1015 // Test that FIFOManager can remove multiple nodes and returns the current node
1016 // in the right order
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleFIFOManager)1017 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFIFOManager) {
1018 // Init.
1019 FIFOManager manager = FIFOManager();
1020
1021 // Add the nodes to FIFOManager.
1022 manager.AddNode(&node1_);
1023 manager.AddNode(&node2_);
1024 manager.AddNode(&node3_);
1025 manager.AddNode(&node4_);
1026
1027 // Keep checking current node while removing nodes from manager.
1028 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1029 manager.RemoveCurrNode();
1030 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1031 manager.RemoveCurrNode();
1032 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1033 manager.RemoveCurrNode();
1034 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1035 manager.RemoveCurrNode();
1036 EXPECT_TRUE(manager.Empty());
1037 }
1038
1039 // Test that FIFOManager can remove multiple nodes and add more nodes, still
1040 // returning the current node in the right order
TEST_F(VirtualSchedulerTest,AddAndRemoveMultipleFIFOManager)1041 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleFIFOManager) {
1042 // Init.
1043 FIFOManager manager = FIFOManager();
1044
1045 // Add the nodes to FIFOManager.
1046 manager.AddNode(&node1_);
1047 manager.AddNode(&node2_);
1048 manager.AddNode(&node3_);
1049 manager.AddNode(&node4_);
1050
1051 // Keep checking current node as nodes are removed and added.
1052 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1053 manager.RemoveCurrNode();
1054 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1055 manager.AddNode(&node5_);
1056 // GetCurrNode() should return the same node even if some nodes are added,
1057 // until RemoveCurrNode() is called.
1058 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1059 manager.RemoveCurrNode();
1060 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1061 manager.RemoveCurrNode();
1062 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1063 manager.AddNode(&node6_);
1064 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1065 manager.RemoveCurrNode();
1066 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1067 manager.RemoveCurrNode();
1068 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1069 manager.RemoveCurrNode();
1070 EXPECT_TRUE(manager.Empty());
1071 }
1072
1073 // Test that LIFOManager correctly returns the current node with only 1 node.
TEST_F(VirtualSchedulerTest,GetSingleNodeLIFOManager)1074 TEST_F(VirtualSchedulerTest, GetSingleNodeLIFOManager) {
1075 // Init.
1076 LIFOManager manager = LIFOManager();
1077
1078 // Add the node to LIFOManager.
1079 manager.AddNode(&node1_);
1080 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1081 }
1082
1083 // Test that LIFOManager removes the only node contained within.
TEST_F(VirtualSchedulerTest,RemoveSingleNodeLIFOManager)1084 TEST_F(VirtualSchedulerTest, RemoveSingleNodeLIFOManager) {
1085 // Init.
1086 LIFOManager manager = LIFOManager();
1087
1088 // Add the node to LIFOManager.
1089 manager.AddNode(&node1_);
1090
1091 // Remove the only node in LIFOManager.
1092 manager.RemoveCurrNode();
1093 EXPECT_TRUE(manager.Empty());
1094 }
1095
1096 // Test that LIFOManager can remove multiple nodes and returns the current node
1097 // in the right order
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleLIFOManager)1098 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleLIFOManager) {
1099 // Init.
1100 LIFOManager manager = LIFOManager();
1101
1102 // Add the nodes to LIFOManager.
1103 manager.AddNode(&node1_);
1104 manager.AddNode(&node2_);
1105 manager.AddNode(&node3_);
1106 manager.AddNode(&node4_);
1107
1108 // Keep checking current node while removing nodes from manager.
1109 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1110 manager.RemoveCurrNode();
1111 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1112 manager.RemoveCurrNode();
1113 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1114 manager.RemoveCurrNode();
1115 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1116 manager.RemoveCurrNode();
1117 EXPECT_TRUE(manager.Empty());
1118 }
1119
1120 // Test that LIFOManager can remove multiple nodes (must be removing the current
1121 // node) and add more nodes, still returning the current node in the right order
TEST_F(VirtualSchedulerTest,AddAndRemoveMultipleLIFOManager)1122 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) {
1123 // Init.
1124 LIFOManager manager = LIFOManager();
1125
1126 // Add the nodes to LIFOManager.
1127 manager.AddNode(&node1_);
1128 manager.AddNode(&node2_);
1129 manager.AddNode(&node3_);
1130 manager.AddNode(&node4_);
1131
1132 // Keep checking current node as nodes are removed and added.
1133 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1134 manager.RemoveCurrNode();
1135 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1136 manager.AddNode(&node5_);
1137 // GetCurrNode() should return the same node even if some nodes are added,
1138 // until RemoveCurrNode() is called.
1139 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1140 manager.RemoveCurrNode();
1141 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1142 manager.RemoveCurrNode();
1143 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1144 manager.AddNode(&node6_);
1145 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1146 manager.RemoveCurrNode();
1147 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1148 manager.RemoveCurrNode();
1149 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1150 manager.RemoveCurrNode();
1151 EXPECT_TRUE(manager.Empty());
1152 }
1153
TEST_F(VirtualSchedulerTest,GetSingleNodeFirstReadyManager)1154 TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) {
1155 FirstReadyManager manager;
1156 manager.Init(&node_states_);
1157
1158 manager.AddNode(&node1_);
1159 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1160 }
1161
TEST_F(VirtualSchedulerTest,RemoveSingleNodeFirstReadyManager)1162 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) {
1163 FirstReadyManager manager;
1164 manager.Init(&node_states_);
1165 manager.AddNode(&node1_);
1166 manager.RemoveCurrNode();
1167 EXPECT_TRUE(manager.Empty());
1168 }
1169
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleFirstReadyManager)1170 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) {
1171 FirstReadyManager manager;
1172 manager.Init(&node_states_);
1173 // Insert nodes in some random order.
1174 manager.AddNode(&node2_);
1175 manager.AddNode(&node1_);
1176 manager.AddNode(&node4_);
1177 manager.AddNode(&node5_);
1178 manager.AddNode(&node3_);
1179 manager.AddNode(&node6_);
1180
1181 // In whatever order we insert nodes, we get the same order based on nodes'
1182 // time_ready.
1183 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1184 manager.RemoveCurrNode();
1185 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1186 manager.RemoveCurrNode();
1187 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1188 manager.RemoveCurrNode();
1189 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1190 manager.RemoveCurrNode();
1191 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1192 manager.RemoveCurrNode();
1193 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1194 manager.RemoveCurrNode();
1195 EXPECT_TRUE(manager.Empty());
1196 }
1197
TEST_F(VirtualSchedulerTest,GetCurrNodeFirstReadyManager)1198 TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) {
1199 FirstReadyManager manager;
1200 manager.Init(&node_states_);
1201 // Insert nodes in some random order.
1202 manager.AddNode(&node2_);
1203 manager.AddNode(&node1_);
1204 manager.AddNode(&node4_);
1205 manager.AddNode(&node5_);
1206 manager.AddNode(&node3_);
1207 manager.AddNode(&node6_);
1208
1209 // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
1210 // should return it.
1211 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1212 // Now insret a few other nodes, but their time_ready's are even smaller than
1213 // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
1214 // the same node, Node6, in this case.
1215
1216 NodeDef node7;
1217 NodeDef node8;
1218 NodeDef node9;
1219 NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
1220 NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
1221 NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
1222
1223 manager.AddNode(&node7);
1224 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1225
1226 manager.AddNode(&node8);
1227 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1228
1229 manager.RemoveCurrNode();
1230 // Now Node6 is removed, and GetCurrNode() will return Node8.
1231 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1232
1233 // Again, AddNode shouldn't change GetCurrNode().
1234 manager.AddNode(&node9);
1235 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1236
1237 manager.RemoveCurrNode();
1238 EXPECT_EQ("Node9", manager.GetCurrNode()->name());
1239 manager.RemoveCurrNode();
1240 EXPECT_EQ("Node7", manager.GetCurrNode()->name());
1241 manager.RemoveCurrNode();
1242 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1243 manager.RemoveCurrNode();
1244 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1245 manager.RemoveCurrNode();
1246 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1247 manager.RemoveCurrNode();
1248 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1249 manager.RemoveCurrNode();
1250 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1251 manager.RemoveCurrNode();
1252 EXPECT_TRUE(manager.Empty());
1253 }
1254
TEST_F(VirtualSchedulerTest,DeterminismInFirstReadyManager)1255 TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) {
1256 FirstReadyManager manager1;
1257 manager1.Init(&node_states_);
1258 FirstReadyManager manager2;
1259 manager2.Init(&node_states_);
1260
1261 // 6 nodes with same time_ready.
1262 NodeDef node7;
1263 NodeDef node8;
1264 NodeDef node9;
1265 NodeDef node10;
1266 NodeDef node11;
1267 NodeDef node12;
1268 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
1269 NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
1270 NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
1271 NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
1272 NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
1273 NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
1274
1275 // Add the above 6 nodes to manager1.
1276 manager1.AddNode(&node7);
1277 manager1.AddNode(&node8);
1278 manager1.AddNode(&node9);
1279 manager1.AddNode(&node10);
1280 manager1.AddNode(&node11);
1281 manager1.AddNode(&node12);
1282
1283 // Add the above 6 nodes to manager2, but in a different order.
1284 manager2.AddNode(&node8);
1285 manager2.AddNode(&node11);
1286 manager2.AddNode(&node9);
1287 manager2.AddNode(&node10);
1288 manager2.AddNode(&node7);
1289 manager2.AddNode(&node12);
1290
1291 // Expect both managers return the same nodes for deterministic node
1292 // scheduling.
1293 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
1294 manager1.RemoveCurrNode();
1295 manager2.RemoveCurrNode();
1296
1297 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
1298 manager1.RemoveCurrNode();
1299 manager2.RemoveCurrNode();
1300
1301 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
1302 manager1.RemoveCurrNode();
1303 manager2.RemoveCurrNode();
1304
1305 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
1306 manager1.RemoveCurrNode();
1307 manager2.RemoveCurrNode();
1308
1309 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
1310 manager1.RemoveCurrNode();
1311 manager2.RemoveCurrNode();
1312
1313 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
1314 manager1.RemoveCurrNode();
1315 manager2.RemoveCurrNode();
1316
1317 EXPECT_TRUE(manager1.Empty());
1318 EXPECT_TRUE(manager2.Empty());
1319 }
1320
TEST_F(VirtualSchedulerTest,RemoveSingleNodeCompositeNodeManager)1321 TEST_F(VirtualSchedulerTest, RemoveSingleNodeCompositeNodeManager) {
1322 CompositeNodeManager manager;
1323 manager.Init(&node_states_);
1324 manager.AddNode(&node1_);
1325 manager.RemoveCurrNode();
1326 EXPECT_TRUE(manager.Empty());
1327 }
1328
TEST_F(VirtualSchedulerTest,RemoveSingleNodeComopsiteNodeManager)1329 TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) {
1330 CompositeNodeManager manager;
1331 manager.Init(&node_states_);
1332
1333 manager.AddNode(&node1_);
1334 manager.RemoveCurrNode();
1335 EXPECT_TRUE(manager.Empty());
1336 }
1337
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleComopsiteNodeManager)1338 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) {
1339 CompositeNodeManager manager;
1340 manager.Init(&node_states_);
1341
1342 // Add the nodes to LIFOManager.
1343 manager.AddNode(&node1_);
1344 manager.AddNode(&node2_);
1345 manager.AddNode(&node3_);
1346 manager.AddNode(&node4_);
1347
1348 // Keep checking current node as nodes are removed and added.
1349 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1350 manager.RemoveCurrNode();
1351 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1352 manager.AddNode(&node5_);
1353 // GetCurrNode() should return the same node even if some nodes are added,
1354 // until RemoveCurrNode() is called.
1355 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1356 manager.RemoveCurrNode();
1357 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1358 manager.RemoveCurrNode();
1359 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1360 manager.AddNode(&node6_);
1361 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1362 manager.RemoveCurrNode();
1363 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1364 manager.RemoveCurrNode();
1365 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1366 manager.RemoveCurrNode();
1367 EXPECT_TRUE(manager.Empty());
1368 }
1369
TEST_F(VirtualSchedulerTest,MultiDeviceSendRecvComopsiteNodeManager)1370 TEST_F(VirtualSchedulerTest, MultiDeviceSendRecvComopsiteNodeManager) {
1371 CompositeNodeManager manager;
1372 manager.Init(&node_states_);
1373 // Additional nodes on kCPU1
1374 NodeDef node7;
1375 NodeDef node8;
1376 NodeDef node9;
1377 NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
1378 NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
1379 NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
1380
1381 // Send and Recv nodes.
1382 NodeDef send1;
1383 NodeDef send2;
1384 NodeDef recv1;
1385 NodeDef recv2;
1386 NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
1387 NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
1388 NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
1389 NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
1390
1391 // Insert nodes.
1392 manager.AddNode(&node1_);
1393 manager.AddNode(&node2_);
1394 manager.AddNode(&node3_);
1395 manager.AddNode(&node4_);
1396 manager.AddNode(&node5_);
1397 manager.AddNode(&node6_);
1398 manager.AddNode(&node7);
1399 manager.AddNode(&node8);
1400 manager.AddNode(&node9);
1401 manager.AddNode(&send1);
1402 manager.AddNode(&send2);
1403 manager.AddNode(&recv1);
1404 manager.AddNode(&recv2);
1405
1406 // on kCPU0; last one is node6_, on kCPU1: last one is node9;
1407 // so choose one that has earliest time_ready among node6_, node9,
1408 // Send1, Send2, Recv1, and Recv2.
1409 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1410 manager.RemoveCurrNode();
1411 // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
1412 // among node5_, node9, Send1, Send2, Recv1, and Recv2.
1413 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1414 manager.RemoveCurrNode();
1415 // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
1416 EXPECT_EQ("Send1", manager.GetCurrNode()->name());
1417 manager.RemoveCurrNode();
1418 // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
1419 EXPECT_EQ("Recv1", manager.GetCurrNode()->name());
1420 manager.RemoveCurrNode();
1421 // Next, choose among node4_, node9, Send2, and Recv2.
1422 EXPECT_EQ("Recv2", manager.GetCurrNode()->name());
1423 manager.RemoveCurrNode();
1424 // Next, choose among node4_, node9, and Send2.
1425 EXPECT_EQ("Send2", manager.GetCurrNode()->name());
1426 manager.RemoveCurrNode();
1427 // Next, choose between node4_, node9.
1428 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1429 manager.RemoveCurrNode();
1430 // Next, choose between node3_, node9.
1431 EXPECT_EQ("Node9", manager.GetCurrNode()->name());
1432 manager.RemoveCurrNode();
1433 // Next, choose between node3_, node8.
1434 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1435 manager.RemoveCurrNode();
1436 // Next, choose between node3_, node7.
1437 EXPECT_EQ("Node7", manager.GetCurrNode()->name());
1438 manager.RemoveCurrNode();
1439 // Then, just the nodes on kCPU1 -- LIFO.
1440 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1441 manager.RemoveCurrNode();
1442 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1443 manager.RemoveCurrNode();
1444 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1445 manager.RemoveCurrNode();
1446 EXPECT_TRUE(manager.Empty());
1447 }
1448
TEST_F(VirtualSchedulerTest,DeterminismInCompositeNodeManager)1449 TEST_F(VirtualSchedulerTest, DeterminismInCompositeNodeManager) {
1450 CompositeNodeManager manager;
1451 manager.Init(&node_states_);
1452 CompositeNodeManager manager2;
1453 manager2.Init(&node_states_);
1454
1455 // 6 nodes with same time_ready.
1456 NodeDef node7;
1457 NodeDef node8;
1458 NodeDef node9;
1459 NodeDef node10;
1460 NodeDef node11;
1461 NodeDef node12;
1462 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
1463 NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
1464 NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
1465 NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
1466 NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
1467 NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
1468
1469 // Add Nodes 7 to 9 to manager.
1470 manager.AddNode(&node7);
1471 manager.AddNode(&node8);
1472 manager.AddNode(&node9);
1473
1474 // It should return _Send, Recv, and the other op order, when the candidate
1475 // nodes have same time_ready.
1476 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1477 EXPECT_EQ(kSend, manager.GetCurrNode()->op());
1478 manager.RemoveCurrNode();
1479 EXPECT_EQ("Node9", manager.GetCurrNode()->name());
1480 EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
1481 manager.RemoveCurrNode();
1482 EXPECT_EQ("Node7", manager.GetCurrNode()->name());
1483 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
1484 manager.RemoveCurrNode();
1485 EXPECT_TRUE(manager.Empty());
1486
1487 // Add Nodes 7 to 9 to manager, but in a different order.
1488 manager.AddNode(&node9);
1489 manager.AddNode(&node8);
1490 manager.AddNode(&node7);
1491
1492 // Expect same order (_Send, _Recv, and the other op), regardless of Add
1493 // order.
1494 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1495 EXPECT_EQ(kSend, manager.GetCurrNode()->op());
1496 manager.RemoveCurrNode();
1497 EXPECT_EQ("Node9", manager.GetCurrNode()->name());
1498 EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
1499 manager.RemoveCurrNode();
1500 EXPECT_EQ("Node7", manager.GetCurrNode()->name());
1501 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
1502 manager.RemoveCurrNode();
1503 EXPECT_TRUE(manager.Empty());
1504
1505 // Conv2D's time_ready < Send's time_ready; Expect Conv2D first.
1506 manager.AddNode(&node8);
1507 manager.AddNode(&node10);
1508 EXPECT_EQ("Node10", manager.GetCurrNode()->name());
1509 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
1510 manager.RemoveCurrNode();
1511 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1512 EXPECT_EQ(kSend, manager.GetCurrNode()->op());
1513 manager.RemoveCurrNode();
1514 EXPECT_TRUE(manager.Empty());
1515
1516 // Recv's time_ready < Send' time_ready; Expect Recv first.
1517 manager.AddNode(&node11);
1518 manager.AddNode(&node8);
1519 EXPECT_EQ("Node11", manager.GetCurrNode()->name());
1520 EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
1521 manager.RemoveCurrNode();
1522 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1523 EXPECT_EQ(kSend, manager.GetCurrNode()->op());
1524 manager.RemoveCurrNode();
1525 EXPECT_TRUE(manager.Empty());
1526
1527 // Node7 and 12 are normal ops with the same time_ready, placed on different
1528 // devices. These two nodes are added to manager and manager2, but in
1529 // different orders; Expect GetCurrNode() returns the nodes in the same order.
1530 manager.AddNode(&node7);
1531 manager.AddNode(&node12);
1532
1533 manager2.AddNode(&node12);
1534 manager2.AddNode(&node7);
1535
1536 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
1537 manager.RemoveCurrNode();
1538 manager2.RemoveCurrNode();
1539 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
1540 manager.RemoveCurrNode();
1541 manager2.RemoveCurrNode();
1542 EXPECT_TRUE(manager.Empty());
1543 }
1544
1545 // Create small graph, run predict costs on it, make sure the costs from the
1546 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)1547 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
1548 // Run matmul test.
1549 CreateGrapplerItemWithMatmulChain();
1550 InitScheduler();
1551 auto ops_executed = RunScheduler("");
1552 Costs c = scheduler_->Summary();
1553
1554 // RandomUniform - 5 * 1s
1555 // Matmuls - 4 * 2s = 8
1556 // Misc - 5 * 1us
1557 // Total: 13000005
1558 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
1559 }
1560
1561 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
1562 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)1563 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
1564 // Run matmul test.
1565 CreateGrapplerItemWithMatmulChain();
1566 InitScheduler();
1567 auto ops_executed = RunScheduler("");
1568 RunMetadata metadata;
1569 Costs c = scheduler_->Summary(&metadata);
1570 StepStats stepstats = metadata.step_stats();
1571 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
1572
1573 // Should only be 1 device!
1574 EXPECT_EQ(1, stepstats.dev_stats().size());
1575
1576 // Create a map of op name -> start and end times (micros).
1577 std::map<string, std::pair<int64, int64>> start_end_times;
1578 for (const auto& device_step_stats : stepstats.dev_stats()) {
1579 for (const auto& stats : device_step_stats.node_stats()) {
1580 int64 start = stats.all_start_micros();
1581 int64 end = start + stats.all_end_rel_micros();
1582 start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end);
1583
1584 // Make sure that the output properties are correct for
1585 // MatMul and RandomUniform operations.
1586 // We only check for dtype, and shape (excluding alloc)
1587 // since alloc is not set by the virtual scheduler.
1588 if (stats.timeline_label() == "MatMul" ||
1589 stats.timeline_label() == "RandomUniform") {
1590 EXPECT_EQ(1, stats.output().size());
1591 for (const auto& output : stats.output()) {
1592 EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
1593 EXPECT_EQ(2, output.tensor_description().shape().dim().size());
1594 for (const auto& dim : output.tensor_description().shape().dim()) {
1595 EXPECT_EQ(3200, dim.size());
1596 }
1597 }
1598 }
1599 }
1600 }
1601
1602 // The base start_time is the time to compute RandomUniforms
1603 int64 cur_time = static_cast<int64>(5000005);
1604 // The increment is the execution time of one matmul. See
1605 // CreateGrapplerItemWithMatmulChain for details.
1606 int64 increment = static_cast<int64>(2000000);
1607 auto op_names = {"ab", "abc", "abcd", "abcde"};
1608 for (const auto& op_name : op_names) {
1609 int64 actual_start = start_end_times[op_name].first;
1610 int64 actual_end = start_end_times[op_name].second;
1611 int64 expected_start = cur_time;
1612 int64 expected_end = cur_time + increment;
1613 EXPECT_EQ(expected_start, actual_start);
1614 EXPECT_EQ(expected_end, actual_end);
1615 cur_time += increment;
1616 }
1617 }
1618
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)1619 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
1620 // Init.
1621 CreateGrapplerItemWithConv2Ds();
1622 InitScheduler();
1623
1624 // Run the scheduler.
1625 auto ops_executed = RunScheduler(""); // Run all the nodes.
1626
1627 // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
1628 // executed.
1629 EXPECT_EQ(8, ops_executed.size());
1630
1631 // x, y, f, c0, and c1 should be in the ops executed.
1632 EXPECT_GT(ops_executed.count("x"), 0);
1633 EXPECT_GT(ops_executed.count("y"), 0);
1634 EXPECT_GT(ops_executed.count("f"), 0);
1635 EXPECT_GT(ops_executed.count("c0"), 0);
1636 EXPECT_GT(ops_executed.count("c1"), 0);
1637
1638 // z and c2 shouldn't be part of it.
1639 EXPECT_EQ(ops_executed.count("z"), 0);
1640 EXPECT_EQ(ops_executed.count("c2"), 0);
1641
1642 // Check input / output properties.
1643 EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
1644 EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
1645 EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
1646 EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
1647 EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
1648 }
1649
TEST_F(VirtualSchedulerTest,CalculateOutputSize)1650 TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
1651 // Init.
1652 CreateGrapplerItemWithAddN();
1653 InitScheduler();
1654
1655 // Create a set of tensor properties.
1656 std::vector<OpInfo::TensorProperties> output;
1657 output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0
1658 output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1
1659 output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2
1660 output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3
1661 output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4
1662 output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4
1663
1664 // port_num -1 is for control dependency: hard coded 4B.
1665 EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
1666
1667 // Test valid outputs.
1668 EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
1669 EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
1670 EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
1671 EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
1672
1673 // Any unknown shape (-1) shall yield zero output size.
1674 EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
1675 EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
1676
1677 // Invalid port_num (though it may be an error) shall yield zero
1678 // output size.
1679 EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
1680 }
1681
TEST_F(VirtualSchedulerTest,MemoryUsage)1682 TEST_F(VirtualSchedulerTest, MemoryUsage) {
1683 // Init.
1684 CreateGrapplerItemWithAddN();
1685 InitScheduler();
1686
1687 // Run the scheduler.
1688 RunScheduler("");
1689
1690 const auto* device_states = scheduler_->GetDeviceStates();
1691 const auto& cpu_state = device_states->at(kCPU0);
1692
1693 // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
1694 // is 4 x the input tensor size while executing the out node.
1695 int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
1696 const std::vector<string> expected_names = {"x", "y", "z", "w"};
1697 EXPECT_EQ(expected_names.size() * one_input_node_size,
1698 cpu_state.max_memory_usage);
1699 ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
1700 cpu_state.mem_usage_snapshot_at_peak);
1701 }
1702
TEST_F(VirtualSchedulerTest,ControlDependency)1703 TEST_F(VirtualSchedulerTest, ControlDependency) {
1704 // Init.
1705 CreateGrapplerItemWithControlDependency();
1706 InitScheduler();
1707
1708 // Run the scheduler.
1709 RunScheduler("");
1710
1711 const auto* device_states = scheduler_->GetDeviceStates();
1712 const auto& cpu_state = device_states->at(kCPU0);
1713
1714 // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
1715 // memory usage is when executing the final NoOp.
1716 int64 one_input_node_size = 4; // control dependency
1717 const std::vector<string> expected_names = {"x", "y", "z", "w",
1718 "u", "v", "t"};
1719 EXPECT_EQ(expected_names.size() * one_input_node_size,
1720 cpu_state.max_memory_usage);
1721 ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
1722 cpu_state.mem_usage_snapshot_at_peak);
1723 }
1724
TEST_F(VirtualSchedulerTest,ComplexDependency)1725 TEST_F(VirtualSchedulerTest, ComplexDependency) {
1726 // Init.
1727 CreateGrapplerItemWithBatchNorm();
1728 InitScheduler();
1729
1730 // Run the scheduler.
1731 RunScheduler("bn");
1732
1733 const auto& device_states = scheduler_->GetDeviceStates();
1734 const auto& cpu_state = device_states->at(kCPU0);
1735
1736 // The graph is
1737 // bn = FusedBatchNorm(x, scale, offset, mean, var)
1738 // z1 = bn.y + x
1739 // z2 = bn.var + bn.var
1740 // z3 = bn.var + bn.var
1741 // z4 = control dependency from bn.
1742 // Note that bn.mean doesn't have any consumer.
1743 const int x_size = batch_size_ * width_ * height_ * depth_in_;
1744 int64 expected_size =
1745 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
1746 1 /* control dependency */);
1747 EXPECT_EQ(expected_size, cpu_state.memory_usage);
1748
1749 // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
1750 std::set<std::pair<string, int>> nodes_in_memory;
1751 std::transform(
1752 cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
1753 std::inserter(nodes_in_memory, nodes_in_memory.begin()),
1754 [](const std::pair<const NodeDef*, int>& node_port) {
1755 return std::make_pair(node_port.first->name(), node_port.second);
1756 });
1757 std::set<std::pair<string, int>> expected = {
1758 std::make_pair("bn", -1),
1759 std::make_pair("bn", 0),
1760 std::make_pair("bn", 2),
1761 std::make_pair("x", 0),
1762 };
1763 ExpectSetEq(expected, nodes_in_memory);
1764
1765 const auto* node_states = scheduler_->GetNodeStates();
1766 const NodeState* bn_node = nullptr;
1767 const NodeState* x_node = nullptr;
1768 for (const auto& nodedef_node_state : *node_states) {
1769 const NodeDef* node = nodedef_node_state.first;
1770 const NodeState& node_state = nodedef_node_state.second;
1771 if (node->name() == "bn") {
1772 bn_node = &node_state;
1773 }
1774 if (node->name() == "x") {
1775 x_node = &node_state;
1776 }
1777 }
1778 CHECK_NOTNULL(bn_node);
1779 CHECK_NOTNULL(x_node);
1780
1781 ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
1782 ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
1783 ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
1784 // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
1785 ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
1786 }
1787
TEST_F(VirtualSchedulerTest,Variable)1788 TEST_F(VirtualSchedulerTest, Variable) {
1789 // Init.
1790 CreateGrapplerItemWithConv2DAndVariable();
1791 InitScheduler();
1792
1793 // Run the scheduler.
1794 RunScheduler("");
1795
1796 const auto* device_states = scheduler_->GetDeviceStates();
1797 const auto& cpu_state = device_states->at(kCPU0);
1798
1799 // There is one Conv2D that takes x and f, but f is variable, so it should be
1800 // in persistent nodes.
1801 // f is variable.
1802 ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */,
1803 cpu_state.persistent_nodes);
1804 // Only x in peak memory usage snapshot.
1805 ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
1806 cpu_state.mem_usage_snapshot_at_peak);
1807 }
1808
TEST_F(VirtualSchedulerTest,WhileLoop)1809 TEST_F(VirtualSchedulerTest, WhileLoop) {
1810 // Init.
1811 CreateGrapplerItemWithLoop();
1812 InitScheduler();
1813
1814 // Run the scheduler.
1815 RunScheduler("");
1816
1817 // Check the timeline
1818 RunMetadata metadata;
1819 scheduler_->Summary(&metadata);
1820
1821 // Nodes in topological order:
1822 // * const, ones
1823 // * while/Enter, while/Enter_1
1824 // * while/Merge, while/Merge_1
1825 // * while/Less/y
1826 // * while/Less
1827 // * while/LoopCond
1828 // * while/Switch, while/Switch_1
1829 // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
1830 // * while/add/y, while/concat/axis
1831 // * while/add, while/concat
1832 // * while/NextIteration, while/NextIteration_1
1833
1834 int num_next_iteration = 0;
1835 int num_next_iteration_1 = 0;
1836 int num_exit = 0;
1837 int num_exit_1 = 0;
1838 int64 next_iter_start_micro;
1839 int64 next_iter_1_start_micro;
1840 int64 exit_start_micro;
1841 int64 exit_1_start_micro;
1842
1843 std::unordered_map<string, int64> start_times;
1844 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
1845 for (const auto& stats : device_step_stats.node_stats()) {
1846 start_times[stats.node_name()] = stats.all_start_micros();
1847 if (stats.node_name() == "while/NextIteration") {
1848 ++num_next_iteration;
1849 next_iter_start_micro = stats.all_start_micros();
1850 } else if (stats.node_name() == "while/NextIteration_1") {
1851 ++num_next_iteration_1;
1852 next_iter_1_start_micro = stats.all_start_micros();
1853 } else if (stats.node_name() == "while/Exit") {
1854 ++num_exit;
1855 exit_start_micro = stats.all_start_micros();
1856 } else if (stats.node_name() == "while/Exit_1") {
1857 ++num_exit_1;
1858 exit_1_start_micro = stats.all_start_micros();
1859 }
1860 }
1861 }
1862
1863 // Make sure we went though the body of the loop once, and that the output of
1864 // the loop was scheduled as well.
1865 EXPECT_EQ(1, num_next_iteration);
1866 EXPECT_EQ(1, num_next_iteration_1);
1867 EXPECT_EQ(1, num_exit);
1868 EXPECT_EQ(1, num_exit_1);
1869
1870 // Start times of while/NextIteration and while/NextIteration_1 should be
1871 // different, so should be those of while/Exit and while/Exit_1.
1872 EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
1873 EXPECT_NE(exit_start_micro, exit_1_start_micro);
1874
1875 // Check dependency among the nodes; no matter what scheduling mechanism we
1876 // use, the scheduled ops should follow these dependency chains.
1877 // Note that currently, VirtualScheduler executes while/Merge twice; hence,
1878 // we're not testing dependency chains related to while/Merge.
1879 // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
1880 // order of Enter, Merge, ...loop condition ..., ... loop body ...,
1881 // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
1882 // chaing test w/ Merge nodes.
1883 ValidateDependencyChain(
1884 start_times,
1885 {"Const", "while/Enter", // "while/Merge",
1886 "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
1887 "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
1888 // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
1889 ValidateDependencyChain(start_times,
1890 {"ones", "while/Enter_1", // "while/Merge_1",
1891 "while/Switch_1", "while/Identity_1", "while/concat",
1892 "while/NextIteration_1"});
1893 ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
1894 ValidateDependencyChain(
1895 start_times, {"while/Identity", "while/concat/axis", "while/concat"});
1896 ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
1897 ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
1898 }
1899
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)1900 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
1901 // Init.
1902 CreateGrapplerItemWithInterDeviceTransfers();
1903 InitScheduler();
1904
1905 // Run the scheduler.
1906 auto ops_executed = RunScheduler("");
1907
1908 // Helper lambda to extract port num from _Send and _Recv op name.
1909 auto get_port_num = [](const string& name) -> int {
1910 if (name.find("bn_0") != std::string::npos) {
1911 return 0;
1912 } else if (name.find("bn_1") != std::string::npos) {
1913 return 1;
1914 } else if (name.find("bn_2") != std::string::npos) {
1915 return 2;
1916 } else if (name.find("bn_minus1") != std::string::npos) {
1917 return -1;
1918 }
1919 return -999;
1920 };
1921
1922 // Reorganize ops_executed for further testing.
1923 std::unordered_map<string, int> op_count;
1924 std::unordered_map<int, string> recv_op_names;
1925 std::unordered_map<int, string> send_op_names;
1926 for (const auto& x : ops_executed) {
1927 const auto& name = x.first;
1928 const auto& node_info = x.second;
1929 const auto& op = node_info.op_info.op();
1930 if (op == kRecv) {
1931 recv_op_names[get_port_num(name)] = name;
1932 } else if (op == kSend) {
1933 send_op_names[get_port_num(name)] = name;
1934 }
1935 op_count[op]++;
1936 }
1937
1938 // Same number of _Send and _Recv.
1939 EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
1940
1941 // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
1942 EXPECT_EQ(op_count.at(kRecv), 4);
1943 EXPECT_EQ(op_count.at(kSend), 4);
1944
1945 // Helper lambda for extracting output Tensor size.
1946 auto get_output_size = [this, ops_executed](const string& name) -> int64 {
1947 const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
1948 std::vector<OpInfo::TensorProperties> output_properties;
1949 for (const auto& output_property : output_properties_) {
1950 output_properties.push_back(output_property);
1951 }
1952 return scheduler_->CalculateOutputSize(output_properties, 0);
1953 };
1954
1955 // Validate transfer size.
1956 // Batchnorm output y is 4D vector: batch x width x width x depth.
1957 int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
1958 EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
1959 EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
1960 // Mean and vars are 1-D vector with size depth_in_.
1961 EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
1962 EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
1963 EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
1964 EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
1965 // Control dependency size is 4B.
1966 EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
1967 EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
1968 }
1969
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)1970 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
1971 // Init.
1972 CreateGrapplerItemWithSendRecv();
1973 InitScheduler();
1974
1975 // Run the scheduler.
1976 auto ops_executed = RunScheduler("");
1977
1978 EXPECT_GT(ops_executed.count("Const"), 0);
1979 EXPECT_GT(ops_executed.count("Send"), 0);
1980 EXPECT_GT(ops_executed.count("Recv"), 0);
1981 }
1982
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)1983 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
1984 // Init.
1985 CreateGrapplerItemWithSendRecv();
1986 // Change Recv node's device so that Send and Recv are placed on different
1987 // devices.
1988 auto& graph = grappler_item_->graph;
1989 const string recv_device = kCPU1;
1990 for (int i = 0; i < graph.node_size(); i++) {
1991 auto* node = graph.mutable_node(i);
1992 if (node->name() == "Recv") {
1993 node->set_device(recv_device);
1994 auto* attr = node->mutable_attr();
1995 (*attr)["recv_device"].set_s(recv_device);
1996 } else if (node->name() == "Send") {
1997 auto* attr = node->mutable_attr();
1998 (*attr)["recv_device"].set_s(recv_device);
1999 }
2000 }
2001 InitScheduler();
2002
2003 // Run the scheduler.
2004 auto ops_executed = RunScheduler("");
2005
2006 // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
2007 EXPECT_GT(ops_executed.count("Const"), 0);
2008 EXPECT_GT(ops_executed.count("Send"), 0);
2009 EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
2010 "task_0/cpu_0_to_/job_localhost"
2011 "/replica_0/task_0/cpu_1"),
2012 0);
2013 EXPECT_GT(ops_executed.count(
2014 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
2015 0);
2016 EXPECT_GT(ops_executed.count("Recv"), 0);
2017 }
2018 } // end namespace grappler
2019 } // end namespace tensorflow
2020