1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_partition.h"
17 
18 #include <deque>
19 #include <queue>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/control_flow.h"
34 #include "tensorflow/core/graph/costmodel.h"
35 #include "tensorflow/core/graph/graph_def_builder.h"
36 #include "tensorflow/core/graph/node_builder.h"
37 #include "tensorflow/core/graph/tensor_id.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tensorflow/core/util/dump_graph.h"
44 
45 namespace tensorflow {
46 
47 namespace {
48 
IsMerge(const NodeDef & node_def)49 inline bool IsMerge(const NodeDef& node_def) {
50   return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
51          node_def.op() == "_XlaMerge";
52 }
53 
IsNextIteration(const NodeDef & node_def)54 inline bool IsNextIteration(const NodeDef& node_def) {
55   return node_def.op() == "NextIteration" ||
56          node_def.op() == "RefNextIteration";
57 }
58 
59 struct DupRecvKey {
60   int src_node_id;           // Edge's src node id
61   int src_output_slot;       // Edge's src node output slot
62   GraphDef* dst_graph;       // Edge's dst node is in this subgraph
63   bool recv_output_on_host;  // The output of recv is on host
64 
65   template <typename H>
AbslHashValue(H h,const DupRecvKey & c)66   friend H AbslHashValue(H h, const DupRecvKey& c) {
67     return H::combine(std::move(h), c.src_node_id, c.src_output_slot,
68                       reinterpret_cast<std::uintptr_t>(c.dst_graph),
69                       c.recv_output_on_host);
70   }
71 
operator ==(const DupRecvKey & x,const DupRecvKey & y)72   friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) {
73     return (x.src_node_id == y.src_node_id) &&
74            (x.src_output_slot == y.src_output_slot) &&
75            (x.dst_graph == y.dst_graph) &&
76            (x.recv_output_on_host == y.recv_output_on_host);
77   }
78 };
79 
80 // struct used to store the recvs, so that start times can be properly updated
81 struct RecvInfo {
82   NodeDef* recv;
83   NodeDef* real_recv;
84   int64 start_time;
85 };
86 
87 typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable;
88 
89 // A map used to store memory types for the inputs/outputs of every node.
90 // The key is a pair of ints consisting of a node id and input/output index.
91 // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC.
92 struct NodePort {
93   int node_id;
94   int index;
95 
operator ==(const NodePort & x,const NodePort & y)96   friend bool operator==(const NodePort& x, const NodePort& y) {
97     return x.node_id == y.node_id && x.index == y.index;
98   }
99 
100   template <typename H>
AbslHashValue(H h,const NodePort & c)101   friend H AbslHashValue(H h, const NodePort& c) {
102     return H::combine(std::move(h), c.node_id, c.index);
103   }
104 };
105 
106 typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap;
107 
108 // We collect the following information about the graph before performing
109 // graph partitioning.
110 struct GraphInfo {
111   std::vector<DeviceType> device_types;
112   MemoryTypeMap input_types;
113   MemoryTypeMap output_types;
114   std::vector<ControlFlowInfo> cf_info;
115 };
116 
EdgeType(const Edge * e)117 DataType EdgeType(const Edge* e) {
118   if (e->IsControlEdge()) {
119     return DT_FLOAT;
120   } else {
121     return e->dst()->input_type(e->dst_input());
122   }
123 }
124 
125 // Return true iff we need to add the same device send/recv for 'edge'.
NeedSameDeviceSendRecv(const Edge * edge,const GraphInfo & info)126 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
127   if (edge->IsControlEdge()) {
128     return false;
129   }
130 
131   const Node* src = edge->src();
132   const Node* dst = edge->dst();
133   if (src->assigned_device_name() == dst->assigned_device_name()) {
134     int src_port = edge->src_output();
135     int dst_port = edge->dst_input();
136     if (info.device_types[src->id()] != DEVICE_CPU) {
137       auto src_it = info.output_types.find({src->id(), src_port});
138       DCHECK(src_it != info.output_types.end());
139       auto dst_it = info.input_types.find({dst->id(), dst_port});
140       DCHECK(dst_it != info.input_types.end());
141       return src_it->second != dst_it->second;
142     }
143   }
144   return false;
145 }
146 
147 // Return true iff (dst, dst_input) is specified on host memory.
IsDstInputOnHost(const Edge * edge,const GraphInfo & info)148 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
149   const Node* dst = edge->dst();
150   int dst_port = edge->dst_input();
151   if (info.device_types[dst->id()] != DEVICE_CPU) {
152     if (edge->IsControlEdge()) return false;
153     auto dst_it = info.input_types.find({dst->id(), dst_port});
154     DCHECK(dst_it != info.input_types.end());
155     return dst_it->second == HOST_MEMORY;
156   }
157   return true;
158 }
159 
160 // Add an input to dst that comes from the "src_slot" output of the
161 // node named by "src_name".
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)162 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
163   if (src_slot == Graph::kControlSlot) {
164     dst->add_input(strings::StrCat("^", src_name));
165   } else if (src_slot == 0) {
166     dst->add_input(src_name.data(), src_name.size());
167   } else {
168     dst->add_input(strings::StrCat(src_name, ":", src_slot));
169   }
170 }
171 
172 // Add a control edge from each input to each recv.
AddReadControl(const std::vector<NodeDef * > & recvs,const std::vector<string> & inputs)173 void AddReadControl(const std::vector<NodeDef*>& recvs,
174                     const std::vector<string>& inputs) {
175   for (NodeDef* recv : recvs) {
176     for (const string& input : inputs) {
177       recv->add_input(strings::StrCat("^", input));
178     }
179   }
180 }
181 
SetSendRecvAttrs(const PartitionOptions & opts,const Edge * edge,NodeDefBuilder * builder)182 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
183                       NodeDefBuilder* builder) {
184   builder->Attr("tensor_name",
185                 strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
186   builder->Attr("send_device", edge->src()->assigned_device_name());
187   builder->Attr("send_device_incarnation",
188                 static_cast<int64>(
189                     opts.get_incarnation(edge->src()->assigned_device_name())));
190   builder->Attr("recv_device", edge->dst()->assigned_device_name());
191   builder->Attr("client_terminated", false);
192   builder->Attr("_src", edge->src()->name());
193   builder->Attr("_dst", edge->dst()->name());
194 }
195 
AddSend(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDefBuilder::NodeOut send_from,int64 start_time,Status * status)196 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
197                  GraphDef* gdef, const Edge* edge,
198                  NodeDefBuilder::NodeOut send_from, int64 start_time,
199                  Status* status) {
200   const DataType dtype = send_from.data_type;
201   const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
202   const Node* src = edge->src();
203   const int src_port = edge->src_output();
204 
205   // host_memory = true iff we need to use HostSend/HostCast.
206   bool host_memory = false;
207   if (!edge->IsControlEdge()) {
208     auto src_it = g_info.output_types.find({src->id(), src_port});
209     DCHECK(src_it != g_info.output_types.end());
210     host_memory = (src_it->second == HOST_MEMORY);
211   }
212 
213   // Add a cast node that casts dtype to cast_dtype.
214   // NOTE(yuanbyu): Only cast for cross-device send/recv.
215   if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
216     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
217     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
218                                 NodeDebugInfo(*src));
219     cast_builder.Device(src->assigned_device_name()).Input(send_from);
220     if (opts.scheduling_for_recvs) {
221       cast_builder.Attr("_start_time", start_time);
222     }
223     cast_builder.Attr("DstT", cast_dtype);
224 
225     if (cast_dtype == DT_BFLOAT16) {
226       // the below attribute specifies that the cast to bfloat16 should use
227       // truncation. This is needed to retain legacy behavior when we change
228       // the default bfloat16 casts to use rounding instead of truncation
229       cast_builder.Attr("Truncate", true);
230     }
231 
232     NodeDef* cast = gdef->add_node();
233     *status = cast_builder.Finalize(cast, /*consume=*/true);
234     if (!status->ok()) return nullptr;
235 
236     // Connect the Send op to the cast.
237     send_from.Reset(cast->name(), 0, cast_dtype);
238   }
239 
240   // Add the send node.
241   const string send_op = (host_memory) ? "_HostSend" : "_Send";
242   NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
243                               NodeDebugInfo(*src));
244   SetSendRecvAttrs(opts, edge, &send_builder);
245   send_builder.Device(src->assigned_device_name()).Input(send_from);
246   if (opts.scheduling_for_recvs) {
247     send_builder.Attr("_start_time", start_time);
248   }
249   NodeDef* send = gdef->add_node();
250   *status = send_builder.Finalize(send, /*consume=*/true);
251   return send;
252 }
253 
AddRecv(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDef ** real_recv,Status * status)254 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
255                  GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
256                  Status* status) {
257   const DataType dtype = EdgeType(edge);
258   const Node* src = edge->src();
259   const Node* dst = edge->dst();
260   const int dst_port = edge->dst_input();
261   DataType cast_dtype = dtype;
262 
263   // NOTE(yuanbyu): Only cast for cross-device send/recv.
264   if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
265     cast_dtype = opts.should_cast(edge);
266   }
267 
268   // host_memory = true iff we need to use HostRecv/HostCast.
269   // Also log the introduction of the send-recv pair, for performance debugging.
270   bool host_memory = false;
271   if (!edge->IsControlEdge()) {
272     auto dst_it = g_info.input_types.find({dst->id(), dst_port});
273     DCHECK(dst_it != g_info.input_types.end());
274     host_memory = (dst_it->second == HOST_MEMORY);
275     bool src_host_memory = false;
276     if (VLOG_IS_ON(1)) {
277       const int src_port = edge->src_output();
278       auto src_it = g_info.output_types.find({src->id(), src_port});
279       DCHECK(src_it != g_info.output_types.end());
280       src_host_memory = (src_it->second == HOST_MEMORY);
281     }
282     VLOG(1) << "Receiving data"
283             << " from " << src->name() << " (" << src->type_string() << ")"
284             << " on " << src->assigned_device_name() << " in "
285             << (src_host_memory ? "host memory" : "device memory") << " for "
286             << dst->name() << " (" << dst->type_string() << ")"
287             << " on " << dst->assigned_device_name() << " in "
288             << (host_memory ? "host memory" : "device memory");
289   } else {
290     // Log control-edge transfers too, but don't mention memory space since it's
291     // irrelevant.
292     VLOG(1) << "Receiving control"
293             << " from " << src->name() << " (" << src->type_string() << ")"
294             << " on " << src->assigned_device_name() << " for " << dst->name()
295             << " (" << dst->type_string() << ")"
296             << " on " << dst->assigned_device_name();
297   }
298 
299   // Add the recv node.
300   const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
301   NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
302                               NodeDebugInfo(*src));
303   SetSendRecvAttrs(opts, edge, &recv_builder);
304   recv_builder.Device(dst->assigned_device_name())
305       .Attr("tensor_type", cast_dtype);
306   NodeDef* recv = gdef->add_node();
307   *status = recv_builder.Finalize(recv, /*consume=*/true);
308   if (!status->ok()) return nullptr;
309   *real_recv = recv;
310 
311   // Add the cast node (from cast_dtype to dtype) or an Identity node.
312   if (dtype != cast_dtype) {
313     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
314     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
315                                 NodeDebugInfo(*src));
316     cast_builder.Attr("DstT", dtype);
317     cast_builder.Device(dst->assigned_device_name())
318         .Input(recv->name(), 0, cast_dtype);
319     NodeDef* cast = gdef->add_node();
320     *status = cast_builder.Finalize(cast, /*consume=*/true);
321     if (!status->ok()) return nullptr;
322     return cast;
323   } else if (edge->IsControlEdge()) {
324     // An Identity is only needed for control edges.
325     NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
326                               NodeDebugInfo(*src));
327     id_builder.Device(dst->assigned_device_name())
328         .Input(recv->name(), 0, cast_dtype);
329     NodeDef* id = gdef->add_node();
330     *status = id_builder.Finalize(id, /*consume=*/true);
331     if (!status->ok()) return nullptr;
332     return id;
333   } else {
334     return recv;
335   }
336 }
337 
AddDummyConst(const PartitionOptions & opts,GraphDef * gdef,const Edge * edge,Status * status)338 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
339                        const Edge* edge, Status* status) {
340   const Node* src = edge->src();
341   Tensor tensor(DT_FLOAT, TensorShape({0}));
342   NodeDef* result = gdef->add_node();
343   *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
344                 .Device(src->assigned_device_name())
345                 .Attr("dtype", DT_FLOAT)
346                 .Attr("value", tensor)
347                 .Finalize(result, /*consume=*/true);
348   return result;
349 }
350 
351 // A dummy node for scheduling.
AddControlTrigger(const PartitionOptions & opts,GraphDef * gdef,const string & assigned_device_name,int64 epoch,int64 starttime,Status * status)352 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
353                            const string& assigned_device_name, int64 epoch,
354                            int64 starttime, Status* status) {
355   NodeDef* result = gdef->add_node();
356   *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
357                            "ControlTrigger")
358                 .Device(assigned_device_name)
359                 .Attr("_start_time", starttime)
360                 .Finalize(result, /*consume=*/true);
361   return result;
362 }
363 
364 // Optimize colocation for control flow nodes. For cond, we want the
365 // switch nodes to colocate with its data input. This is particularly
366 // needed for conditional reading of a remote variable. It may also
367 // reduce the number of devices involved in a loop.
368 // TODO(yuanbyu): In this case, we don't respect the requested device in
369 // the GraphDef for these nodes. Ideally, the placer would enforce the
370 // colocation to render this unnecessary.
OptimizeControlFlowColocation(Graph * graph)371 void OptimizeControlFlowColocation(Graph* graph) {
372   auto visit = [](Node* node) {
373     if (IsSwitch(node)) {
374       for (const Edge* in_edge : node->in_edges()) {
375         if (in_edge->dst_input() == 0) {
376           // Colocate with the data input.
377           node->set_assigned_device_name(
378               in_edge->src()->assigned_device_name());
379           return;
380         }
381       }
382     } else if (IsExit(node)) {
383       for (const Edge* in_edge : node->in_edges()) {
384         if (!in_edge->IsControlEdge()) {
385           // Colocate with upstream node.
386           node->set_assigned_device_name(
387               in_edge->src()->assigned_device_name());
388           return;
389         }
390       }
391     } else {
392       if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
393           IsNextIteration(node)) {
394         const Edge* data_edge = nullptr;
395         for (const Edge* out_edge : node->out_edges()) {
396           if (!out_edge->IsControlEdge()) {
397             data_edge = out_edge;
398             break;
399           }
400         }
401         // Colocate with the first downstream data node.
402         if (data_edge) {
403           node->set_assigned_device_name(
404               data_edge->dst()->assigned_device_name());
405         }
406       }
407     }
408   };
409   DFS(*graph, visit, {});
410 }
411 
ControlLoopName(const string & name)412 string ControlLoopName(const string& name) {
413   return strings::StrCat("_cloop", name);
414 }
415 
IsControlLoop(const Node * node)416 bool IsControlLoop(const Node* node) {
417   const string& name = node->name();
418   return absl::StartsWith(name, "_cloop");
419 }
420 
421 // An enter node for control flow.
AddControlEnter(Graph * g,const string & node_name,const string & device_name,const string & frame_name,const int parallel_iterations,Status * status)422 Node* AddControlEnter(Graph* g, const string& node_name,
423                       const string& device_name, const string& frame_name,
424                       const int parallel_iterations, Status* status) {
425   NodeBuilder node_builder(node_name, "Enter", g->op_registry());
426   node_builder.Input({"dummy", 0, DT_FLOAT});
427   node_builder.Attr("frame_name", frame_name);
428   node_builder.Attr("parallel_iterations", parallel_iterations);
429   Node* res_node;
430   *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
431   if (!status->ok()) return nullptr;
432   res_node->set_assigned_device_name(device_name);
433   return res_node;
434 }
435 
436 // A merge node for control flow.
AddControlMerge(const string & in_name1,const string & in_name2,Graph * g,const string & node_name,const string & device_name,Status * status)437 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
438                       const string& node_name, const string& device_name,
439                       Status* status) {
440   NodeBuilder node_builder(node_name, "Merge", g->op_registry());
441   node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
442   Node* res_node;
443   *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
444   if (!status->ok()) return nullptr;
445   res_node->set_assigned_device_name(device_name);
446   return res_node;
447 }
448 
449 // A switch node for control flow.
AddControlSwitch(NodeBuilder::NodeOut input1,NodeBuilder::NodeOut input2,const string & device_name,const GraphDefBuilder::Options & bopts)450 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
451                        const string& device_name,
452                        const GraphDefBuilder::Options& bopts) {
453   Node* res_node =
454       ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
455   if (bopts.HaveError()) return nullptr;
456   res_node->set_assigned_device_name(device_name);
457   return res_node;
458 }
459 
460 // A next_iteration node for control flow.
AddControlNext(NodeBuilder::NodeOut input,const string & device_name,const GraphDefBuilder::Options & bopts)461 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
462                      const GraphDefBuilder::Options& bopts) {
463   Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
464   if (bopts.HaveError()) return nullptr;
465   res_node->set_assigned_device_name(device_name);
466   return res_node;
467 }
468 
EmptyConst(const GraphDefBuilder::Options & options)469 Node* EmptyConst(const GraphDefBuilder::Options& options) {
470   if (options.HaveError()) return nullptr;
471   NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
472                            options.op_registry());
473   const DataType dt = DataTypeToEnum<float>::v();
474   TensorProto proto;
475   proto.set_dtype(dt);
476   TensorShape empty_shape({0});
477   empty_shape.AsProto(proto.mutable_tensor_shape());
478   node_builder.Attr("dtype", dt).Attr("value", proto);
479   return options.FinalizeBuilder(&node_builder);
480 }
481 
482 // A dummy const node for control flow.
AddControlConst(const string & device_name,const GraphDefBuilder::Options & bopts)483 Node* AddControlConst(const string& device_name,
484                       const GraphDefBuilder::Options& bopts) {
485   Node* res_node = EmptyConst(bopts);
486   if (bopts.HaveError()) return nullptr;
487   res_node->set_assigned_device_name(device_name);
488   return res_node;
489 }
490 
491 // A synthetic loop, made up of dummy nodes. It performs control-flow actions
492 // on behalf of a leader on a different device.
493 struct ControlLoop {
494   Node* enter = nullptr;
495   Node* merge = nullptr;
496   Node* switch_node = nullptr;
497 };
498 
499 // Add the control flow info of a new node added during partitioning.
500 // The new node has the same control flow info as src.
AddControlFlowInfo(const Node * node,const Node * src,std::vector<ControlFlowInfo> * cf_info)501 void AddControlFlowInfo(const Node* node, const Node* src,
502                         std::vector<ControlFlowInfo>* cf_info) {
503   int id = node->id();
504   if (static_cast<size_t>(id) >= cf_info->size()) {
505     cf_info->resize(id + 1);
506   }
507   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
508   ControlFlowInfo* info = &(*cf_info)[id];
509   info->frame = src_info.frame;
510   info->parent_frame = src_info.parent_frame;
511   info->frame_name = src_info.frame_name;
512 }
513 
514 // Constructs a control loop. Returns a struct containing the newly created
515 // enter, merge, and switch nodes. The enter and merge nodes are used in the
516 // recursive construction of control loops for nested frames (loops). The
517 // switch node will be connected to the LoopCond node. The merge node will
518 // be connected to all the recvs of the same frame by control edges when
519 // the actual partitioning happens.
AddControlLoop(const PartitionOptions & opts,Graph * g,const Node * src,const Edge * edge,Node * loop_cond,std::vector<ControlFlowInfo> * cf_info,ControlLoop * loop)520 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
521                       const Edge* edge, Node* loop_cond,
522                       std::vector<ControlFlowInfo>* cf_info,
523                       ControlLoop* loop) {
524   Status status;
525   GraphDefBuilder::Options bopts(g, &status);
526   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
527   const string& device_name = edge->dst()->assigned_device_name();
528   const string& frame_name = src_info.frame_name;
529   int parallel_iterations;
530   status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
531                        &parallel_iterations);
532   if (!status.ok()) return status;
533 
534   // The names of the nodes to be added.
535   const string& enter_name =
536       ControlLoopName(opts.new_name(edge->dst()->name()));
537   const string& merge_name =
538       ControlLoopName(opts.new_name(edge->dst()->name()));
539   const string& switch_name =
540       ControlLoopName(opts.new_name(edge->dst()->name()));
541   const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
542 
543   // Add the nodes to the graph g.
544   Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
545                                 parallel_iterations, &status);
546   if (!status.ok()) return status;
547   Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
548                                 device_name, &status);
549   if (!status.ok()) return status;
550   Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
551                                        bopts.WithName(switch_name));
552   if (!status.ok()) return status;
553   Node* next =
554       AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
555   if (!status.ok()) return status;
556 
557   // Add control flow info for these new nodes:
558   AddControlFlowInfo(enter, src, cf_info);
559   AddControlFlowInfo(merge, src, cf_info);
560   AddControlFlowInfo(switch_node, src, cf_info);
561   AddControlFlowInfo(next, src, cf_info);
562 
563   // Add input edges for the newly created merge node:
564   g->AddEdge(enter, 0, merge, 0);
565   g->AddEdge(next, 0, merge, 1);
566 
567   loop->enter = enter;
568   loop->merge = merge;
569   loop->switch_node = switch_node;
570   return Status::OK();
571 }
572 
573 // Build memory and device type info for every node in the graph.
574 // TODO(yuanbyu): It might be simpler if we convert MemoryType to
575 // DeviceType for the inputs/outputs of each node.
BuildMemoryDeviceInfo(const Graph & g,GraphInfo * info)576 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
577   MemoryTypeVector input_memory_types;
578   MemoryTypeVector output_memory_types;
579 
580   info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
581   for (const Node* node : g.op_nodes()) {
582     DeviceNameUtils::ParsedName parsed;
583     if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
584                                         &parsed)) {
585       return errors::Internal("Malformed assigned device '",
586                               node->assigned_device_name(), "'");
587     }
588 
589     TF_RETURN_IF_ERROR(MemoryTypesForNode(
590         g.op_registry(), DeviceType(parsed.type), node->def(),
591         &input_memory_types, &output_memory_types));
592 
593     int node_id = node->id();
594     info->device_types[node_id] = DeviceType(parsed.type);
595     for (int i = 0; i < input_memory_types.size(); ++i) {
596       info->input_types[{node_id, i}] = input_memory_types[i];
597     }
598     for (int i = 0; i < output_memory_types.size(); ++i) {
599       info->output_types[{node_id, i}] = output_memory_types[i];
600     }
601   }
602   return Status::OK();
603 }
604 
InputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)605 const Node* InputFrame(const Node* node,
606                        const std::vector<ControlFlowInfo>& cf_info) {
607   // An input is in the same frame as the node except for Enter nodes.
608   // The input of Enter is in the parent frame of the Enter node.
609   if (!node->IsEnter()) {
610     return node;
611   }
612   return cf_info[node->id()].parent_frame;
613 }
614 
OutputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)615 const Node* OutputFrame(const Node* node,
616                         const std::vector<ControlFlowInfo>& cf_info) {
617   // An output is in the same frame as the node except for Exit nodes.
618   // The output of Exit is in the parent frame of the Exit node.
619   if (!node->IsExit()) {
620     return node;
621   }
622   return cf_info[node->id()].parent_frame;
623 }
624 
625 // Each participating device needs to decide a) if there is a next iteration,
626 // and b) if the loop terminates. We take the approach to encode this control
627 // flow logic in the dataflow graph. There are at least two possible encodings.
628 // In a completely decentralized encoding, the participants communicate peer
629 // to peer. The other encoding uses a frame leader (the participant who owns
630 // the pivot termination predicate) to broadcast the termination condition to
631 // all the participants. For now we take the latter because it is simpler.
632 //
633 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got
634 // it wrong many times so it would be nice to write a proof to be sure.
AddControlFlow(const PartitionOptions & opts,Graph * g,GraphInfo * g_info)635 Status AddControlFlow(const PartitionOptions& opts, Graph* g,
636                       GraphInfo* g_info) {
637   Status status;
638   GraphDefBuilder::Options bopts(g, &status);
639   std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
640 
641   // Build the control flow info for every node.
642   status = BuildControlFlowInfo(g, &cf_info);
643   if (!status.ok()) return status;
644 
645   OptimizeControlFlowColocation(g);
646 
647   // The map from frames to their LoopCond nodes.
648   std::unordered_map<string, Node*> frame_cond_map;
649   int num_node_ids = g->num_node_ids();
650   for (int i = 0; i < num_node_ids; ++i) {
651     Node* node = g->FindNodeId(i);
652     if (node == nullptr) continue;
653 
654     if (IsLoopCond(node)) {
655       const string& frame_name = cf_info[node->id()].frame_name;
656       DCHECK(!frame_name.empty());
657       frame_cond_map[frame_name] = node;
658     }
659   }
660 
661   // Add all control loops for cross-device frames.
662   // A control loop is added only when there is a cross-device edge in a
663   // non-root frame. Nothing is added if there is no loops. We also don't
664   // add anything for a frame that is completely local to a device. For
665   // nested loops, we stack the control loops together by connecting
666   // the merge of the outer loop to the enter of the inner loop.
667   //
668   // A map from <frame_name, device_name> to ControlLoop.
669   std::unordered_map<string, ControlLoop> control_loops;
670   int num_edge_ids = g->num_edge_ids();
671   for (int i = 0; i < num_edge_ids; ++i) {
672     const Edge* edge = g->FindEdgeId(i);
673     if (edge == nullptr) continue;
674 
675     const Node* src = edge->src();
676     const Node* dst = edge->dst();
677     // Skip Sink/Source nodes.
678     if (!src->IsOp() || !dst->IsOp()) continue;
679 
680     const string& src_device = src->assigned_device_name();
681     const string& dst_device = dst->assigned_device_name();
682     // Skip local edges.
683     if (src_device == dst_device) continue;
684 
685     const Node* src_frame = OutputFrame(src, cf_info);
686     const Node* dst_frame = InputFrame(dst, cf_info);
687     const string& src_frame_name = cf_info[src_frame->id()].frame_name;
688     const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
689     // Skip if src and dst are not in the same frame.
690     if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
691       continue;
692     }
693 
694     // Add the control loop. Start by adding the control loop for the
695     // current frame if needed, and recursively adding the control loop
696     // for its outer frame when nested.
697     ControlLoop child_loop;
698     while (true) {
699       const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
700       if (curr_frame_name.empty()) {
701         // We have reached the root frame.
702         if (child_loop.merge != nullptr) {
703           const string& node_name = opts.new_name(edge->dst()->name());
704           const string& device_name = edge->dst()->assigned_device_name();
705           Node* const_node =
706               AddControlConst(device_name, bopts.WithName(node_name));
707           if (!status.ok()) return status;
708           AddControlFlowInfo(const_node, src_frame, &cf_info);
709           g->AddEdge(const_node, 0, child_loop.enter, 0);
710         }
711         break;
712       }
713 
714       const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
715       auto it = control_loops.find(cl_key);
716       if (it != control_loops.end()) {
717         if (child_loop.enter != nullptr) {
718           g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
719         }
720         break;
721       }
722 
723       // Get the frame's LoopCond.
724       auto cond_it = frame_cond_map.find(curr_frame_name);
725       if (cond_it == frame_cond_map.end()) {
726         return errors::InvalidArgument(
727             "A cross-device loop must have a pivot predicate: ",
728             curr_frame_name);
729       }
730       Node* loop_cond = cond_it->second;
731 
732       // Add the control loop.
733       ControlLoop curr_loop;
734       status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
735                               &curr_loop);
736       if (!status.ok()) return status;
737       control_loops[cl_key] = curr_loop;
738 
739       if (child_loop.enter != nullptr) {
740         // Connect the merge of the outer loop to the enter of the inner.
741         g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
742       }
743       src_frame = cf_info[src_frame->id()].parent_frame;
744       child_loop = curr_loop;
745     }
746   }
747 
748   // For a cross-device edge, on the dst device, add a control edge
749   // from the merge node of the control loop to dst. If a send/recv is
750   // introduced for this edge in future partitioning, we delete this
751   // control edge and add a new control edge from the merge to the recv.
752   num_edge_ids = g->num_edge_ids();
753   for (int i = 0; i < num_edge_ids; ++i) {
754     const Edge* edge = g->FindEdgeId(i);
755     if (edge == nullptr) continue;
756 
757     const Node* src = edge->src();
758     Node* dst = edge->dst();
759     // Skip Sink/Source nodes.
760     if (!src->IsOp() || !dst->IsOp()) continue;
761 
762     const string& src_device = src->assigned_device_name();
763     const string& dst_device = dst->assigned_device_name();
764     if (src_device != dst_device) {
765       const Node* src_frame = OutputFrame(src, cf_info);
766       const Node* dst_frame = InputFrame(dst, cf_info);
767       const string& src_frame_name = cf_info[src_frame->id()].frame_name;
768       const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
769       if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
770         const string& cl_key =
771             strings::StrCat(dst_frame_name, "$$", dst_device);
772         ControlLoop loop = control_loops[cl_key];
773         DCHECK(loop.enter != nullptr);
774         // Note that we'll create multiple duplicate edges if dst has multiple
775         // cross-device inputs. This is expected by the logic in Partition(), so
776         // it can add control edges to the recv nodes once they're created.
777         g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
778       }
779     }
780   }
781   return Status::OK();
782 }
783 
784 struct PriorityTopoSortNode {
PriorityTopoSortNodetensorflow::__anon303478ad0111::PriorityTopoSortNode785   PriorityTopoSortNode(const NodeDef* n, int64 st) : node(n), start_time(st) {}
786 
787   const NodeDef* node;
788   int64 start_time;
789 };
790 
791 struct PriorityTopoSortNodeGreater {
operator ()tensorflow::__anon303478ad0111::PriorityTopoSortNodeGreater792   bool operator()(const PriorityTopoSortNode& left,
793                   const PriorityTopoSortNode& right) {
794     return left.start_time > right.start_time;
795   }
796 };
797 
798 }  // namespace
799 
800 // Returns in <nodes> the nodes that should participate in epoch-based recv
801 // scheduling, along with their times; <nodes> is ordered by increasing
802 // start_time. Returns in <node_to_start_time_out> the timing for all nodes,
803 // even those not in <nodes>.
804 //
805 // Comparing to sorting on the node's start time only, this also processes the
806 // nodes in dependency order, and updates start times to ensure a node's
807 // start_time > the start time for all dependencies.
808 //
809 // Note that graph_partition_test.cc accesses this function for testing, even
810 // though it's not declared in the header.
TopologicalSortNodesWithTimePriority(const GraphDef * gdef,std::vector<std::pair<const NodeDef *,int64>> * nodes,std::unordered_map<const NodeDef *,int64> * node_to_start_time_out)811 Status TopologicalSortNodesWithTimePriority(
812     const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
813     std::unordered_map<const NodeDef*, int64>* node_to_start_time_out) {
814   // Queue of nodes to process; lowest start time is returned first.
815   std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
816                       PriorityTopoSortNodeGreater>
817       q;
818   std::unordered_map<const NodeDef*, int64> node_to_start_time;
819   auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
820     const int64 start_time = node_to_start_time[node];
821     q.emplace(node, start_time);
822   };
823 
824   // Build initial structures, initial contents of queue.
825   std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
826   std::unordered_map<const NodeDef*, int> inputs_needed;
827   for (int n = 0; n < gdef->node_size(); ++n) {
828     const NodeDef* ndef = &gdef->node(n);
829     for (int i = 0; i < ndef->input_size(); ++i) {
830       node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
831           .push_back(ndef);
832     }
833     int64 start_time;
834     TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
835     node_to_start_time[ndef] = start_time;
836     inputs_needed[ndef] = ndef->input_size();
837     if (ndef->input_size() == 0) {
838       enqueue(ndef);
839     }
840   }
841 
842   // Determine which merge nodes are parts of loops; these
843   // need to happen in the traversal after all non-NextIteration inputs
844   // are run.
845   for (int n = 0; n < gdef->node_size(); ++n) {
846     const NodeDef* ndef = &gdef->node(n);
847     if (IsNextIteration(*ndef)) {
848       for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
849         if (IsMerge(*n)) {
850           // n is a merge that is part of a loop structure.
851           // It doesn't need to wait for this NextIteration loop
852           // when doing the traversal.
853           --inputs_needed[n];
854         }
855       }
856     }
857   }
858 
859   // Traverse.
860   std::vector<std::pair<const NodeDef*, int64>> start_times;
861   start_times.reserve(gdef->node_size());
862   while (!q.empty()) {
863     PriorityTopoSortNode cur = q.top();
864     q.pop();
865 
866     start_times.emplace_back(cur.node, cur.start_time);
867 
868     for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
869       auto& output_start_time = node_to_start_time[n];
870       if (output_start_time <= cur.start_time) {
871         output_start_time = cur.start_time + 1;
872       }
873       if (--inputs_needed[n] == 0) {
874         enqueue(n);
875       }
876     }
877   }
878 
879   // Done.
880   nodes->swap(start_times);
881   node_to_start_time_out->swap(node_to_start_time);
882   return Status::OK();
883 }
884 
AddControlEdges(const PartitionOptions & opts,std::unordered_map<string,GraphDef> * partitions)885 Status AddControlEdges(const PartitionOptions& opts,
886                        std::unordered_map<string, GraphDef>* partitions) {
887   Status status;
888   // TODO(yuanbyu): Very naive for now. To be improved.
889   const int num_epochs = 100;
890   const int prefetch = 6;
891 
892   for (auto& part : *partitions) {
893     GraphDef* gdef = &part.second;
894     std::vector<std::pair<const NodeDef*, int64>> start_times;
895     std::unordered_map<const NodeDef*, int64> node_to_start_time;
896     status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
897                                                   &node_to_start_time);
898     if (!status.ok()) {
899       return status;
900     }
901 
902     // Add a dummy node for every epoch, and add a control edge from the
903     // "last" node in the preceding epoch to the dummy node.
904     string device_name = gdef->node(0).device();
905     int64 makespan = start_times.back().second;
906     int64 resolution = (makespan / num_epochs) + 1;
907 
908     int i = 0;
909     int j = 0;
910     std::vector<NodeDef*> dummys;
911     while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
912       if (i * resolution > start_times[j].second) {
913         j++;
914       } else {
915         NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
916                                            i * resolution, &status);
917         if (!status.ok()) {
918           return status;
919         }
920         dummys.push_back(dummy);
921         if (j > 0) {
922           string src_name = start_times[j - 1].first->name();
923           AddInput(dummy, src_name, Graph::kControlSlot);
924         }
925         i++;
926       }
927     }
928 
929     // Finally, add the control edges to recvs.
930     for (int n = 0; n < gdef->node_size(); ++n) {
931       NodeDef* ndef = gdef->mutable_node(n);
932       if (ndef->op() == "_Recv") {
933         const int64 start_time = node_to_start_time[ndef];
934         const int recv_epoch = start_time / resolution;
935         if (recv_epoch >= prefetch) {
936           NodeDef* dummy = dummys[recv_epoch - prefetch];
937           AddInput(ndef, dummy->name(), Graph::kControlSlot);
938         }
939       }
940     }
941   }
942   return Status::OK();
943 }
944 
945 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
946 // if possible.
SetIncarnation(const PartitionOptions & opts,NodeDef * ndef)947 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
948   StringPiece op(ndef->op());
949   if (op != "_Send" && op != "_Recv") {
950     // Not related to send/recv.
951     return;
952   }
953   const string& send_device = GetNodeAttrString(*ndef, "send_device");
954   if (send_device.empty()) {
955     // No known send_device. The runtime will detect it later.
956     return;
957   }
958   int64 incarnation = PartitionOptions::kIllegalIncarnation;
959   if (!TryGetNodeAttr(*ndef, "send_device_incarnation", &incarnation) ||
960       (incarnation == PartitionOptions::kIllegalIncarnation)) {
961     incarnation = opts.get_incarnation(send_device);
962     SetAttrValue(incarnation,
963                  &((*ndef->mutable_attr())["send_device_incarnation"]));
964   }
965 }
966 
967 // Sets attribute send_device_incarnation of all Send/Recv nodes in
968 // 'gdef', if possible.
SetIncarnation(const PartitionOptions & opts,GraphDef * gdef)969 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
970   for (NodeDef& ndef : *gdef->mutable_node()) {
971     SetIncarnation(opts, &ndef);
972   }
973   for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
974     for (NodeDef& ndef : *fdef.mutable_node_def()) {
975       SetIncarnation(opts, &ndef);
976     }
977   }
978 }
979 
Partition(const PartitionOptions & opts,Graph * g,std::unordered_map<string,GraphDef> * partitions)980 Status Partition(const PartitionOptions& opts, Graph* g,
981                  std::unordered_map<string, GraphDef>* partitions) {
982   Status status;
983   partitions->clear();
984 
985   GraphInfo g_info;
986   if (!opts.control_flow_added) {
987     // Add the "code" for distributed execution of control flow. Code is
988     // added only for the frames that are placed on multiple devices. The
989     // new graph is an equivalent transformation of the original graph and
990     // has the property that it can be subsequently partitioned arbitrarily
991     // (down to the level of individual device) for distributed execution.
992     status = AddControlFlow(opts, g, &g_info);
993     if (!status.ok()) return status;
994   }
995 
996   // At this point, all the graph mutations have been done. Build memory
997   // and device type info for every node and edge in the graph.
998   status = BuildMemoryDeviceInfo(*g, &g_info);
999   if (!status.ok()) return status;
1000 
1001   string dstp;
1002   std::vector<const Edge*> inputs;
1003   DupRecvTable dup_recv(3);
1004   // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
1005   // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
1006   // edge to dst. We will add a control edge for every pair in
1007   // (ref_recvs x ref_control_inputs).
1008   std::vector<NodeDef*> ref_recvs;
1009   std::vector<string> ref_control_inputs;
1010 
1011   int32 num_data = 0;
1012   int32 num_control = 0;
1013   for (const Node* dst : g->op_nodes()) {
1014     dstp = opts.node_to_loc(dst);
1015     GraphDef* dst_graph = &(*partitions)[dstp];
1016     NodeDef* dst_def = dst_graph->add_node();
1017     *dst_def = dst->def();
1018     MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
1019     dst_def->set_device(dst->assigned_device_name());
1020     dst_def->clear_input();  // Inputs are filled below
1021     if (opts.need_to_record_start_times) {
1022       int64 start_time;
1023       status = GetNodeAttr(*dst_def, "_start_time", &start_time);
1024       if (errors::IsNotFound(status)) {
1025         start_time = opts.start_times[dst->id()].value();
1026         AddNodeAttr("_start_time", start_time, dst_def);
1027       } else if (!status.ok()) {
1028         return status;
1029       }
1030     }
1031 
1032     // Arrange the incoming edges to dst so that input[i] holds the
1033     // input flowing into slot numbered i. Trailing entries in input[]
1034     // hold control edges.
1035     inputs.clear();
1036     inputs.resize(dst->num_inputs(), nullptr);
1037     ref_recvs.clear();
1038     ref_control_inputs.clear();
1039     const Edge* control_flow_edge = nullptr;
1040     int32 num_control_flow_edges = 0;
1041     int32 num_input_edges = 0;
1042     for (const Edge* edge : dst->in_edges()) {
1043       if (edge->IsControlEdge()) {
1044         if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
1045           // This is one of the control edges added for control flow. There
1046           // can be multiple such edges as the dest node may have multiple
1047           // remote inputs. We keep track of the number of such edges.
1048           control_flow_edge = edge;
1049           ++num_control_flow_edges;
1050         } else {
1051           inputs.push_back(edge);
1052         }
1053       } else {
1054         DCHECK(inputs[edge->dst_input()] == nullptr);
1055         inputs[edge->dst_input()] = edge;
1056         ++num_input_edges;
1057       }
1058     }
1059 
1060     if (num_input_edges != dst->num_inputs()) {
1061       return errors::InvalidArgument("Incomplete graph, missing ",
1062                                      (dst->num_inputs() - num_input_edges),
1063                                      " inputs for ", dst->name());
1064     }
1065 
1066     // Process in order so that all data edges are added as inputs to
1067     // dst in Edge::dst_input() order.
1068     for (const Edge* edge : inputs) {
1069       const Node* src = edge->src();
1070       if (!src->IsOp()) continue;  // Skip Sink/Source nodes.
1071 
1072       GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
1073       if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
1074         // Same partition and compatible memory types:
1075         AddInput(dst_def, src->name(), edge->src_output());
1076         if (edge->IsControlEdge() ||
1077             !IsRefType(src->output_type(edge->src_output()))) {
1078           ref_control_inputs.push_back(src->name());
1079         }
1080         continue;
1081       }
1082 
1083       int64 send_start_time = 0;
1084       int64 recv_start_time = 0;
1085       if (opts.scheduling_for_recvs) {
1086         status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
1087         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1088           send_start_time = opts.start_times[src->id()].value();
1089         } else if (!status.ok()) {
1090           return status;
1091         }
1092 
1093         status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
1094         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1095           recv_start_time = opts.start_times[dst->id()].value();
1096         } else if (!status.ok()) {
1097           return status;
1098         }
1099       }
1100 
1101       // Check whether there is already a send/recv pair transferring
1102       // the same tensor/control from the src to dst partition.
1103       const bool on_host = IsDstInputOnHost(edge, g_info);
1104       DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
1105       auto iter = dup_recv.find(key);
1106       if (iter != dup_recv.end()) {
1107         // We found one. Reuse the data/control transferred already.
1108         const string& recv_node_name = iter->second.recv->name();
1109         if (edge->IsControlEdge()) {
1110           AddInput(dst_def, recv_node_name, Graph::kControlSlot);
1111         } else {
1112           AddInput(dst_def, recv_node_name, 0);
1113         }
1114         ref_control_inputs.push_back(recv_node_name);
1115 
1116         // We want the start_time for the recv to be the smallest of the start
1117         // times of it's consumers. So we update this whenever we use a recv,
1118         // and write it out to the attribute at the end of the subroutine
1119         if (iter->second.start_time > recv_start_time) {
1120           iter->second.start_time = recv_start_time;
1121         }
1122         continue;
1123       }
1124 
1125       NodeDefBuilder::NodeOut send_from;
1126       if (edge->IsControlEdge()) {
1127         // Insert a dummy const node that will generate a tiny
1128         // data element to be sent from send to recv.
1129         VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
1130                 << src->name() << "] -> " << dst->assigned_device_name() << "["
1131                 << dst->name() << "]";
1132         NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
1133         if (!status.ok()) return status;
1134         // Set the start time for this dummy node.
1135         if (opts.scheduling_for_recvs) {
1136           AddNodeAttr("_start_time", send_start_time, dummy);
1137         }
1138         AddInput(dummy, src->name(), Graph::kControlSlot);
1139         send_from.Reset(dummy->name(), 0, DT_FLOAT);
1140       } else {
1141         send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
1142       }
1143 
1144       // Need to split edge by placing matching send/recv nodes on
1145       // the src/dst sides of the edge.
1146       NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
1147                               send_start_time, &status);
1148       if (!status.ok()) return status;
1149 
1150       NodeDef* real_recv = nullptr;
1151       NodeDef* recv =
1152           AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
1153       if (!status.ok()) return status;
1154 
1155       // Fix up the control flow edge.
1156       // NOTE(yuanbyu): 'real_recv' must be the real recv node.
1157       if (src_graph == dst_graph) {
1158         // For same device send/recv, add a control edge from send to recv.
1159         // This prevents the asynchronous recv kernel from being scheduled
1160         // before the data is available.
1161         AddInput(real_recv, send->name(), Graph::kControlSlot);
1162       } else if (control_flow_edge != nullptr) {
1163         // Redirect control edge to the real recv since this is not the same
1164         // device send/recv.
1165         --num_control_flow_edges;
1166         AddInput(real_recv, control_flow_edge->src()->name(),
1167                  Graph::kControlSlot);
1168       }
1169 
1170       if (!edge->IsControlEdge() &&
1171           IsRefType(src->output_type(edge->src_output()))) {
1172         AddNodeAttr("_start_time", recv_start_time, recv);
1173         if (real_recv != recv) {
1174           AddNodeAttr("_start_time", recv_start_time, real_recv);
1175         }
1176         // If src is of ref type and the edge is not a control edge, dst has
1177         // read semantics and therefore we must control the recv.
1178         ref_recvs.push_back(real_recv);
1179       } else {
1180         // Memorize the send/recv pair, only if this is not a "ref" edge.
1181         // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
1182         // for now we don't do it.
1183         dup_recv[key] = {recv, real_recv, recv_start_time};
1184         ref_control_inputs.push_back(recv->name());
1185       }
1186 
1187       if (edge->IsControlEdge()) {
1188         ++num_control;
1189         AddInput(dst_def, recv->name(), Graph::kControlSlot);
1190       } else {
1191         ++num_data;
1192         AddInput(dst_def, recv->name(), 0);
1193       }
1194     }
1195 
1196     // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
1197     // NOTE(yuanbyu): Adding these control edges should not introduce
1198     // deadlocks. 'dst' has implicit "read" nodes that, when we split
1199     // across devices, are made explicit; Retargeting the dependencies
1200     // to 'dst' to those nodes would not introduce cycles if there isn't
1201     // one before the transformation.
1202     // NOTE(yuanbyu): This may impact performance because it defers the
1203     // execution of recvs until all the other inputs become available.
1204     AddReadControl(ref_recvs, ref_control_inputs);
1205 
1206     // Add back the control edges for control flow that are not used.
1207     if (control_flow_edge != nullptr) {
1208       for (int i = 0; i < num_control_flow_edges; ++i) {
1209         AddInput(dst_def, control_flow_edge->src()->name(),
1210                  Graph::kControlSlot);
1211       }
1212     }
1213   }
1214 
1215   const FunctionLibraryDefinition* flib_def = opts.flib_def;
1216   if (flib_def == nullptr) {
1217     flib_def = &g->flib_def();
1218   }
1219 
1220   // Set versions, function library and send/recv incarnation.
1221   for (auto& it : *partitions) {
1222     GraphDef* gdef = &it.second;
1223     *gdef->mutable_versions() = g->versions();
1224     // Prune unreachable functions from `flib_def` before adding them to `gdef`.
1225     *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
1226 
1227     // Traverse the graph to fill every send/recv op's incarnation
1228     // information.
1229     SetIncarnation(opts, gdef);
1230   }
1231 
1232   // Set the start times for recvs at the very end.
1233   if (opts.scheduling_for_recvs) {
1234     for (auto& it : dup_recv) {
1235       AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
1236       if (it.second.real_recv != it.second.recv) {
1237         AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
1238       }
1239     }
1240   }
1241 
1242   VLOG(1) << "Added send/recv: controls=" << num_control
1243           << ", data=" << num_data;
1244   if (VLOG_IS_ON(2)) {
1245     for (auto& it : *partitions) {
1246       GraphDef* gdef = &it.second;
1247       DumpGraphDefToFile(strings::StrCat("partition_", it.first, "_",
1248                                          reinterpret_cast<uintptr_t>(gdef)),
1249                          *gdef);
1250     }
1251   }
1252   return Status::OK();
1253 }
1254 
1255 }  // namespace tensorflow
1256