1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17 
18 #include <unordered_set>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/grappler/costs/graph_properties.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36 
37 namespace tensorflow {
38 namespace grappler {
39 
40 namespace {
41 
RemoveControlInput(NodeDef * node,const string & control_input_to_remove,NodeMap * node_map)42 bool RemoveControlInput(NodeDef* node, const string& control_input_to_remove,
43                         NodeMap* node_map) {
44   for (int pos = node->input_size() - 1; pos >= 0; --pos) {
45     const string& input = node->input(pos);
46     if (input[0] != '^') break;
47     if (input == control_input_to_remove) {
48       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
49       node->mutable_input()->RemoveLast();
50       node_map->RemoveOutput(NodeName(input), node->name());
51       return true;
52     }
53   }
54   return false;
55 }
56 
57 }  // namespace
58 
SafeToRemoveIdentity(const NodeDef & node) const59 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
60   if (!IsIdentity(node) && !IsIdentityN(node)) {
61     return true;
62   }
63 
64   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
65     return false;
66   }
67   if (!fetch_nodes_known_) {
68     // The output values of this node may be needed.
69     return false;
70   }
71   const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
72   CHECK(input != nullptr) << "node = " << node.name()
73                           << " input = " << node.input(0);
74   // Don't remove Identity nodes corresponding to Variable reads or following
75   // Recv.
76   if (IsVariable(*input) || IsRecv(*input)) {
77     return false;
78   }
79   for (const auto& consumer : node_map_->GetOutputs(node.name())) {
80     if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
81       return false;
82     }
83     if (IsSwitch(*input)) {
84       for (const string& consumer_input : consumer->input()) {
85         if (consumer_input == AsControlDependency(node.name())) {
86           return false;
87         }
88       }
89     }
90   }
91   return true;
92 }
93 
SafeToConvertToNoOp(const NodeDef & node) const94 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
95   if (HasRegularOutputs(node, *node_map_)) {
96     // The output values of this node may be needed.
97     VLOG(3) << "Not safe to convert '" << node.name()
98             << " to NoOp. Node has outputs.";
99     return false;
100   }
101   if (!fetch_nodes_known_) {
102     VLOG(3) << "Not safe to convert '" << node.name()
103             << " to NoOp. Fetches unknown.";
104     return false;
105   }
106   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
107     VLOG(3) << "Not safe to convert to NoOp: " << node.name()
108             << " is in preserve set.";
109     return false;
110   }
111   if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
112     VLOG(3) << "Not safe to convert '" << node.name()
113             << " to NoOp. Node modifies frame info.";
114     return false;
115   }
116   // Ops reading variables are marked as stateful, but are safe to remove if
117   // redundant.
118   static const absl::flat_hash_set<string>* gather_ops =
119       new absl::flat_hash_set<string>{"Gather", "GatherV2", "GatherNd",
120                                       "ResourceGather", "ResourceGatherNd"};
121   const bool is_variable_read =
122       IsReadVariableOp(node) || IsReadVariablesOp(node) ||
123       gather_ops->find(node.op()) != gather_ops->end();
124   if (!is_variable_read && !IsFreeOfSideEffect(node)) {
125     VLOG(3) << "Not safe to convert '" << node.name()
126             << " to NoOp. Node has side effect.";
127     return false;
128   }
129   if (node.op().rfind("Submodel", 0) == 0) {
130     return false;
131   }
132   const OpDef* op_def = nullptr;
133   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
134   if (!status.ok() || op_def->output_arg_size() == 0) {
135     return false;
136   }
137   const std::unordered_set<string> do_not_rewrite_ops{
138       "Assert",     "CheckNumerics",         "_Retval",
139       "_Arg",       "_ParallelConcatUpdate", "TPUExecute",
140       "TPUCompile", "ControlTrigger"};
141   if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
142     return false;
143   }
144   if (!SafeToRemoveIdentity(node)) {
145     return false;
146   }
147   return true;
148 }
149 
NumEdgesIfBypassed(const NodeDef & node,const std::vector<NodeDef * > & output_nodes) const150 int DependencyOptimizer::NumEdgesIfBypassed(
151     const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
152   const bool is_multi_input_identity_n =
153       IsIdentityN(node) && !IsIdentityNSingleInput(node);
154   const int num_outputs = output_nodes.size();
155   const int num_inputs = node.input_size();
156 
157   if (is_multi_input_identity_n) {
158     // multi-input identity_n with input/output control dependencies will likely
159     // increase number of edges after optimization.
160     int num_edges_if_bypassed(0);
161     for (const string& input_node_name : node.input()) {
162       if (IsControlInput(input_node_name)) {
163         num_edges_if_bypassed += num_outputs;
164       } else {
165         ++num_edges_if_bypassed;
166       }
167     }
168 
169     for (auto consumer : output_nodes) {
170       for (int j = 0; j < consumer->input_size(); ++j) {
171         const TensorId consumer_input = ParseTensorName(consumer->input(j));
172         if (consumer_input.node() == node.name()) {
173           if (IsControlInput(consumer_input)) {
174             num_edges_if_bypassed += num_inputs;
175           } else {
176             ++num_edges_if_bypassed;
177           }
178         }
179       }
180     }
181     return num_edges_if_bypassed;
182   } else {
183     return num_inputs * num_outputs;
184   }
185 }
186 
BypassingNodeIsBeneficial(const NodeDef & node,const std::vector<NodeDef * > & input_nodes,const std::vector<NodeDef * > & output_nodes) const187 bool DependencyOptimizer::BypassingNodeIsBeneficial(
188     const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
189     const std::vector<NodeDef*>& output_nodes) const {
190   const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
191   const bool is_multi_input_identity_n =
192       IsIdentityN(node) && !IsIdentityNSingleInput(node);
193   const int num_outputs = output_nodes.size();
194   const int num_inputs = node.input_size();
195 
196   if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
197     return false;
198   }
199 
200   // Make sure that we don't increase the number of edges that cross
201   // device boundaries.
202   if ((num_inputs == 1 && num_outputs > 1 &&
203        input_nodes[0]->device() != node.device()) ||
204       (num_inputs > 1 && num_outputs == 1 &&
205        output_nodes[0]->device() != node.device())) {
206     return false;
207   }
208 
209   // TODO(rmlarsen): Not all device crossings are equally expensive.
210   // Assign a cost to each based on device affinity and compute a
211   // cost before and after.
212   const string& node_dev = node.device();
213   int num_cross_in = 0;
214   for (NodeDef* input_node : input_nodes) {
215     num_cross_in += static_cast<int>(input_node->device() != node_dev);
216   }
217   int num_cross_out = 0;
218   for (NodeDef* output_node : output_nodes) {
219     num_cross_out += static_cast<int>(output_node->device() != node_dev);
220   }
221 
222   // Make sure we do not increase the number of device crossings.
223   const int num_cross_before = num_cross_in + num_cross_out;
224   int num_cross_after = 0;
225   for (NodeDef* input_node : input_nodes) {
226     for (NodeDef* output_node : output_nodes) {
227       num_cross_after +=
228           static_cast<int>(input_node->device() != output_node->device());
229     }
230   }
231   if (num_cross_after > num_cross_before) {
232     return false;
233   }
234 
235   if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
236       num_cross_out > 0 && num_cross_after > 0) {
237     // This identity node follows a device crossing, so it might be
238     // following a _Recv node after partitioning. Do not remove such nodes,
239     // unless they only have consumers on the same device as themselves.
240     return false;
241   }
242 
243   return true;
244 }
245 
OptimizeNode(int node_idx,SetVector<int> * nodes_to_simplify,std::set<int> * nodes_to_delete)246 void DependencyOptimizer::OptimizeNode(int node_idx,
247                                        SetVector<int>* nodes_to_simplify,
248                                        std::set<int>* nodes_to_delete) {
249   NodeDef* node = optimized_graph_->mutable_node(node_idx);
250   const bool is_noop = IsNoOp(*node);
251   const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
252   const bool is_multi_input_identity =
253       IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
254   const string node_name = node->name();
255   // Constant nodes with no input control dependency are always executed early,
256   // so we can prune all their output control dependencies.
257   if (IsConstant(*node) && node->input_size() == 0) {
258     const auto output_nodes = node_map_->GetOutputs(node_name);
259     for (NodeDef* fanout : output_nodes) {
260       bool optimize_fanout = false;
261       bool data_connection = false;
262       for (int i = fanout->input_size() - 1; i >= 0; --i) {
263         const TensorId input_tensor = ParseTensorName(fanout->input(i));
264         if (input_tensor.node() == node_name) {
265           if (input_tensor.index() < 0) {
266             fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
267             fanout->mutable_input()->RemoveLast();
268             optimize_fanout = true;
269           } else {
270             data_connection = true;
271           }
272         }
273       }
274       if (optimize_fanout) {
275         nodes_to_simplify->PushBack(node_to_idx_[fanout]);
276         if (!data_connection) {
277           node_map_->RemoveOutput(node_name, fanout->name());
278         }
279       }
280     }
281     if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ &&
282         nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
283       // Mark the node for deletion.
284       nodes_to_delete->insert(node_to_idx_[node]);
285     }
286     return;
287   }
288 
289   // Change ops that only have control dependencies as outputs to NoOps.
290   if (!is_noop && SafeToConvertToNoOp(*node)) {
291     VLOG(2) << "***** Replacing  " << node_name << " (" << node->op()
292             << ") with NoOp.";
293     // The outputs of this node are not consumed. Replace its inputs with
294     // control dependencies and replace the op itself with the NoOp op.
295     std::unordered_set<string> ctrl_inputs;
296     int pos = 0;
297     while (pos < node->input_size()) {
298       const string old_input = node->input(pos);
299       if (IsControlInput(old_input)) {
300         if (!ctrl_inputs.insert(old_input).second) {
301           // We found a duplicate control input. Remove it.
302           node->mutable_input()->SwapElements(pos, node->input_size() - 1);
303           node->mutable_input()->RemoveLast();
304         } else {
305           ++pos;
306         }
307         continue;
308       }
309       // Replace a normal input with a control input.
310       const string ctrl_input = ConstantFolding::AddControlDependency(
311           old_input, optimized_graph_, node_map_.get());
312       ctrl_inputs.insert(ctrl_input);
313       node->set_input(pos, ctrl_input);
314       node_map_->UpdateInput(node_name, old_input, ctrl_input);
315       const NodeDef* old_input_node = node_map_->GetNode(old_input);
316       nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
317       ++pos;
318     }
319     node->set_op("NoOp");
320     EraseRegularNodeAttributes(node);
321     DedupControlInputs(node);
322     nodes_to_simplify->PushBack(node_to_idx_[node]);
323     return;
324   }
325 
326   // Remove NoOp nodes if the product of their fan-in and fan-out is less than
327   // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites
328   // take the following form:
329   //
330   // Case a)
331   //    x --^> +------+                x --^> +---+
332   //    y --^> | NoOp | --^> a   ==>   y --^> | a |
333   //    ...    |      |                  ...  |   |
334   //    z --^> +------+                z --^> +---+
335   //
336   // Case b)
337   //           +------+ --^> a         +---+ --^> a
338   //    x --^> | NoOp | --^> b  ==>    | x | --^> b
339   //           |      | ...            |   | ...
340   //           +------+ --^> c         +---+ --^> c
341   // Case c)
342   //           +------+                x ---^> a
343   //    x --^> | NoOp | --^> a  ==>      \/
344   //    y --^> |      | --^> b           /\
345   //           +------+                y ---^> b
346   //
347   // We only apply this optimization if we don't increase the number of control
348   // edges across device boundaries, e.g. in cases a) and b) if NoOp and
349   // a and x, respectively, are on the same device. Control edges across device
350   // boundaries require inter-device communication (Send/Recv pairs to be
351   // inserted in the graph), which is very costly.
352   //
353   // We also remove identity nodes, subject to the same constraints on number of
354   // resulting control edges and device boundary crossings:
355   //
356   // Case a)
357   //          +----------+ ---> a       +---+ ---> a
358   //    x --> | Identity | --^> b  ==>  | x | --^> b
359   //          |          | ...          |   | ...
360   //          +----------+ --^> c       +---+ --^> c
361   //
362   // Case b)
363   //    x ---> +----------+ ---> a      x ---> +---+
364   //    y --^> | Identity |        ==>  y --^> | a |
365   //    ...    |          |               ...  |   |
366   //    z --^> +----------+             z --^> +---+
367   //
368   // Case c)
369   //           +----------+             x ---> +---+
370   //    x ---> | Identity | ---> a ==>   \--^> | a |
371   //    y --^> |          | --^> b       /\    +---+
372   //           +----------+             y --^> b
373 
374   if (is_noop || ((is_identity || is_multi_input_identity) &&
375                   SafeToRemoveIdentity(*node))) {
376     const int num_inputs = node->input_size();
377     std::vector<NodeDef*> input_nodes;
378     for (int i = 0; i < num_inputs; ++i) {
379       NodeDef* input_node = node_map_->GetNode(node->input(i));
380       if (input_node == nullptr) {
381         LOG(ERROR) << "Invalid input " << node->input(i);
382         return;
383       }
384       input_nodes.push_back(input_node);
385     }
386     const auto& output_node_set = node_map_->GetOutputs(node_name);
387     const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
388                                              output_node_set.end());
389 
390     if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
391       return;
392     }
393 
394     VLOG(2) << "***** Rerouting input around\n" << node->DebugString();
395     // Now remove the node and re-wire its inputs to its outputs.
396     for (auto consumer : output_nodes) {
397       bool updated_consumer = false;
398       VLOG(2) << "consumer before:\n" << consumer->DebugString();
399       // Remove dependency on node from consumer.
400       for (int i = 0; i < num_inputs; ++i) {
401         const NodeDef* input = input_nodes[i];
402         // Forward dependency from input to consumer if it doesn't already
403         // depend on it.
404         if ((is_identity && i == 0) ||
405             (is_multi_input_identity && !IsControlInput(node->input(i)))) {
406           // Replace regular input from Identity node.
407           string new_input;
408           const string& input_to_forward = node->input(i);
409           CHECK(!IsControlInput(input_to_forward));
410           for (int j = 0; j < consumer->input_size(); ++j) {
411             const TensorId old_input = ParseTensorName(consumer->input(j));
412             if (old_input.node() == node_name) {
413               if (old_input.index() == i) {
414                 // Regular input
415                 new_input = input_to_forward;
416                 node_map_->UpdateInput(consumer->name(), old_input.ToString(),
417                                        new_input);
418                 consumer->set_input(j, new_input);
419               } else if (old_input.index() == -1) {
420                 // Control dependency
421                 new_input = AsControlDependency(NodeName(input_to_forward));
422                 node_map_->UpdateInput(consumer->name(), old_input.ToString(),
423                                        new_input);
424                 consumer->set_input(j, new_input);
425               }
426             }
427           }
428           updated_consumer = true;
429         } else {
430           // Forward dependency from input to consumer if it doesn't already
431           // depend on it.
432           if (node_map_->GetOutputs(input->name()).count(consumer) == 0) {
433             consumer->add_input(AsControlDependency(input->name()));
434             node_map_->AddOutput(input->name(), consumer->name());
435             nodes_to_simplify->PushBack(node_to_idx_[input]);
436             updated_consumer = true;
437           }
438         }
439       }
440       updated_consumer |= RemoveControlInput(
441           consumer, AsControlDependency(node_name), node_map_.get());
442       if (updated_consumer) {
443         nodes_to_simplify->PushBack(node_to_idx_[consumer]);
444       }
445       VLOG(2) << "consumer after:\n" << consumer->DebugString();
446     }
447     node_map_->RemoveOutputs(node_name);
448     if (fetch_nodes_known_ &&
449         nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
450       // Mark the node for deletion.
451       nodes_to_delete->insert(node_idx);
452 
453       // Disconnect the node from its inputs to enable further optimizations.
454       node_map_->RemoveInputs(node_name);
455       node->clear_input();
456     }
457   }
458 }
459 
CleanControlInputs()460 void DependencyOptimizer::CleanControlInputs() {
461   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
462     DedupControlInputs(optimized_graph_->mutable_node(i));
463   }
464 }
465 
OptimizeDependencies()466 Status DependencyOptimizer::OptimizeDependencies() {
467   SetVector<int> nodes_to_simplify;
468   std::set<int> nodes_to_delete;
469   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
470     const NodeDef& node = optimized_graph_->node(i);
471     if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
472         IsConstant(node) || SafeToConvertToNoOp(node)) {
473       nodes_to_simplify.PushBack(i);
474     }
475   }
476   while (!nodes_to_simplify.Empty()) {
477     int node_to_simplify = nodes_to_simplify.PopBack();
478     // Discard nodes that were marked for deletion already.
479     while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) {
480       node_to_simplify = nodes_to_simplify.PopBack();
481     }
482     OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete);
483   }
484 
485   if (fetch_nodes_known_) {
486     VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of "
487             << optimized_graph_->node_size() << " nodes.";
488     EraseNodesFromGraph(nodes_to_delete, optimized_graph_);
489     node_map_.reset(new NodeMap(optimized_graph_));
490     BuildNodeToIdx();
491   }
492   return Status::OK();
493 }
494 
495 namespace {
496 
497 enum DistanceFromSource : uint8 { ZERO = 0, ONE = 1, TWO_OR_GREATER = 2 };
498 
LongestPathsLowerBounds(int source,const std::pair<int,int> & target_range,const std::vector<std::vector<int>> & outputs,std::vector<DistanceFromSource> * longest_distance)499 void LongestPathsLowerBounds(
500     int source, const std::pair<int, int>& target_range,
501     const std::vector<std::vector<int>>& outputs,
502     std::vector<DistanceFromSource>* longest_distance) {
503   std::deque<int> queue;
504   queue.emplace_front(source);
505   while (!queue.empty()) {
506     int node = queue.front();
507     queue.pop_front();
508     for (int fanout : outputs[node]) {
509       // 1) Only nodes in the target range can be on paths from source to one of
510       //    its control outputs.
511       // 2) Since we only need a lower bound on the longest distance, we can
512       //    skip nodes for which we have already proven have a path of
513       //    length > 1 from the source.
514       if (fanout >= target_range.first && fanout <= target_range.second &&
515           (*longest_distance)[fanout] != TWO_OR_GREATER) {
516         (*longest_distance)[fanout] =
517             (*longest_distance)[fanout] == ZERO ? ONE : TWO_OR_GREATER;
518         queue.emplace_front(fanout);
519       }
520     }
521   }
522 }
523 
524 }  // namespace
525 
TransitiveReduction()526 Status DependencyOptimizer::TransitiveReduction() {
527   // PRECONDITION: optimized_graph_ must be sorted topologically.
528   const int num_nodes = optimized_graph_->node_size();
529   // Set up a compressed version of the graph to save a constant factor in the
530   // expensive algorithm below. Also cache the set of control outputs and the
531   // highest index of a target of any control output from each node.
532   int num_controls = 0;
533   std::vector<std::vector<int>> outputs(num_nodes);
534   std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
535       num_nodes);
536   // target_range[i] contains the range of node indices for which to compute
537   // longest paths starting from node i.
538   std::vector<std::pair<int, int>> target_range(num_nodes, {num_nodes, -1});
539   for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
540     const NodeDef& node = optimized_graph_->node(node_idx);
541     if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
542       // Ignore function nodes and nodes that modify frame info.
543       continue;
544     }
545     for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
546       const string& input = node.input(input_slot);
547       const NodeDef* input_node = node_map_->GetNode(input);
548       if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
549         // Ignore edges from nodes that modify frame info and from Merge nodes,
550         // because we cannot know which of it's input paths executes.
551         continue;
552       }
553       const int input_node_idx = node_to_idx_[input_node];
554       outputs[input_node_idx].push_back(node_idx);
555       target_range[input_node_idx].first =
556           std::min(target_range[input_node_idx].first, node_idx);
557       if (IsControlInput(input)) {
558         ++num_controls;
559         control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
560         target_range[input_node_idx].second =
561             std::max(target_range[input_node_idx].second, node_idx);
562       }
563     }
564   }
565 
566   // Run the longest path in DAG algorithm for each source node that has control
567   // outputs. If, for any target node of a control output, there exists a path
568   // of length > 1, we can drop that control dependency.
569   int num_controls_removed = 0;
570   std::vector<DistanceFromSource> longest_distance(num_nodes);
571   // Map from target_index -> set of (input_slot, source_index), representing
572   // the control edges to remove. We sort them in reverse order by input slot,
573   // such that when we swap them out so we don't clobber the
574   // node(target).input() repeated field.
575   typedef std::pair<int, int> InputSlotAndSource;
576   absl::flat_hash_map<
577       int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
578       control_edges_to_remove;
579   for (int source = 0; source < num_nodes; ++source) {
580     if (target_range[source].first >= target_range[source].second ||
581         target_range[source].second <= source) {
582       continue;
583     }
584     // Compute the set of nodes in the transitive fanout of source with
585     // topological sort index in [target_range.first : target_range.second]]
586     // to which there exists a path of length 2 or more from source.
587     std::fill(longest_distance.begin() + target_range[source].first,
588               longest_distance.begin() + target_range[source].second + 1, ZERO);
589     LongestPathsLowerBounds(source, target_range[source], outputs,
590                             &longest_distance);
591 
592     // If the longest path from source to target of a control dependency is
593     // longer than 1, there exists an alternate path, and we can eliminate the
594     // redundant direct control dependency.
595     for (const auto& control_output : control_outputs[source]) {
596       const int target = control_output.first;
597       if (longest_distance[target] == TWO_OR_GREATER) {
598         const int input_slot = control_output.second;
599         control_edges_to_remove[target].emplace(input_slot, source);
600       }
601     }
602   }
603   for (const auto& it : control_edges_to_remove) {
604     const int target = it.first;
605     NodeDef* target_node = optimized_graph_->mutable_node(target);
606     for (const InputSlotAndSource& slot_and_source : it.second) {
607       const int input_slot = slot_and_source.first;
608       const int source = slot_and_source.second;
609       const NodeDef& source_node = optimized_graph_->node(source);
610       CHECK_LT(input_slot, target_node->input_size());
611       target_node->mutable_input()->SwapElements(input_slot,
612                                                  target_node->input_size() - 1);
613       node_map_->RemoveOutput(source_node.name(), target_node->name());
614       target_node->mutable_input()->RemoveLast();
615       ++num_controls_removed;
616     }
617   }
618   VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
619           << " control dependencies";
620   return Status::OK();
621 }
622 
BuildNodeToIdx()623 void DependencyOptimizer::BuildNodeToIdx() {
624   // Set up &node -> index map.
625   node_to_idx_.clear();
626   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
627     const NodeDef& node = optimized_graph_->node(i);
628     node_to_idx_[&node] = i;
629   }
630 }
631 
632 // Suppose there are cross-device control inputs to node C from multiple nodes
633 // that are located on another device, e.g., we have control edges:
634 // A->C, B->C
635 // where A and B are on device X and C is on device Y.
636 // We can reduce cross-device communication by introducing an intermediate
637 // NoOp node C' on device X and rewriting the control edges to:
638 // A->C', B->C', C' -> C
GroupCrossDeviceControlEdges(bool host_granularity)639 void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) {
640   VLOG(1)
641       << "DependencyOptimizer::GroupCrossDeviceControlEdges host_granularity="
642       << host_granularity;
643   const int num_nodes = optimized_graph_->node_size();
644   for (int i = 0; i < num_nodes; ++i) {
645     NodeDef* node = optimized_graph_->mutable_node(i);
646     if (node->device().empty()) continue;
647     string rest, node_device = node->device();
648     if (host_granularity) {
649       DeviceNameUtils::SplitDeviceName(node->device(), &node_device, &rest);
650     }
651 
652     // Creates new noop nodes for devices on which multiple control inputs are
653     // located.
654 
655     // Map keyed by device name to the newly introduced Noop node for that
656     // device. A nullptr value means that we have only seen a single node on
657     // that device.
658     std::map<string, NodeDef*> noops;
659     int num_noops = 0;
660     for (int j = 0; j < node->input_size(); ++j) {
661       if (IsControlInput(node->input(j))) {
662         const NodeDef* input = node_map_->GetNode(node->input(j));
663         if (input == nullptr || input->device().empty()) continue;
664         string input_device = input->device();
665         if (host_granularity) {
666           DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
667                                            &rest);
668         }
669         if (input_device != node_device) {
670           VLOG(2) << "Cross-device " << node->name() << " " << input->device()
671                   << " -> " << node->device();
672           auto emplace_result = noops.emplace(input_device, nullptr);
673           if (!emplace_result.second &&
674               emplace_result.first->second == nullptr) {
675             VLOG(2) << "Duplicate input device from " << node->name();
676             // This is the second cross-device control input from the same
677             // device. Creates an intermediate noop node on that device.
678             string group_name;
679             NodeDef* noop;
680             // Creates a fresh node name; there may be conflicting names from
681             // a previous iteration of the optimizer.
682             do {
683               group_name = AddPrefixToNodeName(
684                   node->name(),
685                   strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
686               noop = node_map_->GetNode(group_name);
687               ++num_noops;
688             } while (noop != nullptr);
689             noop = optimized_graph_->add_node();
690             noop->set_name(group_name);
691             noop->set_device(input->device());
692             noop->set_op("NoOp");
693             node_map_->AddNode(noop->name(), noop);
694             emplace_result.first->second = noop;
695             VLOG(1) << "GroupCrossDeviceControlEdges: Added "
696                     << SummarizeNodeDef(*noop);
697           }
698         }
699       }
700     }
701 
702     // Reroute existing control edges to go via the newly introduced NoOp nodes.
703     int pos = 0;
704     while (pos < node->input_size()) {
705       const string& input_name = node->input(pos);
706       if (IsControlInput(input_name)) {
707         NodeDef* input = node_map_->GetNode(input_name);
708         if (input == nullptr) {
709           ++pos;
710         } else {
711           string input_device = input->device();
712           if (host_granularity) {
713             DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
714                                              &rest);
715           }
716           auto it = noops.find(input_device);
717           if (it == noops.end() || it->second == nullptr) {
718             ++pos;
719           } else {
720             VLOG(2) << "Rewriting input from " << input_name;
721             node->mutable_input()->SwapElements(pos, node->input_size() - 1);
722             node->mutable_input()->RemoveLast();
723             it->second->add_input(AsControlDependency(*input));
724             node_map_->UpdateOutput(input_name, node->name(),
725                                     it->second->name());
726           }
727         }
728       } else {
729         ++pos;
730       }
731     }
732     for (const auto& entry : noops) {
733       if (entry.second) {
734         node->add_input(AsControlDependency(*entry.second));
735         node_map_->AddOutput(entry.second->name(), node->name());
736       }
737     }
738   }
739 }
740 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)741 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
742                                      GraphDef* optimized_graph) {
743   optimized_graph_ = optimized_graph;
744   *optimized_graph_ = item.graph;
745   nodes_to_preserve_ = item.NodesToPreserve();
746   fetch_nodes_known_ = !item.fetch.empty();
747   CleanControlInputs();
748 
749   const int num_iterations = 2;
750   for (int iteration = 0; iteration < num_iterations; ++iteration) {
751     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
752     Status topo_sort_status;
753     // Perform topological sort to prepare the graph for transitive reduction.
754     topo_sort_status = TopologicalSort(optimized_graph_);
755     // Set up index-based graph datastructures to speed up analysis steps below.
756     node_map_.reset(new NodeMap(optimized_graph_));
757     BuildNodeToIdx();
758 
759     if (topo_sort_status.ok()) {
760       // Remove redundant control dependencies.
761       TF_RETURN_IF_ERROR(TransitiveReduction());
762     } else {
763       LOG(ERROR) << "Iteration = " << iteration
764                  << ", topological sort failed with message: "
765                  << topo_sort_status.error_message();
766     }
767     // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
768     // nodes.
769     TF_RETURN_IF_ERROR(OptimizeDependencies());
770 
771     // Dedup control inputs.
772     CleanControlInputs();
773 
774     // Merge multiple control edges from the same device.
775     GroupCrossDeviceControlEdges(/*host_granularity=*/false);
776 
777     // Merge control edges from the same host to reduce RPC traffic.
778     GroupCrossDeviceControlEdges(/*host_granularity=*/true);
779   }
780 
781   return Status::OK();
782 }
783 
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)784 void DependencyOptimizer::Feedback(Cluster* /*cluster*/,
785                                    const GrapplerItem& /*item*/,
786                                    const GraphDef& /*optimized_graph*/,
787                                    double /*result*/) {
788   // Nothing to do for DependencyOptimizer.
789 }
790 
791 }  // end namespace grappler
792 }  // end namespace tensorflow
793