1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph_partition.h"
17
18 #include <deque>
19 #include <queue>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/control_flow.h"
34 #include "tensorflow/core/graph/costmodel.h"
35 #include "tensorflow/core/graph/graph_def_builder.h"
36 #include "tensorflow/core/graph/node_builder.h"
37 #include "tensorflow/core/graph/tensor_id.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tensorflow/core/util/dump_graph.h"
44
45 namespace tensorflow {
46
47 namespace {
48
IsMerge(const NodeDef & node_def)49 inline bool IsMerge(const NodeDef& node_def) {
50 return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
51 node_def.op() == "_XlaMerge";
52 }
53
IsNextIteration(const NodeDef & node_def)54 inline bool IsNextIteration(const NodeDef& node_def) {
55 return node_def.op() == "NextIteration" ||
56 node_def.op() == "RefNextIteration";
57 }
58
59 struct DupRecvKey {
60 int src_node_id; // Edge's src node id
61 int src_output_slot; // Edge's src node output slot
62 GraphDef* dst_graph; // Edge's dst node is in this subgraph
63 bool recv_output_on_host; // The output of recv is on host
64
65 template <typename H>
AbslHashValue(H h,const DupRecvKey & c)66 friend H AbslHashValue(H h, const DupRecvKey& c) {
67 return H::combine(std::move(h), c.src_node_id, c.src_output_slot,
68 reinterpret_cast<std::uintptr_t>(c.dst_graph),
69 c.recv_output_on_host);
70 }
71
operator ==(const DupRecvKey & x,const DupRecvKey & y)72 friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) {
73 return (x.src_node_id == y.src_node_id) &&
74 (x.src_output_slot == y.src_output_slot) &&
75 (x.dst_graph == y.dst_graph) &&
76 (x.recv_output_on_host == y.recv_output_on_host);
77 }
78 };
79
80 // struct used to store the recvs, so that start times can be properly updated
81 struct RecvInfo {
82 NodeDef* recv;
83 NodeDef* real_recv;
84 int64 start_time;
85 };
86
87 typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable;
88
89 // A map used to store memory types for the inputs/outputs of every node.
90 // The key is a pair of ints consisting of a node id and input/output index.
91 // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC.
92 struct NodePort {
93 int node_id;
94 int index;
95
operator ==(const NodePort & x,const NodePort & y)96 friend bool operator==(const NodePort& x, const NodePort& y) {
97 return x.node_id == y.node_id && x.index == y.index;
98 }
99
100 template <typename H>
AbslHashValue(H h,const NodePort & c)101 friend H AbslHashValue(H h, const NodePort& c) {
102 return H::combine(std::move(h), c.node_id, c.index);
103 }
104 };
105
106 typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap;
107
108 // We collect the following information about the graph before performing
109 // graph partitioning.
110 struct GraphInfo {
111 std::vector<DeviceType> device_types;
112 MemoryTypeMap input_types;
113 MemoryTypeMap output_types;
114 std::vector<ControlFlowInfo> cf_info;
115 };
116
EdgeType(const Edge * e)117 DataType EdgeType(const Edge* e) {
118 if (e->IsControlEdge()) {
119 return DT_FLOAT;
120 } else {
121 return e->dst()->input_type(e->dst_input());
122 }
123 }
124
125 // Return true iff we need to add the same device send/recv for 'edge'.
NeedSameDeviceSendRecv(const Edge * edge,const GraphInfo & info)126 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
127 if (edge->IsControlEdge()) {
128 return false;
129 }
130
131 const Node* src = edge->src();
132 const Node* dst = edge->dst();
133 if (src->assigned_device_name() == dst->assigned_device_name()) {
134 int src_port = edge->src_output();
135 int dst_port = edge->dst_input();
136 if (info.device_types[src->id()] != DEVICE_CPU) {
137 auto src_it = info.output_types.find({src->id(), src_port});
138 DCHECK(src_it != info.output_types.end());
139 auto dst_it = info.input_types.find({dst->id(), dst_port});
140 DCHECK(dst_it != info.input_types.end());
141 return src_it->second != dst_it->second;
142 }
143 }
144 return false;
145 }
146
147 // Return true iff (dst, dst_input) is specified on host memory.
IsDstInputOnHost(const Edge * edge,const GraphInfo & info)148 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
149 const Node* dst = edge->dst();
150 int dst_port = edge->dst_input();
151 if (info.device_types[dst->id()] != DEVICE_CPU) {
152 if (edge->IsControlEdge()) return false;
153 auto dst_it = info.input_types.find({dst->id(), dst_port});
154 DCHECK(dst_it != info.input_types.end());
155 return dst_it->second == HOST_MEMORY;
156 }
157 return true;
158 }
159
160 // Add an input to dst that comes from the "src_slot" output of the
161 // node named by "src_name".
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)162 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
163 if (src_slot == Graph::kControlSlot) {
164 dst->add_input(strings::StrCat("^", src_name));
165 } else if (src_slot == 0) {
166 dst->add_input(src_name.data(), src_name.size());
167 } else {
168 dst->add_input(strings::StrCat(src_name, ":", src_slot));
169 }
170 }
171
172 // Add a control edge from each input to each recv.
AddReadControl(const std::vector<NodeDef * > & recvs,const std::vector<string> & inputs)173 void AddReadControl(const std::vector<NodeDef*>& recvs,
174 const std::vector<string>& inputs) {
175 for (NodeDef* recv : recvs) {
176 for (const string& input : inputs) {
177 recv->add_input(strings::StrCat("^", input));
178 }
179 }
180 }
181
SetSendRecvAttrs(const PartitionOptions & opts,const Edge * edge,NodeDefBuilder * builder)182 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
183 NodeDefBuilder* builder) {
184 builder->Attr("tensor_name",
185 strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
186 builder->Attr("send_device", edge->src()->assigned_device_name());
187 builder->Attr("send_device_incarnation",
188 static_cast<int64>(
189 opts.get_incarnation(edge->src()->assigned_device_name())));
190 builder->Attr("recv_device", edge->dst()->assigned_device_name());
191 builder->Attr("client_terminated", false);
192 builder->Attr("_src", edge->src()->name());
193 builder->Attr("_dst", edge->dst()->name());
194 }
195
AddSend(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDefBuilder::NodeOut send_from,int64 start_time,Status * status)196 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
197 GraphDef* gdef, const Edge* edge,
198 NodeDefBuilder::NodeOut send_from, int64 start_time,
199 Status* status) {
200 const DataType dtype = send_from.data_type;
201 const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
202 const Node* src = edge->src();
203 const int src_port = edge->src_output();
204
205 // host_memory = true iff we need to use HostSend/HostCast.
206 bool host_memory = false;
207 if (!edge->IsControlEdge()) {
208 auto src_it = g_info.output_types.find({src->id(), src_port});
209 DCHECK(src_it != g_info.output_types.end());
210 host_memory = (src_it->second == HOST_MEMORY);
211 }
212
213 // Add a cast node that casts dtype to cast_dtype.
214 // NOTE(yuanbyu): Only cast for cross-device send/recv.
215 if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
216 const string cast_op = (host_memory) ? "_HostCast" : "Cast";
217 NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
218 NodeDebugInfo(*src));
219 cast_builder.Device(src->assigned_device_name()).Input(send_from);
220 if (opts.scheduling_for_recvs) {
221 cast_builder.Attr("_start_time", start_time);
222 }
223 cast_builder.Attr("DstT", cast_dtype);
224
225 if (cast_dtype == DT_BFLOAT16) {
226 // the below attribute specifies that the cast to bfloat16 should use
227 // truncation. This is needed to retain legacy behavior when we change
228 // the default bfloat16 casts to use rounding instead of truncation
229 cast_builder.Attr("Truncate", true);
230 }
231
232 NodeDef* cast = gdef->add_node();
233 *status = cast_builder.Finalize(cast, /*consume=*/true);
234 if (!status->ok()) return nullptr;
235
236 // Connect the Send op to the cast.
237 send_from.Reset(cast->name(), 0, cast_dtype);
238 }
239
240 // Add the send node.
241 const string send_op = (host_memory) ? "_HostSend" : "_Send";
242 NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
243 NodeDebugInfo(*src));
244 SetSendRecvAttrs(opts, edge, &send_builder);
245 send_builder.Device(src->assigned_device_name()).Input(send_from);
246 if (opts.scheduling_for_recvs) {
247 send_builder.Attr("_start_time", start_time);
248 }
249 NodeDef* send = gdef->add_node();
250 *status = send_builder.Finalize(send, /*consume=*/true);
251 return send;
252 }
253
AddRecv(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDef ** real_recv,Status * status)254 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
255 GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
256 Status* status) {
257 const DataType dtype = EdgeType(edge);
258 const Node* src = edge->src();
259 const Node* dst = edge->dst();
260 const int dst_port = edge->dst_input();
261 DataType cast_dtype = dtype;
262
263 // NOTE(yuanbyu): Only cast for cross-device send/recv.
264 if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
265 cast_dtype = opts.should_cast(edge);
266 }
267
268 // host_memory = true iff we need to use HostRecv/HostCast.
269 // Also log the introduction of the send-recv pair, for performance debugging.
270 bool host_memory = false;
271 if (!edge->IsControlEdge()) {
272 auto dst_it = g_info.input_types.find({dst->id(), dst_port});
273 DCHECK(dst_it != g_info.input_types.end());
274 host_memory = (dst_it->second == HOST_MEMORY);
275 bool src_host_memory = false;
276 if (VLOG_IS_ON(1)) {
277 const int src_port = edge->src_output();
278 auto src_it = g_info.output_types.find({src->id(), src_port});
279 DCHECK(src_it != g_info.output_types.end());
280 src_host_memory = (src_it->second == HOST_MEMORY);
281 }
282 VLOG(1) << "Receiving data"
283 << " from " << src->name() << " (" << src->type_string() << ")"
284 << " on " << src->assigned_device_name() << " in "
285 << (src_host_memory ? "host memory" : "device memory") << " for "
286 << dst->name() << " (" << dst->type_string() << ")"
287 << " on " << dst->assigned_device_name() << " in "
288 << (host_memory ? "host memory" : "device memory");
289 } else {
290 // Log control-edge transfers too, but don't mention memory space since it's
291 // irrelevant.
292 VLOG(1) << "Receiving control"
293 << " from " << src->name() << " (" << src->type_string() << ")"
294 << " on " << src->assigned_device_name() << " for " << dst->name()
295 << " (" << dst->type_string() << ")"
296 << " on " << dst->assigned_device_name();
297 }
298
299 // Add the recv node.
300 const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
301 NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
302 NodeDebugInfo(*src));
303 SetSendRecvAttrs(opts, edge, &recv_builder);
304 recv_builder.Device(dst->assigned_device_name())
305 .Attr("tensor_type", cast_dtype);
306 NodeDef* recv = gdef->add_node();
307 *status = recv_builder.Finalize(recv, /*consume=*/true);
308 if (!status->ok()) return nullptr;
309 *real_recv = recv;
310
311 // Add the cast node (from cast_dtype to dtype) or an Identity node.
312 if (dtype != cast_dtype) {
313 const string cast_op = (host_memory) ? "_HostCast" : "Cast";
314 NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
315 NodeDebugInfo(*src));
316 cast_builder.Attr("DstT", dtype);
317 cast_builder.Device(dst->assigned_device_name())
318 .Input(recv->name(), 0, cast_dtype);
319 NodeDef* cast = gdef->add_node();
320 *status = cast_builder.Finalize(cast, /*consume=*/true);
321 if (!status->ok()) return nullptr;
322 return cast;
323 } else if (edge->IsControlEdge()) {
324 // An Identity is only needed for control edges.
325 NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
326 NodeDebugInfo(*src));
327 id_builder.Device(dst->assigned_device_name())
328 .Input(recv->name(), 0, cast_dtype);
329 NodeDef* id = gdef->add_node();
330 *status = id_builder.Finalize(id, /*consume=*/true);
331 if (!status->ok()) return nullptr;
332 return id;
333 } else {
334 return recv;
335 }
336 }
337
AddDummyConst(const PartitionOptions & opts,GraphDef * gdef,const Edge * edge,Status * status)338 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
339 const Edge* edge, Status* status) {
340 const Node* src = edge->src();
341 Tensor tensor(DT_FLOAT, TensorShape({0}));
342 NodeDef* result = gdef->add_node();
343 *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
344 .Device(src->assigned_device_name())
345 .Attr("dtype", DT_FLOAT)
346 .Attr("value", tensor)
347 .Finalize(result, /*consume=*/true);
348 return result;
349 }
350
351 // A dummy node for scheduling.
AddControlTrigger(const PartitionOptions & opts,GraphDef * gdef,const string & assigned_device_name,int64 epoch,int64 starttime,Status * status)352 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
353 const string& assigned_device_name, int64 epoch,
354 int64 starttime, Status* status) {
355 NodeDef* result = gdef->add_node();
356 *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
357 "ControlTrigger")
358 .Device(assigned_device_name)
359 .Attr("_start_time", starttime)
360 .Finalize(result, /*consume=*/true);
361 return result;
362 }
363
364 // Optimize colocation for control flow nodes. For cond, we want the
365 // switch nodes to colocate with its data input. This is particularly
366 // needed for conditional reading of a remote variable. It may also
367 // reduce the number of devices involved in a loop.
368 // TODO(yuanbyu): In this case, we don't respect the requested device in
369 // the GraphDef for these nodes. Ideally, the placer would enforce the
370 // colocation to render this unnecessary.
OptimizeControlFlowColocation(Graph * graph)371 void OptimizeControlFlowColocation(Graph* graph) {
372 auto visit = [](Node* node) {
373 if (IsSwitch(node)) {
374 for (const Edge* in_edge : node->in_edges()) {
375 if (in_edge->dst_input() == 0) {
376 // Colocate with the data input.
377 node->set_assigned_device_name(
378 in_edge->src()->assigned_device_name());
379 return;
380 }
381 }
382 } else if (IsExit(node)) {
383 for (const Edge* in_edge : node->in_edges()) {
384 if (!in_edge->IsControlEdge()) {
385 // Colocate with upstream node.
386 node->set_assigned_device_name(
387 in_edge->src()->assigned_device_name());
388 return;
389 }
390 }
391 } else {
392 if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
393 IsNextIteration(node)) {
394 const Edge* data_edge = nullptr;
395 for (const Edge* out_edge : node->out_edges()) {
396 if (!out_edge->IsControlEdge()) {
397 data_edge = out_edge;
398 break;
399 }
400 }
401 // Colocate with the first downstream data node.
402 if (data_edge) {
403 node->set_assigned_device_name(
404 data_edge->dst()->assigned_device_name());
405 }
406 }
407 }
408 };
409 DFS(*graph, visit, {});
410 }
411
ControlLoopName(const string & name)412 string ControlLoopName(const string& name) {
413 return strings::StrCat("_cloop", name);
414 }
415
IsControlLoop(const Node * node)416 bool IsControlLoop(const Node* node) {
417 const string& name = node->name();
418 return absl::StartsWith(name, "_cloop");
419 }
420
421 // An enter node for control flow.
AddControlEnter(Graph * g,const string & node_name,const string & device_name,const string & frame_name,const int parallel_iterations,Status * status)422 Node* AddControlEnter(Graph* g, const string& node_name,
423 const string& device_name, const string& frame_name,
424 const int parallel_iterations, Status* status) {
425 NodeBuilder node_builder(node_name, "Enter", g->op_registry());
426 node_builder.Input({"dummy", 0, DT_FLOAT});
427 node_builder.Attr("frame_name", frame_name);
428 node_builder.Attr("parallel_iterations", parallel_iterations);
429 Node* res_node;
430 *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
431 if (!status->ok()) return nullptr;
432 res_node->set_assigned_device_name(device_name);
433 return res_node;
434 }
435
436 // A merge node for control flow.
AddControlMerge(const string & in_name1,const string & in_name2,Graph * g,const string & node_name,const string & device_name,Status * status)437 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
438 const string& node_name, const string& device_name,
439 Status* status) {
440 NodeBuilder node_builder(node_name, "Merge", g->op_registry());
441 node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
442 Node* res_node;
443 *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
444 if (!status->ok()) return nullptr;
445 res_node->set_assigned_device_name(device_name);
446 return res_node;
447 }
448
449 // A switch node for control flow.
AddControlSwitch(NodeBuilder::NodeOut input1,NodeBuilder::NodeOut input2,const string & device_name,const GraphDefBuilder::Options & bopts)450 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
451 const string& device_name,
452 const GraphDefBuilder::Options& bopts) {
453 Node* res_node =
454 ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
455 if (bopts.HaveError()) return nullptr;
456 res_node->set_assigned_device_name(device_name);
457 return res_node;
458 }
459
460 // A next_iteration node for control flow.
AddControlNext(NodeBuilder::NodeOut input,const string & device_name,const GraphDefBuilder::Options & bopts)461 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
462 const GraphDefBuilder::Options& bopts) {
463 Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
464 if (bopts.HaveError()) return nullptr;
465 res_node->set_assigned_device_name(device_name);
466 return res_node;
467 }
468
EmptyConst(const GraphDefBuilder::Options & options)469 Node* EmptyConst(const GraphDefBuilder::Options& options) {
470 if (options.HaveError()) return nullptr;
471 NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
472 options.op_registry());
473 const DataType dt = DataTypeToEnum<float>::v();
474 TensorProto proto;
475 proto.set_dtype(dt);
476 TensorShape empty_shape({0});
477 empty_shape.AsProto(proto.mutable_tensor_shape());
478 node_builder.Attr("dtype", dt).Attr("value", proto);
479 return options.FinalizeBuilder(&node_builder);
480 }
481
482 // A dummy const node for control flow.
AddControlConst(const string & device_name,const GraphDefBuilder::Options & bopts)483 Node* AddControlConst(const string& device_name,
484 const GraphDefBuilder::Options& bopts) {
485 Node* res_node = EmptyConst(bopts);
486 if (bopts.HaveError()) return nullptr;
487 res_node->set_assigned_device_name(device_name);
488 return res_node;
489 }
490
491 // A synthetic loop, made up of dummy nodes. It performs control-flow actions
492 // on behalf of a leader on a different device.
493 struct ControlLoop {
494 Node* enter = nullptr;
495 Node* merge = nullptr;
496 Node* switch_node = nullptr;
497 };
498
499 // Add the control flow info of a new node added during partitioning.
500 // The new node has the same control flow info as src.
AddControlFlowInfo(const Node * node,const Node * src,std::vector<ControlFlowInfo> * cf_info)501 void AddControlFlowInfo(const Node* node, const Node* src,
502 std::vector<ControlFlowInfo>* cf_info) {
503 int id = node->id();
504 if (static_cast<size_t>(id) >= cf_info->size()) {
505 cf_info->resize(id + 1);
506 }
507 const ControlFlowInfo& src_info = (*cf_info)[src->id()];
508 ControlFlowInfo* info = &(*cf_info)[id];
509 info->frame = src_info.frame;
510 info->parent_frame = src_info.parent_frame;
511 info->frame_name = src_info.frame_name;
512 }
513
514 // Constructs a control loop. Returns a struct containing the newly created
515 // enter, merge, and switch nodes. The enter and merge nodes are used in the
516 // recursive construction of control loops for nested frames (loops). The
517 // switch node will be connected to the LoopCond node. The merge node will
518 // be connected to all the recvs of the same frame by control edges when
519 // the actual partitioning happens.
AddControlLoop(const PartitionOptions & opts,Graph * g,const Node * src,const Edge * edge,Node * loop_cond,std::vector<ControlFlowInfo> * cf_info,ControlLoop * loop)520 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
521 const Edge* edge, Node* loop_cond,
522 std::vector<ControlFlowInfo>* cf_info,
523 ControlLoop* loop) {
524 Status status;
525 GraphDefBuilder::Options bopts(g, &status);
526 const ControlFlowInfo& src_info = (*cf_info)[src->id()];
527 const string& device_name = edge->dst()->assigned_device_name();
528 const string& frame_name = src_info.frame_name;
529 int parallel_iterations;
530 status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
531 ¶llel_iterations);
532 if (!status.ok()) return status;
533
534 // The names of the nodes to be added.
535 const string& enter_name =
536 ControlLoopName(opts.new_name(edge->dst()->name()));
537 const string& merge_name =
538 ControlLoopName(opts.new_name(edge->dst()->name()));
539 const string& switch_name =
540 ControlLoopName(opts.new_name(edge->dst()->name()));
541 const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
542
543 // Add the nodes to the graph g.
544 Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
545 parallel_iterations, &status);
546 if (!status.ok()) return status;
547 Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
548 device_name, &status);
549 if (!status.ok()) return status;
550 Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
551 bopts.WithName(switch_name));
552 if (!status.ok()) return status;
553 Node* next =
554 AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
555 if (!status.ok()) return status;
556
557 // Add control flow info for these new nodes:
558 AddControlFlowInfo(enter, src, cf_info);
559 AddControlFlowInfo(merge, src, cf_info);
560 AddControlFlowInfo(switch_node, src, cf_info);
561 AddControlFlowInfo(next, src, cf_info);
562
563 // Add input edges for the newly created merge node:
564 g->AddEdge(enter, 0, merge, 0);
565 g->AddEdge(next, 0, merge, 1);
566
567 loop->enter = enter;
568 loop->merge = merge;
569 loop->switch_node = switch_node;
570 return Status::OK();
571 }
572
573 // Build memory and device type info for every node in the graph.
574 // TODO(yuanbyu): It might be simpler if we convert MemoryType to
575 // DeviceType for the inputs/outputs of each node.
BuildMemoryDeviceInfo(const Graph & g,GraphInfo * info)576 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
577 MemoryTypeVector input_memory_types;
578 MemoryTypeVector output_memory_types;
579
580 info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
581 for (const Node* node : g.op_nodes()) {
582 DeviceNameUtils::ParsedName parsed;
583 if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
584 &parsed)) {
585 return errors::Internal("Malformed assigned device '",
586 node->assigned_device_name(), "'");
587 }
588
589 TF_RETURN_IF_ERROR(MemoryTypesForNode(
590 g.op_registry(), DeviceType(parsed.type), node->def(),
591 &input_memory_types, &output_memory_types));
592
593 int node_id = node->id();
594 info->device_types[node_id] = DeviceType(parsed.type);
595 for (int i = 0; i < input_memory_types.size(); ++i) {
596 info->input_types[{node_id, i}] = input_memory_types[i];
597 }
598 for (int i = 0; i < output_memory_types.size(); ++i) {
599 info->output_types[{node_id, i}] = output_memory_types[i];
600 }
601 }
602 return Status::OK();
603 }
604
InputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)605 const Node* InputFrame(const Node* node,
606 const std::vector<ControlFlowInfo>& cf_info) {
607 // An input is in the same frame as the node except for Enter nodes.
608 // The input of Enter is in the parent frame of the Enter node.
609 if (!node->IsEnter()) {
610 return node;
611 }
612 return cf_info[node->id()].parent_frame;
613 }
614
OutputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)615 const Node* OutputFrame(const Node* node,
616 const std::vector<ControlFlowInfo>& cf_info) {
617 // An output is in the same frame as the node except for Exit nodes.
618 // The output of Exit is in the parent frame of the Exit node.
619 if (!node->IsExit()) {
620 return node;
621 }
622 return cf_info[node->id()].parent_frame;
623 }
624
625 // Each participating device needs to decide a) if there is a next iteration,
626 // and b) if the loop terminates. We take the approach to encode this control
627 // flow logic in the dataflow graph. There are at least two possible encodings.
628 // In a completely decentralized encoding, the participants communicate peer
629 // to peer. The other encoding uses a frame leader (the participant who owns
630 // the pivot termination predicate) to broadcast the termination condition to
631 // all the participants. For now we take the latter because it is simpler.
632 //
633 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got
634 // it wrong many times so it would be nice to write a proof to be sure.
AddControlFlow(const PartitionOptions & opts,Graph * g,GraphInfo * g_info)635 Status AddControlFlow(const PartitionOptions& opts, Graph* g,
636 GraphInfo* g_info) {
637 Status status;
638 GraphDefBuilder::Options bopts(g, &status);
639 std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
640
641 // Build the control flow info for every node.
642 status = BuildControlFlowInfo(g, &cf_info);
643 if (!status.ok()) return status;
644
645 OptimizeControlFlowColocation(g);
646
647 // The map from frames to their LoopCond nodes.
648 std::unordered_map<string, Node*> frame_cond_map;
649 int num_node_ids = g->num_node_ids();
650 for (int i = 0; i < num_node_ids; ++i) {
651 Node* node = g->FindNodeId(i);
652 if (node == nullptr) continue;
653
654 if (IsLoopCond(node)) {
655 const string& frame_name = cf_info[node->id()].frame_name;
656 DCHECK(!frame_name.empty());
657 frame_cond_map[frame_name] = node;
658 }
659 }
660
661 // Add all control loops for cross-device frames.
662 // A control loop is added only when there is a cross-device edge in a
663 // non-root frame. Nothing is added if there is no loops. We also don't
664 // add anything for a frame that is completely local to a device. For
665 // nested loops, we stack the control loops together by connecting
666 // the merge of the outer loop to the enter of the inner loop.
667 //
668 // A map from <frame_name, device_name> to ControlLoop.
669 std::unordered_map<string, ControlLoop> control_loops;
670 int num_edge_ids = g->num_edge_ids();
671 for (int i = 0; i < num_edge_ids; ++i) {
672 const Edge* edge = g->FindEdgeId(i);
673 if (edge == nullptr) continue;
674
675 const Node* src = edge->src();
676 const Node* dst = edge->dst();
677 // Skip Sink/Source nodes.
678 if (!src->IsOp() || !dst->IsOp()) continue;
679
680 const string& src_device = src->assigned_device_name();
681 const string& dst_device = dst->assigned_device_name();
682 // Skip local edges.
683 if (src_device == dst_device) continue;
684
685 const Node* src_frame = OutputFrame(src, cf_info);
686 const Node* dst_frame = InputFrame(dst, cf_info);
687 const string& src_frame_name = cf_info[src_frame->id()].frame_name;
688 const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
689 // Skip if src and dst are not in the same frame.
690 if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
691 continue;
692 }
693
694 // Add the control loop. Start by adding the control loop for the
695 // current frame if needed, and recursively adding the control loop
696 // for its outer frame when nested.
697 ControlLoop child_loop;
698 while (true) {
699 const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
700 if (curr_frame_name.empty()) {
701 // We have reached the root frame.
702 if (child_loop.merge != nullptr) {
703 const string& node_name = opts.new_name(edge->dst()->name());
704 const string& device_name = edge->dst()->assigned_device_name();
705 Node* const_node =
706 AddControlConst(device_name, bopts.WithName(node_name));
707 if (!status.ok()) return status;
708 AddControlFlowInfo(const_node, src_frame, &cf_info);
709 g->AddEdge(const_node, 0, child_loop.enter, 0);
710 }
711 break;
712 }
713
714 const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
715 auto it = control_loops.find(cl_key);
716 if (it != control_loops.end()) {
717 if (child_loop.enter != nullptr) {
718 g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
719 }
720 break;
721 }
722
723 // Get the frame's LoopCond.
724 auto cond_it = frame_cond_map.find(curr_frame_name);
725 if (cond_it == frame_cond_map.end()) {
726 return errors::InvalidArgument(
727 "A cross-device loop must have a pivot predicate: ",
728 curr_frame_name);
729 }
730 Node* loop_cond = cond_it->second;
731
732 // Add the control loop.
733 ControlLoop curr_loop;
734 status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
735 &curr_loop);
736 if (!status.ok()) return status;
737 control_loops[cl_key] = curr_loop;
738
739 if (child_loop.enter != nullptr) {
740 // Connect the merge of the outer loop to the enter of the inner.
741 g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
742 }
743 src_frame = cf_info[src_frame->id()].parent_frame;
744 child_loop = curr_loop;
745 }
746 }
747
748 // For a cross-device edge, on the dst device, add a control edge
749 // from the merge node of the control loop to dst. If a send/recv is
750 // introduced for this edge in future partitioning, we delete this
751 // control edge and add a new control edge from the merge to the recv.
752 num_edge_ids = g->num_edge_ids();
753 for (int i = 0; i < num_edge_ids; ++i) {
754 const Edge* edge = g->FindEdgeId(i);
755 if (edge == nullptr) continue;
756
757 const Node* src = edge->src();
758 Node* dst = edge->dst();
759 // Skip Sink/Source nodes.
760 if (!src->IsOp() || !dst->IsOp()) continue;
761
762 const string& src_device = src->assigned_device_name();
763 const string& dst_device = dst->assigned_device_name();
764 if (src_device != dst_device) {
765 const Node* src_frame = OutputFrame(src, cf_info);
766 const Node* dst_frame = InputFrame(dst, cf_info);
767 const string& src_frame_name = cf_info[src_frame->id()].frame_name;
768 const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
769 if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
770 const string& cl_key =
771 strings::StrCat(dst_frame_name, "$$", dst_device);
772 ControlLoop loop = control_loops[cl_key];
773 DCHECK(loop.enter != nullptr);
774 // Note that we'll create multiple duplicate edges if dst has multiple
775 // cross-device inputs. This is expected by the logic in Partition(), so
776 // it can add control edges to the recv nodes once they're created.
777 g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
778 }
779 }
780 }
781 return Status::OK();
782 }
783
784 struct PriorityTopoSortNode {
PriorityTopoSortNodetensorflow::__anon303478ad0111::PriorityTopoSortNode785 PriorityTopoSortNode(const NodeDef* n, int64 st) : node(n), start_time(st) {}
786
787 const NodeDef* node;
788 int64 start_time;
789 };
790
791 struct PriorityTopoSortNodeGreater {
operator ()tensorflow::__anon303478ad0111::PriorityTopoSortNodeGreater792 bool operator()(const PriorityTopoSortNode& left,
793 const PriorityTopoSortNode& right) {
794 return left.start_time > right.start_time;
795 }
796 };
797
798 } // namespace
799
800 // Returns in <nodes> the nodes that should participate in epoch-based recv
801 // scheduling, along with their times; <nodes> is ordered by increasing
802 // start_time. Returns in <node_to_start_time_out> the timing for all nodes,
803 // even those not in <nodes>.
804 //
805 // Comparing to sorting on the node's start time only, this also processes the
806 // nodes in dependency order, and updates start times to ensure a node's
807 // start_time > the start time for all dependencies.
808 //
809 // Note that graph_partition_test.cc accesses this function for testing, even
810 // though it's not declared in the header.
TopologicalSortNodesWithTimePriority(const GraphDef * gdef,std::vector<std::pair<const NodeDef *,int64>> * nodes,std::unordered_map<const NodeDef *,int64> * node_to_start_time_out)811 Status TopologicalSortNodesWithTimePriority(
812 const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
813 std::unordered_map<const NodeDef*, int64>* node_to_start_time_out) {
814 // Queue of nodes to process; lowest start time is returned first.
815 std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
816 PriorityTopoSortNodeGreater>
817 q;
818 std::unordered_map<const NodeDef*, int64> node_to_start_time;
819 auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
820 const int64 start_time = node_to_start_time[node];
821 q.emplace(node, start_time);
822 };
823
824 // Build initial structures, initial contents of queue.
825 std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
826 std::unordered_map<const NodeDef*, int> inputs_needed;
827 for (int n = 0; n < gdef->node_size(); ++n) {
828 const NodeDef* ndef = &gdef->node(n);
829 for (int i = 0; i < ndef->input_size(); ++i) {
830 node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
831 .push_back(ndef);
832 }
833 int64 start_time;
834 TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
835 node_to_start_time[ndef] = start_time;
836 inputs_needed[ndef] = ndef->input_size();
837 if (ndef->input_size() == 0) {
838 enqueue(ndef);
839 }
840 }
841
842 // Determine which merge nodes are parts of loops; these
843 // need to happen in the traversal after all non-NextIteration inputs
844 // are run.
845 for (int n = 0; n < gdef->node_size(); ++n) {
846 const NodeDef* ndef = &gdef->node(n);
847 if (IsNextIteration(*ndef)) {
848 for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
849 if (IsMerge(*n)) {
850 // n is a merge that is part of a loop structure.
851 // It doesn't need to wait for this NextIteration loop
852 // when doing the traversal.
853 --inputs_needed[n];
854 }
855 }
856 }
857 }
858
859 // Traverse.
860 std::vector<std::pair<const NodeDef*, int64>> start_times;
861 start_times.reserve(gdef->node_size());
862 while (!q.empty()) {
863 PriorityTopoSortNode cur = q.top();
864 q.pop();
865
866 start_times.emplace_back(cur.node, cur.start_time);
867
868 for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
869 auto& output_start_time = node_to_start_time[n];
870 if (output_start_time <= cur.start_time) {
871 output_start_time = cur.start_time + 1;
872 }
873 if (--inputs_needed[n] == 0) {
874 enqueue(n);
875 }
876 }
877 }
878
879 // Done.
880 nodes->swap(start_times);
881 node_to_start_time_out->swap(node_to_start_time);
882 return Status::OK();
883 }
884
AddControlEdges(const PartitionOptions & opts,std::unordered_map<string,GraphDef> * partitions)885 Status AddControlEdges(const PartitionOptions& opts,
886 std::unordered_map<string, GraphDef>* partitions) {
887 Status status;
888 // TODO(yuanbyu): Very naive for now. To be improved.
889 const int num_epochs = 100;
890 const int prefetch = 6;
891
892 for (auto& part : *partitions) {
893 GraphDef* gdef = &part.second;
894 std::vector<std::pair<const NodeDef*, int64>> start_times;
895 std::unordered_map<const NodeDef*, int64> node_to_start_time;
896 status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
897 &node_to_start_time);
898 if (!status.ok()) {
899 return status;
900 }
901
902 // Add a dummy node for every epoch, and add a control edge from the
903 // "last" node in the preceding epoch to the dummy node.
904 string device_name = gdef->node(0).device();
905 int64 makespan = start_times.back().second;
906 int64 resolution = (makespan / num_epochs) + 1;
907
908 int i = 0;
909 int j = 0;
910 std::vector<NodeDef*> dummys;
911 while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
912 if (i * resolution > start_times[j].second) {
913 j++;
914 } else {
915 NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
916 i * resolution, &status);
917 if (!status.ok()) {
918 return status;
919 }
920 dummys.push_back(dummy);
921 if (j > 0) {
922 string src_name = start_times[j - 1].first->name();
923 AddInput(dummy, src_name, Graph::kControlSlot);
924 }
925 i++;
926 }
927 }
928
929 // Finally, add the control edges to recvs.
930 for (int n = 0; n < gdef->node_size(); ++n) {
931 NodeDef* ndef = gdef->mutable_node(n);
932 if (ndef->op() == "_Recv") {
933 const int64 start_time = node_to_start_time[ndef];
934 const int recv_epoch = start_time / resolution;
935 if (recv_epoch >= prefetch) {
936 NodeDef* dummy = dummys[recv_epoch - prefetch];
937 AddInput(ndef, dummy->name(), Graph::kControlSlot);
938 }
939 }
940 }
941 }
942 return Status::OK();
943 }
944
945 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
946 // if possible.
SetIncarnation(const PartitionOptions & opts,NodeDef * ndef)947 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
948 StringPiece op(ndef->op());
949 if (op != "_Send" && op != "_Recv") {
950 // Not related to send/recv.
951 return;
952 }
953 const string& send_device = GetNodeAttrString(*ndef, "send_device");
954 if (send_device.empty()) {
955 // No known send_device. The runtime will detect it later.
956 return;
957 }
958 int64 incarnation = PartitionOptions::kIllegalIncarnation;
959 if (!TryGetNodeAttr(*ndef, "send_device_incarnation", &incarnation) ||
960 (incarnation == PartitionOptions::kIllegalIncarnation)) {
961 incarnation = opts.get_incarnation(send_device);
962 SetAttrValue(incarnation,
963 &((*ndef->mutable_attr())["send_device_incarnation"]));
964 }
965 }
966
967 // Sets attribute send_device_incarnation of all Send/Recv nodes in
968 // 'gdef', if possible.
SetIncarnation(const PartitionOptions & opts,GraphDef * gdef)969 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
970 for (NodeDef& ndef : *gdef->mutable_node()) {
971 SetIncarnation(opts, &ndef);
972 }
973 for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
974 for (NodeDef& ndef : *fdef.mutable_node_def()) {
975 SetIncarnation(opts, &ndef);
976 }
977 }
978 }
979
Partition(const PartitionOptions & opts,Graph * g,std::unordered_map<string,GraphDef> * partitions)980 Status Partition(const PartitionOptions& opts, Graph* g,
981 std::unordered_map<string, GraphDef>* partitions) {
982 Status status;
983 partitions->clear();
984
985 GraphInfo g_info;
986 if (!opts.control_flow_added) {
987 // Add the "code" for distributed execution of control flow. Code is
988 // added only for the frames that are placed on multiple devices. The
989 // new graph is an equivalent transformation of the original graph and
990 // has the property that it can be subsequently partitioned arbitrarily
991 // (down to the level of individual device) for distributed execution.
992 status = AddControlFlow(opts, g, &g_info);
993 if (!status.ok()) return status;
994 }
995
996 // At this point, all the graph mutations have been done. Build memory
997 // and device type info for every node and edge in the graph.
998 status = BuildMemoryDeviceInfo(*g, &g_info);
999 if (!status.ok()) return status;
1000
1001 string dstp;
1002 std::vector<const Edge*> inputs;
1003 DupRecvTable dup_recv(3);
1004 // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
1005 // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
1006 // edge to dst. We will add a control edge for every pair in
1007 // (ref_recvs x ref_control_inputs).
1008 std::vector<NodeDef*> ref_recvs;
1009 std::vector<string> ref_control_inputs;
1010
1011 int32 num_data = 0;
1012 int32 num_control = 0;
1013 for (const Node* dst : g->op_nodes()) {
1014 dstp = opts.node_to_loc(dst);
1015 GraphDef* dst_graph = &(*partitions)[dstp];
1016 NodeDef* dst_def = dst_graph->add_node();
1017 *dst_def = dst->def();
1018 MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
1019 dst_def->set_device(dst->assigned_device_name());
1020 dst_def->clear_input(); // Inputs are filled below
1021 if (opts.need_to_record_start_times) {
1022 int64 start_time;
1023 status = GetNodeAttr(*dst_def, "_start_time", &start_time);
1024 if (errors::IsNotFound(status)) {
1025 start_time = opts.start_times[dst->id()].value();
1026 AddNodeAttr("_start_time", start_time, dst_def);
1027 } else if (!status.ok()) {
1028 return status;
1029 }
1030 }
1031
1032 // Arrange the incoming edges to dst so that input[i] holds the
1033 // input flowing into slot numbered i. Trailing entries in input[]
1034 // hold control edges.
1035 inputs.clear();
1036 inputs.resize(dst->num_inputs(), nullptr);
1037 ref_recvs.clear();
1038 ref_control_inputs.clear();
1039 const Edge* control_flow_edge = nullptr;
1040 int32 num_control_flow_edges = 0;
1041 int32 num_input_edges = 0;
1042 for (const Edge* edge : dst->in_edges()) {
1043 if (edge->IsControlEdge()) {
1044 if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
1045 // This is one of the control edges added for control flow. There
1046 // can be multiple such edges as the dest node may have multiple
1047 // remote inputs. We keep track of the number of such edges.
1048 control_flow_edge = edge;
1049 ++num_control_flow_edges;
1050 } else {
1051 inputs.push_back(edge);
1052 }
1053 } else {
1054 DCHECK(inputs[edge->dst_input()] == nullptr);
1055 inputs[edge->dst_input()] = edge;
1056 ++num_input_edges;
1057 }
1058 }
1059
1060 if (num_input_edges != dst->num_inputs()) {
1061 return errors::InvalidArgument("Incomplete graph, missing ",
1062 (dst->num_inputs() - num_input_edges),
1063 " inputs for ", dst->name());
1064 }
1065
1066 // Process in order so that all data edges are added as inputs to
1067 // dst in Edge::dst_input() order.
1068 for (const Edge* edge : inputs) {
1069 const Node* src = edge->src();
1070 if (!src->IsOp()) continue; // Skip Sink/Source nodes.
1071
1072 GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
1073 if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
1074 // Same partition and compatible memory types:
1075 AddInput(dst_def, src->name(), edge->src_output());
1076 if (edge->IsControlEdge() ||
1077 !IsRefType(src->output_type(edge->src_output()))) {
1078 ref_control_inputs.push_back(src->name());
1079 }
1080 continue;
1081 }
1082
1083 int64 send_start_time = 0;
1084 int64 recv_start_time = 0;
1085 if (opts.scheduling_for_recvs) {
1086 status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
1087 if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1088 send_start_time = opts.start_times[src->id()].value();
1089 } else if (!status.ok()) {
1090 return status;
1091 }
1092
1093 status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
1094 if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1095 recv_start_time = opts.start_times[dst->id()].value();
1096 } else if (!status.ok()) {
1097 return status;
1098 }
1099 }
1100
1101 // Check whether there is already a send/recv pair transferring
1102 // the same tensor/control from the src to dst partition.
1103 const bool on_host = IsDstInputOnHost(edge, g_info);
1104 DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
1105 auto iter = dup_recv.find(key);
1106 if (iter != dup_recv.end()) {
1107 // We found one. Reuse the data/control transferred already.
1108 const string& recv_node_name = iter->second.recv->name();
1109 if (edge->IsControlEdge()) {
1110 AddInput(dst_def, recv_node_name, Graph::kControlSlot);
1111 } else {
1112 AddInput(dst_def, recv_node_name, 0);
1113 }
1114 ref_control_inputs.push_back(recv_node_name);
1115
1116 // We want the start_time for the recv to be the smallest of the start
1117 // times of it's consumers. So we update this whenever we use a recv,
1118 // and write it out to the attribute at the end of the subroutine
1119 if (iter->second.start_time > recv_start_time) {
1120 iter->second.start_time = recv_start_time;
1121 }
1122 continue;
1123 }
1124
1125 NodeDefBuilder::NodeOut send_from;
1126 if (edge->IsControlEdge()) {
1127 // Insert a dummy const node that will generate a tiny
1128 // data element to be sent from send to recv.
1129 VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
1130 << src->name() << "] -> " << dst->assigned_device_name() << "["
1131 << dst->name() << "]";
1132 NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
1133 if (!status.ok()) return status;
1134 // Set the start time for this dummy node.
1135 if (opts.scheduling_for_recvs) {
1136 AddNodeAttr("_start_time", send_start_time, dummy);
1137 }
1138 AddInput(dummy, src->name(), Graph::kControlSlot);
1139 send_from.Reset(dummy->name(), 0, DT_FLOAT);
1140 } else {
1141 send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
1142 }
1143
1144 // Need to split edge by placing matching send/recv nodes on
1145 // the src/dst sides of the edge.
1146 NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
1147 send_start_time, &status);
1148 if (!status.ok()) return status;
1149
1150 NodeDef* real_recv = nullptr;
1151 NodeDef* recv =
1152 AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
1153 if (!status.ok()) return status;
1154
1155 // Fix up the control flow edge.
1156 // NOTE(yuanbyu): 'real_recv' must be the real recv node.
1157 if (src_graph == dst_graph) {
1158 // For same device send/recv, add a control edge from send to recv.
1159 // This prevents the asynchronous recv kernel from being scheduled
1160 // before the data is available.
1161 AddInput(real_recv, send->name(), Graph::kControlSlot);
1162 } else if (control_flow_edge != nullptr) {
1163 // Redirect control edge to the real recv since this is not the same
1164 // device send/recv.
1165 --num_control_flow_edges;
1166 AddInput(real_recv, control_flow_edge->src()->name(),
1167 Graph::kControlSlot);
1168 }
1169
1170 if (!edge->IsControlEdge() &&
1171 IsRefType(src->output_type(edge->src_output()))) {
1172 AddNodeAttr("_start_time", recv_start_time, recv);
1173 if (real_recv != recv) {
1174 AddNodeAttr("_start_time", recv_start_time, real_recv);
1175 }
1176 // If src is of ref type and the edge is not a control edge, dst has
1177 // read semantics and therefore we must control the recv.
1178 ref_recvs.push_back(real_recv);
1179 } else {
1180 // Memorize the send/recv pair, only if this is not a "ref" edge.
1181 // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
1182 // for now we don't do it.
1183 dup_recv[key] = {recv, real_recv, recv_start_time};
1184 ref_control_inputs.push_back(recv->name());
1185 }
1186
1187 if (edge->IsControlEdge()) {
1188 ++num_control;
1189 AddInput(dst_def, recv->name(), Graph::kControlSlot);
1190 } else {
1191 ++num_data;
1192 AddInput(dst_def, recv->name(), 0);
1193 }
1194 }
1195
1196 // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
1197 // NOTE(yuanbyu): Adding these control edges should not introduce
1198 // deadlocks. 'dst' has implicit "read" nodes that, when we split
1199 // across devices, are made explicit; Retargeting the dependencies
1200 // to 'dst' to those nodes would not introduce cycles if there isn't
1201 // one before the transformation.
1202 // NOTE(yuanbyu): This may impact performance because it defers the
1203 // execution of recvs until all the other inputs become available.
1204 AddReadControl(ref_recvs, ref_control_inputs);
1205
1206 // Add back the control edges for control flow that are not used.
1207 if (control_flow_edge != nullptr) {
1208 for (int i = 0; i < num_control_flow_edges; ++i) {
1209 AddInput(dst_def, control_flow_edge->src()->name(),
1210 Graph::kControlSlot);
1211 }
1212 }
1213 }
1214
1215 const FunctionLibraryDefinition* flib_def = opts.flib_def;
1216 if (flib_def == nullptr) {
1217 flib_def = &g->flib_def();
1218 }
1219
1220 // Set versions, function library and send/recv incarnation.
1221 for (auto& it : *partitions) {
1222 GraphDef* gdef = &it.second;
1223 *gdef->mutable_versions() = g->versions();
1224 // Prune unreachable functions from `flib_def` before adding them to `gdef`.
1225 *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
1226
1227 // Traverse the graph to fill every send/recv op's incarnation
1228 // information.
1229 SetIncarnation(opts, gdef);
1230 }
1231
1232 // Set the start times for recvs at the very end.
1233 if (opts.scheduling_for_recvs) {
1234 for (auto& it : dup_recv) {
1235 AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
1236 if (it.second.real_recv != it.second.recv) {
1237 AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
1238 }
1239 }
1240 }
1241
1242 VLOG(1) << "Added send/recv: controls=" << num_control
1243 << ", data=" << num_data;
1244 if (VLOG_IS_ON(2)) {
1245 for (auto& it : *partitions) {
1246 GraphDef* gdef = &it.second;
1247 DumpGraphDefToFile(strings::StrCat("partition_", it.first, "_",
1248 reinterpret_cast<uintptr_t>(gdef)),
1249 *gdef);
1250 }
1251 }
1252 return Status::OK();
1253 }
1254
1255 } // namespace tensorflow
1256