1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Provides a set of matchers for tensorflow nodes.
17 //
18 // Example usage:
19 //
20 //  tensorflow::Node* node = ...;
21 //  EXPECT_THAT(node, NodeWith(Name("name"), Op("op"),
22 //                             Inputs(Out(3, NodeWith(Name("input"))))))
23 //
24 // Matchable node properties (the expressions that go inside NodeWith(...))
25 // are:
26 //
27 //  - Name(string): matches the node name exactly.  We will probably need to
28 //    have this take a string matcher soon in the future.
29 //
30 //  - Op(string): matches the op exactly.
31 //
32 //  - AssignedDevice(string): matches the assigned device exactly.
33 //
34 //  - Inputs(<ordered list>): matches the list of non-control inputs to the node
35 //    exactly (i.e. does not match a suffix or a prefix) where each element
36 //    matches an output of a node (see Out(idx, node) below).
37 //
38 //  - CtrlDeps(<unordered list>): matches the list of control dependences on the
39 //    node exactly but in any order.
40 //
41 //  - ConstantValue(tensorflow::Input::Initializer init): matches a Const node
42 //    with the constant value `init`.  Implies Op("Const").
43 //
44 //  - Attr(name, value): Matches a single attribute with name `name` and value
45 //    `value`.  Right now only boolean values are supported.
46 //
47 // Overlapping node properties may not be repeated in a single NodeWith(...)
48 // matcher.  E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail.  Since
49 // ConstantValue implies Op("Const"), a single NodeWith matcher can't have both
50 // ConstantValue(...) and Op(...).  Multiple Attr() values can be combined as
51 // long as the attribute names are different.
52 //
53 // Out(idx, node) matches the `idx`'th output of a node that matches `node`.
54 
55 #ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
56 #define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
57 
58 #include <array>
59 #include <string>
60 #include <vector>
61 
62 #include "absl/algorithm/container.h"
63 #include "absl/strings/string_view.h"
64 #include "absl/types/optional.h"
65 #include "absl/types/span.h"
66 #include "tensorflow/cc/framework/ops.h"
67 #include "tensorflow/compiler/xla/test.h"
68 #include "tensorflow/core/graph/graph.h"
69 
70 namespace tensorflow {
71 namespace testing {
72 namespace matchers {
73 
74 namespace impl {
75 
76 using OutEdge = std::pair<const Node*, int>;
77 
78 // -----------------------------------------------------------------------------
79 // Implementation details.
80 
81 // Properties that we match on for a particular Node.  If a particular property
82 // is nullopt then any value for it is allowed.
83 class NodeMatcherProperties {
84  public:
85   using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>;
86   using InputSeqMatcher = std::vector<::testing::Matcher<OutEdge>>;
87   using AttrKeyValuePair = std::pair<string, absl::optional<AttrValue>>;
88 
name()89   const absl::optional<string>& name() const { return name_; }
op()90   const absl::optional<string>& op() const { return op_; }
assigned_device()91   const absl::optional<string>& assigned_device() const {
92     return assigned_device_;
93   }
constant_value()94   const absl::optional<Tensor>& constant_value() const {
95     return constant_value_;
96   }
inputs()97   const absl::optional<InputSeqMatcher>& inputs() const {
98     return input_matchers_;
99   }
control_deps()100   const absl::optional<NodeSeqMatcher>& control_deps() const {
101     return control_deps_;
102   }
attr()103   const absl::optional<AttrKeyValuePair>& attr() const { return attr_; }
104 
set_name(string name)105   void set_name(string name) {
106     DCHECK(IsEmpty());
107     name_ = std::move(name);
108   }
109 
set_op(string op)110   void set_op(string op) {
111     DCHECK(IsEmpty());
112     op_ = std::move(op);
113   }
114 
set_assigned_device(string assigned_device)115   void set_assigned_device(string assigned_device) {
116     DCHECK(IsEmpty());
117     assigned_device_ = std::move(assigned_device);
118   }
119 
set_constant_value(Tensor constant_value)120   void set_constant_value(Tensor constant_value) {
121     DCHECK(IsEmpty());
122     constant_value_ = std::move(constant_value);
123     op_ = "Const";
124   }
125 
set_inputs(InputSeqMatcher inputs)126   void set_inputs(InputSeqMatcher inputs) {
127     DCHECK(IsEmpty());
128     input_matchers_ = std::move(inputs);
129   }
130 
set_control_deps(NodeSeqMatcher control_deps)131   void set_control_deps(NodeSeqMatcher control_deps) {
132     DCHECK(IsEmpty());
133     control_deps_ = std::move(control_deps);
134   }
135 
set_attr(AttrKeyValuePair attr)136   void set_attr(AttrKeyValuePair attr) {
137     DCHECK(IsEmpty());
138     attr_ = std::move(attr);
139   }
140 
IsEmpty()141   bool IsEmpty() const {
142     return !name().has_value() && !op().has_value() && !inputs().has_value() &&
143            !control_deps().has_value() && !attr().has_value();
144   }
145 
146  private:
147   absl::optional<string> name_;
148   absl::optional<string> op_;
149   absl::optional<string> assigned_device_;
150   absl::optional<Tensor> constant_value_;
151   absl::optional<InputSeqMatcher> input_matchers_;
152   absl::optional<NodeSeqMatcher> control_deps_;
153   absl::optional<AttrKeyValuePair> attr_;
154 };
155 
156 ::testing::Matcher<const Node*> NodeWith(
157     absl::Span<const NodeMatcherProperties> props);
158 
159 impl::NodeMatcherProperties Inputs(
160     absl::Span<const ::testing::Matcher<OutEdge>> inputs);
161 
162 impl::NodeMatcherProperties CtrlDeps(
163     absl::Span<const ::testing::Matcher<const Node*>> control_deps);
164 
165 impl::NodeMatcherProperties Attr(std::pair<string, AttrValue> attrs);
166 impl::NodeMatcherProperties Attr(string name);
167 
168 std::pair<string, AttrValue> AttrLiteralHelper(
169     const std::pair<string, bool>& bool_attr);
170 
171 std::pair<string, AttrValue> AttrLiteralHelper(
172     const std::pair<string, absl::Span<const int>>& int_list_attr);
173 
174 std::pair<string, AttrValue> AttrLiteralHelper(
175     const std::pair<string, absl::Span<const string>>& string_list_attr);
176 }  // namespace impl
177 
178 // -----------------------------------------------------------------------------
179 // Public interface.
180 
181 // Matches a node with name `name`.
182 impl::NodeMatcherProperties Name(string name);
183 
184 // Matches a node with op `op`.
185 impl::NodeMatcherProperties Op(string op);
186 
187 // Matches a node with assigned device `assigned_device`.
188 impl::NodeMatcherProperties AssignedDevice(string assigned_device);
189 
190 // Matches a node with a boolean typed attrbute named `name` and with value
191 // `value`.
192 template <typename ValueTy>
Attr(const string & name,ValueTy value)193 impl::NodeMatcherProperties Attr(const string& name, ValueTy value) {
194   return impl::Attr({impl::AttrLiteralHelper({name, value})});
195 }
196 
Attr(const string & name)197 inline impl::NodeMatcherProperties Attr(const string& name) {
198   return impl::Attr(name);
199 }
200 
201 // Matches a node with inputs `inputs`.
202 //
203 // `inputs` are ordered; `inputs`[i] must match input i.
204 template <typename... Ts>
Inputs(Ts...inputs)205 impl::NodeMatcherProperties Inputs(Ts... inputs) {
206   return impl::Inputs({inputs...});
207 }
208 
209 // Matches the `idx`'th output of a node that matches `node`.
210 ::testing::Matcher<impl::OutEdge> Out(int oidx,
211                                       ::testing::Matcher<const Node*> node);
212 
213 // Matches the first output of a node that matches `node`.
Out(::testing::Matcher<const Node * > node)214 inline ::testing::Matcher<impl::OutEdge> Out(
215     ::testing::Matcher<const Node*> node) {
216   return Out(0, node);
217 }
218 
219 // Matches a node with control dependences `control_deps`.
220 //
221 // `control_deps` are unordered and will match the control deps of a node in any
222 // order.
223 template <typename... Ts>
CtrlDeps(Ts...control_deps)224 impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) {
225   return impl::CtrlDeps({control_deps...});
226 }
227 
228 // Matches a constant node with value `val`.
229 impl::NodeMatcherProperties ConstantValue(
230     const ::tensorflow::Input::Initializer& val);
231 
232 // The main gmock matcher.  See file comment for example usage.
233 template <typename... Ts>
NodeWith(Ts...args)234 ::testing::Matcher<const Node*> NodeWith(Ts... args) {
235   std::array<impl::NodeMatcherProperties, sizeof...(Ts)> array = {args...};
236   return impl::NodeWith(array);
237 }
238 
239 ::testing::Matcher<impl::OutEdge> Const(
240     const ::tensorflow::Input::Initializer& val);
241 }  // namespace matchers
242 
243 // If `g` has a node named `name` returns it, otherwise returns null.
244 Node* FindNodeByName(Graph* g, absl::string_view name);
245 }  // namespace testing
246 
247 void PrintTo(const Node* n, ::std::ostream* os);
248 void PrintTo(Node* n, ::std::ostream* os);
249 }  // namespace tensorflow
250 
251 #endif  // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
252