1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
17 
18 #include <set>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/graph_topology_view.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/canonicalizer.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/grappler/utils/traversal.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/hash.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "tensorflow/core/platform/strcat.h"
43 #include "tensorflow/core/platform/stringpiece.h"
44 #include "tensorflow/core/platform/types.h"
45 
46 namespace tensorflow {
47 namespace grappler {
48 class Cluster;
49 }  // namespace grappler
50 }  // namespace tensorflow
51 
52 using tensorflow::strings::StrCat;
53 
54 namespace tensorflow {
55 namespace grappler {
56 
57 class UniqueNodes {
58  public:
FindOrAddRepresentative(NodeDef * node)59   NodeDef* FindOrAddRepresentative(NodeDef* node) {
60     uint64 sig = ComputeSignature(*node);
61     std::vector<NodeDef*>& candidates = rep_[sig];
62     for (auto& candidate : candidates) {
63       if ((candidate == node) || SameNode(*candidate, *node)) {
64         return candidate;
65       }
66     }
67     candidates.push_back(node);
68     return node;
69   }
70 
RemoveRepresentative(NodeDef * node)71   void RemoveRepresentative(NodeDef* node) {
72     auto it = memoized_signatures_.find(node);
73     if (it == memoized_signatures_.end()) return;
74 
75     std::vector<NodeDef*>& candidates = rep_[it->second];
76     for (int i = 0, end = candidates.size(); i < end; ++i) {
77       if (candidates[i] == node) {
78         std::swap(candidates[i], candidates[candidates.size() - 1]);
79         candidates.resize(candidates.size() - 1);
80         break;
81       }
82     }
83     memoized_signatures_.erase(node);
84   }
85 
86  private:
87   uint64 ComputeSignature(const NodeDef& node);
88   bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
89 
90   absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
91   absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
92 };
93 
ComputeSignature(const NodeDef & node)94 uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
95   auto it = memoized_signatures_.find(&node);
96   if (it != memoized_signatures_.end()) return it->second;
97 
98   uint64 h = Hash64(node.op());
99   h = Hash64Combine(Hash64(node.device()), h);
100 
101   for (const auto& input : node.input()) {
102     const TensorId input_tensor = ParseTensorName(input);
103     uint64 input_hash = Hash64Combine(
104         Hash64(input_tensor.node().data(), input_tensor.node().size()),
105         std::hash<int>()(input_tensor.index()));
106     h = Hash64CombineUnordered(input_hash, h);
107   }
108   for (const auto& attr : node.attr()) {
109     uint64 attr_hash =
110         Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second));
111     h = Hash64CombineUnordered(attr_hash, h);
112   }
113   memoized_signatures_.emplace(&node, h);
114   return h;
115 }
116 
117 // PRECONDITION:
118 //  Node input orders are assumed to be canonicalized, i.e. control inputs for
119 //  all nodes as well as regular inputs for commutative nodes must be sorted.
SameNode(const NodeDef & node1,const NodeDef & node2) const120 bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
121   if (node1.op() != node2.op()) {
122     return false;
123   }
124   if (node1.device() != node2.device()) {
125     return false;
126   }
127   if (node1.input_size() != node2.input_size()) {
128     return false;
129   }
130   if (node1.attr_size() != node2.attr_size()) {
131     return false;
132   }
133 
134   // Compare inputs.
135   auto it1 = node1.input().begin();
136   auto it2 = node2.input().begin();
137   for (; it1 != node1.input().end(); ++it1, ++it2) {
138     if (*it1 != *it2) return false;
139   }
140 
141   // Compare attributes.
142   for (const auto& attr1 : node1.attr()) {
143     auto it = node2.attr().find(attr1.first);
144     if (it == node2.attr().end()) return false;
145     if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
146   }
147 
148   return true;
149 }
150 
CanDedup(const NodeDef & node) const151 bool CommonSubgraphElimination::CanDedup(const NodeDef& node) const {
152   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
153     return false;
154   }
155   if (IsEnter(node) || IsExit(node)) {
156     return false;
157   }
158   if (node.device().find("SPU") != string::npos) {
159     return false;
160   }
161   // Workaround for Assert and Print mistakenly being labeled as stateful.
162   if (IsAssert(node) || IsPrint(node)) {
163     return true;
164   }
165   return IsFreeOfSideEffect(node);
166 }
167 
DedupComputations(GraphDef * optimized_graph)168 Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) {
169   CanonicalizeGraph(optimized_graph);
170 
171   GraphTopologyView graph_view;
172   if (!graph_view.InitializeFromGraph(*optimized_graph).ok()) {
173     LOG(WARNING) << "Failed to initialize GraphTopologyView.";
174     return Status::OK();
175   }
176 
177   // If either node or rep feeds an inplace op, deduping them may cause data
178   // races. For example: If we dedup nodes initializing two independent
179   // inplace accumulations, they will write to the same buffer, clobbering
180   // each other's results.
181   absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
182   for (int i = 0; i < optimized_graph->node_size(); ++i) {
183     const NodeDef& root = optimized_graph->node(i);
184     if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
185     if (ModifiesInputsInPlace(root)) {
186       const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
187         return node->op() == root.op() || !NeverForwardsInputs(*node);
188       };
189 
190       DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
191                    DfsPredicates::Advance(is_continue_traversal),
192                    DfsCallbacks::PreOrder([&](const NodeDef* node) {
193                      feeds_inplace_op.insert(node);
194                    }));
195     }
196   }
197 
198   std::vector<bool> can_dedup(optimized_graph->node_size());
199   for (int i = 0; i < optimized_graph->node_size(); ++i) {
200     const NodeDef& node = optimized_graph->node(i);
201     can_dedup[i] = (feeds_inplace_op.find(&node) == feeds_inplace_op.end()) &&
202                    CanDedup(node);
203   }
204 
205   bool stop = true;
206   std::set<int> duplicates;
207   UniqueNodes nodes;
208   NodeMap node_map(optimized_graph);
209   do {
210     stop = true;
211     for (int i = 0; i < optimized_graph->node_size(); ++i) {
212       if (!can_dedup[i] || duplicates.find(i) != duplicates.end()) {
213         continue;
214       }
215       NodeDef* node = optimized_graph->mutable_node(i);
216       NodeDef* rep = nodes.FindOrAddRepresentative(node);
217       if (rep == node) {
218         continue;
219       }
220       // Make a copy since we mutate the set below.
221       const auto fanouts = node_map.GetOutputs(node->name());
222       for (NodeDef* fanout : fanouts) {
223         // Update consumers of node.
224         bool updated_fanout = false;
225         for (int i = 0; i < fanout->input_size(); ++i) {
226           string* fanout_input = fanout->mutable_input(i);
227 
228           const int position =
229               NodePositionIfSameNode(*fanout_input, node->name());
230           // Update name in-place.
231           if (position < -1) {
232             continue;
233           } else {
234             if (!updated_fanout) {
235               // The signature of the fanout node will change. Remove it from
236               // nodes.
237               nodes.RemoveRepresentative(fanout);
238             }
239             updated_fanout = true;
240             if (position > 0) {
241               *fanout_input = StrCat(rep->name(), ":", position);
242             } else if (position == 0) {
243               *fanout_input = rep->name();
244             } else {
245               *fanout_input = StrCat("^", rep->name());
246             }
247           }
248         }
249         if (updated_fanout) {
250           node_map.UpdateInput(fanout->name(), node->name(), rep->name());
251           CanonicalizeNode(fanout);
252         }
253       }
254       if (fetch_nodes_known_) {
255         node->Clear();
256       }
257       duplicates.insert(i);
258       stop = false;
259     }
260   } while (!stop);
261 
262   // Delete duplicates
263   if (fetch_nodes_known_ && !duplicates.empty()) {
264     EraseNodesFromGraph(duplicates, optimized_graph);
265   }
266 
267   return Status::OK();
268 }
269 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)270 Status CommonSubgraphElimination::Optimize(Cluster* /*cluster*/,
271                                            const GrapplerItem& item,
272                                            GraphDef* optimized_graph) {
273   // Set up helper data structures.
274   nodes_to_preserve_ = item.NodesToPreserve();
275   fetch_nodes_known_ = !item.fetch.empty();
276   *optimized_graph = item.graph;
277 
278   // Perform topological sort on the graph in order to help DedupComputations
279   // optimize larger subgraphs starting from the roots with more inputs.
280   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
281   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
282 
283   return DedupComputations(optimized_graph);
284 }
285 
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)286 void CommonSubgraphElimination::Feedback(Cluster* /*cluster*/,
287                                          const GrapplerItem& /*item*/,
288                                          const GraphDef& /*optimized_graph*/,
289                                          double /*result*/) {
290   // Nothing to do for ArithmeticOptimizer.
291 }
292 
293 }  // namespace grappler
294 }  // namespace tensorflow
295