1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
17 
18 #include "absl/strings/str_join.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/test.h"
21 
22 namespace xla {
23 namespace testing {
24 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const25 bool HloMatcher::MatchAndExplain(
26     const HloInstruction* instruction,
27     ::testing::MatchResultListener* listener) const {
28   // These cases are self-explanatory from the printed value.
29   if (!instruction || instruction->opcode() != opcode_) {
30     return false;
31   }
32   // Special case: no operand matchers means don't verify.
33   if (operands_.empty()) {
34     return true;
35   }
36   const auto& operands = instruction->operands();
37   if (operands.size() != operands_.size()) {
38     *listener << "has too "
39               << (operands.size() > operands_.size() ? "many" : "few")
40               << " operands (got " << operands.size() << ", want "
41               << operands_.size() << ")";
42     return false;
43   }
44   for (int index = 0; index < operands.size(); index++) {
45     ::testing::StringMatchResultListener inner_listener;
46     if (!operands_[index].MatchAndExplain(operands[index], &inner_listener)) {
47       if (listener->IsInterested()) {
48         *listener << "\noperand " << index << ":\n\t"
49                   << operands[index]->ToString()
50                   << "\ndoesn't match expected:\n\t";
51         operands_[index].DescribeTo(listener->stream());
52         string explanation = inner_listener.str();
53         if (!explanation.empty()) {
54           *listener << ", " << explanation;
55         }
56       }
57       return false;
58     }
59   }
60   return true;
61 }
62 
DescribeTo(::std::ostream * os) const63 void HloMatcher::DescribeTo(::std::ostream* os) const {
64   *os << opcode_;
65   if (!operands_.empty()) {
66     *os << "(";
67     for (int i = 0; i < operands_.size(); i++) {
68       if (i > 0) {
69         *os << ", ";
70       }
71       operands_[i].DescribeTo(os);
72     }
73     *os << ")";
74   }
75 }
76 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const77 bool HloParameterMatcher::MatchAndExplain(
78     const HloInstruction* instruction,
79     ::testing::MatchResultListener* listener) const {
80   if (!HloMatcher::MatchAndExplain(instruction, listener)) {
81     return false;
82   }
83   if (instruction->parameter_number() != parameter_number_) {
84     *listener << "has wrong parameter number (got "
85               << instruction->parameter_number() << ", want "
86               << parameter_number_ << ")";
87     return false;
88   }
89   return true;
90 }
91 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const92 bool HloComparisonMatcher::MatchAndExplain(
93     const HloInstruction* instruction,
94     ::testing::MatchResultListener* listener) const {
95   if (!HloMatcher::MatchAndExplain(instruction, listener)) {
96     return false;
97   }
98   if (instruction->comparison_direction() != direction_) {
99     *listener << "has wrong comparison direction (got "
100               << ComparisonDirectionToString(
101                      instruction->comparison_direction())
102               << ", want " << ComparisonDirectionToString(direction_) << ")";
103     return false;
104   }
105   return true;
106 }
107 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const108 bool HloGetTupleElementMatcher::MatchAndExplain(
109     const HloInstruction* instruction,
110     ::testing::MatchResultListener* listener) const {
111   if (!HloMatcher::MatchAndExplain(instruction, listener)) {
112     return false;
113   }
114   if (instruction->tuple_index() != tuple_index_) {
115     *listener << "has wrong tuple index (got " << instruction->tuple_index()
116               << ", want " << tuple_index_ << ")";
117     return false;
118   }
119   return true;
120 }
121 
DescribeTo(std::ostream * os) const122 void HloCustomCallMatcher::DescribeTo(std::ostream* os) const {
123   HloMatcher::DescribeTo(os);
124   *os << " with call target that ";
125   call_target_matcher_.DescribeTo(os);
126 }
127 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const128 bool HloCustomCallMatcher::MatchAndExplain(
129     const HloInstruction* instruction,
130     ::testing::MatchResultListener* listener) const {
131   if (!HloMatcher::MatchAndExplain(instruction, listener)) {
132     return false;
133   }
134   ::testing::StringMatchResultListener sub_listener;
135   bool result = ExplainMatchResult(
136       call_target_matcher_, instruction->custom_call_target(), &sub_listener);
137   if (sub_listener.str().empty()) {
138     sub_listener << " that ";
139 
140     std::stringstream desc_stream;
141     if (result) {
142       call_target_matcher_.DescribeTo(&desc_stream);
143     } else {
144       call_target_matcher_.DescribeNegationTo(&desc_stream);
145     }
146     sub_listener << desc_stream.str();
147   }
148   *listener << "custom-call with call target" << sub_listener.str();
149   return result;
150 }
151 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const152 bool HloShapeMatcher::MatchAndExplain(
153     const HloInstruction* instruction,
154     ::testing::MatchResultListener* listener) const {
155   if (ShapeUtil::Compatible(instruction->shape(), shape_)) {
156     return true;
157   }
158   *listener << instruction->ToString() << " has incorrect shape (expected: "
159             << ShapeUtil::HumanString(shape_) << ")";
160   return false;
161 }
162 
DescribeTo(std::ostream * os) const163 void HloShapeMatcher::DescribeTo(std::ostream* os) const {
164   *os << ShapeUtil::HumanString(shape_);
165 }
166 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const167 bool HloShapeAndLayoutMatcher::MatchAndExplain(
168     const HloInstruction* instruction,
169     ::testing::MatchResultListener* listener) const {
170   if (ShapeUtil::Equal(instruction->shape(), shape_)) {
171     return true;
172   }
173   *listener << instruction->ToString() << " has incorrect shape (expected: "
174             << ShapeUtil::HumanStringWithLayout(shape_) << ")";
175   return false;
176 }
177 
DescribeTo(std::ostream * os) const178 void HloShapeAndLayoutMatcher::DescribeTo(std::ostream* os) const {
179   *os << ShapeUtil::HumanStringWithLayout(shape_);
180 }
181 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const182 bool HloShardingMatcher::MatchAndExplain(
183     const HloInstruction* instruction,
184     ::testing::MatchResultListener* listener) const {
185   if (!sharding_.has_value()) {
186     if (!instruction->has_sharding()) {
187       return true;
188     }
189     *listener << instruction->ToString() << " expected to have no sharding.";
190     return false;
191   }
192   if (instruction->has_sharding()) {
193     if (instruction->sharding() == sharding_.value()) {
194       return true;
195     }
196     *listener << instruction->ToString()
197               << " has incorrect sharding (expected: " << sharding_->ToString()
198               << ")";
199     return false;
200   } else {
201     *listener << instruction->ToString()
202               << " has no sharding (expected: " << sharding_->ToString() << ")";
203     return false;
204   }
205 }
206 
DescribeTo(std::ostream * os) const207 void HloShardingMatcher::DescribeTo(std::ostream* os) const {
208   if (sharding_.has_value()) {
209     *os << sharding_->ToString();
210   } else {
211     *os << "<no-sharding>";
212   }
213 }
214 
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const215 bool HloDotWithContractingDimsMatcher::MatchAndExplain(
216     const HloInstruction* instruction,
217     ::testing::MatchResultListener* listener) const {
218   if (!HloMatcher::MatchAndExplain(instruction, listener)) {
219     return false;
220   }
221 
222   const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers();
223   if (dim_nums.lhs_contracting_dimensions_size() != 1 ||
224       dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
225     *listener << instruction->ToString()
226               << " has wrong lhs_contracting_dimensions (got {"
227               << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",")
228               << "} want {" << lhs_contracting_dim_ << "})";
229     return false;
230   }
231 
232   if (dim_nums.rhs_contracting_dimensions_size() != 1 ||
233       dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
234     *listener << instruction->ToString()
235               << " has wrong rhs_contracting_dimensions (got {"
236               << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",")
237               << "} want {" << rhs_contracting_dim_ << "})";
238     return false;
239   }
240 
241   return true;
242 }
243 
DescribeTo(std::ostream * os) const244 void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const {
245   HloMatcher::DescribeTo(os);
246   *os << " with lhs_contracting_dims={" << lhs_contracting_dim_
247       << "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}";
248 }
249 
250 }  // namespace testing
251 
PrintTo(const HloInstruction * inst,::std::ostream * os)252 void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
253   *os << (inst ? inst->ToString() : "nullptr");
254 }
255 
256 }  // namespace xla
257