1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17
18 #include <unordered_set>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/grappler/costs/graph_properties.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36
37 namespace tensorflow {
38 namespace grappler {
39
40 namespace {
41
RemoveControlInput(NodeDef * node,const string & control_input_to_remove,NodeMap * node_map)42 bool RemoveControlInput(NodeDef* node, const string& control_input_to_remove,
43 NodeMap* node_map) {
44 for (int pos = node->input_size() - 1; pos >= 0; --pos) {
45 const string& input = node->input(pos);
46 if (input[0] != '^') break;
47 if (input == control_input_to_remove) {
48 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
49 node->mutable_input()->RemoveLast();
50 node_map->RemoveOutput(NodeName(input), node->name());
51 return true;
52 }
53 }
54 return false;
55 }
56
57 } // namespace
58
SafeToRemoveIdentity(const NodeDef & node) const59 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
60 if (!IsIdentity(node) && !IsIdentityN(node)) {
61 return true;
62 }
63
64 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
65 return false;
66 }
67 if (!fetch_nodes_known_) {
68 // The output values of this node may be needed.
69 return false;
70 }
71 const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
72 CHECK(input != nullptr) << "node = " << node.name()
73 << " input = " << node.input(0);
74 // Don't remove Identity nodes corresponding to Variable reads or following
75 // Recv.
76 if (IsVariable(*input) || IsRecv(*input)) {
77 return false;
78 }
79 for (const auto& consumer : node_map_->GetOutputs(node.name())) {
80 if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
81 return false;
82 }
83 if (IsSwitch(*input)) {
84 for (const string& consumer_input : consumer->input()) {
85 if (consumer_input == AsControlDependency(node.name())) {
86 return false;
87 }
88 }
89 }
90 }
91 return true;
92 }
93
SafeToConvertToNoOp(const NodeDef & node) const94 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
95 if (HasRegularOutputs(node, *node_map_)) {
96 // The output values of this node may be needed.
97 VLOG(3) << "Not safe to convert '" << node.name()
98 << " to NoOp. Node has outputs.";
99 return false;
100 }
101 if (!fetch_nodes_known_) {
102 VLOG(3) << "Not safe to convert '" << node.name()
103 << " to NoOp. Fetches unknown.";
104 return false;
105 }
106 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
107 VLOG(3) << "Not safe to convert to NoOp: " << node.name()
108 << " is in preserve set.";
109 return false;
110 }
111 if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
112 VLOG(3) << "Not safe to convert '" << node.name()
113 << " to NoOp. Node modifies frame info.";
114 return false;
115 }
116 // Ops reading variables are marked as stateful, but are safe to remove if
117 // redundant.
118 static const absl::flat_hash_set<string>* gather_ops =
119 new absl::flat_hash_set<string>{"Gather", "GatherV2", "GatherNd",
120 "ResourceGather", "ResourceGatherNd"};
121 const bool is_variable_read =
122 IsReadVariableOp(node) || IsReadVariablesOp(node) ||
123 gather_ops->find(node.op()) != gather_ops->end();
124 if (!is_variable_read && !IsFreeOfSideEffect(node)) {
125 VLOG(3) << "Not safe to convert '" << node.name()
126 << " to NoOp. Node has side effect.";
127 return false;
128 }
129 if (node.op().rfind("Submodel", 0) == 0) {
130 return false;
131 }
132 const OpDef* op_def = nullptr;
133 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
134 if (!status.ok() || op_def->output_arg_size() == 0) {
135 return false;
136 }
137 const std::unordered_set<string> do_not_rewrite_ops{
138 "Assert", "CheckNumerics", "_Retval",
139 "_Arg", "_ParallelConcatUpdate", "TPUExecute",
140 "TPUCompile", "ControlTrigger"};
141 if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
142 return false;
143 }
144 if (!SafeToRemoveIdentity(node)) {
145 return false;
146 }
147 return true;
148 }
149
NumEdgesIfBypassed(const NodeDef & node,const std::vector<NodeDef * > & output_nodes) const150 int DependencyOptimizer::NumEdgesIfBypassed(
151 const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
152 const bool is_multi_input_identity_n =
153 IsIdentityN(node) && !IsIdentityNSingleInput(node);
154 const int num_outputs = output_nodes.size();
155 const int num_inputs = node.input_size();
156
157 if (is_multi_input_identity_n) {
158 // multi-input identity_n with input/output control dependencies will likely
159 // increase number of edges after optimization.
160 int num_edges_if_bypassed(0);
161 for (const string& input_node_name : node.input()) {
162 if (IsControlInput(input_node_name)) {
163 num_edges_if_bypassed += num_outputs;
164 } else {
165 ++num_edges_if_bypassed;
166 }
167 }
168
169 for (auto consumer : output_nodes) {
170 for (int j = 0; j < consumer->input_size(); ++j) {
171 const TensorId consumer_input = ParseTensorName(consumer->input(j));
172 if (consumer_input.node() == node.name()) {
173 if (IsControlInput(consumer_input)) {
174 num_edges_if_bypassed += num_inputs;
175 } else {
176 ++num_edges_if_bypassed;
177 }
178 }
179 }
180 }
181 return num_edges_if_bypassed;
182 } else {
183 return num_inputs * num_outputs;
184 }
185 }
186
BypassingNodeIsBeneficial(const NodeDef & node,const std::vector<NodeDef * > & input_nodes,const std::vector<NodeDef * > & output_nodes) const187 bool DependencyOptimizer::BypassingNodeIsBeneficial(
188 const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
189 const std::vector<NodeDef*>& output_nodes) const {
190 const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
191 const bool is_multi_input_identity_n =
192 IsIdentityN(node) && !IsIdentityNSingleInput(node);
193 const int num_outputs = output_nodes.size();
194 const int num_inputs = node.input_size();
195
196 if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
197 return false;
198 }
199
200 // Make sure that we don't increase the number of edges that cross
201 // device boundaries.
202 if ((num_inputs == 1 && num_outputs > 1 &&
203 input_nodes[0]->device() != node.device()) ||
204 (num_inputs > 1 && num_outputs == 1 &&
205 output_nodes[0]->device() != node.device())) {
206 return false;
207 }
208
209 // TODO(rmlarsen): Not all device crossings are equally expensive.
210 // Assign a cost to each based on device affinity and compute a
211 // cost before and after.
212 const string& node_dev = node.device();
213 int num_cross_in = 0;
214 for (NodeDef* input_node : input_nodes) {
215 num_cross_in += static_cast<int>(input_node->device() != node_dev);
216 }
217 int num_cross_out = 0;
218 for (NodeDef* output_node : output_nodes) {
219 num_cross_out += static_cast<int>(output_node->device() != node_dev);
220 }
221
222 // Make sure we do not increase the number of device crossings.
223 const int num_cross_before = num_cross_in + num_cross_out;
224 int num_cross_after = 0;
225 for (NodeDef* input_node : input_nodes) {
226 for (NodeDef* output_node : output_nodes) {
227 num_cross_after +=
228 static_cast<int>(input_node->device() != output_node->device());
229 }
230 }
231 if (num_cross_after > num_cross_before) {
232 return false;
233 }
234
235 if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
236 num_cross_out > 0 && num_cross_after > 0) {
237 // This identity node follows a device crossing, so it might be
238 // following a _Recv node after partitioning. Do not remove such nodes,
239 // unless they only have consumers on the same device as themselves.
240 return false;
241 }
242
243 return true;
244 }
245
OptimizeNode(int node_idx,SetVector<int> * nodes_to_simplify,std::set<int> * nodes_to_delete)246 void DependencyOptimizer::OptimizeNode(int node_idx,
247 SetVector<int>* nodes_to_simplify,
248 std::set<int>* nodes_to_delete) {
249 NodeDef* node = optimized_graph_->mutable_node(node_idx);
250 const bool is_noop = IsNoOp(*node);
251 const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
252 const bool is_multi_input_identity =
253 IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
254 const string node_name = node->name();
255 // Constant nodes with no input control dependency are always executed early,
256 // so we can prune all their output control dependencies.
257 if (IsConstant(*node) && node->input_size() == 0) {
258 const auto output_nodes = node_map_->GetOutputs(node_name);
259 for (NodeDef* fanout : output_nodes) {
260 bool optimize_fanout = false;
261 bool data_connection = false;
262 for (int i = fanout->input_size() - 1; i >= 0; --i) {
263 const TensorId input_tensor = ParseTensorName(fanout->input(i));
264 if (input_tensor.node() == node_name) {
265 if (input_tensor.index() < 0) {
266 fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
267 fanout->mutable_input()->RemoveLast();
268 optimize_fanout = true;
269 } else {
270 data_connection = true;
271 }
272 }
273 }
274 if (optimize_fanout) {
275 nodes_to_simplify->PushBack(node_to_idx_[fanout]);
276 if (!data_connection) {
277 node_map_->RemoveOutput(node_name, fanout->name());
278 }
279 }
280 }
281 if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ &&
282 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
283 // Mark the node for deletion.
284 nodes_to_delete->insert(node_to_idx_[node]);
285 }
286 return;
287 }
288
289 // Change ops that only have control dependencies as outputs to NoOps.
290 if (!is_noop && SafeToConvertToNoOp(*node)) {
291 VLOG(2) << "***** Replacing " << node_name << " (" << node->op()
292 << ") with NoOp.";
293 // The outputs of this node are not consumed. Replace its inputs with
294 // control dependencies and replace the op itself with the NoOp op.
295 std::unordered_set<string> ctrl_inputs;
296 int pos = 0;
297 while (pos < node->input_size()) {
298 const string old_input = node->input(pos);
299 if (IsControlInput(old_input)) {
300 if (!ctrl_inputs.insert(old_input).second) {
301 // We found a duplicate control input. Remove it.
302 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
303 node->mutable_input()->RemoveLast();
304 } else {
305 ++pos;
306 }
307 continue;
308 }
309 // Replace a normal input with a control input.
310 const string ctrl_input = ConstantFolding::AddControlDependency(
311 old_input, optimized_graph_, node_map_.get());
312 ctrl_inputs.insert(ctrl_input);
313 node->set_input(pos, ctrl_input);
314 node_map_->UpdateInput(node_name, old_input, ctrl_input);
315 const NodeDef* old_input_node = node_map_->GetNode(old_input);
316 nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
317 ++pos;
318 }
319 node->set_op("NoOp");
320 EraseRegularNodeAttributes(node);
321 DedupControlInputs(node);
322 nodes_to_simplify->PushBack(node_to_idx_[node]);
323 return;
324 }
325
326 // Remove NoOp nodes if the product of their fan-in and fan-out is less than
327 // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites
328 // take the following form:
329 //
330 // Case a)
331 // x --^> +------+ x --^> +---+
332 // y --^> | NoOp | --^> a ==> y --^> | a |
333 // ... | | ... | |
334 // z --^> +------+ z --^> +---+
335 //
336 // Case b)
337 // +------+ --^> a +---+ --^> a
338 // x --^> | NoOp | --^> b ==> | x | --^> b
339 // | | ... | | ...
340 // +------+ --^> c +---+ --^> c
341 // Case c)
342 // +------+ x ---^> a
343 // x --^> | NoOp | --^> a ==> \/
344 // y --^> | | --^> b /\
345 // +------+ y ---^> b
346 //
347 // We only apply this optimization if we don't increase the number of control
348 // edges across device boundaries, e.g. in cases a) and b) if NoOp and
349 // a and x, respectively, are on the same device. Control edges across device
350 // boundaries require inter-device communication (Send/Recv pairs to be
351 // inserted in the graph), which is very costly.
352 //
353 // We also remove identity nodes, subject to the same constraints on number of
354 // resulting control edges and device boundary crossings:
355 //
356 // Case a)
357 // +----------+ ---> a +---+ ---> a
358 // x --> | Identity | --^> b ==> | x | --^> b
359 // | | ... | | ...
360 // +----------+ --^> c +---+ --^> c
361 //
362 // Case b)
363 // x ---> +----------+ ---> a x ---> +---+
364 // y --^> | Identity | ==> y --^> | a |
365 // ... | | ... | |
366 // z --^> +----------+ z --^> +---+
367 //
368 // Case c)
369 // +----------+ x ---> +---+
370 // x ---> | Identity | ---> a ==> \--^> | a |
371 // y --^> | | --^> b /\ +---+
372 // +----------+ y --^> b
373
374 if (is_noop || ((is_identity || is_multi_input_identity) &&
375 SafeToRemoveIdentity(*node))) {
376 const int num_inputs = node->input_size();
377 std::vector<NodeDef*> input_nodes;
378 for (int i = 0; i < num_inputs; ++i) {
379 NodeDef* input_node = node_map_->GetNode(node->input(i));
380 if (input_node == nullptr) {
381 LOG(ERROR) << "Invalid input " << node->input(i);
382 return;
383 }
384 input_nodes.push_back(input_node);
385 }
386 const auto& output_node_set = node_map_->GetOutputs(node_name);
387 const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
388 output_node_set.end());
389
390 if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
391 return;
392 }
393
394 VLOG(2) << "***** Rerouting input around\n" << node->DebugString();
395 // Now remove the node and re-wire its inputs to its outputs.
396 for (auto consumer : output_nodes) {
397 bool updated_consumer = false;
398 VLOG(2) << "consumer before:\n" << consumer->DebugString();
399 // Remove dependency on node from consumer.
400 for (int i = 0; i < num_inputs; ++i) {
401 const NodeDef* input = input_nodes[i];
402 // Forward dependency from input to consumer if it doesn't already
403 // depend on it.
404 if ((is_identity && i == 0) ||
405 (is_multi_input_identity && !IsControlInput(node->input(i)))) {
406 // Replace regular input from Identity node.
407 string new_input;
408 const string& input_to_forward = node->input(i);
409 CHECK(!IsControlInput(input_to_forward));
410 for (int j = 0; j < consumer->input_size(); ++j) {
411 const TensorId old_input = ParseTensorName(consumer->input(j));
412 if (old_input.node() == node_name) {
413 if (old_input.index() == i) {
414 // Regular input
415 new_input = input_to_forward;
416 node_map_->UpdateInput(consumer->name(), old_input.ToString(),
417 new_input);
418 consumer->set_input(j, new_input);
419 } else if (old_input.index() == -1) {
420 // Control dependency
421 new_input = AsControlDependency(NodeName(input_to_forward));
422 node_map_->UpdateInput(consumer->name(), old_input.ToString(),
423 new_input);
424 consumer->set_input(j, new_input);
425 }
426 }
427 }
428 updated_consumer = true;
429 } else {
430 // Forward dependency from input to consumer if it doesn't already
431 // depend on it.
432 if (node_map_->GetOutputs(input->name()).count(consumer) == 0) {
433 consumer->add_input(AsControlDependency(input->name()));
434 node_map_->AddOutput(input->name(), consumer->name());
435 nodes_to_simplify->PushBack(node_to_idx_[input]);
436 updated_consumer = true;
437 }
438 }
439 }
440 updated_consumer |= RemoveControlInput(
441 consumer, AsControlDependency(node_name), node_map_.get());
442 if (updated_consumer) {
443 nodes_to_simplify->PushBack(node_to_idx_[consumer]);
444 }
445 VLOG(2) << "consumer after:\n" << consumer->DebugString();
446 }
447 node_map_->RemoveOutputs(node_name);
448 if (fetch_nodes_known_ &&
449 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
450 // Mark the node for deletion.
451 nodes_to_delete->insert(node_idx);
452
453 // Disconnect the node from its inputs to enable further optimizations.
454 node_map_->RemoveInputs(node_name);
455 node->clear_input();
456 }
457 }
458 }
459
CleanControlInputs()460 void DependencyOptimizer::CleanControlInputs() {
461 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
462 DedupControlInputs(optimized_graph_->mutable_node(i));
463 }
464 }
465
OptimizeDependencies()466 Status DependencyOptimizer::OptimizeDependencies() {
467 SetVector<int> nodes_to_simplify;
468 std::set<int> nodes_to_delete;
469 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
470 const NodeDef& node = optimized_graph_->node(i);
471 if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
472 IsConstant(node) || SafeToConvertToNoOp(node)) {
473 nodes_to_simplify.PushBack(i);
474 }
475 }
476 while (!nodes_to_simplify.Empty()) {
477 int node_to_simplify = nodes_to_simplify.PopBack();
478 // Discard nodes that were marked for deletion already.
479 while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) {
480 node_to_simplify = nodes_to_simplify.PopBack();
481 }
482 OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete);
483 }
484
485 if (fetch_nodes_known_) {
486 VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of "
487 << optimized_graph_->node_size() << " nodes.";
488 EraseNodesFromGraph(nodes_to_delete, optimized_graph_);
489 node_map_.reset(new NodeMap(optimized_graph_));
490 BuildNodeToIdx();
491 }
492 return Status::OK();
493 }
494
495 namespace {
496
497 enum DistanceFromSource : uint8 { ZERO = 0, ONE = 1, TWO_OR_GREATER = 2 };
498
LongestPathsLowerBounds(int source,const std::pair<int,int> & target_range,const std::vector<std::vector<int>> & outputs,std::vector<DistanceFromSource> * longest_distance)499 void LongestPathsLowerBounds(
500 int source, const std::pair<int, int>& target_range,
501 const std::vector<std::vector<int>>& outputs,
502 std::vector<DistanceFromSource>* longest_distance) {
503 std::deque<int> queue;
504 queue.emplace_front(source);
505 while (!queue.empty()) {
506 int node = queue.front();
507 queue.pop_front();
508 for (int fanout : outputs[node]) {
509 // 1) Only nodes in the target range can be on paths from source to one of
510 // its control outputs.
511 // 2) Since we only need a lower bound on the longest distance, we can
512 // skip nodes for which we have already proven have a path of
513 // length > 1 from the source.
514 if (fanout >= target_range.first && fanout <= target_range.second &&
515 (*longest_distance)[fanout] != TWO_OR_GREATER) {
516 (*longest_distance)[fanout] =
517 (*longest_distance)[fanout] == ZERO ? ONE : TWO_OR_GREATER;
518 queue.emplace_front(fanout);
519 }
520 }
521 }
522 }
523
524 } // namespace
525
TransitiveReduction()526 Status DependencyOptimizer::TransitiveReduction() {
527 // PRECONDITION: optimized_graph_ must be sorted topologically.
528 const int num_nodes = optimized_graph_->node_size();
529 // Set up a compressed version of the graph to save a constant factor in the
530 // expensive algorithm below. Also cache the set of control outputs and the
531 // highest index of a target of any control output from each node.
532 int num_controls = 0;
533 std::vector<std::vector<int>> outputs(num_nodes);
534 std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
535 num_nodes);
536 // target_range[i] contains the range of node indices for which to compute
537 // longest paths starting from node i.
538 std::vector<std::pair<int, int>> target_range(num_nodes, {num_nodes, -1});
539 for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
540 const NodeDef& node = optimized_graph_->node(node_idx);
541 if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
542 // Ignore function nodes and nodes that modify frame info.
543 continue;
544 }
545 for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
546 const string& input = node.input(input_slot);
547 const NodeDef* input_node = node_map_->GetNode(input);
548 if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
549 // Ignore edges from nodes that modify frame info and from Merge nodes,
550 // because we cannot know which of it's input paths executes.
551 continue;
552 }
553 const int input_node_idx = node_to_idx_[input_node];
554 outputs[input_node_idx].push_back(node_idx);
555 target_range[input_node_idx].first =
556 std::min(target_range[input_node_idx].first, node_idx);
557 if (IsControlInput(input)) {
558 ++num_controls;
559 control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
560 target_range[input_node_idx].second =
561 std::max(target_range[input_node_idx].second, node_idx);
562 }
563 }
564 }
565
566 // Run the longest path in DAG algorithm for each source node that has control
567 // outputs. If, for any target node of a control output, there exists a path
568 // of length > 1, we can drop that control dependency.
569 int num_controls_removed = 0;
570 std::vector<DistanceFromSource> longest_distance(num_nodes);
571 // Map from target_index -> set of (input_slot, source_index), representing
572 // the control edges to remove. We sort them in reverse order by input slot,
573 // such that when we swap them out so we don't clobber the
574 // node(target).input() repeated field.
575 typedef std::pair<int, int> InputSlotAndSource;
576 absl::flat_hash_map<
577 int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
578 control_edges_to_remove;
579 for (int source = 0; source < num_nodes; ++source) {
580 if (target_range[source].first >= target_range[source].second ||
581 target_range[source].second <= source) {
582 continue;
583 }
584 // Compute the set of nodes in the transitive fanout of source with
585 // topological sort index in [target_range.first : target_range.second]]
586 // to which there exists a path of length 2 or more from source.
587 std::fill(longest_distance.begin() + target_range[source].first,
588 longest_distance.begin() + target_range[source].second + 1, ZERO);
589 LongestPathsLowerBounds(source, target_range[source], outputs,
590 &longest_distance);
591
592 // If the longest path from source to target of a control dependency is
593 // longer than 1, there exists an alternate path, and we can eliminate the
594 // redundant direct control dependency.
595 for (const auto& control_output : control_outputs[source]) {
596 const int target = control_output.first;
597 if (longest_distance[target] == TWO_OR_GREATER) {
598 const int input_slot = control_output.second;
599 control_edges_to_remove[target].emplace(input_slot, source);
600 }
601 }
602 }
603 for (const auto& it : control_edges_to_remove) {
604 const int target = it.first;
605 NodeDef* target_node = optimized_graph_->mutable_node(target);
606 for (const InputSlotAndSource& slot_and_source : it.second) {
607 const int input_slot = slot_and_source.first;
608 const int source = slot_and_source.second;
609 const NodeDef& source_node = optimized_graph_->node(source);
610 CHECK_LT(input_slot, target_node->input_size());
611 target_node->mutable_input()->SwapElements(input_slot,
612 target_node->input_size() - 1);
613 node_map_->RemoveOutput(source_node.name(), target_node->name());
614 target_node->mutable_input()->RemoveLast();
615 ++num_controls_removed;
616 }
617 }
618 VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
619 << " control dependencies";
620 return Status::OK();
621 }
622
BuildNodeToIdx()623 void DependencyOptimizer::BuildNodeToIdx() {
624 // Set up &node -> index map.
625 node_to_idx_.clear();
626 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
627 const NodeDef& node = optimized_graph_->node(i);
628 node_to_idx_[&node] = i;
629 }
630 }
631
632 // Suppose there are cross-device control inputs to node C from multiple nodes
633 // that are located on another device, e.g., we have control edges:
634 // A->C, B->C
635 // where A and B are on device X and C is on device Y.
636 // We can reduce cross-device communication by introducing an intermediate
637 // NoOp node C' on device X and rewriting the control edges to:
638 // A->C', B->C', C' -> C
GroupCrossDeviceControlEdges(bool host_granularity)639 void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) {
640 VLOG(1)
641 << "DependencyOptimizer::GroupCrossDeviceControlEdges host_granularity="
642 << host_granularity;
643 const int num_nodes = optimized_graph_->node_size();
644 for (int i = 0; i < num_nodes; ++i) {
645 NodeDef* node = optimized_graph_->mutable_node(i);
646 if (node->device().empty()) continue;
647 string rest, node_device = node->device();
648 if (host_granularity) {
649 DeviceNameUtils::SplitDeviceName(node->device(), &node_device, &rest);
650 }
651
652 // Creates new noop nodes for devices on which multiple control inputs are
653 // located.
654
655 // Map keyed by device name to the newly introduced Noop node for that
656 // device. A nullptr value means that we have only seen a single node on
657 // that device.
658 std::map<string, NodeDef*> noops;
659 int num_noops = 0;
660 for (int j = 0; j < node->input_size(); ++j) {
661 if (IsControlInput(node->input(j))) {
662 const NodeDef* input = node_map_->GetNode(node->input(j));
663 if (input == nullptr || input->device().empty()) continue;
664 string input_device = input->device();
665 if (host_granularity) {
666 DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
667 &rest);
668 }
669 if (input_device != node_device) {
670 VLOG(2) << "Cross-device " << node->name() << " " << input->device()
671 << " -> " << node->device();
672 auto emplace_result = noops.emplace(input_device, nullptr);
673 if (!emplace_result.second &&
674 emplace_result.first->second == nullptr) {
675 VLOG(2) << "Duplicate input device from " << node->name();
676 // This is the second cross-device control input from the same
677 // device. Creates an intermediate noop node on that device.
678 string group_name;
679 NodeDef* noop;
680 // Creates a fresh node name; there may be conflicting names from
681 // a previous iteration of the optimizer.
682 do {
683 group_name = AddPrefixToNodeName(
684 node->name(),
685 strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
686 noop = node_map_->GetNode(group_name);
687 ++num_noops;
688 } while (noop != nullptr);
689 noop = optimized_graph_->add_node();
690 noop->set_name(group_name);
691 noop->set_device(input->device());
692 noop->set_op("NoOp");
693 node_map_->AddNode(noop->name(), noop);
694 emplace_result.first->second = noop;
695 VLOG(1) << "GroupCrossDeviceControlEdges: Added "
696 << SummarizeNodeDef(*noop);
697 }
698 }
699 }
700 }
701
702 // Reroute existing control edges to go via the newly introduced NoOp nodes.
703 int pos = 0;
704 while (pos < node->input_size()) {
705 const string& input_name = node->input(pos);
706 if (IsControlInput(input_name)) {
707 NodeDef* input = node_map_->GetNode(input_name);
708 if (input == nullptr) {
709 ++pos;
710 } else {
711 string input_device = input->device();
712 if (host_granularity) {
713 DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
714 &rest);
715 }
716 auto it = noops.find(input_device);
717 if (it == noops.end() || it->second == nullptr) {
718 ++pos;
719 } else {
720 VLOG(2) << "Rewriting input from " << input_name;
721 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
722 node->mutable_input()->RemoveLast();
723 it->second->add_input(AsControlDependency(*input));
724 node_map_->UpdateOutput(input_name, node->name(),
725 it->second->name());
726 }
727 }
728 } else {
729 ++pos;
730 }
731 }
732 for (const auto& entry : noops) {
733 if (entry.second) {
734 node->add_input(AsControlDependency(*entry.second));
735 node_map_->AddOutput(entry.second->name(), node->name());
736 }
737 }
738 }
739 }
740
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)741 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
742 GraphDef* optimized_graph) {
743 optimized_graph_ = optimized_graph;
744 *optimized_graph_ = item.graph;
745 nodes_to_preserve_ = item.NodesToPreserve();
746 fetch_nodes_known_ = !item.fetch.empty();
747 CleanControlInputs();
748
749 const int num_iterations = 2;
750 for (int iteration = 0; iteration < num_iterations; ++iteration) {
751 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
752 Status topo_sort_status;
753 // Perform topological sort to prepare the graph for transitive reduction.
754 topo_sort_status = TopologicalSort(optimized_graph_);
755 // Set up index-based graph datastructures to speed up analysis steps below.
756 node_map_.reset(new NodeMap(optimized_graph_));
757 BuildNodeToIdx();
758
759 if (topo_sort_status.ok()) {
760 // Remove redundant control dependencies.
761 TF_RETURN_IF_ERROR(TransitiveReduction());
762 } else {
763 LOG(ERROR) << "Iteration = " << iteration
764 << ", topological sort failed with message: "
765 << topo_sort_status.error_message();
766 }
767 // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
768 // nodes.
769 TF_RETURN_IF_ERROR(OptimizeDependencies());
770
771 // Dedup control inputs.
772 CleanControlInputs();
773
774 // Merge multiple control edges from the same device.
775 GroupCrossDeviceControlEdges(/*host_granularity=*/false);
776
777 // Merge control edges from the same host to reduce RPC traffic.
778 GroupCrossDeviceControlEdges(/*host_granularity=*/true);
779 }
780
781 return Status::OK();
782 }
783
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)784 void DependencyOptimizer::Feedback(Cluster* /*cluster*/,
785 const GrapplerItem& /*item*/,
786 const GraphDef& /*optimized_graph*/,
787 double /*result*/) {
788 // Nothing to do for DependencyOptimizer.
789 }
790
791 } // end namespace grappler
792 } // end namespace tensorflow
793