1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
17 #include "tensorflow/compiler/xla/literal_util.h"
18 #include "tensorflow/compiler/xla/service/hlo_parser.h"
19 #include "tensorflow/compiler/xla/shape_util.h"
20 
21 namespace op = xla::testing::opcode_matchers;
22 using ::testing::_;
23 using ::testing::Eq;
24 
25 namespace xla {
26 namespace {
27 
DescribeHloMatcher(const::testing::Matcher<const HloInstruction * > & m)28 string DescribeHloMatcher(const ::testing::Matcher<const HloInstruction*>& m) {
29   std::stringstream ss;
30   m.DescribeTo(&ss);
31   return ss.str();
32 }
33 
34 template <typename M, typename T>
Explain(const T & t,const M & m)35 string Explain(const T& t, const M& m) {
36   ::testing::StringMatchResultListener listener;
37   EXPECT_THAT(t, ::testing::Not(m));  // For the error message.
38   EXPECT_FALSE(m.MatchAndExplain(t, &listener));
39   return listener.str();
40 }
41 
TEST(HloMatchersTest,Test)42 TEST(HloMatchersTest, Test) {
43   auto shape = ShapeUtil::MakeShape(F32, {1});
44   auto param = HloInstruction::CreateParameter(0, shape, "param");
45   auto mul = HloInstruction::CreateBinary(shape, HloOpcode::kMultiply,
46                                           param.get(), param.get());
47   auto add = HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param.get(),
48                                           mul.get());
49 
50   EXPECT_THAT(add.get(), op::Add());
51   EXPECT_THAT(add.get(), op::Add(op::Parameter(), op::Multiply()));
52   EXPECT_THAT(add.get(),
53               op::Add(op::Parameter(), op::Multiply(_, op::Parameter())));
54 
55   // Negative matches: check the explanation string.
56   EXPECT_THAT(Explain(add.get(), op::Parameter()), Eq(""));
57   EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter())),
58               Eq("has too many operands (got 2, want 1)"));
59   EXPECT_THAT(
60       Explain(add.get(), op::Add(op::Parameter(), op::Parameter())),
61       Eq("\noperand 1:\n\t"
62          "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n"
63          "doesn't match expected:\n\t"
64          "parameter"));
65   EXPECT_THAT(
66       Explain(add.get(),
67               op::Add(op::Parameter(), op::Multiply(op::Add(), op::Add()))),
68       Eq("\noperand 1:\n\t"
69          "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n"
70          "doesn't match expected:\n\t"
71          "multiply(add, add), \n"
72          "operand 0:\n\t"
73          "%param = f32[1]{0} parameter(0)\n"
74          "doesn't match expected:\n\t"
75          "add"));
76 }
77 
TEST(HloMatchersTest,CustomCallMatcher)78 TEST(HloMatchersTest, CustomCallMatcher) {
79   auto c1 =
80       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
81   auto c2 =
82       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3}));
83   auto call = HloInstruction::CreateCustomCall(
84       ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target");
85 
86   EXPECT_THAT(call.get(), op::CustomCall());
87   EXPECT_THAT(call.get(), op::CustomCall(c1.get(), c2.get()));
88   EXPECT_THAT(call.get(), op::CustomCall("foo_target"));
89   EXPECT_THAT(call.get(), op::CustomCall("foo_target", c1.get(), c2.get()));
90   EXPECT_THAT(call.get(), op::CustomCall(::testing::StartsWith("foo")));
91   EXPECT_THAT(call.get(),
92               op::CustomCall(::testing::Not(::testing::StartsWith("bar"))));
93 
94   // Wrong number of operands.
95   EXPECT_THAT(call.get(), ::testing::Not(op::CustomCall(c1.get())));
96 
97   // Call target does not match.
98   EXPECT_THAT(call.get(),
99               ::testing::Not(op::CustomCall(::testing::StartsWith("bar"))));
100 
101   EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")),
102               R"(custom-call with call target that isn't equal to "bar")");
103   EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")),
104               R"(custom-call with call target that is equal to "foo_target")");
105 }
106 
TEST(HloMatchersTest,ShapeMatcher)107 TEST(HloMatchersTest, ShapeMatcher) {
108   auto p0 = HloInstruction::CreateParameter(
109       0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param");
110 
111   EXPECT_THAT(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {5, 7})));
112   EXPECT_THAT(p0.get(), op::Shape("f32[5,7]"));
113   EXPECT_THAT(
114       p0.get(),
115       ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {5, 7}))));
116   EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]")));
117   EXPECT_THAT(p0.get(),
118               ::testing::Not(op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))));
119   EXPECT_THAT(p0.get(), ::testing::Not(op::Shape("f32[7,5]")));
120   EXPECT_THAT(
121       p0.get(),
122       ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {7, 5}))));
123   EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[7,5]")));
124   EXPECT_THAT(p0.get(),
125               op::Shape(ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1})));
126   EXPECT_THAT(p0.get(), op::Shape("f32[5,7]{0,1}"));
127   EXPECT_THAT(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout(
128                             F32, {5, 7}, {0, 1})));
129   EXPECT_THAT(p0.get(), op::ShapeWithLayout("f32[5,7]{0,1}"));
130   EXPECT_THAT(p0.get(),
131               ::testing::Not(op::ShapeWithLayout(
132                   ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {1, 0}))));
133   EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]{1,0}")));
134 
135   EXPECT_THAT(Explain(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))),
136               "%param = f32[5,7]{0,1} parameter(0) has incorrect shape "
137               "(expected: f32[7,5])");
138   EXPECT_THAT(
139       Explain(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout(
140                             F32, {7, 5}, {1, 0}))),
141       "%param = f32[5,7]{0,1} parameter(0) has incorrect shape "
142       "(expected: f32[7,5]{1,0})");
143 }
144 
TEST(HloMatchersTest,ShardingMatcher)145 TEST(HloMatchersTest, ShardingMatcher) {
146   auto p0 = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {5}),
147                                             "param.0");
148   p0->clear_sharding();
149   auto p1 = HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {7}),
150                                             "param.1");
151   p1->set_sharding(HloSharding::AssignDevice(1));
152 
153   auto tuple_shape = ShapeUtil::MakeTupleShape(
154       {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}),
155        ShapeUtil::MakeShape(F32, {11})});
156   auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2");
157   Array<int64> assignment({2});
158   assignment.SetValues({0, 1});
159   auto sharding = HloSharding::Tuple(
160       tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1),
161                     HloSharding::Replicate()});
162   p2->set_sharding(sharding);
163 
164   EXPECT_THAT(p0.get(), op::NoSharding());
165   EXPECT_THAT(p0.get(),
166               ::testing::Not(op::Sharding(HloSharding::AssignDevice(1))));
167   EXPECT_THAT(p1.get(), ::testing::Not(op::NoSharding()));
168   EXPECT_THAT(p1.get(),
169               ::testing::Not(op::Sharding(HloSharding::AssignDevice(0))));
170   EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1)));
171 
172   EXPECT_THAT(
173       p2.get(),
174       op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}"));
175 
176   EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))),
177               "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: "
178               "{maximal device=1})");
179   EXPECT_THAT(Explain(p1.get(), op::NoSharding()),
180               "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} "
181               "expected to have no sharding.");
182   EXPECT_THAT(Explain(p1.get(), op::Sharding(HloSharding::AssignDevice(0))),
183               "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} "
184               "has incorrect sharding (expected: {maximal device=0})");
185 }
186 
TEST(HloMatchersTest,DotMatcher)187 TEST(HloMatchersTest, DotMatcher) {
188   string hlo_string = R"(
189 HloModule DotOperationFusion_TransposeFusion
190 
191 ENTRY DotOperationFusion_TransposeFusion {
192   arg0 = f32[1,256] parameter(0)
193   arg1 = f32[256,1024] parameter(1)
194   ROOT dot = f32[1,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
195 }
196 )";
197 
198   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
199                           ParseHloString(hlo_string));
200   HloInstruction* root = module->entry_computation()->root_instruction();
201 
202   EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1),
203                             /*lhs_contracting_dim=*/1,
204                             /*rhs_contracting_dim=*/0));
205 
206   EXPECT_THAT(
207       Explain(root, op::Dot(op::Parameter(0), op::Parameter(1),
208                             /*lhs_contracting_dim=*/0,
209                             /*rhs_contracting_dim=*/0)),
210       "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
211       "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong "
212       "lhs_contracting_dimensions (got {1} want {0})");
213 
214   EXPECT_THAT(
215       Explain(root, op::Dot(op::Parameter(0), op::Parameter(1),
216                             /*lhs_contracting_dim=*/1,
217                             /*rhs_contracting_dim=*/1)),
218       "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} "
219       "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong "
220       "rhs_contracting_dimensions (got {0} want {1})");
221 }
222 
TEST(HloMatchersTest,ComparisonMatcher)223 TEST(HloMatchersTest, ComparisonMatcher) {
224   auto shape = ShapeUtil::MakeShape(F32, {1});
225   auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
226   auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
227   auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
228                                           ComparisonDirection::kEq);
229   auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
230                                           ComparisonDirection::kNe);
231   auto add =
232       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
233   auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
234                                           ComparisonDirection::kLe);
235 
236   EXPECT_THAT(eq.get(), op::Compare());
237   EXPECT_THAT(eq.get(), op::Eq());
238   EXPECT_THAT(ne.get(), op::Compare());
239   EXPECT_THAT(ne.get(), op::Ne());
240   EXPECT_THAT(le.get(),
241               op::Compare(op::Parameter(0),
242                           op::Add(op::Parameter(0), op::Parameter(1))));
243   EXPECT_THAT(le.get(), op::Le(op::Parameter(0),
244                                op::Add(op::Parameter(0), op::Parameter(1))));
245 
246   EXPECT_THAT(Explain(eq.get(), op::Add()), Eq(""));
247   EXPECT_THAT(Explain(eq.get(), op::Ne()),
248               Eq("has wrong comparison direction (got EQ, want NE)"));
249 }
250 
251 }  // namespace
252 }  // namespace xla
253