1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/node_matchers.h"
17
18 #include <utility>
19 #include "absl/algorithm/container.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_replace.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27
28 namespace tensorflow {
29 namespace testing {
30 namespace matchers {
31 namespace {
32
33 using impl::NodeMatcherProperties;
34 using impl::OutEdge;
35
IndentAllButFirstLine(absl::string_view text)36 string IndentAllButFirstLine(absl::string_view text) {
37 std::vector<std::string> lines = absl::StrSplit(text, '\n');
38 for (int i = 1; i < lines.size(); i++) {
39 lines[i].insert(0, " ");
40 }
41 return absl::StrJoin(lines, "\n");
42 }
43
44 template <typename T>
CompareTensor(const Tensor & actual,const Tensor & expected,::testing::MatchResultListener * listener)45 bool CompareTensor(const Tensor& actual, const Tensor& expected,
46 ::testing::MatchResultListener* listener) {
47 if (actual.NumElements() != expected.NumElements()) {
48 if (listener->IsInterested()) {
49 *listener << "\nwas looking for tensor with " << expected.NumElements()
50 << " elements, found tensor with " << actual.NumElements()
51 << " elements";
52 return false;
53 }
54 }
55
56 for (int64 i = 0, e = actual.NumElements(); i < e; i++) {
57 if (actual.flat<T>()(i) != expected.flat<T>()(i)) {
58 *listener << "\nmismatch in constant tensor at index " << i
59 << " expected = " << expected.flat<T>()(i)
60 << " actual = " << actual.flat<T>()(i);
61 return false;
62 }
63 }
64
65 return true;
66 }
67
MatchAndExplainTensor(const Tensor & tensor,const Tensor & expected_tensor,::testing::MatchResultListener * listener)68 bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
69 ::testing::MatchResultListener* listener) {
70 if (tensor.dtype() != expected_tensor.dtype()) {
71 if (listener->IsInterested()) {
72 *listener << "\nexpected tensor of type "
73 << DataType_Name(expected_tensor.dtype())
74 << " but found one of type " << DataType_Name(tensor.dtype());
75 return false;
76 }
77 }
78
79 switch (tensor.dtype()) {
80 case DT_FLOAT:
81 return CompareTensor<float>(tensor, expected_tensor, listener);
82 case DT_DOUBLE:
83 return CompareTensor<double>(tensor, expected_tensor, listener);
84 case DT_INT8:
85 return CompareTensor<int8>(tensor, expected_tensor, listener);
86 case DT_INT16:
87 return CompareTensor<int16>(tensor, expected_tensor, listener);
88 case DT_INT32:
89 return CompareTensor<int32>(tensor, expected_tensor, listener);
90 case DT_INT64:
91 return CompareTensor<int64>(tensor, expected_tensor, listener);
92 case DT_UINT8:
93 return CompareTensor<uint8>(tensor, expected_tensor, listener);
94 case DT_UINT16:
95 return CompareTensor<uint16>(tensor, expected_tensor, listener);
96 case DT_UINT32:
97 return CompareTensor<uint32>(tensor, expected_tensor, listener);
98 case DT_UINT64:
99 return CompareTensor<uint64>(tensor, expected_tensor, listener);
100 default:
101 LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly.
102 << DataType_Name(tensor.dtype());
103 }
104 }
105
106 struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
MatchAndExplaintensorflow::testing::matchers::__anon857546a50111::NodeMatcher107 bool MatchAndExplain(
108 const Node* node,
109 ::testing::MatchResultListener* listener) const override {
110 if (op && node->type_string() != *op) {
111 if (listener->IsInterested()) {
112 *listener << "\nexpected op " << *op << " but found "
113 << node->type_string();
114 }
115 return false;
116 }
117
118 if (assigned_device && node->assigned_device_name() != *assigned_device) {
119 if (listener->IsInterested()) {
120 *listener << "\nexpected assigned_device " << *assigned_device
121 << " but found \"" << node->assigned_device_name() << "\"";
122 }
123 return false;
124 }
125
126 if (name && node->name() != *name) {
127 if (listener->IsInterested()) {
128 *listener << "\nexpected name " << *name << " but found "
129 << node->name();
130 }
131 return false;
132 }
133
134 if (constant_value) {
135 const TensorProto* proto = nullptr;
136 if (!GetNodeAttr(node->def(), "value", &proto).ok()) {
137 if (listener->IsInterested()) {
138 *listener << "\ncould not find \"value\" attribute in node";
139 }
140 return false;
141 }
142
143 Tensor tensor(proto->dtype());
144 if (!tensor.FromProto(*proto)) {
145 if (listener->IsInterested()) {
146 *listener << "\ncould not convert TensorProto in \"value\" attribute "
147 "to Tensor";
148 }
149 return false;
150 }
151
152 if (!MatchAndExplainTensor(/*tensor=*/tensor,
153 /*expected_tensor=*/*constant_value,
154 listener)) {
155 return false;
156 }
157 }
158
159 if (input_matchers) {
160 if (input_matchers->size() != node->num_inputs()) {
161 if (listener->IsInterested()) {
162 *listener << "\nexpected " << input_matchers->size()
163 << " inputs but node has " << node->num_inputs();
164 }
165 return false;
166 }
167
168 for (int input_idx = 0, e = input_matchers->size(); input_idx < e;
169 input_idx++) {
170 if (!MatchAndExplainInput(node, input_idx, listener)) {
171 return false;
172 }
173 }
174 }
175
176 std::vector<const Node*> control_deps;
177 for (const Edge* e : node->in_edges()) {
178 if (e->IsControlEdge()) {
179 control_deps.push_back(e->src());
180 }
181 }
182
183 ::testing::StringMatchResultListener inner_listener;
184 if (control_dep_set &&
185 !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) {
186 if (listener->IsInterested()) {
187 string explanation = inner_listener.str();
188 if (!explanation.empty()) {
189 explanation = absl::StrCat(", ", explanation, ",");
190 }
191 *listener << "ctrl_deps" << explanation << " does not match expected: ";
192 control_dep_set->DescribeTo(listener->stream());
193 }
194 return false;
195 }
196
197 const AttrValueMap attr_value_map = node->def().attr();
198 for (const auto& attr_kv_pair : attrs) {
199 auto it = attr_value_map.find(attr_kv_pair.first);
200 if (it == attr_value_map.end()) {
201 if (listener->IsInterested()) {
202 *listener << "did not find attribute named \"" << attr_kv_pair.first
203 << "\" in node";
204 }
205 return false;
206 }
207 if (attr_kv_pair.second &&
208 !AreAttrValuesEqual(it->second, *attr_kv_pair.second)) {
209 if (listener->IsInterested()) {
210 *listener << "attribute named " << attr_kv_pair.first
211 << " does not match value; expected: \""
212 << SummarizeAttrValue(*attr_kv_pair.second)
213 << "\", found: \"" << SummarizeAttrValue(it->second)
214 << "\"";
215 }
216 return false;
217 }
218 }
219
220 return true;
221 }
222
DescribeTotensorflow::testing::matchers::__anon857546a50111::NodeMatcher223 void DescribeTo(::std::ostream* os) const override {
224 std::vector<string> predicates;
225
226 if (name) {
227 predicates.push_back(absl::StrCat("name: ", *name));
228 }
229
230 if (op) {
231 predicates.push_back(absl::StrCat("op: ", *op));
232 }
233
234 if (assigned_device) {
235 predicates.push_back(absl::StrCat("assigned device: ", *assigned_device));
236 }
237
238 bool printed_something = !predicates.empty();
239
240 *os << absl::StrJoin(predicates, ", ");
241
242 if (constant_value) {
243 printed_something = true;
244 *os << "constant value: " << constant_value->DebugString();
245 }
246
247 if (input_matchers) {
248 if (!input_matchers->empty()) {
249 printed_something = true;
250 *os << " with " << (input_matchers->size() == 1 ? "only " : "")
251 << "input" << (input_matchers->size() == 1 ? "" : "s") << " ";
252 }
253
254 if (input_matchers->size() == 1) {
255 ::std::stringstream ss;
256 input_matchers->front().DescribeTo(&ss);
257 printed_something = true;
258 *os << "matching " << ss.str();
259 } else {
260 int edge_idx = 0;
261 for (const ::testing::Matcher<OutEdge>& matcher : (*input_matchers)) {
262 *os << "\n [" << edge_idx << "] matching (";
263 ::std::stringstream ss;
264 matcher.DescribeTo(&ss);
265 printed_something = true;
266 *os << IndentAllButFirstLine(ss.str());
267 *os << ")";
268 edge_idx++;
269 }
270 }
271 }
272
273 if (control_dep_set) {
274 printed_something = true;
275 *os << " and control deps ";
276 control_dep_set->DescribeTo(os);
277 }
278
279 if (!attrs.empty()) {
280 printed_something = true;
281 std::vector<string> attrs_str;
282 absl::c_transform(
283 attrs, std::back_inserter(attrs_str),
284 [](const std::pair<string, absl::optional<AttrValue>>& attr_kv_pair) {
285 return absl::StrCat(attr_kv_pair.first, "->",
286 attr_kv_pair.second
287 ? SummarizeAttrValue(*attr_kv_pair.second)
288 : "*");
289 });
290 *os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ")
291 << "]";
292 }
293
294 if (!printed_something) {
295 *os << "is any node";
296 }
297 }
298
MatchAndExplainInputtensorflow::testing::matchers::__anon857546a50111::NodeMatcher299 bool MatchAndExplainInput(const Node* node, int input_idx,
300 ::testing::MatchResultListener* listener) const {
301 const Edge* edge;
302 if (!node->input_edge(input_idx, &edge).ok()) {
303 if (listener->IsInterested()) {
304 *listener << "\ncould not find incoming edge for input " << input_idx;
305 }
306 return false;
307 }
308
309 ::testing::StringMatchResultListener inner_listener;
310 OutEdge input = {edge->src(), edge->src_output()};
311 if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) {
312 return true;
313 }
314
315 if (listener->IsInterested()) {
316 *listener << "\ninput " << input_idx << " does not match expected:\n";
317 (*input_matchers)[input_idx].DescribeTo(listener->stream());
318 string explanation = inner_listener.str();
319 if (!explanation.empty()) {
320 *listener << ", " << explanation;
321 }
322 }
323 return false;
324 }
325
326 absl::optional<string> op;
327 absl::optional<string> name;
328 absl::optional<string> assigned_device;
329 absl::optional<Tensor> constant_value;
330 absl::optional<std::vector<::testing::Matcher<OutEdge>>> input_matchers;
331 absl::optional<::testing::Matcher<absl::Span<const Node* const>>>
332 control_dep_set;
333 std::map<string, absl::optional<AttrValue>> attrs;
334 };
335
336 // Matches a dst and dst_output on an input edge. Today we only use this with
337 // dst_output=0 but we will eventually need to support multi-output operations.
338 class OutEdgeMatcher : public ::testing::MatcherInterface<OutEdge> {
339 public:
OutEdgeMatcher(::testing::Matcher<const Node * > src_matcher,int src_oidx)340 OutEdgeMatcher(::testing::Matcher<const Node*> src_matcher, int src_oidx)
341 : src_matcher_(std::move(src_matcher)), src_oidx_(src_oidx) {}
342
MatchAndExplain(OutEdge out_edge,::testing::MatchResultListener * listener) const343 bool MatchAndExplain(
344 OutEdge out_edge,
345 ::testing::MatchResultListener* listener) const override {
346 ::testing::StringMatchResultListener inner_listener;
347 if (!src_matcher_.MatchAndExplain(out_edge.first, &inner_listener)) {
348 if (listener->IsInterested()) {
349 *listener << "\nsource does not match expected ";
350 src_matcher_.DescribeTo(listener->stream());
351 string explanation = inner_listener.str();
352 if (!explanation.empty()) {
353 *listener << "\n\t" << explanation;
354 }
355 }
356 return false;
357 }
358 if (out_edge.second != src_oidx_) {
359 if (listener->IsInterested()) {
360 *listener << "\nexpected output slot to be " << src_oidx_
361 << " but found " << out_edge.second;
362 }
363 return false;
364 }
365
366 return true;
367 }
368
DescribeTo(::std::ostream * os) const369 void DescribeTo(::std::ostream* os) const override {
370 if (src_oidx_) {
371 *os << "output slot: " << src_oidx_ << ", source: (";
372 }
373
374 src_matcher_.DescribeTo(os);
375
376 if (src_oidx_) {
377 *os << ")";
378 }
379 }
380
381 private:
382 ::testing::Matcher<const Node*> src_matcher_;
383 int src_oidx_;
384 };
385 } // namespace
386
NodeWith(absl::Span<const NodeMatcherProperties> props)387 ::testing::Matcher<const Node*> impl::NodeWith(
388 absl::Span<const NodeMatcherProperties> props) {
389 NodeMatcher* matcher = new NodeMatcher();
390 for (const NodeMatcherProperties& prop : props) {
391 if (prop.name()) {
392 DCHECK(!matcher->name);
393 matcher->name = prop.name();
394 }
395
396 if (prop.op()) {
397 DCHECK(!matcher->op);
398 matcher->op = prop.op();
399 }
400
401 if (prop.constant_value()) {
402 DCHECK(!matcher->constant_value);
403 matcher->constant_value = prop.constant_value();
404 }
405
406 if (prop.assigned_device()) {
407 DCHECK(!matcher->assigned_device);
408 matcher->assigned_device = prop.assigned_device();
409 }
410
411 if (prop.inputs()) {
412 DCHECK(!matcher->input_matchers);
413 matcher->input_matchers = *prop.inputs();
414 }
415
416 if (prop.control_deps()) {
417 DCHECK(!matcher->control_dep_set);
418 matcher->control_dep_set =
419 ::testing::UnorderedElementsAreArray(*prop.control_deps());
420 }
421
422 if (prop.attr()) {
423 auto insert_result = matcher->attrs.insert(*prop.attr());
424 DCHECK(insert_result.second);
425 }
426 }
427
428 return ::testing::MakeMatcher(matcher);
429 }
430
Name(string name)431 impl::NodeMatcherProperties Name(string name) {
432 impl::NodeMatcherProperties props;
433 props.set_name(std::move(name));
434 return props;
435 }
436
437 // Matches a node with op `op`.
Op(string op)438 impl::NodeMatcherProperties Op(string op) {
439 impl::NodeMatcherProperties props;
440 props.set_op(std::move(op));
441 return props;
442 }
443
444 // Matches a node with assigned device `assigned_device`.
AssignedDevice(string assigned_device)445 impl::NodeMatcherProperties AssignedDevice(string assigned_device) {
446 impl::NodeMatcherProperties props;
447 props.set_assigned_device(std::move(assigned_device));
448 return props;
449 }
450
Inputs(absl::Span<const::testing::Matcher<OutEdge>> inputs)451 impl::NodeMatcherProperties impl::Inputs(
452 absl::Span<const ::testing::Matcher<OutEdge>> inputs) {
453 std::vector<::testing::Matcher<OutEdge>> inputs_vector;
454 absl::c_copy(inputs, std::back_inserter(inputs_vector));
455
456 impl::NodeMatcherProperties props;
457 props.set_inputs(std::move(inputs_vector));
458 return props;
459 }
460
CtrlDeps(absl::Span<const::testing::Matcher<const Node * >> control_deps)461 impl::NodeMatcherProperties impl::CtrlDeps(
462 absl::Span<const ::testing::Matcher<const Node*>> control_deps) {
463 std::vector<::testing::Matcher<const Node*>> control_deps_vector;
464 absl::c_copy(control_deps, std::back_inserter(control_deps_vector));
465
466 impl::NodeMatcherProperties props;
467 props.set_control_deps(std::move(control_deps_vector));
468 return props;
469 }
470
AttrLiteralHelper(const std::pair<string,bool> & bool_attr)471 std::pair<string, AttrValue> impl::AttrLiteralHelper(
472 const std::pair<string, bool>& bool_attr) {
473 AttrValue attr_value;
474 attr_value.set_b(bool_attr.second);
475 return {bool_attr.first, attr_value};
476 }
477
AttrLiteralHelper(const std::pair<string,absl::Span<const int>> & int_list_attr)478 std::pair<string, AttrValue> impl::AttrLiteralHelper(
479 const std::pair<string, absl::Span<const int>>& int_list_attr) {
480 AttrValue attr_value;
481 AttrValue::ListValue* list = attr_value.mutable_list();
482 for (int i : int_list_attr.second) {
483 list->add_i(i);
484 }
485 return {int_list_attr.first, attr_value};
486 }
487
AttrLiteralHelper(const std::pair<string,absl::Span<const string>> & string_list_attr)488 std::pair<string, AttrValue> impl::AttrLiteralHelper(
489 const std::pair<string, absl::Span<const string>>& string_list_attr) {
490 AttrValue attr_value;
491 AttrValue::ListValue* list = attr_value.mutable_list();
492 for (string s : string_list_attr.second) {
493 list->add_s(s);
494 }
495 return {string_list_attr.first, attr_value};
496 }
497
Attr(std::pair<string,AttrValue> attr)498 impl::NodeMatcherProperties impl::Attr(std::pair<string, AttrValue> attr) {
499 impl::NodeMatcherProperties props;
500 props.set_attr(std::move(attr));
501 return props;
502 }
503
Attr(string name)504 impl::NodeMatcherProperties impl::Attr(string name) {
505 impl::NodeMatcherProperties props;
506 props.set_attr({std::move(name), absl::nullopt});
507 return props;
508 }
509
ConstantValue(const::tensorflow::Input::Initializer & val)510 NodeMatcherProperties ConstantValue(
511 const ::tensorflow::Input::Initializer& val) {
512 TF_CHECK_OK(val.status);
513 NodeMatcherProperties props;
514 props.set_constant_value(val.tensor);
515 return props;
516 }
517
Const(const::tensorflow::Input::Initializer & val)518 ::testing::Matcher<impl::OutEdge> Const(
519 const ::tensorflow::Input::Initializer& val) {
520 return Out(NodeWith(ConstantValue(val)));
521 }
Out(int oidx,::testing::Matcher<const Node * > node_matcher)522 ::testing::Matcher<impl::OutEdge> Out(
523 int oidx, ::testing::Matcher<const Node*> node_matcher) {
524 return ::testing::MakeMatcher(new OutEdgeMatcher(node_matcher, oidx));
525 }
526 } // namespace matchers
527
FindNodeByName(Graph * g,absl::string_view name)528 Node* FindNodeByName(Graph* g, absl::string_view name) {
529 for (Node* n : g->nodes()) {
530 if (n->name() == name) {
531 return n;
532 }
533 }
534
535 return nullptr;
536 }
537 } // namespace testing
538
PrintTo(const Node * n,::std::ostream * os)539 void PrintTo(const Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); }
PrintTo(Node * n,::std::ostream * os)540 void PrintTo(Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); }
541 } // namespace tensorflow
542