1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
17 
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
23 #include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
24 #include "tensorflow/core/grappler/utils.h"
25 
26 namespace tensorflow {
27 namespace grappler {
28 namespace graph_analyzer {
29 namespace test {
30 
31 using ::testing::ElementsAre;
32 using ::testing::Eq;
33 using ::testing::Gt;
34 using ::testing::Ne;
35 using ::testing::SizeIs;
36 
37 //===
38 
TEST(SigNodeLinkTag,Compare)39 TEST(SigNodeLinkTag, Compare) {
40   SigNode::LinkTag a(GenNode::Port(false, 1), GenNode::Port(false, 2));
41   SigNode::LinkTag b(GenNode::Port(false, 1), GenNode::Port(false, 2));
42   SigNode::LinkTag c(GenNode::Port(false, 2), GenNode::Port(false, 1));
43   SigNode::LinkTag d(GenNode::Port(false, 1), GenNode::Port(false, 3));
44   SigNode::LinkTag e(GenNode::Port(false, 2), GenNode::Port(false, 2));
45 
46   EXPECT_TRUE(a == b);
47   EXPECT_FALSE(a == c);
48   EXPECT_FALSE(a == e);
49 
50   EXPECT_FALSE(a < b);
51   EXPECT_FALSE(b < a);
52 
53   EXPECT_TRUE(a < c);
54   EXPECT_FALSE(c < a);
55 
56   EXPECT_TRUE(a < d);
57   EXPECT_FALSE(d < a);
58 }
59 
60 //===
61 
62 class SigBaseTest : public ::testing::Test, protected TestGraphs {
63  protected:
BuildSigMap(const GraphDef & graph)64   void BuildSigMap(const GraphDef& graph) {
65     gen_map_.clear();
66     sig_.map.clear();
67     CHECK(GenNode::BuildGraphInMap(graph, &gen_map_).ok());
68     Subgraph::Identity id;
69     for (const auto& entry : gen_map_) {
70       id.insert(entry.second.get());
71     }
72     Subgraph sg(id);
73     sg.ExtractForSignature(&sig_.map);
74   }
75 
CopyLinksPass2(std::map<SigNode::LinkTag,SigNode::Link> * link_map,SigNode * node)76   static void CopyLinksPass2(
77       std::map<SigNode::LinkTag, SigNode::Link>* link_map, SigNode* node) {
78     node->CopyLinksPass2(link_map);
79   }
80 
ComputeTopoHash0(SigNode * node)81   static void ComputeTopoHash0(SigNode* node) { node->ComputeTopoHash0(); }
82 
ComputeTopoHash(int distance,SigNode * node)83   static void ComputeTopoHash(int distance, SigNode* node) {
84     node->ComputeTopoHash(distance);
85   }
86 
GetTopoHash(int distance,SigNode * node)87   static size_t GetTopoHash(int distance, SigNode* node) {
88     return node->GetTopoHash(distance);
89   }
90 
GetHighTopoHash(SigNode * node)91   static size_t GetHighTopoHash(SigNode* node) {
92     return node->GetHighTopoHash();
93   }
94 
ReHighTopoHash(SigNode * node)95   static void ReHighTopoHash(SigNode* node) { node->ReHighTopoHash(); }
96 
RefHashedPeers(SigNode * node)97   static SigNode::HashedPeerVector& RefHashedPeers(SigNode* node) {
98     return node->hashed_peers_;
99   }
RefUniqueRank(SigNode * node)100   static size_t& RefUniqueRank(SigNode* node) { return node->unique_rank_; }
RefHashIsFinal(SigNode * node)101   static bool& RefHashIsFinal(SigNode* node) { return node->hash_is_final_; }
RefTopoHash(SigNode * node)102   static std::vector<size_t>& RefTopoHash(SigNode* node) {
103     return node->topo_hash_;
104   }
RefNodeMask(SigNode * node)105   static uint64_t& RefNodeMask(SigNode* node) { return node->node_mask_; }
RefLastHashedNodes(SigNode * node)106   static uint64_t& RefLastHashedNodes(SigNode* node) {
107     return node->last_hashed_nodes_;
108   }
RefNextHashedNodes(SigNode * node)109   static uint64_t& RefNextHashedNodes(SigNode* node) {
110     return node->next_hashed_nodes_;
111   }
112 
PrepareNodes(Signature * signature)113   static void PrepareNodes(Signature* signature) { signature->PrepareNodes(); }
114 
FindUniqueHashes(size_t * next_node_id_p,Signature * signature)115   static void FindUniqueHashes(size_t* next_node_id_p, Signature* signature) {
116     signature->FindUniqueHashes(next_node_id_p);
117   }
118 
ComputeOneRound(size_t next_node_id,Signature * signature)119   static void ComputeOneRound(size_t next_node_id, Signature* signature) {
120     signature->ComputeOneRound(next_node_id);
121   }
122 
OrderLinks(Signature * signature)123   static void OrderLinks(Signature* signature) { signature->OrderLinks(); }
124 
125   // These get initialized in BuildSigMap().
126   GenNodeMap gen_map_;
127   Signature sig_;
128 };
129 
130 //===
131 
132 class SigNodeTest : public SigBaseTest {};
133 
134 // Tests that the duplicate hashes get resolved by rehashing.
TEST_F(SigNodeTest,DuplicateHash)135 TEST_F(SigNodeTest, DuplicateHash) {
136   NodeDef node1 = MakeNodeConst("node1");
137   NodeDef node2 = MakeNodeConst("node2");
138   NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2");
139 
140   SigNode sn1(&node1);
141   SigNode sn2(&node2);
142   SigNode sn3(&node3);
143 
144   constexpr size_t kSameHash = 999;
145 
146   SigNode::Link link1;
147   link1.tag = SigNode::LinkTag(GenNode::Port(true, 0), GenNode::Port(false, 0));
148   link1.unique_hash = kSameHash;
149   link1.peers.emplace_back(&sn1);
150 
151   SigNode::Link link2;
152   link2.tag = SigNode::LinkTag(GenNode::Port(true, 1), GenNode::Port(false, 0));
153   link2.unique_hash = kSameHash;
154   link2.peers.emplace_back(&sn2);
155 
156   SigNode::Link link3;
157   link3.tag = SigNode::LinkTag(GenNode::Port(true, 2), GenNode::Port(false, 0));
158   link3.unique_hash = kSameHash;
159   link3.peers.emplace_back(&sn3);
160 
161   std::map<SigNode::LinkTag, SigNode::Link> link_map;
162   link_map[link1.tag] = link1;
163   link_map[link2.tag] = link2;
164   link_map[link3.tag] = link3;
165 
166   CopyLinksPass2(&link_map, &sn3);
167   auto& hl = sn3.hash_to_link();
168   EXPECT_THAT(hl, SizeIs(3));
169 
170   // Check that the hashes are self_consistent, and put the entries into
171   // another map with a known order.
172   std::map<SigNode::LinkTag, SigNode::Link> rehashed;
173   auto hlit = hl.begin();
174   ASSERT_THAT(hlit, Ne(hl.end()));
175   EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
176   rehashed[hlit->second.tag] = hlit->second;
177   ++hlit;
178   ASSERT_THAT(hlit, Ne(hl.end()));
179   EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
180   rehashed[hlit->second.tag] = hlit->second;
181   ++hlit;
182   ASSERT_THAT(hlit, Ne(hl.end()));
183   EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
184   rehashed[hlit->second.tag] = hlit->second;
185 
186   // Just in case.
187   ASSERT_THAT(rehashed, SizeIs(3));
188 
189   auto rhit = rehashed.begin();
190   ASSERT_THAT(rhit, Ne(rehashed.end()));
191   EXPECT_TRUE(rhit->second.tag == link1.tag);
192   EXPECT_THAT(rhit->second.unique_hash, Eq(kSameHash));
193   EXPECT_THAT(rhit->second.peers, ElementsAre(&sn1));
194 
195   ++rhit;
196   ASSERT_THAT(rhit, Ne(rehashed.end()));
197   EXPECT_TRUE(rhit->second.tag == link2.tag);
198   // This hash must be rehashed.
199   EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
200   size_t hash2 = rhit->second.unique_hash;
201   EXPECT_THAT(rhit->second.peers, ElementsAre(&sn2));
202 
203   ++rhit;
204   ASSERT_THAT(rhit, Ne(rehashed.end()));
205   EXPECT_TRUE(rhit->second.tag == link3.tag);
206   // This hash must be rehashed.
207   EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
208   EXPECT_THAT(rhit->second.unique_hash, Ne(hash2));
209   size_t hash3 = rhit->second.unique_hash;
210   EXPECT_THAT(rhit->second.peers, ElementsAre(&sn3));
211 
212   auto& peers = sn3.hashed_peers();
213   EXPECT_THAT(peers, SizeIs(3));
214 
215   auto peerit = peers.begin();
216   ASSERT_THAT(peerit, Ne(peers.end()));
217   EXPECT_THAT(peerit->link_hash, Eq(kSameHash));
218   EXPECT_THAT(peerit->peer, Eq(&sn1));
219 
220   ++peerit;
221   ASSERT_THAT(peerit, Ne(peers.end()));
222   EXPECT_THAT(peerit->link_hash, Eq(hash2));
223   EXPECT_THAT(peerit->peer, Eq(&sn2));
224 
225   ++peerit;
226   ASSERT_THAT(peerit, Ne(peers.end()));
227   EXPECT_THAT(peerit->link_hash, Eq(hash3));
228   EXPECT_THAT(peerit->peer, Eq(&sn3));
229 }
230 
231 // The full CopyLinks() is tested in (SubgraphTest, ExtractForSignature).
232 
TEST_F(SigNodeTest,GetTopoHash)233 TEST_F(SigNodeTest, GetTopoHash) {
234   NodeDef node1 = MakeNodeConst("node1");
235   SigNode sn1(&node1);
236 
237   // Fake some hash values.
238   RefTopoHash(&sn1).emplace_back(123);
239   RefTopoHash(&sn1).emplace_back(456);
240 
241   EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
242   EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
243 
244   RefHashIsFinal(&sn1) = true;
245 
246   EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
247   EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
248   EXPECT_THAT(GetTopoHash(2, &sn1), Eq(456));
249 
250   EXPECT_THAT(GetHighTopoHash(&sn1), Eq(456));
251 }
252 
TEST_F(SigNodeTest,ReTopoHash)253 TEST_F(SigNodeTest, ReTopoHash) {
254   NodeDef node1 = MakeNodeConst("node1");
255   SigNode sn1(&node1);
256 
257   // Fake some hash values.
258   RefTopoHash(&sn1).emplace_back(123);
259   RefTopoHash(&sn1).emplace_back(456);
260 
261   EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
262   EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
263 
264   ReHighTopoHash(&sn1);
265 
266   size_t expected_hash = 456;
267   CombineHash(1, &expected_hash);
268 
269   EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
270   EXPECT_THAT(GetTopoHash(1, &sn1), Eq(expected_hash));
271 }
272 
TEST_F(SigNodeTest,ComputeTopoHash0)273 TEST_F(SigNodeTest, ComputeTopoHash0) {
274   NodeDef node1 = MakeNodeConst("node1");
275   SigNode sn1(&node1);
276 
277   // Fake a topology.
278   RefUniqueRank(&sn1) = 10;
279   RefNodeMask(&sn1) = 0x02;
280 
281   RefTopoHash(&sn1).emplace_back(123);
282   RefTopoHash(&sn1).emplace_back(456);
283 
284   // Fake a state.
285   RefLastHashedNodes(&sn1) = 0xFF;
286   RefNextHashedNodes(&sn1) = 0xFF;
287 
288   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
289   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
290   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(2, nullptr));
291   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
292   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
293 
294   // Run the test.
295   ComputeTopoHash0(&sn1);
296 
297   EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x02));
298   EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x02));
299   EXPECT_THAT(RefTopoHash(&sn1), SizeIs(1));
300 
301   size_t exp_hval = std::hash<string>()(sn1.opcode());
302   CombineHash(1, &exp_hval);
303   CombineHash(1, &exp_hval);
304   CombineHash(2, &exp_hval);
305   CombineHash(3, &exp_hval);
306   CombineHash(3, &exp_hval);
307 
308   EXPECT_THAT(GetTopoHash(0, &sn1), Eq(exp_hval));
309 }
310 
TEST_F(SigNodeTest,ComputeTopoHashNotFinal)311 TEST_F(SigNodeTest, ComputeTopoHashNotFinal) {
312   NodeDef node1 = MakeNodeConst("node1");
313   SigNode sn1(&node1);
314   NodeDef node2 = MakeNodeConst("node2");
315   SigNode sn2(&node2);
316   NodeDef node3 = MakeNodeConst("node3");
317   SigNode sn3(&node3);
318 
319   // Fake a topology.
320   RefUniqueRank(&sn1) = 0;
321   RefNodeMask(&sn1) = 0x01;
322   RefUniqueRank(&sn2) = 0;
323   RefNodeMask(&sn2) = 0x02;
324   RefUniqueRank(&sn3) = 0;
325   RefNodeMask(&sn3) = 0x04;
326 
327   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
328   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
329   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
330   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
331   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
332 
333   // Fake a state.
334   RefTopoHash(&sn1).emplace_back(123);
335   RefTopoHash(&sn1).emplace_back(321);
336 
337   RefTopoHash(&sn2).emplace_back(456);
338   RefTopoHash(&sn2).emplace_back(654);
339 
340   RefTopoHash(&sn3).emplace_back(789);
341   RefTopoHash(&sn3).emplace_back(987);
342 
343   // These values are not realistic in the way that they don't include the bits
344   // from the mask of nodes themselves, but that's the point of this test: only
345   // the previous nodes' node sets are used in the computation, not their own
346   // masks directly.
347   RefLastHashedNodes(&sn1) = 0x8;
348   RefLastHashedNodes(&sn2) = 0x10;
349   RefLastHashedNodes(&sn3) = 0x20;
350 
351   // A scratch value to get overwritten.
352   RefNextHashedNodes(&sn1) = 0x100;
353 
354   ComputeTopoHash(2, &sn1);
355 
356   EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8));  // Unchanged.
357   EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x38));
358 
359   // This computes the hash form the explicit numbers above.
360   size_t exp_hash = 123;  // The 0th hash is the starting point.
361   size_t comm_hash;
362 
363   comm_hash = 0;
364   CombineHashCommutative(654, &comm_hash);
365   CombineHashCommutative(987, &comm_hash);
366 
367   CombineHash(10, &exp_hash);
368   CombineHash(comm_hash, &exp_hash);
369 
370   comm_hash = 0;
371   CombineHashCommutative(654, &comm_hash);
372 
373   CombineHash(20, &exp_hash);
374   CombineHash(comm_hash, &exp_hash);
375 
376   comm_hash = 0;
377   CombineHashCommutative(654, &comm_hash);
378   CombineHashCommutative(987, &comm_hash);
379 
380   CombineHash(30, &exp_hash);
381   CombineHash(comm_hash, &exp_hash);
382 
383   EXPECT_THAT(GetTopoHash(2, &sn1), Eq(exp_hash));
384   EXPECT_THAT(RefTopoHash(&sn1), SizeIs(3));
385 }
386 
TEST_F(SigNodeTest,ComputeTopoHashFinal)387 TEST_F(SigNodeTest, ComputeTopoHashFinal) {
388   NodeDef node1 = MakeNodeConst("node1");
389   SigNode sn1(&node1);
390   NodeDef node2 = MakeNodeConst("node2");
391   SigNode sn2(&node2);
392   NodeDef node3 = MakeNodeConst("node3");
393   SigNode sn3(&node3);
394 
395   // Fake a topology - same as for ComputeTopoHashNotFinal.
396   RefUniqueRank(&sn1) = 0;
397   RefNodeMask(&sn1) = 0x01;
398   RefUniqueRank(&sn2) = 0;
399   RefNodeMask(&sn2) = 0x02;
400   RefUniqueRank(&sn3) = 0;
401   RefNodeMask(&sn3) = 0x04;
402 
403   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
404   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
405   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
406   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
407   RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
408 
409   // Fake a state - mostly same as for ComputeTopoHashNotFinal.
410   RefTopoHash(&sn1).emplace_back(123);
411   RefTopoHash(&sn1).emplace_back(321);
412 
413   RefTopoHash(&sn2).emplace_back(456);
414   RefTopoHash(&sn2).emplace_back(654);
415 
416   RefTopoHash(&sn3).emplace_back(789);
417   RefTopoHash(&sn3).emplace_back(987);
418 
419   // These values are not realistic in the way that they don't include the bits
420   // from the mask of nodes themselves, but that's the point of this test: only
421   // the previous nodes' node sets are used in the computation, not their own
422   // masks directly.
423   RefLastHashedNodes(&sn1) = 0x8;
424   RefLastHashedNodes(&sn2) = 0x10;
425   RefLastHashedNodes(&sn3) = 0x20;
426 
427   // A scratch value to get overwritten.
428   RefNextHashedNodes(&sn1) = 0x100;
429 
430   // This is the difference in configuration.
431   RefHashIsFinal(&sn1) = true;
432 
433   ComputeTopoHash(2, &sn1);
434 
435   EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8));  // Unchanged.
436   EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x8));
437   EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
438   EXPECT_THAT(GetTopoHash(2, &sn1), Eq(321));
439 }
440 
TEST_F(SigNodeTest,EqualsOpcode)441 TEST_F(SigNodeTest, EqualsOpcode) {
442   NodeDef node1 = MakeNodeConst("node1");
443   SigNode sn1(&node1);
444 
445   NodeDef node2 = MakeNodeConst("node2");
446   SigNode sn2(&node2);
447 
448   EXPECT_TRUE(sn1 == sn2);
449   EXPECT_FALSE(sn1 != sn2);
450 
451   node2.set_op("Mul");
452 
453   EXPECT_TRUE(sn1 != sn2);
454   EXPECT_FALSE(sn1 == sn2);
455 }
456 
TEST_F(SigNodeTest,EqualsRank)457 TEST_F(SigNodeTest, EqualsRank) {
458   NodeDef node1 = MakeNodeConst("node1");
459   SigNode sn1(&node1);
460 
461   NodeDef node2 = MakeNodeConst("node2");
462   SigNode sn2(&node2);
463 
464   EXPECT_TRUE(sn1 == sn2);
465   EXPECT_FALSE(sn1 != sn2);
466 
467   RefUniqueRank(&sn1) = 1;
468   RefUniqueRank(&sn2) = 2;
469 
470   EXPECT_TRUE(sn1 != sn2);
471   EXPECT_FALSE(sn1 == sn2);
472 }
473 
474 // Checks that if the nodes have a different number of links,
475 // they will be considered unequal.
TEST_F(SigNodeTest,EqualsLinkSize)476 TEST_F(SigNodeTest, EqualsLinkSize) {
477   GraphDef graph1;
478   (*graph1.add_node()) = MakeNodeConst("node1");
479   (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
480 
481   GenNodeMap gen_map1;
482   ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
483 
484   Subgraph::Identity id1;
485   id1.insert(gen_map1["node1"].get());
486   id1.insert(gen_map1["node2"].get());
487   Subgraph sg1(id1);
488 
489   SigNodeMap sig_map1;
490   sg1.ExtractForSignature(&sig_map1);
491 
492   GraphDef graph2;
493   (*graph2.add_node()) = MakeNodeConst("node1");
494   // The difference between graph1 and graph2: one more input.
495   auto node22 = graph2.add_node();
496   *node22 = MakeNodeMul("node2", "node1", "node1");
497   node22->add_input("node2");
498 
499   GenNodeMap gen_map2;
500   ASSERT_THAT(GenNode::BuildGraphInMap(graph2, &gen_map2), Eq(Status::OK()));
501 
502   Subgraph::Identity id2;
503   id2.insert(gen_map2["node1"].get());
504   id2.insert(gen_map2["node2"].get());
505   Subgraph sg2(id2);
506 
507   SigNodeMap sig_map2;
508   sg2.ExtractForSignature(&sig_map2);
509 
510   EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
511   EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
512   EXPECT_FALSE(*sig_map2["node2"] == *sig_map1["node2"]);
513 }
514 
TEST_F(SigNodeTest,EqualsLinks)515 TEST_F(SigNodeTest, EqualsLinks) {
516   // Start with 2 copies of the same graph.
517   GraphDef graph1;
518   (*graph1.add_node()) = MakeNodeConst("node1");
519   (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
520 
521   GenNodeMap gen_map1;
522   ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
523 
524   Subgraph::Identity id1;
525   id1.insert(gen_map1["node1"].get());
526   id1.insert(gen_map1["node2"].get());
527   Subgraph sg1(id1);
528 
529   SigNodeMap sig_map1;
530   sg1.ExtractForSignature(&sig_map1);
531 
532   GenNodeMap gen_map2;
533   ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map2), Eq(Status::OK()));
534 
535   Subgraph::Identity id2;
536   id2.insert(gen_map2["node1"].get());
537   id2.insert(gen_map2["node2"].get());
538   Subgraph sg2(id2);
539 
540   SigNodeMap sig_map2;
541   sg2.ExtractForSignature(&sig_map2);
542 
543   EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
544   EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
545 
546   // Alter the link hash of one of the nodes.
547   SigNode* sn2 = sig_map2["node2"].get();
548   ++RefHashedPeers(sn2)[0].link_hash;
549 
550   EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
551 
552   // Restore back.
553   --RefHashedPeers(sn2)[0].link_hash;
554   EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
555 
556   // Alter the unique rank of a referenced node.
557   ++RefUniqueRank(sig_map2["node1"].get());
558 
559   EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
560 }
561 
562 //===
563 
564 class SignatureTest : public SigBaseTest {
565  protected:
566   // Initializeds the state used to generate the permutations of a given size.
InitPermutation(size_t size,std::vector<size_t> * plain_permutation,std::vector<size_t> * countdown)567   static void InitPermutation(size_t size,
568                               std::vector<size_t>* plain_permutation,
569                               std::vector<size_t>* countdown) {
570     plain_permutation->clear();
571     countdown->clear();
572     for (size_t i = 0; i < size; ++i) {
573       plain_permutation->emplace_back(i);
574       countdown->emplace_back(size - 1 - i);
575     }
576   }
577 
578   // Builds a permutation guided by the count-down value.
BuildPermutation(const std::vector<size_t> & plain_permutation,const std::vector<size_t> & countdown,std::vector<size_t> * result)579   static void BuildPermutation(const std::vector<size_t>& plain_permutation,
580                                const std::vector<size_t>& countdown,
581                                std::vector<size_t>* result) {
582     *result = plain_permutation;
583     for (int i = 0; i < result->size(); ++i) {
584       std::swap((*result)[i], (*result)[i + countdown[i]]);
585     }
586   }
587 
588   // Returns false when the count-down is finished.
CountDown(std::vector<size_t> * countdown)589   static bool CountDown(std::vector<size_t>* countdown) {
590     // The last position always contains 0, so skip it.
591     int pos;
592     for (pos = countdown->size() - 2; pos >= 0; --pos) {
593       if ((*countdown)[pos] > 0) {
594         --(*countdown)[pos];
595         break;
596       }
597       (*countdown)[pos] = (countdown->size() - 1 - pos);
598     }
599 
600     return pos >= 0;
601   }
602 
603   // Permutes the nodes every which way and checks that all the signatures
604   // produced are the same. This is reasonable for the graphs up to the
605   // size 5, maybe 6 at the stretch. After that the number of permutation grows
606   // huge and the test becomes very slow.
TestGraphEveryWay(const GraphDef & graph)607   void TestGraphEveryWay(const GraphDef& graph) {
608     size_t graph_size = graph.node_size();
609 
610     gen_map_.clear();
611     sig_.map.clear();
612     Status result = GenNode::BuildGraphInMap(graph, &gen_map_);
613     ASSERT_THAT(result, Eq(Status::OK()));
614     Subgraph::Identity id;
615     for (const auto& entry : gen_map_) {
616       id.insert(entry.second.get());
617     }
618     Subgraph sg(id);
619     sg.ExtractForSignature(&sig_.map);
620 
621     std::vector<size_t> plain_permutation;
622     std::vector<size_t> countdown;
623     InitPermutation(graph_size, &plain_permutation, &countdown);
624 
625     std::set<string> signatures;
626     std::vector<size_t> permutation;
627     do {
628       BuildPermutation(plain_permutation, countdown, &permutation);
629 
630       constexpr bool kDebugPermutation = false;
631       if (kDebugPermutation) {
632         string p;
633         for (int i = 0; i < permutation.size(); ++i) {
634           p.push_back('0' + permutation[i]);
635         }
636         LOG(INFO) << "Permutation: " << p;
637       }
638 
639       std::vector<std::unique_ptr<SigNode>> hold(graph_size);
640       int idx;
641 
642       // Permute the nodes.
643       sig_.nodes.clear();
644       idx = 0;
645       if (kDebugPermutation) {
646         LOG(INFO) << "    nodes before permutation:";
647       }
648       for (auto& entry : sig_.map) {
649         if (kDebugPermutation) {
650           LOG(INFO) << "        " << entry.second.get();
651         }
652         hold[idx++] = std::move(entry.second);
653       }
654       idx = 0;
655       if (kDebugPermutation) {
656         LOG(INFO) << "    nodes after permutation:";
657       }
658       for (auto& entry : sig_.map) {
659         entry.second = std::move(hold[permutation[idx++]]);
660         if (kDebugPermutation) {
661           LOG(INFO) << "        " << entry.second.get();
662         }
663         // This is used to order the links per permutation.
664         sig_.nodes.emplace_back(entry.second.get());
665         RefUniqueRank(entry.second.get()) = idx;
666       }
667       // Order the links with the same tags per permutation.
668       OrderLinks(&sig_);
669 
670       // The test as such.
671       ASSERT_THAT(sig_.Compute(), Eq(Status::OK()));
672 
673       signatures.insert(sig_.ToString());
674 
675       EXPECT_THAT(sig_.sig_full, SizeIs(graph_size));
676       size_t hval = 0;
677       for (size_t ih : sig_.sig_full) {
678         // The space 1..graph_size is reserved.
679         EXPECT_THAT(ih, Gt(graph_size));
680         CombineHash(ih, &hval);
681       }
682       EXPECT_THAT(sig_.sig_short, Eq(hval));
683 
684       // Un-permute the nodes for the next iteration.
685       idx = 0;
686       for (auto& entry : sig_.map) {
687         hold[permutation[idx++]] = std::move(entry.second);
688       }
689       idx = 0;
690       if (kDebugPermutation) {
691         LOG(INFO) << "    nodes after un-permutation:";
692       }
693       for (auto& entry : sig_.map) {
694         entry.second = std::move(hold[idx++]);
695         if (kDebugPermutation) {
696           LOG(INFO) << "        " << entry.second.get();
697         }
698       }
699     } while (CountDown(&countdown));
700 
701     for (const auto& s : signatures) {
702       LOG(INFO) << "Signature: " << s;
703     }
704 
705     // All the permutations should produce the same signature.
706     EXPECT_THAT(signatures, SizeIs(1));
707   }
708 };
709 
TEST_F(SignatureTest,PrepareNodes)710 TEST_F(SignatureTest, PrepareNodes) {
711   NodeDef node1 = MakeNodeConst("node1");
712   sig_.map["node1"] = absl::make_unique<SigNode>(&node1);
713   NodeDef node2 = MakeNodeConst("node2");
714   sig_.map["node2"] = absl::make_unique<SigNode>(&node2);
715   NodeDef node3 = MakeNodeConst("node3");
716   sig_.map["node3"] = absl::make_unique<SigNode>(&node3);
717 
718   PrepareNodes(&sig_);
719 
720   ASSERT_THAT(sig_.nodes, SizeIs(3));
721 
722   int idx = 0;
723   for (const auto& entry : sig_.map) {
724     EXPECT_THAT(RefNodeMask(entry.second.get()), Eq(1 << idx))
725         << " at index " << idx;
726     EXPECT_THAT(RefUniqueRank(entry.second.get()), Eq(static_cast<size_t>(~0)))
727         << " at index " << idx;
728     EXPECT_THAT(RefHashIsFinal(entry.second.get()), false)
729         << " at index " << idx;
730     EXPECT_THAT(RefTopoHash(entry.second.get()), SizeIs(1))
731         << " at index " << idx;
732     ++idx;
733   }
734 }
735 
TEST_F(SignatureTest,FindUniqueHashesAllDifferent)736 TEST_F(SignatureTest, FindUniqueHashesAllDifferent) {
737   NodeDef node1 = MakeNodeConst("node1");
738   SigNode sn1(&node1);
739   NodeDef node2 = MakeNodeConst("node2");
740   SigNode sn2(&node2);
741   NodeDef node3 = MakeNodeConst("node3");
742   SigNode sn3(&node3);
743   NodeDef node4 = MakeNodeConst("node4");
744   SigNode sn4(&node4);
745 
746   // The last values in the arrays values go in the backwards order.
747   RefTopoHash(&sn1).emplace_back(100);
748   RefTopoHash(&sn1).emplace_back(900);
749 
750   RefTopoHash(&sn2).emplace_back(200);
751   RefTopoHash(&sn2).emplace_back(800);
752 
753   RefTopoHash(&sn3).emplace_back(300);
754   RefTopoHash(&sn3).emplace_back(700);
755 
756   RefTopoHash(&sn4).emplace_back(400);
757   RefTopoHash(&sn4).emplace_back(600);
758 
759   sig_.nodes.emplace_back(&sn1);
760   sig_.nodes.emplace_back(&sn2);
761   sig_.nodes.emplace_back(&sn3);
762   sig_.nodes.emplace_back(&sn4);
763 
764   size_t next = 1;  // Skips over sn1.
765 
766   FindUniqueHashes(&next, &sig_);
767   EXPECT_THAT(next, Eq(4));
768 
769   EXPECT_THAT(sig_.nodes[0], Eq(&sn1));
770   // The nodes after first one get sorted by the high hash.
771   EXPECT_THAT(sig_.nodes[1], Eq(&sn4));
772   EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
773   EXPECT_THAT(sig_.nodes[3], Eq(&sn2));
774 
775   EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
776   // Nodes that get finalized are marked as such.
777   EXPECT_THAT(RefHashIsFinal(&sn2), Eq(true));
778   EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
779   EXPECT_THAT(RefHashIsFinal(&sn4), Eq(true));
780 
781   EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
782   ASSERT_THAT(RefTopoHash(&sn2), SizeIs(1));
783   ASSERT_THAT(RefTopoHash(&sn3), SizeIs(1));
784   ASSERT_THAT(RefTopoHash(&sn4), SizeIs(1));
785 
786   EXPECT_THAT(RefTopoHash(&sn2)[0], Eq(4));
787   EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(3));
788   EXPECT_THAT(RefTopoHash(&sn4)[0], Eq(2));
789 
790   EXPECT_THAT(sig_.sig_full, ElementsAre(600, 700, 800));
791 
792   size_t exp_short_hash = 0;
793   CombineHash(600, &exp_short_hash);
794   CombineHash(700, &exp_short_hash);
795   CombineHash(800, &exp_short_hash);
796   EXPECT_THAT(sig_.sig_short, Eq(exp_short_hash));
797 }
798 
TEST_F(SignatureTest,FindUniqueHashesDuplicatesExceptOne)799 TEST_F(SignatureTest, FindUniqueHashesDuplicatesExceptOne) {
800   NodeDef node1 = MakeNodeConst("node1");
801   SigNode sn1(&node1);
802   NodeDef node2 = MakeNodeConst("node2");
803   SigNode sn2(&node2);
804   NodeDef node3 = MakeNodeConst("node3");
805   SigNode sn3(&node3);
806   NodeDef node4 = MakeNodeConst("node4");
807   SigNode sn4(&node4);
808   NodeDef node5 = MakeNodeConst("node5");
809   SigNode sn5(&node5);
810 
811   RefTopoHash(&sn1).emplace_back(100);
812   RefTopoHash(&sn1).emplace_back(600);
813 
814   RefTopoHash(&sn2).emplace_back(200);
815   RefTopoHash(&sn2).emplace_back(600);
816 
817   RefTopoHash(&sn3).emplace_back(300);
818   RefTopoHash(&sn3).emplace_back(700);
819 
820   RefTopoHash(&sn4).emplace_back(400);
821   RefTopoHash(&sn4).emplace_back(800);
822 
823   RefTopoHash(&sn5).emplace_back(500);
824   RefTopoHash(&sn5).emplace_back(800);
825 
826   sig_.nodes.emplace_back(&sn1);
827   sig_.nodes.emplace_back(&sn2);
828   sig_.nodes.emplace_back(&sn3);
829   sig_.nodes.emplace_back(&sn4);
830   sig_.nodes.emplace_back(&sn5);
831 
832   size_t next = 0;
833 
834   FindUniqueHashes(&next, &sig_);
835   EXPECT_THAT(next, Eq(1));
836 
837   // The unique node goes first.
838   EXPECT_THAT(sig_.nodes[0], Eq(&sn3));
839 
840   // The rest of the nodes are assumed to be sorted in a stable order.
841   EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
842   // Node 1 gets swapped with node 3.
843   EXPECT_THAT(sig_.nodes[2], Eq(&sn1));
844   EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
845   EXPECT_THAT(sig_.nodes[4], Eq(&sn5));
846 
847   EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
848   EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
849   EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
850   EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
851   EXPECT_THAT(RefHashIsFinal(&sn5), Eq(false));
852 
853   EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
854   EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
855   EXPECT_THAT(RefTopoHash(&sn3), SizeIs(1));
856   EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
857   EXPECT_THAT(RefTopoHash(&sn5), SizeIs(2));
858 
859   EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(1));
860 }
861 
TEST_F(SignatureTest,FindUniqueHashesDuplicates)862 TEST_F(SignatureTest, FindUniqueHashesDuplicates) {
863   NodeDef node1 = MakeNodeConst("node1");
864   SigNode sn1(&node1);
865   NodeDef node2 = MakeNodeConst("node2");
866   SigNode sn2(&node2);
867   NodeDef node3 = MakeNodeConst("node3");
868   SigNode sn3(&node3);
869   NodeDef node4 = MakeNodeConst("node4");
870   SigNode sn4(&node4);
871   NodeDef node5 = MakeNodeConst("node5");
872   SigNode sn5(&node5);
873 
874   RefTopoHash(&sn1).emplace_back(100);
875   RefTopoHash(&sn1).emplace_back(600);
876 
877   RefTopoHash(&sn2).emplace_back(200);
878   RefTopoHash(&sn2).emplace_back(600);
879 
880   RefTopoHash(&sn3).emplace_back(300);
881   RefTopoHash(&sn3).emplace_back(700);
882 
883   RefTopoHash(&sn4).emplace_back(400);
884   RefTopoHash(&sn4).emplace_back(700);
885 
886   RefTopoHash(&sn5).emplace_back(500);
887   RefTopoHash(&sn5).emplace_back(700);
888 
889   sig_.nodes.emplace_back(&sn1);
890   sig_.nodes.emplace_back(&sn2);
891   sig_.nodes.emplace_back(&sn3);
892   sig_.nodes.emplace_back(&sn4);
893   sig_.nodes.emplace_back(&sn5);
894 
895   size_t next = 0;
896 
897   FindUniqueHashes(&next, &sig_);
898   EXPECT_THAT(next, Eq(1));
899 
900   // The last copy of the last duplicate wins.
901   EXPECT_THAT(sig_.nodes[0], Eq(&sn5));
902 
903   // The rest of the nodes are assumed to be sorted in a stable order.
904   // Node 1 gets swapped.
905   EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
906   EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
907   EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
908   EXPECT_THAT(sig_.nodes[4], Eq(&sn1));
909 
910   EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
911   EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
912   EXPECT_THAT(RefHashIsFinal(&sn3), Eq(false));
913   EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
914   EXPECT_THAT(RefHashIsFinal(&sn5), Eq(true));
915 
916   EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
917   EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
918   EXPECT_THAT(RefTopoHash(&sn3), SizeIs(2));
919   EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
920   EXPECT_THAT(RefTopoHash(&sn5), SizeIs(1));
921 
922   EXPECT_THAT(RefTopoHash(&sn5)[0], Eq(1));
923 }
924 
925 // On a circular topology.
TEST_F(SignatureTest,ComputeOneRoundCircular)926 TEST_F(SignatureTest, ComputeOneRoundCircular) {
927   BuildSigMap(graph_circular_onedir_);
928   PrepareNodes(&sig_);
929 
930   ASSERT_THAT(sig_.nodes, SizeIs(5));
931 
932   // This skips FindUniqueHashes() which would pick one node, so that
933   // all the nodes are equivalent for ComputeOneRound().
934 
935   ComputeOneRound(0, &sig_);
936 
937   // All the nodes are the same, so the computed hashes will also be the same.
938   size_t hval = GetHighTopoHash(sig_.nodes[0]);
939   for (int i = 0; i < 5; ++i) {
940     EXPECT_THAT(GetHighTopoHash(sig_.nodes[i]), Eq(hval)) << " at index " << i;
941     EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
942     EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
943         << " at index " << i;
944     EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
945         << " at index " << i;
946     // The sets of hashed nodes go like this:
947     // Step 0: self.
948     // Step 1: self, previous (-1) and next (+1) node.
949     // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
950     // Step 3: still all 5 nodes in the graph
951     EXPECT_THAT(RefTopoHash(sig_.nodes[i]), SizeIs(4)) << " at index " << i;
952   }
953 }
954 
955 // On a linear topology.
TEST_F(SignatureTest,ComputeOneRoundLinear)956 TEST_F(SignatureTest, ComputeOneRoundLinear) {
957   BuildSigMap(graph_linear_);
958   PrepareNodes(&sig_);
959 
960   ASSERT_THAT(sig_.nodes, SizeIs(5));
961 
962   // This skips FindUniqueHashes() which would pick one node, so that
963   // all the nodes are equivalent for ComputeOneRound().
964 
965   ComputeOneRound(0, &sig_);
966 
967   std::vector<size_t> hash_size;
968   for (int i = 0; i < 5; ++i) {
969     EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
970     EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
971         << " at index " << i;
972     EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
973         << " at index " << i;
974     hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
975   }
976 
977   // The sets of hashed nodes for the central node go like this:
978   // Step 0: self.
979   // Step 1: self, previous (-1) and next (+1) node.
980   // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
981   // Step 3: still all 5 nodes in the graph
982   //
983   // The nodes one step closer to the ends require one more step. The end nodes
984   // require one more step yet.
985   std::sort(hash_size.begin(), hash_size.end());
986   EXPECT_THAT(hash_size, ElementsAre(4, 5, 5, 6, 6));
987 }
988 
989 // On a linear topology where the central node has been already marked as unique
990 // (yeah, not a very realistic case but tests the situations when the
991 // disconnected subgraphs get created).
TEST_F(SignatureTest,ComputeOneRoundSplitLinear)992 TEST_F(SignatureTest, ComputeOneRoundSplitLinear) {
993   BuildSigMap(graph_linear_);
994   PrepareNodes(&sig_);
995 
996   ASSERT_THAT(sig_.nodes, SizeIs(5));
997 
998   // This test relies on the order of SigNodeMap imposed on sig_.nodes.
999 
1000   // The middle node gets separated by moving it to the front.
1001   std::swap(sig_.nodes[0], sig_.nodes[2]);
1002   ASSERT_THAT(RefNodeMask(sig_.nodes[0]), Eq(0x04));
1003   ASSERT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
1004   ASSERT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
1005   RefHashIsFinal(sig_.nodes[0]) = true;
1006 
1007   ComputeOneRound(1, &sig_);
1008 
1009   // These should stay unchanged.
1010   EXPECT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
1011   EXPECT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
1012 
1013   std::vector<size_t> hash_size;
1014   for (int i = 1; i < 5; ++i) {
1015     EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
1016     hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
1017   }
1018 
1019   std::sort(hash_size.begin(), hash_size.end());
1020   // The end nodes take 4 steps, closer to the center 3 steps.
1021   EXPECT_THAT(hash_size, ElementsAre(3, 3, 4, 4));
1022 
1023   EXPECT_THAT(RefLastHashedNodes(sig_.nodes[1]), Eq(0x07));
1024   EXPECT_THAT(RefNextHashedNodes(sig_.nodes[1]), Eq(0x07));
1025   EXPECT_THAT(RefLastHashedNodes(sig_.nodes[2]), Eq(0x07));
1026   EXPECT_THAT(RefNextHashedNodes(sig_.nodes[2]), Eq(0x07));
1027 
1028   EXPECT_THAT(RefLastHashedNodes(sig_.nodes[3]), Eq(0x1C));
1029   EXPECT_THAT(RefNextHashedNodes(sig_.nodes[3]), Eq(0x1C));
1030   EXPECT_THAT(RefLastHashedNodes(sig_.nodes[4]), Eq(0x1C));
1031   EXPECT_THAT(RefNextHashedNodes(sig_.nodes[4]), Eq(0x1C));
1032 }
1033 
TEST_F(SignatureTest,OrderLinks)1034 TEST_F(SignatureTest, OrderLinks) {
1035   gen_map_.clear();
1036   sig_.map.clear();
1037   Status result = GenNode::BuildGraphInMap(graph_for_link_order_, &gen_map_);
1038   ASSERT_THAT(result, Eq(Status::OK()));
1039   Subgraph::Identity id;
1040   for (const auto& entry : gen_map_) {
1041     id.insert(entry.second.get());
1042   }
1043   Subgraph sg(id);
1044   sg.ExtractForSignature(&sig_.map);
1045 
1046   // Populate the fake signature and assign the ranks in the backwards order.
1047   for (auto it = sig_.map.rbegin(); it != sig_.map.rend(); ++it) {
1048     auto& entry = *it;
1049     RefUniqueRank(entry.second.get()) = sig_.nodes.size();
1050     sig_.nodes.emplace_back(entry.second.get());
1051   }
1052 
1053   // How it was ordered in the original graph.
1054   string before = sig_.ToString();
1055   // clang-format off
1056   EXPECT_THAT(before, Eq(
1057     "0:Mul[i0:o0:5][i0:o0:4][i0:o1:4][i0:o2:3][i0:o2:2][i0:o3:2],"
1058     "1:Mul[i0:o0:5][i0:o0:4][i0:o0:3][i0:o0:2],"
1059     "2:Const,"
1060     "3:Const,"
1061     "4:Const,"
1062     "5:Const,"
1063     ));
1064   // clang-format on
1065 
1066   OrderLinks(&sig_);
1067 
1068   string after = sig_.ToString();
1069   // clang-format off
1070   EXPECT_THAT(after, Eq(
1071       "0:Mul[i0:o0:4][i0:o0:5][i0:o1:4][i0:o2:2][i0:o2:3][i0:o3:2],"
1072       "1:Mul[i0:o0:2][i0:o0:3][i0:o0:4][i0:o0:5],"
1073       "2:Const,"
1074       "3:Const,"
1075       "4:Const,"
1076       "5:Const,"
1077       ));
1078   // clang-format on
1079 }
1080 
TEST_F(SignatureTest,GraphTooBig)1081 TEST_F(SignatureTest, GraphTooBig) {
1082   GraphDef graph;
1083   for (int i = 0; i <= Signature::kMaxGraphSize; ++i) {
1084     (*graph.add_node()) = MakeNodeConst(absl::StrFormat("node%d", i));
1085   }
1086 
1087   ASSERT_THAT(GenNode::BuildGraphInMap(graph, &gen_map_), Eq(Status::OK()));
1088 
1089   Subgraph::Identity id;
1090   for (const auto& entry : gen_map_) {
1091     id.insert(entry.second.get());
1092   }
1093   Subgraph sg(id);
1094   sg.ExtractForSignature(&sig_.map);
1095 
1096   ASSERT_THAT(sig_.Compute(),
1097               Eq(Status(error::INVALID_ARGUMENT,
1098                         "A graph of 65 nodes is too big for signature "
1099                         "computation, the maximal supported node count is "
1100                         "64.")));
1101 }
1102 
TEST_F(SignatureTest,ToString)1103 TEST_F(SignatureTest, ToString) {
1104   BuildSigMap(graph_circular_onedir_);
1105   PrepareNodes(&sig_);
1106 
1107   ASSERT_THAT(sig_.nodes, SizeIs(5));
1108 
1109   // Fake the works by assigning unique ranks as they go in the initial order.
1110   for (int i = 0; i < 5; ++i) {
1111     RefUniqueRank(sig_.nodes[i]) = i;
1112     RefHashIsFinal(sig_.nodes[i]) = true;
1113   }
1114 
1115   string result = sig_.ToString();
1116 
1117   // clang-format off
1118   ASSERT_THAT(result, Eq(
1119       "0:Mul[i0:o0:4][i0:o0:4],"
1120       "1:Mul[i0:o0:0][i0:o0:0],"
1121       "2:Mul[i0:o0:1][i0:o0:1],"
1122       "3:Mul[i0:o0:2][i0:o0:2],"
1123       "4:Mul[i0:o0:3][i0:o0:3],"
1124       ));
1125   // clang-format on
1126 }
1127 
1128 // This is a test of the permutation logic itself.
TEST_F(SignatureTest,Permutation)1129 TEST_F(SignatureTest, Permutation) {
1130   std::vector<size_t> plain_permutation;
1131   std::vector<size_t> countdown;
1132   InitPermutation(5, &plain_permutation, &countdown);
1133 
1134   std::set<string> results;
1135 
1136   std::vector<size_t> permutation;
1137   do {
1138     BuildPermutation(plain_permutation, countdown, &permutation);
1139     EXPECT_THAT(permutation, SizeIs(5));
1140 
1141     string p;
1142     for (int i = 0; i < permutation.size(); ++i) {
1143       p.push_back('0' + permutation[i]);
1144     }
1145     LOG(INFO) << "Permutation: " << p;
1146     results.insert(p);
1147   } while (CountDown(&countdown));
1148 
1149   EXPECT_THAT(results, SizeIs(5 * 4 * 3 * 2 * 1));
1150 }
1151 
TEST_F(SignatureTest,ComputeCircularOneDir)1152 TEST_F(SignatureTest, ComputeCircularOneDir) {
1153   TestGraphEveryWay(graph_circular_onedir_);
1154 }
1155 
TEST_F(SignatureTest,ComputeCircularBiDir)1156 TEST_F(SignatureTest, ComputeCircularBiDir) {
1157   TestGraphEveryWay(graph_circular_bidir_);
1158 }
1159 
TEST_F(SignatureTest,ComputeLinear)1160 TEST_F(SignatureTest, ComputeLinear) { TestGraphEveryWay(graph_linear_); }
1161 
TEST_F(SignatureTest,ComputeMultiInput)1162 TEST_F(SignatureTest, ComputeMultiInput) {
1163   TestGraphEveryWay(graph_multi_input_);
1164 }
1165 
TEST_F(SignatureTest,ComputeAllOrNone)1166 TEST_F(SignatureTest, ComputeAllOrNone) {
1167   TestGraphEveryWay(graph_all_or_none_);
1168 }
1169 
TEST_F(SignatureTest,ComputeCross)1170 TEST_F(SignatureTest, ComputeCross) { TestGraphEveryWay(graph_small_cross_); }
1171 
TEST_F(SignatureTest,Equals)1172 TEST_F(SignatureTest, Equals) {
1173   // Start with 2 copies of the same graph.
1174   GenNodeMap gen_map1;
1175   ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map1),
1176               Eq(Status::OK()));
1177 
1178   Subgraph::Identity id1;
1179   id1.insert(gen_map1["node1"].get());
1180   id1.insert(gen_map1["node2"].get());
1181   Subgraph sg1(id1);
1182 
1183   Signature sig1;
1184   sg1.ExtractForSignature(&sig1.map);
1185   ASSERT_THAT(sig1.Compute(), Eq(Status::OK()));
1186 
1187   GenNodeMap gen_map2;
1188   ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map2),
1189               Eq(Status::OK()));
1190 
1191   Subgraph::Identity id2;
1192   id2.insert(gen_map2["node1"].get());
1193   id2.insert(gen_map2["node2"].get());
1194   Subgraph sg2(id2);
1195 
1196   Signature sig2;
1197   sg2.ExtractForSignature(&sig2.map);
1198   ASSERT_THAT(sig2.Compute(), Eq(Status::OK()));
1199 
1200   EXPECT_TRUE(sig1 == sig2);
1201 
1202   // Change the short hash.
1203   ++sig2.sig_short;
1204   EXPECT_FALSE(sig1 == sig2);
1205 
1206   // Restore back.
1207   --sig2.sig_short;
1208   EXPECT_TRUE(sig1 == sig2);
1209 
1210   // Change the full hash.
1211   ++sig2.sig_full[0];
1212   EXPECT_FALSE(sig1 == sig2);
1213 
1214   // Restore back.
1215   --sig2.sig_full[0];
1216   EXPECT_TRUE(sig1 == sig2);
1217 
1218   // Make the nodes different.
1219   std::swap(sig2.nodes[0], sig2.nodes[1]);
1220   EXPECT_FALSE(sig1 == sig2);
1221 
1222   // Restore back.
1223   std::swap(sig2.nodes[0], sig2.nodes[1]);
1224   EXPECT_TRUE(sig1 == sig2);
1225 
1226   // Different number of nodes.
1227   sig2.nodes.emplace_back(sig2.nodes[0]);
1228   EXPECT_FALSE(sig1 == sig2);
1229   EXPECT_FALSE(sig2 == sig1);
1230 }
1231 
1232 }  // end namespace test
1233 }  // end namespace graph_analyzer
1234 }  // end namespace grappler
1235 }  // end namespace tensorflow
1236