1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_partition.h"
17 
18 #include <unordered_map>
19 #include <utility>
20 
21 #include "tensorflow/cc/ops/array_ops.h"
22 #include "tensorflow/cc/ops/const_op.h"
23 #include "tensorflow/cc/ops/control_flow_ops.h"
24 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
25 #include "tensorflow/cc/ops/math_ops.h"
26 #include "tensorflow/cc/ops/random_ops.h"
27 #include "tensorflow/cc/ops/sendrecv_ops.h"
28 #include "tensorflow/cc/ops/while_loop.h"
29 #include "tensorflow/core/common_runtime/graph_constructor.h"
30 #include "tensorflow/core/framework/common_shape_fns.h"
31 #include "tensorflow/core/framework/function_testlib.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/graph/graph.h"
35 #include "tensorflow/core/graph/graph_def_builder.h"
36 #include "tensorflow/core/kernels/ops_util.h"
37 #include "tensorflow/core/lib/core/status_test_util.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/protobuf.h"
41 #include "tensorflow/core/platform/test.h"
42 #include "tensorflow/core/public/version.h"
43 #include "tensorflow/core/util/equal_graph_def.h"
44 
45 namespace tensorflow {
46 
47 // from graph_partition.cc
48 extern Status TopologicalSortNodesWithTimePriority(
49     const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
50     std::unordered_map<const NodeDef*, int64>* node_to_start_time_out);
51 
52 namespace {
53 
54 using ops::_Recv;
55 using ops::_Send;
56 using ops::Const;
57 using ops::Identity;
58 using ops::LoopCond;
59 using ops::NextIteration;
60 
61 const char gpu_device[] = "/job:a/replica:0/task:0/device:GPU:0";
62 
SplitByDevice(const Node * node)63 string SplitByDevice(const Node* node) { return node->assigned_device_name(); }
64 
DeviceName(const Node * node)65 string DeviceName(const Node* node) {
66   char first = node->name()[0];
67   if (first == 'G') {
68     return gpu_device;
69   } else {
70     const string cpu_prefix = "/job:a/replica:0/task:0/cpu:";
71     int index = first - 'A';
72     return strings::StrCat(cpu_prefix, index);
73   }
74 }
75 
Partition(const GraphDef & graph_def,std::unordered_map<string,GraphDef> * partitions)76 void Partition(const GraphDef& graph_def,
77                std::unordered_map<string, GraphDef>* partitions) {
78   Graph g(OpRegistry::Global());
79   GraphConstructorOptions opts;
80   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g));
81 
82   // Assigns devices to each node. Uses 1st letter of the node name as the
83   // device index if no device is specified.
84   for (Node* node : g.nodes()) {
85     string device_name = !node->requested_device().empty()
86                              ? node->requested_device()
87                              : DeviceName(node);
88     node->set_assigned_device_name(device_name);
89   }
90 
91   PartitionOptions popts;
92   popts.node_to_loc = SplitByDevice;
93   popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); };
94   popts.get_incarnation = [](const string& name) {
95     return (name[0] - 'A') + 100;
96   };
97   Status s = Partition(popts, &g, partitions);
98   CHECK(s.ok()) << s;
99 
100   // Check versions.
101   EXPECT_EQ(graph_def.versions().producer(), TF_GRAPH_DEF_VERSION);
102   // Partitions must inherit the versions of the original graph.
103   for (auto& it : *partitions) {
104     EXPECT_EQ(graph_def.versions().producer(), it.second.versions().producer());
105     EXPECT_EQ(graph_def.versions().min_consumer(),
106               it.second.versions().min_consumer());
107   }
108 }
109 
CheckLoopConstruction(const GraphDef & graph_def)110 void CheckLoopConstruction(const GraphDef& graph_def) {
111   std::unordered_map<string, GraphDef> partitions;
112   Partition(graph_def, &partitions);
113   for (const auto& kv : partitions) {
114     const GraphDef& gdef = kv.second;
115     bool has_control_enter = false;
116     bool has_control_merge = false;
117     bool has_control_switch = false;
118     bool has_control_next = false;
119     for (const NodeDef& ndef : gdef.node()) {
120       // _recvs must have a control input
121       if (ndef.op() == "_Recv") {
122         bool has_control = false;
123         for (const string& input_name : ndef.input()) {
124           if (absl::StartsWith(input_name, "^")) {
125             has_control = true;
126             break;
127           }
128         }
129         EXPECT_TRUE(has_control);
130       }
131       // Must have a control loop
132       if (absl::StartsWith(ndef.name(), "_cloop")) {
133         if (ndef.op() == "Enter") {
134           has_control_enter = true;
135         }
136         if (ndef.op() == "Merge") {
137           has_control_merge = true;
138         }
139         if (ndef.op() == "Switch") {
140           has_control_switch = true;
141         }
142         if (ndef.op() == "NextIteration") {
143           has_control_next = true;
144         }
145       }
146     }
147     EXPECT_TRUE(has_control_enter);
148     EXPECT_TRUE(has_control_merge);
149     EXPECT_TRUE(has_control_switch);
150     EXPECT_TRUE(has_control_next);
151   }
152 }
153 
154 REGISTER_OP("FloatInput")
155     .Output("o: float")
156     .SetShapeFn(shape_inference::UnknownShape);
157 REGISTER_OP("BoolInput")
158     .Output("o: bool")
159     .SetShapeFn(shape_inference::UnknownShape);
160 REGISTER_OP("Combine")
161     .Input("a: float")
162     .Input("b: float")
163     .Output("o: float")
164     .SetShapeFn(shape_inference::UnknownShape);
165 
ConstructOp(const Scope & scope,const string & op_type,const gtl::ArraySlice<Input> & inputs)166 Output ConstructOp(const Scope& scope, const string& op_type,
167                    const gtl::ArraySlice<Input>& inputs) {
168   if (!scope.ok()) return Output();
169   const string unique_name = scope.GetUniqueNameForOp(op_type);
170   auto builder =
171       NodeBuilder(unique_name, op_type, scope.graph()->op_registry());
172   for (auto const& input : inputs) {
173     builder.Input(ops::NodeOut(input.node(), input.index()));
174   }
175   scope.UpdateBuilder(&builder);
176   Node* ret;
177   scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
178   if (!scope.ok()) return Output();
179   scope.UpdateStatus(scope.DoShapeInference(ret));
180   if (!scope.ok()) return Output();
181   return Output(ret);
182 }
183 
FloatInput(const Scope & scope)184 Output FloatInput(const Scope& scope) {
185   return ConstructOp(scope, "FloatInput", {});
186 }
187 
BoolInput(const Scope & scope)188 Output BoolInput(const Scope& scope) {
189   return ConstructOp(scope, "BoolInput", {});
190 }
191 
Combine(const Scope & scope,Input a,Input b)192 Output Combine(const Scope& scope, Input a, Input b) {
193   return ConstructOp(scope, "Combine", {std::move(a), std::move(b)});
194 }
195 
196 class GraphPartitionTest : public ::testing::Test {
197  protected:
GraphPartitionTest()198   GraphPartitionTest()
199       : in_(Scope::NewRootScope().ExitOnError()),
200         scope_a_(Scope::NewRootScope().ExitOnError().WithDevice(
201             "/job:a/replica:0/task:0/cpu:0")),
202         scope_b_(Scope::NewRootScope().ExitOnError().WithDevice(
203             "/job:a/replica:0/task:0/cpu:1")) {}
204 
ToGraphDef()205   const GraphDef& ToGraphDef() {
206     TF_EXPECT_OK(in_.ToGraphDef(&in_graph_def_));
207     return in_graph_def_;
208   }
209 
ExpectMatchA()210   void ExpectMatchA() {
211     GraphDef graph_def;
212     TF_EXPECT_OK(scope_a_.ToGraphDef(&graph_def));
213     string a = "/job:a/replica:0/task:0/cpu:0";
214     TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]);
215   }
216 
ExpectMatchB()217   void ExpectMatchB() {
218     GraphDef graph_def;
219     TF_EXPECT_OK(scope_b_.ToGraphDef(&graph_def));
220     string b = "/job:a/replica:0/task:0/cpu:1";
221     TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]);
222   }
223 
ExpectFunctions(const FunctionDefLibrary & library,const std::set<string> & expected_names)224   void ExpectFunctions(const FunctionDefLibrary& library,
225                        const std::set<string>& expected_names) {
226     std::set<string> actual_names;
227     for (const FunctionDef& fdef : library.function()) {
228       actual_names.insert(fdef.signature().name());
229     }
230     EXPECT_EQ(actual_names, expected_names);
231   }
232 
233   Scope in_;
234   GraphDef in_graph_def_;
235   Scope scope_a_;
236   Scope scope_b_;
237   std::unordered_map<string, GraphDef> partitions_;
238 };
239 
TEST_F(GraphPartitionTest,SingleDevice)240 TEST_F(GraphPartitionTest, SingleDevice) {
241   auto a1 = FloatInput(in_.WithOpName("A1"));
242   Combine(in_.WithOpName("A2"), a1, a1);
243 
244   Partition(ToGraphDef(), &partitions_);
245   EXPECT_EQ(1, partitions_.size());
246 
247   a1 = FloatInput(scope_a_.WithOpName("A1"));
248   Combine(scope_a_.WithOpName("A2"), a1, a1);
249   ExpectMatchA();
250 }
251 
TEST_F(GraphPartitionTest,CrossDeviceData)252 TEST_F(GraphPartitionTest, CrossDeviceData) {
253   auto a1 = FloatInput(in_.WithOpName("A1"));
254   auto b1 = FloatInput(in_.WithOpName("B1"));
255   Combine(in_.WithOpName("B2"), a1, b1);
256 
257   Partition(ToGraphDef(), &partitions_);
258   EXPECT_EQ(2, partitions_.size());
259 
260   string a = "/job:a/replica:0/task:0/cpu:0";
261   string b = "/job:a/replica:0/task:0/cpu:1";
262   a1 = FloatInput(scope_a_.WithOpName("A1"));
263   _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
264   ExpectMatchA();
265 
266   b1 = FloatInput(scope_b_.WithOpName("B1"));
267   auto recv =
268       _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
269   Combine(scope_b_.WithOpName("B2"), recv, b1);
270   ExpectMatchB();
271 }
272 
TEST_F(GraphPartitionTest,CrossDeviceControl)273 TEST_F(GraphPartitionTest, CrossDeviceControl) {
274   auto a1 = FloatInput(in_.WithOpName("A1"));
275   auto b1 = FloatInput(in_.WithOpName("B1"));
276   Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
277 
278   Partition(ToGraphDef(), &partitions_);
279   EXPECT_EQ(2, partitions_.size());
280 
281   string a = "/job:a/replica:0/task:0/cpu:0";
282   string b = "/job:a/replica:0/task:0/cpu:1";
283   a1 = FloatInput(scope_a_.WithOpName("A1"));
284   auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
285   _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
286   ExpectMatchA();
287 
288   auto recv =
289       _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
290   auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
291   b1 = FloatInput(scope_b_.WithOpName("B1"));
292   Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
293   ExpectMatchB();
294 }
295 
TEST_F(GraphPartitionTest,CrossDeviceData_MultiUse)296 TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
297   auto a1 = FloatInput(in_.WithOpName("A1"));
298   auto b1 = FloatInput(in_.WithOpName("B1"));
299   Combine(in_.WithOpName("B2"), a1, b1);
300   Combine(in_.WithOpName("B3"), a1, a1);
301 
302   Partition(ToGraphDef(), &partitions_);
303   EXPECT_EQ(2, partitions_.size());
304 
305   string a = "/job:a/replica:0/task:0/cpu:0";
306   string b = "/job:a/replica:0/task:0/cpu:1";
307   a1 = FloatInput(scope_a_.WithOpName("A1"));
308   _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
309   ExpectMatchA();
310 
311   auto recv =
312       _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
313   b1 = FloatInput(scope_b_.WithOpName("B1"));
314   Combine(scope_b_.WithOpName("B2"), recv, b1);
315   Combine(scope_b_.WithOpName("B3"), recv, recv);
316   ExpectMatchB();
317 }
318 
TEST_F(GraphPartitionTest,CrossDeviceControl_MultiUse)319 TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
320   auto a1 = FloatInput(in_.WithOpName("A1"));
321   auto b1 = FloatInput(in_.WithOpName("B1"));
322   Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
323   FloatInput(in_.WithOpName("B3").WithControlDependencies(a1));
324 
325   Partition(ToGraphDef(), &partitions_);
326   EXPECT_EQ(2, partitions_.size());
327 
328   string a = "/job:a/replica:0/task:0/cpu:0";
329   string b = "/job:a/replica:0/task:0/cpu:1";
330   a1 = FloatInput(scope_a_.WithOpName("A1"));
331   auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
332   _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
333   ExpectMatchA();
334 
335   auto recv =
336       _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
337   auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
338   b1 = FloatInput(scope_b_.WithOpName("B1"));
339   Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
340   FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id));
341   ExpectMatchB();
342 }
343 
TEST_F(GraphPartitionTest,CrossDevice_DataControl)344 TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
345   auto a1 = FloatInput(in_.WithOpName("A1"));
346   auto b1 = FloatInput(in_.WithOpName("B1"));
347   Combine(in_.WithOpName("B2"), a1, b1);
348   FloatInput(in_.WithOpName("B3").WithControlDependencies(a1));
349 
350   Partition(ToGraphDef(), &partitions_);
351   EXPECT_EQ(2, partitions_.size());
352 
353   string a = "/job:a/replica:0/task:0/cpu:0";
354   string b = "/job:a/replica:0/task:0/cpu:1";
355   a1 = FloatInput(scope_a_.WithOpName("A1"));
356   _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
357   auto c = Const(scope_a_.WithOpName("A1/_2").WithControlDependencies(a1), {});
358   // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could
359   // use A1/_0 -> A1/_4 as the control as a minor optimization.
360   _Send(scope_a_.WithOpName("A1/_3"), c, "edge_3_A1", a, 82, b);
361   ExpectMatchA();
362 
363   auto recv1 =
364       _Recv(scope_b_.WithOpName("A1/_4"), DT_FLOAT, "edge_3_A1", a, 82, b);
365   auto id1 = Identity(scope_b_.WithOpName("A1/_5"), recv1);
366   auto recv2 =
367       _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
368   b1 = FloatInput(scope_b_.WithOpName("B1"));
369   Combine(scope_b_.WithOpName("B2"), recv2, b1);
370   FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id1));
371   ExpectMatchB();
372 }
373 
TEST_F(GraphPartitionTest,CrossDeviceLoopSimple)374 TEST_F(GraphPartitionTest, CrossDeviceLoopSimple) {
375   auto a1 = BoolInput(in_.WithOpName("A1"));
376   auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("A2"), a1, "foo");
377   auto a3 = ::tensorflow::ops::Merge(in_.WithOpName("A3"),
378                                      {a2, Input("A5", 0, DT_BOOL)})
379                 .output;
380   LoopCond(in_.WithOpName("A4"), a3);
381   auto b1 = Identity(in_.WithOpName("B1"), a3);
382   NextIteration(in_.WithOpName("A5"), b1);
383 
384   CheckLoopConstruction(ToGraphDef());
385 }
386 
TEST_F(GraphPartitionTest,CrossDeviceLoopSimple1)387 TEST_F(GraphPartitionTest, CrossDeviceLoopSimple1) {
388   auto a1 = BoolInput(in_.WithOpName("A1"));
389   auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("B2"), a1, "foo");
390   auto a3 = ::tensorflow::ops::Merge(in_.WithOpName("A3"),
391                                      {a2, Input("B5", 0, DT_BOOL)})
392                 .output;
393   LoopCond(in_.WithOpName("A4"), a3);
394   auto b1 = Identity(in_.WithOpName("B1"), a3);
395   NextIteration(in_.WithOpName("B5"), b1);
396 
397   std::unordered_map<string, GraphDef> partitions;
398   Partition(ToGraphDef(), &partitions);
399   for (const auto& kv : partitions) {
400     const GraphDef& gdef = kv.second;
401     for (const NodeDef& ndef : gdef.node()) {
402       if (ndef.name() == "A3") {
403         // A3, B2, and B5 are on the same device.
404         EXPECT_EQ(ndef.input(0), "B2");
405         EXPECT_EQ(ndef.input(1), "B5");
406       }
407     }
408   }
409 }
410 
TEST_F(GraphPartitionTest,CrossDeviceLoopFull)411 TEST_F(GraphPartitionTest, CrossDeviceLoopFull) {
412   Scope cpu0 = in_.WithDevice("/job:a/replica:0/task:0/cpu:0");
413   auto p1 = ops::Placeholder(cpu0, DT_INT32);
414   auto p2 = ops::Placeholder(cpu0, DT_INT32);
415   OutputList outputs;
416   // while i1 < 10: i1 += i2
417   TF_ASSERT_OK(ops::BuildWhileLoop(
418       cpu0, {p1, p2},
419       [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
420         *output = ops::Less(s, inputs[0], 10);
421         return s.status();
422       },
423       [](const Scope& s, const std::vector<Output>& inputs,
424          std::vector<Output>* outputs) {
425         Scope cpu1 = s.WithDevice("/job:a/replica:0/task:0/cpu:1");
426         outputs->push_back(ops::AddN(cpu1, {inputs[0], inputs[1]}));
427         outputs->push_back(inputs[1]);
428         return s.status();
429       },
430       "test_loop", &outputs));
431   CheckLoopConstruction(ToGraphDef());
432 }
433 
TEST_F(GraphPartitionTest,PartitionIncompleteGraph)434 TEST_F(GraphPartitionTest, PartitionIncompleteGraph) {
435   NodeDef ndef;
436   Graph g(OpRegistry::Global());
437   // Invalid graph since the Combine node requires an input.
438   bool parsed = protobuf::TextFormat::ParseFromString(
439       R"EOF(
440       name: "N"
441       op: "Combine"
442       )EOF",
443       &ndef);
444   ASSERT_TRUE(parsed);
445   Status status;
446   g.AddNode(ndef, &status);
447   TF_ASSERT_OK(status);
448 
449   PartitionOptions popts;
450   popts.node_to_loc = SplitByDevice;
451   popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); };
452   popts.get_incarnation = [](const string&) { return 1; };
453 
454   std::unordered_map<string, GraphDef> partitions;
455   status = Partition(popts, &g, &partitions);
456   // Partitioning should fail, but not crash like it did before the
457   // changes that accompanied the addition of this test.
458   EXPECT_EQ(error::INVALID_ARGUMENT, status.code()) << status;
459 }
460 
TEST_F(GraphPartitionTest,Functions)461 TEST_F(GraphPartitionTest, Functions) {
462   FunctionDefLibrary fdef_lib;
463   *fdef_lib.add_function() = test::function::XTimesTwo();
464   *fdef_lib.add_function() = test::function::XTimesFour();
465   TF_ASSERT_OK(in_.graph()->AddFunctionLibrary(fdef_lib));
466 
467   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
468   auto a1 = FloatInput(in_.WithOpName("A1"));
469   auto b1 = FloatInput(in_.WithOpName("B1"));
470   ConstructOp(in_.WithOpName("A2"), "XTimesTwo", {a1});
471   ConstructOp(in_.WithOpName("B2"), "XTimesFour", {b1});
472 
473   // The `Partition()` helper function uses the first letter of the op name ('A'
474   // or 'B') to choose a device for each node.
475   Partition(ToGraphDef(), &partitions_);
476   EXPECT_EQ(2, partitions_.size());
477 
478   // Test that partition graphs inherit function library from original graph.
479   string a = "/job:a/replica:0/task:0/cpu:0";
480   string b = "/job:a/replica:0/task:0/cpu:1";
481 
482   // Node "A2" is placed in part `a`, and uses only "XTimesTwo".
483   ExpectFunctions(partitions_[a].library(), {"XTimesTwo"});
484   // Node "B2" is placed in part `b`, and uses both "XTimesFour" directly,
485   // and "XTimesTwo" in the body of "XTimesFour".
486   ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"});
487 }
488 
TEST_F(GraphPartitionTest,SetIncarnation)489 TEST_F(GraphPartitionTest, SetIncarnation) {
490   GraphDef gdef;
491   const char* const kSendRecvAttrs = R"proto(
492   attr { key: 'T' value { type: DT_FLOAT  }  }
493   attr { key: 'client_terminated' value {  b: false } }
494   attr { key: 'recv_device' value { s: 'B' } }
495   attr { key: 'send_device' value { s: 'A' } }
496   attr { key: 'send_device_incarnation' value { i: 0 }  }
497   attr { key: 'tensor_name' value { s: 'test' } }
498 )proto";
499   CHECK(protobuf::TextFormat::ParseFromString(
500       strings::StrCat(
501           "node { name: 'A/Pi' op: 'Const' ",
502           "  attr { key: 'dtype' value { type: DT_FLOAT } } ",
503           "  attr { key: 'value' value { tensor { ",
504           "    dtype: DT_FLOAT tensor_shape {} float_val: 3.14 } } } }",
505           "node { name: 'A' op: '_Send' input: 'A/Pi' ", kSendRecvAttrs, "}",
506           "node { name: 'B' op: '_Recv' ", kSendRecvAttrs,
507           "  attr { key: 'tensor_type' value { type:DT_FLOAT}}}"),
508       &gdef));
509   gdef.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION);
510   Partition(gdef, &partitions_);
511   EXPECT_EQ(2, partitions_.size());
512 
513   for (const auto& kv : partitions_) {
514     const GraphDef& gdef = kv.second;
515     for (const NodeDef& ndef : gdef.node()) {
516       if (ndef.name() == "A" || ndef.name() == "B") {
517         int64 val;
518         TF_CHECK_OK(GetNodeAttr(ndef, "send_device_incarnation", &val));
519         EXPECT_EQ(val, 100);  // Send device is "A".
520       }
521     }
522   }
523 }
524 
TEST(TopologicalSortNodesWithTimePriorityTest,NoDependencies)525 TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) {
526   // Create placeholders, shuffle them so the order in the graph is not strictly
527   // increasing.
528   Scope root = Scope::NewRootScope().ExitOnError();
529   std::vector<int> indexes;
530   for (int i = 0; i < 20; ++i) {
531     indexes.push_back((i + 2001) % 20);
532   }
533   std::vector<ops::Placeholder> placeholders;
534   for (int i : indexes) {
535     placeholders.emplace_back(root.WithOpName(strings::StrCat("p", i)),
536                               DT_FLOAT);
537     placeholders.back().node()->AddAttr("_start_time", i + 1);
538   }
539 
540   GraphDef gdef;
541   TF_EXPECT_OK(root.ToGraphDef(&gdef));
542 
543   std::vector<std::pair<const NodeDef*, int64>> nodes;
544   std::unordered_map<const NodeDef*, int64> node_to_start_time;
545   TF_CHECK_OK(
546       TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
547   ASSERT_EQ(nodes.size(), 20);
548   for (int i = 0; i < nodes.size(); ++i) {
549     EXPECT_EQ(strings::StrCat("p", i), nodes[i].first->name());
550     EXPECT_EQ(i + 1, nodes[i].second);
551   }
552 }
553 
TEST(TopologicalSortNodesWithTimePriority,Dependencies)554 TEST(TopologicalSortNodesWithTimePriority, Dependencies) {
555   // Create placeholders, shuffle them so the order in the graph is not strictly
556   // increasing.
557   Scope root = Scope::NewRootScope().ExitOnError();
558   std::vector<int> indexes;
559   std::vector<ops::Placeholder> placeholders_in_order;
560   const int num_leaves = 20;
561   for (int i = 0; i < num_leaves; ++i) {
562     indexes.push_back((i + 2001) % num_leaves);
563     placeholders_in_order.emplace_back(root.WithOpName(strings::StrCat("p", i)),
564                                        DT_FLOAT);
565     placeholders_in_order.back().node()->AddAttr("_start_time", i + 1);
566   }
567   std::vector<ops::Placeholder> placeholders;
568   for (int i : indexes) {
569     placeholders.push_back(placeholders_in_order[i]);
570   }
571 
572   // Create ops that depend on the placeholders. We give start times to these
573   // that are in descending order (e.g., the op that depends on the first
574   // placeholder runs last).
575   std::vector<ops::Square> squares;
576   for (int i : indexes) {
577     squares.emplace_back(root.WithOpName(strings::StrCat("s", i)),
578                          placeholders[i]);
579     squares.back().node()->AddAttr("_start_time", 50 - (i + 1));
580   }
581 
582   // Create addn to sum all squares.
583   std::vector<Input> inputs;
584   for (const auto& s : squares) inputs.push_back(s);
585   ops::AddN addn = ops::AddN(root.WithOpName("addn"),
586                              tensorflow::gtl::ArraySlice<Input>(inputs));
587   // Start times is actually listed earlier than the nodes it depends on.
588   // But because of dependency ordering, it is last in the list.
589   addn.node()->AddAttr("_start_time", 1);
590 
591   GraphDef gdef;
592   TF_EXPECT_OK(root.ToGraphDef(&gdef));
593 
594   std::vector<std::pair<const NodeDef*, int64>> nodes;
595   std::unordered_map<const NodeDef*, int64> node_to_start_time;
596   TF_CHECK_OK(
597       TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
598   ASSERT_EQ(1 + squares.size() + placeholders.size(), nodes.size());
599   for (int i = 0; i < placeholders.size(); ++i) {
600     const NodeDef* node = nodes[i].first;
601     EXPECT_EQ(strings::StrCat("p", i), node->name());
602     EXPECT_EQ(i + 1, nodes[i].second);
603     EXPECT_EQ(i + 1, node_to_start_time[node]);
604   }
605   for (int i = 0; i < squares.size(); ++i) {
606     int node_index = placeholders.size() + i;
607     int square_index = num_leaves - 1 - i;
608     const NodeDef* node = nodes[node_index].first;
609     EXPECT_EQ(strings::StrCat("s", square_index), node->name());
610     EXPECT_EQ(50 - (square_index + 1), nodes[node_index].second);
611     EXPECT_EQ(50 - (square_index + 1), node_to_start_time[node]);
612   }
613   EXPECT_EQ("addn", nodes.back().first->name());
614   EXPECT_EQ(50, nodes.back().second);
615   EXPECT_EQ(50, node_to_start_time[nodes.back().first]);
616 }
617 
TEST(TopologicalSortNodesWithTimePriority,WhileLoop)618 TEST(TopologicalSortNodesWithTimePriority, WhileLoop) {
619   using namespace ::tensorflow::ops;            // NOLINT(build/namespaces)
620   using namespace ::tensorflow::ops::internal;  // NOLINT(build/namespaces)
621 
622   // Create placeholders.
623   Scope root = Scope::NewRootScope().ExitOnError();
624   std::vector<int> indexes;
625   std::vector<Placeholder> placeholders_in_order;
626   const int num_leaves = 20;
627   for (int i = 0; i < num_leaves; ++i) {
628     indexes.push_back((i + 2001) % num_leaves);
629     placeholders_in_order.emplace_back(root.WithOpName(strings::StrCat("p", i)),
630                                        DT_FLOAT);
631     placeholders_in_order.back().node()->AddAttr("_start_time", i + 1);
632   }
633   std::vector<Placeholder> placeholders;
634   placeholders.reserve(indexes.size());
635   for (int i : indexes) {
636     placeholders.push_back(placeholders_in_order[i]);
637   }
638 
639   // Add a while loop above each placeholder.
640   std::vector<Exit> while_exits;
641   const int nodes_per_loop = 8;
642   for (int i : indexes) {
643     Scope scope = root.NewSubScope(strings::StrCat("while", i));
644     auto dummy = Placeholder(scope, DT_FLOAT);
645 
646     Enter enter(scope, placeholders[i], strings::StrCat("frame", i));
647     Merge merge(scope, std::initializer_list<Input>{enter, dummy});
648     auto cv = Const(scope.WithControlDependencies({merge.output}), false);
649     LoopCond loop_cond(scope, cv);
650     Switch switch_node(scope, merge.output, loop_cond);
651     Identity identity(scope, switch_node.output_true);
652     NextIteration next_iteration(scope, identity);
653     while_exits.emplace_back(scope.WithOpName("exit"),
654                              switch_node.output_false);
655 
656     // Complete loop by removing dummy node and attaching NextIteration to
657     // that input of the merge node.
658     scope.graph()->RemoveNode(dummy.node());
659     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
660 
661     int base_start_time = i * 10 + 100;
662     for (const auto& op : std::initializer_list<Output>{
663              enter, merge.output, cv, loop_cond, switch_node.output_false,
664              identity, next_iteration, while_exits.back()}) {
665       op.node()->AddAttr("_start_time", base_start_time++);
666     }
667   }
668 
669   // Create ops that depend on the loop exits.
670   std::vector<Square> squares;
671   squares.reserve(indexes.size());
672   for (int i : indexes) {
673     squares.emplace_back(root.WithOpName(strings::StrCat("s", i)),
674                          while_exits[i]);
675     squares.back().node()->AddAttr("_start_time", 500 - (i + 1));
676   }
677 
678   GraphDef gdef;
679   TF_EXPECT_OK(root.ToGraphDef(&gdef));
680 
681   // Run the sort. The while loop nodes do not appear in the output <nodes>.
682   std::vector<std::pair<const NodeDef*, int64>> nodes;
683   std::unordered_map<const NodeDef*, int64> node_to_start_time;
684   TF_CHECK_OK(
685       TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
686   ASSERT_LT(while_exits.size() + squares.size() + placeholders.size(),
687             nodes.size());
688   int node_index = 0;
689   for (int i = 0; i < placeholders.size(); ++i, ++node_index) {
690     const NodeDef* node = nodes[i].first;
691     EXPECT_EQ(strings::StrCat("p", i), node->name());
692     EXPECT_EQ(i + 1, nodes[i].second);
693     EXPECT_EQ(i + 1, node_to_start_time[node]);
694   }
695   for (int i = 0; i < while_exits.size(); ++i, node_index += nodes_per_loop) {
696     const NodeDef* node = nodes[node_index].first;
697     EXPECT_EQ(strings::StrCat("while", i, "/Enter"), node->name());
698     EXPECT_EQ(100 + i * 10, nodes[node_index].second);
699     EXPECT_EQ(100 + i * 10, node_to_start_time[node]);
700   }
701   for (int i = 0; i < squares.size(); ++i, ++node_index) {
702     int square_index = num_leaves - 1 - i;
703     const NodeDef* node = nodes[node_index].first;
704     EXPECT_EQ(strings::StrCat("s", square_index), node->name());
705     EXPECT_EQ(500 - (square_index + 1), nodes[node_index].second);
706     EXPECT_EQ(500 - (square_index + 1), node_to_start_time[node]);
707   }
708 }
709 
710 }  // namespace
711 }  // namespace tensorflow
712