1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/node_matchers.h"
17 
18 #include <utility>
19 #include "absl/algorithm/container.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_replace.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 
28 namespace tensorflow {
29 namespace testing {
30 namespace matchers {
31 namespace {
32 
33 using impl::NodeMatcherProperties;
34 using impl::OutEdge;
35 
IndentAllButFirstLine(absl::string_view text)36 string IndentAllButFirstLine(absl::string_view text) {
37   std::vector<std::string> lines = absl::StrSplit(text, '\n');
38   for (int i = 1; i < lines.size(); i++) {
39     lines[i].insert(0, "  ");
40   }
41   return absl::StrJoin(lines, "\n");
42 }
43 
44 template <typename T>
CompareTensor(const Tensor & actual,const Tensor & expected,::testing::MatchResultListener * listener)45 bool CompareTensor(const Tensor& actual, const Tensor& expected,
46                    ::testing::MatchResultListener* listener) {
47   if (actual.NumElements() != expected.NumElements()) {
48     if (listener->IsInterested()) {
49       *listener << "\nwas looking for tensor with " << expected.NumElements()
50                 << " elements, found tensor with " << actual.NumElements()
51                 << " elements";
52       return false;
53     }
54   }
55 
56   for (int64 i = 0, e = actual.NumElements(); i < e; i++) {
57     if (actual.flat<T>()(i) != expected.flat<T>()(i)) {
58       *listener << "\nmismatch in constant tensor at index " << i
59                 << " expected = " << expected.flat<T>()(i)
60                 << " actual = " << actual.flat<T>()(i);
61       return false;
62     }
63   }
64 
65   return true;
66 }
67 
MatchAndExplainTensor(const Tensor & tensor,const Tensor & expected_tensor,::testing::MatchResultListener * listener)68 bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
69                            ::testing::MatchResultListener* listener) {
70   if (tensor.dtype() != expected_tensor.dtype()) {
71     if (listener->IsInterested()) {
72       *listener << "\nexpected tensor of type "
73                 << DataType_Name(expected_tensor.dtype())
74                 << " but found one of type " << DataType_Name(tensor.dtype());
75       return false;
76     }
77   }
78 
79   switch (tensor.dtype()) {
80     case DT_FLOAT:
81       return CompareTensor<float>(tensor, expected_tensor, listener);
82     case DT_DOUBLE:
83       return CompareTensor<double>(tensor, expected_tensor, listener);
84     case DT_INT8:
85       return CompareTensor<int8>(tensor, expected_tensor, listener);
86     case DT_INT16:
87       return CompareTensor<int16>(tensor, expected_tensor, listener);
88     case DT_INT32:
89       return CompareTensor<int32>(tensor, expected_tensor, listener);
90     case DT_INT64:
91       return CompareTensor<int64>(tensor, expected_tensor, listener);
92     case DT_UINT8:
93       return CompareTensor<uint8>(tensor, expected_tensor, listener);
94     case DT_UINT16:
95       return CompareTensor<uint16>(tensor, expected_tensor, listener);
96     case DT_UINT32:
97       return CompareTensor<uint32>(tensor, expected_tensor, listener);
98     case DT_UINT64:
99       return CompareTensor<uint64>(tensor, expected_tensor, listener);
100     default:
101       LOG(FATAL) << "Unsupported dtype "  // Crash ok: testonly.
102                  << DataType_Name(tensor.dtype());
103   }
104 }
105 
106 struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
MatchAndExplaintensorflow::testing::matchers::__anon857546a50111::NodeMatcher107   bool MatchAndExplain(
108       const Node* node,
109       ::testing::MatchResultListener* listener) const override {
110     if (op && node->type_string() != *op) {
111       if (listener->IsInterested()) {
112         *listener << "\nexpected op " << *op << " but found "
113                   << node->type_string();
114       }
115       return false;
116     }
117 
118     if (assigned_device && node->assigned_device_name() != *assigned_device) {
119       if (listener->IsInterested()) {
120         *listener << "\nexpected assigned_device " << *assigned_device
121                   << " but found \"" << node->assigned_device_name() << "\"";
122       }
123       return false;
124     }
125 
126     if (name && node->name() != *name) {
127       if (listener->IsInterested()) {
128         *listener << "\nexpected name " << *name << " but found "
129                   << node->name();
130       }
131       return false;
132     }
133 
134     if (constant_value) {
135       const TensorProto* proto = nullptr;
136       if (!GetNodeAttr(node->def(), "value", &proto).ok()) {
137         if (listener->IsInterested()) {
138           *listener << "\ncould not find \"value\" attribute in node";
139         }
140         return false;
141       }
142 
143       Tensor tensor(proto->dtype());
144       if (!tensor.FromProto(*proto)) {
145         if (listener->IsInterested()) {
146           *listener << "\ncould not convert TensorProto in \"value\" attribute "
147                        "to Tensor";
148         }
149         return false;
150       }
151 
152       if (!MatchAndExplainTensor(/*tensor=*/tensor,
153                                  /*expected_tensor=*/*constant_value,
154                                  listener)) {
155         return false;
156       }
157     }
158 
159     if (input_matchers) {
160       if (input_matchers->size() != node->num_inputs()) {
161         if (listener->IsInterested()) {
162           *listener << "\nexpected " << input_matchers->size()
163                     << " inputs but node has " << node->num_inputs();
164         }
165         return false;
166       }
167 
168       for (int input_idx = 0, e = input_matchers->size(); input_idx < e;
169            input_idx++) {
170         if (!MatchAndExplainInput(node, input_idx, listener)) {
171           return false;
172         }
173       }
174     }
175 
176     std::vector<const Node*> control_deps;
177     for (const Edge* e : node->in_edges()) {
178       if (e->IsControlEdge()) {
179         control_deps.push_back(e->src());
180       }
181     }
182 
183     ::testing::StringMatchResultListener inner_listener;
184     if (control_dep_set &&
185         !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) {
186       if (listener->IsInterested()) {
187         string explanation = inner_listener.str();
188         if (!explanation.empty()) {
189           explanation = absl::StrCat(", ", explanation, ",");
190         }
191         *listener << "ctrl_deps" << explanation << " does not match expected: ";
192         control_dep_set->DescribeTo(listener->stream());
193       }
194       return false;
195     }
196 
197     const AttrValueMap attr_value_map = node->def().attr();
198     for (const auto& attr_kv_pair : attrs) {
199       auto it = attr_value_map.find(attr_kv_pair.first);
200       if (it == attr_value_map.end()) {
201         if (listener->IsInterested()) {
202           *listener << "did not find attribute named \"" << attr_kv_pair.first
203                     << "\" in node";
204         }
205         return false;
206       }
207       if (attr_kv_pair.second &&
208           !AreAttrValuesEqual(it->second, *attr_kv_pair.second)) {
209         if (listener->IsInterested()) {
210           *listener << "attribute named " << attr_kv_pair.first
211                     << " does not match value; expected: \""
212                     << SummarizeAttrValue(*attr_kv_pair.second)
213                     << "\", found: \"" << SummarizeAttrValue(it->second)
214                     << "\"";
215         }
216         return false;
217       }
218     }
219 
220     return true;
221   }
222 
DescribeTotensorflow::testing::matchers::__anon857546a50111::NodeMatcher223   void DescribeTo(::std::ostream* os) const override {
224     std::vector<string> predicates;
225 
226     if (name) {
227       predicates.push_back(absl::StrCat("name: ", *name));
228     }
229 
230     if (op) {
231       predicates.push_back(absl::StrCat("op: ", *op));
232     }
233 
234     if (assigned_device) {
235       predicates.push_back(absl::StrCat("assigned device: ", *assigned_device));
236     }
237 
238     bool printed_something = !predicates.empty();
239 
240     *os << absl::StrJoin(predicates, ", ");
241 
242     if (constant_value) {
243       printed_something = true;
244       *os << "constant value: " << constant_value->DebugString();
245     }
246 
247     if (input_matchers) {
248       if (!input_matchers->empty()) {
249         printed_something = true;
250         *os << " with " << (input_matchers->size() == 1 ? "only " : "")
251             << "input" << (input_matchers->size() == 1 ? "" : "s") << " ";
252       }
253 
254       if (input_matchers->size() == 1) {
255         ::std::stringstream ss;
256         input_matchers->front().DescribeTo(&ss);
257         printed_something = true;
258         *os << "matching " << ss.str();
259       } else {
260         int edge_idx = 0;
261         for (const ::testing::Matcher<OutEdge>& matcher : (*input_matchers)) {
262           *os << "\n  [" << edge_idx << "] matching (";
263           ::std::stringstream ss;
264           matcher.DescribeTo(&ss);
265           printed_something = true;
266           *os << IndentAllButFirstLine(ss.str());
267           *os << ")";
268           edge_idx++;
269         }
270       }
271     }
272 
273     if (control_dep_set) {
274       printed_something = true;
275       *os << " and control deps ";
276       control_dep_set->DescribeTo(os);
277     }
278 
279     if (!attrs.empty()) {
280       printed_something = true;
281       std::vector<string> attrs_str;
282       absl::c_transform(
283           attrs, std::back_inserter(attrs_str),
284           [](const std::pair<string, absl::optional<AttrValue>>& attr_kv_pair) {
285             return absl::StrCat(attr_kv_pair.first, "->",
286                                 attr_kv_pair.second
287                                     ? SummarizeAttrValue(*attr_kv_pair.second)
288                                     : "*");
289           });
290       *os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ")
291           << "]";
292     }
293 
294     if (!printed_something) {
295       *os << "is any node";
296     }
297   }
298 
MatchAndExplainInputtensorflow::testing::matchers::__anon857546a50111::NodeMatcher299   bool MatchAndExplainInput(const Node* node, int input_idx,
300                             ::testing::MatchResultListener* listener) const {
301     const Edge* edge;
302     if (!node->input_edge(input_idx, &edge).ok()) {
303       if (listener->IsInterested()) {
304         *listener << "\ncould not find incoming edge for input " << input_idx;
305       }
306       return false;
307     }
308 
309     ::testing::StringMatchResultListener inner_listener;
310     OutEdge input = {edge->src(), edge->src_output()};
311     if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) {
312       return true;
313     }
314 
315     if (listener->IsInterested()) {
316       *listener << "\ninput " << input_idx << " does not match expected:\n";
317       (*input_matchers)[input_idx].DescribeTo(listener->stream());
318       string explanation = inner_listener.str();
319       if (!explanation.empty()) {
320         *listener << ", " << explanation;
321       }
322     }
323     return false;
324   }
325 
326   absl::optional<string> op;
327   absl::optional<string> name;
328   absl::optional<string> assigned_device;
329   absl::optional<Tensor> constant_value;
330   absl::optional<std::vector<::testing::Matcher<OutEdge>>> input_matchers;
331   absl::optional<::testing::Matcher<absl::Span<const Node* const>>>
332       control_dep_set;
333   std::map<string, absl::optional<AttrValue>> attrs;
334 };
335 
336 // Matches a dst and dst_output on an input edge.  Today we only use this with
337 // dst_output=0 but we will eventually need to support multi-output operations.
338 class OutEdgeMatcher : public ::testing::MatcherInterface<OutEdge> {
339  public:
OutEdgeMatcher(::testing::Matcher<const Node * > src_matcher,int src_oidx)340   OutEdgeMatcher(::testing::Matcher<const Node*> src_matcher, int src_oidx)
341       : src_matcher_(std::move(src_matcher)), src_oidx_(src_oidx) {}
342 
MatchAndExplain(OutEdge out_edge,::testing::MatchResultListener * listener) const343   bool MatchAndExplain(
344       OutEdge out_edge,
345       ::testing::MatchResultListener* listener) const override {
346     ::testing::StringMatchResultListener inner_listener;
347     if (!src_matcher_.MatchAndExplain(out_edge.first, &inner_listener)) {
348       if (listener->IsInterested()) {
349         *listener << "\nsource does not match expected ";
350         src_matcher_.DescribeTo(listener->stream());
351         string explanation = inner_listener.str();
352         if (!explanation.empty()) {
353           *listener << "\n\t" << explanation;
354         }
355       }
356       return false;
357     }
358     if (out_edge.second != src_oidx_) {
359       if (listener->IsInterested()) {
360         *listener << "\nexpected output slot to be " << src_oidx_
361                   << " but found " << out_edge.second;
362       }
363       return false;
364     }
365 
366     return true;
367   }
368 
DescribeTo(::std::ostream * os) const369   void DescribeTo(::std::ostream* os) const override {
370     if (src_oidx_) {
371       *os << "output slot: " << src_oidx_ << ", source: (";
372     }
373 
374     src_matcher_.DescribeTo(os);
375 
376     if (src_oidx_) {
377       *os << ")";
378     }
379   }
380 
381  private:
382   ::testing::Matcher<const Node*> src_matcher_;
383   int src_oidx_;
384 };
385 }  // namespace
386 
NodeWith(absl::Span<const NodeMatcherProperties> props)387 ::testing::Matcher<const Node*> impl::NodeWith(
388     absl::Span<const NodeMatcherProperties> props) {
389   NodeMatcher* matcher = new NodeMatcher();
390   for (const NodeMatcherProperties& prop : props) {
391     if (prop.name()) {
392       DCHECK(!matcher->name);
393       matcher->name = prop.name();
394     }
395 
396     if (prop.op()) {
397       DCHECK(!matcher->op);
398       matcher->op = prop.op();
399     }
400 
401     if (prop.constant_value()) {
402       DCHECK(!matcher->constant_value);
403       matcher->constant_value = prop.constant_value();
404     }
405 
406     if (prop.assigned_device()) {
407       DCHECK(!matcher->assigned_device);
408       matcher->assigned_device = prop.assigned_device();
409     }
410 
411     if (prop.inputs()) {
412       DCHECK(!matcher->input_matchers);
413       matcher->input_matchers = *prop.inputs();
414     }
415 
416     if (prop.control_deps()) {
417       DCHECK(!matcher->control_dep_set);
418       matcher->control_dep_set =
419           ::testing::UnorderedElementsAreArray(*prop.control_deps());
420     }
421 
422     if (prop.attr()) {
423       auto insert_result = matcher->attrs.insert(*prop.attr());
424       DCHECK(insert_result.second);
425     }
426   }
427 
428   return ::testing::MakeMatcher(matcher);
429 }
430 
Name(string name)431 impl::NodeMatcherProperties Name(string name) {
432   impl::NodeMatcherProperties props;
433   props.set_name(std::move(name));
434   return props;
435 }
436 
437 // Matches a node with op `op`.
Op(string op)438 impl::NodeMatcherProperties Op(string op) {
439   impl::NodeMatcherProperties props;
440   props.set_op(std::move(op));
441   return props;
442 }
443 
444 // Matches a node with assigned device `assigned_device`.
AssignedDevice(string assigned_device)445 impl::NodeMatcherProperties AssignedDevice(string assigned_device) {
446   impl::NodeMatcherProperties props;
447   props.set_assigned_device(std::move(assigned_device));
448   return props;
449 }
450 
Inputs(absl::Span<const::testing::Matcher<OutEdge>> inputs)451 impl::NodeMatcherProperties impl::Inputs(
452     absl::Span<const ::testing::Matcher<OutEdge>> inputs) {
453   std::vector<::testing::Matcher<OutEdge>> inputs_vector;
454   absl::c_copy(inputs, std::back_inserter(inputs_vector));
455 
456   impl::NodeMatcherProperties props;
457   props.set_inputs(std::move(inputs_vector));
458   return props;
459 }
460 
CtrlDeps(absl::Span<const::testing::Matcher<const Node * >> control_deps)461 impl::NodeMatcherProperties impl::CtrlDeps(
462     absl::Span<const ::testing::Matcher<const Node*>> control_deps) {
463   std::vector<::testing::Matcher<const Node*>> control_deps_vector;
464   absl::c_copy(control_deps, std::back_inserter(control_deps_vector));
465 
466   impl::NodeMatcherProperties props;
467   props.set_control_deps(std::move(control_deps_vector));
468   return props;
469 }
470 
AttrLiteralHelper(const std::pair<string,bool> & bool_attr)471 std::pair<string, AttrValue> impl::AttrLiteralHelper(
472     const std::pair<string, bool>& bool_attr) {
473   AttrValue attr_value;
474   attr_value.set_b(bool_attr.second);
475   return {bool_attr.first, attr_value};
476 }
477 
AttrLiteralHelper(const std::pair<string,absl::Span<const int>> & int_list_attr)478 std::pair<string, AttrValue> impl::AttrLiteralHelper(
479     const std::pair<string, absl::Span<const int>>& int_list_attr) {
480   AttrValue attr_value;
481   AttrValue::ListValue* list = attr_value.mutable_list();
482   for (int i : int_list_attr.second) {
483     list->add_i(i);
484   }
485   return {int_list_attr.first, attr_value};
486 }
487 
AttrLiteralHelper(const std::pair<string,absl::Span<const string>> & string_list_attr)488 std::pair<string, AttrValue> impl::AttrLiteralHelper(
489     const std::pair<string, absl::Span<const string>>& string_list_attr) {
490   AttrValue attr_value;
491   AttrValue::ListValue* list = attr_value.mutable_list();
492   for (string s : string_list_attr.second) {
493     list->add_s(s);
494   }
495   return {string_list_attr.first, attr_value};
496 }
497 
Attr(std::pair<string,AttrValue> attr)498 impl::NodeMatcherProperties impl::Attr(std::pair<string, AttrValue> attr) {
499   impl::NodeMatcherProperties props;
500   props.set_attr(std::move(attr));
501   return props;
502 }
503 
Attr(string name)504 impl::NodeMatcherProperties impl::Attr(string name) {
505   impl::NodeMatcherProperties props;
506   props.set_attr({std::move(name), absl::nullopt});
507   return props;
508 }
509 
ConstantValue(const::tensorflow::Input::Initializer & val)510 NodeMatcherProperties ConstantValue(
511     const ::tensorflow::Input::Initializer& val) {
512   TF_CHECK_OK(val.status);
513   NodeMatcherProperties props;
514   props.set_constant_value(val.tensor);
515   return props;
516 }
517 
Const(const::tensorflow::Input::Initializer & val)518 ::testing::Matcher<impl::OutEdge> Const(
519     const ::tensorflow::Input::Initializer& val) {
520   return Out(NodeWith(ConstantValue(val)));
521 }
Out(int oidx,::testing::Matcher<const Node * > node_matcher)522 ::testing::Matcher<impl::OutEdge> Out(
523     int oidx, ::testing::Matcher<const Node*> node_matcher) {
524   return ::testing::MakeMatcher(new OutEdgeMatcher(node_matcher, oidx));
525 }
526 }  // namespace matchers
527 
FindNodeByName(Graph * g,absl::string_view name)528 Node* FindNodeByName(Graph* g, absl::string_view name) {
529   for (Node* n : g->nodes()) {
530     if (n->name() == name) {
531       return n;
532     }
533   }
534 
535   return nullptr;
536 }
537 }  // namespace testing
538 
PrintTo(const Node * n,::std::ostream * os)539 void PrintTo(const Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); }
PrintTo(Node * n,::std::ostream * os)540 void PrintTo(Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); }
541 }  // namespace tensorflow
542