1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/equal_graph_def.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/protobuf.h"
28
29 namespace tensorflow {
30
EqualGraphDef(const GraphDef & actual,const GraphDef & expected,string * diff,const EqualGraphDefOptions & options)31 bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
32 string* diff, const EqualGraphDefOptions& options) {
33 // Intentionally do not check that versions match so that this routine can
34 // be used for less brittle golden file tests.
35 return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
36 }
37
GraphDefHash(const GraphDef & gdef,const EqualGraphDefOptions & options)38 uint64 GraphDefHash(const GraphDef& gdef, const EqualGraphDefOptions& options) {
39 return RepeatedNodeDefHash(gdef.node(), options);
40 }
41
EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef> & actual,const protobuf::RepeatedPtrField<NodeDef> & expected,string * diff,const EqualGraphDefOptions & options)42 bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
43 const protobuf::RepeatedPtrField<NodeDef>& expected,
44 string* diff, const EqualGraphDefOptions& options) {
45 std::unordered_map<string, const NodeDef*> actual_index;
46 for (const NodeDef& node : actual) {
47 actual_index[node.name()] = &node;
48 }
49
50 for (const NodeDef& expected_node : expected) {
51 auto actual_iter = actual_index.find(expected_node.name());
52 if (actual_iter == actual_index.end()) {
53 if (diff != nullptr) {
54 *diff = strings::StrCat("Did not find expected node '",
55 SummarizeNodeDef(expected_node), "'");
56 }
57 return false;
58 }
59
60 if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
61 return false;
62 }
63
64 actual_index.erase(actual_iter);
65 }
66
67 if (!actual_index.empty()) {
68 if (diff != nullptr) {
69 *diff =
70 strings::StrCat("Found unexpected node '",
71 SummarizeNodeDef(*actual_index.begin()->second), "'");
72 }
73 return false;
74 }
75
76 return true;
77 }
78
RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef> & ndefs,const EqualGraphDefOptions & options)79 uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef>& ndefs,
80 const EqualGraphDefOptions& options) {
81 uint64 h = 0xDECAFCAFFE;
82 // Insert NodeDefs into map to deterministically sort by name
83 std::map<string, const NodeDef*> nodes;
84 for (const NodeDef& node : ndefs) {
85 nodes[node.name()] = &node;
86 }
87 for (const auto& pair : nodes) {
88 h = Hash64(pair.first.data(), pair.first.size(), h);
89 h = Hash64Combine(NodeDefHash(*pair.second, options), h);
90 }
91 return h;
92 }
93
94 namespace {
95
JoinStringField(const protobuf::RepeatedPtrField<string> & f)96 string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
97 string ret;
98 for (int i = 0; i < f.size(); ++i) {
99 if (i > 0) strings::StrAppend(&ret, ", ");
100 strings::StrAppend(&ret, f.Get(i));
101 }
102 return ret;
103 }
104
105 } // namespace
106
EqualNodeDef(const NodeDef & actual,const NodeDef & expected,string * diff,const EqualGraphDefOptions & options)107 bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
108 const EqualGraphDefOptions& options) {
109 if (actual.name() != expected.name()) {
110 if (diff != nullptr) {
111 *diff = strings::StrCat("Actual node name '", actual.name(),
112 "' is not expected '", expected.name(), "'");
113 }
114 return false;
115 }
116
117 if (actual.op() != expected.op()) {
118 if (diff != nullptr) {
119 *diff = strings::StrCat("Node named '", actual.name(), "' has op '",
120 actual.op(), "' that is not expected '",
121 expected.op(), "'");
122 }
123 return false;
124 }
125
126 if (actual.device() != expected.device()) {
127 if (diff != nullptr) {
128 *diff = strings::StrCat("Node named '", actual.name(), "' has device '",
129 actual.device(), "' that is not expected '",
130 expected.device(), "'");
131 }
132 return false;
133 }
134
135 if (actual.input_size() != expected.input_size()) {
136 if (diff != nullptr) {
137 *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '",
138 JoinStringField(actual.input()),
139 "' that don't match expected '",
140 JoinStringField(expected.input()), "'");
141 }
142 return false;
143 }
144
145 int first_control_input = actual.input_size();
146 for (int i = 0; i < actual.input_size(); ++i) {
147 if (StringPiece(actual.input(i)).starts_with("^")) {
148 first_control_input = i;
149 break;
150 }
151 // Special case for inputs: "tensor" is equivalent to "tensor:0"
152 if (actual.input(i) != expected.input(i) &&
153 actual.input(i) != strings::StrCat(expected.input(i), ":0") &&
154 strings::StrCat(actual.input(i), ":0") != expected.input(i)) {
155 if (diff != nullptr) {
156 *diff = strings::StrCat("Node named '", actual.name(), "' has input ",
157 i, " '", actual.input(i),
158 "' that doesn't match expected '",
159 expected.input(i), "'");
160 }
161 return false;
162 }
163 }
164
165 std::unordered_set<string> actual_control;
166 std::unordered_set<string> expected_control;
167 for (int i = first_control_input; i < actual.input_size(); ++i) {
168 actual_control.insert(actual.input(i));
169 expected_control.insert(expected.input(i));
170 }
171 for (const auto& e : expected_control) {
172 if (actual_control.erase(e) == 0) {
173 if (diff != nullptr) {
174 *diff = strings::StrCat("Node named '", actual.name(),
175 "' missing expected control input '", e, "'");
176 }
177 return false;
178 }
179 }
180 if (!actual_control.empty()) {
181 if (diff != nullptr) {
182 *diff = strings::StrCat("Node named '", actual.name(),
183 "' has unexpected control input '",
184 *actual_control.begin(), "'");
185 }
186 return false;
187 }
188
189 std::unordered_set<string> actual_attr;
190 for (const auto& a : actual.attr()) {
191 if (options.ignore_internal_attrs && !a.first.empty() &&
192 a.first[0] == '_') {
193 continue;
194 }
195 actual_attr.insert(a.first);
196 }
197 for (const auto& e : expected.attr()) {
198 if (options.ignore_internal_attrs && !e.first.empty() &&
199 e.first[0] == '_') {
200 continue;
201 }
202
203 if (actual_attr.erase(e.first) == 0) {
204 if (diff != nullptr) {
205 *diff = strings::StrCat("Node named '", actual.name(),
206 "' missing expected attr '", e.first,
207 "' with value: ", SummarizeAttrValue(e.second));
208 }
209 return false;
210 }
211 auto iter = actual.attr().find(e.first);
212 if (!AreAttrValuesEqual(e.second, iter->second)) {
213 if (diff != nullptr) {
214 *diff = strings::StrCat(
215 "Node named '", actual.name(), "' has attr '", e.first,
216 "' with value: ", SummarizeAttrValue(iter->second),
217 " that does not match expected: ", SummarizeAttrValue(e.second));
218 }
219 return false;
220 }
221 }
222 if (!actual_attr.empty()) {
223 if (diff != nullptr) {
224 *diff = strings::StrCat(
225 "Node named '", actual.name(), "' has unexpected attr '",
226 *actual_attr.begin(), "' with value: ",
227 SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second));
228 }
229 return false;
230 }
231
232 return true;
233 }
234
NodeDefHash(const NodeDef & ndef,const EqualGraphDefOptions & options)235 uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) {
236 uint64 h = Hash64(ndef.name());
237 h = Hash64(ndef.op().data(), ndef.op().size(), h);
238 h = Hash64(ndef.device().data(), ndef.device().size(), h);
239
240 // Normal inputs. Order important.
241 int first_control_input = ndef.input_size();
242 for (int i = 0; i < ndef.input_size(); ++i) {
243 if (StringPiece(ndef.input(i)).starts_with("^")) {
244 first_control_input = i;
245 break;
246 }
247 h = Hash64(ndef.input(i).data(), ndef.input(i).size(), h);
248 }
249
250 // Control inputs. Order irrelevant.
251 std::set<string> ndef_control;
252 for (int i = first_control_input; i < ndef.input_size(); ++i) {
253 ndef_control.insert(ndef.input(i));
254 }
255 for (const string& s : ndef_control) {
256 h = Hash64(s.data(), s.size(), h);
257 }
258
259 // Attributes
260 std::map<string, AttrValue> ndef_attr;
261 for (const auto& a : ndef.attr()) {
262 if (options.ignore_internal_attrs && !a.first.empty() &&
263 a.first[0] == '_') {
264 continue;
265 }
266 ndef_attr[a.first] = a.second;
267 }
268 for (const auto& a : ndef_attr) {
269 h = Hash64(a.first.data(), a.first.size(), h);
270 h = Hash64Combine(AttrValueHash(a.second), h);
271 }
272
273 return h;
274 }
275
276 } // namespace tensorflow
277