1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This module implements a common subexpression elimination pass.  We
17 // process the nodes in the graph in reverse postorder
18 // (i.e. inputs before their downstream dependencies).  The rough algorithm is
19 // as follows:
20 //
21 // std::unordered_map<size_t, Node*> available
22 // for each node n in forward topological order:
23 //   h = NodeHash(n)
24 //   if available[h] exists and Equivalent(available(h), h)
25 //     redirect downstream uses of outputs of n to available[h]
26 //     remove n from graph
27 //   else
28 //     if available[h] does not exist
29 //       available[h] = n
30 //
31 // This is similar to the global value number algorithm describe in this
32 // paper:
33 //
34 // "Global code motion/global value numbering", Cliff Click, PLDI '95
35 // Proceedings of the ACM SIGPLAN 1995 conference on Programming
36 // language design and implementation, Pages 246-257
37 //      http://dl.acm.org/citation.cfm?id=207154
38 
39 #include "tensorflow/core/graph/optimizer_cse.h"
40 
41 #include <iostream>
42 #include <unordered_map>
43 #include <utility>
44 #include <vector>
45 
46 #include "tensorflow/core/framework/node_def.pb.h"
47 #include "tensorflow/core/framework/node_def_util.h"
48 #include "tensorflow/core/graph/algorithm.h"
49 #include "tensorflow/core/graph/graph_node_util.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/lib/hash/hash.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 
55 namespace tensorflow {
56 
57 class OptimizerCSE {
58  public:
OptimizerCSE(Graph * g)59   explicit OptimizerCSE(Graph* g) : g_(g) {}
60 
61   bool Optimize(const std::function<bool(const Node*)>& consider_fn);
62 
63  private:
64   static size_t NodeHash(const Node* n);
65   static bool Equivalent(const Node* a, const Node* b,
66                          AttrSlice::Scratch* scratch);
67 
68   Graph* g_;
69 };
70 
FillInputs(const Node * n,gtl::InlinedVector<const Node *,4> * control_edges,gtl::InlinedVector<std::pair<const Node *,int>,4> * in)71 static void FillInputs(const Node* n,
72                        gtl::InlinedVector<const Node*, 4>* control_edges,
73                        gtl::InlinedVector<std::pair<const Node*, int>, 4>* in) {
74   DCHECK_EQ(in->size(), n->num_inputs());
75   control_edges->clear();
76   for (const Edge* e : n->in_edges()) {
77     if (e->IsControlEdge()) {
78       control_edges->push_back(e->src());
79     } else {
80       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
81     }
82   }
83   std::sort(control_edges->begin(), control_edges->end());
84   if (n->op_def().is_commutative()) {
85     // For commutative inputs, we sort the input by the input Node*
86     // to get a canonical ordering (so that add(a,b) and add(b, a) will
87     // hash to the same value if is_commutative is true for 'add').
88     std::sort(in->begin(), in->end());
89   }
90 }
91 
92 static size_t kIllegalNodeHash = 0;
93 
94 class Hasher {
95  public:
hash()96   uint64 hash() { return h_ == kIllegalNodeHash ? kIllegalNodeHash + 1 : h_; }
97 
MixString(const string & s)98   void MixString(const string& s) { h_ = Hash64(s.data(), s.size(), h_); }
99 
MixInteger(size_t z)100   void MixInteger(size_t z) { h_ = Hash64Combine(h_, z); }
101 
MixProto(const protobuf::MessageLite & msg)102   void MixProto(const protobuf::MessageLite& msg) {
103     msg.ByteSizeLong();  // Ensure sizes are cached accurately.
104     HashingOutputStream hasher;
105     {
106       // CodedOutputStream doesn't call BackUp until it's destroyed, so we need
107       // it to be destroyed before we call hasher.hash().
108       protobuf::io::CodedOutputStream stream(&hasher);
109       stream.EnableAliasing(true);
110       stream.SetSerializationDeterministic(true);
111       msg.SerializeWithCachedSizes(&stream);
112     }
113     h_ = Hash64Combine(h_, hasher.hash());
114   }
115 
116  private:
117   // HashingOutputStream produces the same exact hash as if you serialized the
118   // proto and hashed it sequentially in kBufSize chunks, except it doesn't
119   // manifest the entire proto into memory at any point.
120   class HashingOutputStream : public protobuf::io::ZeroCopyOutputStream {
121    public:
122     // This kBufSize makes sizeof(HashingOutputStream) == 256.  It's not chosen
123     // for any particular reason except it's a nice even number of cache lines.
124     static constexpr size_t kBufSize = 228;
125     static constexpr uint64 kDefaultSeed = 2570847921467975139ULL;
Next(void ** data,int * size)126     bool Next(void** data, int* size) override {
127       if (i_ == kBufSize) {
128         // Mix the chunk in.
129         Mix(buf_, kBufSize);
130         *data = buf_;
131         *size = kBufSize;
132       } else {
133         *data = buf_ + i_;
134         *size = kBufSize - i_;
135       }
136       // We always set i_ to be past the end, since we've given the rest of buf_
137       // out.
138       i_ = kBufSize;
139       return true;
140     }
141 
BackUp(int count)142     void BackUp(int count) override { i_ -= count; }
143 
ByteCount() const144     int64_t ByteCount() const override { return byte_count_; }
145 
WriteAliasedRaw(const void * void_data,int size)146     bool WriteAliasedRaw(const void* void_data, int size) override {
147       // We can't do math on void*.
148       const char* data = static_cast<const char*>(void_data);
149       const auto remaining = kBufSize - i_;
150       if (remaining > 0) {
151         if (size < remaining) {
152           memcpy(buf_ + i_, data, size);
153           i_ += size;
154           return true;
155         }
156         memcpy(buf_ + i_, data, remaining);
157         i_ = kBufSize;
158         data += remaining;
159         size -= remaining;
160       }
161       if (i_ == kBufSize) {
162         Mix(buf_, kBufSize);
163         i_ = 0;
164       }
165       while (size >= kBufSize) {
166         Mix(data, kBufSize);
167         data += kBufSize;
168         size -= kBufSize;
169       }
170       memcpy(buf_, data, size);
171       i_ = size;
172       return true;
173     }
174 
AllowsAliasing() const175     bool AllowsAliasing() const override { return true; }
176 
hash()177     uint64 hash() {
178       if (i_ != 0) {
179         Mix(buf_, i_);
180         i_ = 0;
181       }
182       return h_;
183     }
184 
185    private:
Mix(const char * p,size_t n)186     void Mix(const char* p, size_t n) {
187       byte_count_ += n;
188       h_ = Hash64(p, n, h_);
189     }
190     char buf_[kBufSize];
191     int i_ = 0;
192     int64_t byte_count_ = 0;
193     uint64 h_ = kDefaultSeed;
194   };
195 
196   uint64 h_ = HashingOutputStream::kDefaultSeed;
197 };
198 
NodeHash(const Node * n)199 size_t OptimizerCSE::NodeHash(const Node* n) {
200   Hasher hasher;
201   hasher.MixString(n->type_string());
202   hasher.MixInteger(n->output_types().size());
203   for (DataType dt : n->output_types()) {
204     hasher.MixInteger(dt);
205   }
206 
207   hasher.MixInteger(n->num_inputs());
208   gtl::InlinedVector<const Node*, 4> control_edges;
209   gtl::InlinedVector<std::pair<const Node*, int>, 4> in(n->num_inputs());
210   FillInputs(n, &control_edges, &in);
211   for (const auto& edge : in) {
212     hasher.MixInteger(edge.first->id());
213     hasher.MixInteger(edge.second);
214   }
215 
216 #if !defined(__ANDROID__)
217   // Hash the attrs.  For example, this makes sure different constants
218   // end up in different hash buckets.
219   size_t attr_hashes = 0;
220   for (const auto& attr : n->attrs()) {
221     Hasher h;
222     h.MixString(attr.first);
223     h.MixProto(attr.second);
224     attr_hashes = Hash64CombineUnordered(attr_hashes, h.hash());
225   }
226   hasher.MixInteger(attr_hashes);
227 #endif
228 
229   return hasher.hash();
230 }
231 
HasRefInput(const Node * n)232 static bool HasRefInput(const Node* n) {
233   for (auto dt : n->input_types()) {
234     if (IsRefType(dt)) return true;
235   }
236   return false;
237 }
238 
Equivalent(const Node * a,const Node * b,AttrSlice::Scratch * scratch)239 bool OptimizerCSE::Equivalent(const Node* a, const Node* b,
240                               AttrSlice::Scratch* scratch) {
241   // Different op names are different
242   if (a->type_string() != b->type_string()) return false;
243 
244   // Never consider stateful nodes (such as non-const inputs) equivalent.
245   if (a->op_def().is_stateful()) return false;
246 
247   // For now, we consider any node that takes a ref input to not be
248   // equivalent to any other node.
249   if (HasRefInput(a) || HasRefInput(b)) return false;
250 
251   // Compare attrs.  Note that equal attrs implies equal input and
252   // output types.
253   if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false;
254 
255   // Compare input sources
256   if (a->num_inputs() != b->num_inputs()) return false;
257   const int N_in = a->num_inputs();
258   gtl::InlinedVector<const Node*, 4> a_control_edges;
259   gtl::InlinedVector<const Node*, 4> b_control_edges;
260   gtl::InlinedVector<std::pair<const Node*, int>, 4> a_in(N_in);
261   gtl::InlinedVector<std::pair<const Node*, int>, 4> b_in(N_in);
262   FillInputs(a, &a_control_edges, &a_in);
263   FillInputs(b, &b_control_edges, &b_in);
264   if (a_in != b_in) return false;
265   if (a_control_edges != b_control_edges) return false;
266 
267   return true;
268 }
269 
Optimize(const std::function<bool (const Node *)> & consider_fn)270 bool OptimizerCSE::Optimize(
271     const std::function<bool(const Node*)>& consider_fn) {
272   // This very simple implementation works if the whole graph is one
273   // giant basic block (because we just traverse nodes in a
274   // topological order). This simple implementation works well
275   // with control flow/loops/etc. But we need to be careful about
276   // control flow if we want to add more sophisticated CSE optimizations.
277 
278   // TODO(jeff): We need to handle Update nodes specially, but dealing
279   // with more general control flow will also solve this issue, and for
280   // now, our updates are almost always the most downstream nodes in
281   // the graph.
282   std::vector<Node*> order;
283   GetReversePostOrder(*g_, &order);
284 
285   // Our value is just a single Node*, meaning we keep just a single
286   // candidate for a given node hash value.  This may cause us to
287   // (rarely) lose some optimization opportunities if there are
288   // hash collisions, but it allows us to avoid having the value
289   // be a set<Node*> (or equivalent).
290   std::unordered_map<size_t, Node*> available;
291 
292   // Scratch space for Equivalent calls.  Allocated here and passed in to
293   // Equivalent to avoid allocation inside the loop below.
294   bool changed = false;
295   AttrSlice::Scratch scratch;
296   for (Node* n : order) {
297     if (!n->IsOp()) continue;
298 
299     // Don't prune placeholder nodes.
300     if (n->type_string() == "Placeholder" ||
301         n->type_string() == "PlaceholderV2" ||
302         n->type_string() == "PlaceholderWithDefault") {
303       continue;
304     }
305 
306     // See if we should consider this node at all
307     if (consider_fn != nullptr && !consider_fn(n)) continue;
308 
309     size_t h = NodeHash(n);
310     Node** candidate = &available[h];
311     if (*candidate == nullptr) {
312       // No existing match: insert "n" into the hash table under "h"
313       *candidate = n;
314     } else if (Equivalent(*candidate, n, &scratch)) {
315       VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and "
316               << n->name();
317       // *candidate and n are equivalent.  Therefore, we can replace
318       // n with *candidate by fixing up outgoing edges from "n" to instead
319       // come from "*candidate", and then delete n from the graph
320       for (const Edge* e : n->out_edges()) {
321         g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input());
322       }
323 
324       MergeDebugInfo(NodeDebugInfo(*n), *candidate);
325       g_->RemoveNode(n);
326       changed = true;
327     }
328   }
329   return changed;
330 }
331 
OptimizeCSE(Graph * g,const std::function<bool (const Node *)> & consider_fn)332 bool OptimizeCSE(Graph* g,
333                  const std::function<bool(const Node*)>& consider_fn) {
334   OptimizerCSE opt(g);
335   return opt.Optimize(consider_fn);
336 }
337 
338 }  // namespace tensorflow
339