1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Provides a set of matchers for tensorflow nodes.
17 //
18 // Example usage:
19 //
20 // tensorflow::Node* node = ...;
21 // EXPECT_THAT(node, NodeWith(Name("name"), Op("op"),
22 // Inputs(Out(3, NodeWith(Name("input"))))))
23 //
24 // Matchable node properties (the expressions that go inside NodeWith(...))
25 // are:
26 //
27 // - Name(string): matches the node name exactly. We will probably need to
28 // have this take a string matcher soon in the future.
29 //
30 // - Op(string): matches the op exactly.
31 //
32 // - AssignedDevice(string): matches the assigned device exactly.
33 //
34 // - Inputs(<ordered list>): matches the list of non-control inputs to the node
35 // exactly (i.e. does not match a suffix or a prefix) where each element
36 // matches an output of a node (see Out(idx, node) below).
37 //
38 // - CtrlDeps(<unordered list>): matches the list of control dependences on the
39 // node exactly but in any order.
40 //
41 // - ConstantValue(tensorflow::Input::Initializer init): matches a Const node
42 // with the constant value `init`. Implies Op("Const").
43 //
44 // - Attr(name, value): Matches a single attribute with name `name` and value
45 // `value`. Right now only boolean values are supported.
46 //
47 // Overlapping node properties may not be repeated in a single NodeWith(...)
48 // matcher. E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since
49 // ConstantValue implies Op("Const"), a single NodeWith matcher can't have both
50 // ConstantValue(...) and Op(...). Multiple Attr() values can be combined as
51 // long as the attribute names are different.
52 //
53 // Out(idx, node) matches the `idx`'th output of a node that matches `node`.
54
55 #ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
56 #define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
57
58 #include <array>
59 #include <string>
60 #include <vector>
61
62 #include "absl/algorithm/container.h"
63 #include "absl/strings/string_view.h"
64 #include "absl/types/optional.h"
65 #include "absl/types/span.h"
66 #include "tensorflow/cc/framework/ops.h"
67 #include "tensorflow/compiler/xla/test.h"
68 #include "tensorflow/core/graph/graph.h"
69
70 namespace tensorflow {
71 namespace testing {
72 namespace matchers {
73
74 namespace impl {
75
76 using OutEdge = std::pair<const Node*, int>;
77
78 // -----------------------------------------------------------------------------
79 // Implementation details.
80
81 // Properties that we match on for a particular Node. If a particular property
82 // is nullopt then any value for it is allowed.
83 class NodeMatcherProperties {
84 public:
85 using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>;
86 using InputSeqMatcher = std::vector<::testing::Matcher<OutEdge>>;
87 using AttrKeyValuePair = std::pair<string, absl::optional<AttrValue>>;
88
name()89 const absl::optional<string>& name() const { return name_; }
op()90 const absl::optional<string>& op() const { return op_; }
assigned_device()91 const absl::optional<string>& assigned_device() const {
92 return assigned_device_;
93 }
constant_value()94 const absl::optional<Tensor>& constant_value() const {
95 return constant_value_;
96 }
inputs()97 const absl::optional<InputSeqMatcher>& inputs() const {
98 return input_matchers_;
99 }
control_deps()100 const absl::optional<NodeSeqMatcher>& control_deps() const {
101 return control_deps_;
102 }
attr()103 const absl::optional<AttrKeyValuePair>& attr() const { return attr_; }
104
set_name(string name)105 void set_name(string name) {
106 DCHECK(IsEmpty());
107 name_ = std::move(name);
108 }
109
set_op(string op)110 void set_op(string op) {
111 DCHECK(IsEmpty());
112 op_ = std::move(op);
113 }
114
set_assigned_device(string assigned_device)115 void set_assigned_device(string assigned_device) {
116 DCHECK(IsEmpty());
117 assigned_device_ = std::move(assigned_device);
118 }
119
set_constant_value(Tensor constant_value)120 void set_constant_value(Tensor constant_value) {
121 DCHECK(IsEmpty());
122 constant_value_ = std::move(constant_value);
123 op_ = "Const";
124 }
125
set_inputs(InputSeqMatcher inputs)126 void set_inputs(InputSeqMatcher inputs) {
127 DCHECK(IsEmpty());
128 input_matchers_ = std::move(inputs);
129 }
130
set_control_deps(NodeSeqMatcher control_deps)131 void set_control_deps(NodeSeqMatcher control_deps) {
132 DCHECK(IsEmpty());
133 control_deps_ = std::move(control_deps);
134 }
135
set_attr(AttrKeyValuePair attr)136 void set_attr(AttrKeyValuePair attr) {
137 DCHECK(IsEmpty());
138 attr_ = std::move(attr);
139 }
140
IsEmpty()141 bool IsEmpty() const {
142 return !name().has_value() && !op().has_value() && !inputs().has_value() &&
143 !control_deps().has_value() && !attr().has_value();
144 }
145
146 private:
147 absl::optional<string> name_;
148 absl::optional<string> op_;
149 absl::optional<string> assigned_device_;
150 absl::optional<Tensor> constant_value_;
151 absl::optional<InputSeqMatcher> input_matchers_;
152 absl::optional<NodeSeqMatcher> control_deps_;
153 absl::optional<AttrKeyValuePair> attr_;
154 };
155
156 ::testing::Matcher<const Node*> NodeWith(
157 absl::Span<const NodeMatcherProperties> props);
158
159 impl::NodeMatcherProperties Inputs(
160 absl::Span<const ::testing::Matcher<OutEdge>> inputs);
161
162 impl::NodeMatcherProperties CtrlDeps(
163 absl::Span<const ::testing::Matcher<const Node*>> control_deps);
164
165 impl::NodeMatcherProperties Attr(std::pair<string, AttrValue> attrs);
166 impl::NodeMatcherProperties Attr(string name);
167
168 std::pair<string, AttrValue> AttrLiteralHelper(
169 const std::pair<string, bool>& bool_attr);
170
171 std::pair<string, AttrValue> AttrLiteralHelper(
172 const std::pair<string, absl::Span<const int>>& int_list_attr);
173
174 std::pair<string, AttrValue> AttrLiteralHelper(
175 const std::pair<string, absl::Span<const string>>& string_list_attr);
176 } // namespace impl
177
178 // -----------------------------------------------------------------------------
179 // Public interface.
180
181 // Matches a node with name `name`.
182 impl::NodeMatcherProperties Name(string name);
183
184 // Matches a node with op `op`.
185 impl::NodeMatcherProperties Op(string op);
186
187 // Matches a node with assigned device `assigned_device`.
188 impl::NodeMatcherProperties AssignedDevice(string assigned_device);
189
190 // Matches a node with a boolean typed attrbute named `name` and with value
191 // `value`.
192 template <typename ValueTy>
Attr(const string & name,ValueTy value)193 impl::NodeMatcherProperties Attr(const string& name, ValueTy value) {
194 return impl::Attr({impl::AttrLiteralHelper({name, value})});
195 }
196
Attr(const string & name)197 inline impl::NodeMatcherProperties Attr(const string& name) {
198 return impl::Attr(name);
199 }
200
201 // Matches a node with inputs `inputs`.
202 //
203 // `inputs` are ordered; `inputs`[i] must match input i.
204 template <typename... Ts>
Inputs(Ts...inputs)205 impl::NodeMatcherProperties Inputs(Ts... inputs) {
206 return impl::Inputs({inputs...});
207 }
208
209 // Matches the `idx`'th output of a node that matches `node`.
210 ::testing::Matcher<impl::OutEdge> Out(int oidx,
211 ::testing::Matcher<const Node*> node);
212
213 // Matches the first output of a node that matches `node`.
Out(::testing::Matcher<const Node * > node)214 inline ::testing::Matcher<impl::OutEdge> Out(
215 ::testing::Matcher<const Node*> node) {
216 return Out(0, node);
217 }
218
219 // Matches a node with control dependences `control_deps`.
220 //
221 // `control_deps` are unordered and will match the control deps of a node in any
222 // order.
223 template <typename... Ts>
CtrlDeps(Ts...control_deps)224 impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) {
225 return impl::CtrlDeps({control_deps...});
226 }
227
228 // Matches a constant node with value `val`.
229 impl::NodeMatcherProperties ConstantValue(
230 const ::tensorflow::Input::Initializer& val);
231
232 // The main gmock matcher. See file comment for example usage.
233 template <typename... Ts>
NodeWith(Ts...args)234 ::testing::Matcher<const Node*> NodeWith(Ts... args) {
235 std::array<impl::NodeMatcherProperties, sizeof...(Ts)> array = {args...};
236 return impl::NodeWith(array);
237 }
238
239 ::testing::Matcher<impl::OutEdge> Const(
240 const ::tensorflow::Input::Initializer& val);
241 } // namespace matchers
242
243 // If `g` has a node named `name` returns it, otherwise returns null.
244 Node* FindNodeByName(Graph* g, absl::string_view name);
245 } // namespace testing
246
247 void PrintTo(const Node* n, ::std::ostream* os);
248 void PrintTo(Node* n, ::std::ostream* os);
249 } // namespace tensorflow
250
251 #endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
252