1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
17
18 #include <set>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/graph_topology_view.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/canonicalizer.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/grappler/utils/traversal.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/hash.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "tensorflow/core/platform/strcat.h"
43 #include "tensorflow/core/platform/stringpiece.h"
44 #include "tensorflow/core/platform/types.h"
45
46 namespace tensorflow {
47 namespace grappler {
48 class Cluster;
49 } // namespace grappler
50 } // namespace tensorflow
51
52 using tensorflow::strings::StrCat;
53
54 namespace tensorflow {
55 namespace grappler {
56
57 class UniqueNodes {
58 public:
FindOrAddRepresentative(NodeDef * node)59 NodeDef* FindOrAddRepresentative(NodeDef* node) {
60 uint64 sig = ComputeSignature(*node);
61 std::vector<NodeDef*>& candidates = rep_[sig];
62 for (auto& candidate : candidates) {
63 if ((candidate == node) || SameNode(*candidate, *node)) {
64 return candidate;
65 }
66 }
67 candidates.push_back(node);
68 return node;
69 }
70
RemoveRepresentative(NodeDef * node)71 void RemoveRepresentative(NodeDef* node) {
72 auto it = memoized_signatures_.find(node);
73 if (it == memoized_signatures_.end()) return;
74
75 std::vector<NodeDef*>& candidates = rep_[it->second];
76 for (int i = 0, end = candidates.size(); i < end; ++i) {
77 if (candidates[i] == node) {
78 std::swap(candidates[i], candidates[candidates.size() - 1]);
79 candidates.resize(candidates.size() - 1);
80 break;
81 }
82 }
83 memoized_signatures_.erase(node);
84 }
85
86 private:
87 uint64 ComputeSignature(const NodeDef& node);
88 bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
89
90 absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
91 absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
92 };
93
ComputeSignature(const NodeDef & node)94 uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
95 auto it = memoized_signatures_.find(&node);
96 if (it != memoized_signatures_.end()) return it->second;
97
98 uint64 h = Hash64(node.op());
99 h = Hash64Combine(Hash64(node.device()), h);
100
101 for (const auto& input : node.input()) {
102 const TensorId input_tensor = ParseTensorName(input);
103 uint64 input_hash = Hash64Combine(
104 Hash64(input_tensor.node().data(), input_tensor.node().size()),
105 std::hash<int>()(input_tensor.index()));
106 h = Hash64CombineUnordered(input_hash, h);
107 }
108 for (const auto& attr : node.attr()) {
109 uint64 attr_hash =
110 Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second));
111 h = Hash64CombineUnordered(attr_hash, h);
112 }
113 memoized_signatures_.emplace(&node, h);
114 return h;
115 }
116
117 // PRECONDITION:
118 // Node input orders are assumed to be canonicalized, i.e. control inputs for
119 // all nodes as well as regular inputs for commutative nodes must be sorted.
SameNode(const NodeDef & node1,const NodeDef & node2) const120 bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
121 if (node1.op() != node2.op()) {
122 return false;
123 }
124 if (node1.device() != node2.device()) {
125 return false;
126 }
127 if (node1.input_size() != node2.input_size()) {
128 return false;
129 }
130 if (node1.attr_size() != node2.attr_size()) {
131 return false;
132 }
133
134 // Compare inputs.
135 auto it1 = node1.input().begin();
136 auto it2 = node2.input().begin();
137 for (; it1 != node1.input().end(); ++it1, ++it2) {
138 if (*it1 != *it2) return false;
139 }
140
141 // Compare attributes.
142 for (const auto& attr1 : node1.attr()) {
143 auto it = node2.attr().find(attr1.first);
144 if (it == node2.attr().end()) return false;
145 if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
146 }
147
148 return true;
149 }
150
CanDedup(const NodeDef & node) const151 bool CommonSubgraphElimination::CanDedup(const NodeDef& node) const {
152 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
153 return false;
154 }
155 if (IsEnter(node) || IsExit(node)) {
156 return false;
157 }
158 if (node.device().find("SPU") != string::npos) {
159 return false;
160 }
161 // Workaround for Assert and Print mistakenly being labeled as stateful.
162 if (IsAssert(node) || IsPrint(node)) {
163 return true;
164 }
165 return IsFreeOfSideEffect(node);
166 }
167
DedupComputations(GraphDef * optimized_graph)168 Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) {
169 CanonicalizeGraph(optimized_graph);
170
171 GraphTopologyView graph_view;
172 if (!graph_view.InitializeFromGraph(*optimized_graph).ok()) {
173 LOG(WARNING) << "Failed to initialize GraphTopologyView.";
174 return Status::OK();
175 }
176
177 // If either node or rep feeds an inplace op, deduping them may cause data
178 // races. For example: If we dedup nodes initializing two independent
179 // inplace accumulations, they will write to the same buffer, clobbering
180 // each other's results.
181 absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
182 for (int i = 0; i < optimized_graph->node_size(); ++i) {
183 const NodeDef& root = optimized_graph->node(i);
184 if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
185 if (ModifiesInputsInPlace(root)) {
186 const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
187 return node->op() == root.op() || !NeverForwardsInputs(*node);
188 };
189
190 DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
191 DfsPredicates::Advance(is_continue_traversal),
192 DfsCallbacks::PreOrder([&](const NodeDef* node) {
193 feeds_inplace_op.insert(node);
194 }));
195 }
196 }
197
198 std::vector<bool> can_dedup(optimized_graph->node_size());
199 for (int i = 0; i < optimized_graph->node_size(); ++i) {
200 const NodeDef& node = optimized_graph->node(i);
201 can_dedup[i] = (feeds_inplace_op.find(&node) == feeds_inplace_op.end()) &&
202 CanDedup(node);
203 }
204
205 bool stop = true;
206 std::set<int> duplicates;
207 UniqueNodes nodes;
208 NodeMap node_map(optimized_graph);
209 do {
210 stop = true;
211 for (int i = 0; i < optimized_graph->node_size(); ++i) {
212 if (!can_dedup[i] || duplicates.find(i) != duplicates.end()) {
213 continue;
214 }
215 NodeDef* node = optimized_graph->mutable_node(i);
216 NodeDef* rep = nodes.FindOrAddRepresentative(node);
217 if (rep == node) {
218 continue;
219 }
220 // Make a copy since we mutate the set below.
221 const auto fanouts = node_map.GetOutputs(node->name());
222 for (NodeDef* fanout : fanouts) {
223 // Update consumers of node.
224 bool updated_fanout = false;
225 for (int i = 0; i < fanout->input_size(); ++i) {
226 string* fanout_input = fanout->mutable_input(i);
227
228 const int position =
229 NodePositionIfSameNode(*fanout_input, node->name());
230 // Update name in-place.
231 if (position < -1) {
232 continue;
233 } else {
234 if (!updated_fanout) {
235 // The signature of the fanout node will change. Remove it from
236 // nodes.
237 nodes.RemoveRepresentative(fanout);
238 }
239 updated_fanout = true;
240 if (position > 0) {
241 *fanout_input = StrCat(rep->name(), ":", position);
242 } else if (position == 0) {
243 *fanout_input = rep->name();
244 } else {
245 *fanout_input = StrCat("^", rep->name());
246 }
247 }
248 }
249 if (updated_fanout) {
250 node_map.UpdateInput(fanout->name(), node->name(), rep->name());
251 CanonicalizeNode(fanout);
252 }
253 }
254 if (fetch_nodes_known_) {
255 node->Clear();
256 }
257 duplicates.insert(i);
258 stop = false;
259 }
260 } while (!stop);
261
262 // Delete duplicates
263 if (fetch_nodes_known_ && !duplicates.empty()) {
264 EraseNodesFromGraph(duplicates, optimized_graph);
265 }
266
267 return Status::OK();
268 }
269
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)270 Status CommonSubgraphElimination::Optimize(Cluster* /*cluster*/,
271 const GrapplerItem& item,
272 GraphDef* optimized_graph) {
273 // Set up helper data structures.
274 nodes_to_preserve_ = item.NodesToPreserve();
275 fetch_nodes_known_ = !item.fetch.empty();
276 *optimized_graph = item.graph;
277
278 // Perform topological sort on the graph in order to help DedupComputations
279 // optimize larger subgraphs starting from the roots with more inputs.
280 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
281 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
282
283 return DedupComputations(optimized_graph);
284 }
285
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)286 void CommonSubgraphElimination::Feedback(Cluster* /*cluster*/,
287 const GrapplerItem& /*item*/,
288 const GraphDef& /*optimized_graph*/,
289 double /*result*/) {
290 // Nothing to do for ArithmeticOptimizer.
291 }
292
293 } // namespace grappler
294 } // namespace tensorflow
295