1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_description.pb.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
22 #include "tensorflow/core/grappler/costs/utils.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace grappler {
29 namespace {
30
31 // Device names:
32 constexpr char kCPU0[] = "/job:localhost/replica:0/task:0/cpu:0";
33 constexpr char kCPU1[] = "/job:localhost/replica:0/task:0/cpu:1";
34 constexpr char kChannelFrom0To1[] = "Channel from CPU0 to CPU1";
35 constexpr char kChannelFrom1To0[] = "Channel from CPU1 to CPU0";
36 // Op names:
37 constexpr char kConv2D[] = "Conv2D";
38 constexpr char kSend[] = "_Send";
39 constexpr char kRecv[] = "_Recv";
40
41 class ReadyNodeManagerTest : public ::testing::Test {
42 protected:
ReadyNodeManagerTest()43 ReadyNodeManagerTest() {
44 // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
45 NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
46 NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
47 NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
48 NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
49 NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
50 NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
51 }
52
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)53 void NodeSetUp(const string& name, const string& op_name,
54 const string& device_name, const uint64 time_ready,
55 NodeDef* node) {
56 node->set_name(name);
57 node->set_op(op_name);
58 node->set_device(device_name);
59
60 node_states_[node] = NodeState();
61 node_states_[node].time_ready = time_ready;
62 node_states_[node].device_name = device_name;
63 }
64
65 NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
66 std::unordered_map<const NodeDef*, NodeState> node_states_;
67 };
68
69 // Tests that FIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeFIFOManager)70 TEST_F(ReadyNodeManagerTest, GetSingleNodeFIFOManager) {
71 FIFOManager manager = FIFOManager();
72 manager.AddNode(&node1_);
73 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
74 }
75
76 // Tests that FIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFIFOManager)77 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFIFOManager) {
78 FIFOManager manager = FIFOManager();
79 manager.AddNode(&node1_);
80
81 // Removes the only node in FIFOManager.
82 manager.RemoveCurrNode();
83 EXPECT_TRUE(manager.Empty());
84 }
85
86 // Tests that FIFOManager can remove multiple nodes and returns the current node
87 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFIFOManager)88 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFIFOManager) {
89 FIFOManager manager = FIFOManager();
90 manager.AddNode(&node1_);
91 manager.AddNode(&node2_);
92 manager.AddNode(&node3_);
93 manager.AddNode(&node4_);
94
95 // Keeps checking current node while removing nodes from manager.
96 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
97 manager.RemoveCurrNode();
98 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
99 manager.RemoveCurrNode();
100 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
101 manager.RemoveCurrNode();
102 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
103 manager.RemoveCurrNode();
104 EXPECT_TRUE(manager.Empty());
105 }
106
107 // Tests that FIFOManager can remove multiple nodes and add more nodes, still
108 // returning the current node in the right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleFIFOManager)109 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleFIFOManager) {
110 FIFOManager manager = FIFOManager();
111 manager.AddNode(&node1_);
112 manager.AddNode(&node2_);
113 manager.AddNode(&node3_);
114 manager.AddNode(&node4_);
115
116 // Keeps checking current node as nodes are removed and added.
117 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
118 manager.RemoveCurrNode();
119 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
120 manager.AddNode(&node5_);
121 // GetCurrNode() should return the same node even if some nodes are added,
122 // until RemoveCurrNode() is called.
123 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
124 manager.RemoveCurrNode();
125 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
126 manager.RemoveCurrNode();
127 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
128 manager.AddNode(&node6_);
129 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
130 manager.RemoveCurrNode();
131 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
132 manager.RemoveCurrNode();
133 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
134 manager.RemoveCurrNode();
135 EXPECT_TRUE(manager.Empty());
136 }
137
138 // Tests that LIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeLIFOManager)139 TEST_F(ReadyNodeManagerTest, GetSingleNodeLIFOManager) {
140 LIFOManager manager = LIFOManager();
141 manager.AddNode(&node1_);
142 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
143 }
144
145 // Tests that LIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeLIFOManager)146 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeLIFOManager) {
147 LIFOManager manager = LIFOManager();
148 manager.AddNode(&node1_);
149
150 // Removes the only node in LIFOManager.
151 manager.RemoveCurrNode();
152 EXPECT_TRUE(manager.Empty());
153 }
154
155 // Tests that LIFOManager can remove multiple nodes and returns the current node
156 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleLIFOManager)157 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleLIFOManager) {
158 LIFOManager manager = LIFOManager();
159 manager.AddNode(&node1_);
160 manager.AddNode(&node2_);
161 manager.AddNode(&node3_);
162 manager.AddNode(&node4_);
163
164 // Keeps checking current node while removing nodes from manager.
165 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
166 manager.RemoveCurrNode();
167 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
168 manager.RemoveCurrNode();
169 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
170 manager.RemoveCurrNode();
171 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
172 manager.RemoveCurrNode();
173 EXPECT_TRUE(manager.Empty());
174 }
175
176 // Tests that LIFOManager can remove multiple nodes (must be removing the
177 // current node) and add more nodes, still returning the current node in the
178 // right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleLIFOManager)179 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
180 LIFOManager manager = LIFOManager();
181 manager.AddNode(&node1_);
182 manager.AddNode(&node2_);
183 manager.AddNode(&node3_);
184 manager.AddNode(&node4_);
185
186 // Keeps checking current node as nodes are removed and added.
187 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
188 manager.RemoveCurrNode();
189 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
190 manager.AddNode(&node5_);
191 // GetCurrNode() should return the same node even if some nodes are added,
192 // until RemoveCurrNode() is called.
193 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
194 manager.RemoveCurrNode();
195 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
196 manager.RemoveCurrNode();
197 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
198 manager.AddNode(&node6_);
199 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
200 manager.RemoveCurrNode();
201 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
202 manager.RemoveCurrNode();
203 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
204 manager.RemoveCurrNode();
205 EXPECT_TRUE(manager.Empty());
206 }
207
TEST_F(ReadyNodeManagerTest,MergeOrderInLIFOManager)208 TEST_F(ReadyNodeManagerTest, MergeOrderInLIFOManager) {
209 LIFOManager manager = LIFOManager();
210 node3_.set_op("Merge");
211 manager.AddNode(&node1_);
212 manager.AddNode(&node2_);
213 manager.AddNode(&node3_);
214 manager.AddNode(&node4_);
215
216 // Merge node (node3) will be scheduled at the end (even though it's added
217 // after nodde2).
218 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
219 manager.RemoveCurrNode();
220 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
221 manager.RemoveCurrNode();
222 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
223 manager.RemoveCurrNode();
224 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
225 manager.RemoveCurrNode();
226 }
227
TEST_F(ReadyNodeManagerTest,GetSingleNodeFirstReadyManager)228 TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
229 FirstReadyManager manager;
230 TF_EXPECT_OK(manager.Init(&node_states_));
231 manager.AddNode(&node1_);
232 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
233 }
234
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFirstReadyManager)235 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
236 FirstReadyManager manager;
237 TF_EXPECT_OK(manager.Init(&node_states_));
238 manager.AddNode(&node1_);
239 manager.RemoveCurrNode();
240 EXPECT_TRUE(manager.Empty());
241 }
242
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFirstReadyManager)243 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
244 FirstReadyManager manager;
245 TF_EXPECT_OK(manager.Init(&node_states_));
246 // Insert nodes in some random order.
247 manager.AddNode(&node2_);
248 manager.AddNode(&node1_);
249 manager.AddNode(&node4_);
250 manager.AddNode(&node5_);
251 manager.AddNode(&node3_);
252 manager.AddNode(&node6_);
253
254 // In whatever order we insert nodes, we get the same order based on nodes'
255 // time_ready.
256 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
257 manager.RemoveCurrNode();
258 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
259 manager.RemoveCurrNode();
260 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
261 manager.RemoveCurrNode();
262 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
263 manager.RemoveCurrNode();
264 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
265 manager.RemoveCurrNode();
266 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
267 manager.RemoveCurrNode();
268 EXPECT_TRUE(manager.Empty());
269 }
270
TEST_F(ReadyNodeManagerTest,GetCurrNodeFirstReadyManager)271 TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
272 FirstReadyManager manager;
273 TF_EXPECT_OK(manager.Init(&node_states_));
274
275 // Inserts nodes in some random order.
276 manager.AddNode(&node2_);
277 manager.AddNode(&node1_);
278 manager.AddNode(&node4_);
279 manager.AddNode(&node5_);
280 manager.AddNode(&node3_);
281 manager.AddNode(&node6_);
282
283 // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
284 // should return it.
285 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
286
287 // Now inserts a few other nodes, but their time_ready's are even smaller than
288 // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
289 // the same node, Node6, in this case.
290 NodeDef node7;
291 NodeDef node8;
292 NodeDef node9;
293 NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
294 NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
295 NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
296
297 manager.AddNode(&node7);
298 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
299
300 manager.AddNode(&node8);
301 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
302
303 manager.RemoveCurrNode();
304 // Now Node6 is removed, and GetCurrNode() will return Node8.
305 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
306
307 // Again, AddNode shouldn't change GetCurrNode().
308 manager.AddNode(&node9);
309 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
310
311 manager.RemoveCurrNode();
312 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
313 manager.RemoveCurrNode();
314 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
315 manager.RemoveCurrNode();
316 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
317 manager.RemoveCurrNode();
318 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
319 manager.RemoveCurrNode();
320 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
321 manager.RemoveCurrNode();
322 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
323 manager.RemoveCurrNode();
324 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
325 manager.RemoveCurrNode();
326 EXPECT_TRUE(manager.Empty());
327 }
328
TEST_F(ReadyNodeManagerTest,DeterminismInFirstReadyManager)329 TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
330 FirstReadyManager manager1;
331 TF_EXPECT_OK(manager1.Init(&node_states_));
332 FirstReadyManager manager2;
333 TF_EXPECT_OK(manager2.Init(&node_states_));
334
335 // 6 nodes with same time_ready.
336 NodeDef node7;
337 NodeDef node8;
338 NodeDef node9;
339 NodeDef node10;
340 NodeDef node11;
341 NodeDef node12;
342 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
343 NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
344 NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
345 NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
346 NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
347 NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
348
349 // Adds the above 6 nodes to manager1.
350 manager1.AddNode(&node7);
351 manager1.AddNode(&node8);
352 manager1.AddNode(&node9);
353 manager1.AddNode(&node10);
354 manager1.AddNode(&node11);
355 manager1.AddNode(&node12);
356
357 // Adds the above 6 nodes to manager2, but in a different order.
358 manager2.AddNode(&node8);
359 manager2.AddNode(&node11);
360 manager2.AddNode(&node9);
361 manager2.AddNode(&node10);
362 manager2.AddNode(&node7);
363 manager2.AddNode(&node12);
364
365 // Expects both managers return the same nodes for deterministic node
366 // scheduling.
367 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
368 manager1.RemoveCurrNode();
369 manager2.RemoveCurrNode();
370
371 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
372 manager1.RemoveCurrNode();
373 manager2.RemoveCurrNode();
374
375 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
376 manager1.RemoveCurrNode();
377 manager2.RemoveCurrNode();
378
379 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
380 manager1.RemoveCurrNode();
381 manager2.RemoveCurrNode();
382
383 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
384 manager1.RemoveCurrNode();
385 manager2.RemoveCurrNode();
386
387 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
388 manager1.RemoveCurrNode();
389 manager2.RemoveCurrNode();
390
391 EXPECT_TRUE(manager1.Empty());
392 EXPECT_TRUE(manager2.Empty());
393 }
394
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultiplePriorityReadyManager)395 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultiplePriorityReadyManager) {
396 PriorityReadyManager manager;
397 TF_EXPECT_OK(manager.Init(&node_states_));
398
399 // Sets up node priorities.
400 std::unordered_map<string, int> node_priority = {
401 {"Node1", 1}, {"Node2", 2}, {"Node3", 2}, {"Node4", 4}, {"Node5", 5}};
402 TF_EXPECT_OK(manager.SetPriority(node_priority));
403
404 // Inserts nodes in some random order.
405 manager.AddNode(&node3_);
406 manager.AddNode(&node1_);
407 manager.AddNode(&node4_);
408 manager.AddNode(&node5_);
409 manager.AddNode(&node2_);
410 manager.AddNode(&node6_);
411
412 // Expects nodes scheduled based on priority.
413 // Node6 should default to lowest priority, since it is not found.
414 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
415 manager.RemoveCurrNode();
416 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
417 manager.RemoveCurrNode();
418 // Nodes 2 and 3 have equal priority and so should be scheduled ready-first.
419 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
420 manager.RemoveCurrNode();
421 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
422 manager.RemoveCurrNode();
423 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
424 manager.RemoveCurrNode();
425 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
426 manager.RemoveCurrNode();
427 EXPECT_TRUE(manager.Empty());
428 }
429
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeCompositeNodeManager)430 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
431 CompositeNodeManager manager;
432 TF_EXPECT_OK(manager.Init(&node_states_));
433 manager.AddNode(&node1_);
434 manager.RemoveCurrNode();
435 EXPECT_TRUE(manager.Empty());
436 }
437
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleCompositeNodeManager)438 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleCompositeNodeManager) {
439 CompositeNodeManager manager;
440 TF_EXPECT_OK(manager.Init(&node_states_));
441 manager.AddNode(&node1_);
442 manager.AddNode(&node2_);
443 manager.AddNode(&node3_);
444 manager.AddNode(&node4_);
445
446 // Keeps checking current node as nodes are removed and added.
447 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
448 manager.RemoveCurrNode();
449 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
450 manager.AddNode(&node5_);
451 // GetCurrNode() should return the same node even if some nodes are added,
452 // until RemoveCurrNode() is called.
453 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
454 manager.RemoveCurrNode();
455 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
456 manager.RemoveCurrNode();
457 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
458 manager.AddNode(&node6_);
459 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
460 manager.RemoveCurrNode();
461 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
462 manager.RemoveCurrNode();
463 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
464 manager.RemoveCurrNode();
465 EXPECT_TRUE(manager.Empty());
466 }
467
TEST_F(ReadyNodeManagerTest,MultiDeviceSendRecvCompositeNodeManager)468 TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvCompositeNodeManager) {
469 CompositeNodeManager manager;
470 TF_EXPECT_OK(manager.Init(&node_states_));
471 // Additional nodes on kCPU1.
472 NodeDef node7;
473 NodeDef node8;
474 NodeDef node9;
475 NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
476 NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
477 NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
478
479 // Send and Recv nodes.
480 NodeDef send1;
481 NodeDef send2;
482 NodeDef recv1;
483 NodeDef recv2;
484 NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
485 NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
486 NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
487 NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
488
489 // Inserts nodes.
490 manager.AddNode(&node1_);
491 manager.AddNode(&node2_);
492 manager.AddNode(&node3_);
493 manager.AddNode(&node4_);
494 manager.AddNode(&node5_);
495 manager.AddNode(&node6_);
496 manager.AddNode(&node7);
497 manager.AddNode(&node8);
498 manager.AddNode(&node9);
499 manager.AddNode(&send1);
500 manager.AddNode(&send2);
501 manager.AddNode(&recv1);
502 manager.AddNode(&recv2);
503
504 // On kCPU0; last one is node6_, on kCPU1: last one is node9;
505 // so choose one that has earliest time_ready among node6_, node9,
506 // Send1, Send2, Recv1, and Recv2.
507 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
508 manager.RemoveCurrNode();
509 // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
510 // among node5_, node9, Send1, Send2, Recv1, and Recv2.
511 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
512 manager.RemoveCurrNode();
513 // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
514 EXPECT_EQ(manager.GetCurrNode()->name(), "Send1");
515 manager.RemoveCurrNode();
516 // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
517 EXPECT_EQ(manager.GetCurrNode()->name(), "Recv1");
518 manager.RemoveCurrNode();
519 // Next, choose among node4_, node9, Send2, and Recv2.
520 EXPECT_EQ(manager.GetCurrNode()->name(), "Recv2");
521 manager.RemoveCurrNode();
522 // Next, choose among node4_, node9, and Send2.
523 EXPECT_EQ(manager.GetCurrNode()->name(), "Send2");
524 manager.RemoveCurrNode();
525 // Next, choose between node4_, node9.
526 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
527 manager.RemoveCurrNode();
528 // Next, choose between node3_, node9.
529 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
530 manager.RemoveCurrNode();
531 // Next, choose between node3_, node8.
532 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
533 manager.RemoveCurrNode();
534 // Next, choose between node3_, node7.
535 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
536 manager.RemoveCurrNode();
537 // Then, just the nodes on kCPU1 -- LIFO.
538 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
539 manager.RemoveCurrNode();
540 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
541 manager.RemoveCurrNode();
542 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
543 manager.RemoveCurrNode();
544 EXPECT_TRUE(manager.Empty());
545 }
546
TEST_F(ReadyNodeManagerTest,DeterminismInCompositeNodeManager)547 TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
548 CompositeNodeManager manager;
549 TF_EXPECT_OK(manager.Init(&node_states_));
550 CompositeNodeManager manager2;
551 TF_EXPECT_OK(manager2.Init(&node_states_));
552
553 // 6 nodes with same time_ready.
554 NodeDef node7;
555 NodeDef node8;
556 NodeDef node9;
557 NodeDef node10;
558 NodeDef node11;
559 NodeDef node12;
560 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
561 NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
562 NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
563 NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
564 NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
565 NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
566
567 // Adds Nodes 7 to 9 to manager.
568 manager.AddNode(&node7);
569 manager.AddNode(&node8);
570 manager.AddNode(&node9);
571
572 // It should return _Send, Recv, and the other op order, when the candidate
573 // nodes have same time_ready.
574 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
575 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
576 manager.RemoveCurrNode();
577 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
578 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
579 manager.RemoveCurrNode();
580 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
581 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
582 manager.RemoveCurrNode();
583 EXPECT_TRUE(manager.Empty());
584
585 // Adds Nodes 7 to 9 to manager, but in a different order.
586 manager.AddNode(&node9);
587 manager.AddNode(&node8);
588 manager.AddNode(&node7);
589
590 // Expects same order (_Send, _Recv, and the other op), regardless of Add
591 // order.
592 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
593 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
594 manager.RemoveCurrNode();
595 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
596 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
597 manager.RemoveCurrNode();
598 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
599 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
600 manager.RemoveCurrNode();
601 EXPECT_TRUE(manager.Empty());
602
603 // Conv2D's time_ready < Send's time_ready; Expects Conv2D first.
604 manager.AddNode(&node8);
605 manager.AddNode(&node10);
606 EXPECT_EQ(manager.GetCurrNode()->name(), "Node10");
607 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
608 manager.RemoveCurrNode();
609 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
610 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
611 manager.RemoveCurrNode();
612 EXPECT_TRUE(manager.Empty());
613
614 // Recv's time_ready < Send' time_ready; Expects Recv first.
615 manager.AddNode(&node11);
616 manager.AddNode(&node8);
617 EXPECT_EQ(manager.GetCurrNode()->name(), "Node11");
618 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
619 manager.RemoveCurrNode();
620 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
621 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
622 manager.RemoveCurrNode();
623 EXPECT_TRUE(manager.Empty());
624
625 // Node7 and 12 are normal ops with the same time_ready, placed on different
626 // devices. These two nodes are added to manager and manager2, but in
627 // different orders; Expects GetCurrNode() returns the nodes in the same
628 // order.
629 manager.AddNode(&node7);
630 manager.AddNode(&node12);
631
632 manager2.AddNode(&node12);
633 manager2.AddNode(&node7);
634
635 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
636 manager.RemoveCurrNode();
637 manager2.RemoveCurrNode();
638 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
639 manager.RemoveCurrNode();
640 manager2.RemoveCurrNode();
641 EXPECT_TRUE(manager.Empty());
642 }
643
644 // Class for testing virtual scheduler.
645 class TestVirtualScheduler : public VirtualScheduler {
646 public:
TestVirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,ReadyNodeManager * ready_node_manager,Cluster * cluster)647 TestVirtualScheduler(const bool use_static_shapes,
648 const bool use_aggressive_shape_inference,
649 ReadyNodeManager* ready_node_manager, Cluster* cluster)
650 : VirtualScheduler(
651 use_static_shapes, use_aggressive_shape_inference, cluster,
652 ready_node_manager,
653 absl::make_unique<VirtualPlacer>(cluster->GetDevices())) {
654 enable_mem_usage_tracking();
655 }
656
657 FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
658 FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
659 FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
660 FRIEND_TEST(VirtualSchedulerTest, Variable);
661 FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
662 };
663
664 class VirtualSchedulerTest : public ::testing::Test {
665 protected:
VirtualSchedulerTest()666 VirtualSchedulerTest() {
667 // Initializes cluster_ and scheduler_.
668 std::unordered_map<string, DeviceProperties> devices;
669
670 // Set some dummy CPU properties
671 DeviceProperties cpu_device = GetDummyCPUDevice();
672
673 // IMPORTANT: Device is not actually ever used in the test case since
674 // force_cpu_type is defaulted to "Haswell"
675 devices[kCPU0] = cpu_device;
676 devices[kCPU1] = cpu_device;
677 cluster_ = absl::make_unique<VirtualCluster>(devices);
678 scheduler_ = absl::make_unique<TestVirtualScheduler>(
679 /*use_static_shapes=*/true,
680 /*use_aggressive_shape_inference=*/true, &first_ready_manager_,
681 cluster_.get());
682 }
683
GetDummyCPUDevice()684 DeviceProperties GetDummyCPUDevice() {
685 // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
686 // - 8 Gflops
687 // - 2 GB/s
688 DeviceProperties cpu_device;
689 cpu_device.set_type("CPU");
690 cpu_device.set_frequency(4000);
691 cpu_device.set_num_cores(2);
692 cpu_device.set_bandwidth(2000000);
693 return cpu_device;
694 }
695
696 // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()697 void CreateGrapplerItemWithConv2Ds() {
698 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
699 auto x = ops::RandomUniform(
700 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
701 auto y = ops::RandomUniform(
702 s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
703 auto z = ops::RandomUniform(
704 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
705 auto f = ops::RandomUniform(
706 s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
707 std::vector<int> strides = {1, 1, 1, 1};
708 auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
709 auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
710 auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
711
712 grappler_item_ = absl::make_unique<GrapplerItem>();
713 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
714 grappler_item_->id = "test_conv2d_graph";
715 grappler_item_->fetch = {"c0", "c1"};
716
717 dependency_["c0"] = {"x", "f"};
718 dependency_["c1"] = {"y", "f"};
719 }
720
721 // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()722 void CreateGrapplerItemWithConv2DAndVariable() {
723 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
724 auto x = ops::RandomUniform(
725 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
726 auto f = ops::Variable(s.WithOpName("f"),
727 {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
728 std::vector<int> strides = {1, 1, 1, 1};
729 auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
730
731 grappler_item_ = absl::make_unique<GrapplerItem>();
732 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
733 grappler_item_->id = "test_conv2d_var_graph";
734
735 grappler_item_->fetch = {"y"};
736
737 dependency_["y"] = {"x", "f"};
738 }
739
CreateGrapplerItemWithMatmulChain()740 void CreateGrapplerItemWithMatmulChain() {
741 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
742 // Add control dependencies to ensure tests do not rely on specific
743 // manager and the order remains consistent for the test.
744 auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
745 auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
746 {3200, 3200}, DT_FLOAT);
747 auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
748 {3200, 3200}, DT_FLOAT);
749 auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
750 {3200, 3200}, DT_FLOAT);
751 auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
752 {3200, 3200}, DT_FLOAT);
753
754 auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
755 auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
756 auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
757 auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
758
759 grappler_item_ = absl::make_unique<GrapplerItem>();
760 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
761 grappler_item_->id = "test_matmul_sequence_graph";
762 grappler_item_->fetch = {"abcde"};
763
764 dependency_["ab"] = {"a", "b"};
765 dependency_["abc"] = {"ab", "c"};
766 dependency_["abcd"] = {"abc", "d"};
767 dependency_["abcde"] = {"abcd", "e"};
768 }
769
770 // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()771 void CreateGrapplerItemWithAddN() {
772 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
773 auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
774 auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
775 auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
776 auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
777 OutputList input_tensors = {x, y, z, w};
778 auto add = ops::AddN(s.WithOpName("add"), input_tensors);
779 auto out = ops::Identity(s.WithOpName("out"), add);
780
781 grappler_item_ = absl::make_unique<GrapplerItem>();
782 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
783 grappler_item_->id = "test_addn_graph";
784 grappler_item_->fetch = {"out"};
785
786 dependency_["out"] = {"x", "y", "z", "w", "add"};
787 }
788
789 // Graph with some placeholder feed nodes that are not in the fetch fan-in.
CreateGrapplerItemWithUnnecessaryPlaceholderNodes()790 void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
791 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
792 auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
793 auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
794
795 grappler_item_ = absl::make_unique<GrapplerItem>();
796 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
797
798 grappler_item_->id = "test_extra_placeholders";
799 grappler_item_->fetch = {"x"};
800
801 // Grappler Item Builder puts all placeholder nodes into the feed
802 // list by default.
803 grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
804 }
805
806 // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()807 void CreateGrapplerItemWithControlDependency() {
808 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
809 std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
810 std::vector<Operation> input_tensors;
811 for (const auto& input : input_noop_names) {
812 auto x = ops::NoOp(s.WithOpName(input));
813 input_tensors.push_back(x.operation);
814 }
815 auto out =
816 ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
817
818 grappler_item_ = absl::make_unique<GrapplerItem>();
819 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
820
821 grappler_item_->id = "test_control_dependency_graph";
822 grappler_item_->fetch = {"out"};
823
824 dependency_["out"] = input_noop_names;
825 }
826
CreateGrapplerItemWithAddFromOneTensor()827 void CreateGrapplerItemWithAddFromOneTensor() {
828 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
829 auto x = tensorflow::ops::RandomUniform(
830 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
831
832 auto y = tensorflow::ops::Add(s.WithOpName("y"), x, x);
833 Output fetch = ops::Identity(s.WithOpName("fetch"), y);
834
835 grappler_item_ = absl::make_unique<GrapplerItem>();
836 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
837
838 grappler_item_->id = "test_add_from_one_tensor";
839 grappler_item_->fetch = {"fetch"};
840
841 dependency_["fetch"] = {"y"};
842 dependency_["y"] = {"x"};
843 }
844
CreateGrapplerItemWithSwitchMergeInput()845 void CreateGrapplerItemWithSwitchMergeInput() {
846 // sw = Switch(x, pred)
847 // a = Add(S:1, b)
848 // m = Merge(sw:0, a)
849 // y = Add(m, z)
850
851 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
852 auto x = ops::RandomUniform(
853 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
854 auto pred = ops::Const(s.WithOpName("pred"), false, {});
855 auto sw = ops::Switch(s.WithOpName("switch"), x, pred);
856 auto b = ops::RandomUniform(
857 s.WithOpName("b"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
858 auto a = ops::Add(s.WithOpName("a"), sw.output_true, b);
859 auto m = ops::Merge(s.WithOpName("m"), {sw.output_false, a.z});
860 auto z = ops::RandomUniform(
861 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
862 auto y = ops::Add(s.WithOpName("y"), m.output, z);
863
864 grappler_item_ = absl::make_unique<GrapplerItem>();
865 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
866
867 grappler_item_->id = "test_add_merge_switch";
868 grappler_item_->fetch = {"y"};
869
870 dependency_["y"] = {"m", "z"};
871 }
872
873 // FusedBN [an op with multiple outputs] with multiple consumers (including
874 // control dependency).
CreateGrapplerItemWithBatchNorm()875 void CreateGrapplerItemWithBatchNorm() {
876 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
877 auto x = ops::RandomUniform(
878 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
879 auto scale =
880 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
881 auto offset =
882 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
883 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
884 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
885
886 auto batch_norm = ops::FusedBatchNorm(
887 s.WithOpName("bn"), x, scale, offset, mean, var,
888 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
889 auto y = batch_norm.y;
890 auto batch_mean = batch_norm.batch_mean;
891 auto batch_var = batch_norm.batch_variance;
892
893 auto z1 = ops::Add(s.WithOpName("z1"), x, y);
894 auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
895 auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
896 std::vector<Operation> input_tensors = {
897 batch_mean.op(),
898 z1.z.op(),
899 z2.z.op(),
900 z3.z.op(),
901 };
902 auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
903
904 grappler_item_ = absl::make_unique<GrapplerItem>();
905 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
906
907 grappler_item_->id = "test_complex_dependency_graph";
908 grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
909
910 dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
911 dependency_["z1"] = {"x", "bn"};
912 dependency_["z2"] = {"bn"};
913 dependency_["z3"] = {"bn"};
914 dependency_["z4"] = {"bn"};
915 }
916
CreateGrapplerItemWithSendRecv()917 void CreateGrapplerItemWithSendRecv() {
918 const string gdef_ascii = R"EOF(
919 node {
920 name: "Const"
921 op: "Const"
922 device: "/job:localhost/replica:0/task:0/device:CPU:0"
923 attr {
924 key: "dtype"
925 value {
926 type: DT_FLOAT
927 }
928 }
929 attr {
930 key: "_output_shapes"
931 value {
932 list { shape {
933 dim { size: 128 }
934 dim { size: 32 }
935 }}}
936 }
937 attr {
938 key: "shape"
939 value {
940 list { shape {
941 dim { size: 128 }
942 dim { size: 32 }
943 }}}
944 }
945 attr {
946 key: "value"
947 value {
948 tensor {
949 dtype: DT_FLOAT
950 tensor_shape {
951 dim { size: 128 }
952 dim { size: 32 }
953 }
954 float_val: 3.1415
955 }
956 }
957 }
958 }
959 node {
960 name: "Send"
961 op: "_Send"
962 input: "Const"
963 device: "/job:localhost/replica:0/task:0/device:CPU:0"
964 attr {
965 key: "T"
966 value {
967 type: DT_FLOAT
968 }
969 }
970 attr {
971 key: "_output_shapes"
972 value {
973 list { shape {
974 dim { size: 128 }
975 dim { size: 32 }
976 }}}
977 }
978 attr {
979 key: "shape"
980 value {
981 list { shape {
982 dim { size: 128 }
983 dim { size: 32 }
984 }}}
985 }
986 attr {
987 key: "client_terminated"
988 value {
989 b: false
990 }
991 }
992 attr {
993 key: "recv_device"
994 value {
995 s: "/job:localhost/replica:0/task:0/device:CPU:0"
996 }
997 }
998 attr {
999 key: "send_device"
1000 value {
1001 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1002 }
1003 }
1004 attr {
1005 key: "send_device_incarnation"
1006 value {
1007 i: 0
1008 }
1009 }
1010 attr {
1011 key: "tensor_name"
1012 value {
1013 s: "test"
1014 }
1015 }
1016 }
1017 node {
1018 name: "Recv"
1019 op: "_Recv"
1020 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1021 attr {
1022 key: "client_terminated"
1023 value {
1024 b: false
1025 }
1026 }
1027 attr {
1028 key: "_output_shapes"
1029 value {
1030 list { shape {
1031 dim { size: 128 }
1032 dim { size: 32 }
1033 }}}
1034 }
1035 attr {
1036 key: "shape"
1037 value {
1038 list { shape {
1039 dim { size: 128 }
1040 dim { size: 32 }
1041 }}}
1042 }
1043 attr {
1044 key: "recv_device"
1045 value {
1046 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1047 }
1048 }
1049 attr {
1050 key: "send_device"
1051 value {
1052 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1053 }
1054 }
1055 attr {
1056 key: "send_device_incarnation"
1057 value {
1058 i: 0
1059 }
1060 }
1061 attr {
1062 key: "tensor_name"
1063 value {
1064 s: "test"
1065 }
1066 }
1067 attr {
1068 key: "tensor_type"
1069 value {
1070 type: DT_FLOAT
1071 }
1072 }
1073 }
1074 library {
1075 }
1076 versions {
1077 producer: 24
1078 }
1079 )EOF";
1080
1081 grappler_item_ = absl::make_unique<GrapplerItem>();
1082
1083 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1084 &grappler_item_->graph));
1085 grappler_item_->id = "test_graph";
1086 grappler_item_->fetch = {"Recv"};
1087 }
1088
CreateGrapplerItemWithRecvWithoutSend()1089 void CreateGrapplerItemWithRecvWithoutSend() {
1090 const string gdef_ascii = R"EOF(
1091 node {
1092 name: "Recv"
1093 op: "_Recv"
1094 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1095 attr {
1096 key: "client_terminated"
1097 value {
1098 b: false
1099 }
1100 }
1101 attr {
1102 key: "recv_device"
1103 value {
1104 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1105 }
1106 }
1107 attr {
1108 key: "send_device"
1109 value {
1110 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1111 }
1112 }
1113 attr {
1114 key: "send_device_incarnation"
1115 value {
1116 i: 0
1117 }
1118 }
1119 attr {
1120 key: "tensor_name"
1121 value {
1122 s: "test"
1123 }
1124 }
1125 attr {
1126 key: "tensor_type"
1127 value {
1128 type: DT_FLOAT
1129 }
1130 }
1131 }
1132 library {
1133 }
1134 versions {
1135 producer: 24
1136 }
1137 )EOF";
1138
1139 grappler_item_ = absl::make_unique<GrapplerItem>();
1140 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1141 &grappler_item_->graph));
1142 grappler_item_->id = "test_graph";
1143 grappler_item_->fetch = {"Recv"};
1144 }
1145
1146 // A simple while loop
CreateGrapplerItemWithLoop()1147 void CreateGrapplerItemWithLoop() {
1148 // Test graph produced in python using:
1149 /*
1150 with tf.Graph().as_default():
1151 i0 = tf.constant(0)
1152 m0 = tf.ones([2, 2])
1153 c = lambda i, m: i < 10
1154 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1155 r = tf.while_loop(
1156 c, b, loop_vars=[i0, m0],
1157 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1158 with open('/tmp/graph.pbtxt', 'w') as f:
1159 f.write(str(tf.get_default_graph().as_graph_def()))
1160 */
1161 const string gdef_ascii = R"EOF(
1162 node {
1163 name: "Const"
1164 op: "Const"
1165 attr {
1166 key: "dtype"
1167 value {
1168 type: DT_INT32
1169 }
1170 }
1171 attr {
1172 key: "value"
1173 value {
1174 tensor {
1175 dtype: DT_INT32
1176 tensor_shape {
1177 }
1178 int_val: 0
1179 }
1180 }
1181 }
1182 }
1183 node {
1184 name: "ones"
1185 op: "Const"
1186 attr {
1187 key: "dtype"
1188 value {
1189 type: DT_FLOAT
1190 }
1191 }
1192 attr {
1193 key: "value"
1194 value {
1195 tensor {
1196 dtype: DT_FLOAT
1197 tensor_shape {
1198 dim {
1199 size: 2
1200 }
1201 dim {
1202 size: 2
1203 }
1204 }
1205 float_val: 1.0
1206 }
1207 }
1208 }
1209 }
1210 node {
1211 name: "while/Enter"
1212 op: "Enter"
1213 input: "Const"
1214 attr {
1215 key: "T"
1216 value {
1217 type: DT_INT32
1218 }
1219 }
1220 attr {
1221 key: "frame_name"
1222 value {
1223 s: "while/while/"
1224 }
1225 }
1226 attr {
1227 key: "is_constant"
1228 value {
1229 b: false
1230 }
1231 }
1232 attr {
1233 key: "parallel_iterations"
1234 value {
1235 i: 10
1236 }
1237 }
1238 }
1239 node {
1240 name: "while/Enter_1"
1241 op: "Enter"
1242 input: "ones"
1243 attr {
1244 key: "T"
1245 value {
1246 type: DT_FLOAT
1247 }
1248 }
1249 attr {
1250 key: "frame_name"
1251 value {
1252 s: "while/while/"
1253 }
1254 }
1255 attr {
1256 key: "is_constant"
1257 value {
1258 b: false
1259 }
1260 }
1261 attr {
1262 key: "parallel_iterations"
1263 value {
1264 i: 10
1265 }
1266 }
1267 }
1268 node {
1269 name: "while/Merge"
1270 op: "Merge"
1271 input: "while/Enter"
1272 input: "while/NextIteration"
1273 attr {
1274 key: "N"
1275 value {
1276 i: 2
1277 }
1278 }
1279 attr {
1280 key: "T"
1281 value {
1282 type: DT_INT32
1283 }
1284 }
1285 }
1286 node {
1287 name: "while/Merge_1"
1288 op: "Merge"
1289 input: "while/Enter_1"
1290 input: "while/NextIteration_1"
1291 attr {
1292 key: "N"
1293 value {
1294 i: 2
1295 }
1296 }
1297 attr {
1298 key: "T"
1299 value {
1300 type: DT_FLOAT
1301 }
1302 }
1303 }
1304 node {
1305 name: "while/Less/y"
1306 op: "Const"
1307 input: "^while/Merge"
1308 attr {
1309 key: "dtype"
1310 value {
1311 type: DT_INT32
1312 }
1313 }
1314 attr {
1315 key: "value"
1316 value {
1317 tensor {
1318 dtype: DT_INT32
1319 tensor_shape {
1320 }
1321 int_val: 10
1322 }
1323 }
1324 }
1325 }
1326 node {
1327 name: "while/Less"
1328 op: "Less"
1329 input: "while/Merge"
1330 input: "while/Less/y"
1331 attr {
1332 key: "T"
1333 value {
1334 type: DT_INT32
1335 }
1336 }
1337 }
1338 node {
1339 name: "while/LoopCond"
1340 op: "LoopCond"
1341 input: "while/Less"
1342 }
1343 node {
1344 name: "while/Switch"
1345 op: "Switch"
1346 input: "while/Merge"
1347 input: "while/LoopCond"
1348 attr {
1349 key: "T"
1350 value {
1351 type: DT_INT32
1352 }
1353 }
1354 attr {
1355 key: "_class"
1356 value {
1357 list {
1358 s: "loc:@while/Merge"
1359 }
1360 }
1361 }
1362 }
1363 node {
1364 name: "while/Switch_1"
1365 op: "Switch"
1366 input: "while/Merge_1"
1367 input: "while/LoopCond"
1368 attr {
1369 key: "T"
1370 value {
1371 type: DT_FLOAT
1372 }
1373 }
1374 attr {
1375 key: "_class"
1376 value {
1377 list {
1378 s: "loc:@while/Merge_1"
1379 }
1380 }
1381 }
1382 }
1383 node {
1384 name: "while/Identity"
1385 op: "Identity"
1386 input: "while/Switch:1"
1387 attr {
1388 key: "T"
1389 value {
1390 type: DT_INT32
1391 }
1392 }
1393 }
1394 node {
1395 name: "while/Identity_1"
1396 op: "Identity"
1397 input: "while/Switch_1:1"
1398 attr {
1399 key: "T"
1400 value {
1401 type: DT_FLOAT
1402 }
1403 }
1404 }
1405 node {
1406 name: "while/add/y"
1407 op: "Const"
1408 input: "^while/Identity"
1409 attr {
1410 key: "dtype"
1411 value {
1412 type: DT_INT32
1413 }
1414 }
1415 attr {
1416 key: "value"
1417 value {
1418 tensor {
1419 dtype: DT_INT32
1420 tensor_shape {
1421 }
1422 int_val: 1
1423 }
1424 }
1425 }
1426 }
1427 node {
1428 name: "while/add"
1429 op: "Add"
1430 input: "while/Identity"
1431 input: "while/add/y"
1432 attr {
1433 key: "T"
1434 value {
1435 type: DT_INT32
1436 }
1437 }
1438 }
1439 node {
1440 name: "while/concat/axis"
1441 op: "Const"
1442 input: "^while/Identity"
1443 attr {
1444 key: "dtype"
1445 value {
1446 type: DT_INT32
1447 }
1448 }
1449 attr {
1450 key: "value"
1451 value {
1452 tensor {
1453 dtype: DT_INT32
1454 tensor_shape {
1455 }
1456 int_val: 0
1457 }
1458 }
1459 }
1460 }
1461 node {
1462 name: "while/concat"
1463 op: "ConcatV2"
1464 input: "while/Identity_1"
1465 input: "while/Identity_1"
1466 input: "while/concat/axis"
1467 attr {
1468 key: "N"
1469 value {
1470 i: 2
1471 }
1472 }
1473 attr {
1474 key: "T"
1475 value {
1476 type: DT_FLOAT
1477 }
1478 }
1479 attr {
1480 key: "Tidx"
1481 value {
1482 type: DT_INT32
1483 }
1484 }
1485 }
1486 node {
1487 name: "while/NextIteration"
1488 op: "NextIteration"
1489 input: "while/add"
1490 attr {
1491 key: "T"
1492 value {
1493 type: DT_INT32
1494 }
1495 }
1496 }
1497 node {
1498 name: "while/NextIteration_1"
1499 op: "NextIteration"
1500 input: "while/concat"
1501 attr {
1502 key: "T"
1503 value {
1504 type: DT_FLOAT
1505 }
1506 }
1507 }
1508 node {
1509 name: "while/Exit"
1510 op: "Exit"
1511 input: "while/Switch"
1512 attr {
1513 key: "T"
1514 value {
1515 type: DT_INT32
1516 }
1517 }
1518 }
1519 node {
1520 name: "while/Exit_1"
1521 op: "Exit"
1522 input: "while/Switch_1"
1523 attr {
1524 key: "T"
1525 value {
1526 type: DT_FLOAT
1527 }
1528 }
1529 }
1530 versions {
1531 producer: 21
1532 }
1533 )EOF";
1534
1535 grappler_item_ = absl::make_unique<GrapplerItem>();
1536 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1537 &grappler_item_->graph));
1538 grappler_item_->id = "test_graph";
1539 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
1540 }
1541
1542 // A simple while loop strengthened with Switch outputs xxx.
CreateGrapplerItemWithLoopAnnotated()1543 void CreateGrapplerItemWithLoopAnnotated() {
1544 // Test graph produced in python using:
1545 /*
1546 with tf.Graph().as_default():
1547 i0 = tf.constant(0)
1548 m0 = tf.ones([2, 2])
1549 c = lambda i, m: i < 10
1550 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1551 r = tf.while_loop(
1552 c, b, loop_vars=[i0, m0],
1553 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1554 with open('/tmp/graph.pbtxt', 'w') as f:
1555 f.write(str(tf.get_default_graph().as_graph_def()))
1556 */
1557 const string gdef_ascii = R"EOF(
1558 node {
1559 name: "Const"
1560 op: "Const"
1561 attr {
1562 key: "dtype"
1563 value {
1564 type: DT_INT32
1565 }
1566 }
1567 attr {
1568 key: "value"
1569 value {
1570 tensor {
1571 dtype: DT_INT32
1572 tensor_shape {
1573 }
1574 int_val: 0
1575 }
1576 }
1577 }
1578 attr {
1579 key: "_execution_count"
1580 value {
1581 i: 1
1582 }
1583 }
1584 }
1585 node {
1586 name: "ones"
1587 op: "Const"
1588 attr {
1589 key: "dtype"
1590 value {
1591 type: DT_FLOAT
1592 }
1593 }
1594 attr {
1595 key: "value"
1596 value {
1597 tensor {
1598 dtype: DT_FLOAT
1599 tensor_shape {
1600 dim {
1601 size: 2
1602 }
1603 dim {
1604 size: 2
1605 }
1606 }
1607 float_val: 1.0
1608 }
1609 }
1610 }
1611 attr {
1612 key: "_execution_count"
1613 value {
1614 i: 1
1615 }
1616 }
1617 }
1618 node {
1619 name: "while/Enter"
1620 op: "Enter"
1621 input: "Const"
1622 attr {
1623 key: "T"
1624 value {
1625 type: DT_INT32
1626 }
1627 }
1628 attr {
1629 key: "frame_name"
1630 value {
1631 s: "while/while/"
1632 }
1633 }
1634 attr {
1635 key: "is_constant"
1636 value {
1637 b: false
1638 }
1639 }
1640 attr {
1641 key: "parallel_iterations"
1642 value {
1643 i: 10
1644 }
1645 }
1646 attr {
1647 key: "_execution_count"
1648 value {
1649 i: 1
1650 }
1651 }
1652 }
1653 node {
1654 name: "while/Enter_1"
1655 op: "Enter"
1656 input: "ones"
1657 attr {
1658 key: "T"
1659 value {
1660 type: DT_FLOAT
1661 }
1662 }
1663 attr {
1664 key: "frame_name"
1665 value {
1666 s: "while/while/"
1667 }
1668 }
1669 attr {
1670 key: "is_constant"
1671 value {
1672 b: false
1673 }
1674 }
1675 attr {
1676 key: "parallel_iterations"
1677 value {
1678 i: 10
1679 }
1680 }
1681 attr {
1682 key: "_execution_count"
1683 value {
1684 i: 1
1685 }
1686 }
1687 }
1688 node {
1689 name: "while/Merge"
1690 op: "Merge"
1691 input: "while/Enter"
1692 input: "while/NextIteration"
1693 attr {
1694 key: "N"
1695 value {
1696 i: 2
1697 }
1698 }
1699 attr {
1700 key: "T"
1701 value {
1702 type: DT_INT32
1703 }
1704 }
1705 attr {
1706 key: "_execution_count"
1707 value {
1708 i: 10
1709 }
1710 }
1711 }
1712 node {
1713 name: "while/Merge_1"
1714 op: "Merge"
1715 input: "while/Enter_1"
1716 input: "while/NextIteration_1"
1717 attr {
1718 key: "N"
1719 value {
1720 i: 2
1721 }
1722 }
1723 attr {
1724 key: "T"
1725 value {
1726 type: DT_FLOAT
1727 }
1728 }
1729 attr {
1730 key: "_execution_count"
1731 value {
1732 i: 10
1733 }
1734 }
1735 }
1736 node {
1737 name: "while/Less/y"
1738 op: "Const"
1739 input: "^while/Merge"
1740 attr {
1741 key: "dtype"
1742 value {
1743 type: DT_INT32
1744 }
1745 }
1746 attr {
1747 key: "value"
1748 value {
1749 tensor {
1750 dtype: DT_INT32
1751 tensor_shape {
1752 }
1753 int_val: 10
1754 }
1755 }
1756 }
1757 attr {
1758 key: "_execution_count"
1759 value {
1760 i: 10
1761 }
1762 }
1763 }
1764 node {
1765 name: "while/Less"
1766 op: "Less"
1767 input: "while/Merge"
1768 input: "while/Less/y"
1769 attr {
1770 key: "T"
1771 value {
1772 type: DT_INT32
1773 }
1774 }
1775 attr {
1776 key: "_execution_count"
1777 value {
1778 i: 10
1779 }
1780 }
1781 }
1782 node {
1783 name: "while/LoopCond"
1784 op: "LoopCond"
1785 input: "while/Less"
1786 attr {
1787 key: "_execution_count"
1788 value {
1789 i: 10
1790 }
1791 }
1792 }
1793 node {
1794 name: "while/Switch"
1795 op: "Switch"
1796 input: "while/Merge"
1797 input: "while/LoopCond"
1798 attr {
1799 key: "T"
1800 value {
1801 type: DT_INT32
1802 }
1803 }
1804 attr {
1805 key: "_class"
1806 value {
1807 list {
1808 s: "loc:@while/Merge"
1809 }
1810 }
1811 }
1812 attr {
1813 key: "_execution_count"
1814 value {
1815 i: 11
1816 }
1817 }
1818 attr {
1819 key: "_output_slot_vector"
1820 value {
1821 list {
1822 i: 1
1823 i: 1
1824 i: 1
1825 i: 1
1826 i: 1
1827 i: 1
1828 i: 1
1829 i: 1
1830 i: 1
1831 i: 1
1832 i: 0
1833 }
1834 }
1835 }
1836 }
1837 node {
1838 name: "while/Switch_1"
1839 op: "Switch"
1840 input: "while/Merge_1"
1841 input: "while/LoopCond"
1842 attr {
1843 key: "T"
1844 value {
1845 type: DT_FLOAT
1846 }
1847 }
1848 attr {
1849 key: "_class"
1850 value {
1851 list {
1852 s: "loc:@while/Merge_1"
1853 }
1854 }
1855 }
1856 attr {
1857 key: "_execution_count"
1858 value {
1859 i: 11
1860 }
1861 }
1862 attr {
1863 key: "_output_slot_vector"
1864 value {
1865 list {
1866 i: 1
1867 i: 1
1868 i: 1
1869 i: 1
1870 i: 1
1871 i: 1
1872 i: 1
1873 i: 1
1874 i: 1
1875 i: 1
1876 i: 0
1877 }
1878 }
1879 }
1880 }
1881 node {
1882 name: "while/Identity"
1883 op: "Identity"
1884 input: "while/Switch:1"
1885 attr {
1886 key: "T"
1887 value {
1888 type: DT_INT32
1889 }
1890 }
1891 attr {
1892 key: "_execution_count"
1893 value {
1894 i: 10
1895 }
1896 }
1897 }
1898 node {
1899 name: "while/Identity_1"
1900 op: "Identity"
1901 input: "while/Switch_1:1"
1902 attr {
1903 key: "T"
1904 value {
1905 type: DT_FLOAT
1906 }
1907 }
1908 attr {
1909 key: "_execution_count"
1910 value {
1911 i: 10
1912 }
1913 }
1914 }
1915 node {
1916 name: "while/add/y"
1917 op: "Const"
1918 input: "^while/Identity"
1919 attr {
1920 key: "dtype"
1921 value {
1922 type: DT_INT32
1923 }
1924 }
1925 attr {
1926 key: "value"
1927 value {
1928 tensor {
1929 dtype: DT_INT32
1930 tensor_shape {
1931 }
1932 int_val: 1
1933 }
1934 }
1935 }
1936 attr {
1937 key: "_execution_count"
1938 value {
1939 i: 10
1940 }
1941 }
1942 }
1943 node {
1944 name: "while/add"
1945 op: "Add"
1946 input: "while/Identity"
1947 input: "while/add/y"
1948 attr {
1949 key: "T"
1950 value {
1951 type: DT_INT32
1952 }
1953 }
1954 attr {
1955 key: "_execution_count"
1956 value {
1957 i: 10
1958 }
1959 }
1960 }
1961 node {
1962 name: "while/concat/axis"
1963 op: "Const"
1964 input: "^while/Identity"
1965 attr {
1966 key: "dtype"
1967 value {
1968 type: DT_INT32
1969 }
1970 }
1971 attr {
1972 key: "value"
1973 value {
1974 tensor {
1975 dtype: DT_INT32
1976 tensor_shape {
1977 }
1978 int_val: 0
1979 }
1980 }
1981 }
1982 attr {
1983 key: "_execution_count"
1984 value {
1985 i: 10
1986 }
1987 }
1988 }
1989 node {
1990 name: "while/concat"
1991 op: "ConcatV2"
1992 input: "while/Identity_1"
1993 input: "while/Identity_1"
1994 input: "while/concat/axis"
1995 attr {
1996 key: "N"
1997 value {
1998 i: 2
1999 }
2000 }
2001 attr {
2002 key: "T"
2003 value {
2004 type: DT_FLOAT
2005 }
2006 }
2007 attr {
2008 key: "Tidx"
2009 value {
2010 type: DT_INT32
2011 }
2012 }
2013 attr {
2014 key: "_execution_count"
2015 value {
2016 i: 10
2017 }
2018 }
2019 }
2020 node {
2021 name: "while/NextIteration"
2022 op: "NextIteration"
2023 input: "while/add"
2024 attr {
2025 key: "T"
2026 value {
2027 type: DT_INT32
2028 }
2029 }
2030 attr {
2031 key: "_execution_count"
2032 value {
2033 i: 10
2034 }
2035 }
2036 }
2037 node {
2038 name: "while/NextIteration_1"
2039 op: "NextIteration"
2040 input: "while/concat"
2041 attr {
2042 key: "T"
2043 value {
2044 type: DT_FLOAT
2045 }
2046 }
2047 attr {
2048 key: "_execution_count"
2049 value {
2050 i: 10
2051 }
2052 }
2053 }
2054 node {
2055 name: "while/Exit"
2056 op: "Exit"
2057 input: "while/Switch"
2058 attr {
2059 key: "T"
2060 value {
2061 type: DT_INT32
2062 }
2063 }
2064 attr {
2065 key: "_execution_count"
2066 value {
2067 i: 1
2068 }
2069 }
2070 }
2071 node {
2072 name: "while/Exit_1"
2073 op: "Exit"
2074 input: "while/Switch_1"
2075 attr {
2076 key: "T"
2077 value {
2078 type: DT_FLOAT
2079 }
2080 }
2081 attr {
2082 key: "_execution_count"
2083 value {
2084 i: 1
2085 }
2086 }
2087 }
2088 versions {
2089 producer: 21
2090 }
2091 )EOF";
2092
2093 grappler_item_.reset(new GrapplerItem);
2094 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2095 &grappler_item_->graph));
2096 grappler_item_->id = "test_graph";
2097 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
2098 }
2099
2100 // A simple condition graph.
CreateGrapplerItemWithCondition()2101 void CreateGrapplerItemWithCondition() {
2102 // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
2103 const string gdef_ascii = R"EOF(
2104 node {
2105 name: "a"
2106 op: "Const"
2107 attr {
2108 key: "dtype"
2109 value {
2110 type: DT_FLOAT
2111 }
2112 }
2113 attr {
2114 key: "value"
2115 value {
2116 tensor {
2117 dtype: DT_FLOAT
2118 tensor_shape {
2119 }
2120 float_val: 2.0
2121 }
2122 }
2123 }
2124 }
2125 node {
2126 name: "Less"
2127 op: "Const"
2128 attr {
2129 key: "dtype"
2130 value {
2131 type: DT_BOOL
2132 }
2133 }
2134 attr {
2135 key: "value"
2136 value {
2137 tensor {
2138 dtype: DT_BOOL
2139 tensor_shape {
2140 }
2141 tensor_content: "\001"
2142 }
2143 }
2144 }
2145 }
2146 node {
2147 name: "Switch"
2148 op: "Switch"
2149 input: "a"
2150 input: "Less"
2151 attr {
2152 key: "T"
2153 value {
2154 type: DT_FLOAT
2155 }
2156 }
2157 }
2158 node {
2159 name: "First"
2160 op: "Identity"
2161 input: "Switch"
2162 attr {
2163 key: "T"
2164 value {
2165 type: DT_FLOAT
2166 }
2167 }
2168 }
2169 node {
2170 name: "Second"
2171 op: "Identity"
2172 input: "Switch:1"
2173 attr {
2174 key: "T"
2175 value {
2176 type: DT_FLOAT
2177 }
2178 }
2179 }
2180 node {
2181 name: "Merge"
2182 op: "Merge"
2183 input: "First"
2184 input: "Second"
2185 attr {
2186 key: "N"
2187 value {
2188 i: 2
2189 }
2190 }
2191 attr {
2192 key: "T"
2193 value {
2194 type: DT_FLOAT
2195 }
2196 }
2197 }
2198 versions {
2199 producer: 27
2200 })EOF";
2201
2202 grappler_item_ = absl::make_unique<GrapplerItem>();
2203 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2204 &grappler_item_->graph));
2205 grappler_item_->id = "test_graph";
2206 grappler_item_->fetch = {"Merge"};
2207 }
2208
2209 // Create a FusedBatchNorm op that has multiple output ports.
CreateGrapplerItemWithInterDeviceTransfers()2210 void CreateGrapplerItemWithInterDeviceTransfers() {
2211 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
2212
2213 // Create a FusedBatchNorm op that has multiple output ports.
2214 auto x = ops::RandomUniform(
2215 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
2216 auto scale =
2217 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
2218 auto offset =
2219 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
2220 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
2221 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
2222
2223 auto batch_norm = ops::FusedBatchNorm(
2224 s.WithOpName("bn"), x, scale, offset, mean, var,
2225 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
2226 auto y = batch_norm.y;
2227 auto batch_mean = batch_norm.batch_mean;
2228 auto batch_var = batch_norm.batch_variance;
2229 // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
2230 auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
2231 auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
2232 // batch_mean1 and batch_var1 take different output ports, so each will
2233 // initiate Send/Recv.
2234 auto batch_mean1 = ops::Identity(
2235 s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
2236 auto batch_var1 =
2237 ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
2238 // This is control dependency.
2239 auto control_dep = ops::NoOp(s.WithOpName("control_dep")
2240 .WithControlDependencies(y)
2241 .WithDevice(kCPU1));
2242
2243 grappler_item_ = absl::make_unique<GrapplerItem>();
2244 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
2245 grappler_item_->id = "test_conv2d_graph";
2246 grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
2247 "control_dep"};
2248
2249 dependency_["bn"] = {"x", "mean", "var"};
2250 dependency_["y1"] = {"bn"};
2251 dependency_["y2"] = {"bn"};
2252 dependency_["batch_mean1"] = {"bn"};
2253 dependency_["batch_var1"] = {"bn"};
2254 dependency_["control_dep"] = {"bn"};
2255 }
2256
2257 // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()2258 void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); }
2259
2260 // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const2261 Costs SimplePredictCosts(const OpContext& op_context) const {
2262 Costs c;
2263 int64 exec_cost = 0;
2264 if (op_context.op_info.op() == "MatMul") {
2265 exec_cost = 2000000000;
2266 } else if (op_context.op_info.op() == "RandomUniform") {
2267 exec_cost = 1000000000;
2268 } else {
2269 exec_cost = 1000;
2270 }
2271 c.execution_time = Costs::NanoSeconds(exec_cost);
2272 return c;
2273 }
2274
2275 // Call this after init scheduler_. Scheduler stops after executing
2276 // target_node.
RunScheduler(const string & target_node)2277 std::unordered_map<string, OpContext> RunScheduler(
2278 const string& target_node) {
2279 std::unordered_map<string, OpContext> ops_executed;
2280 bool more_nodes = true;
2281 do {
2282 OpContext op_context = scheduler_->GetCurrNode();
2283 ops_executed[op_context.name] = op_context;
2284 std::cout << op_context.name << std::endl;
2285
2286 Costs node_costs = SimplePredictCosts(op_context);
2287
2288 // Check scheduling order.
2289 auto it = dependency_.find(op_context.name);
2290 if (it != dependency_.end()) {
2291 for (const auto& preceding_node : it->second) {
2292 EXPECT_GT(ops_executed.count(preceding_node), 0);
2293 }
2294 }
2295 more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
2296
2297 if (op_context.name == target_node) {
2298 // Scheduler has the state after executing the target node.
2299 break;
2300 }
2301 } while (more_nodes);
2302 return ops_executed;
2303 }
2304
2305 // Helper method for validating a vector.
2306 template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)2307 void ExpectVectorEq(const std::vector<T>& expected,
2308 const std::vector<T>& test_elements) {
2309 // Set of expected elements for an easy comparison.
2310 std::set<T> expected_set(expected.begin(), expected.end());
2311 for (const auto& element : test_elements) {
2312 EXPECT_GT(expected_set.count(element), 0);
2313 }
2314 EXPECT_EQ(expected.size(), test_elements.size());
2315 }
2316
2317 // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)2318 void ValidateNodeDefs(const std::vector<string>& expected,
2319 const std::vector<const NodeDef*>& node_defs) {
2320 std::vector<string> node_names;
2321 std::transform(node_defs.begin(), node_defs.end(),
2322 std::back_inserter(node_names),
2323 [](const NodeDef* node) { return node->name(); });
2324 ExpectVectorEq(expected, node_names);
2325 }
2326
2327 // Helper method for validating a set.
2328 template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)2329 void ExpectSetEq(const std::set<T>& expected,
2330 const std::set<T>& test_elements) {
2331 for (const auto& element : test_elements) {
2332 EXPECT_GT(expected.count(element), 0);
2333 }
2334 EXPECT_EQ(expected.size(), test_elements.size());
2335 }
2336
2337 // Helper method for validating an unordered map.
2338 template <typename T, typename U>
ExpectUnorderedMapEq(const std::unordered_map<T,U> & expected,const std::unordered_map<T,U> & test_map)2339 void ExpectUnorderedMapEq(const std::unordered_map<T, U>& expected,
2340 const std::unordered_map<T, U>& test_map) {
2341 EXPECT_EQ(expected.size(), test_map.size());
2342 for (const auto& key_val : expected) {
2343 EXPECT_GT(test_map.count(key_val.first), 0);
2344 EXPECT_EQ(test_map.at(key_val.first), key_val.second);
2345 }
2346 }
2347
2348 // Helper method that checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)2349 void ValidateMemoryUsageSnapshot(
2350 const std::vector<string>& expected_names, const int port_num_expected,
2351 const std::unordered_set<std::pair<const NodeDef*, int>,
2352 DeviceState::NodePairHash>& mem_usage_snapshot) {
2353 std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
2354 std::transform(
2355 mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
2356 std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
2357 [](const std::pair<const NodeDef*, int>& node_port) {
2358 return std::make_pair(node_port.first->name(), node_port.second);
2359 });
2360 std::set<std::pair<string, int>> expected;
2361 std::transform(expected_names.begin(), expected_names.end(),
2362 std::inserter(expected, expected.begin()),
2363 [port_num_expected](const string& name) {
2364 return std::make_pair(name, port_num_expected);
2365 });
2366 ExpectSetEq(expected, nodes_at_peak_mem_usage);
2367 }
2368
2369 // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64> & start_times,const std::vector<string> & nodes_in_dependency_order)2370 void ValidateDependencyChain(
2371 const std::unordered_map<string, int64>& start_times,
2372 const std::vector<string>& nodes_in_dependency_order) {
2373 int64 prev_node_time = -1;
2374 for (const auto& node : nodes_in_dependency_order) {
2375 int64 curr_node_time = start_times.at(node);
2376 EXPECT_GE(curr_node_time, prev_node_time);
2377 prev_node_time = curr_node_time;
2378 }
2379 }
2380
2381 // cluster_ and scheduler_ are initialized in the c'tor.
2382 std::unique_ptr<VirtualCluster> cluster_;
2383 std::unique_ptr<TestVirtualScheduler> scheduler_;
2384 FirstReadyManager first_ready_manager_;
2385 CompositeNodeManager composite_node_manager_;
2386
2387 // grappler_item_ will be initialized differently for each test case.
2388 std::unique_ptr<GrapplerItem> grappler_item_;
2389 // Node name -> its preceding nodes map for testing scheduling order.
2390 std::unordered_map<string, std::vector<string>> dependency_;
2391
2392 // Shared params for Conv2D related graphs:
2393 const int batch_size_ = 4;
2394 const int width_ = 10;
2395 const int height_ = 10;
2396 const int depth_in_ = 8;
2397 const int kernel_ = 3;
2398 const int depth_out_ = 16;
2399 };
2400
2401 // Create small graph, run predict costs on it, make sure the costs from the
2402 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)2403 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
2404 // Run matmul test.
2405 CreateGrapplerItemWithMatmulChain();
2406 InitScheduler();
2407 auto ops_executed = RunScheduler("");
2408 Costs c = scheduler_->Summary();
2409
2410 // RandomUniform - 5 * 1s
2411 // Matmuls - 4 * 2s = 8
2412 // Misc - 5 * 1us
2413 // Total: 13000005
2414 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2415 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2416 EXPECT_FALSE(c.inaccurate);
2417 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2418 }
2419
2420 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
2421 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)2422 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
2423 // Run matmul test.
2424 CreateGrapplerItemWithMatmulChain();
2425 InitScheduler();
2426 auto ops_executed = RunScheduler("");
2427 RunMetadata metadata;
2428 Costs c = scheduler_->Summary(&metadata);
2429 StepStats stepstats = metadata.step_stats();
2430 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2431 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2432 EXPECT_FALSE(c.inaccurate);
2433 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2434
2435 // Should only be 1 device!
2436 EXPECT_EQ(1, stepstats.dev_stats().size());
2437
2438 // Create a map of op name -> start and end times (micros).
2439 std::map<string, std::pair<int64, int64>> start_end_times;
2440 for (const auto& device_step_stats : stepstats.dev_stats()) {
2441 for (const auto& stats : device_step_stats.node_stats()) {
2442 int64 start = stats.all_start_micros();
2443 int64 end = start + stats.all_end_rel_micros();
2444 start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end);
2445
2446 // Make sure that the output properties are correct for
2447 // MatMul and RandomUniform operations.
2448 // We only check for dtype, and shape (excluding alloc)
2449 // since alloc is not set by the virtual scheduler.
2450 if (stats.timeline_label() == "MatMul" ||
2451 stats.timeline_label() == "RandomUniform") {
2452 EXPECT_EQ(1, stats.output().size());
2453 for (const auto& output : stats.output()) {
2454 EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
2455 EXPECT_EQ(2, output.tensor_description().shape().dim().size());
2456 for (const auto& dim : output.tensor_description().shape().dim()) {
2457 EXPECT_EQ(3200, dim.size());
2458 }
2459 }
2460 }
2461 }
2462 }
2463
2464 // The base start_time is the time to compute RandomUniforms
2465 int64 cur_time = static_cast<int64>(5000005);
2466 // The increment is the execution time of one matmul. See
2467 // CreateGrapplerItemWithMatmulChain for details.
2468 int64 increment = static_cast<int64>(2000000);
2469 auto op_names = {"ab", "abc", "abcd", "abcde"};
2470 for (const auto& op_name : op_names) {
2471 int64 actual_start = start_end_times[op_name].first;
2472 int64 actual_end = start_end_times[op_name].second;
2473 int64 expected_start = cur_time;
2474 int64 expected_end = cur_time + increment;
2475 EXPECT_EQ(expected_start, actual_start);
2476 EXPECT_EQ(expected_end, actual_end);
2477 cur_time += increment;
2478 }
2479 }
2480
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)2481 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
2482 // Init.
2483 CreateGrapplerItemWithConv2Ds();
2484 InitScheduler();
2485
2486 // Run the scheduler.
2487 auto ops_executed = RunScheduler(""); // Run all the nodes.
2488
2489 // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
2490 // executed.
2491 EXPECT_EQ(8, ops_executed.size());
2492
2493 // x, y, f, c0, and c1 should be in the ops executed.
2494 EXPECT_GT(ops_executed.count("x"), 0);
2495 EXPECT_GT(ops_executed.count("y"), 0);
2496 EXPECT_GT(ops_executed.count("f"), 0);
2497 EXPECT_GT(ops_executed.count("c0"), 0);
2498 EXPECT_GT(ops_executed.count("c1"), 0);
2499
2500 // z and c2 shouldn't be part of it.
2501 EXPECT_EQ(ops_executed.count("z"), 0);
2502 EXPECT_EQ(ops_executed.count("c2"), 0);
2503
2504 // Check input / output properties.
2505 EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
2506 EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
2507 EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
2508 EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
2509 EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
2510 }
2511
TEST_F(VirtualSchedulerTest,MemoryUsage)2512 TEST_F(VirtualSchedulerTest, MemoryUsage) {
2513 // Init.
2514 CreateGrapplerItemWithAddN();
2515 InitScheduler();
2516
2517 // Run the scheduler.
2518 RunScheduler("");
2519
2520 const auto* device_states = scheduler_->GetDeviceStates();
2521 const auto& cpu_state = device_states->at(kCPU0);
2522
2523 // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
2524 // is 4 x the input tensor size while executing the out node.
2525 int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
2526 const std::vector<string> expected_names = {"x", "y", "z", "w", "add"};
2527 EXPECT_EQ(expected_names.size() * one_input_node_size,
2528 cpu_state.max_memory_usage);
2529 ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
2530 cpu_state.mem_usage_snapshot_at_peak);
2531 ExpectUnorderedMapEq(
2532 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 64)},
2533 scheduler_->GetPersistentMemoryUsage());
2534 ExpectUnorderedMapEq(
2535 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 200000)},
2536 scheduler_->GetPeakMemoryUsage());
2537 }
2538
TEST_F(VirtualSchedulerTest,MemoryUsageForStreamingOps)2539 TEST_F(VirtualSchedulerTest, MemoryUsageForStreamingOps) {
2540 // Init.
2541 CreateGrapplerItemWithAddN();
2542 auto& graph = grappler_item_->graph;
2543 // Nodes add and out are placed on CPU1.
2544 // Nodes x, y are allocate in memory, while Nodes z and w are streaming nodes.
2545 for (auto& node : *graph.mutable_node()) {
2546 if (node.name() == "out" || node.name() == "add") {
2547 node.set_device(kCPU1);
2548 }
2549 if (node.name() == "z" || node.name() == "w")
2550 (*node.mutable_attr())[kStreaming].mutable_list()->add_b(true);
2551 }
2552
2553 InitScheduler();
2554
2555 // Run the scheduler.
2556 auto ops_executed = RunScheduler("");
2557
2558 const auto* device_states = scheduler_->GetDeviceStates();
2559 const auto& cpu_state_0 = device_states->at(kCPU0);
2560 const auto& cpu_state_1 = device_states->at(kCPU1);
2561 // All tensors are of the same size, 10 x 10 x 10 x 10.
2562 int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
2563 const std::vector<string> cpu_0_expected_tensors = {"x", "y"};
2564 const std::vector<string> cpu_1_expected_tensors = {"x", "y", "add"};
2565 EXPECT_EQ(cpu_0_expected_tensors.size() * one_input_node_size,
2566 cpu_state_0.max_memory_usage);
2567 EXPECT_EQ(cpu_1_expected_tensors.size() * one_input_node_size,
2568 cpu_state_1.max_memory_usage);
2569 // After the graph is executed, at the end, memory usage for the device
2570 // should be zero.
2571 EXPECT_EQ(cpu_state_0.memory_usage, 0);
2572 EXPECT_EQ(cpu_state_1.memory_usage, 0);
2573 }
2574
TEST_F(VirtualSchedulerTest,UnnecessaryFeedNodes)2575 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
2576 CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
2577 InitScheduler();
2578
2579 // Test that scheduler can run graphs with extra unnecessary feed nodes.
2580 auto ops_executed = RunScheduler("");
2581 ASSERT_EQ(1, ops_executed.size());
2582 ASSERT_EQ(ops_executed.count("x"), 1);
2583 }
2584
TEST_F(VirtualSchedulerTest,ControlDependency)2585 TEST_F(VirtualSchedulerTest, ControlDependency) {
2586 // Init.
2587 CreateGrapplerItemWithControlDependency();
2588 InitScheduler();
2589
2590 // Run the scheduler.
2591 RunScheduler("");
2592
2593 const auto* device_states = scheduler_->GetDeviceStates();
2594 const auto& cpu_state = device_states->at(kCPU0);
2595
2596 // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
2597 // memory usage is when executing the final NoOp.
2598 int64 one_input_node_size = 4; // control dependency
2599 const std::vector<string> expected_names = {"x", "y", "z", "w",
2600 "u", "v", "t"};
2601 EXPECT_EQ(expected_names.size() * one_input_node_size,
2602 cpu_state.max_memory_usage);
2603 ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
2604 cpu_state.mem_usage_snapshot_at_peak);
2605 ExpectUnorderedMapEq(
2606 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 0)},
2607 scheduler_->GetPersistentMemoryUsage());
2608 ExpectUnorderedMapEq(
2609 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 28)},
2610 scheduler_->GetPeakMemoryUsage());
2611 }
2612
TEST_F(VirtualSchedulerTest,ComplexDependency)2613 TEST_F(VirtualSchedulerTest, ComplexDependency) {
2614 // Init.
2615 CreateGrapplerItemWithBatchNorm();
2616 InitScheduler();
2617
2618 // Run the scheduler.
2619 RunScheduler("bn");
2620
2621 const auto& device_states = scheduler_->GetDeviceStates();
2622 const auto& cpu_state = device_states->at(kCPU0);
2623
2624 // The graph is
2625 // bn = FusedBatchNorm(x, scale, offset, mean, var)
2626 // z1 = bn.y + x
2627 // z2 = bn.var + bn.var
2628 // z3 = bn.var + bn.var
2629 // z4 = control dependency from bn.
2630 // Note that bn.mean doesn't have any consumer.
2631 const int x_size = batch_size_ * width_ * height_ * depth_in_;
2632 int64 expected_size =
2633 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
2634 1 /* control dependency */);
2635 EXPECT_EQ(expected_size, cpu_state.memory_usage);
2636
2637 // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
2638 std::set<std::pair<string, int>> nodes_in_memory;
2639 std::transform(
2640 cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
2641 std::inserter(nodes_in_memory, nodes_in_memory.begin()),
2642 [](const std::pair<const NodeDef*, int>& node_port) {
2643 return std::make_pair(node_port.first->name(), node_port.second);
2644 });
2645 std::set<std::pair<string, int>> expected = {
2646 std::make_pair("bn", -1),
2647 std::make_pair("bn", 0),
2648 std::make_pair("bn", 2),
2649 std::make_pair("x", 0),
2650 };
2651 ExpectSetEq(expected, nodes_in_memory);
2652
2653 const auto* node_states = scheduler_->GetNodeStates();
2654 const NodeState* bn_node = nullptr;
2655 const NodeState* x_node = nullptr;
2656 for (const auto& nodedef_node_state : *node_states) {
2657 const NodeDef* node = nodedef_node_state.first;
2658 const NodeState& node_state = nodedef_node_state.second;
2659 if (node->name() == "bn") {
2660 bn_node = &node_state;
2661 }
2662 if (node->name() == "x") {
2663 x_node = &node_state;
2664 }
2665 }
2666 CHECK_NOTNULL(bn_node);
2667 CHECK_NOTNULL(x_node);
2668
2669 ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
2670 ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
2671 ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
2672 // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
2673 ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
2674 }
2675
TEST_F(VirtualSchedulerTest,Variable)2676 TEST_F(VirtualSchedulerTest, Variable) {
2677 // Init.
2678 CreateGrapplerItemWithConv2DAndVariable();
2679 InitScheduler();
2680
2681 // Run the scheduler.
2682 RunScheduler("");
2683
2684 const auto* device_states = scheduler_->GetDeviceStates();
2685 const auto& cpu_state = device_states->at(kCPU0);
2686
2687 // There is one Conv2D that takes x and f, but f is variable, so it should be
2688 // in persistent nodes.
2689 ValidateMemoryUsageSnapshot({"f", "Const/Const"}, /*port_num_expected=*/0,
2690 cpu_state.persistent_nodes);
2691 // Only x in peak memory usage snapshot.
2692 ValidateMemoryUsageSnapshot({"x"}, /*port_num_expected=*/0,
2693 cpu_state.mem_usage_snapshot_at_peak);
2694 ExpectUnorderedMapEq(
2695 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 4624)},
2696 scheduler_->GetPersistentMemoryUsage());
2697 ExpectUnorderedMapEq(
2698 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 12800)},
2699 scheduler_->GetPeakMemoryUsage());
2700 }
2701
TEST_F(VirtualSchedulerTest,WhileLoop)2702 TEST_F(VirtualSchedulerTest, WhileLoop) {
2703 // Init.
2704 CreateGrapplerItemWithLoop();
2705 InitScheduler();
2706
2707 // Run the scheduler.
2708 RunScheduler("");
2709
2710 // Check the timeline
2711 RunMetadata metadata;
2712 scheduler_->Summary(&metadata);
2713
2714 // Nodes in topological order:
2715 // * const, ones
2716 // * while/Enter, while/Enter_1
2717 // * while/Merge, while/Merge_1
2718 // * while/Less/y
2719 // * while/Less
2720 // * while/LoopCond
2721 // * while/Switch, while/Switch_1
2722 // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
2723 // * while/add/y, while/concat/axis
2724 // * while/add, while/concat
2725 // * while/NextIteration, while/NextIteration_1
2726
2727 int num_next_iteration = 0;
2728 int num_next_iteration_1 = 0;
2729 int num_exit = 0;
2730 int num_exit_1 = 0;
2731 int64 next_iter_start_micro;
2732 int64 next_iter_1_start_micro;
2733 int64 exit_start_micro;
2734 int64 exit_1_start_micro;
2735
2736 std::unordered_map<string, int64> start_times;
2737 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2738 for (const auto& stats : device_step_stats.node_stats()) {
2739 start_times[stats.node_name()] = stats.all_start_micros();
2740 if (stats.node_name() == "while/NextIteration") {
2741 ++num_next_iteration;
2742 next_iter_start_micro = stats.all_start_micros();
2743 } else if (stats.node_name() == "while/NextIteration_1") {
2744 ++num_next_iteration_1;
2745 next_iter_1_start_micro = stats.all_start_micros();
2746 } else if (stats.node_name() == "while/Exit") {
2747 ++num_exit;
2748 exit_start_micro = stats.all_start_micros();
2749 } else if (stats.node_name() == "while/Exit_1") {
2750 ++num_exit_1;
2751 exit_1_start_micro = stats.all_start_micros();
2752 }
2753 }
2754 }
2755
2756 // Make sure we went though the body of the loop once, and that the output of
2757 // the loop was scheduled as well.
2758 EXPECT_EQ(1, num_next_iteration);
2759 EXPECT_EQ(1, num_next_iteration_1);
2760 EXPECT_EQ(1, num_exit);
2761 EXPECT_EQ(1, num_exit_1);
2762
2763 // Start times of while/NextIteration and while/NextIteration_1 should be
2764 // different, so should be those of while/Exit and while/Exit_1.
2765 EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
2766 EXPECT_NE(exit_start_micro, exit_1_start_micro);
2767
2768 // Check dependency among the nodes; no matter what scheduling mechanism we
2769 // use, the scheduled ops should follow these dependency chains.
2770 // Note that currently, VirtualScheduler executes while/Merge twice; hence,
2771 // we're not testing dependency chains related to while/Merge.
2772 // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
2773 // order of Enter, Merge, ...loop condition ..., ... loop body ...,
2774 // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
2775 // chaining test w/ Merge nodes.
2776 ValidateDependencyChain(
2777 start_times,
2778 {"Const", "while/Enter", // "while/Merge",
2779 "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
2780 "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
2781 // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
2782 ValidateDependencyChain(start_times,
2783 {"ones", "while/Enter_1", // "while/Merge_1",
2784 "while/Switch_1", "while/Identity_1", "while/concat",
2785 "while/NextIteration_1"});
2786 ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
2787 ValidateDependencyChain(
2788 start_times, {"while/Identity", "while/concat/axis", "while/concat"});
2789 ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
2790 ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
2791 }
2792
TEST_F(VirtualSchedulerTest,AnnotatedWhileLoop)2793 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
2794 {
2795 // Init.
2796 CreateGrapplerItemWithLoop();
2797 InitScheduler();
2798
2799 // Runs the scheduler.
2800 RunScheduler("");
2801 Costs c = scheduler_->Summary();
2802
2803 EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
2804 // Both while/Merge and while/Merge_1 are scheduled twice.
2805 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2806 EXPECT_FALSE(c.inaccurate);
2807 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2808 }
2809
2810 {
2811 // Init.
2812 CreateGrapplerItemWithLoopAnnotated();
2813 InitScheduler();
2814
2815 // Runs the scheduler.
2816 RunScheduler("");
2817 Costs c = scheduler_->Summary();
2818
2819 // The costs for Merge is accumulated twice for execution_count times, but
2820 // since Merge's cost is minimal, we keep this behavior here.
2821 EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
2822 // Both while/Merge and while/Merge_1 are scheduled twice.
2823 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2824 EXPECT_FALSE(c.inaccurate);
2825 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2826 }
2827 }
2828
TEST_F(VirtualSchedulerTest,Condition)2829 TEST_F(VirtualSchedulerTest, Condition) {
2830 // Without annotation.
2831 {
2832 // Inits.
2833 CreateGrapplerItemWithCondition();
2834 InitScheduler();
2835
2836 // Runs the scheduler.
2837 RunScheduler("");
2838 RunMetadata metadata;
2839 Costs c = scheduler_->Summary(&metadata);
2840
2841 // Nodes in topological order: a/Less, Switch, First/Second, Merge.
2842 int num_a = 0;
2843 int num_less = 0;
2844 int num_switch = 0;
2845 int num_first = 0;
2846 int num_second = 0;
2847 int num_merge = 0;
2848
2849 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2850 for (const auto& stats : device_step_stats.node_stats()) {
2851 if (stats.node_name() == "a") {
2852 ++num_a;
2853 } else if (stats.node_name() == "Less") {
2854 ++num_less;
2855 } else if (stats.node_name() == "Switch") {
2856 ++num_switch;
2857 } else if (stats.node_name() == "First") {
2858 ++num_first;
2859 } else if (stats.node_name() == "Second") {
2860 ++num_second;
2861 } else if (stats.node_name() == "Merge") {
2862 ++num_merge;
2863 }
2864 }
2865 }
2866
2867 EXPECT_EQ(1, num_a);
2868 EXPECT_EQ(1, num_less);
2869 EXPECT_EQ(1, num_switch);
2870 EXPECT_EQ(1, num_first);
2871 EXPECT_EQ(1, num_second);
2872 EXPECT_EQ(2, num_merge);
2873
2874 EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
2875 // Merge is executed twice.
2876 EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
2877 EXPECT_FALSE(c.inaccurate);
2878 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2879 }
2880
2881 // With annotation.
2882 {
2883 // Inits.
2884 CreateGrapplerItemWithCondition();
2885
2886 // Annotates the Switch node.
2887 for (auto& node : *grappler_item_->graph.mutable_node()) {
2888 if (node.name() == "Switch") {
2889 AttrValue attr_output_info;
2890 // Adds one output slot 0 so that Second shouldn't be executed.
2891 (*attr_output_info.mutable_list()).add_i(0);
2892 AddNodeAttr(kOutputSlots, attr_output_info, &node);
2893 }
2894 }
2895
2896 InitScheduler();
2897
2898 // Runs the scheduler.
2899 RunScheduler("");
2900 RunMetadata metadata;
2901 Costs c = scheduler_->Summary(&metadata);
2902
2903 // Nodes in topological order: a/Less, Switch, Merge
2904 int num_a = 0;
2905 int num_less = 0;
2906 int num_switch = 0;
2907 int num_first = 0;
2908 int num_second = 0;
2909 int num_merge = 0;
2910
2911 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2912 for (const auto& stats : device_step_stats.node_stats()) {
2913 if (stats.node_name() == "a") {
2914 ++num_a;
2915 } else if (stats.node_name() == "Less") {
2916 ++num_less;
2917 } else if (stats.node_name() == "Switch") {
2918 ++num_switch;
2919 } else if (stats.node_name() == "First") {
2920 ++num_first;
2921 } else if (stats.node_name() == "Second") {
2922 ++num_second;
2923 } else if (stats.node_name() == "Merge") {
2924 ++num_merge;
2925 }
2926 }
2927 }
2928
2929 EXPECT_EQ(1, num_a);
2930 EXPECT_EQ(1, num_less);
2931 EXPECT_EQ(1, num_switch);
2932 EXPECT_EQ(1, num_first);
2933 EXPECT_EQ(0, num_second);
2934 EXPECT_EQ(1, num_merge);
2935
2936 EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
2937 // Second is not executed.
2938 EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
2939 EXPECT_FALSE(c.inaccurate);
2940 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2941 }
2942 }
2943
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)2944 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
2945 // Init.
2946 CreateGrapplerItemWithInterDeviceTransfers();
2947 InitScheduler();
2948
2949 // Run the scheduler.
2950 auto ops_executed = RunScheduler("");
2951
2952 // Helper lambda to extract port num from _Send and _Recv op name.
2953 auto get_port_num = [](const string& name) -> int {
2954 if (name.find("bn_0") != string::npos) {
2955 return 0;
2956 } else if (name.find("bn_1") != string::npos) {
2957 return 1;
2958 } else if (name.find("bn_2") != string::npos) {
2959 return 2;
2960 } else if (name.find("bn_minus1") != string::npos) {
2961 return -1;
2962 }
2963 return -999;
2964 };
2965
2966 // Reorganize ops_executed for further testing.
2967 std::unordered_map<string, int> op_count;
2968 std::unordered_map<int, string> recv_op_names;
2969 std::unordered_map<int, string> send_op_names;
2970 for (const auto& x : ops_executed) {
2971 const auto& name = x.first;
2972 const auto& node_info = x.second;
2973 const auto& op = node_info.op_info.op();
2974 if (op == kRecv) {
2975 recv_op_names[get_port_num(name)] = name;
2976 } else if (op == kSend) {
2977 send_op_names[get_port_num(name)] = name;
2978 }
2979 op_count[op]++;
2980 }
2981
2982 // Same number of _Send and _Recv.
2983 EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
2984
2985 // Expect 3 Send and Recvs each: port 0, 1, and, 2.
2986 // Control dependency bypasses the channel.
2987 EXPECT_EQ(op_count.at(kRecv), 3);
2988 EXPECT_EQ(op_count.at(kSend), 3);
2989
2990 // Helper lambda for extracting output Tensor size.
2991 auto get_output_size = [this, ops_executed](const string& name) -> int64 {
2992 const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
2993 std::vector<OpInfo::TensorProperties> output_properties;
2994 for (const auto& output_property : output_properties_) {
2995 output_properties.push_back(output_property);
2996 }
2997 return CalculateOutputSize(output_properties, 0);
2998 };
2999
3000 // Validate transfer size.
3001 // Batchnorm output y is 4D vector: batch x width x width x depth.
3002 int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
3003 EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
3004 EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
3005 // Mean and vars are 1-D vector with size depth_in_.
3006 EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
3007 EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
3008 EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
3009 EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
3010 }
3011
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)3012 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
3013 // Init.
3014 CreateGrapplerItemWithSendRecv();
3015 InitScheduler();
3016
3017 // Run the scheduler.
3018 auto ops_executed = RunScheduler("");
3019
3020 EXPECT_GT(ops_executed.count("Const"), 0);
3021 EXPECT_GT(ops_executed.count("Send"), 0);
3022 EXPECT_GT(ops_executed.count("Recv"), 0);
3023 }
3024
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)3025 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
3026 // Init.
3027 CreateGrapplerItemWithSendRecv();
3028 // Change Recv node's device so that Send and Recv are placed on different
3029 // devices.
3030 auto& graph = grappler_item_->graph;
3031 const string recv_device = kCPU1;
3032 for (int i = 0; i < graph.node_size(); i++) {
3033 auto* node = graph.mutable_node(i);
3034 if (node->name() == "Recv") {
3035 node->set_device(recv_device);
3036 auto* attr = node->mutable_attr();
3037 (*attr)["recv_device"].set_s(recv_device);
3038 } else if (node->name() == "Send") {
3039 auto* attr = node->mutable_attr();
3040 (*attr)["recv_device"].set_s(recv_device);
3041 }
3042 }
3043 InitScheduler();
3044
3045 // Run the scheduler.
3046 auto ops_executed = RunScheduler("");
3047
3048 // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
3049 EXPECT_GT(ops_executed.count("Const"), 0);
3050 EXPECT_GT(ops_executed.count("Send"), 0);
3051 EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
3052 "task_0/cpu_0_to_/job_localhost"
3053 "/replica_0/task_0/cpu_1"),
3054 0);
3055 EXPECT_GT(ops_executed.count(
3056 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
3057 0);
3058 EXPECT_GT(ops_executed.count("Recv"), 0);
3059 }
3060
TEST_F(VirtualSchedulerTest,GraphWihtOnlyRecv)3061 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) {
3062 // Init.
3063 CreateGrapplerItemWithRecvWithoutSend();
3064 InitScheduler();
3065
3066 // Run the scheduler.
3067 auto ops_executed = RunScheduler("");
3068
3069 // Recv without Send will be treated as initially ready node.
3070 EXPECT_GT(ops_executed.count("Recv"), 0);
3071 }
3072
TEST_F(VirtualSchedulerTest,AddMergeSwitch)3073 TEST_F(VirtualSchedulerTest, AddMergeSwitch) {
3074 // Override scheduler_ with CompositeNodeManager.
3075 scheduler_ = absl::make_unique<TestVirtualScheduler>(
3076 /*use_static_shapes=*/true,
3077 /*use_aggressive_shape_inference=*/true, &composite_node_manager_,
3078 cluster_.get());
3079 CreateGrapplerItemWithSwitchMergeInput();
3080 InitScheduler();
3081
3082 // pred --+ z --+
3083 // | |
3084 // V V
3085 // x -> Switch --------> Merge ---> Add --> y
3086 // | ^
3087 // | |
3088 // +-----> Add -----+
3089 // ^
3090 // |
3091 // b --------------+
3092
3093 // Run the scheduler. The current VirtualScheduler, w/o annotation, triggers
3094 // both outputs of Switch; then Merge (as long as one input is ready, it's z
3095 // is ready, if we just use num_inputs_ready counter, the final Add becomes
3096 // ready. possible to skip scheduling z. (Need to use CompositeNodeManager
3097 // to test this case).
3098 auto ops_executed = RunScheduler("");
3099
3100 EXPECT_GT(ops_executed.count("z"), 0);
3101 }
3102
TEST_F(VirtualSchedulerTest,AddFromOneTensor)3103 TEST_F(VirtualSchedulerTest, AddFromOneTensor) {
3104 CreateGrapplerItemWithAddFromOneTensor();
3105 InitScheduler();
3106
3107 // x -+----> Add --> y
3108 // | ^
3109 // | |
3110 // +-------+
3111
3112 // Run the scheduler.
3113 auto ops_executed = RunScheduler("");
3114 EXPECT_GT(ops_executed.count("y"), 0);
3115 EXPECT_GT(ops_executed.count("x"), 0);
3116 }
3117
3118 } // namespace
3119 } // end namespace grappler
3120 } // end namespace tensorflow
3121