1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
17 
18 #include <algorithm>
19 
20 #include "absl/strings/str_format.h"
21 
22 namespace tensorflow {
23 namespace grappler {
24 namespace graph_analyzer {
25 
26 static constexpr bool debug = false;
27 
28 //=== SigNode
29 
SigNode(const NodeDef * node)30 SigNode::SigNode(const NodeDef* node) : node_(node) {}
31 
CopyLinks(const GenNode & from,const TranslationMap & map)32 void SigNode::CopyLinks(const GenNode& from, const TranslationMap& map) {
33   hash_to_link_.clear();
34   hashed_peers_.clear();
35 
36   std::map<LinkTag, Link> link_map;
37   CopyLinksPass1(from, map, &link_map);
38   CopyLinksPass2(&link_map);
39 }
40 
CopyLinksPass1(const GenNode & from,const TranslationMap & map,std::map<LinkTag,Link> * link_map)41 void SigNode::CopyLinksPass1(const GenNode& from, const TranslationMap& map,
42                              std::map<LinkTag, Link>* link_map) {
43   LinkTag::Hasher link_hasher;
44 
45   for (const auto& entry : from.links()) {
46     for (const auto& target : entry.second) {
47       auto nodeit = map.find(target.node);
48       if (nodeit == map.end()) {
49         // Node is not in the subgraph, ignore.
50         continue;
51       }
52 
53       LinkTag tag(entry.first, target.port);
54       size_t hval = link_hasher(tag);
55 
56       // This instantiates the entry if it was not present.
57       Link& map_entry = (*link_map)[tag];
58       if (map_entry.peers.empty()) {
59         map_entry.tag = tag;
60         map_entry.unique_hash = hval;
61       }
62       map_entry.peers.push_back(nodeit->second);
63     }
64   }
65 }
66 
CopyLinksPass2(std::map<LinkTag,Link> * link_map)67 void SigNode::CopyLinksPass2(std::map<LinkTag, Link>* link_map) {
68   for (auto& entry : *link_map) {
69     Link* hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
70     // In case of a conflict, rehash. This should almost never happen.
71     // Because the order of iteration is predictable, the rehashed values
72     // will also be predictable.
73     while (!hl_entry_ptr->peers.empty()) {
74       CombineHash(1, &entry.second.unique_hash);
75       hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
76     }
77 
78     for (const auto& peer : entry.second.peers) {
79       hashed_peers_.emplace_back(HashedPeer(entry.second.unique_hash, peer));
80     }
81 
82     hl_entry_ptr->tag = entry.second.tag;
83     hl_entry_ptr->unique_hash = entry.second.unique_hash;
84     hl_entry_ptr->peers.swap(entry.second.peers);
85   }
86 }
87 
ComputeTopoHash0()88 void SigNode::ComputeTopoHash0() {
89   topo_hash_.clear();
90   last_hashed_nodes_ = next_hashed_nodes_ = node_mask_;
91 
92   // TODO(babkin): include the attrbutes too, as an option.
93   size_t hval = std::hash<string>()(opcode());
94 
95   // Getting the topology of the links in to the hash early should get more
96   // conflicts resolved early.
97   for (const auto& entry : hashed_peers_) {
98     CombineHash(entry.link_hash, &hval);
99   }
100 
101   topo_hash_.push_back(hval);
102 }
103 
ComputeTopoHash(int distance)104 void SigNode::ComputeTopoHash(int distance) {
105   // The new starting point.
106   next_hashed_nodes_ = last_hashed_nodes_;
107   if (debug) {
108     LOG(INFO) << "DEBUG    node " << name() << " mask=" << std::hex
109               << next_hashed_nodes_;
110   }
111 
112   if (hash_is_final_) {
113     return;
114   }
115 
116   CHECK(topo_hash_.size() == distance);
117 
118   int prev = distance - 1;
119 
120   // Start with own's local topology hash. This value is stable, so
121   // if the hashes of the surrounding nodes don't change on the following
122   // distances, the hash of this node won't change either.
123   size_t hval = topo_hash_[0];
124 
125   if (!hashed_peers_.empty()) {
126     size_t last_link_hash = hashed_peers_[0].link_hash;
127     size_t comm_hash = 0;
128 
129     for (const auto& entry : hashed_peers_) {
130       if (entry.link_hash != last_link_hash) {
131         CombineHash(last_link_hash, &hval);
132         CombineHash(comm_hash, &hval);
133         comm_hash = 0;
134         last_link_hash = entry.link_hash;
135       }
136 
137       // The links in the same vector are commutative, so combine their
138       // hashes in a commutative way.
139       CombineHashCommutative(entry.peer->GetTopoHash(prev), &comm_hash);
140       next_hashed_nodes_ |= entry.peer->last_hashed_nodes_;
141       if (debug) {
142         LOG(INFO) << "DEBUG    node " << name() << " += " << entry.peer->name()
143                   << " mask=" << std::hex << next_hashed_nodes_;
144       }
145     }
146 
147     // The last commutative group.
148     CombineHash(last_link_hash, &hval);
149     CombineHash(comm_hash, &hval);
150   }
151 
152   topo_hash_.push_back(hval);
153 }
154 
GetTopoHash(int distance) const155 size_t SigNode::GetTopoHash(int distance) const {
156   CHECK(!topo_hash_.empty());
157   if (distance >= topo_hash_.size()) {
158     CHECK(hash_is_final_);
159     return topo_hash_.back();
160   } else {
161     return topo_hash_[distance];
162   }
163 }
164 
operator ==(const SigNode & other) const165 bool SigNode::operator==(const SigNode& other) const {
166   // TODO(babkin): add attributes too.
167   if (opcode() != other.opcode()) {
168     return false;
169   }
170 
171   // Normally the caller is expected to compare the nodes
172   // at the same rank in different graphs, but just in case...
173   if (unique_rank_ != other.unique_rank_) {
174     return false;
175   }
176 
177   if (hashed_peers_.size() != other.hashed_peers_.size()) {
178     return false;
179   }
180 
181   for (auto it1 = hashed_peers_.begin(), it2 = other.hashed_peers_.begin();
182        it1 != hashed_peers_.end(); ++it1, ++it2) {
183     // TODO(babkin): might compare the actual values too
184     // but the hash is probably just as good.
185     if (it1->link_hash != it2->link_hash) {
186       return false;
187     }
188     if (it1->peer->unique_rank_ != it2->peer->unique_rank_) {
189       return false;
190     }
191   }
192 
193   return true;
194 }
195 
196 //=== Signature
197 
198 constexpr int Signature::kMaxGraphSize;
199 
ToString() const200 string Signature::ToString() const {
201   string result;
202   for (size_t n = 0; n < nodes.size(); ++n) {
203     // TODO(babkin): add attributes too.
204     result += absl::StrFormat("%d:%s", n, nodes[n]->opcode());
205     for (const auto& entry : nodes[n]->hashed_peers_) {
206       const auto& link = nodes[n]->hash_to_link_[entry.link_hash];
207 
208       // The link entries are already sorted, by tags and then by the
209       // node ranks.
210       if (link.tag.local.IsInbound()) {
211         result +=
212             absl::StrFormat("[%s:%s:%d]", string(link.tag.local),
213                             string(link.tag.remote), entry.peer->unique_rank_);
214       }
215     }
216     result.push_back(',');
217   }
218   return result;
219 }
220 
Compute()221 Status Signature::Compute() {
222   if (map.size() > kMaxGraphSize) {
223     return Status(
224         error::INVALID_ARGUMENT,
225         absl::StrFormat(
226             "A graph of %d nodes is too big for signature computation, "
227             "the maximal supported node count is %d.",
228             map.size(), kMaxGraphSize));
229   }
230 
231   // The value that will be assigned next as the unique node id.
232   // This also means that all the entries in nodes at indexes less than this
233   // have been finalized and don't need to be touched any more.
234   size_t next_node_id = 0;
235 
236   sig_short = 0;
237   sig_full.resize(0);  // Keep the storage.
238 
239   // The main signature generation.
240   PrepareNodes();
241   FindUniqueHashes(&next_node_id);
242   while (next_node_id < map.size()) {
243     ComputeOneRound(next_node_id);
244     FindUniqueHashes(&next_node_id);
245   }
246 
247   OrderLinks();
248 
249   return Status::OK();
250 }
251 
PrepareNodes()252 void Signature::PrepareNodes() {
253   nodes.resize(0);  // Keep the storage.
254 
255   // Initialize the nodes.
256   int64_t mask = 1;
257   for (const auto& entry : map) {
258     SigNode* node = entry.second.get();
259     node->last_hashed_nodes_ = node->node_mask_ = mask;
260     mask <<= 1;
261     node->unique_rank_ = ~0;
262     node->hash_is_final_ = false;
263     node->ComputeTopoHash0();
264     if (node->GetHighTopoHash() <= map.size()) {
265       // Would conflict with one of the reserved values.
266       node->ReHighTopoHash();
267     }
268 
269     // The initial order is random.
270     nodes.emplace_back(node);
271   }
272 }
273 
FindUniqueHashes(size_t * next_node_id_p)274 void Signature::FindUniqueHashes(size_t* next_node_id_p) {
275   // Start by sorting by the hash value.
276   std::sort(nodes.begin() + *next_node_id_p, nodes.end(),
277             SigNode::NodeOrderLess());
278 
279   // At each call, if no nodes have unique hashes, one node that has a
280   // non-unique (shared) hash can be made unique by assigning a unique id.
281   // This node gets picked predictably by taking the last node.
282   // TODO(babkin): Technically, more than one node can be unshared,
283   // as long as their last_hashed_nodes_ overlap only by the nodes that
284   // already had the assigned ids before the current round. But it's not clear
285   // yet, how often would this beneficial, because it looks like for many
286   // subgraphs unsharing one node should be enough to untangle them. This
287   // would need more measurement before implementing.
288   bool found_unique = false;
289   for (size_t n = *next_node_id_p; n < nodes.size(); ++n) {
290     size_t cur_hash = nodes[n]->GetHighTopoHash();
291     if (n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash) {
292       // A sequence of nodes sharing the same hash. Skip over it.
293       // TODO(babkin): check here for the arbitrary hash conflicts and resolve
294       // them.
295       for (++n;
296            n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash;
297            ++n) {
298       }
299       if (found_unique || n != nodes.size() - 1) {
300         // Either some unique nodes have already been found, or this is
301         // not the last chance, keep trying to find the unique nodes.
302         continue;
303       }
304       // Here we're at the last node and haven't found any unique ones.
305       // So fall through and make this last node unique.
306     }
307 
308     found_unique = true;
309     size_t id = (*next_node_id_p)++;
310     nodes[n]->unique_rank_ = id;
311 
312     size_t last_hash = nodes[n]->GetHighTopoHash();
313     CombineHash(last_hash, &sig_short);
314     sig_full.push_back(last_hash);
315 
316     // Take the hash at 0 and mix the unique rank into it. After that it will
317     // stay fixed.
318     nodes[n]->topo_hash_.resize(1);
319     nodes[n]->topo_hash_[0] = id + 1;  // Avoid the value of 0.
320 
321     nodes[n]->hash_is_final_ = true;
322     nodes[n]->last_hashed_nodes_ = nodes[n]->node_mask_;
323     if (n != id) {
324       std::swap(nodes[id], nodes[n]);
325     }
326   }
327 }
328 
ComputeOneRound(size_t next_node_id)329 void Signature::ComputeOneRound(size_t next_node_id) {
330   // Reset the state of the nodes.
331   int debug_i = 0;
332   for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
333     auto node = *it;
334     // The hash at distance 0 never changes, so preserve it.
335     node->topo_hash_.resize(1);
336     node->last_hashed_nodes_ = node->node_mask_;
337     node->hash_is_final_ = false;
338     if (debug) {
339       LOG(INFO) << "DEBUG distance=" << 0 << " node " << debug_i++ << " "
340                 << node->name() << " mask=" << std::hex
341                 << node->last_hashed_nodes_;
342     }
343   }
344 
345   bool stop = false;
346   // The distance can reach up to nodes.size()+1, to include not only all the
347   // nodes but also all the redundant paths.
348   for (int distance = 1; !stop; ++distance) {
349     for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
350       auto node = *it;
351       if (node->hash_is_final_) {
352         continue;
353       }
354       node->ComputeTopoHash(distance);
355       if (node->GetHighTopoHash() <= nodes.size()) {
356         // Would conflict with one of the reserved values.
357         node->ReHighTopoHash();
358       }
359     }
360 
361     // Will be looking for the indications to not stop.
362     stop = true;
363 
364     debug_i = 0;
365     // The bitmasks get moved after all the hash computations are done.
366     for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
367       auto node = *it;
368       if (debug) {
369         LOG(INFO) << "DEBUG distance=" << distance << " node " << debug_i++
370                   << " " << node->name() << " oldmask=" << std::hex
371                   << node->last_hashed_nodes_ << " mask=" << std::hex
372                   << node->next_hashed_nodes_;
373       }
374       if (node->last_hashed_nodes_ == node->next_hashed_nodes_) {
375         // Stopped growing, this part of the graph must be fully
376         // surrounded by nodes that already have the unique ids.
377         node->hash_is_final_ = true;
378       } else {
379         node->last_hashed_nodes_ = node->next_hashed_nodes_;
380         stop = false;
381       }
382     }
383   }
384 }
385 
OrderLinks()386 void Signature::OrderLinks() {
387   for (const auto& node : nodes) {
388     if (node->hashed_peers_.empty()) {
389       continue;
390     }
391 
392     size_t cur_link_hash = node->hashed_peers_[0].link_hash + 1;
393     int first_idx = -1;
394 
395     int idx;
396     for (idx = 0; idx < node->hashed_peers_.size(); ++idx) {
397       auto& entry = node->hashed_peers_[idx];
398       if (entry.link_hash == cur_link_hash) {
399         continue;
400       }
401       if (idx - first_idx > 1) {
402         // Need to sort.
403         std::sort(node->hashed_peers_.begin() + first_idx,
404                   node->hashed_peers_.begin() + idx,
405                   SigNode::HashedPeer::LessByRank());
406       }
407 
408       cur_link_hash = entry.link_hash;
409       first_idx = idx;
410     }
411     if (idx - first_idx > 1) {
412       // Sort the last bunch.
413       std::sort(node->hashed_peers_.begin() + first_idx,
414                 node->hashed_peers_.begin() + idx,
415                 SigNode::HashedPeer::LessByRank());
416     }
417   }
418 }
419 
operator ==(const Signature & other) const420 bool Signature::operator==(const Signature& other) const {
421   // Tries to find the differences as early as possible by
422   // comparing the hashes first.
423 
424   if (sig_short != other.sig_short) {
425     return false;
426   }
427   if (sig_full.size() != other.sig_full.size()) {
428     return false;
429   }
430 
431   for (auto it1 = sig_full.begin(), it2 = other.sig_full.begin();
432        it1 != sig_full.end(); ++it1, ++it2) {
433     if (*it1 != *it2) {
434       return false;
435     }
436   }
437 
438   if (nodes.size() != other.nodes.size()) {
439     return false;
440   }
441   for (auto it1 = nodes.begin(), it2 = other.nodes.begin(); it1 != nodes.end();
442        ++it1, ++it2) {
443     if (**it1 != **it2) {
444       return false;
445     }
446   }
447 
448   return true;
449 }
450 
451 }  // end namespace graph_analyzer
452 }  // end namespace grappler
453 }  // end namespace tensorflow
454