1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
17 #include "tensorflow/compiler/xla/literal_util.h"
18 #include "tensorflow/compiler/xla/service/hlo_parser.h"
19 #include "tensorflow/compiler/xla/shape_util.h"
20
21 namespace op = xla::testing::opcode_matchers;
22 using ::testing::_;
23 using ::testing::Eq;
24
25 namespace xla {
26 namespace {
27
DescribeHloMatcher(const::testing::Matcher<const HloInstruction * > & m)28 string DescribeHloMatcher(const ::testing::Matcher<const HloInstruction*>& m) {
29 std::stringstream ss;
30 m.DescribeTo(&ss);
31 return ss.str();
32 }
33
34 template <typename M, typename T>
Explain(const T & t,const M & m)35 string Explain(const T& t, const M& m) {
36 ::testing::StringMatchResultListener listener;
37 EXPECT_THAT(t, ::testing::Not(m)); // For the error message.
38 EXPECT_FALSE(m.MatchAndExplain(t, &listener));
39 return listener.str();
40 }
41
TEST(HloMatchersTest,Test)42 TEST(HloMatchersTest, Test) {
43 auto shape = ShapeUtil::MakeShape(F32, {1});
44 auto param = HloInstruction::CreateParameter(0, shape, "param");
45 auto mul = HloInstruction::CreateBinary(shape, HloOpcode::kMultiply,
46 param.get(), param.get());
47 auto add = HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param.get(),
48 mul.get());
49
50 EXPECT_THAT(add.get(), op::Add());
51 EXPECT_THAT(add.get(), op::Add(op::Parameter(), op::Multiply()));
52 EXPECT_THAT(add.get(),
53 op::Add(op::Parameter(), op::Multiply(_, op::Parameter())));
54
55 // Negative matches: check the explanation string.
56 EXPECT_THAT(Explain(add.get(), op::Parameter()), Eq(""));
57 EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter())),
58 Eq("has too many operands (got 2, want 1)"));
59 EXPECT_THAT(
60 Explain(add.get(), op::Add(op::Parameter(), op::Parameter())),
61 Eq("\noperand 1:\n\t"
62 "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n"
63 "doesn't match expected:\n\t"
64 "parameter"));
65 EXPECT_THAT(
66 Explain(add.get(),
67 op::Add(op::Parameter(), op::Multiply(op::Add(), op::Add()))),
68 Eq("\noperand 1:\n\t"
69 "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n"
70 "doesn't match expected:\n\t"
71 "multiply(add, add), \n"
72 "operand 0:\n\t"
73 "%param = f32[1]{0} parameter(0)\n"
74 "doesn't match expected:\n\t"
75 "add"));
76 }
77
TEST(HloMatchersTest,CustomCallMatcher)78 TEST(HloMatchersTest, CustomCallMatcher) {
79 auto c1 =
80 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
81 auto c2 =
82 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3}));
83 auto call = HloInstruction::CreateCustomCall(
84 ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target");
85
86 EXPECT_THAT(call.get(), op::CustomCall());
87 EXPECT_THAT(call.get(), op::CustomCall(c1.get(), c2.get()));
88 EXPECT_THAT(call.get(), op::CustomCall("foo_target"));
89 EXPECT_THAT(call.get(), op::CustomCall("foo_target", c1.get(), c2.get()));
90 EXPECT_THAT(call.get(), op::CustomCall(::testing::StartsWith("foo")));
91 EXPECT_THAT(call.get(),
92 op::CustomCall(::testing::Not(::testing::StartsWith("bar"))));
93
94 // Wrong number of operands.
95 EXPECT_THAT(call.get(), ::testing::Not(op::CustomCall(c1.get())));
96
97 // Call target does not match.
98 EXPECT_THAT(call.get(),
99 ::testing::Not(op::CustomCall(::testing::StartsWith("bar"))));
100
101 EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")),
102 R"(custom-call with call target that isn't equal to "bar")");
103 EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")),
104 R"(custom-call with call target that is equal to "foo_target")");
105 }
106
TEST(HloMatchersTest,ShapeMatcher)107 TEST(HloMatchersTest, ShapeMatcher) {
108 auto p0 = HloInstruction::CreateParameter(
109 0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param");
110
111 EXPECT_THAT(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {5, 7})));
112 EXPECT_THAT(p0.get(), op::Shape("f32[5,7]"));
113 EXPECT_THAT(
114 p0.get(),
115 ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {5, 7}))));
116 EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]")));
117 EXPECT_THAT(p0.get(),
118 ::testing::Not(op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))));
119 EXPECT_THAT(p0.get(), ::testing::Not(op::Shape("f32[7,5]")));
120 EXPECT_THAT(
121 p0.get(),
122 ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {7, 5}))));
123 EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[7,5]")));
124 EXPECT_THAT(p0.get(),
125 op::Shape(ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1})));
126 EXPECT_THAT(p0.get(), op::Shape("f32[5,7]{0,1}"));
127 EXPECT_THAT(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout(
128 F32, {5, 7}, {0, 1})));
129 EXPECT_THAT(p0.get(), op::ShapeWithLayout("f32[5,7]{0,1}"));
130 EXPECT_THAT(p0.get(),
131 ::testing::Not(op::ShapeWithLayout(
132 ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {1, 0}))));
133 EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]{1,0}")));
134
135 EXPECT_THAT(Explain(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))),
136 "%param = f32[5,7]{0,1} parameter(0) has incorrect shape "
137 "(expected: f32[7,5])");
138 EXPECT_THAT(
139 Explain(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout(
140 F32, {7, 5}, {1, 0}))),
141 "%param = f32[5,7]{0,1} parameter(0) has incorrect shape "
142 "(expected: f32[7,5]{1,0})");
143 }
144
TEST(HloMatchersTest,ShardingMatcher)145 TEST(HloMatchersTest, ShardingMatcher) {
146 auto p0 = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {5}),
147 "param.0");
148 p0->clear_sharding();
149 auto p1 = HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {7}),
150 "param.1");
151 p1->set_sharding(HloSharding::AssignDevice(1));
152
153 auto tuple_shape = ShapeUtil::MakeTupleShape(
154 {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}),
155 ShapeUtil::MakeShape(F32, {11})});
156 auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2");
157 Array<int64> assignment({2});
158 assignment.SetValues({0, 1});
159 auto sharding = HloSharding::Tuple(
160 tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1),
161 HloSharding::Replicate()});
162 p2->set_sharding(sharding);
163
164 EXPECT_THAT(p0.get(), op::NoSharding());
165 EXPECT_THAT(p0.get(),
166 ::testing::Not(op::Sharding(HloSharding::AssignDevice(1))));
167 EXPECT_THAT(p1.get(), ::testing::Not(op::NoSharding()));
168 EXPECT_THAT(p1.get(),
169 ::testing::Not(op::Sharding(HloSharding::AssignDevice(0))));
170 EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1)));
171
172 EXPECT_THAT(
173 p2.get(),
174 op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}"));
175
176 EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))),
177 "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: "
178 "{maximal device=1})");
179 EXPECT_THAT(Explain(p1.get(), op::NoSharding()),
180 "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} "
181 "expected to have no sharding.");
182 EXPECT_THAT(Explain(p1.get(), op::Sharding(HloSharding::AssignDevice(0))),
183 "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} "
184 "has incorrect sharding (expected: {maximal device=0})");
185 }
186
TEST(HloMatchersTest,DotMatcher)187 TEST(HloMatchersTest, DotMatcher) {
188 string hlo_string = R"(
189 HloModule DotOperationFusion_TransposeFusion
190
191 ENTRY DotOperationFusion_TransposeFusion {
192 arg0 = f32[1,256] parameter(0)
193 arg1 = f32[256,1024] parameter(1)
194 ROOT dot = f32[1,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
195 }
196 )";
197
198 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
199 ParseHloString(hlo_string));
200 HloInstruction* root = module->entry_computation()->root_instruction();
201
202 EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1),
203 /*lhs_contracting_dim=*/1,
204 /*rhs_contracting_dim=*/0));
205
206 EXPECT_THAT(
207 Explain(root, op::Dot(op::Parameter(0), op::Parameter(1),
208 /*lhs_contracting_dim=*/0,
209 /*rhs_contracting_dim=*/0)),
210 "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
211 "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong "
212 "lhs_contracting_dimensions (got {1} want {0})");
213
214 EXPECT_THAT(
215 Explain(root, op::Dot(op::Parameter(0), op::Parameter(1),
216 /*lhs_contracting_dim=*/1,
217 /*rhs_contracting_dim=*/1)),
218 "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
219 "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong "
220 "rhs_contracting_dimensions (got {0} want {1})");
221 }
222
TEST(HloMatchersTest,ComparisonMatcher)223 TEST(HloMatchersTest, ComparisonMatcher) {
224 auto shape = ShapeUtil::MakeShape(F32, {1});
225 auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
226 auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
227 auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
228 ComparisonDirection::kEq);
229 auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
230 ComparisonDirection::kNe);
231 auto add =
232 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
233 auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
234 ComparisonDirection::kLe);
235
236 EXPECT_THAT(eq.get(), op::Compare());
237 EXPECT_THAT(eq.get(), op::Eq());
238 EXPECT_THAT(ne.get(), op::Compare());
239 EXPECT_THAT(ne.get(), op::Ne());
240 EXPECT_THAT(le.get(),
241 op::Compare(op::Parameter(0),
242 op::Add(op::Parameter(0), op::Parameter(1))));
243 EXPECT_THAT(le.get(), op::Le(op::Parameter(0),
244 op::Add(op::Parameter(0), op::Parameter(1))));
245
246 EXPECT_THAT(Explain(eq.get(), op::Add()), Eq(""));
247 EXPECT_THAT(Explain(eq.get(), op::Ne()),
248 Eq("has wrong comparison direction (got EQ, want NE)"));
249 }
250
251 } // namespace
252 } // namespace xla
253