1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/equal_graph_def.h"
17 
18 #include <unordered_map>
19 #include <unordered_set>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 
30 namespace tensorflow {
31 
EqualGraphDef(const GraphDef & actual,const GraphDef & expected,string * diff,const EqualGraphDefOptions & options)32 bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
33                    string* diff, const EqualGraphDefOptions& options) {
34   // Intentionally do not check that versions match so that this routine can
35   // be used for less brittle golden file tests.
36   return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
37 }
38 
GraphDefHash(const GraphDef & gdef,const EqualGraphDefOptions & options)39 uint64 GraphDefHash(const GraphDef& gdef, const EqualGraphDefOptions& options) {
40   return RepeatedNodeDefHash(gdef.node(), options);
41 }
42 
EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef> & actual,const protobuf::RepeatedPtrField<NodeDef> & expected,string * diff,const EqualGraphDefOptions & options)43 bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
44                           const protobuf::RepeatedPtrField<NodeDef>& expected,
45                           string* diff, const EqualGraphDefOptions& options) {
46   std::unordered_map<string, const NodeDef*> actual_index;
47   for (const NodeDef& node : actual) {
48     actual_index[node.name()] = &node;
49   }
50 
51   for (const NodeDef& expected_node : expected) {
52     auto actual_iter = actual_index.find(expected_node.name());
53     if (actual_iter == actual_index.end()) {
54       if (diff != nullptr) {
55         *diff = strings::StrCat("Did not find expected node '",
56                                 SummarizeNodeDef(expected_node), "'");
57       }
58       return false;
59     }
60 
61     if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
62       return false;
63     }
64 
65     actual_index.erase(actual_iter);
66   }
67 
68   if (!actual_index.empty()) {
69     if (diff != nullptr) {
70       *diff =
71           strings::StrCat("Found unexpected node '",
72                           SummarizeNodeDef(*actual_index.begin()->second), "'");
73     }
74     return false;
75   }
76 
77   return true;
78 }
79 
RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef> & ndefs,const EqualGraphDefOptions & options)80 uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef>& ndefs,
81                            const EqualGraphDefOptions& options) {
82   uint64 h = 0xDECAFCAFFE;
83   // Insert NodeDefs into map to deterministically sort by name
84   std::map<string, const NodeDef*> nodes;
85   for (const NodeDef& node : ndefs) {
86     nodes[node.name()] = &node;
87   }
88   for (const auto& pair : nodes) {
89     h = Hash64(pair.first.data(), pair.first.size(), h);
90     h = Hash64Combine(NodeDefHash(*pair.second, options), h);
91   }
92   return h;
93 }
94 
95 namespace {
96 
JoinStringField(const protobuf::RepeatedPtrField<string> & f)97 string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
98   string ret;
99   for (int i = 0; i < f.size(); ++i) {
100     if (i > 0) strings::StrAppend(&ret, ", ");
101     strings::StrAppend(&ret, f.Get(i));
102   }
103   return ret;
104 }
105 
106 }  // namespace
107 
EqualNodeDef(const NodeDef & actual,const NodeDef & expected,string * diff,const EqualGraphDefOptions & options)108 bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
109                   const EqualGraphDefOptions& options) {
110   if (actual.name() != expected.name()) {
111     if (diff != nullptr) {
112       *diff = strings::StrCat("Actual node name '", actual.name(),
113                               "' is not expected '", expected.name(), "'");
114     }
115     return false;
116   }
117 
118   if (actual.op() != expected.op()) {
119     if (diff != nullptr) {
120       *diff = strings::StrCat("Node named '", actual.name(), "' has op '",
121                               actual.op(), "' that is not expected '",
122                               expected.op(), "'");
123     }
124     return false;
125   }
126 
127   if (actual.device() != expected.device()) {
128     if (diff != nullptr) {
129       *diff = strings::StrCat("Node named '", actual.name(), "' has device '",
130                               actual.device(), "' that is not expected '",
131                               expected.device(), "'");
132     }
133     return false;
134   }
135 
136   if (actual.input_size() != expected.input_size()) {
137     if (diff != nullptr) {
138       *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '",
139                               JoinStringField(actual.input()),
140                               "' that don't match expected '",
141                               JoinStringField(expected.input()), "'");
142     }
143     return false;
144   }
145 
146   int first_control_input = actual.input_size();
147   for (int i = 0; i < actual.input_size(); ++i) {
148     if (str_util::StartsWith(actual.input(i), "^")) {
149       first_control_input = i;
150       break;
151     }
152     // Special case for inputs: "tensor" is equivalent to "tensor:0"
153     if (actual.input(i) != expected.input(i) &&
154         actual.input(i) != strings::StrCat(expected.input(i), ":0") &&
155         strings::StrCat(actual.input(i), ":0") != expected.input(i)) {
156       if (diff != nullptr) {
157         *diff = strings::StrCat("Node named '", actual.name(), "' has input ",
158                                 i, " '", actual.input(i),
159                                 "' that doesn't match expected '",
160                                 expected.input(i), "'");
161       }
162       return false;
163     }
164   }
165 
166   std::unordered_set<string> actual_control;
167   std::unordered_set<string> expected_control;
168   for (int i = first_control_input; i < actual.input_size(); ++i) {
169     actual_control.insert(actual.input(i));
170     expected_control.insert(expected.input(i));
171   }
172   for (const auto& e : expected_control) {
173     if (actual_control.erase(e) == 0) {
174       if (diff != nullptr) {
175         *diff = strings::StrCat("Node named '", actual.name(),
176                                 "' missing expected control input '", e, "'");
177       }
178       return false;
179     }
180   }
181   if (!actual_control.empty()) {
182     if (diff != nullptr) {
183       *diff = strings::StrCat("Node named '", actual.name(),
184                               "' has unexpected control input '",
185                               *actual_control.begin(), "'");
186     }
187     return false;
188   }
189 
190   std::unordered_set<string> actual_attr;
191   for (const auto& a : actual.attr()) {
192     if (options.ignore_internal_attrs && !a.first.empty() &&
193         a.first[0] == '_') {
194       continue;
195     }
196     actual_attr.insert(a.first);
197   }
198   for (const auto& e : expected.attr()) {
199     if (options.ignore_internal_attrs && !e.first.empty() &&
200         e.first[0] == '_') {
201       continue;
202     }
203 
204     if (actual_attr.erase(e.first) == 0) {
205       if (diff != nullptr) {
206         *diff = strings::StrCat("Node named '", actual.name(),
207                                 "' missing expected attr '", e.first,
208                                 "' with value: ", SummarizeAttrValue(e.second));
209       }
210       return false;
211     }
212     auto iter = actual.attr().find(e.first);
213     if (!AreAttrValuesEqual(e.second, iter->second)) {
214       if (diff != nullptr) {
215         *diff = strings::StrCat(
216             "Node named '", actual.name(), "' has attr '", e.first,
217             "' with value: ", SummarizeAttrValue(iter->second),
218             " that does not match expected: ", SummarizeAttrValue(e.second));
219       }
220       return false;
221     }
222   }
223   if (!actual_attr.empty()) {
224     if (diff != nullptr) {
225       *diff = strings::StrCat(
226           "Node named '", actual.name(), "' has unexpected attr '",
227           *actual_attr.begin(), "' with value: ",
228           SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second));
229     }
230     return false;
231   }
232 
233   return true;
234 }
235 
NodeDefHash(const NodeDef & ndef,const EqualGraphDefOptions & options)236 uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) {
237   uint64 h = Hash64(ndef.name());
238   h = Hash64(ndef.op().data(), ndef.op().size(), h);
239   h = Hash64(ndef.device().data(), ndef.device().size(), h);
240 
241   // Normal inputs. Order important.
242   int first_control_input = ndef.input_size();
243   for (int i = 0; i < ndef.input_size(); ++i) {
244     if (str_util::StartsWith(ndef.input(i), "^")) {
245       first_control_input = i;
246       break;
247     }
248     h = Hash64(ndef.input(i).data(), ndef.input(i).size(), h);
249   }
250 
251   // Control inputs. Order irrelevant.
252   std::set<string> ndef_control;
253   for (int i = first_control_input; i < ndef.input_size(); ++i) {
254     ndef_control.insert(ndef.input(i));
255   }
256   for (const string& s : ndef_control) {
257     h = Hash64(s.data(), s.size(), h);
258   }
259 
260   // Attributes
261   std::map<string, AttrValue> ndef_attr;
262   for (const auto& a : ndef.attr()) {
263     if (options.ignore_internal_attrs && !a.first.empty() &&
264         a.first[0] == '_') {
265       continue;
266     }
267     ndef_attr[a.first] = a.second;
268   }
269   for (const auto& a : ndef_attr) {
270     h = Hash64(a.first.data(), a.first.size(), h);
271     h = Hash64Combine(AttrValueHash(a.second), h);
272   }
273 
274   return h;
275 }
276 
277 }  // namespace tensorflow
278