1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
17
18 #include <algorithm>
19
20 #include "absl/strings/str_format.h"
21
22 namespace tensorflow {
23 namespace grappler {
24 namespace graph_analyzer {
25
26 static constexpr bool debug = false;
27
28 //=== SigNode
29
SigNode(const NodeDef * node)30 SigNode::SigNode(const NodeDef* node) : node_(node) {}
31
CopyLinks(const GenNode & from,const TranslationMap & map)32 void SigNode::CopyLinks(const GenNode& from, const TranslationMap& map) {
33 hash_to_link_.clear();
34 hashed_peers_.clear();
35
36 std::map<LinkTag, Link> link_map;
37 CopyLinksPass1(from, map, &link_map);
38 CopyLinksPass2(&link_map);
39 }
40
CopyLinksPass1(const GenNode & from,const TranslationMap & map,std::map<LinkTag,Link> * link_map)41 void SigNode::CopyLinksPass1(const GenNode& from, const TranslationMap& map,
42 std::map<LinkTag, Link>* link_map) {
43 LinkTag::Hasher link_hasher;
44
45 for (const auto& entry : from.links()) {
46 for (const auto& target : entry.second) {
47 auto nodeit = map.find(target.node);
48 if (nodeit == map.end()) {
49 // Node is not in the subgraph, ignore.
50 continue;
51 }
52
53 LinkTag tag(entry.first, target.port);
54 size_t hval = link_hasher(tag);
55
56 // This instantiates the entry if it was not present.
57 Link& map_entry = (*link_map)[tag];
58 if (map_entry.peers.empty()) {
59 map_entry.tag = tag;
60 map_entry.unique_hash = hval;
61 }
62 map_entry.peers.push_back(nodeit->second);
63 }
64 }
65 }
66
CopyLinksPass2(std::map<LinkTag,Link> * link_map)67 void SigNode::CopyLinksPass2(std::map<LinkTag, Link>* link_map) {
68 for (auto& entry : *link_map) {
69 Link* hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
70 // In case of a conflict, rehash. This should almost never happen.
71 // Because the order of iteration is predictable, the rehashed values
72 // will also be predictable.
73 while (!hl_entry_ptr->peers.empty()) {
74 CombineHash(1, &entry.second.unique_hash);
75 hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
76 }
77
78 for (const auto& peer : entry.second.peers) {
79 hashed_peers_.emplace_back(HashedPeer(entry.second.unique_hash, peer));
80 }
81
82 hl_entry_ptr->tag = entry.second.tag;
83 hl_entry_ptr->unique_hash = entry.second.unique_hash;
84 hl_entry_ptr->peers.swap(entry.second.peers);
85 }
86 }
87
ComputeTopoHash0()88 void SigNode::ComputeTopoHash0() {
89 topo_hash_.clear();
90 last_hashed_nodes_ = next_hashed_nodes_ = node_mask_;
91
92 // TODO(babkin): include the attrbutes too, as an option.
93 size_t hval = std::hash<string>()(opcode());
94
95 // Getting the topology of the links in to the hash early should get more
96 // conflicts resolved early.
97 for (const auto& entry : hashed_peers_) {
98 CombineHash(entry.link_hash, &hval);
99 }
100
101 topo_hash_.push_back(hval);
102 }
103
ComputeTopoHash(int distance)104 void SigNode::ComputeTopoHash(int distance) {
105 // The new starting point.
106 next_hashed_nodes_ = last_hashed_nodes_;
107 if (debug) {
108 LOG(INFO) << "DEBUG node " << name() << " mask=" << std::hex
109 << next_hashed_nodes_;
110 }
111
112 if (hash_is_final_) {
113 return;
114 }
115
116 CHECK(topo_hash_.size() == distance);
117
118 int prev = distance - 1;
119
120 // Start with own's local topology hash. This value is stable, so
121 // if the hashes of the surrounding nodes don't change on the following
122 // distances, the hash of this node won't change either.
123 size_t hval = topo_hash_[0];
124
125 if (!hashed_peers_.empty()) {
126 size_t last_link_hash = hashed_peers_[0].link_hash;
127 size_t comm_hash = 0;
128
129 for (const auto& entry : hashed_peers_) {
130 if (entry.link_hash != last_link_hash) {
131 CombineHash(last_link_hash, &hval);
132 CombineHash(comm_hash, &hval);
133 comm_hash = 0;
134 last_link_hash = entry.link_hash;
135 }
136
137 // The links in the same vector are commutative, so combine their
138 // hashes in a commutative way.
139 CombineHashCommutative(entry.peer->GetTopoHash(prev), &comm_hash);
140 next_hashed_nodes_ |= entry.peer->last_hashed_nodes_;
141 if (debug) {
142 LOG(INFO) << "DEBUG node " << name() << " += " << entry.peer->name()
143 << " mask=" << std::hex << next_hashed_nodes_;
144 }
145 }
146
147 // The last commutative group.
148 CombineHash(last_link_hash, &hval);
149 CombineHash(comm_hash, &hval);
150 }
151
152 topo_hash_.push_back(hval);
153 }
154
GetTopoHash(int distance) const155 size_t SigNode::GetTopoHash(int distance) const {
156 CHECK(!topo_hash_.empty());
157 if (distance >= topo_hash_.size()) {
158 CHECK(hash_is_final_);
159 return topo_hash_.back();
160 } else {
161 return topo_hash_[distance];
162 }
163 }
164
operator ==(const SigNode & other) const165 bool SigNode::operator==(const SigNode& other) const {
166 // TODO(babkin): add attributes too.
167 if (opcode() != other.opcode()) {
168 return false;
169 }
170
171 // Normally the caller is expected to compare the nodes
172 // at the same rank in different graphs, but just in case...
173 if (unique_rank_ != other.unique_rank_) {
174 return false;
175 }
176
177 if (hashed_peers_.size() != other.hashed_peers_.size()) {
178 return false;
179 }
180
181 for (auto it1 = hashed_peers_.begin(), it2 = other.hashed_peers_.begin();
182 it1 != hashed_peers_.end(); ++it1, ++it2) {
183 // TODO(babkin): might compare the actual values too
184 // but the hash is probably just as good.
185 if (it1->link_hash != it2->link_hash) {
186 return false;
187 }
188 if (it1->peer->unique_rank_ != it2->peer->unique_rank_) {
189 return false;
190 }
191 }
192
193 return true;
194 }
195
196 //=== Signature
197
198 constexpr int Signature::kMaxGraphSize;
199
ToString() const200 string Signature::ToString() const {
201 string result;
202 for (size_t n = 0; n < nodes.size(); ++n) {
203 // TODO(babkin): add attributes too.
204 result += absl::StrFormat("%d:%s", n, nodes[n]->opcode());
205 for (const auto& entry : nodes[n]->hashed_peers_) {
206 const auto& link = nodes[n]->hash_to_link_[entry.link_hash];
207
208 // The link entries are already sorted, by tags and then by the
209 // node ranks.
210 if (link.tag.local.IsInbound()) {
211 result +=
212 absl::StrFormat("[%s:%s:%d]", string(link.tag.local),
213 string(link.tag.remote), entry.peer->unique_rank_);
214 }
215 }
216 result.push_back(',');
217 }
218 return result;
219 }
220
Compute()221 Status Signature::Compute() {
222 if (map.size() > kMaxGraphSize) {
223 return Status(
224 error::INVALID_ARGUMENT,
225 absl::StrFormat(
226 "A graph of %d nodes is too big for signature computation, "
227 "the maximal supported node count is %d.",
228 map.size(), kMaxGraphSize));
229 }
230
231 // The value that will be assigned next as the unique node id.
232 // This also means that all the entries in nodes at indexes less than this
233 // have been finalized and don't need to be touched any more.
234 size_t next_node_id = 0;
235
236 sig_short = 0;
237 sig_full.resize(0); // Keep the storage.
238
239 // The main signature generation.
240 PrepareNodes();
241 FindUniqueHashes(&next_node_id);
242 while (next_node_id < map.size()) {
243 ComputeOneRound(next_node_id);
244 FindUniqueHashes(&next_node_id);
245 }
246
247 OrderLinks();
248
249 return Status::OK();
250 }
251
PrepareNodes()252 void Signature::PrepareNodes() {
253 nodes.resize(0); // Keep the storage.
254
255 // Initialize the nodes.
256 int64_t mask = 1;
257 for (const auto& entry : map) {
258 SigNode* node = entry.second.get();
259 node->last_hashed_nodes_ = node->node_mask_ = mask;
260 mask <<= 1;
261 node->unique_rank_ = ~0;
262 node->hash_is_final_ = false;
263 node->ComputeTopoHash0();
264 if (node->GetHighTopoHash() <= map.size()) {
265 // Would conflict with one of the reserved values.
266 node->ReHighTopoHash();
267 }
268
269 // The initial order is random.
270 nodes.emplace_back(node);
271 }
272 }
273
FindUniqueHashes(size_t * next_node_id_p)274 void Signature::FindUniqueHashes(size_t* next_node_id_p) {
275 // Start by sorting by the hash value.
276 std::sort(nodes.begin() + *next_node_id_p, nodes.end(),
277 SigNode::NodeOrderLess());
278
279 // At each call, if no nodes have unique hashes, one node that has a
280 // non-unique (shared) hash can be made unique by assigning a unique id.
281 // This node gets picked predictably by taking the last node.
282 // TODO(babkin): Technically, more than one node can be unshared,
283 // as long as their last_hashed_nodes_ overlap only by the nodes that
284 // already had the assigned ids before the current round. But it's not clear
285 // yet, how often would this beneficial, because it looks like for many
286 // subgraphs unsharing one node should be enough to untangle them. This
287 // would need more measurement before implementing.
288 bool found_unique = false;
289 for (size_t n = *next_node_id_p; n < nodes.size(); ++n) {
290 size_t cur_hash = nodes[n]->GetHighTopoHash();
291 if (n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash) {
292 // A sequence of nodes sharing the same hash. Skip over it.
293 // TODO(babkin): check here for the arbitrary hash conflicts and resolve
294 // them.
295 for (++n;
296 n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash;
297 ++n) {
298 }
299 if (found_unique || n != nodes.size() - 1) {
300 // Either some unique nodes have already been found, or this is
301 // not the last chance, keep trying to find the unique nodes.
302 continue;
303 }
304 // Here we're at the last node and haven't found any unique ones.
305 // So fall through and make this last node unique.
306 }
307
308 found_unique = true;
309 size_t id = (*next_node_id_p)++;
310 nodes[n]->unique_rank_ = id;
311
312 size_t last_hash = nodes[n]->GetHighTopoHash();
313 CombineHash(last_hash, &sig_short);
314 sig_full.push_back(last_hash);
315
316 // Take the hash at 0 and mix the unique rank into it. After that it will
317 // stay fixed.
318 nodes[n]->topo_hash_.resize(1);
319 nodes[n]->topo_hash_[0] = id + 1; // Avoid the value of 0.
320
321 nodes[n]->hash_is_final_ = true;
322 nodes[n]->last_hashed_nodes_ = nodes[n]->node_mask_;
323 if (n != id) {
324 std::swap(nodes[id], nodes[n]);
325 }
326 }
327 }
328
ComputeOneRound(size_t next_node_id)329 void Signature::ComputeOneRound(size_t next_node_id) {
330 // Reset the state of the nodes.
331 int debug_i = 0;
332 for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
333 auto node = *it;
334 // The hash at distance 0 never changes, so preserve it.
335 node->topo_hash_.resize(1);
336 node->last_hashed_nodes_ = node->node_mask_;
337 node->hash_is_final_ = false;
338 if (debug) {
339 LOG(INFO) << "DEBUG distance=" << 0 << " node " << debug_i++ << " "
340 << node->name() << " mask=" << std::hex
341 << node->last_hashed_nodes_;
342 }
343 }
344
345 bool stop = false;
346 // The distance can reach up to nodes.size()+1, to include not only all the
347 // nodes but also all the redundant paths.
348 for (int distance = 1; !stop; ++distance) {
349 for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
350 auto node = *it;
351 if (node->hash_is_final_) {
352 continue;
353 }
354 node->ComputeTopoHash(distance);
355 if (node->GetHighTopoHash() <= nodes.size()) {
356 // Would conflict with one of the reserved values.
357 node->ReHighTopoHash();
358 }
359 }
360
361 // Will be looking for the indications to not stop.
362 stop = true;
363
364 debug_i = 0;
365 // The bitmasks get moved after all the hash computations are done.
366 for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
367 auto node = *it;
368 if (debug) {
369 LOG(INFO) << "DEBUG distance=" << distance << " node " << debug_i++
370 << " " << node->name() << " oldmask=" << std::hex
371 << node->last_hashed_nodes_ << " mask=" << std::hex
372 << node->next_hashed_nodes_;
373 }
374 if (node->last_hashed_nodes_ == node->next_hashed_nodes_) {
375 // Stopped growing, this part of the graph must be fully
376 // surrounded by nodes that already have the unique ids.
377 node->hash_is_final_ = true;
378 } else {
379 node->last_hashed_nodes_ = node->next_hashed_nodes_;
380 stop = false;
381 }
382 }
383 }
384 }
385
OrderLinks()386 void Signature::OrderLinks() {
387 for (const auto& node : nodes) {
388 if (node->hashed_peers_.empty()) {
389 continue;
390 }
391
392 size_t cur_link_hash = node->hashed_peers_[0].link_hash + 1;
393 int first_idx = -1;
394
395 int idx;
396 for (idx = 0; idx < node->hashed_peers_.size(); ++idx) {
397 auto& entry = node->hashed_peers_[idx];
398 if (entry.link_hash == cur_link_hash) {
399 continue;
400 }
401 if (idx - first_idx > 1) {
402 // Need to sort.
403 std::sort(node->hashed_peers_.begin() + first_idx,
404 node->hashed_peers_.begin() + idx,
405 SigNode::HashedPeer::LessByRank());
406 }
407
408 cur_link_hash = entry.link_hash;
409 first_idx = idx;
410 }
411 if (idx - first_idx > 1) {
412 // Sort the last bunch.
413 std::sort(node->hashed_peers_.begin() + first_idx,
414 node->hashed_peers_.begin() + idx,
415 SigNode::HashedPeer::LessByRank());
416 }
417 }
418 }
419
operator ==(const Signature & other) const420 bool Signature::operator==(const Signature& other) const {
421 // Tries to find the differences as early as possible by
422 // comparing the hashes first.
423
424 if (sig_short != other.sig_short) {
425 return false;
426 }
427 if (sig_full.size() != other.sig_full.size()) {
428 return false;
429 }
430
431 for (auto it1 = sig_full.begin(), it2 = other.sig_full.begin();
432 it1 != sig_full.end(); ++it1, ++it2) {
433 if (*it1 != *it2) {
434 return false;
435 }
436 }
437
438 if (nodes.size() != other.nodes.size()) {
439 return false;
440 }
441 for (auto it1 = nodes.begin(), it2 = other.nodes.begin(); it1 != nodes.end();
442 ++it1, ++it2) {
443 if (**it1 != **it2) {
444 return false;
445 }
446 }
447
448 return true;
449 }
450
451 } // end namespace graph_analyzer
452 } // end namespace grappler
453 } // end namespace tensorflow
454