1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
17
18 #include "absl/strings/str_join.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/test.h"
21
22 namespace xla {
23 namespace testing {
24
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const25 bool HloMatcher::MatchAndExplain(
26 const HloInstruction* instruction,
27 ::testing::MatchResultListener* listener) const {
28 // These cases are self-explanatory from the printed value.
29 if (!instruction || instruction->opcode() != opcode_) {
30 return false;
31 }
32 // Special case: no operand matchers means don't verify.
33 if (operands_.empty()) {
34 return true;
35 }
36 const auto& operands = instruction->operands();
37 if (operands.size() != operands_.size()) {
38 *listener << "has too "
39 << (operands.size() > operands_.size() ? "many" : "few")
40 << " operands (got " << operands.size() << ", want "
41 << operands_.size() << ")";
42 return false;
43 }
44 for (int index = 0; index < operands.size(); index++) {
45 ::testing::StringMatchResultListener inner_listener;
46 if (!operands_[index].MatchAndExplain(operands[index], &inner_listener)) {
47 if (listener->IsInterested()) {
48 *listener << "\noperand " << index << ":\n\t"
49 << operands[index]->ToString()
50 << "\ndoesn't match expected:\n\t";
51 operands_[index].DescribeTo(listener->stream());
52 string explanation = inner_listener.str();
53 if (!explanation.empty()) {
54 *listener << ", " << explanation;
55 }
56 }
57 return false;
58 }
59 }
60 return true;
61 }
62
DescribeTo(::std::ostream * os) const63 void HloMatcher::DescribeTo(::std::ostream* os) const {
64 *os << opcode_;
65 if (!operands_.empty()) {
66 *os << "(";
67 for (int i = 0; i < operands_.size(); i++) {
68 if (i > 0) {
69 *os << ", ";
70 }
71 operands_[i].DescribeTo(os);
72 }
73 *os << ")";
74 }
75 }
76
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const77 bool HloParameterMatcher::MatchAndExplain(
78 const HloInstruction* instruction,
79 ::testing::MatchResultListener* listener) const {
80 if (!HloMatcher::MatchAndExplain(instruction, listener)) {
81 return false;
82 }
83 if (instruction->parameter_number() != parameter_number_) {
84 *listener << "has wrong parameter number (got "
85 << instruction->parameter_number() << ", want "
86 << parameter_number_ << ")";
87 return false;
88 }
89 return true;
90 }
91
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const92 bool HloComparisonMatcher::MatchAndExplain(
93 const HloInstruction* instruction,
94 ::testing::MatchResultListener* listener) const {
95 if (!HloMatcher::MatchAndExplain(instruction, listener)) {
96 return false;
97 }
98 if (instruction->comparison_direction() != direction_) {
99 *listener << "has wrong comparison direction (got "
100 << ComparisonDirectionToString(
101 instruction->comparison_direction())
102 << ", want " << ComparisonDirectionToString(direction_) << ")";
103 return false;
104 }
105 return true;
106 }
107
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const108 bool HloGetTupleElementMatcher::MatchAndExplain(
109 const HloInstruction* instruction,
110 ::testing::MatchResultListener* listener) const {
111 if (!HloMatcher::MatchAndExplain(instruction, listener)) {
112 return false;
113 }
114 if (instruction->tuple_index() != tuple_index_) {
115 *listener << "has wrong tuple index (got " << instruction->tuple_index()
116 << ", want " << tuple_index_ << ")";
117 return false;
118 }
119 return true;
120 }
121
DescribeTo(std::ostream * os) const122 void HloCustomCallMatcher::DescribeTo(std::ostream* os) const {
123 HloMatcher::DescribeTo(os);
124 *os << " with call target that ";
125 call_target_matcher_.DescribeTo(os);
126 }
127
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const128 bool HloCustomCallMatcher::MatchAndExplain(
129 const HloInstruction* instruction,
130 ::testing::MatchResultListener* listener) const {
131 if (!HloMatcher::MatchAndExplain(instruction, listener)) {
132 return false;
133 }
134 ::testing::StringMatchResultListener sub_listener;
135 bool result = ExplainMatchResult(
136 call_target_matcher_, instruction->custom_call_target(), &sub_listener);
137 if (sub_listener.str().empty()) {
138 sub_listener << " that ";
139
140 std::stringstream desc_stream;
141 if (result) {
142 call_target_matcher_.DescribeTo(&desc_stream);
143 } else {
144 call_target_matcher_.DescribeNegationTo(&desc_stream);
145 }
146 sub_listener << desc_stream.str();
147 }
148 *listener << "custom-call with call target" << sub_listener.str();
149 return result;
150 }
151
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const152 bool HloShapeMatcher::MatchAndExplain(
153 const HloInstruction* instruction,
154 ::testing::MatchResultListener* listener) const {
155 if (ShapeUtil::Compatible(instruction->shape(), shape_)) {
156 return true;
157 }
158 *listener << instruction->ToString() << " has incorrect shape (expected: "
159 << ShapeUtil::HumanString(shape_) << ")";
160 return false;
161 }
162
DescribeTo(std::ostream * os) const163 void HloShapeMatcher::DescribeTo(std::ostream* os) const {
164 *os << ShapeUtil::HumanString(shape_);
165 }
166
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const167 bool HloShapeAndLayoutMatcher::MatchAndExplain(
168 const HloInstruction* instruction,
169 ::testing::MatchResultListener* listener) const {
170 if (ShapeUtil::Equal(instruction->shape(), shape_)) {
171 return true;
172 }
173 *listener << instruction->ToString() << " has incorrect shape (expected: "
174 << ShapeUtil::HumanStringWithLayout(shape_) << ")";
175 return false;
176 }
177
DescribeTo(std::ostream * os) const178 void HloShapeAndLayoutMatcher::DescribeTo(std::ostream* os) const {
179 *os << ShapeUtil::HumanStringWithLayout(shape_);
180 }
181
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const182 bool HloShardingMatcher::MatchAndExplain(
183 const HloInstruction* instruction,
184 ::testing::MatchResultListener* listener) const {
185 if (!sharding_.has_value()) {
186 if (!instruction->has_sharding()) {
187 return true;
188 }
189 *listener << instruction->ToString() << " expected to have no sharding.";
190 return false;
191 }
192 if (instruction->has_sharding()) {
193 if (instruction->sharding() == sharding_.value()) {
194 return true;
195 }
196 *listener << instruction->ToString()
197 << " has incorrect sharding (expected: " << sharding_->ToString()
198 << ")";
199 return false;
200 } else {
201 *listener << instruction->ToString()
202 << " has no sharding (expected: " << sharding_->ToString() << ")";
203 return false;
204 }
205 }
206
DescribeTo(std::ostream * os) const207 void HloShardingMatcher::DescribeTo(std::ostream* os) const {
208 if (sharding_.has_value()) {
209 *os << sharding_->ToString();
210 } else {
211 *os << "<no-sharding>";
212 }
213 }
214
MatchAndExplain(const HloInstruction * instruction,::testing::MatchResultListener * listener) const215 bool HloDotWithContractingDimsMatcher::MatchAndExplain(
216 const HloInstruction* instruction,
217 ::testing::MatchResultListener* listener) const {
218 if (!HloMatcher::MatchAndExplain(instruction, listener)) {
219 return false;
220 }
221
222 const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers();
223 if (dim_nums.lhs_contracting_dimensions_size() != 1 ||
224 dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
225 *listener << instruction->ToString()
226 << " has wrong lhs_contracting_dimensions (got {"
227 << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",")
228 << "} want {" << lhs_contracting_dim_ << "})";
229 return false;
230 }
231
232 if (dim_nums.rhs_contracting_dimensions_size() != 1 ||
233 dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
234 *listener << instruction->ToString()
235 << " has wrong rhs_contracting_dimensions (got {"
236 << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",")
237 << "} want {" << rhs_contracting_dim_ << "})";
238 return false;
239 }
240
241 return true;
242 }
243
DescribeTo(std::ostream * os) const244 void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const {
245 HloMatcher::DescribeTo(os);
246 *os << " with lhs_contracting_dims={" << lhs_contracting_dim_
247 << "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}";
248 }
249
250 } // namespace testing
251
PrintTo(const HloInstruction * inst,::std::ostream * os)252 void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
253 *os << (inst ? inst->ToString() : "nullptr");
254 }
255
256 } // namespace xla
257