1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_description.pb.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
22 #include "tensorflow/core/grappler/costs/utils.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 namespace {
30 
31 // Device names:
32 constexpr char kCPU0[] = "/job:localhost/replica:0/task:0/cpu:0";
33 constexpr char kCPU1[] = "/job:localhost/replica:0/task:0/cpu:1";
34 constexpr char kChannelFrom0To1[] = "Channel from CPU0 to CPU1";
35 constexpr char kChannelFrom1To0[] = "Channel from CPU1 to CPU0";
36 // Op names:
37 constexpr char kConv2D[] = "Conv2D";
38 constexpr char kSend[] = "_Send";
39 constexpr char kRecv[] = "_Recv";
40 
41 class ReadyNodeManagerTest : public ::testing::Test {
42  protected:
ReadyNodeManagerTest()43   ReadyNodeManagerTest() {
44     // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
45     NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
46     NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
47     NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
48     NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
49     NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
50     NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
51   }
52 
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)53   void NodeSetUp(const string& name, const string& op_name,
54                  const string& device_name, const uint64 time_ready,
55                  NodeDef* node) {
56     node->set_name(name);
57     node->set_op(op_name);
58     node->set_device(device_name);
59 
60     node_states_[node] = NodeState();
61     node_states_[node].time_ready = time_ready;
62     node_states_[node].device_name = device_name;
63   }
64 
65   NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
66   std::unordered_map<const NodeDef*, NodeState> node_states_;
67 };
68 
69 // Tests that FIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeFIFOManager)70 TEST_F(ReadyNodeManagerTest, GetSingleNodeFIFOManager) {
71   FIFOManager manager = FIFOManager();
72   manager.AddNode(&node1_);
73   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
74 }
75 
76 // Tests that FIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFIFOManager)77 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFIFOManager) {
78   FIFOManager manager = FIFOManager();
79   manager.AddNode(&node1_);
80 
81   // Removes the only node in FIFOManager.
82   manager.RemoveCurrNode();
83   EXPECT_TRUE(manager.Empty());
84 }
85 
86 // Tests that FIFOManager can remove multiple nodes and returns the current node
87 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFIFOManager)88 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFIFOManager) {
89   FIFOManager manager = FIFOManager();
90   manager.AddNode(&node1_);
91   manager.AddNode(&node2_);
92   manager.AddNode(&node3_);
93   manager.AddNode(&node4_);
94 
95   // Keeps checking current node while removing nodes from manager.
96   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
97   manager.RemoveCurrNode();
98   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
99   manager.RemoveCurrNode();
100   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
101   manager.RemoveCurrNode();
102   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
103   manager.RemoveCurrNode();
104   EXPECT_TRUE(manager.Empty());
105 }
106 
107 // Tests that FIFOManager can remove multiple nodes and add more nodes, still
108 // returning the current node in the right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleFIFOManager)109 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleFIFOManager) {
110   FIFOManager manager = FIFOManager();
111   manager.AddNode(&node1_);
112   manager.AddNode(&node2_);
113   manager.AddNode(&node3_);
114   manager.AddNode(&node4_);
115 
116   // Keeps checking current node as nodes are removed and added.
117   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
118   manager.RemoveCurrNode();
119   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
120   manager.AddNode(&node5_);
121   // GetCurrNode() should return the same node even if some nodes are added,
122   // until RemoveCurrNode() is called.
123   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
124   manager.RemoveCurrNode();
125   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
126   manager.RemoveCurrNode();
127   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
128   manager.AddNode(&node6_);
129   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
130   manager.RemoveCurrNode();
131   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
132   manager.RemoveCurrNode();
133   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
134   manager.RemoveCurrNode();
135   EXPECT_TRUE(manager.Empty());
136 }
137 
138 // Tests that LIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeLIFOManager)139 TEST_F(ReadyNodeManagerTest, GetSingleNodeLIFOManager) {
140   LIFOManager manager = LIFOManager();
141   manager.AddNode(&node1_);
142   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
143 }
144 
145 // Tests that LIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeLIFOManager)146 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeLIFOManager) {
147   LIFOManager manager = LIFOManager();
148   manager.AddNode(&node1_);
149 
150   // Removes the only node in LIFOManager.
151   manager.RemoveCurrNode();
152   EXPECT_TRUE(manager.Empty());
153 }
154 
155 // Tests that LIFOManager can remove multiple nodes and returns the current node
156 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleLIFOManager)157 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleLIFOManager) {
158   LIFOManager manager = LIFOManager();
159   manager.AddNode(&node1_);
160   manager.AddNode(&node2_);
161   manager.AddNode(&node3_);
162   manager.AddNode(&node4_);
163 
164   // Keeps checking current node while removing nodes from manager.
165   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
166   manager.RemoveCurrNode();
167   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
168   manager.RemoveCurrNode();
169   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
170   manager.RemoveCurrNode();
171   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
172   manager.RemoveCurrNode();
173   EXPECT_TRUE(manager.Empty());
174 }
175 
176 // Tests that LIFOManager can remove multiple nodes (must be removing the
177 // current node) and add more nodes, still returning the current node in the
178 // right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleLIFOManager)179 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
180   LIFOManager manager = LIFOManager();
181   manager.AddNode(&node1_);
182   manager.AddNode(&node2_);
183   manager.AddNode(&node3_);
184   manager.AddNode(&node4_);
185 
186   // Keeps checking current node as nodes are removed and added.
187   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
188   manager.RemoveCurrNode();
189   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
190   manager.AddNode(&node5_);
191   // GetCurrNode()  should return the same node even if some nodes are added,
192   // until RemoveCurrNode() is called.
193   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
194   manager.RemoveCurrNode();
195   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
196   manager.RemoveCurrNode();
197   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
198   manager.AddNode(&node6_);
199   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
200   manager.RemoveCurrNode();
201   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
202   manager.RemoveCurrNode();
203   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
204   manager.RemoveCurrNode();
205   EXPECT_TRUE(manager.Empty());
206 }
207 
TEST_F(ReadyNodeManagerTest,MergeOrderInLIFOManager)208 TEST_F(ReadyNodeManagerTest, MergeOrderInLIFOManager) {
209   LIFOManager manager = LIFOManager();
210   node3_.set_op("Merge");
211   manager.AddNode(&node1_);
212   manager.AddNode(&node2_);
213   manager.AddNode(&node3_);
214   manager.AddNode(&node4_);
215 
216   // Merge node (node3) will be scheduled at the end (even though it's added
217   // after nodde2).
218   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
219   manager.RemoveCurrNode();
220   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
221   manager.RemoveCurrNode();
222   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
223   manager.RemoveCurrNode();
224   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
225   manager.RemoveCurrNode();
226 }
227 
TEST_F(ReadyNodeManagerTest,GetSingleNodeFirstReadyManager)228 TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
229   FirstReadyManager manager;
230   TF_EXPECT_OK(manager.Init(&node_states_));
231   manager.AddNode(&node1_);
232   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
233 }
234 
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFirstReadyManager)235 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
236   FirstReadyManager manager;
237   TF_EXPECT_OK(manager.Init(&node_states_));
238   manager.AddNode(&node1_);
239   manager.RemoveCurrNode();
240   EXPECT_TRUE(manager.Empty());
241 }
242 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFirstReadyManager)243 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
244   FirstReadyManager manager;
245   TF_EXPECT_OK(manager.Init(&node_states_));
246   // Insert nodes in some random order.
247   manager.AddNode(&node2_);
248   manager.AddNode(&node1_);
249   manager.AddNode(&node4_);
250   manager.AddNode(&node5_);
251   manager.AddNode(&node3_);
252   manager.AddNode(&node6_);
253 
254   // In whatever order we insert nodes, we get the same order based on nodes'
255   // time_ready.
256   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
257   manager.RemoveCurrNode();
258   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
259   manager.RemoveCurrNode();
260   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
261   manager.RemoveCurrNode();
262   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
263   manager.RemoveCurrNode();
264   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
265   manager.RemoveCurrNode();
266   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
267   manager.RemoveCurrNode();
268   EXPECT_TRUE(manager.Empty());
269 }
270 
TEST_F(ReadyNodeManagerTest,GetCurrNodeFirstReadyManager)271 TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
272   FirstReadyManager manager;
273   TF_EXPECT_OK(manager.Init(&node_states_));
274 
275   // Inserts nodes in some random order.
276   manager.AddNode(&node2_);
277   manager.AddNode(&node1_);
278   manager.AddNode(&node4_);
279   manager.AddNode(&node5_);
280   manager.AddNode(&node3_);
281   manager.AddNode(&node6_);
282 
283   // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
284   // should return it.
285   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
286 
287   // Now inserts a few other nodes, but their time_ready's are even smaller than
288   // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
289   // the same node, Node6, in this case.
290   NodeDef node7;
291   NodeDef node8;
292   NodeDef node9;
293   NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
294   NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
295   NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
296 
297   manager.AddNode(&node7);
298   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
299 
300   manager.AddNode(&node8);
301   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
302 
303   manager.RemoveCurrNode();
304   // Now Node6 is removed, and GetCurrNode() will return Node8.
305   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
306 
307   // Again, AddNode shouldn't change GetCurrNode().
308   manager.AddNode(&node9);
309   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
310 
311   manager.RemoveCurrNode();
312   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
313   manager.RemoveCurrNode();
314   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
315   manager.RemoveCurrNode();
316   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
317   manager.RemoveCurrNode();
318   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
319   manager.RemoveCurrNode();
320   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
321   manager.RemoveCurrNode();
322   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
323   manager.RemoveCurrNode();
324   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
325   manager.RemoveCurrNode();
326   EXPECT_TRUE(manager.Empty());
327 }
328 
TEST_F(ReadyNodeManagerTest,DeterminismInFirstReadyManager)329 TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
330   FirstReadyManager manager1;
331   TF_EXPECT_OK(manager1.Init(&node_states_));
332   FirstReadyManager manager2;
333   TF_EXPECT_OK(manager2.Init(&node_states_));
334 
335   // 6 nodes with same time_ready.
336   NodeDef node7;
337   NodeDef node8;
338   NodeDef node9;
339   NodeDef node10;
340   NodeDef node11;
341   NodeDef node12;
342   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
343   NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
344   NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
345   NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
346   NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
347   NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
348 
349   // Adds the above 6 nodes to manager1.
350   manager1.AddNode(&node7);
351   manager1.AddNode(&node8);
352   manager1.AddNode(&node9);
353   manager1.AddNode(&node10);
354   manager1.AddNode(&node11);
355   manager1.AddNode(&node12);
356 
357   // Adds the above 6 nodes to manager2, but in a different order.
358   manager2.AddNode(&node8);
359   manager2.AddNode(&node11);
360   manager2.AddNode(&node9);
361   manager2.AddNode(&node10);
362   manager2.AddNode(&node7);
363   manager2.AddNode(&node12);
364 
365   // Expects both managers return the same nodes for deterministic node
366   // scheduling.
367   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
368   manager1.RemoveCurrNode();
369   manager2.RemoveCurrNode();
370 
371   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
372   manager1.RemoveCurrNode();
373   manager2.RemoveCurrNode();
374 
375   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
376   manager1.RemoveCurrNode();
377   manager2.RemoveCurrNode();
378 
379   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
380   manager1.RemoveCurrNode();
381   manager2.RemoveCurrNode();
382 
383   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
384   manager1.RemoveCurrNode();
385   manager2.RemoveCurrNode();
386 
387   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
388   manager1.RemoveCurrNode();
389   manager2.RemoveCurrNode();
390 
391   EXPECT_TRUE(manager1.Empty());
392   EXPECT_TRUE(manager2.Empty());
393 }
394 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultiplePriorityReadyManager)395 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultiplePriorityReadyManager) {
396   PriorityReadyManager manager;
397   TF_EXPECT_OK(manager.Init(&node_states_));
398 
399   // Sets up node priorities.
400   std::unordered_map<string, int> node_priority = {
401       {"Node1", 1}, {"Node2", 2}, {"Node3", 2}, {"Node4", 4}, {"Node5", 5}};
402   TF_EXPECT_OK(manager.SetPriority(node_priority));
403 
404   // Inserts nodes in some random order.
405   manager.AddNode(&node3_);
406   manager.AddNode(&node1_);
407   manager.AddNode(&node4_);
408   manager.AddNode(&node5_);
409   manager.AddNode(&node2_);
410   manager.AddNode(&node6_);
411 
412   // Expects nodes scheduled based on priority.
413   // Node6 should default to lowest priority, since it is not found.
414   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
415   manager.RemoveCurrNode();
416   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
417   manager.RemoveCurrNode();
418   // Nodes 2 and 3 have equal priority and so should be scheduled ready-first.
419   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
420   manager.RemoveCurrNode();
421   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
422   manager.RemoveCurrNode();
423   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
424   manager.RemoveCurrNode();
425   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
426   manager.RemoveCurrNode();
427   EXPECT_TRUE(manager.Empty());
428 }
429 
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeCompositeNodeManager)430 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
431   CompositeNodeManager manager;
432   TF_EXPECT_OK(manager.Init(&node_states_));
433   manager.AddNode(&node1_);
434   manager.RemoveCurrNode();
435   EXPECT_TRUE(manager.Empty());
436 }
437 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleCompositeNodeManager)438 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleCompositeNodeManager) {
439   CompositeNodeManager manager;
440   TF_EXPECT_OK(manager.Init(&node_states_));
441   manager.AddNode(&node1_);
442   manager.AddNode(&node2_);
443   manager.AddNode(&node3_);
444   manager.AddNode(&node4_);
445 
446   // Keeps checking current node as nodes are removed and added.
447   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
448   manager.RemoveCurrNode();
449   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
450   manager.AddNode(&node5_);
451   // GetCurrNode()  should return the same node even if some nodes are added,
452   // until RemoveCurrNode() is called.
453   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
454   manager.RemoveCurrNode();
455   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
456   manager.RemoveCurrNode();
457   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
458   manager.AddNode(&node6_);
459   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
460   manager.RemoveCurrNode();
461   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
462   manager.RemoveCurrNode();
463   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
464   manager.RemoveCurrNode();
465   EXPECT_TRUE(manager.Empty());
466 }
467 
TEST_F(ReadyNodeManagerTest,MultiDeviceSendRecvCompositeNodeManager)468 TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvCompositeNodeManager) {
469   CompositeNodeManager manager;
470   TF_EXPECT_OK(manager.Init(&node_states_));
471   // Additional nodes on kCPU1.
472   NodeDef node7;
473   NodeDef node8;
474   NodeDef node9;
475   NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
476   NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
477   NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
478 
479   // Send and Recv nodes.
480   NodeDef send1;
481   NodeDef send2;
482   NodeDef recv1;
483   NodeDef recv2;
484   NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
485   NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
486   NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
487   NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
488 
489   // Inserts nodes.
490   manager.AddNode(&node1_);
491   manager.AddNode(&node2_);
492   manager.AddNode(&node3_);
493   manager.AddNode(&node4_);
494   manager.AddNode(&node5_);
495   manager.AddNode(&node6_);
496   manager.AddNode(&node7);
497   manager.AddNode(&node8);
498   manager.AddNode(&node9);
499   manager.AddNode(&send1);
500   manager.AddNode(&send2);
501   manager.AddNode(&recv1);
502   manager.AddNode(&recv2);
503 
504   // On kCPU0; last one is node6_, on kCPU1: last one is node9;
505   // so choose one that has earliest time_ready among node6_, node9,
506   // Send1, Send2, Recv1, and Recv2.
507   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
508   manager.RemoveCurrNode();
509   // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
510   // among node5_, node9, Send1, Send2, Recv1, and Recv2.
511   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
512   manager.RemoveCurrNode();
513   // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
514   EXPECT_EQ(manager.GetCurrNode()->name(), "Send1");
515   manager.RemoveCurrNode();
516   // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
517   EXPECT_EQ(manager.GetCurrNode()->name(), "Recv1");
518   manager.RemoveCurrNode();
519   // Next, choose among node4_, node9, Send2, and Recv2.
520   EXPECT_EQ(manager.GetCurrNode()->name(), "Recv2");
521   manager.RemoveCurrNode();
522   // Next, choose among node4_, node9, and Send2.
523   EXPECT_EQ(manager.GetCurrNode()->name(), "Send2");
524   manager.RemoveCurrNode();
525   // Next, choose between node4_, node9.
526   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
527   manager.RemoveCurrNode();
528   // Next, choose between node3_, node9.
529   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
530   manager.RemoveCurrNode();
531   // Next, choose between node3_, node8.
532   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
533   manager.RemoveCurrNode();
534   // Next, choose between node3_, node7.
535   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
536   manager.RemoveCurrNode();
537   // Then, just the nodes on kCPU1 -- LIFO.
538   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
539   manager.RemoveCurrNode();
540   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
541   manager.RemoveCurrNode();
542   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
543   manager.RemoveCurrNode();
544   EXPECT_TRUE(manager.Empty());
545 }
546 
TEST_F(ReadyNodeManagerTest,DeterminismInCompositeNodeManager)547 TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
548   CompositeNodeManager manager;
549   TF_EXPECT_OK(manager.Init(&node_states_));
550   CompositeNodeManager manager2;
551   TF_EXPECT_OK(manager2.Init(&node_states_));
552 
553   // 6 nodes with same time_ready.
554   NodeDef node7;
555   NodeDef node8;
556   NodeDef node9;
557   NodeDef node10;
558   NodeDef node11;
559   NodeDef node12;
560   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
561   NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
562   NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
563   NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
564   NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
565   NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
566 
567   // Adds Nodes 7 to 9 to manager.
568   manager.AddNode(&node7);
569   manager.AddNode(&node8);
570   manager.AddNode(&node9);
571 
572   // It should return _Send, Recv, and the other op order, when the candidate
573   // nodes have same time_ready.
574   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
575   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
576   manager.RemoveCurrNode();
577   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
578   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
579   manager.RemoveCurrNode();
580   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
581   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
582   manager.RemoveCurrNode();
583   EXPECT_TRUE(manager.Empty());
584 
585   // Adds Nodes 7 to 9 to manager, but in a different order.
586   manager.AddNode(&node9);
587   manager.AddNode(&node8);
588   manager.AddNode(&node7);
589 
590   // Expects same order (_Send, _Recv, and the other op), regardless of Add
591   // order.
592   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
593   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
594   manager.RemoveCurrNode();
595   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
596   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
597   manager.RemoveCurrNode();
598   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
599   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
600   manager.RemoveCurrNode();
601   EXPECT_TRUE(manager.Empty());
602 
603   // Conv2D's time_ready < Send's time_ready; Expects Conv2D first.
604   manager.AddNode(&node8);
605   manager.AddNode(&node10);
606   EXPECT_EQ(manager.GetCurrNode()->name(), "Node10");
607   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
608   manager.RemoveCurrNode();
609   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
610   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
611   manager.RemoveCurrNode();
612   EXPECT_TRUE(manager.Empty());
613 
614   // Recv's time_ready < Send' time_ready; Expects Recv first.
615   manager.AddNode(&node11);
616   manager.AddNode(&node8);
617   EXPECT_EQ(manager.GetCurrNode()->name(), "Node11");
618   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
619   manager.RemoveCurrNode();
620   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
621   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
622   manager.RemoveCurrNode();
623   EXPECT_TRUE(manager.Empty());
624 
625   // Node7 and 12 are normal ops with the same time_ready, placed on different
626   // devices. These two nodes are added to manager and manager2, but in
627   // different orders; Expects GetCurrNode() returns the nodes in the same
628   // order.
629   manager.AddNode(&node7);
630   manager.AddNode(&node12);
631 
632   manager2.AddNode(&node12);
633   manager2.AddNode(&node7);
634 
635   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
636   manager.RemoveCurrNode();
637   manager2.RemoveCurrNode();
638   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
639   manager.RemoveCurrNode();
640   manager2.RemoveCurrNode();
641   EXPECT_TRUE(manager.Empty());
642 }
643 
644 // Class for testing virtual scheduler.
645 class TestVirtualScheduler : public VirtualScheduler {
646  public:
TestVirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,ReadyNodeManager * ready_node_manager,Cluster * cluster)647   TestVirtualScheduler(const bool use_static_shapes,
648                        const bool use_aggressive_shape_inference,
649                        ReadyNodeManager* ready_node_manager, Cluster* cluster)
650       : VirtualScheduler(
651             use_static_shapes, use_aggressive_shape_inference, cluster,
652             ready_node_manager,
653             absl::make_unique<VirtualPlacer>(cluster->GetDevices())) {
654     enable_mem_usage_tracking();
655   }
656 
657   FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
658   FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
659   FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
660   FRIEND_TEST(VirtualSchedulerTest, Variable);
661   FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
662 };
663 
664 class VirtualSchedulerTest : public ::testing::Test {
665  protected:
VirtualSchedulerTest()666   VirtualSchedulerTest() {
667     // Initializes cluster_ and scheduler_.
668     std::unordered_map<string, DeviceProperties> devices;
669 
670     // Set some dummy CPU properties
671     DeviceProperties cpu_device = GetDummyCPUDevice();
672 
673     // IMPORTANT: Device is not actually ever used in the test case since
674     // force_cpu_type is defaulted to "Haswell"
675     devices[kCPU0] = cpu_device;
676     devices[kCPU1] = cpu_device;
677     cluster_ = absl::make_unique<VirtualCluster>(devices);
678     scheduler_ = absl::make_unique<TestVirtualScheduler>(
679         /*use_static_shapes=*/true,
680         /*use_aggressive_shape_inference=*/true, &first_ready_manager_,
681         cluster_.get());
682   }
683 
GetDummyCPUDevice()684   DeviceProperties GetDummyCPUDevice() {
685     // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
686     // - 8 Gflops
687     // - 2 GB/s
688     DeviceProperties cpu_device;
689     cpu_device.set_type("CPU");
690     cpu_device.set_frequency(4000);
691     cpu_device.set_num_cores(2);
692     cpu_device.set_bandwidth(2000000);
693     return cpu_device;
694   }
695 
696   // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()697   void CreateGrapplerItemWithConv2Ds() {
698     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
699     auto x = ops::RandomUniform(
700         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
701     auto y = ops::RandomUniform(
702         s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
703     auto z = ops::RandomUniform(
704         s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
705     auto f = ops::RandomUniform(
706         s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
707     std::vector<int> strides = {1, 1, 1, 1};
708     auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
709     auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
710     auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
711 
712     grappler_item_ = absl::make_unique<GrapplerItem>();
713     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
714     grappler_item_->id = "test_conv2d_graph";
715     grappler_item_->fetch = {"c0", "c1"};
716 
717     dependency_["c0"] = {"x", "f"};
718     dependency_["c1"] = {"y", "f"};
719   }
720 
721   // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()722   void CreateGrapplerItemWithConv2DAndVariable() {
723     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
724     auto x = ops::RandomUniform(
725         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
726     auto f = ops::Variable(s.WithOpName("f"),
727                            {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
728     std::vector<int> strides = {1, 1, 1, 1};
729     auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
730 
731     grappler_item_ = absl::make_unique<GrapplerItem>();
732     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
733     grappler_item_->id = "test_conv2d_var_graph";
734 
735     grappler_item_->fetch = {"y"};
736 
737     dependency_["y"] = {"x", "f"};
738   }
739 
CreateGrapplerItemWithMatmulChain()740   void CreateGrapplerItemWithMatmulChain() {
741     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
742     // Add control dependencies to ensure tests do not rely on specific
743     // manager and the order remains consistent for the test.
744     auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
745     auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
746                                 {3200, 3200}, DT_FLOAT);
747     auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
748                                 {3200, 3200}, DT_FLOAT);
749     auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
750                                 {3200, 3200}, DT_FLOAT);
751     auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
752                                 {3200, 3200}, DT_FLOAT);
753 
754     auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
755     auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
756     auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
757     auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
758 
759     grappler_item_ = absl::make_unique<GrapplerItem>();
760     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
761     grappler_item_->id = "test_matmul_sequence_graph";
762     grappler_item_->fetch = {"abcde"};
763 
764     dependency_["ab"] = {"a", "b"};
765     dependency_["abc"] = {"ab", "c"};
766     dependency_["abcd"] = {"abc", "d"};
767     dependency_["abcde"] = {"abcd", "e"};
768   }
769 
770   // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()771   void CreateGrapplerItemWithAddN() {
772     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
773     auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
774     auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
775     auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
776     auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
777     OutputList input_tensors = {x, y, z, w};
778     auto add = ops::AddN(s.WithOpName("add"), input_tensors);
779     auto out = ops::Identity(s.WithOpName("out"), add);
780 
781     grappler_item_ = absl::make_unique<GrapplerItem>();
782     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
783     grappler_item_->id = "test_addn_graph";
784     grappler_item_->fetch = {"out"};
785 
786     dependency_["out"] = {"x", "y", "z", "w", "add"};
787   }
788 
789   // Graph with some placeholder feed nodes that are not in the fetch fan-in.
CreateGrapplerItemWithUnnecessaryPlaceholderNodes()790   void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
791     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
792     auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
793     auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
794 
795     grappler_item_ = absl::make_unique<GrapplerItem>();
796     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
797 
798     grappler_item_->id = "test_extra_placeholders";
799     grappler_item_->fetch = {"x"};
800 
801     // Grappler Item Builder puts all placeholder nodes into the feed
802     // list by default.
803     grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
804   }
805 
806   // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()807   void CreateGrapplerItemWithControlDependency() {
808     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
809     std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
810     std::vector<Operation> input_tensors;
811     for (const auto& input : input_noop_names) {
812       auto x = ops::NoOp(s.WithOpName(input));
813       input_tensors.push_back(x.operation);
814     }
815     auto out =
816         ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
817 
818     grappler_item_ = absl::make_unique<GrapplerItem>();
819     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
820 
821     grappler_item_->id = "test_control_dependency_graph";
822     grappler_item_->fetch = {"out"};
823 
824     dependency_["out"] = input_noop_names;
825   }
826 
CreateGrapplerItemWithAddFromOneTensor()827   void CreateGrapplerItemWithAddFromOneTensor() {
828     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
829     auto x = tensorflow::ops::RandomUniform(
830         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
831 
832     auto y = tensorflow::ops::Add(s.WithOpName("y"), x, x);
833     Output fetch = ops::Identity(s.WithOpName("fetch"), y);
834 
835     grappler_item_ = absl::make_unique<GrapplerItem>();
836     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
837 
838     grappler_item_->id = "test_add_from_one_tensor";
839     grappler_item_->fetch = {"fetch"};
840 
841     dependency_["fetch"] = {"y"};
842     dependency_["y"] = {"x"};
843   }
844 
CreateGrapplerItemWithSwitchMergeInput()845   void CreateGrapplerItemWithSwitchMergeInput() {
846     // sw = Switch(x, pred)
847     // a = Add(S:1, b)
848     // m = Merge(sw:0, a)
849     // y = Add(m, z)
850 
851     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
852     auto x = ops::RandomUniform(
853         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
854     auto pred = ops::Const(s.WithOpName("pred"), false, {});
855     auto sw = ops::Switch(s.WithOpName("switch"), x, pred);
856     auto b = ops::RandomUniform(
857         s.WithOpName("b"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
858     auto a = ops::Add(s.WithOpName("a"), sw.output_true, b);
859     auto m = ops::Merge(s.WithOpName("m"), {sw.output_false, a.z});
860     auto z = ops::RandomUniform(
861         s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
862     auto y = ops::Add(s.WithOpName("y"), m.output, z);
863 
864     grappler_item_ = absl::make_unique<GrapplerItem>();
865     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
866 
867     grappler_item_->id = "test_add_merge_switch";
868     grappler_item_->fetch = {"y"};
869 
870     dependency_["y"] = {"m", "z"};
871   }
872 
873   // FusedBN [an op with multiple outputs] with multiple consumers (including
874   // control dependency).
CreateGrapplerItemWithBatchNorm()875   void CreateGrapplerItemWithBatchNorm() {
876     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
877     auto x = ops::RandomUniform(
878         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
879     auto scale =
880         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
881     auto offset =
882         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
883     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
884     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
885 
886     auto batch_norm = ops::FusedBatchNorm(
887         s.WithOpName("bn"), x, scale, offset, mean, var,
888         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
889     auto y = batch_norm.y;
890     auto batch_mean = batch_norm.batch_mean;
891     auto batch_var = batch_norm.batch_variance;
892 
893     auto z1 = ops::Add(s.WithOpName("z1"), x, y);
894     auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
895     auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
896     std::vector<Operation> input_tensors = {
897         batch_mean.op(),
898         z1.z.op(),
899         z2.z.op(),
900         z3.z.op(),
901     };
902     auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
903 
904     grappler_item_ = absl::make_unique<GrapplerItem>();
905     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
906 
907     grappler_item_->id = "test_complex_dependency_graph";
908     grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
909 
910     dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
911     dependency_["z1"] = {"x", "bn"};
912     dependency_["z2"] = {"bn"};
913     dependency_["z3"] = {"bn"};
914     dependency_["z4"] = {"bn"};
915   }
916 
CreateGrapplerItemWithSendRecv()917   void CreateGrapplerItemWithSendRecv() {
918     const string gdef_ascii = R"EOF(
919 node {
920   name: "Const"
921   op: "Const"
922   device: "/job:localhost/replica:0/task:0/device:CPU:0"
923   attr {
924     key: "dtype"
925     value {
926       type: DT_FLOAT
927     }
928   }
929   attr {
930     key: "_output_shapes"
931     value {
932       list { shape {
933         dim { size: 128 }
934         dim { size: 32 }
935       }}}
936   }
937   attr {
938     key: "shape"
939     value {
940       list { shape {
941         dim { size: 128 }
942         dim { size: 32 }
943       }}}
944   }
945   attr {
946     key: "value"
947     value {
948       tensor {
949         dtype: DT_FLOAT
950         tensor_shape {
951           dim { size: 128 }
952           dim { size: 32 }
953         }
954         float_val: 3.1415
955       }
956     }
957   }
958 }
959 node {
960   name: "Send"
961   op: "_Send"
962   input: "Const"
963   device: "/job:localhost/replica:0/task:0/device:CPU:0"
964   attr {
965     key: "T"
966     value {
967       type: DT_FLOAT
968     }
969   }
970   attr {
971     key: "_output_shapes"
972     value {
973       list { shape {
974         dim { size: 128 }
975         dim { size: 32 }
976       }}}
977   }
978   attr {
979     key: "shape"
980     value {
981       list { shape {
982         dim { size: 128 }
983         dim { size: 32 }
984       }}}
985   }
986   attr {
987     key: "client_terminated"
988     value {
989       b: false
990     }
991   }
992   attr {
993     key: "recv_device"
994     value {
995       s: "/job:localhost/replica:0/task:0/device:CPU:0"
996     }
997   }
998   attr {
999     key: "send_device"
1000     value {
1001       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1002     }
1003   }
1004   attr {
1005     key: "send_device_incarnation"
1006     value {
1007       i: 0
1008     }
1009   }
1010   attr {
1011     key: "tensor_name"
1012     value {
1013       s: "test"
1014     }
1015   }
1016 }
1017 node {
1018   name: "Recv"
1019   op: "_Recv"
1020   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1021   attr {
1022     key: "client_terminated"
1023     value {
1024       b: false
1025     }
1026   }
1027   attr {
1028     key: "_output_shapes"
1029     value {
1030       list { shape {
1031         dim { size: 128 }
1032         dim { size: 32 }
1033       }}}
1034   }
1035   attr {
1036     key: "shape"
1037     value {
1038       list { shape {
1039         dim { size: 128 }
1040         dim { size: 32 }
1041       }}}
1042   }
1043   attr {
1044     key: "recv_device"
1045     value {
1046       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1047     }
1048   }
1049   attr {
1050     key: "send_device"
1051     value {
1052       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1053     }
1054   }
1055   attr {
1056     key: "send_device_incarnation"
1057     value {
1058       i: 0
1059     }
1060   }
1061   attr {
1062     key: "tensor_name"
1063     value {
1064       s: "test"
1065     }
1066   }
1067   attr {
1068     key: "tensor_type"
1069     value {
1070       type: DT_FLOAT
1071     }
1072   }
1073 }
1074 library {
1075 }
1076 versions {
1077   producer: 24
1078 }
1079     )EOF";
1080 
1081     grappler_item_ = absl::make_unique<GrapplerItem>();
1082 
1083     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1084                                                 &grappler_item_->graph));
1085     grappler_item_->id = "test_graph";
1086     grappler_item_->fetch = {"Recv"};
1087   }
1088 
CreateGrapplerItemWithRecvWithoutSend()1089   void CreateGrapplerItemWithRecvWithoutSend() {
1090     const string gdef_ascii = R"EOF(
1091 node {
1092   name: "Recv"
1093   op: "_Recv"
1094   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1095   attr {
1096     key: "client_terminated"
1097     value {
1098       b: false
1099     }
1100   }
1101   attr {
1102     key: "recv_device"
1103     value {
1104       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1105     }
1106   }
1107   attr {
1108     key: "send_device"
1109     value {
1110       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1111     }
1112   }
1113   attr {
1114     key: "send_device_incarnation"
1115     value {
1116       i: 0
1117     }
1118   }
1119   attr {
1120     key: "tensor_name"
1121     value {
1122       s: "test"
1123     }
1124   }
1125   attr {
1126     key: "tensor_type"
1127     value {
1128       type: DT_FLOAT
1129     }
1130   }
1131 }
1132 library {
1133 }
1134 versions {
1135   producer: 24
1136 }
1137     )EOF";
1138 
1139     grappler_item_ = absl::make_unique<GrapplerItem>();
1140     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1141                                                 &grappler_item_->graph));
1142     grappler_item_->id = "test_graph";
1143     grappler_item_->fetch = {"Recv"};
1144   }
1145 
1146   // A simple while loop
CreateGrapplerItemWithLoop()1147   void CreateGrapplerItemWithLoop() {
1148     // Test graph produced in python using:
1149     /*
1150       with tf.Graph().as_default():
1151       i0 = tf.constant(0)
1152       m0 = tf.ones([2, 2])
1153       c = lambda i, m: i < 10
1154       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1155       r = tf.while_loop(
1156       c, b, loop_vars=[i0, m0],
1157       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1158       with open('/tmp/graph.pbtxt', 'w') as f:
1159       f.write(str(tf.get_default_graph().as_graph_def()))
1160     */
1161     const string gdef_ascii = R"EOF(
1162 node {
1163   name: "Const"
1164   op: "Const"
1165   attr {
1166     key: "dtype"
1167     value {
1168       type: DT_INT32
1169     }
1170   }
1171   attr {
1172     key: "value"
1173     value {
1174       tensor {
1175         dtype: DT_INT32
1176         tensor_shape {
1177         }
1178         int_val: 0
1179       }
1180     }
1181   }
1182 }
1183 node {
1184   name: "ones"
1185   op: "Const"
1186   attr {
1187     key: "dtype"
1188     value {
1189       type: DT_FLOAT
1190     }
1191   }
1192   attr {
1193     key: "value"
1194     value {
1195       tensor {
1196         dtype: DT_FLOAT
1197         tensor_shape {
1198           dim {
1199             size: 2
1200           }
1201           dim {
1202             size: 2
1203           }
1204         }
1205         float_val: 1.0
1206       }
1207     }
1208   }
1209 }
1210 node {
1211   name: "while/Enter"
1212   op: "Enter"
1213   input: "Const"
1214   attr {
1215     key: "T"
1216     value {
1217       type: DT_INT32
1218     }
1219   }
1220   attr {
1221     key: "frame_name"
1222     value {
1223       s: "while/while/"
1224     }
1225   }
1226   attr {
1227     key: "is_constant"
1228     value {
1229       b: false
1230     }
1231   }
1232   attr {
1233     key: "parallel_iterations"
1234     value {
1235       i: 10
1236     }
1237   }
1238 }
1239 node {
1240   name: "while/Enter_1"
1241   op: "Enter"
1242   input: "ones"
1243   attr {
1244     key: "T"
1245     value {
1246       type: DT_FLOAT
1247     }
1248   }
1249   attr {
1250     key: "frame_name"
1251     value {
1252       s: "while/while/"
1253     }
1254   }
1255   attr {
1256     key: "is_constant"
1257     value {
1258       b: false
1259     }
1260   }
1261   attr {
1262     key: "parallel_iterations"
1263     value {
1264       i: 10
1265     }
1266   }
1267 }
1268 node {
1269   name: "while/Merge"
1270   op: "Merge"
1271   input: "while/Enter"
1272   input: "while/NextIteration"
1273   attr {
1274     key: "N"
1275     value {
1276       i: 2
1277     }
1278   }
1279   attr {
1280     key: "T"
1281     value {
1282       type: DT_INT32
1283     }
1284   }
1285 }
1286 node {
1287   name: "while/Merge_1"
1288   op: "Merge"
1289   input: "while/Enter_1"
1290   input: "while/NextIteration_1"
1291   attr {
1292     key: "N"
1293     value {
1294       i: 2
1295     }
1296   }
1297   attr {
1298     key: "T"
1299     value {
1300       type: DT_FLOAT
1301     }
1302   }
1303 }
1304 node {
1305   name: "while/Less/y"
1306   op: "Const"
1307   input: "^while/Merge"
1308   attr {
1309     key: "dtype"
1310     value {
1311       type: DT_INT32
1312     }
1313   }
1314   attr {
1315     key: "value"
1316     value {
1317       tensor {
1318         dtype: DT_INT32
1319         tensor_shape {
1320         }
1321         int_val: 10
1322       }
1323     }
1324   }
1325 }
1326 node {
1327   name: "while/Less"
1328   op: "Less"
1329   input: "while/Merge"
1330   input: "while/Less/y"
1331   attr {
1332     key: "T"
1333     value {
1334       type: DT_INT32
1335     }
1336   }
1337 }
1338 node {
1339   name: "while/LoopCond"
1340   op: "LoopCond"
1341   input: "while/Less"
1342 }
1343 node {
1344   name: "while/Switch"
1345   op: "Switch"
1346   input: "while/Merge"
1347   input: "while/LoopCond"
1348   attr {
1349     key: "T"
1350     value {
1351       type: DT_INT32
1352     }
1353   }
1354   attr {
1355     key: "_class"
1356     value {
1357       list {
1358         s: "loc:@while/Merge"
1359       }
1360     }
1361   }
1362 }
1363 node {
1364   name: "while/Switch_1"
1365   op: "Switch"
1366   input: "while/Merge_1"
1367   input: "while/LoopCond"
1368   attr {
1369     key: "T"
1370     value {
1371       type: DT_FLOAT
1372     }
1373   }
1374   attr {
1375     key: "_class"
1376     value {
1377       list {
1378         s: "loc:@while/Merge_1"
1379       }
1380     }
1381   }
1382 }
1383 node {
1384   name: "while/Identity"
1385   op: "Identity"
1386   input: "while/Switch:1"
1387   attr {
1388     key: "T"
1389     value {
1390       type: DT_INT32
1391     }
1392   }
1393 }
1394 node {
1395   name: "while/Identity_1"
1396   op: "Identity"
1397   input: "while/Switch_1:1"
1398   attr {
1399     key: "T"
1400     value {
1401       type: DT_FLOAT
1402     }
1403   }
1404 }
1405 node {
1406   name: "while/add/y"
1407   op: "Const"
1408   input: "^while/Identity"
1409   attr {
1410     key: "dtype"
1411     value {
1412       type: DT_INT32
1413     }
1414   }
1415   attr {
1416     key: "value"
1417     value {
1418       tensor {
1419         dtype: DT_INT32
1420         tensor_shape {
1421         }
1422         int_val: 1
1423       }
1424     }
1425   }
1426 }
1427 node {
1428   name: "while/add"
1429   op: "Add"
1430   input: "while/Identity"
1431   input: "while/add/y"
1432   attr {
1433     key: "T"
1434     value {
1435       type: DT_INT32
1436     }
1437   }
1438 }
1439 node {
1440   name: "while/concat/axis"
1441   op: "Const"
1442   input: "^while/Identity"
1443   attr {
1444     key: "dtype"
1445     value {
1446       type: DT_INT32
1447     }
1448   }
1449   attr {
1450     key: "value"
1451     value {
1452       tensor {
1453         dtype: DT_INT32
1454         tensor_shape {
1455         }
1456         int_val: 0
1457       }
1458     }
1459   }
1460 }
1461 node {
1462   name: "while/concat"
1463   op: "ConcatV2"
1464   input: "while/Identity_1"
1465   input: "while/Identity_1"
1466   input: "while/concat/axis"
1467   attr {
1468     key: "N"
1469     value {
1470       i: 2
1471     }
1472   }
1473   attr {
1474     key: "T"
1475     value {
1476       type: DT_FLOAT
1477     }
1478   }
1479   attr {
1480     key: "Tidx"
1481     value {
1482       type: DT_INT32
1483     }
1484   }
1485 }
1486 node {
1487   name: "while/NextIteration"
1488   op: "NextIteration"
1489   input: "while/add"
1490   attr {
1491     key: "T"
1492     value {
1493       type: DT_INT32
1494     }
1495   }
1496 }
1497 node {
1498   name: "while/NextIteration_1"
1499   op: "NextIteration"
1500   input: "while/concat"
1501   attr {
1502     key: "T"
1503     value {
1504       type: DT_FLOAT
1505     }
1506   }
1507 }
1508 node {
1509   name: "while/Exit"
1510   op: "Exit"
1511   input: "while/Switch"
1512   attr {
1513     key: "T"
1514     value {
1515       type: DT_INT32
1516     }
1517   }
1518 }
1519 node {
1520   name: "while/Exit_1"
1521   op: "Exit"
1522   input: "while/Switch_1"
1523   attr {
1524     key: "T"
1525     value {
1526       type: DT_FLOAT
1527     }
1528   }
1529 }
1530 versions {
1531   producer: 21
1532 }
1533   )EOF";
1534 
1535     grappler_item_ = absl::make_unique<GrapplerItem>();
1536     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1537                                                 &grappler_item_->graph));
1538     grappler_item_->id = "test_graph";
1539     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
1540   }
1541 
1542   // A simple while loop strengthened with Switch outputs xxx.
CreateGrapplerItemWithLoopAnnotated()1543   void CreateGrapplerItemWithLoopAnnotated() {
1544     // Test graph produced in python using:
1545     /*
1546       with tf.Graph().as_default():
1547       i0 = tf.constant(0)
1548       m0 = tf.ones([2, 2])
1549       c = lambda i, m: i < 10
1550       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1551       r = tf.while_loop(
1552       c, b, loop_vars=[i0, m0],
1553       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1554       with open('/tmp/graph.pbtxt', 'w') as f:
1555       f.write(str(tf.get_default_graph().as_graph_def()))
1556     */
1557     const string gdef_ascii = R"EOF(
1558 node {
1559   name: "Const"
1560   op: "Const"
1561   attr {
1562     key: "dtype"
1563     value {
1564       type: DT_INT32
1565     }
1566   }
1567   attr {
1568     key: "value"
1569     value {
1570       tensor {
1571         dtype: DT_INT32
1572         tensor_shape {
1573         }
1574         int_val: 0
1575       }
1576     }
1577   }
1578   attr {
1579     key: "_execution_count"
1580     value {
1581       i: 1
1582     }
1583   }
1584 }
1585 node {
1586   name: "ones"
1587   op: "Const"
1588   attr {
1589     key: "dtype"
1590     value {
1591       type: DT_FLOAT
1592     }
1593   }
1594   attr {
1595     key: "value"
1596     value {
1597       tensor {
1598         dtype: DT_FLOAT
1599         tensor_shape {
1600           dim {
1601             size: 2
1602           }
1603           dim {
1604             size: 2
1605           }
1606         }
1607         float_val: 1.0
1608       }
1609     }
1610   }
1611   attr {
1612     key: "_execution_count"
1613     value {
1614       i: 1
1615     }
1616   }
1617 }
1618 node {
1619   name: "while/Enter"
1620   op: "Enter"
1621   input: "Const"
1622   attr {
1623     key: "T"
1624     value {
1625       type: DT_INT32
1626     }
1627   }
1628   attr {
1629     key: "frame_name"
1630     value {
1631       s: "while/while/"
1632     }
1633   }
1634   attr {
1635     key: "is_constant"
1636     value {
1637       b: false
1638     }
1639   }
1640   attr {
1641     key: "parallel_iterations"
1642     value {
1643       i: 10
1644     }
1645   }
1646   attr {
1647     key: "_execution_count"
1648     value {
1649       i: 1
1650     }
1651   }
1652 }
1653 node {
1654   name: "while/Enter_1"
1655   op: "Enter"
1656   input: "ones"
1657   attr {
1658     key: "T"
1659     value {
1660       type: DT_FLOAT
1661     }
1662   }
1663   attr {
1664     key: "frame_name"
1665     value {
1666       s: "while/while/"
1667     }
1668   }
1669   attr {
1670     key: "is_constant"
1671     value {
1672       b: false
1673     }
1674   }
1675   attr {
1676     key: "parallel_iterations"
1677     value {
1678       i: 10
1679     }
1680   }
1681   attr {
1682     key: "_execution_count"
1683     value {
1684       i: 1
1685     }
1686   }
1687 }
1688 node {
1689   name: "while/Merge"
1690   op: "Merge"
1691   input: "while/Enter"
1692   input: "while/NextIteration"
1693   attr {
1694     key: "N"
1695     value {
1696       i: 2
1697     }
1698   }
1699   attr {
1700     key: "T"
1701     value {
1702       type: DT_INT32
1703     }
1704   }
1705   attr {
1706     key: "_execution_count"
1707     value {
1708       i: 10
1709     }
1710   }
1711 }
1712 node {
1713   name: "while/Merge_1"
1714   op: "Merge"
1715   input: "while/Enter_1"
1716   input: "while/NextIteration_1"
1717   attr {
1718     key: "N"
1719     value {
1720       i: 2
1721     }
1722   }
1723   attr {
1724     key: "T"
1725     value {
1726       type: DT_FLOAT
1727     }
1728   }
1729   attr {
1730     key: "_execution_count"
1731     value {
1732       i: 10
1733     }
1734   }
1735 }
1736 node {
1737   name: "while/Less/y"
1738   op: "Const"
1739   input: "^while/Merge"
1740   attr {
1741     key: "dtype"
1742     value {
1743       type: DT_INT32
1744     }
1745   }
1746   attr {
1747     key: "value"
1748     value {
1749       tensor {
1750         dtype: DT_INT32
1751         tensor_shape {
1752         }
1753         int_val: 10
1754       }
1755     }
1756   }
1757   attr {
1758     key: "_execution_count"
1759     value {
1760       i: 10
1761     }
1762   }
1763 }
1764 node {
1765   name: "while/Less"
1766   op: "Less"
1767   input: "while/Merge"
1768   input: "while/Less/y"
1769   attr {
1770     key: "T"
1771     value {
1772       type: DT_INT32
1773     }
1774   }
1775   attr {
1776     key: "_execution_count"
1777     value {
1778       i: 10
1779     }
1780   }
1781 }
1782 node {
1783   name: "while/LoopCond"
1784   op: "LoopCond"
1785   input: "while/Less"
1786   attr {
1787     key: "_execution_count"
1788     value {
1789       i: 10
1790     }
1791   }
1792 }
1793 node {
1794   name: "while/Switch"
1795   op: "Switch"
1796   input: "while/Merge"
1797   input: "while/LoopCond"
1798   attr {
1799     key: "T"
1800     value {
1801       type: DT_INT32
1802     }
1803   }
1804   attr {
1805     key: "_class"
1806     value {
1807       list {
1808         s: "loc:@while/Merge"
1809       }
1810     }
1811   }
1812   attr {
1813     key: "_execution_count"
1814     value {
1815       i: 11
1816     }
1817   }
1818   attr {
1819     key: "_output_slot_vector"
1820     value {
1821       list {
1822         i: 1
1823         i: 1
1824         i: 1
1825         i: 1
1826         i: 1
1827         i: 1
1828         i: 1
1829         i: 1
1830         i: 1
1831         i: 1
1832         i: 0
1833       }
1834     }
1835   }
1836 }
1837 node {
1838   name: "while/Switch_1"
1839   op: "Switch"
1840   input: "while/Merge_1"
1841   input: "while/LoopCond"
1842   attr {
1843     key: "T"
1844     value {
1845       type: DT_FLOAT
1846     }
1847   }
1848   attr {
1849     key: "_class"
1850     value {
1851       list {
1852         s: "loc:@while/Merge_1"
1853       }
1854     }
1855   }
1856   attr {
1857     key: "_execution_count"
1858     value {
1859       i: 11
1860     }
1861   }
1862   attr {
1863     key: "_output_slot_vector"
1864     value {
1865       list {
1866         i: 1
1867         i: 1
1868         i: 1
1869         i: 1
1870         i: 1
1871         i: 1
1872         i: 1
1873         i: 1
1874         i: 1
1875         i: 1
1876         i: 0
1877       }
1878     }
1879   }
1880 }
1881 node {
1882   name: "while/Identity"
1883   op: "Identity"
1884   input: "while/Switch:1"
1885   attr {
1886     key: "T"
1887     value {
1888       type: DT_INT32
1889     }
1890   }
1891   attr {
1892     key: "_execution_count"
1893     value {
1894       i: 10
1895     }
1896   }
1897 }
1898 node {
1899   name: "while/Identity_1"
1900   op: "Identity"
1901   input: "while/Switch_1:1"
1902   attr {
1903     key: "T"
1904     value {
1905       type: DT_FLOAT
1906     }
1907   }
1908   attr {
1909     key: "_execution_count"
1910     value {
1911       i: 10
1912     }
1913   }
1914 }
1915 node {
1916   name: "while/add/y"
1917   op: "Const"
1918   input: "^while/Identity"
1919   attr {
1920     key: "dtype"
1921     value {
1922       type: DT_INT32
1923     }
1924   }
1925   attr {
1926     key: "value"
1927     value {
1928       tensor {
1929         dtype: DT_INT32
1930         tensor_shape {
1931         }
1932         int_val: 1
1933       }
1934     }
1935   }
1936   attr {
1937     key: "_execution_count"
1938     value {
1939       i: 10
1940     }
1941   }
1942 }
1943 node {
1944   name: "while/add"
1945   op: "Add"
1946   input: "while/Identity"
1947   input: "while/add/y"
1948   attr {
1949     key: "T"
1950     value {
1951       type: DT_INT32
1952     }
1953   }
1954   attr {
1955     key: "_execution_count"
1956     value {
1957       i: 10
1958     }
1959   }
1960 }
1961 node {
1962   name: "while/concat/axis"
1963   op: "Const"
1964   input: "^while/Identity"
1965   attr {
1966     key: "dtype"
1967     value {
1968       type: DT_INT32
1969     }
1970   }
1971   attr {
1972     key: "value"
1973     value {
1974       tensor {
1975         dtype: DT_INT32
1976         tensor_shape {
1977         }
1978         int_val: 0
1979       }
1980     }
1981   }
1982   attr {
1983     key: "_execution_count"
1984     value {
1985       i: 10
1986     }
1987   }
1988 }
1989 node {
1990   name: "while/concat"
1991   op: "ConcatV2"
1992   input: "while/Identity_1"
1993   input: "while/Identity_1"
1994   input: "while/concat/axis"
1995   attr {
1996     key: "N"
1997     value {
1998       i: 2
1999     }
2000   }
2001   attr {
2002     key: "T"
2003     value {
2004       type: DT_FLOAT
2005     }
2006   }
2007   attr {
2008     key: "Tidx"
2009     value {
2010       type: DT_INT32
2011     }
2012   }
2013   attr {
2014     key: "_execution_count"
2015     value {
2016       i: 10
2017     }
2018   }
2019 }
2020 node {
2021   name: "while/NextIteration"
2022   op: "NextIteration"
2023   input: "while/add"
2024   attr {
2025     key: "T"
2026     value {
2027       type: DT_INT32
2028     }
2029   }
2030   attr {
2031     key: "_execution_count"
2032     value {
2033       i: 10
2034     }
2035   }
2036 }
2037 node {
2038   name: "while/NextIteration_1"
2039   op: "NextIteration"
2040   input: "while/concat"
2041   attr {
2042     key: "T"
2043     value {
2044       type: DT_FLOAT
2045     }
2046   }
2047   attr {
2048     key: "_execution_count"
2049     value {
2050       i: 10
2051     }
2052   }
2053 }
2054 node {
2055   name: "while/Exit"
2056   op: "Exit"
2057   input: "while/Switch"
2058   attr {
2059     key: "T"
2060     value {
2061       type: DT_INT32
2062     }
2063   }
2064   attr {
2065     key: "_execution_count"
2066     value {
2067       i: 1
2068     }
2069   }
2070 }
2071 node {
2072   name: "while/Exit_1"
2073   op: "Exit"
2074   input: "while/Switch_1"
2075   attr {
2076     key: "T"
2077     value {
2078       type: DT_FLOAT
2079     }
2080   }
2081   attr {
2082     key: "_execution_count"
2083     value {
2084       i: 1
2085     }
2086   }
2087 }
2088 versions {
2089   producer: 21
2090 }
2091   )EOF";
2092 
2093     grappler_item_.reset(new GrapplerItem);
2094     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2095                                                 &grappler_item_->graph));
2096     grappler_item_->id = "test_graph";
2097     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
2098   }
2099 
2100   // A simple condition graph.
CreateGrapplerItemWithCondition()2101   void CreateGrapplerItemWithCondition() {
2102     // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
2103     const string gdef_ascii = R"EOF(
2104 node {
2105   name: "a"
2106   op: "Const"
2107   attr {
2108     key: "dtype"
2109     value {
2110       type: DT_FLOAT
2111     }
2112   }
2113   attr {
2114     key: "value"
2115     value {
2116       tensor {
2117         dtype: DT_FLOAT
2118         tensor_shape {
2119         }
2120         float_val: 2.0
2121       }
2122     }
2123   }
2124 }
2125 node {
2126   name: "Less"
2127   op: "Const"
2128   attr {
2129     key: "dtype"
2130     value {
2131       type: DT_BOOL
2132     }
2133   }
2134   attr {
2135     key: "value"
2136     value {
2137       tensor {
2138         dtype: DT_BOOL
2139         tensor_shape {
2140         }
2141         tensor_content: "\001"
2142       }
2143     }
2144   }
2145 }
2146 node {
2147   name: "Switch"
2148   op: "Switch"
2149   input: "a"
2150   input: "Less"
2151   attr {
2152     key: "T"
2153     value {
2154       type: DT_FLOAT
2155     }
2156   }
2157 }
2158 node {
2159   name: "First"
2160   op: "Identity"
2161   input: "Switch"
2162   attr {
2163     key: "T"
2164     value {
2165       type: DT_FLOAT
2166     }
2167   }
2168 }
2169 node {
2170   name: "Second"
2171   op: "Identity"
2172   input: "Switch:1"
2173   attr {
2174     key: "T"
2175     value {
2176       type: DT_FLOAT
2177     }
2178   }
2179 }
2180 node {
2181   name: "Merge"
2182   op: "Merge"
2183   input: "First"
2184   input: "Second"
2185   attr {
2186     key: "N"
2187     value {
2188       i: 2
2189     }
2190   }
2191   attr {
2192     key: "T"
2193     value {
2194       type: DT_FLOAT
2195     }
2196   }
2197 }
2198 versions {
2199   producer: 27
2200 })EOF";
2201 
2202     grappler_item_ = absl::make_unique<GrapplerItem>();
2203     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2204                                                 &grappler_item_->graph));
2205     grappler_item_->id = "test_graph";
2206     grappler_item_->fetch = {"Merge"};
2207   }
2208 
2209   // Create a FusedBatchNorm op that has multiple output ports.
CreateGrapplerItemWithInterDeviceTransfers()2210   void CreateGrapplerItemWithInterDeviceTransfers() {
2211     tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
2212 
2213     // Create a FusedBatchNorm op that has multiple output ports.
2214     auto x = ops::RandomUniform(
2215         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
2216     auto scale =
2217         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
2218     auto offset =
2219         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
2220     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
2221     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
2222 
2223     auto batch_norm = ops::FusedBatchNorm(
2224         s.WithOpName("bn"), x, scale, offset, mean, var,
2225         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
2226     auto y = batch_norm.y;
2227     auto batch_mean = batch_norm.batch_mean;
2228     auto batch_var = batch_norm.batch_variance;
2229     // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
2230     auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
2231     auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
2232     // batch_mean1 and batch_var1 take different output ports, so each will
2233     // initiate Send/Recv.
2234     auto batch_mean1 = ops::Identity(
2235         s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
2236     auto batch_var1 =
2237         ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
2238     // This is control dependency.
2239     auto control_dep = ops::NoOp(s.WithOpName("control_dep")
2240                                      .WithControlDependencies(y)
2241                                      .WithDevice(kCPU1));
2242 
2243     grappler_item_ = absl::make_unique<GrapplerItem>();
2244     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
2245     grappler_item_->id = "test_conv2d_graph";
2246     grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
2247                              "control_dep"};
2248 
2249     dependency_["bn"] = {"x", "mean", "var"};
2250     dependency_["y1"] = {"bn"};
2251     dependency_["y2"] = {"bn"};
2252     dependency_["batch_mean1"] = {"bn"};
2253     dependency_["batch_var1"] = {"bn"};
2254     dependency_["control_dep"] = {"bn"};
2255   }
2256 
2257   // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()2258   void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); }
2259 
2260   // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const2261   Costs SimplePredictCosts(const OpContext& op_context) const {
2262     Costs c;
2263     int64 exec_cost = 0;
2264     if (op_context.op_info.op() == "MatMul") {
2265       exec_cost = 2000000000;
2266     } else if (op_context.op_info.op() == "RandomUniform") {
2267       exec_cost = 1000000000;
2268     } else {
2269       exec_cost = 1000;
2270     }
2271     c.execution_time = Costs::NanoSeconds(exec_cost);
2272     return c;
2273   }
2274 
2275   // Call this after init scheduler_. Scheduler stops after executing
2276   // target_node.
RunScheduler(const string & target_node)2277   std::unordered_map<string, OpContext> RunScheduler(
2278       const string& target_node) {
2279     std::unordered_map<string, OpContext> ops_executed;
2280     bool more_nodes = true;
2281     do {
2282       OpContext op_context = scheduler_->GetCurrNode();
2283       ops_executed[op_context.name] = op_context;
2284       std::cout << op_context.name << std::endl;
2285 
2286       Costs node_costs = SimplePredictCosts(op_context);
2287 
2288       // Check scheduling order.
2289       auto it = dependency_.find(op_context.name);
2290       if (it != dependency_.end()) {
2291         for (const auto& preceding_node : it->second) {
2292           EXPECT_GT(ops_executed.count(preceding_node), 0);
2293         }
2294       }
2295       more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
2296 
2297       if (op_context.name == target_node) {
2298         // Scheduler has the state after executing the target node.
2299         break;
2300       }
2301     } while (more_nodes);
2302     return ops_executed;
2303   }
2304 
2305   // Helper method for validating a vector.
2306   template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)2307   void ExpectVectorEq(const std::vector<T>& expected,
2308                       const std::vector<T>& test_elements) {
2309     // Set of expected elements for an easy comparison.
2310     std::set<T> expected_set(expected.begin(), expected.end());
2311     for (const auto& element : test_elements) {
2312       EXPECT_GT(expected_set.count(element), 0);
2313     }
2314     EXPECT_EQ(expected.size(), test_elements.size());
2315   }
2316 
2317   // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)2318   void ValidateNodeDefs(const std::vector<string>& expected,
2319                         const std::vector<const NodeDef*>& node_defs) {
2320     std::vector<string> node_names;
2321     std::transform(node_defs.begin(), node_defs.end(),
2322                    std::back_inserter(node_names),
2323                    [](const NodeDef* node) { return node->name(); });
2324     ExpectVectorEq(expected, node_names);
2325   }
2326 
2327   // Helper method for validating a set.
2328   template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)2329   void ExpectSetEq(const std::set<T>& expected,
2330                    const std::set<T>& test_elements) {
2331     for (const auto& element : test_elements) {
2332       EXPECT_GT(expected.count(element), 0);
2333     }
2334     EXPECT_EQ(expected.size(), test_elements.size());
2335   }
2336 
2337   // Helper method for validating an unordered map.
2338   template <typename T, typename U>
ExpectUnorderedMapEq(const std::unordered_map<T,U> & expected,const std::unordered_map<T,U> & test_map)2339   void ExpectUnorderedMapEq(const std::unordered_map<T, U>& expected,
2340                             const std::unordered_map<T, U>& test_map) {
2341     EXPECT_EQ(expected.size(), test_map.size());
2342     for (const auto& key_val : expected) {
2343       EXPECT_GT(test_map.count(key_val.first), 0);
2344       EXPECT_EQ(test_map.at(key_val.first), key_val.second);
2345     }
2346   }
2347 
2348   // Helper method that checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)2349   void ValidateMemoryUsageSnapshot(
2350       const std::vector<string>& expected_names, const int port_num_expected,
2351       const std::unordered_set<std::pair<const NodeDef*, int>,
2352                                DeviceState::NodePairHash>& mem_usage_snapshot) {
2353     std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
2354     std::transform(
2355         mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
2356         std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
2357         [](const std::pair<const NodeDef*, int>& node_port) {
2358           return std::make_pair(node_port.first->name(), node_port.second);
2359         });
2360     std::set<std::pair<string, int>> expected;
2361     std::transform(expected_names.begin(), expected_names.end(),
2362                    std::inserter(expected, expected.begin()),
2363                    [port_num_expected](const string& name) {
2364                      return std::make_pair(name, port_num_expected);
2365                    });
2366     ExpectSetEq(expected, nodes_at_peak_mem_usage);
2367   }
2368 
2369   // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64> & start_times,const std::vector<string> & nodes_in_dependency_order)2370   void ValidateDependencyChain(
2371       const std::unordered_map<string, int64>& start_times,
2372       const std::vector<string>& nodes_in_dependency_order) {
2373     int64 prev_node_time = -1;
2374     for (const auto& node : nodes_in_dependency_order) {
2375       int64 curr_node_time = start_times.at(node);
2376       EXPECT_GE(curr_node_time, prev_node_time);
2377       prev_node_time = curr_node_time;
2378     }
2379   }
2380 
2381   // cluster_ and scheduler_ are initialized in the c'tor.
2382   std::unique_ptr<VirtualCluster> cluster_;
2383   std::unique_ptr<TestVirtualScheduler> scheduler_;
2384   FirstReadyManager first_ready_manager_;
2385   CompositeNodeManager composite_node_manager_;
2386 
2387   // grappler_item_ will be initialized differently for each test case.
2388   std::unique_ptr<GrapplerItem> grappler_item_;
2389   // Node name -> its preceding nodes map for testing scheduling order.
2390   std::unordered_map<string, std::vector<string>> dependency_;
2391 
2392   // Shared params for Conv2D related graphs:
2393   const int batch_size_ = 4;
2394   const int width_ = 10;
2395   const int height_ = 10;
2396   const int depth_in_ = 8;
2397   const int kernel_ = 3;
2398   const int depth_out_ = 16;
2399 };
2400 
2401 // Create small graph, run predict costs on it, make sure the costs from the
2402 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)2403 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
2404   // Run matmul test.
2405   CreateGrapplerItemWithMatmulChain();
2406   InitScheduler();
2407   auto ops_executed = RunScheduler("");
2408   Costs c = scheduler_->Summary();
2409 
2410   // RandomUniform - 5 * 1s
2411   // Matmuls - 4 * 2s = 8
2412   // Misc - 5 * 1us
2413   // Total: 13000005
2414   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2415   EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2416   EXPECT_FALSE(c.inaccurate);
2417   EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2418 }
2419 
2420 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
2421 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)2422 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
2423   // Run matmul test.
2424   CreateGrapplerItemWithMatmulChain();
2425   InitScheduler();
2426   auto ops_executed = RunScheduler("");
2427   RunMetadata metadata;
2428   Costs c = scheduler_->Summary(&metadata);
2429   StepStats stepstats = metadata.step_stats();
2430   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2431   EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2432   EXPECT_FALSE(c.inaccurate);
2433   EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2434 
2435   // Should only be 1 device!
2436   EXPECT_EQ(1, stepstats.dev_stats().size());
2437 
2438   // Create a map of op name -> start and end times (micros).
2439   std::map<string, std::pair<int64, int64>> start_end_times;
2440   for (const auto& device_step_stats : stepstats.dev_stats()) {
2441     for (const auto& stats : device_step_stats.node_stats()) {
2442       int64 start = stats.all_start_micros();
2443       int64 end = start + stats.all_end_rel_micros();
2444       start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end);
2445 
2446       // Make sure that the output properties are correct for
2447       // MatMul and RandomUniform operations.
2448       // We only check for dtype, and shape (excluding alloc)
2449       // since alloc is not set by the virtual scheduler.
2450       if (stats.timeline_label() == "MatMul" ||
2451           stats.timeline_label() == "RandomUniform") {
2452         EXPECT_EQ(1, stats.output().size());
2453         for (const auto& output : stats.output()) {
2454           EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
2455           EXPECT_EQ(2, output.tensor_description().shape().dim().size());
2456           for (const auto& dim : output.tensor_description().shape().dim()) {
2457             EXPECT_EQ(3200, dim.size());
2458           }
2459         }
2460       }
2461     }
2462   }
2463 
2464   // The base start_time is the time to compute RandomUniforms
2465   int64 cur_time = static_cast<int64>(5000005);
2466   // The increment is the execution time of one matmul. See
2467   // CreateGrapplerItemWithMatmulChain for details.
2468   int64 increment = static_cast<int64>(2000000);
2469   auto op_names = {"ab", "abc", "abcd", "abcde"};
2470   for (const auto& op_name : op_names) {
2471     int64 actual_start = start_end_times[op_name].first;
2472     int64 actual_end = start_end_times[op_name].second;
2473     int64 expected_start = cur_time;
2474     int64 expected_end = cur_time + increment;
2475     EXPECT_EQ(expected_start, actual_start);
2476     EXPECT_EQ(expected_end, actual_end);
2477     cur_time += increment;
2478   }
2479 }
2480 
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)2481 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
2482   // Init.
2483   CreateGrapplerItemWithConv2Ds();
2484   InitScheduler();
2485 
2486   // Run the scheduler.
2487   auto ops_executed = RunScheduler("");  // Run all the nodes.
2488 
2489   // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
2490   // executed.
2491   EXPECT_EQ(8, ops_executed.size());
2492 
2493   // x, y, f, c0, and c1 should be in the ops executed.
2494   EXPECT_GT(ops_executed.count("x"), 0);
2495   EXPECT_GT(ops_executed.count("y"), 0);
2496   EXPECT_GT(ops_executed.count("f"), 0);
2497   EXPECT_GT(ops_executed.count("c0"), 0);
2498   EXPECT_GT(ops_executed.count("c1"), 0);
2499 
2500   // z and c2 shouldn't be part of it.
2501   EXPECT_EQ(ops_executed.count("z"), 0);
2502   EXPECT_EQ(ops_executed.count("c2"), 0);
2503 
2504   // Check input / output properties.
2505   EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
2506   EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
2507   EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
2508   EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
2509   EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
2510 }
2511 
TEST_F(VirtualSchedulerTest,MemoryUsage)2512 TEST_F(VirtualSchedulerTest, MemoryUsage) {
2513   // Init.
2514   CreateGrapplerItemWithAddN();
2515   InitScheduler();
2516 
2517   // Run the scheduler.
2518   RunScheduler("");
2519 
2520   const auto* device_states = scheduler_->GetDeviceStates();
2521   const auto& cpu_state = device_states->at(kCPU0);
2522 
2523   // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
2524   // is 4 x the input tensor size while executing the out node.
2525   int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
2526   const std::vector<string> expected_names = {"x", "y", "z", "w", "add"};
2527   EXPECT_EQ(expected_names.size() * one_input_node_size,
2528             cpu_state.max_memory_usage);
2529   ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
2530                               cpu_state.mem_usage_snapshot_at_peak);
2531   ExpectUnorderedMapEq(
2532       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 64)},
2533       scheduler_->GetPersistentMemoryUsage());
2534   ExpectUnorderedMapEq(
2535       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 200000)},
2536       scheduler_->GetPeakMemoryUsage());
2537 }
2538 
TEST_F(VirtualSchedulerTest,MemoryUsageForStreamingOps)2539 TEST_F(VirtualSchedulerTest, MemoryUsageForStreamingOps) {
2540   // Init.
2541   CreateGrapplerItemWithAddN();
2542   auto& graph = grappler_item_->graph;
2543   // Nodes add and out are placed on CPU1.
2544   // Nodes x, y are allocate in memory, while Nodes z and w are streaming nodes.
2545   for (auto& node : *graph.mutable_node()) {
2546     if (node.name() == "out" || node.name() == "add") {
2547       node.set_device(kCPU1);
2548     }
2549     if (node.name() == "z" || node.name() == "w")
2550       (*node.mutable_attr())[kStreaming].mutable_list()->add_b(true);
2551   }
2552 
2553   InitScheduler();
2554 
2555   // Run the scheduler.
2556   auto ops_executed = RunScheduler("");
2557 
2558   const auto* device_states = scheduler_->GetDeviceStates();
2559   const auto& cpu_state_0 = device_states->at(kCPU0);
2560   const auto& cpu_state_1 = device_states->at(kCPU1);
2561   // All tensors are of the same size, 10 x 10 x 10 x 10.
2562   int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
2563   const std::vector<string> cpu_0_expected_tensors = {"x", "y"};
2564   const std::vector<string> cpu_1_expected_tensors = {"x", "y", "add"};
2565   EXPECT_EQ(cpu_0_expected_tensors.size() * one_input_node_size,
2566             cpu_state_0.max_memory_usage);
2567   EXPECT_EQ(cpu_1_expected_tensors.size() * one_input_node_size,
2568             cpu_state_1.max_memory_usage);
2569   // After the graph is executed, at the end, memory usage for the device
2570   // should be zero.
2571   EXPECT_EQ(cpu_state_0.memory_usage, 0);
2572   EXPECT_EQ(cpu_state_1.memory_usage, 0);
2573 }
2574 
TEST_F(VirtualSchedulerTest,UnnecessaryFeedNodes)2575 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
2576   CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
2577   InitScheduler();
2578 
2579   // Test that scheduler can run graphs with extra unnecessary feed nodes.
2580   auto ops_executed = RunScheduler("");
2581   ASSERT_EQ(1, ops_executed.size());
2582   ASSERT_EQ(ops_executed.count("x"), 1);
2583 }
2584 
TEST_F(VirtualSchedulerTest,ControlDependency)2585 TEST_F(VirtualSchedulerTest, ControlDependency) {
2586   // Init.
2587   CreateGrapplerItemWithControlDependency();
2588   InitScheduler();
2589 
2590   // Run the scheduler.
2591   RunScheduler("");
2592 
2593   const auto* device_states = scheduler_->GetDeviceStates();
2594   const auto& cpu_state = device_states->at(kCPU0);
2595 
2596   // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
2597   // memory usage is when executing the final NoOp.
2598   int64 one_input_node_size = 4;  // control dependency
2599   const std::vector<string> expected_names = {"x", "y", "z", "w",
2600                                               "u", "v", "t"};
2601   EXPECT_EQ(expected_names.size() * one_input_node_size,
2602             cpu_state.max_memory_usage);
2603   ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
2604                               cpu_state.mem_usage_snapshot_at_peak);
2605   ExpectUnorderedMapEq(
2606       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 0)},
2607       scheduler_->GetPersistentMemoryUsage());
2608   ExpectUnorderedMapEq(
2609       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 28)},
2610       scheduler_->GetPeakMemoryUsage());
2611 }
2612 
TEST_F(VirtualSchedulerTest,ComplexDependency)2613 TEST_F(VirtualSchedulerTest, ComplexDependency) {
2614   // Init.
2615   CreateGrapplerItemWithBatchNorm();
2616   InitScheduler();
2617 
2618   // Run the scheduler.
2619   RunScheduler("bn");
2620 
2621   const auto& device_states = scheduler_->GetDeviceStates();
2622   const auto& cpu_state = device_states->at(kCPU0);
2623 
2624   // The graph is
2625   //  bn = FusedBatchNorm(x, scale, offset, mean, var)
2626   //  z1 = bn.y + x
2627   //  z2 = bn.var + bn.var
2628   //  z3 = bn.var + bn.var
2629   //  z4 = control dependency from bn.
2630   //  Note that bn.mean doesn't have any consumer.
2631   const int x_size = batch_size_ * width_ * height_ * depth_in_;
2632   int64 expected_size =
2633       4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
2634            1 /* control dependency */);
2635   EXPECT_EQ(expected_size, cpu_state.memory_usage);
2636 
2637   // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
2638   std::set<std::pair<string, int>> nodes_in_memory;
2639   std::transform(
2640       cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
2641       std::inserter(nodes_in_memory, nodes_in_memory.begin()),
2642       [](const std::pair<const NodeDef*, int>& node_port) {
2643         return std::make_pair(node_port.first->name(), node_port.second);
2644       });
2645   std::set<std::pair<string, int>> expected = {
2646       std::make_pair("bn", -1),
2647       std::make_pair("bn", 0),
2648       std::make_pair("bn", 2),
2649       std::make_pair("x", 0),
2650   };
2651   ExpectSetEq(expected, nodes_in_memory);
2652 
2653   const auto* node_states = scheduler_->GetNodeStates();
2654   const NodeState* bn_node = nullptr;
2655   const NodeState* x_node = nullptr;
2656   for (const auto& nodedef_node_state : *node_states) {
2657     const NodeDef* node = nodedef_node_state.first;
2658     const NodeState& node_state = nodedef_node_state.second;
2659     if (node->name() == "bn") {
2660       bn_node = &node_state;
2661     }
2662     if (node->name() == "x") {
2663       x_node = &node_state;
2664     }
2665   }
2666   CHECK_NOTNULL(bn_node);
2667   CHECK_NOTNULL(x_node);
2668 
2669   ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
2670   ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
2671   ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
2672   // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
2673   ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
2674 }
2675 
TEST_F(VirtualSchedulerTest,Variable)2676 TEST_F(VirtualSchedulerTest, Variable) {
2677   // Init.
2678   CreateGrapplerItemWithConv2DAndVariable();
2679   InitScheduler();
2680 
2681   // Run the scheduler.
2682   RunScheduler("");
2683 
2684   const auto* device_states = scheduler_->GetDeviceStates();
2685   const auto& cpu_state = device_states->at(kCPU0);
2686 
2687   // There is one Conv2D that takes x and f, but f is variable, so it should be
2688   // in persistent nodes.
2689   ValidateMemoryUsageSnapshot({"f", "Const/Const"}, /*port_num_expected=*/0,
2690                               cpu_state.persistent_nodes);
2691   // Only x in peak memory usage snapshot.
2692   ValidateMemoryUsageSnapshot({"x"}, /*port_num_expected=*/0,
2693                               cpu_state.mem_usage_snapshot_at_peak);
2694   ExpectUnorderedMapEq(
2695       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 4624)},
2696       scheduler_->GetPersistentMemoryUsage());
2697   ExpectUnorderedMapEq(
2698       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 12800)},
2699       scheduler_->GetPeakMemoryUsage());
2700 }
2701 
TEST_F(VirtualSchedulerTest,WhileLoop)2702 TEST_F(VirtualSchedulerTest, WhileLoop) {
2703   // Init.
2704   CreateGrapplerItemWithLoop();
2705   InitScheduler();
2706 
2707   // Run the scheduler.
2708   RunScheduler("");
2709 
2710   // Check the timeline
2711   RunMetadata metadata;
2712   scheduler_->Summary(&metadata);
2713 
2714   // Nodes in topological order:
2715   // * const, ones
2716   // * while/Enter, while/Enter_1
2717   // * while/Merge, while/Merge_1
2718   // * while/Less/y
2719   // * while/Less
2720   // * while/LoopCond
2721   // * while/Switch, while/Switch_1
2722   // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
2723   // * while/add/y, while/concat/axis
2724   // * while/add, while/concat
2725   // * while/NextIteration, while/NextIteration_1
2726 
2727   int num_next_iteration = 0;
2728   int num_next_iteration_1 = 0;
2729   int num_exit = 0;
2730   int num_exit_1 = 0;
2731   int64 next_iter_start_micro;
2732   int64 next_iter_1_start_micro;
2733   int64 exit_start_micro;
2734   int64 exit_1_start_micro;
2735 
2736   std::unordered_map<string, int64> start_times;
2737   for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2738     for (const auto& stats : device_step_stats.node_stats()) {
2739       start_times[stats.node_name()] = stats.all_start_micros();
2740       if (stats.node_name() == "while/NextIteration") {
2741         ++num_next_iteration;
2742         next_iter_start_micro = stats.all_start_micros();
2743       } else if (stats.node_name() == "while/NextIteration_1") {
2744         ++num_next_iteration_1;
2745         next_iter_1_start_micro = stats.all_start_micros();
2746       } else if (stats.node_name() == "while/Exit") {
2747         ++num_exit;
2748         exit_start_micro = stats.all_start_micros();
2749       } else if (stats.node_name() == "while/Exit_1") {
2750         ++num_exit_1;
2751         exit_1_start_micro = stats.all_start_micros();
2752       }
2753     }
2754   }
2755 
2756   // Make sure we went though the body of the loop once, and that the output of
2757   // the loop was scheduled as well.
2758   EXPECT_EQ(1, num_next_iteration);
2759   EXPECT_EQ(1, num_next_iteration_1);
2760   EXPECT_EQ(1, num_exit);
2761   EXPECT_EQ(1, num_exit_1);
2762 
2763   // Start times of while/NextIteration and while/NextIteration_1 should be
2764   // different, so should be those of while/Exit and while/Exit_1.
2765   EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
2766   EXPECT_NE(exit_start_micro, exit_1_start_micro);
2767 
2768   // Check dependency among the nodes; no matter what scheduling mechanism we
2769   // use, the scheduled ops should follow these dependency chains.
2770   // Note that currently, VirtualScheduler executes while/Merge twice; hence,
2771   // we're not testing dependency chains related to while/Merge.
2772   // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
2773   // order of Enter, Merge, ...loop condition ..., ... loop body ...,
2774   // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
2775   // chaining test w/ Merge nodes.
2776   ValidateDependencyChain(
2777       start_times,
2778       {"Const", "while/Enter",  // "while/Merge",
2779        "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
2780        "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
2781   // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
2782   ValidateDependencyChain(start_times,
2783                           {"ones", "while/Enter_1",  // "while/Merge_1",
2784                            "while/Switch_1", "while/Identity_1", "while/concat",
2785                            "while/NextIteration_1"});
2786   ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
2787   ValidateDependencyChain(
2788       start_times, {"while/Identity", "while/concat/axis", "while/concat"});
2789   ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
2790   ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
2791 }
2792 
TEST_F(VirtualSchedulerTest,AnnotatedWhileLoop)2793 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
2794   {
2795     // Init.
2796     CreateGrapplerItemWithLoop();
2797     InitScheduler();
2798 
2799     // Runs the scheduler.
2800     RunScheduler("");
2801     Costs c = scheduler_->Summary();
2802 
2803     EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
2804     // Both while/Merge and while/Merge_1 are scheduled twice.
2805     EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2806     EXPECT_FALSE(c.inaccurate);
2807     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2808   }
2809 
2810   {
2811     // Init.
2812     CreateGrapplerItemWithLoopAnnotated();
2813     InitScheduler();
2814 
2815     // Runs the scheduler.
2816     RunScheduler("");
2817     Costs c = scheduler_->Summary();
2818 
2819     // The costs for Merge is accumulated twice for execution_count times, but
2820     // since Merge's cost is minimal, we keep this behavior here.
2821     EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
2822     // Both while/Merge and while/Merge_1 are scheduled twice.
2823     EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2824     EXPECT_FALSE(c.inaccurate);
2825     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2826   }
2827 }
2828 
TEST_F(VirtualSchedulerTest,Condition)2829 TEST_F(VirtualSchedulerTest, Condition) {
2830   // Without annotation.
2831   {
2832     // Inits.
2833     CreateGrapplerItemWithCondition();
2834     InitScheduler();
2835 
2836     // Runs the scheduler.
2837     RunScheduler("");
2838     RunMetadata metadata;
2839     Costs c = scheduler_->Summary(&metadata);
2840 
2841     // Nodes in topological order: a/Less, Switch, First/Second, Merge.
2842     int num_a = 0;
2843     int num_less = 0;
2844     int num_switch = 0;
2845     int num_first = 0;
2846     int num_second = 0;
2847     int num_merge = 0;
2848 
2849     for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2850       for (const auto& stats : device_step_stats.node_stats()) {
2851         if (stats.node_name() == "a") {
2852           ++num_a;
2853         } else if (stats.node_name() == "Less") {
2854           ++num_less;
2855         } else if (stats.node_name() == "Switch") {
2856           ++num_switch;
2857         } else if (stats.node_name() == "First") {
2858           ++num_first;
2859         } else if (stats.node_name() == "Second") {
2860           ++num_second;
2861         } else if (stats.node_name() == "Merge") {
2862           ++num_merge;
2863         }
2864       }
2865     }
2866 
2867     EXPECT_EQ(1, num_a);
2868     EXPECT_EQ(1, num_less);
2869     EXPECT_EQ(1, num_switch);
2870     EXPECT_EQ(1, num_first);
2871     EXPECT_EQ(1, num_second);
2872     EXPECT_EQ(2, num_merge);
2873 
2874     EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
2875     // Merge is executed twice.
2876     EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
2877     EXPECT_FALSE(c.inaccurate);
2878     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2879   }
2880 
2881   // With annotation.
2882   {
2883     // Inits.
2884     CreateGrapplerItemWithCondition();
2885 
2886     // Annotates the Switch node.
2887     for (auto& node : *grappler_item_->graph.mutable_node()) {
2888       if (node.name() == "Switch") {
2889         AttrValue attr_output_info;
2890         // Adds one output slot 0 so that Second shouldn't be executed.
2891         (*attr_output_info.mutable_list()).add_i(0);
2892         AddNodeAttr(kOutputSlots, attr_output_info, &node);
2893       }
2894     }
2895 
2896     InitScheduler();
2897 
2898     // Runs the scheduler.
2899     RunScheduler("");
2900     RunMetadata metadata;
2901     Costs c = scheduler_->Summary(&metadata);
2902 
2903     // Nodes in topological order: a/Less, Switch, Merge
2904     int num_a = 0;
2905     int num_less = 0;
2906     int num_switch = 0;
2907     int num_first = 0;
2908     int num_second = 0;
2909     int num_merge = 0;
2910 
2911     for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2912       for (const auto& stats : device_step_stats.node_stats()) {
2913         if (stats.node_name() == "a") {
2914           ++num_a;
2915         } else if (stats.node_name() == "Less") {
2916           ++num_less;
2917         } else if (stats.node_name() == "Switch") {
2918           ++num_switch;
2919         } else if (stats.node_name() == "First") {
2920           ++num_first;
2921         } else if (stats.node_name() == "Second") {
2922           ++num_second;
2923         } else if (stats.node_name() == "Merge") {
2924           ++num_merge;
2925         }
2926       }
2927     }
2928 
2929     EXPECT_EQ(1, num_a);
2930     EXPECT_EQ(1, num_less);
2931     EXPECT_EQ(1, num_switch);
2932     EXPECT_EQ(1, num_first);
2933     EXPECT_EQ(0, num_second);
2934     EXPECT_EQ(1, num_merge);
2935 
2936     EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
2937     // Second is not executed.
2938     EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
2939     EXPECT_FALSE(c.inaccurate);
2940     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2941   }
2942 }
2943 
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)2944 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
2945   // Init.
2946   CreateGrapplerItemWithInterDeviceTransfers();
2947   InitScheduler();
2948 
2949   // Run the scheduler.
2950   auto ops_executed = RunScheduler("");
2951 
2952   // Helper lambda to extract port num from _Send and _Recv op name.
2953   auto get_port_num = [](const string& name) -> int {
2954     if (name.find("bn_0") != string::npos) {
2955       return 0;
2956     } else if (name.find("bn_1") != string::npos) {
2957       return 1;
2958     } else if (name.find("bn_2") != string::npos) {
2959       return 2;
2960     } else if (name.find("bn_minus1") != string::npos) {
2961       return -1;
2962     }
2963     return -999;
2964   };
2965 
2966   // Reorganize ops_executed for further testing.
2967   std::unordered_map<string, int> op_count;
2968   std::unordered_map<int, string> recv_op_names;
2969   std::unordered_map<int, string> send_op_names;
2970   for (const auto& x : ops_executed) {
2971     const auto& name = x.first;
2972     const auto& node_info = x.second;
2973     const auto& op = node_info.op_info.op();
2974     if (op == kRecv) {
2975       recv_op_names[get_port_num(name)] = name;
2976     } else if (op == kSend) {
2977       send_op_names[get_port_num(name)] = name;
2978     }
2979     op_count[op]++;
2980   }
2981 
2982   // Same number of _Send and _Recv.
2983   EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
2984 
2985   // Expect 3 Send and Recvs each: port 0, 1, and, 2.
2986   // Control dependency bypasses the channel.
2987   EXPECT_EQ(op_count.at(kRecv), 3);
2988   EXPECT_EQ(op_count.at(kSend), 3);
2989 
2990   // Helper lambda for extracting output Tensor size.
2991   auto get_output_size = [this, ops_executed](const string& name) -> int64 {
2992     const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
2993     std::vector<OpInfo::TensorProperties> output_properties;
2994     for (const auto& output_property : output_properties_) {
2995       output_properties.push_back(output_property);
2996     }
2997     return CalculateOutputSize(output_properties, 0);
2998   };
2999 
3000   // Validate transfer size.
3001   // Batchnorm output y is 4D vector: batch x width x width x depth.
3002   int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
3003   EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
3004   EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
3005   // Mean and vars are 1-D vector with size depth_in_.
3006   EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
3007   EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
3008   EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
3009   EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
3010 }
3011 
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)3012 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
3013   // Init.
3014   CreateGrapplerItemWithSendRecv();
3015   InitScheduler();
3016 
3017   // Run the scheduler.
3018   auto ops_executed = RunScheduler("");
3019 
3020   EXPECT_GT(ops_executed.count("Const"), 0);
3021   EXPECT_GT(ops_executed.count("Send"), 0);
3022   EXPECT_GT(ops_executed.count("Recv"), 0);
3023 }
3024 
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)3025 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
3026   // Init.
3027   CreateGrapplerItemWithSendRecv();
3028   // Change Recv node's device so that Send and Recv are placed on different
3029   // devices.
3030   auto& graph = grappler_item_->graph;
3031   const string recv_device = kCPU1;
3032   for (int i = 0; i < graph.node_size(); i++) {
3033     auto* node = graph.mutable_node(i);
3034     if (node->name() == "Recv") {
3035       node->set_device(recv_device);
3036       auto* attr = node->mutable_attr();
3037       (*attr)["recv_device"].set_s(recv_device);
3038     } else if (node->name() == "Send") {
3039       auto* attr = node->mutable_attr();
3040       (*attr)["recv_device"].set_s(recv_device);
3041     }
3042   }
3043   InitScheduler();
3044 
3045   // Run the scheduler.
3046   auto ops_executed = RunScheduler("");
3047 
3048   // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
3049   EXPECT_GT(ops_executed.count("Const"), 0);
3050   EXPECT_GT(ops_executed.count("Send"), 0);
3051   EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
3052                                "task_0/cpu_0_to_/job_localhost"
3053                                "/replica_0/task_0/cpu_1"),
3054             0);
3055   EXPECT_GT(ops_executed.count(
3056                 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
3057             0);
3058   EXPECT_GT(ops_executed.count("Recv"), 0);
3059 }
3060 
TEST_F(VirtualSchedulerTest,GraphWihtOnlyRecv)3061 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) {
3062   // Init.
3063   CreateGrapplerItemWithRecvWithoutSend();
3064   InitScheduler();
3065 
3066   // Run the scheduler.
3067   auto ops_executed = RunScheduler("");
3068 
3069   // Recv without Send will be treated as initially ready node.
3070   EXPECT_GT(ops_executed.count("Recv"), 0);
3071 }
3072 
TEST_F(VirtualSchedulerTest,AddMergeSwitch)3073 TEST_F(VirtualSchedulerTest, AddMergeSwitch) {
3074   // Override scheduler_ with CompositeNodeManager.
3075   scheduler_ = absl::make_unique<TestVirtualScheduler>(
3076       /*use_static_shapes=*/true,
3077       /*use_aggressive_shape_inference=*/true, &composite_node_manager_,
3078       cluster_.get());
3079   CreateGrapplerItemWithSwitchMergeInput();
3080   InitScheduler();
3081 
3082   // pred --+                      z --+
3083   //        |                          |
3084   //        V                          V
3085   // x -> Switch --------> Merge ---> Add --> y
3086   //        |                ^
3087   //        |                |
3088   //        +-----> Add -----+
3089   //                 ^
3090   //                 |
3091   // b --------------+
3092 
3093   // Run the scheduler. The current VirtualScheduler, w/o annotation, triggers
3094   // both outputs of Switch; then Merge (as long as one input is ready, it's z
3095   // is ready, if we just use num_inputs_ready counter, the final Add becomes
3096   // ready. possible to skip scheduling z. (Need to use CompositeNodeManager
3097   // to test this case).
3098   auto ops_executed = RunScheduler("");
3099 
3100   EXPECT_GT(ops_executed.count("z"), 0);
3101 }
3102 
TEST_F(VirtualSchedulerTest,AddFromOneTensor)3103 TEST_F(VirtualSchedulerTest, AddFromOneTensor) {
3104   CreateGrapplerItemWithAddFromOneTensor();
3105   InitScheduler();
3106 
3107   // x -+----> Add --> y
3108   //    |       ^
3109   //    |       |
3110   //    +-------+
3111 
3112   // Run the scheduler.
3113   auto ops_executed = RunScheduler("");
3114   EXPECT_GT(ops_executed.count("y"), 0);
3115   EXPECT_GT(ops_executed.count("x"), 0);
3116 }
3117 
3118 }  // namespace
3119 }  // end namespace grappler
3120 }  // end namespace tensorflow
3121