1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/test.h"
23 
24 namespace xla {
25 namespace testing {
26 
27 class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
28  public:
HloMatcher(HloOpcode opcode,std::vector<::testing::Matcher<const HloInstruction * >> operands)29   HloMatcher(HloOpcode opcode,
30              std::vector<::testing::Matcher<const HloInstruction*>> operands)
31       : opcode_(opcode), operands_(operands) {}
32 
33   bool MatchAndExplain(const HloInstruction* instruction,
34                        ::testing::MatchResultListener* listener) const override;
35 
36   void DescribeTo(::std::ostream* os) const override;
37 
38  private:
39   HloOpcode opcode_;
40   std::vector<::testing::Matcher<const HloInstruction*>> operands_;
41 };
42 
43 // Custom matcher for parameters, which accepts a parameter number.
44 class HloParameterMatcher : public HloMatcher {
45  public:
HloParameterMatcher(int64 parameter_number)46   explicit HloParameterMatcher(int64 parameter_number)
47       : HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
48         parameter_number_(parameter_number) {}
49 
50   bool MatchAndExplain(const HloInstruction* instruction,
51                        ::testing::MatchResultListener* listener) const override;
52 
53  private:
54   int64 parameter_number_;
55 };
56 
57 // Custom matcher for comparisons, which accepts a comparison direction.
58 class HloComparisonMatcher : public HloMatcher {
59  public:
HloComparisonMatcher(ComparisonDirection direction,std::vector<::testing::Matcher<const HloInstruction * >> operands)60   explicit HloComparisonMatcher(
61       ComparisonDirection direction,
62       std::vector<::testing::Matcher<const HloInstruction*>> operands)
63       : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
64 
65   bool MatchAndExplain(const HloInstruction* instruction,
66                        ::testing::MatchResultListener* listener) const override;
67 
68  private:
69   ComparisonDirection direction_;
70 };
71 
72 // Custom matcher for get-tuple-element instructions, which accepts a tuple
73 // index to match.
74 class HloGetTupleElementMatcher : public HloMatcher {
75  public:
HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction * > operand,int64 tuple_index)76   HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
77                             int64 tuple_index)
78       : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
79         tuple_index_(tuple_index) {}
80 
81   bool MatchAndExplain(const HloInstruction* instruction,
82                        ::testing::MatchResultListener* listener) const override;
83 
84  private:
85   int64 tuple_index_;
86 };
87 
88 // Custom matcher for custom-call instructions, which accepts a matcher for its
89 // call target.
90 class HloCustomCallMatcher : public HloMatcher {
91  public:
HloCustomCallMatcher(::testing::Matcher<string> call_target_matcher,std::vector<::testing::Matcher<const HloInstruction * >> operands)92   HloCustomCallMatcher(
93       ::testing::Matcher<string> call_target_matcher,
94       std::vector<::testing::Matcher<const HloInstruction*>> operands)
95       : HloMatcher(HloOpcode::kCustomCall, operands),
96         call_target_matcher_(call_target_matcher) {}
97 
98   bool MatchAndExplain(const HloInstruction* instruction,
99                        ::testing::MatchResultListener* listener) const override;
100   void DescribeTo(std::ostream* os) const override;
101 
102  private:
103   ::testing::Matcher<string> call_target_matcher_;
104 };
105 
106 class HloShapeMatcher
107     : public ::testing::MatcherInterface<const HloInstruction*> {
108  public:
HloShapeMatcher(const Shape & shape)109   explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {}
110 
111   bool MatchAndExplain(const HloInstruction* instruction,
112                        ::testing::MatchResultListener* listener) const override;
113   void DescribeTo(std::ostream* os) const override;
114 
115  private:
116   Shape shape_;
117 };
118 
119 class HloShapeAndLayoutMatcher
120     : public ::testing::MatcherInterface<const HloInstruction*> {
121  public:
HloShapeAndLayoutMatcher(const Shape & shape)122   explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {}
123 
124   bool MatchAndExplain(const HloInstruction* instruction,
125                        ::testing::MatchResultListener* listener) const override;
126   void DescribeTo(std::ostream* os) const override;
127 
128  private:
129   Shape shape_;
130 };
131 
132 // Verify the sharding of an instruction against the provided HloSharding. If a
133 // nullopt is provided for the expected sharding then it checks that no sharding
134 // is present for an instruction.
135 class HloShardingMatcher
136     : public ::testing::MatcherInterface<const HloInstruction*> {
137  public:
HloShardingMatcher(const absl::optional<HloSharding> & sharding)138   explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding)
139       : sharding_(sharding) {}
140 
141   bool MatchAndExplain(const HloInstruction* instruction,
142                        ::testing::MatchResultListener* listener) const override;
143   void DescribeTo(std::ostream* os) const override;
144 
145  private:
146   absl::optional<HloSharding> sharding_;
147 };
148 
149 // Matches a Dot HLO instruction with specific LHS and RHS contracting
150 // dimensions.
151 class HloDotWithContractingDimsMatcher : public HloMatcher {
152  public:
HloDotWithContractingDimsMatcher(::testing::Matcher<const HloInstruction * > lhs,::testing::Matcher<const HloInstruction * > rhs,int64 lhs_contracting_dim,int64 rhs_contracting_dim)153   explicit HloDotWithContractingDimsMatcher(
154       ::testing::Matcher<const HloInstruction*> lhs,
155       ::testing::Matcher<const HloInstruction*> rhs, int64 lhs_contracting_dim,
156       int64 rhs_contracting_dim)
157       : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}),
158         lhs_contracting_dim_(lhs_contracting_dim),
159         rhs_contracting_dim_(rhs_contracting_dim) {}
160 
161   bool MatchAndExplain(const HloInstruction* instruction,
162                        ::testing::MatchResultListener* listener) const override;
163   void DescribeTo(std::ostream* os) const override;
164 
165  private:
166   int64 lhs_contracting_dim_;
167   int64 rhs_contracting_dim_;
168 };
169 
170 // HloInstruction* matchers for opcode and operands. Example:
171 //   namespace op = xla::opcode_matchers;
172 //   EXPECT_THAT(instruction,
173 //               op::Add(op::Reshape(), op::Add(op::Reshape(), _)));
174 namespace opcode_matchers {
175 #define HLO_MATCHER(opcode)                                                \
176   template <typename... M>                                                 \
177   ::testing::Matcher<const ::xla::HloInstruction*> opcode(M... operands) { \
178     return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(          \
179         ::xla::HloOpcode::k##opcode, {operands...}));                      \
180   }
181 HLO_MATCHER(Abs);
182 HLO_MATCHER(Add);
183 HLO_MATCHER(AllToAll);
184 HLO_MATCHER(Bitcast);
185 HLO_MATCHER(Broadcast);
186 HLO_MATCHER(BatchNormGrad);
187 HLO_MATCHER(Call);
188 HLO_MATCHER(Ceil);
189 HLO_MATCHER(Clamp);
190 HLO_MATCHER(Compare);
191 HLO_MATCHER(Concatenate);
192 HLO_MATCHER(Conditional);
193 HLO_MATCHER(Constant);
194 HLO_MATCHER(Convert);
195 HLO_MATCHER(Convolution);
196 HLO_MATCHER(Copy);
197 HLO_MATCHER(AllReduce);
198 HLO_MATCHER(CollectivePermute);
199 HLO_MATCHER(Divide);
200 HLO_MATCHER(Domain);
201 HLO_MATCHER(DynamicSlice);
202 HLO_MATCHER(DynamicUpdateSlice);
203 HLO_MATCHER(Exp);
204 HLO_MATCHER(Floor);
205 HLO_MATCHER(Fusion);
206 HLO_MATCHER(AfterAll);
207 HLO_MATCHER(Iota);
208 HLO_MATCHER(Infeed);
209 HLO_MATCHER(IsFinite);
210 HLO_MATCHER(Log);
211 HLO_MATCHER(And);
212 HLO_MATCHER(Not);
213 HLO_MATCHER(Or);
214 HLO_MATCHER(Xor);
215 HLO_MATCHER(Map);
216 HLO_MATCHER(Maximum);
217 HLO_MATCHER(Minimum);
218 HLO_MATCHER(Multiply);
219 HLO_MATCHER(Negate);
220 HLO_MATCHER(Outfeed);
221 HLO_MATCHER(Pad);
222 HLO_MATCHER(Power);
223 HLO_MATCHER(Recv);
224 HLO_MATCHER(RecvDone);
225 HLO_MATCHER(Reduce);
226 HLO_MATCHER(ReducePrecision);
227 HLO_MATCHER(ReduceWindow);
228 HLO_MATCHER(Remainder);
229 HLO_MATCHER(Reshape);
230 HLO_MATCHER(Reverse);
231 HLO_MATCHER(Rng);
232 HLO_MATCHER(Scatter);
233 HLO_MATCHER(Select);
234 HLO_MATCHER(SelectAndScatter);
235 HLO_MATCHER(Send);
236 HLO_MATCHER(SendDone);
237 HLO_MATCHER(ShiftLeft);
238 HLO_MATCHER(ShiftRightLogical);
239 HLO_MATCHER(ShiftRightArithmetic);
240 HLO_MATCHER(Sign);
241 HLO_MATCHER(Slice);
242 HLO_MATCHER(Sort);
243 HLO_MATCHER(Subtract);
244 HLO_MATCHER(Tanh);
245 HLO_MATCHER(Trace);
246 HLO_MATCHER(Transpose);
247 HLO_MATCHER(Tuple);
248 HLO_MATCHER(TupleSelect);
249 HLO_MATCHER(While);
250 
251 // The special cases below let you check additional information about the
252 // HloInstruction, beyond just its opcode and operands.  In all cases you can
253 // still use the generic matcher which doesn't check this info.
254 //
255 // Feel free to add additional custom matchers below.
256 
257 //  - Parameter(N) matches parameter number N.
258 //  - Parameter() matches any parameter.
Parameter(int64 parameter_number)259 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
260     int64 parameter_number) {
261   return ::testing::MakeMatcher(
262       new ::xla::testing::HloParameterMatcher(parameter_number));
263 }
Parameter()264 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
265   return ::testing::MakeMatcher(
266       new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
267 }
268 
269 // Comparison matchers below do not require any additional arguments.
270 template <typename... M>
Eq(M...operands)271 inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
272   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
273       ComparisonDirection::kEq, {operands...}));
274 }
275 template <typename... M>
Ne(M...operands)276 inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
277   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
278       ComparisonDirection::kNe, {operands...}));
279 }
280 template <typename... M>
Ge(M...operands)281 inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
282   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
283       ComparisonDirection::kGe, {operands...}));
284 }
285 template <typename... M>
Gt(M...operands)286 inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
287   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
288       ComparisonDirection::kGt, {operands...}));
289 }
290 template <typename... M>
Le(M...operands)291 inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
292   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
293       ComparisonDirection::kLe, {operands...}));
294 }
295 template <typename... M>
Lt(M...operands)296 inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
297   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
298       ComparisonDirection::kLt, {operands...}));
299 }
300 
301 // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
302 // tuple element of operand, while GetTupleElement(operand) matches any GTE
303 // operation on operand, and GetTupleElement() matches any GTE operation at all.
GetTupleElement(::testing::Matcher<const HloInstruction * > operand,int64 tuple_index)304 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
305     ::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) {
306   return ::testing::MakeMatcher(
307       new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
308 }
GetTupleElement(::testing::Matcher<const HloInstruction * > operand)309 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
310     ::testing::Matcher<const HloInstruction*> operand) {
311   return ::testing::MakeMatcher(
312       new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
313 }
GetTupleElement()314 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
315   return ::testing::MakeMatcher(
316       new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
317 }
318 
319 // - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
320 //   target T and the given operands.
321 //
322 // - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
323 //   given operands.
324 //
325 // - CustomCall() matches any CustomCall HLO at all.
326 template <typename... M>
CustomCall(::testing::Matcher<string> call_target_matcher,M...operands)327 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
328     ::testing::Matcher<string> call_target_matcher, M... operands) {
329   return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
330       call_target_matcher, {operands...}));
331 }
332 // This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
333 // ::testing::Matcher<string>.  In that case, we want to prefer the overload
334 // above.
335 template <typename FirstM, typename... M,
336           typename Dummy = typename std::enable_if<
337               !std::is_convertible<FirstM, ::testing::Matcher<string>>::value,
338               void>::type*>
CustomCall(FirstM operands_first,M...operands_rest)339 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
340     FirstM operands_first, M... operands_rest) {
341   return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
342       HloOpcode::kCustomCall, {operands_first, operands_rest...}));
343 }
CustomCall()344 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
345   return ::testing::MakeMatcher(
346       new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
347 }
348 
349 // Verifies the shape or the shape and the layout of an HLO instruction against
350 // the provided shape object.
Shape(const class Shape & shape)351 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
352     const class Shape& shape) {
353   return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
354 }
Shape(absl::string_view shape)355 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
356     absl::string_view shape) {
357   return ::testing::MakeMatcher(
358       new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie()));
359 }
ShapeWithLayout(const class Shape & shape)360 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
361     const class Shape& shape) {
362   return ::testing::MakeMatcher(
363       new ::xla::testing::HloShapeAndLayoutMatcher(shape));
364 }
ShapeWithLayout(absl::string_view shape)365 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
366     absl::string_view shape) {
367   return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
368       ParseShape(shape).ValueOrDie()));
369 }
370 
371 // Verifies the value of the HloSharing against the provided sharding object.
Sharding(const HloSharding & sharding)372 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
373     const HloSharding& sharding) {
374   return ::testing::MakeMatcher(
375       new ::xla::testing::HloShardingMatcher(sharding));
376 }
377 // Matcher for Sharding from sharding string
Sharding(absl::string_view sharding)378 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
379     absl::string_view sharding) {
380   return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
381       ParseSharding(sharding).ValueOrDie()));
382 }
383 // Verifies that no HloSharding is set for an HLO instruction.
NoSharding()384 inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
385   return ::testing::MakeMatcher(
386       new ::xla::testing::HloShardingMatcher(absl::nullopt));
387 }
388 
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher)389 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
390     ::testing::Matcher<const HloInstruction*> lhs_matcher,
391     ::testing::Matcher<const HloInstruction*> rhs_matcher) {
392   return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
393       ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher}));
394 }
395 
396 // Matches a Dot HLO instruction if it has exactly one lhs contracting dimension
397 // equal to `lhs_contracting_dim` and exactly one rhs contracting dimension
398 // equal to `rhs_contracting_dim`.
399 //
400 // Currently the HLO verifier rejects Dot operations with more than one
401 // contracting dimension (even though we can represent these in the
402 // DotDimensionNumbers proto) so there is no need to generalize this to support
403 // multiple contracting dimensions.
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher,int64 lhs_contracting_dim,int64 rhs_contracting_dim)404 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
405     ::testing::Matcher<const HloInstruction*> lhs_matcher,
406     ::testing::Matcher<const HloInstruction*> rhs_matcher,
407     int64 lhs_contracting_dim, int64 rhs_contracting_dim) {
408   return ::testing::MakeMatcher(
409       new ::xla::testing::HloDotWithContractingDimsMatcher(
410           lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim));
411 }
412 
413 #undef HLO_MATCHER
414 }  // namespace opcode_matchers
415 
416 // Helper to convert smart to raw pointers for matching.
417 template <typename Container>
Pointers(const Container & container)418 std::vector<const HloInstruction*> Pointers(const Container& container) {
419   std::vector<const HloInstruction*> result;
420   result.reserve(container.size());
421   for (const auto& entry : container) result.push_back(entry.get());
422   return result;
423 }
424 
425 }  // namespace testing
426 
427 // Tell GMock to print HloInstruction* by value, so error messages are nice.
428 // Has to be in the same namespace as 'HloInstruction'.
429 void PrintTo(const HloInstruction* inst, ::std::ostream* os);
430 
431 }  // namespace xla
432 
433 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
434