1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16
17 #include <initializer_list>
18 #include <memory>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/reference_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/test.h"
37 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
38 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
39 #include "tensorflow/compiler/xla/tests/test_utils.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/platform/test.h"
46 #include "tensorflow/core/platform/test_benchmark.h"
47 #include "tensorflow/core/platform/types.h"
48
49 namespace xla {
50 namespace {
51
52 static std::array<bool, 2> use_bf16_params{true, false};
53
54 // Test fixture for the HloEvaluator.
55 //
56 // In bf16 mode, all f32 shapes are converted to bf16 before running.
57 class HloEvaluatorTest : public HloTestBase {
58 public:
HloEvaluatorTest()59 HloEvaluatorTest() : use_bfloat16_(false) {}
60
Evaluate(absl::Span<const Literal * const> arg_literals={})61 StatusOr<Literal> Evaluate(
62 absl::Span<const Literal* const> arg_literals = {}) {
63 if (use_bfloat16_) {
64 HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
65 }
66 return evaluator_.Evaluate(*m_->entry_computation(), arg_literals);
67 }
68
69 // Evaluate function that takes in a local module instead of using m_
70 // that is in HloTestBase. Once m_ in HloTestBase is
71 // removed, this should be the default Evaluate function.
EvaluateWithModule(HloModule * module,absl::Span<const Literal * const> arg_literals={})72 Literal EvaluateWithModule(
73 HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
74 if (use_bfloat16_) {
75 HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
76 }
77 return evaluator_.Evaluate(*module->entry_computation(), arg_literals)
78 .ConsumeValueOrDie();
79 }
80
TestUnaryOp(HloOpcode opcode,Literal expected,Literal input,float aabs=0)81 void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
82 float aabs = 0) {
83 HloComputation::Builder b(TestName());
84 auto c1 =
85 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
86 b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
87 m_->AddEntryComputation(b.Build());
88
89 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
90
91 auto element_type = expected.shape().element_type();
92 if (element_type == F32 || element_type == F64) {
93 ErrorSpec error(aabs);
94 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
95 } else {
96 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
97 }
98 }
99
TestBinaryOp(HloOpcode opcode,Literal expected,Literal lhs,Literal rhs)100 void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
101 Literal rhs) {
102 HloComputation::Builder b(TestName());
103 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
104 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
105 b.AddInstruction(
106 HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
107 m_->AddEntryComputation(b.Build());
108
109 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
110
111 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
112 }
113
TestTernaryOp(HloOpcode opcode,Literal expected,Literal src0,Literal src1,Literal src2)114 void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0,
115 Literal src1, Literal src2) {
116 HloComputation::Builder b(TestName());
117 auto operand0 =
118 b.AddInstruction(HloInstruction::CreateConstant(std::move(src0)));
119 auto operand1 =
120 b.AddInstruction(HloInstruction::CreateConstant(std::move(src1)));
121 auto operand2 =
122 b.AddInstruction(HloInstruction::CreateConstant(std::move(src2)));
123 b.AddInstruction(HloInstruction::CreateTernary(
124 expected.shape(), opcode, operand0, operand1, operand2));
125 m_->AddEntryComputation(b.Build());
126
127 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
128
129 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
130 }
131
132 protected:
HloEvaluatorTest(bool use_bfloat16)133 explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {}
134 HloEvaluator evaluator_;
135
136 const bool use_bfloat16_;
137 std::unique_ptr<HloModule> m_ = CreateNewVerifiedModule();
138 };
139
140 // Lets you write TEST_Ps that run twice, once with and once without bf16.
141 class HloEvaluatorBf16Test : public ::testing::WithParamInterface<bool>,
142 public HloEvaluatorTest {
143 protected:
HloEvaluatorBf16Test()144 HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {}
145 };
146
147 INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test,
148 ::testing::ValuesIn(use_bf16_params));
149
150 // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
151 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesClamp)152 TEST_P(HloEvaluatorBf16Test, DoesClamp) {
153 auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
154 auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
155 auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
156
157 Shape shape = low.shape();
158 HloComputation::Builder b(TestName());
159 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
160 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
161 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
162 b.AddInstruction(
163 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
164 m_->AddEntryComputation(b.Build());
165
166 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
167
168 auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
169
170 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
171 }
172
173 // Verifies that clamping of int64 does not cause loss of precision
TEST_P(HloEvaluatorBf16Test,DoesClampInt64)174 TEST_P(HloEvaluatorBf16Test, DoesClampInt64) {
175 auto ones = [](int bits) { return (int64{1} << bits) - 1; };
176
177 auto low =
178 LiteralUtil::CreateR2<int64>({{0, ones(54)}, {ones(54), ones(58)}});
179 auto value = LiteralUtil::CreateR2<int64>({{0, ones(56)}, {0, ones(58)}});
180 auto high = LiteralUtil::CreateR2<int64>(
181 {{ones(54), ones(55)}, {ones(56), ones(58)}});
182
183 Shape shape = low.shape();
184 HloComputation::Builder b(TestName());
185 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
186 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
187 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
188 b.AddInstruction(
189 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
190 m_->AddEntryComputation(b.Build());
191
192 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
193
194 auto expected =
195 LiteralUtil::CreateR2<int64>({{0, ones(55)}, {ones(54), ones(58)}});
196
197 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
198 }
199
TEST_P(HloEvaluatorBf16Test,DISABLED_DoesClampSpecialBroadcast)200 TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) {
201 auto low = LiteralUtil::CreateR0<float>(0.f);
202 auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
203 auto high = LiteralUtil::CreateR0<float>(1.f);
204
205 Shape shape = value.shape();
206 HloComputation::Builder b(TestName());
207 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
208 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
209 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
210 b.AddInstruction(
211 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
212 m_->AddEntryComputation(b.Build());
213
214 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
215
216 auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
217
218 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
219 }
220
221 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
222 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesSelect)223 TEST_P(HloEvaluatorBf16Test, DoesSelect) {
224 auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
225 auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
226 auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
227
228 Shape shape = on_true.shape();
229 HloComputation::Builder b(TestName());
230 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
231 auto c2 =
232 b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true)));
233 auto c3 =
234 b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false)));
235 b.AddInstruction(
236 HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
237 m_->AddEntryComputation(b.Build());
238
239 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
240
241 auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
242
243 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
244 }
245
246 // Verifies that HloEvaluator evaluates a HLO instruction that performs
247 // element-wise addition with 2 operands.
TEST_F(HloEvaluatorTest,DoesAdd)248 TEST_F(HloEvaluatorTest, DoesAdd) {
249 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
250 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
251 auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}});
252 TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
253 std::move(rhs));
254 }
255 // Verifies that HloEvaluator evaluates a HLO instruction that performs
256 // element-wise and with 2 operands.
TEST_P(HloEvaluatorBf16Test,DoesAnd)257 TEST_P(HloEvaluatorBf16Test, DoesAnd) {
258 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
259 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
260 auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {4, 4}});
261 TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
262 std::move(rhs));
263 }
264 // Verifies that HloEvaluator evaluates a HLO instruction that performs
265 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesOr)266 TEST_F(HloEvaluatorTest, DoesOr) {
267 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
268 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
269 auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-100, 4}});
270 TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
271 std::move(rhs));
272 }
273 // Verifies that HloEvaluator evaluates a HLO instruction that performs
274 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesXor)275 TEST_F(HloEvaluatorTest, DoesXor) {
276 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
277 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
278 auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-104, 0}});
279 TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
280 std::move(rhs));
281 }
282 // Verifies that HloEvaluator evaluates a HLO instruction that performs
283 // element-wise multiply with 2 operands.
TEST_F(HloEvaluatorTest,DoesMultiply)284 TEST_F(HloEvaluatorTest, DoesMultiply) {
285 auto lhs = LiteralUtil::CreateR2<int32>({{-1, 0}, {-100, 4}});
286 auto rhs = LiteralUtil::CreateR2<int32>(
287 {{std::numeric_limits<int32>::min(), 4}, {4, 4}});
288 auto expected = LiteralUtil::CreateR2<int32>(
289 {{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
290 TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
291 std::move(rhs));
292 }
293 // Verifies that HloEvaluator evaluates a HLO instruction that performs
294 // element-wise divide with 2 operands.
TEST_F(HloEvaluatorTest,DoesDivideInt64)295 TEST_F(HloEvaluatorTest, DoesDivideInt64) {
296 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
297 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
298 auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}});
299 TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
300 std::move(rhs));
301 }
302
TEST_F(HloEvaluatorTest,DoesClampS64)303 TEST_F(HloEvaluatorTest, DoesClampS64) {
304 auto low = LiteralUtil::CreateR1<int64>(
305 {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL});
306 auto value = LiteralUtil::CreateR1<int64>(
307 {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL});
308 auto high = LiteralUtil::CreateR1<int64>(
309 {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL});
310 auto expected = LiteralUtil::CreateR1<int64>(
311 {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL});
312 TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low),
313 std::move(value), std::move(high));
314 }
315
TEST_P(HloEvaluatorBf16Test,DoesDivideDouble)316 TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) {
317 auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
318 auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
319 auto expected =
320 LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
321 TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
322 std::move(rhs));
323 }
324
325 // Verifies that HloEvaluator evaluates a HLO instruction that performs
326 // element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest,DoesAbsR2)327 TEST_F(HloEvaluatorTest, DoesAbsR2) {
328 auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
329 auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
330 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
331 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR0)332 TEST_P(HloEvaluatorBf16Test, DoesAbsR0) {
333 auto operand = LiteralUtil::CreateR0<float>(-1.0f);
334 auto expected = LiteralUtil::CreateR0<float>(1.0f);
335 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
336 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR1WithZeroSize)337 TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) {
338 auto operand = LiteralUtil::CreateR1<float>({});
339 auto expected = LiteralUtil::CreateR1<float>({});
340 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
341 }
TEST_F(HloEvaluatorTest,DoesNegateR2)342 TEST_F(HloEvaluatorTest, DoesNegateR2) {
343 auto operand = LiteralUtil::CreateR2<int32>(
344 {{0, std::numeric_limits<int32>::min()}, {-1, 4}});
345 auto expected = LiteralUtil::CreateR2<int32>(
346 {{0, std::numeric_limits<int>::min()}, {1, -4}});
347 TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
348 }
TEST_P(HloEvaluatorBf16Test,DoesCosR2)349 TEST_P(HloEvaluatorBf16Test, DoesCosR2) {
350 auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
351 auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
352 TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
353 use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
354 }
TEST_P(HloEvaluatorBf16Test,DoesSinR2)355 TEST_P(HloEvaluatorBf16Test, DoesSinR2) {
356 auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
357 auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
358 TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
359 use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
360 }
TEST_F(HloEvaluatorTest,DoesNotR2)361 TEST_F(HloEvaluatorTest, DoesNotR2) {
362 auto operand =
363 LiteralUtil::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
364 {-1, std::numeric_limits<int>::max()}});
365 auto expected =
366 LiteralUtil::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
367 {0, std::numeric_limits<int>::min()}});
368 TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
369 }
370
TEST_F(HloEvaluatorTest,DoesRealC128)371 TEST_F(HloEvaluatorTest, DoesRealC128) {
372 auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
373 auto expected_real = LiteralUtil::CreateR1<double>({1, -100});
374 TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x));
375 }
376
TEST_F(HloEvaluatorTest,DoesImagC128)377 TEST_F(HloEvaluatorTest, DoesImagC128) {
378 auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
379 auto expected_imag = LiteralUtil::CreateR1<double>({0, 4});
380 TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
381 }
382
383 // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
384 // constant operands.
TEST_F(HloEvaluatorTest,DoesTraverseInstructions)385 TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
386 auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
387 auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
388 auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
389 std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
390
391 Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
392
393 HloComputation::Builder b(TestName());
394 auto param_lhs =
395 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
396 auto param_rhs =
397 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
398 auto lhs_instruction = b.AddInstruction(HloInstruction::CreateBinary(
399 shape, HloOpcode::kAdd, param_lhs, param_rhs));
400
401 auto param_rhs2 =
402 b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2"));
403 b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd,
404 lhs_instruction, param_rhs2));
405 m_->AddEntryComputation(b.Build());
406
407 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args));
408
409 auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
410
411 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
412 }
413
414 // Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesReshape)415 TEST_F(HloEvaluatorTest, DoesReshape) {
416 HloComputation::Builder b(TestName());
417 const int64 dimensions[] = {11, 8, 7, 5, 9};
418 TF_ASSERT_OK_AND_ASSIGN(auto literal,
419 LiteralUtil::CreateRandomLiteral<F32>(
420 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
421 auto literal_clone = literal.Clone();
422 HloInstruction* literal_instruction =
423 b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
424
425 Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
426 const int64 permutation[] = {1, 2, 0, 4, 3};
427 b.AddInstruction(
428 HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
429 m_->AddEntryComputation(b.Build());
430
431 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
432
433 using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
434 result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
435 std::vector<int64> rindexes = Permute(permutation, indices);
436 EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
437 });
438 }
439
440 // Verifies Broadcast operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesBroadcast)441 TEST_F(HloEvaluatorTest, DoesBroadcast) {
442 HloComputation::Builder b(TestName());
443 auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
444 auto output_literal = LiteralUtil::CreateR3<int32>(
445 {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
446 HloInstruction* literal_instruction = b.AddInstruction(
447 HloInstruction::CreateConstant(std::move(input_literal)));
448 b.AddInstruction(HloInstruction::CreateBroadcast(
449 output_literal.shape(), literal_instruction, {1, 2}));
450 m_->AddEntryComputation(b.Build());
451
452 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
453
454 EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
455 }
456
TEST_F(HloEvaluatorTest,DoesBroadcastScalar)457 TEST_F(HloEvaluatorTest, DoesBroadcastScalar) {
458 HloComputation::Builder b(TestName());
459 auto input_literal = LiteralUtil::CreateR0<int32>(111);
460 auto output_literal = LiteralUtil::CreateR2<int32>(
461 {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
462
463 HloInstruction* literal_instruction = b.AddInstruction(
464 HloInstruction::CreateConstant(std::move(input_literal)));
465 // Broadcast dimension should be empty in the case of scalars.
466 b.AddInstruction(HloInstruction::CreateBroadcast(
467 output_literal.shape(), literal_instruction,
468 /*broadcast_dimensions=*/{}));
469 m_->AddEntryComputation(b.Build());
470
471 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
472
473 EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
474 }
475
TEST_F(HloEvaluatorTest,DoesConcatenateSimple)476 TEST_F(HloEvaluatorTest, DoesConcatenateSimple) {
477 HloComputation::Builder b(TestName());
478
479 HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
480 LiteralUtil::CreateR2<int64>({{-1, -2}, {100, 200}})));
481 HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
482 LiteralUtil::CreateR2<int64>({{-2, -3}, {-100, -200}})));
483
484 std::vector<HloInstruction*> operands = {operand1, operand2};
485
486 Shape shape = ShapeUtil::MakeShape(S64, {4, 2});
487 b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
488
489 m_->AddEntryComputation(b.Build());
490
491 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
492
493 auto expected = LiteralUtil::CreateR2<int64>(
494 {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
495 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
496 }
497
TEST_F(HloEvaluatorTest,ConcatenateHandlesShapeWithZeroElement)498 TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
499 HloComputation::Builder b(TestName());
500
501 HloInstruction* operand1 = b.AddInstruction(
502 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({100, 200})));
503 HloInstruction* operand2 = b.AddInstruction(
504 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({})));
505
506 std::vector<HloInstruction*> operands = {operand1, operand2};
507
508 Shape shape = ShapeUtil::MakeShape(S64, {2});
509 b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
510
511 m_->AddEntryComputation(b.Build());
512
513 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
514
515 auto expected = LiteralUtil::CreateR1<int64>({100, 200});
516 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
517 }
518
TEST_P(HloEvaluatorBf16Test,ConvertWithSameLayout)519 TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) {
520 HloComputation::Builder b(TestName());
521
522 auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
523 auto expected =
524 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
525 ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
526 expected.shape()));
527
528 HloInstruction* constant = b.AddInstruction(
529 HloInstruction::CreateConstant(std::move(input_literal)));
530 b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
531 m_->AddEntryComputation(b.Build());
532
533 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
534
535 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
536 }
537
TEST_P(HloEvaluatorBf16Test,ConvertWithDifferentLayout)538 TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) {
539 HloComputation::Builder b(TestName());
540
541 auto input_literal = LiteralUtil::CreateR2WithLayout<int32>(
542 {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
543 auto expected = LiteralUtil::CreateR2WithLayout<float>(
544 {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
545 ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
546 expected.shape()));
547
548 HloInstruction* constant = b.AddInstruction(
549 HloInstruction::CreateConstant(std::move(input_literal)));
550 b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
551 m_->AddEntryComputation(b.Build());
552
553 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
554
555 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
556 }
557
CreatePaddingConfig(std::initializer_list<std::array<int64,3>> padding_dimensions)558 PaddingConfig CreatePaddingConfig(
559 std::initializer_list<std::array<int64, 3>> padding_dimensions) {
560 PaddingConfig padding_config;
561
562 for (auto& paddings_per_dim : padding_dimensions) {
563 auto dimension = padding_config.add_dimensions();
564 dimension->set_edge_padding_low(paddings_per_dim[0]);
565 dimension->set_edge_padding_high(paddings_per_dim[1]);
566 dimension->set_interior_padding(paddings_per_dim[2]);
567 }
568 return padding_config;
569 }
570
TEST_F(HloEvaluatorTest,Pad2DIntegerArrayWithZeroDimension)571 TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
572 auto operand = LiteralUtil::CreateR2<int32>({{}, {}});
573 HloComputation::Builder b(TestName());
574 auto operand_instruction =
575 b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
576
577 constexpr int32 kPadValue = 10;
578 auto pad_value = LiteralUtil::CreateR0<int32>(kPadValue);
579 auto padding_value_instruction =
580 b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
581
582 auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}});
583 Shape shape = ShapeUtil::MakeShape(S32, {5, 2});
584 b.AddInstruction(HloInstruction::CreatePad(
585 shape, operand_instruction, padding_value_instruction, padding_config));
586 m_->AddEntryComputation(b.Build());
587
588 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
589
590 auto expected = LiteralUtil::CreateR2<int32>(
591 {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
592
593 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
594 }
595
TEST_P(HloEvaluatorBf16Test,Pad4DFloatArrayWithInteriorPadding)596 TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) {
597 HloComputation::Builder b(TestName());
598
599 Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
600 auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
601 HloInstruction* input_instruction =
602 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
603 constexpr float kPadValue = 1.5;
604 auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
605 HloInstruction* pad_instruction =
606 b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
607
608 Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1});
609 auto r4_padding_on_dim0_dim1 =
610 CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}});
611 b.AddInstruction(HloInstruction::CreatePad(
612 shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
613 m_->AddEntryComputation(b.Build());
614
615 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
616
617 auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
618 expected_array->Fill(kPadValue);
619 (*expected_array)(1, 0, 0, 0) = 1.0f;
620 (*expected_array)(1, 2, 0, 0) = 2.0f;
621 (*expected_array)(4, 0, 0, 0) = 3.0f;
622 (*expected_array)(4, 2, 0, 0) = 4.0f;
623 (*expected_array)(7, 0, 0, 0) = 5.0f;
624 (*expected_array)(7, 2, 0, 0) = 6.0f;
625
626 auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
627
628 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
629 }
630
TEST_P(HloEvaluatorBf16Test,NegativePadding2D)631 TEST_P(HloEvaluatorBf16Test, NegativePadding2D) {
632 HloComputation::Builder b(TestName());
633
634 // input_array:
635 // f32[4,3] {
636 // { 1, 2, 3 },
637 // { 5, 6, 7 },
638 // { 9, 10, 11 },
639 // { 13, 14, 15 },
640 // }
641 auto input_array = absl::make_unique<Array2D<float>>(4, 3);
642 input_array->FillUnique(1.0f);
643 auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
644 HloInstruction* input_instruction =
645 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
646
647 auto pad_value_instruction = b.AddInstruction(
648 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
649
650 auto r2_padding_on_dim0_dim1 =
651 CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
652 Shape shape = ShapeUtil::MakeShape(F32, {1, 5});
653 b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
654 pad_value_instruction,
655 r2_padding_on_dim0_dim1));
656
657 m_->AddEntryComputation(b.Build());
658
659 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
660
661 // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
662 auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
663 (*expected_array)(0, 0) = 7.0f;
664 (*expected_array)(0, 1) = 2.718f;
665 (*expected_array)(0, 2) = 2.718f;
666 (*expected_array)(0, 3) = 2.718f;
667 (*expected_array)(0, 4) = 2.718f;
668 auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
669
670 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
671 }
672
TEST_P(HloEvaluatorBf16Test,NegativeAndInteriorPadding2D)673 TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) {
674 HloComputation::Builder b(TestName());
675
676 // f32[4,3] {
677 // { 1, 2, 3 },
678 // { 5, 6, 7 },
679 // { 9, 10, 11 },
680 // { 13, 14, 15 },
681 // }
682 auto input_array = absl::make_unique<Array2D<float>>(4, 3);
683 input_array->FillUnique(1.0f);
684 auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
685 HloInstruction* input_instruction =
686 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
687
688 auto pad_value_instruction = b.AddInstruction(
689 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
690
691 PaddingConfig padding_config = MakeNoPaddingConfig(2);
692
693 // Negative padding that results in zero dimensions.
694 auto r2_padding_on_dim0_dim1 =
695 CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}});
696
697 Shape shape = ShapeUtil::MakeShape(F32, {0, 9});
698 b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
699 pad_value_instruction,
700 r2_padding_on_dim0_dim1));
701
702 m_->AddEntryComputation(b.Build());
703
704 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
705
706 auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
707 auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
708
709 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
710 }
711
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank1)712 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) {
713 HloComputation::Builder b(TestName());
714
715 // lhs:
716 // f32[4,1] {
717 // { 1 },
718 // { 2 },
719 // { 3 },
720 // { 4 },
721 // }
722 auto lhs_array = absl::make_unique<Array2D<float>>(4, 1);
723 lhs_array->FillUnique(1.0f);
724 auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
725 HloInstruction* lhs_instruction =
726 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
727
728 // rhs:
729 // f32[2] { 1, 2 },
730 auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
731 HloInstruction* rhs_instruction =
732 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
733
734 Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
735 DotDimensionNumbers dot_dnums;
736 dot_dnums.add_lhs_contracting_dimensions(1);
737 dot_dnums.add_rhs_contracting_dimensions(0);
738 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
739 rhs_instruction, dot_dnums,
740 DefaultPrecisionConfig(2)));
741 m_->AddEntryComputation(b.Build());
742
743 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
744
745 // clang-format off
746 auto expected_array = Array2D<float>({
747 {1.f, 2.f},
748 {2.f, 4.f},
749 {3.f, 6.f},
750 {4.f, 8.f},
751 });
752 // clang-format on
753 auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
754
755 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
756 }
757
TEST_P(HloEvaluatorBf16Test,DotRank1AndRank2)758 TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) {
759 HloComputation::Builder b(TestName());
760
761 // lhs:
762 // f32[3]
763 // { 1, 2, 3 },
764 auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
765 HloInstruction* lhs_instruction =
766 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
767
768 // rhs:
769 // f32[3,2] {
770 // { 1, 2 },
771 // { 3, 4 },
772 // { 5, 6 },
773 // }
774 auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
775 rhs_array->FillUnique(1.0f);
776 auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
777 HloInstruction* rhs_instruction =
778 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
779
780 Shape shape = ShapeUtil::MakeShape(F32, {2});
781 DotDimensionNumbers dot_dnums;
782 dot_dnums.add_lhs_contracting_dimensions(0);
783 dot_dnums.add_rhs_contracting_dimensions(0);
784 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
785 rhs_instruction, dot_dnums,
786 DefaultPrecisionConfig(2)));
787 m_->AddEntryComputation(b.Build());
788
789 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
790
791 auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
792
793 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
794 }
795
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank2)796 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) {
797 HloComputation::Builder b(TestName());
798
799 // lhs:
800 // f32[4,3] {
801 // { 1, 2, 3 },
802 // { 5, 6, 7 },
803 // { 9, 10, 11 },
804 // { 13, 14, 15 },
805 // }
806 auto lhs_array = absl::make_unique<Array2D<float>>(4, 3);
807 lhs_array->FillUnique(1.0f);
808 auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
809 HloInstruction* lhs_instruction =
810 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
811
812 // rhs:
813 // f32[3,2] {
814 // { 1, 2 },
815 // { 3, 4 },
816 // { 5, 6 },
817 // }
818 auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
819 rhs_array->FillUnique(1.0f);
820 auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
821 HloInstruction* rhs_instruction =
822 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
823
824 Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
825 DotDimensionNumbers dot_dnums;
826 dot_dnums.add_lhs_contracting_dimensions(1);
827 dot_dnums.add_rhs_contracting_dimensions(0);
828 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
829 rhs_instruction, dot_dnums,
830 DefaultPrecisionConfig(2)));
831 m_->AddEntryComputation(b.Build());
832
833 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
834
835 auto expected_array = Array2D<float>({
836 {22.f, 28.f},
837 {58.f, 76.f},
838 {94.f, 124.f},
839 {130.f, 172.f},
840 });
841 auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
842
843 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
844 }
845
TEST_P(HloEvaluatorBf16Test,DotRank4AndRank4)846 TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) {
847 HloComputation::Builder b(TestName());
848
849 auto lhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
850 lhs_array->FillIota(1.0f);
851 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*lhs_array);
852 HloInstruction* lhs_instruction =
853 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
854
855 auto rhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
856 rhs_array->FillIota(2.0f);
857 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*rhs_array);
858 HloInstruction* rhs_instruction =
859 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
860
861 Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1});
862 DotDimensionNumbers dot_dnums;
863
864 dot_dnums.add_lhs_batch_dimensions(0);
865 dot_dnums.add_rhs_batch_dimensions(0);
866 dot_dnums.add_lhs_contracting_dimensions(1);
867 dot_dnums.add_lhs_contracting_dimensions(2);
868 dot_dnums.add_rhs_contracting_dimensions(1);
869 dot_dnums.add_rhs_contracting_dimensions(2);
870 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
871 rhs_instruction, dot_dnums,
872 DefaultPrecisionConfig(2)));
873 m_->AddEntryComputation(b.Build());
874
875 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
876
877 float expected_1 = 0;
878 for (float i = 1.0f; i < 7.0f; ++i) {
879 expected_1 += i * i + i;
880 }
881 float expected_2 = 0;
882 for (float i = 7.0f; i < 13.0f; ++i) {
883 expected_2 += i * i + i;
884 }
885 auto expected_array = Array3D<float>({{{expected_1}}, {{expected_2}}});
886 auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
887
888 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
889 }
890
TEST_P(HloEvaluatorBf16Test,SimpleConv1D)891 TEST_P(HloEvaluatorBf16Test, SimpleConv1D) {
892 HloComputation::Builder b(TestName());
893
894 Array3D<float> lhs_array = {{{1, 2, 3}}};
895 auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
896 HloInstruction* lhs_instruction =
897 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
898
899 Array3D<float> rhs_array = {{{3.f, 4.f}}};
900 auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
901 HloInstruction* rhs_instruction =
902 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
903
904 Window window;
905 WindowDimension dim;
906 dim.set_size(2);
907 dim.set_stride(1);
908 dim.set_padding_low(0);
909 dim.set_padding_high(1);
910 dim.set_window_dilation(1);
911 dim.set_base_dilation(1);
912 *window.add_dimensions() = dim;
913
914 ConvolutionDimensionNumbers dnums;
915 dnums.set_input_batch_dimension(0);
916 dnums.set_output_batch_dimension(0);
917 dnums.set_input_feature_dimension(1);
918 dnums.set_output_feature_dimension(1);
919 dnums.add_input_spatial_dimensions(2);
920 dnums.add_output_spatial_dimensions(2);
921
922 dnums.set_kernel_output_feature_dimension(0);
923 dnums.set_kernel_input_feature_dimension(1);
924 dnums.add_kernel_spatial_dimensions(2);
925
926 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
927 b.AddInstruction(HloInstruction::CreateConvolve(
928 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
929 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
930 m_->AddEntryComputation(b.Build());
931
932 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
933
934 Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
935 auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
936
937 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
938 }
939
TEST_P(HloEvaluatorBf16Test,Simple4x4Conv2DWith2x2Kernel)940 TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) {
941 HloComputation::Builder b(TestName());
942
943 Array4D<float> lhs_array(1, 1, 4, 4);
944 // clang-format off
945 lhs_array.FillWithYX(Array2D<float>({
946 {1, 2, 3, 4 },
947 {5, 6, 7, 8 },
948 {9, 10, 11, 12},
949 {13, 14, 15, 16},
950 }));
951 // clang-format on
952 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
953 HloInstruction* lhs_instruction =
954 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
955
956 Array4D<float> rhs_array(1, 1, 2, 2);
957 // clang-format off
958 rhs_array.FillWithYX(Array2D<float>({
959 {5, 6},
960 {7, 8},
961 }));
962 // clang-format on
963 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
964 HloInstruction* rhs_instruction =
965 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
966
967 Window window;
968 WindowDimension dim;
969 dim.set_size(2);
970 dim.set_stride(1);
971 dim.set_padding_low(0);
972 dim.set_padding_high(1);
973 dim.set_window_dilation(1);
974 dim.set_base_dilation(1);
975 *window.add_dimensions() = dim;
976 *window.add_dimensions() = dim;
977
978 ConvolutionDimensionNumbers dnums =
979 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
980
981 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
982 b.AddInstruction(HloInstruction::CreateConvolve(
983 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
984 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
985 m_->AddEntryComputation(b.Build());
986
987 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
988
989 Array4D<float> expected_array(1, 1, 4, 4);
990 // clang-format off
991 expected_array.FillWithYX(Array2D<float>({
992 {100, 126, 152, 76},
993 {204, 230, 256, 124},
994 {308, 334, 360, 172},
995 {149, 160, 171, 80},
996 }));
997 // clang-format on
998 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
999
1000 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1001 }
1002
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensionsReversed)1003 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) {
1004 HloComputation::Builder b(TestName());
1005
1006 // clang-format off
1007 // Input dimensions: [feature=2, height=3, batch=1, width=4]
1008 Array4D<float> input({
1009 {{{1, 2, 3, 4}},
1010 {{5, 6, 7, 8}},
1011 {{9, 10, 11, 12}}},
1012 {{{13, 14, 15, 16}},
1013 {{17, 18, 19, 20}},
1014 {{21, 22, 23, 24}}}
1015 });
1016 // Weight dimensions:
1017 // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1018 Array4D<float> weight({{
1019 {{1, 7, 13},
1020 {4, 10, 16}},
1021 {{2, 8, 14},
1022 {5, 11, 17}},
1023 {{3, 9, 15},
1024 {6, 12, 18}}
1025 }});
1026 // clang-format on
1027
1028 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1029 HloInstruction* lhs_instruction =
1030 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1031
1032 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1033 HloInstruction* rhs_instruction =
1034 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1035 rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
1036 rhs_instruction->shape(), rhs_instruction, {3, 1}));
1037
1038 Window window;
1039 WindowDimension dim;
1040 dim.set_size(3);
1041 dim.set_stride(1);
1042 dim.set_padding_low(0);
1043 dim.set_padding_high(0);
1044 dim.set_window_dilation(1);
1045 dim.set_base_dilation(1);
1046 dim.set_window_reversal(true);
1047 *window.add_dimensions() = dim;
1048 *window.add_dimensions() = dim;
1049
1050 ConvolutionDimensionNumbers dnums;
1051 dnums.set_input_batch_dimension(2);
1052 dnums.set_output_batch_dimension(2);
1053 dnums.set_input_feature_dimension(0);
1054 dnums.set_output_feature_dimension(0);
1055 dnums.add_input_spatial_dimensions(1);
1056 dnums.add_output_spatial_dimensions(1);
1057 dnums.add_input_spatial_dimensions(3);
1058 dnums.add_output_spatial_dimensions(3);
1059
1060 dnums.set_kernel_output_feature_dimension(0);
1061 dnums.set_kernel_input_feature_dimension(2);
1062 dnums.add_kernel_spatial_dimensions(3);
1063 dnums.add_kernel_spatial_dimensions(1);
1064
1065 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1066 b.AddInstruction(HloInstruction::CreateConvolve(
1067 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1068 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1069 m_->AddEntryComputation(b.Build());
1070
1071 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1072
1073 // clang-format off
1074 // Result dimensions: [feature=1, height=1, batch=1, width=2]
1075 Array4D<float> expected_array({{{{2514, 2685}}}});
1076 Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1077 // clang-format on
1078 auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1079 use_bfloat16_ ? expected_array_bf16 : expected_array);
1080
1081 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1082 }
1083
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensions)1084 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) {
1085 HloComputation::Builder b(TestName());
1086
1087 // clang-format off
1088 // Input dimensions: [feature=2, height=3, batch=1, width=4]
1089 Array4D<float> input({
1090 {{{1, 2, 3, 4}},
1091 {{5, 6, 7, 8}},
1092 {{9, 10, 11, 12}}},
1093 {{{13, 14, 15, 16}},
1094 {{17, 18, 19, 20}},
1095 {{21, 22, 23, 24}}}
1096 });
1097 // Weight dimensions:
1098 // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1099 Array4D<float> weight({{
1100 {{1, 7, 13},
1101 {4, 10, 16}},
1102 {{2, 8, 14},
1103 {5, 11, 17}},
1104 {{3, 9, 15},
1105 {6, 12, 18}}
1106 }});
1107 // clang-format on
1108
1109 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1110 HloInstruction* lhs_instruction =
1111 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1112
1113 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1114 HloInstruction* rhs_instruction =
1115 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1116
1117 Window window;
1118 WindowDimension dim;
1119 dim.set_size(3);
1120 dim.set_stride(1);
1121 dim.set_padding_low(0);
1122 dim.set_padding_high(0);
1123 dim.set_window_dilation(1);
1124 dim.set_base_dilation(1);
1125 *window.add_dimensions() = dim;
1126 *window.add_dimensions() = dim;
1127
1128 ConvolutionDimensionNumbers dnums;
1129 dnums.set_input_batch_dimension(2);
1130 dnums.set_output_batch_dimension(2);
1131 dnums.set_input_feature_dimension(0);
1132 dnums.set_output_feature_dimension(0);
1133 dnums.add_input_spatial_dimensions(1);
1134 dnums.add_output_spatial_dimensions(1);
1135 dnums.add_input_spatial_dimensions(3);
1136 dnums.add_output_spatial_dimensions(3);
1137
1138 dnums.set_kernel_output_feature_dimension(0);
1139 dnums.set_kernel_input_feature_dimension(2);
1140 dnums.add_kernel_spatial_dimensions(3);
1141 dnums.add_kernel_spatial_dimensions(1);
1142
1143 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1144 b.AddInstruction(HloInstruction::CreateConvolve(
1145 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1146 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1147 m_->AddEntryComputation(b.Build());
1148
1149 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1150
1151 // clang-format off
1152 // Result dimensions: [feature=1, height=1, batch=1, width=2]
1153 Array4D<float> expected_array({{{{2514, 2685}}}});
1154 Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1155 // clang-format on
1156 auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1157 use_bfloat16_ ? expected_array_bf16 : expected_array);
1158
1159 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1160 }
1161
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithHighPadding)1162 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) {
1163 HloComputation::Builder b(TestName());
1164
1165 Array4D<float> lhs_array(1, 1, 4, 4);
1166 // clang-format off
1167 lhs_array.FillWithYX(Array2D<float>({
1168 {1, 2, 3, 4 },
1169 {5, 6, 7, 8 },
1170 {9, 10, 11, 12},
1171 {13, 14, 15, 16},
1172 }));
1173 // clang-format on
1174 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1175 HloInstruction* lhs_instruction =
1176 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1177
1178 Array4D<float> rhs_array(1, 1, 2, 2);
1179 // clang-format off
1180 rhs_array.FillWithYX(Array2D<float>({
1181 {5, 6},
1182 {7, 8},
1183 }));
1184 // clang-format on
1185 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1186 HloInstruction* rhs_instruction =
1187 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1188
1189 Window window;
1190 WindowDimension dim;
1191 dim.set_size(2);
1192 dim.set_stride(1);
1193 dim.set_padding_low(0);
1194 dim.set_padding_high(1);
1195 dim.set_window_dilation(1);
1196 dim.set_base_dilation(2);
1197 *window.add_dimensions() = dim;
1198 *window.add_dimensions() = dim;
1199
1200 ConvolutionDimensionNumbers dnums =
1201 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1202
1203 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
1204 b.AddInstruction(HloInstruction::CreateConvolve(
1205 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1206 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1207 m_->AddEntryComputation(b.Build());
1208
1209 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1210
1211 Array4D<float> expected_array(1, 1, 7, 7);
1212 expected_array.FillWithYX(Array2D<float>({
1213 {5, 12, 10, 18, 15, 24, 20},
1214 {35, 48, 42, 56, 49, 64, 56},
1215 {25, 36, 30, 42, 35, 48, 40},
1216 {63, 80, 70, 88, 77, 96, 84},
1217 {45, 60, 50, 66, 55, 72, 60},
1218 {91, 112, 98, 120, 105, 128, 112},
1219 {65, 84, 70, 90, 75, 96, 80},
1220 }));
1221 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1222
1223 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1224 }
1225
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithLowAndHighPadding)1226 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) {
1227 HloComputation::Builder b(TestName());
1228
1229 Array4D<float> lhs_array(1, 1, 4, 4);
1230 // clang-format off
1231 lhs_array.FillWithYX(Array2D<float>({
1232 {1, 2, 3, 4 },
1233 {5, 6, 7, 8 },
1234 {9, 10, 11, 12},
1235 {13, 14, 15, 16},
1236 }));
1237 // clang-format on
1238 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1239 HloInstruction* lhs_instruction =
1240 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1241
1242 Array4D<float> rhs_array(1, 1, 2, 2);
1243 // clang-format off
1244 rhs_array.FillWithYX(Array2D<float>({
1245 {5, 6},
1246 {7, 8},
1247 }));
1248 // clang-format on
1249 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1250 HloInstruction* rhs_instruction =
1251 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1252
1253 Window window;
1254 WindowDimension dim;
1255 dim.set_size(2);
1256 dim.set_stride(1);
1257 dim.set_padding_low(1);
1258 dim.set_padding_high(1);
1259 dim.set_window_dilation(1);
1260 dim.set_base_dilation(2);
1261 *window.add_dimensions() = dim;
1262 *window.add_dimensions() = dim;
1263
1264 ConvolutionDimensionNumbers dnums =
1265 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1266
1267 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
1268 b.AddInstruction(HloInstruction::CreateConvolve(
1269 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1270 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1271 m_->AddEntryComputation(b.Build());
1272
1273 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1274
1275 Array4D<float> expected_array(1, 1, 8, 8);
1276 expected_array.FillWithYX(Array2D<float>({
1277 {8, 7, 16, 14, 24, 21, 32, 28},
1278 {6, 5, 12, 10, 18, 15, 24, 20},
1279 {40, 35, 48, 42, 56, 49, 64, 56},
1280 {30, 25, 36, 30, 42, 35, 48, 40},
1281 {72, 63, 80, 70, 88, 77, 96, 84},
1282 {54, 45, 60, 50, 66, 55, 72, 60},
1283 {104, 91, 112, 98, 120, 105, 128, 112},
1284 {78, 65, 84, 70, 90, 75, 96, 80},
1285 }));
1286 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1287
1288 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1289 }
1290
TEST_P(HloEvaluatorBf16Test,DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides)1291 TEST_P(HloEvaluatorBf16Test,
1292 DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) {
1293 HloComputation::Builder b(TestName());
1294
1295 Array4D<float> lhs_array(1, 1, 4, 4);
1296 // clang-format off
1297 lhs_array.FillWithYX(Array2D<float>({
1298 {1, 2, 3, 4 },
1299 {5, 6, 7, 8 },
1300 {9, 10, 11, 12},
1301 {13, 14, 15, 16},
1302 }));
1303 // clang-format on
1304 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1305 HloInstruction* lhs_instruction =
1306 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1307
1308 Array4D<float> rhs_array(1, 1, 2, 3);
1309 // clang-format off
1310 rhs_array.FillWithYX(Array2D<float>({
1311 {5, 6, 7},
1312 {8, 9, 10},
1313 }));
1314 // clang-format on
1315 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1316 HloInstruction* rhs_instruction =
1317 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1318
1319 Window window;
1320 WindowDimension dim;
1321 dim.set_size(2);
1322 dim.set_stride(1);
1323 dim.set_padding_low(2);
1324 dim.set_padding_high(2);
1325 dim.set_window_dilation(2);
1326 dim.set_base_dilation(2);
1327 *window.add_dimensions() = dim;
1328 dim.set_size(3);
1329 dim.set_stride(3);
1330 dim.set_padding_low(2);
1331 dim.set_padding_high(-1);
1332 dim.set_window_dilation(1);
1333 dim.set_base_dilation(3);
1334 *window.add_dimensions() = dim;
1335
1336 ConvolutionDimensionNumbers dnums =
1337 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1338
1339 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
1340 b.AddInstruction(HloInstruction::CreateConvolve(
1341 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1342 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1343 m_->AddEntryComputation(b.Build());
1344
1345 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1346
1347 Array4D<float> expected_array(1, 1, 9, 3);
1348 expected_array.FillWithYX(Array2D<float>({
1349 {10, 20, 30},
1350 {0, 0, 0},
1351 {57, 74, 91},
1352 {0, 0, 0},
1353 {125, 142, 159},
1354 {0, 0, 0},
1355 {193, 210, 227},
1356 {0, 0, 0},
1357 {91, 98, 105},
1358 }));
1359 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1360
1361 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1362 }
1363
TEST_P(HloEvaluatorBf16Test,Conv2DGroupedConvolution)1364 TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) {
1365 HloComputation::Builder b(TestName());
1366 std::vector<int64> input_dims = {1, 2, 2, 4};
1367 std::vector<int64> filter_dims = {2, 2, 2, 8};
1368 Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
1369 Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
1370 // Tensorflow dimension numbers for 2D convolution.
1371 ConvolutionDimensionNumbers dnums;
1372 dnums.set_input_batch_dimension(0);
1373 dnums.set_output_batch_dimension(0);
1374 dnums.add_input_spatial_dimensions(1);
1375 dnums.add_output_spatial_dimensions(1);
1376 dnums.add_input_spatial_dimensions(2);
1377 dnums.add_output_spatial_dimensions(2);
1378 dnums.set_input_feature_dimension(3);
1379 dnums.set_output_feature_dimension(3);
1380 dnums.add_kernel_spatial_dimensions(0);
1381 dnums.add_kernel_spatial_dimensions(1);
1382 dnums.set_kernel_input_feature_dimension(2);
1383 dnums.set_kernel_output_feature_dimension(3);
1384
1385 Window window;
1386 WindowDimension dim;
1387 dim.set_size(2);
1388 dim.set_stride(1);
1389 dim.set_padding_low(0);
1390 dim.set_padding_high(0);
1391 dim.set_window_dilation(1);
1392 dim.set_base_dilation(1);
1393 *window.add_dimensions() = dim;
1394 *window.add_dimensions() = dim;
1395
1396 std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
1397 std::iota(input_elems.begin(), input_elems.end(), -7);
1398 auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
1399 auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1400 HloInstruction* lhs_instruction =
1401 b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
1402
1403 std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1404 std::iota(filter_elems.begin(), filter_elems.end(), -31);
1405 auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
1406 auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1407 HloInstruction* rhs_instruction =
1408 b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
1409
1410 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
1411 b.AddInstruction(HloInstruction::CreateConvolve(
1412 shape, lhs_instruction, rhs_instruction,
1413 /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums,
1414 DefaultPrecisionConfig(2)));
1415 m_->AddEntryComputation(b.Build());
1416
1417 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1418
1419 Array4D<float> expected_array(1, 1, 1, 8);
1420 expected_array.FillWithYX(
1421 Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
1422 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1423 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1424 }
1425
1426 class HloEvaluatorPreciseReduceTest : public HloTestBase {};
1427
1428 // Tests that Reduce doesn't lose precision when adding many numbers (because
1429 // it accumulates its result in a double).
TEST_F(HloEvaluatorPreciseReduceTest,AddReductionPrecisionTest)1430 TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
1431 auto m = CreateNewVerifiedModule();
1432 HloComputation::Builder b(TestName());
1433
1434 constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
1435 std::vector<float> v(kNumElements, 1.0f);
1436 HloInstruction* arg_instruction = b.AddInstruction(
1437 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
1438 HloInstruction* init_value = b.AddInstruction(
1439 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1440
1441 HloComputation::Builder add_computation("add");
1442 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1443 auto param_lhs = add_computation.AddInstruction(
1444 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1445 auto param_rhs = add_computation.AddInstruction(
1446 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1447 add_computation.AddInstruction(HloInstruction::CreateBinary(
1448 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1449 auto add_func = m->AddEmbeddedComputation(add_computation.Build());
1450
1451 HloInstruction* reduce_instruction = b.AddInstruction(
1452 HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
1453 /*dimensions_to_reduce=*/{0}, add_func));
1454 m->AddEntryComputation(b.Build());
1455
1456 HloEvaluator hlo_eval;
1457 Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
1458 LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
1459 }
1460
1461 // Reducing many numbers should be fast because it doesn't create
1462 // intermediate Literals; the microbenchmark should finish in < 1 msec.
BM_ReducePrecisely(int num_iters)1463 void BM_ReducePrecisely(int num_iters) {
1464 tensorflow::testing::StopTiming();
1465 HloComputation::Builder b("BM_ReducePrecisely");
1466 HloModuleConfig config;
1467 config.set_debug_options(GetDebugOptionsFromFlags());
1468 HloModule module("BM_ReducePrecisely", config);
1469
1470 constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
1471 std::vector<float> v(kNumElements, 1.0f);
1472 HloInstruction* arg_instruction = b.AddInstruction(
1473 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
1474 auto init_value = b.AddInstruction(
1475 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1476
1477 HloComputation::Builder add_computation("add");
1478 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1479 auto param_lhs = add_computation.AddInstruction(
1480 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1481 auto param_rhs = add_computation.AddInstruction(
1482 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1483 add_computation.AddInstruction(HloInstruction::CreateBinary(
1484 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1485 auto add_func = module.AddEmbeddedComputation(add_computation.Build());
1486
1487 HloInstruction* reduce_instruction = b.AddInstruction(
1488 HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
1489 /*dimensions_to_reduce=*/{0}, add_func));
1490 module.AddEntryComputation(b.Build());
1491
1492 HloEvaluator hlo_eval;
1493 tensorflow::testing::StartTiming();
1494 hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
1495 tensorflow::testing::StopTiming();
1496 }
1497
1498 BENCHMARK(BM_ReducePrecisely);
1499
TEST_P(HloEvaluatorBf16Test,ReduceAdd)1500 TEST_P(HloEvaluatorBf16Test, ReduceAdd) {
1501 HloComputation::Builder b(TestName());
1502
1503 // arg:
1504 // f32[2,3] {
1505 // { 1, 2, 3 },
1506 // { 5, 6, 7 },
1507 // }
1508 auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
1509 arg_array->FillUnique(1.0f);
1510 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
1511
1512 HloInstruction* arg_instruction =
1513 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1514
1515 auto init_value = b.AddInstruction(
1516 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1517
1518 HloComputation::Builder add_computation("add");
1519 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1520 auto param_lhs = add_computation.AddInstruction(
1521 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1522 auto param_rhs = add_computation.AddInstruction(
1523 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1524 add_computation.AddInstruction(HloInstruction::CreateBinary(
1525 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1526 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
1527
1528 Shape shape = ShapeUtil::MakeShape(F32, {2});
1529 b.AddInstruction(
1530 HloInstruction::CreateReduce(shape, arg_instruction, init_value,
1531 /*dimensions_to_reduce=*/{1}, add_func));
1532
1533 m_->AddEntryComputation(b.Build());
1534
1535 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1536
1537 auto expected = LiteralUtil::CreateR1<float>({6, 18});
1538
1539 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1540 }
1541
TEST_P(HloEvaluatorBf16Test,ReduceWindowMax)1542 TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) {
1543 HloComputation::Builder b(TestName());
1544
1545 // arg:
1546 // f32[2,3] {
1547 // { 1, 2, 3 },
1548 // { 5, 6, 7 },
1549 // }
1550 auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
1551 arg_array->FillUnique(1.0f);
1552 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
1553
1554 HloInstruction* arg_instruction =
1555 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1556
1557 auto init_value = b.AddInstruction(
1558 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1559
1560 HloComputation::Builder max_computation("max");
1561 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1562 auto param_lhs = max_computation.AddInstruction(
1563 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1564 auto param_rhs = max_computation.AddInstruction(
1565 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1566 max_computation.AddInstruction(HloInstruction::CreateBinary(
1567 scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
1568 auto max_func = m_->AddEmbeddedComputation(max_computation.Build());
1569
1570 Window window;
1571 WindowDimension dim;
1572 dim.set_size(2);
1573 dim.set_stride(1);
1574 dim.set_padding_low(0);
1575 dim.set_padding_high(0);
1576 dim.set_window_dilation(1);
1577 dim.set_base_dilation(1);
1578 *window.add_dimensions() = dim;
1579 *window.add_dimensions() = dim;
1580
1581 Shape shape = ShapeUtil::MakeShape(F32, {1, 2});
1582 b.AddInstruction(HloInstruction::CreateReduceWindow(
1583 shape, arg_instruction, init_value, window, max_func));
1584
1585 m_->AddEntryComputation(b.Build());
1586
1587 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1588
1589 auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
1590 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1591 }
1592
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxWindowDilation)1593 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) {
1594 HloComputation::Builder b(TestName());
1595
1596 // arg:
1597 // f32[3,3] {
1598 // { 1, 2, 3 },
1599 // { 5, 6, 7 },
1600 // { 9, 10, 11 },
1601 // }
1602 auto arg_array = absl::make_unique<Array2D<float>>(3, 3);
1603 arg_array->FillUnique(1.0f);
1604 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
1605
1606 HloInstruction* arg_instruction =
1607 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1608
1609 auto init_value = b.AddInstruction(
1610 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1611
1612 HloComputation::Builder max_computation("max");
1613 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1614 auto param_lhs = max_computation.AddInstruction(
1615 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1616 auto param_rhs = max_computation.AddInstruction(
1617 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1618 max_computation.AddInstruction(HloInstruction::CreateBinary(
1619 scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
1620 auto max_func = m_->AddEmbeddedComputation(max_computation.Build());
1621
1622 Window window;
1623 WindowDimension dim;
1624 dim.set_size(2);
1625 dim.set_stride(1);
1626 dim.set_padding_low(0);
1627 dim.set_padding_high(0);
1628 dim.set_window_dilation(2);
1629 dim.set_base_dilation(1);
1630 *window.add_dimensions() = dim;
1631 *window.add_dimensions() = dim;
1632
1633 Shape shape = ShapeUtil::MakeShape(F32, {1, 1});
1634 b.AddInstruction(HloInstruction::CreateReduceWindow(
1635 shape, arg_instruction, init_value, window, max_func));
1636
1637 m_->AddEntryComputation(b.Build());
1638
1639 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1640
1641 auto expected = LiteralUtil::CreateR2<float>({{11}});
1642 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1643 }
1644
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd)1645 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) {
1646 HloComputation::Builder b(TestName());
1647
1648 // arg:
1649 // f32[2,3] {
1650 // { 1, 2, 3 },
1651 // { 5, 6, 7 },
1652 // }
1653 auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
1654 arg_array->FillUnique(1.0f);
1655 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
1656
1657 HloInstruction* arg_instruction =
1658 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1659
1660 auto init_value = b.AddInstruction(
1661 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1662
1663 HloComputation::Builder add_computation("add");
1664 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1665 auto param_lhs = add_computation.AddInstruction(
1666 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1667 auto param_rhs = add_computation.AddInstruction(
1668 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1669 add_computation.AddInstruction(HloInstruction::CreateBinary(
1670 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1671 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
1672
1673 Window window;
1674 WindowDimension dim;
1675 dim.set_size(1);
1676 dim.set_stride(1);
1677 dim.set_padding_low(0);
1678 dim.set_padding_high(0);
1679 dim.set_window_dilation(1);
1680 dim.set_base_dilation(1);
1681 *window.add_dimensions() = dim;
1682 dim.set_size(2);
1683 dim.set_stride(1);
1684 dim.set_padding_low(1);
1685 dim.set_padding_high(0);
1686 dim.set_window_dilation(1);
1687 dim.set_base_dilation(1);
1688 *window.add_dimensions() = dim;
1689
1690 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1691 b.AddInstruction(HloInstruction::CreateReduceWindow(
1692 shape, arg_instruction, init_value, window, add_func));
1693
1694 m_->AddEntryComputation(b.Build());
1695
1696 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1697
1698 auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
1699 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1700 }
1701
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd6D)1702 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
1703 HloComputation::Builder b(TestName());
1704
1705 // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
1706 std::vector<int64> input_dims(6, 4);
1707 Literal arg_literal =
1708 LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
1709
1710 HloInstruction* arg_instruction =
1711 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1712
1713 auto init_value = b.AddInstruction(
1714 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1715
1716 HloComputation::Builder add_computation("add");
1717 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1718 auto param_lhs = add_computation.AddInstruction(
1719 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1720 auto param_rhs = add_computation.AddInstruction(
1721 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1722 add_computation.AddInstruction(HloInstruction::CreateBinary(
1723 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1724 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
1725
1726 Window window;
1727
1728 WindowDimension trivial_dim;
1729 trivial_dim.set_size(1);
1730 trivial_dim.set_stride(1);
1731 trivial_dim.set_padding_low(0);
1732 trivial_dim.set_padding_high(0);
1733 trivial_dim.set_window_dilation(1);
1734 trivial_dim.set_base_dilation(1);
1735
1736 WindowDimension active_dim;
1737 active_dim.set_size(2);
1738 active_dim.set_stride(1);
1739 active_dim.set_padding_low(0);
1740 active_dim.set_padding_high(0);
1741 active_dim.set_window_dilation(1);
1742 active_dim.set_base_dilation(1);
1743
1744 *window.add_dimensions() = trivial_dim;
1745 *window.add_dimensions() = active_dim;
1746 *window.add_dimensions() = active_dim;
1747 *window.add_dimensions() = active_dim;
1748 *window.add_dimensions() = trivial_dim;
1749 *window.add_dimensions() = trivial_dim;
1750
1751 Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4});
1752 b.AddInstruction(HloInstruction::CreateReduceWindow(
1753 shape, arg_instruction, init_value, window, add_func));
1754
1755 m_->AddEntryComputation(b.Build());
1756
1757 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1758
1759 std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
1760 Literal result_literal =
1761 LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
1762 EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
1763 }
1764
TEST_P(HloEvaluatorBf16Test,StridedSlice)1765 TEST_P(HloEvaluatorBf16Test, StridedSlice) {
1766 HloComputation::Builder b(TestName());
1767
1768 // arg:
1769 // f32[3,5] {
1770 // { 1, 2, 3, 4, 5 },
1771 // { 9, 10, 11, 12, 13 },
1772 // { 17, 18, 19, 20, 21 },
1773 // }
1774 auto operand_array = absl::make_unique<Array2D<float>>(3, 5);
1775 operand_array->FillUnique(1.0f);
1776 auto operand_literal =
1777 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
1778
1779 HloInstruction* operand = b.AddInstruction(
1780 HloInstruction::CreateConstant(std::move(operand_literal)));
1781
1782 Shape shape = ShapeUtil::MakeShape(F32, {2, 1});
1783 b.AddInstruction(HloInstruction::CreateSlice(shape, operand,
1784 /*start_indices=*/{0, 2},
1785 /*limit_indices=*/{3, 5},
1786 /*strides=*/{2, 3}));
1787 m_->AddEntryComputation(b.Build());
1788
1789 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1790
1791 auto expected = LiteralUtil::CreateR2<float>({
1792 {3},
1793 {19},
1794 });
1795
1796 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1797 }
1798
TEST_P(HloEvaluatorBf16Test,DynamicSlice)1799 TEST_P(HloEvaluatorBf16Test, DynamicSlice) {
1800 HloComputation::Builder b(TestName());
1801
1802 // arg:
1803 // f32[2,4] {
1804 // { 1, 2, 3, 4 },
1805 // { 5, 6, 7, 8 },
1806 // }
1807 auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
1808 operand_array->FillUnique(1.0f);
1809 auto operand_literal =
1810 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
1811
1812 HloInstruction* operand = b.AddInstruction(
1813 HloInstruction::CreateConstant(std::move(operand_literal)));
1814
1815 auto zero = b.AddInstruction(
1816 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
1817 auto one = b.AddInstruction(
1818 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
1819
1820 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1821 b.AddInstruction(
1822 HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3}));
1823 m_->AddEntryComputation(b.Build());
1824
1825 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1826
1827 auto expected = LiteralUtil::CreateR2<float>({
1828 {2, 3, 4},
1829 {6, 7, 8},
1830 });
1831
1832 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1833 }
1834
1835 // Verifies that the HloEvaluator's implementation goes along with existing
1836 // backends' behavior, although this is not required by the spec.
TEST_P(HloEvaluatorBf16Test,DynamicSliceModSlice)1837 TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) {
1838 HloComputation::Builder b(TestName());
1839
1840 // arg:
1841 // f32[2,4] {
1842 // { 1, 2, 3, 4 },
1843 // { 5, 6, 7, 8 },
1844 // }
1845 auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
1846 operand_array->FillUnique(1.0f);
1847 auto operand_literal =
1848 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
1849
1850 HloInstruction* operand = b.AddInstruction(
1851 HloInstruction::CreateConstant(std::move(operand_literal)));
1852
1853 auto two = b.AddInstruction(
1854 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
1855 auto one = b.AddInstruction(
1856 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
1857
1858 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1859 b.AddInstruction(
1860 HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3}));
1861 m_->AddEntryComputation(b.Build());
1862
1863 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1864
1865 auto expected = LiteralUtil::CreateR2<float>({
1866 {2, 3, 4},
1867 {6, 7, 8},
1868 });
1869
1870 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1871 }
1872
TEST_P(HloEvaluatorBf16Test,DynamicSliceUpdate)1873 TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) {
1874 HloComputation::Builder b(TestName());
1875
1876 // arg:
1877 // f32[2,3] {
1878 // { 1, 2, 3 },
1879 // { 5, 6, 7 },
1880 // }
1881 auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
1882 operand_array->FillUnique(1.0);
1883 auto operand_literal =
1884 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
1885
1886 HloInstruction* operand = b.AddInstruction(
1887 HloInstruction::CreateConstant(std::move(operand_literal)));
1888
1889 auto zero = b.AddInstruction(
1890 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
1891 auto one = b.AddInstruction(
1892 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
1893
1894 auto update = b.AddInstruction(HloInstruction::CreateConstant(
1895 LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
1896
1897 Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
1898 b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1899 shape, operand, update, {zero, one}));
1900 m_->AddEntryComputation(b.Build());
1901
1902 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1903
1904 auto expected = LiteralUtil::CreateR2<double>({
1905 {1, -2, -3},
1906 {5, -6, -7},
1907 });
1908
1909 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1910 }
1911
TEST_P(HloEvaluatorBf16Test,SetAndGetTuples)1912 TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) {
1913 HloComputation::Builder b(TestName());
1914
1915 // arg:
1916 // f32[2,3] {
1917 // { 1, 2, 3 },
1918 // { 5, 6, 7 },
1919 // }
1920 auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
1921 operand_array->FillUnique(1.0);
1922 auto operand_literal2 =
1923 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
1924
1925 HloInstruction* operand2 = b.AddInstruction(
1926 HloInstruction::CreateConstant(std::move(operand_literal2)));
1927 HloInstruction* operand1 = b.AddInstruction(
1928 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
1929
1930 auto tuple =
1931 b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
1932
1933 Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
1934 b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1));
1935
1936 m_->AddEntryComputation(b.Build());
1937
1938 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1939
1940 auto expected = LiteralUtil::CreateR2<double>({
1941 {1, 2, 3},
1942 {5, 6, 7},
1943 });
1944
1945 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1946 }
1947
TEST_P(HloEvaluatorBf16Test,SetAndGetNestedTuples)1948 TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) {
1949 HloComputation::Builder b(TestName());
1950
1951 // arg:
1952 // f32[2,3] {
1953 // { 1, 2, 3 },
1954 // { 5, 6, 7 },
1955 // }
1956 auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
1957 operand_array->FillUnique(1.0);
1958
1959 HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
1960 LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
1961 HloInstruction* operand1 = b.AddInstruction(
1962 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
1963
1964 auto tuple1 =
1965 b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
1966 auto tuple2 =
1967 b.AddInstruction(HloInstruction::CreateTuple({operand2, operand2}));
1968
1969 auto outer_tuple =
1970 b.AddInstruction(HloInstruction::CreateTuple({tuple1, tuple2}));
1971
1972 b.AddInstruction(
1973 HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1));
1974
1975 m_->AddEntryComputation(b.Build());
1976
1977 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1978
1979 auto result_inner_literal =
1980 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
1981 auto expected =
1982 LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
1983
1984 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1985 }
1986
TEST_P(HloEvaluatorBf16Test,Reverse)1987 TEST_P(HloEvaluatorBf16Test, Reverse) {
1988 HloComputation::Builder b(TestName());
1989
1990 // Input shape is float[4x3x2x1].
1991 // clang-format off
1992 Array4D<float> input({
1993 {{{1.0f}, {2.0f}},
1994 {{3.0f}, {4.0f}},
1995 {{5.0f}, {6.0f}}},
1996 {{{7.0f}, {8.0f}},
1997 {{9.0f}, {10.0f}},
1998 {{11.0f}, {12.0f}}},
1999 {{{13.0f}, {14.0f}},
2000 {{15.0f}, {16.0f}},
2001 {{17.0f}, {18.0f}}},
2002 {{{19.0f}, {20.0f}},
2003 {{21.0f}, {22.0f}},
2004 {{23.0f}, {24.0f}}},
2005 });
2006 // clang-format on
2007 auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
2008 HloInstruction* operand = b.AddInstruction(
2009 HloInstruction::CreateConstant(std::move(operand_literal)));
2010
2011 const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
2012 b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
2013 m_->AddEntryComputation(b.Build());
2014
2015 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2016
2017 // clang-format off
2018 auto expected = LiteralUtil::CreateR4FromArray4D<float>({
2019 {{{23.0f}, {24.0f}},
2020 {{21.0f}, {22.0f}},
2021 {{19.0f}, {20.0f}}},
2022
2023 {{{17.0f}, {18.0f}},
2024 {{15.0f}, {16.0f}},
2025 {{13.0f}, {14.0f}}},
2026
2027 {{{11.0f}, {12.0f}},
2028 {{9.0f}, {10.0f}},
2029 {{7.0f}, {8.0f}}},
2030
2031 {{{5.0f}, {6.0f}},
2032 {{3.0f}, {4.0f}},
2033 {{1.0f}, {2.0f}}},
2034 });
2035 // clang-format on
2036
2037 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2038 }
2039
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutions)2040 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) {
2041 HloComputation::Builder b(TestName());
2042 Shape shape = ShapeUtil::MakeShape(F32, {4});
2043
2044 HloInstruction* param0 =
2045 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
2046 HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
2047 shape, HloOpcode::kMultiply, param0, param0));
2048 HloInstruction* add = b.AddInstruction(
2049 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, square));
2050
2051 // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
2052 HloEvaluator evaluator;
2053 Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
2054 Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
2055 TF_ASSERT_OK_AND_ASSIGN(
2056 Literal result,
2057 evaluator.EvaluateWithSubstitutions(
2058 add, {{param0, ¶m0_literal}, {square, &square_literal}}));
2059 EXPECT_TRUE(LiteralTestUtil::Equal(
2060 LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
2061 }
2062
2063 // Check that EvaluateWithSubstitutions works if one of the operands to the op
2064 // we're evaluating is a constant.
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutionsWithConstantOperand)2065 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) {
2066 HloComputation::Builder b(TestName());
2067 Shape shape = ShapeUtil::MakeShape(F32, {4});
2068
2069 HloInstruction* param0 =
2070 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
2071 HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
2072 shape, HloOpcode::kMultiply, param0, param0));
2073 HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
2074 LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
2075 HloInstruction* add = b.AddInstruction(
2076 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
2077
2078 // Evaluate add with square = {10, 20, 30, 40}.
2079 HloEvaluator evaluator;
2080 Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
2081 TF_ASSERT_OK_AND_ASSIGN(
2082 Literal result,
2083 evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}));
2084 EXPECT_TRUE(LiteralTestUtil::Equal(
2085 LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
2086 }
2087
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV1)2088 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
2089 const char* hlo_text = R"(
2090 HloModule TensorFlowGatherV1
2091
2092 ENTRY main {
2093 operand = s32[3,3] parameter(0)
2094 indices = s32[2] parameter(1)
2095 ROOT gather = s32[2,3] gather(operand, indices),
2096 offset_dims={1},
2097 collapsed_slice_dims={0},
2098 start_index_map={0},
2099 index_vector_dim=1,
2100 slice_sizes={1, 3}
2101 }
2102 )";
2103 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2104 Literal operand =
2105 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2106 Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
2107 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2108 EXPECT_TRUE(LiteralTestUtil::Equal(
2109 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), result));
2110 }
2111
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV2)2112 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
2113 const char* hlo_text = R"(
2114 HloModule TensorFlowGatherV2
2115
2116 ENTRY main {
2117 operand = s32[3,3] parameter(0)
2118 indices = s32[2] parameter(1)
2119 ROOT gather = s32[3,2] gather(operand, indices),
2120 offset_dims={0},
2121 collapsed_slice_dims={1},
2122 start_index_map={1},
2123 index_vector_dim=1,
2124 slice_sizes={3, 1}
2125 }
2126 )";
2127 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2128 Literal operand =
2129 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2130 Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
2131 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2132 EXPECT_TRUE(LiteralTestUtil::Equal(
2133 LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), result));
2134 }
2135
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherMultipleBatchDims)2136 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
2137 const char* hlo_text = R"(
2138 HloModule TensorFlowGatherMultipleBatchDims
2139
2140 ENTRY main {
2141 operand = s32[3,3] parameter(0)
2142 indices = s32[2,2] parameter(1)
2143 ROOT gather = s32[2,3,2] gather(operand, indices),
2144 offset_dims={1},
2145 collapsed_slice_dims={1},
2146 start_index_map={1},
2147 index_vector_dim=2,
2148 slice_sizes={3, 1}
2149 }
2150 )";
2151 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2152 Literal operand =
2153 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2154 Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
2155 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2156 EXPECT_TRUE(LiteralTestUtil::Equal(
2157 LiteralUtil::CreateR3<int32>(
2158 {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
2159 result));
2160 }
2161
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNd)2162 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
2163 const char* hlo_text = R"(
2164 HloModule TensorFlowGatherNd
2165
2166 ENTRY main {
2167 operand = s32[3,3,2] parameter(0)
2168 indices = s32[2,2] parameter(1)
2169 ROOT gather = s32[2,2] gather(operand, indices),
2170 offset_dims={1},
2171 collapsed_slice_dims={0,1},
2172 start_index_map={0,1},
2173 index_vector_dim=1,
2174 slice_sizes={1,1,2}
2175 }
2176 )";
2177 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2178 Literal operand =
2179 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
2180 {{-4, 4}, {-5, 5}, {-6, 6}}, //
2181 {{-7, 7}, {-8, 8}, {-9, 9}}});
2182 Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
2183 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2184 EXPECT_TRUE(LiteralTestUtil::Equal(
2185 LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}), result));
2186 }
2187
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim)2188 TEST_F(HloEvaluatorTest,
2189 EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
2190 const char* hlo_text = R"(
2191 HloModule TensorFlowGatherNd
2192
2193 ENTRY main {
2194 operand = s32[3,3,2] parameter(0)
2195 indices = s32[2,2] parameter(1)
2196 ROOT gather = s32[2,2] gather(operand, indices),
2197 offset_dims={1},
2198 collapsed_slice_dims={0,1},
2199 start_index_map={0,1},
2200 index_vector_dim=0,
2201 slice_sizes={1,1,2}
2202 }
2203 )";
2204 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2205 Literal operand =
2206 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
2207 {{-4, 4}, {-5, 5}, {-6, 6}}, //
2208 {{-7, 7}, {-8, 8}, {-9, 9}}});
2209 Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
2210 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2211 EXPECT_TRUE(LiteralTestUtil::Equal(
2212 LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}), result));
2213 }
2214
TEST_F(HloEvaluatorTest,EvaluateGather_DynamicSlice)2215 TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
2216 const char* hlo_text = R"(
2217 HloModule DynamicSlice
2218
2219 ENTRY main {
2220 operand = s32[3,3] parameter(0)
2221 indices = s32[2] parameter(1)
2222 ROOT gather = s32[1,1] gather(operand, indices),
2223 offset_dims={0,1},
2224 collapsed_slice_dims={},
2225 start_index_map={0,1},
2226 index_vector_dim=0,
2227 slice_sizes={1,1}
2228 }
2229 )";
2230 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2231 Literal operand =
2232 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2233 Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
2234 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2235 EXPECT_TRUE(
2236 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}), result));
2237 }
2238
TEST_F(HloEvaluatorTest,EvaluateGather_BatchDynamicSlice)2239 TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
2240 const char* hlo_text = R"(
2241 HloModule BatchDynamicSlice
2242
2243 ENTRY main {
2244 operand = s32[3,3] parameter(0)
2245 indices = s32[2,2] parameter(1)
2246 ROOT gather = s32[2,1,1] gather(operand, indices),
2247 offset_dims={1,2},
2248 collapsed_slice_dims={},
2249 start_index_map={0,1},
2250 index_vector_dim=0,
2251 slice_sizes={1,1}
2252 }
2253 )";
2254 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2255 Literal operand =
2256 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2257 Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
2258 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2259 EXPECT_TRUE(LiteralTestUtil::Equal(
2260 LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}), result));
2261 }
2262
TEST_F(HloEvaluatorTest,EvaluateGather_ZeroDimBounds)2263 TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
2264 const char* hlo_text = R"(
2265 HloModule TensorFlowGatherV1
2266
2267 ENTRY main {
2268 operand = s32[3,0] parameter(0)
2269 indices = s32[2] parameter(1)
2270 ROOT gather = s32[2,0] gather(operand, indices),
2271 offset_dims={1},
2272 collapsed_slice_dims={0},
2273 start_index_map={0},
2274 index_vector_dim=1,
2275 slice_sizes={1, 0}
2276 }
2277 )";
2278 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2279 Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
2280 Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
2281 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2282 EXPECT_TRUE(
2283 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}), result));
2284 }
2285
TEST_F(HloEvaluatorTest,EvaluateGather_NoOutputWindowDims)2286 TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
2287 const string hlo_text = R"(
2288 HloModule GatherXd
2289
2290 ENTRY main {
2291 operand = s32[3] parameter(0)
2292 indices = s32[2,2,1] parameter(1)
2293 ROOT gather = s32[2,2] gather(operand, indices),
2294 offset_dims={},
2295 collapsed_slice_dims={0},
2296 start_index_map={0},
2297 index_vector_dim=2,
2298 slice_sizes={1}
2299 }
2300 )";
2301 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2302
2303 Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
2304 Literal start_indices =
2305 LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
2306 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2307 EXPECT_TRUE(LiteralTestUtil::Equal(
2308 LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}), result));
2309 }
2310
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV1_Update)2311 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
2312 const char* hlo_text = R"(
2313 HloModule TensorFlowScatterV1
2314
2315 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2316 lhs = s32[] parameter(0)
2317 ROOT rhs = s32[] parameter(1)
2318 }
2319
2320 ENTRY main {
2321 operand = s32[3,3] parameter(0)
2322 indices = s32[2] parameter(1)
2323 updates = s32[2,3] parameter(2)
2324 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2325 to_apply=update_s32,
2326 update_window_dims={1},
2327 inserted_window_dims={0},
2328 scatter_dims_to_operand_dims={0},
2329 index_vector_dim=1
2330 }
2331 )";
2332 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2333 Literal operand =
2334 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2335 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2336 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2337 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2338 Evaluate({&operand, &scatter_indices, &updates}));
2339 EXPECT_TRUE(LiteralTestUtil::Equal(
2340 LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
2341 result));
2342 }
2343
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV2_Update)2344 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
2345 const char* hlo_text = R"(
2346 HloModule TensorFlowScatterV2
2347
2348 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2349 lhs = s32[] parameter(0)
2350 ROOT rhs = s32[] parameter(1)
2351 }
2352
2353 ENTRY main {
2354 operand = s32[3,3] parameter(0)
2355 indices = s32[2] parameter(1)
2356 updates = s32[3,2] parameter(2)
2357 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2358 to_apply=update_s32,
2359 update_window_dims={0},
2360 inserted_window_dims={1},
2361 scatter_dims_to_operand_dims={1},
2362 index_vector_dim=1
2363 }
2364 )";
2365 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2366 Literal operand =
2367 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2368 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2369 Literal updates =
2370 LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
2371 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2372 Evaluate({&operand, &scatter_indices, &updates}));
2373 EXPECT_TRUE(LiteralTestUtil::Equal(
2374 LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
2375 result));
2376 }
2377
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Add)2378 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
2379 const char* hlo_text = R"(
2380 HloModule TensorFlowScatter
2381
2382 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2383 lhs = s32[] parameter(0)
2384 rhs = s32[] parameter(1)
2385 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2386 }
2387
2388 ENTRY main {
2389 operand = s32[3,3] parameter(0)
2390 indices = s32[2] parameter(1)
2391 updates = s32[2,3] parameter(2)
2392 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2393 to_apply=add_s32,
2394 update_window_dims={1},
2395 inserted_window_dims={0},
2396 scatter_dims_to_operand_dims={0},
2397 index_vector_dim=1
2398 }
2399 )";
2400 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2401 Literal operand =
2402 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2403 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2404 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2405 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2406 Evaluate({&operand, &scatter_indices, &updates}));
2407 EXPECT_TRUE(LiteralTestUtil::Equal(
2408 LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
2409 result));
2410 }
2411
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Mul)2412 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
2413 const char* hlo_text = R"(
2414 HloModule TensorFlowScatter
2415
2416 mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2417 lhs = s32[] parameter(0)
2418 rhs = s32[] parameter(1)
2419 ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
2420 }
2421
2422 ENTRY main {
2423 operand = s32[3,3] parameter(0)
2424 indices = s32[2] parameter(1)
2425 updates = s32[2,3] parameter(2)
2426 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2427 to_apply=mul_s32,
2428 update_window_dims={1},
2429 inserted_window_dims={0},
2430 scatter_dims_to_operand_dims={0},
2431 index_vector_dim=1
2432 }
2433 )";
2434 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2435 Literal operand =
2436 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2437 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2438 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2439 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2440 Evaluate({&operand, &scatter_indices, &updates}));
2441 EXPECT_TRUE(LiteralTestUtil::Equal(
2442 LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
2443 result));
2444 }
2445
TEST_P(HloEvaluatorBf16Test,EvaluateScatter_TensorFlowScatter_F32)2446 TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) {
2447 const char* hlo_text = R"(
2448 HloModule TensorFlowScatter
2449
2450 add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
2451 lhs = f32[] parameter(0)
2452 rhs = f32[] parameter(1)
2453 ROOT add = f32[] add(f32[] lhs, f32[] rhs)
2454 }
2455
2456 ENTRY main {
2457 operand = f32[3,3] parameter(0)
2458 indices = s32[2] parameter(1)
2459 updates = f32[2,3] parameter(2)
2460 ROOT scatter = f32[3,3] scatter(operand, indices, updates),
2461 to_apply=add_f32,
2462 update_window_dims={1},
2463 inserted_window_dims={0},
2464 scatter_dims_to_operand_dims={0},
2465 index_vector_dim=1
2466 }
2467 )";
2468 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2469 Literal operand = LiteralUtil::CreateR2<float>(
2470 {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
2471 Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
2472 Literal updates =
2473 LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
2474 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2475 Evaluate({&operand, &scatter_indices, &updates}));
2476 EXPECT_TRUE(LiteralTestUtil::Near(
2477 LiteralUtil::CreateR2<float>(
2478 {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
2479 result, ErrorSpec{0.1, 0.01}));
2480 }
2481
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_RepeatedIndices)2482 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
2483 const char* hlo_text = R"(
2484 HloModule TensorFlowScatter
2485
2486 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2487 lhs = s32[] parameter(0)
2488 rhs = s32[] parameter(1)
2489 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2490 }
2491
2492 ENTRY main {
2493 operand = s32[3,3] parameter(0)
2494 indices = s32[2] parameter(1)
2495 updates = s32[2,3] parameter(2)
2496 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2497 to_apply=add_s32,
2498 update_window_dims={1},
2499 inserted_window_dims={0},
2500 scatter_dims_to_operand_dims={0},
2501 index_vector_dim=1
2502 }
2503 )";
2504 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2505 Literal operand =
2506 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2507 Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
2508 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2509 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2510 Evaluate({&operand, &scatter_indices, &updates}));
2511 EXPECT_TRUE(LiteralTestUtil::Equal(
2512 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
2513 result));
2514 }
2515
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_MultipleBatchDims)2516 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
2517 const char* hlo_text = R"(
2518 HloModule TensorFlowScatterMultipleBatchDims
2519
2520 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2521 lhs = s32[] parameter(0)
2522 rhs = s32[] parameter(1)
2523 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2524 }
2525
2526 ENTRY main {
2527 operand = s32[3,3] parameter(0)
2528 indices = s32[2,2] parameter(1)
2529 updates = s32[2,3,2] parameter(2)
2530 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2531 to_apply=add_s32,
2532 update_window_dims={1},
2533 inserted_window_dims={1},
2534 scatter_dims_to_operand_dims={1},
2535 index_vector_dim=2
2536 }
2537 )";
2538 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2539 Literal operand =
2540 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2541 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
2542 Literal updates = LiteralUtil::CreateR3<int32>(
2543 {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
2544 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2545 Evaluate({&operand, &scatter_indices, &updates}));
2546 EXPECT_TRUE(LiteralTestUtil::Equal(
2547 LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
2548 result));
2549 }
2550
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd)2551 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
2552 const char* hlo_text = R"(
2553 HloModule TensorFlowScatterNd
2554
2555 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2556 lhs = s32[] parameter(0)
2557 ROOT rhs = s32[] parameter(1)
2558 }
2559
2560 ENTRY main {
2561 operand = s32[3,3,2] parameter(0)
2562 indices = s32[2,2] parameter(1)
2563 updates = s32[2,2] parameter(2)
2564 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
2565 to_apply=update_s32,
2566 update_window_dims={1},
2567 inserted_window_dims={0,1},
2568 scatter_dims_to_operand_dims={0,1},
2569 index_vector_dim=1
2570 }
2571 )";
2572 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2573 Literal operand =
2574 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
2575 {{-4, 4}, {-5, 5}, {-6, 6}}, //
2576 {{-7, 7}, {-8, 8}, {-9, 9}}});
2577 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
2578 Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
2579 Literal expected =
2580 LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
2581 {{-40, 40}, {-5, 5}, {-6, 6}}, //
2582 {{-7, 7}, {-8, 8}, {-9, 9}}});
2583 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2584 Evaluate({&operand, &scatter_indices, &updates}));
2585 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2586 }
2587
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim)2588 TEST_F(HloEvaluatorTest,
2589 EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) {
2590 const char* hlo_text = R"(
2591 HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
2592
2593 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2594 lhs = s32[] parameter(0)
2595 ROOT rhs = s32[] parameter(1)
2596 }
2597
2598 ENTRY main {
2599 operand = s32[3,3,2] parameter(0)
2600 indices = s32[2,2] parameter(1)
2601 updates = s32[2,2] parameter(2)
2602 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
2603 to_apply=update_s32,
2604 update_window_dims={1},
2605 inserted_window_dims={0,1},
2606 scatter_dims_to_operand_dims={0,1},
2607 index_vector_dim=0
2608 }
2609 )";
2610 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2611 Literal operand =
2612 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
2613 {{-4, 4}, {-5, 5}, {-6, 6}}, //
2614 {{-7, 7}, {-8, 8}, {-9, 9}}});
2615 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
2616 Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
2617 Literal expected =
2618 LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}}, //
2619 {{-4, 4}, {-5, 5}, {-6, 6}}, //
2620 {{-7, 7}, {-8, 8}, {-9, 9}}});
2621 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2622 Evaluate({&operand, &scatter_indices, &updates}));
2623 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2624 }
2625
TEST_F(HloEvaluatorTest,EvaluateScatter_DynamicUpdateSlice)2626 TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
2627 const char* hlo_text = R"(
2628 HloModule DynamicUpdateSlice
2629
2630 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2631 lhs = s32[] parameter(0)
2632 ROOT rhs = s32[] parameter(1)
2633 }
2634
2635 ENTRY main {
2636 operand = s32[3,3] parameter(0)
2637 indices = s32[2] parameter(1)
2638 updates = s32[1,1] parameter(2)
2639 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2640 to_apply=update_s32,
2641 update_window_dims={0,1},
2642 inserted_window_dims={},
2643 scatter_dims_to_operand_dims={0,1},
2644 index_vector_dim=0
2645 }
2646 )";
2647 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2648 Literal operand =
2649 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2650 Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
2651 Literal updates = LiteralUtil::CreateR2<int32>({{10}});
2652 Literal expected =
2653 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
2654 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2655 Evaluate({&operand, &scatter_indices, &updates}));
2656 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2657 }
2658
TEST_F(HloEvaluatorTest,EvaluateScatter_BatchDynamicUpdateSlice)2659 TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
2660 const char* hlo_text = R"(
2661 HloModule BatchDynamicUpdateSlice
2662
2663 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2664 lhs = s32[] parameter(0)
2665 ROOT rhs = s32[] parameter(1)
2666 }
2667
2668 ENTRY main {
2669 operand = s32[3,3] parameter(0)
2670 indices = s32[2,2] parameter(1)
2671 updates = s32[2,1,1] parameter(2)
2672 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2673 to_apply=update_s32,
2674 update_window_dims={1,2},
2675 inserted_window_dims={},
2676 scatter_dims_to_operand_dims={0,1},
2677 index_vector_dim=0
2678 }
2679 )";
2680 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2681 Literal operand =
2682 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2683 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
2684 Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
2685 Literal expected =
2686 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
2687 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2688 Evaluate({&operand, &scatter_indices, &updates}));
2689 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2690 }
2691
TEST_F(HloEvaluatorTest,EvaluateScatter_ZeroDimBounds)2692 TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
2693 const char* hlo_text = R"(
2694 HloModule TensorFlowScatter_ZeroDimBounds
2695
2696 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2697 lhs = s32[] parameter(0)
2698 ROOT rhs = s32[] parameter(1)
2699 }
2700
2701 ENTRY main {
2702 operand = s32[3,0] parameter(0)
2703 indices = s32[2] parameter(1)
2704 updates = s32[2,0] parameter(2)
2705 ROOT scatter = s32[3,0] scatter(operand, indices, updates),
2706 to_apply=update_s32,
2707 update_window_dims={1},
2708 inserted_window_dims={0},
2709 scatter_dims_to_operand_dims={0},
2710 index_vector_dim=1
2711 }
2712 )";
2713 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2714 Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
2715 Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2716 Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
2717 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2718 Evaluate({&operand, &scatter_indices, &updates}));
2719 EXPECT_TRUE(LiteralTestUtil::Equal(operand, result));
2720 }
2721
TEST_F(HloEvaluatorTest,EvaluateScatter_NoUpdateWindowDims)2722 TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
2723 const string hlo_text = R"(
2724 HloModule Scatter_NoUpdateWindowDims
2725
2726 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2727 lhs = s32[] parameter(0)
2728 rhs = s32[] parameter(1)
2729 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2730 }
2731
2732 ENTRY main {
2733 operand = s32[3] parameter(0)
2734 indices = s32[2,2,1] parameter(1)
2735 updates = s32[2,2] parameter(2)
2736 ROOT scatter = s32[3] scatter(operand, indices, updates),
2737 to_apply=add_s32,
2738 update_window_dims={},
2739 inserted_window_dims={0},
2740 scatter_dims_to_operand_dims={0},
2741 index_vector_dim=2
2742 }
2743 )";
2744 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2745
2746 Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
2747 Literal scatter_indices =
2748 LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
2749 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
2750 Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
2751 TF_ASSERT_OK_AND_ASSIGN(Literal result,
2752 Evaluate({&operand, &scatter_indices, &updates}));
2753 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2754 }
2755
TEST_F(HloEvaluatorTest,EvaluateScatter_NegativeIndices)2756 TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
2757 const char* hlo_text = R"(
2758 HloModule TensorFlowScatter_NegativeIndices
2759
2760 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2761 lhs = s32[] parameter(0)
2762 rhs = s32[] parameter(1)
2763 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2764 }
2765
2766 ENTRY main {
2767 operand = s32[3,3] parameter(0)
2768 indices = s32[2] parameter(1)
2769 updates = s32[2,3] parameter(2)
2770 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2771 to_apply=add_s32,
2772 update_window_dims={1},
2773 inserted_window_dims={0},
2774 scatter_dims_to_operand_dims={0},
2775 index_vector_dim=1
2776 }
2777 )";
2778 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2779 ParseAndReturnVerifiedModule(hlo_text));
2780 Literal operand =
2781 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2782 // No updates should happen for the negative indices.
2783 Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2});
2784 Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2785 EXPECT_TRUE(LiteralTestUtil::Equal(
2786 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
2787 EvaluateWithModule(module.get(),
2788 {&operand, &scatter_indices, &updates})));
2789 }
2790
TEST_F(HloEvaluatorTest,EvaluateScatter_OobIndices)2791 TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) {
2792 const string hlo_text = R"(
2793 HloModule BatchDynamicUpdateSlice
2794
2795 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2796 lhs = s32[] parameter(0)
2797 ROOT rhs = s32[] parameter(1)
2798 }
2799
2800 ENTRY main {
2801 operand = s32[3,3]{1,0} parameter(0)
2802 indices = s32[6,2]{1,0} parameter(1)
2803 updates = s32[6,1,1]{2,1,0} parameter(2)
2804 ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
2805 to_apply=update_s32,
2806 update_window_dims={1,2},
2807 inserted_window_dims={},
2808 scatter_dims_to_operand_dims={0,1},
2809 index_vector_dim=1
2810 }
2811 )";
2812 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2813 ParseAndReturnVerifiedModule(hlo_text));
2814 Literal operand =
2815 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2816 // No updates should happen for the OOB indices.
2817 Literal scatter_indices = LiteralUtil::CreateR2<int32>(
2818 {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
2819 Literal updates = LiteralUtil::CreateR3<int32>(
2820 {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
2821 EXPECT_TRUE(LiteralTestUtil::Equal(
2822 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
2823 EvaluateWithModule(module.get(),
2824 {&operand, &scatter_indices, &updates})));
2825 }
2826
TEST_F(HloEvaluatorTest,EvaluateScatter_OobUpdateWindow)2827 TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
2828 const char* hlo_text = R"(
2829 HloModule TensorFlowScatterNd_OobUpdateWindow
2830
2831 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2832 lhs = s32[] parameter(0)
2833 ROOT rhs = s32[] parameter(1)
2834 }
2835
2836 ENTRY main {
2837 operand = s32[3,3,2] parameter(0)
2838 indices = s32[1,2] parameter(1)
2839 updates = s32[1,2,2] parameter(2)
2840 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
2841 to_apply=update_s32,
2842 update_window_dims={1,2},
2843 inserted_window_dims={0},
2844 scatter_dims_to_operand_dims={0,1},
2845 index_vector_dim=1
2846 }
2847 )";
2848 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2849 ParseAndReturnVerifiedModule(hlo_text));
2850 Literal operand =
2851 LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
2852 {{-4, 4}, {-5, 5}, {-6, 6}}, //
2853 {{-7, 7}, {-8, 8}, {-9, 9}}});
2854 Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
2855 Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
2856 // Given the update window size of 2,2 and the index of 0,2, the update window
2857 // will be OOB. So, nothing should be updated.
2858 Literal expected = operand.Clone();
2859 EXPECT_TRUE(LiteralTestUtil::Equal(
2860 expected, EvaluateWithModule(module.get(),
2861 {&operand, &scatter_indices, &updates})));
2862 }
2863
2864 // Verifies that HloEvaluator evaluates a HLO instruction that performs
2865 // element-wise comparison with 2 bfloat16 operands.
TEST_F(HloEvaluatorTest,DoesCompareBF16)2866 TEST_F(HloEvaluatorTest, DoesCompareBF16) {
2867 // lhs >= rhs
2868 auto lhs = LiteralUtil::CreateR2<bfloat16>(
2869 {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
2870 {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
2871 auto rhs = LiteralUtil::CreateR2<bfloat16>(
2872 {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
2873 {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
2874 auto expected =
2875 LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
2876
2877 HloComputation::Builder b(TestName());
2878 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
2879 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
2880 b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
2881 ComparisonDirection::kGe));
2882 m_->AddEntryComputation(b.Build());
2883
2884 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2885 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2886 }
2887
TEST_P(HloEvaluatorBf16Test,Bf16Reduction)2888 TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
2889 const string hlo_text = R"(
2890 HloModule Bf16Reduction
2891
2892 add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
2893 lhs = bf16[] parameter(0)
2894 rhs = bf16[] parameter(1)
2895 ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
2896 }
2897
2898 ENTRY main {
2899 arg0 = bf16[4]{0} parameter(0)
2900 init = bf16[] constant(0)
2901 ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
2902 }
2903 )";
2904 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2905
2906 Literal arg = LiteralUtil::CreateR1<bfloat16>(
2907 {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
2908 Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
2909 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
2910 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2911 }
2912
TEST_F(HloEvaluatorTest,DontFailOnCallUnimplementedOps)2913 TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) {
2914 // Infeed triggers unimplemented error within HandleCall, and we verify that
2915 // the Evaluator does fail in such case.
2916 const string hlo_text = R"(
2917 HloModule DontFailOnCall
2918
2919 call {
2920 token0 = token[] after-all()
2921 ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
2922 }
2923
2924 ENTRY main {
2925 ROOT result = ((u32[3]{0}, pred[]), token[]) call(), to_apply=call
2926 }
2927 )";
2928 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2929 auto statusor = Evaluate();
2930 EXPECT_FALSE(statusor.status().ok());
2931 }
2932
TEST_F(HloEvaluatorTest,DontFailOnFusionWithUnimplementedOps)2933 TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) {
2934 // Infeed triggers unimplemented error within HandleFusion, and we verify that
2935 // the Evaluator does fail in such case.
2936 const string hlo_text = R"(
2937 HloModule DontFailOnFusion
2938
2939 fused_computation {
2940 token0 = token[] after-all()
2941 ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
2942 }
2943
2944 ENTRY main {
2945 ROOT result = ((u32[3]{0}, pred[]), token[]) fusion(), kind=kLoop, calls=fused_computation
2946 }
2947 )";
2948 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2949 auto statusor = Evaluate();
2950 EXPECT_FALSE(statusor.status().ok());
2951 }
2952
TEST_P(HloEvaluatorBf16Test,SliceWithDifferentLayout)2953 TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) {
2954 // Regression test for b/114735354.
2955 const string hlo_text = R"(
2956 HloModule SliceWithDifferentLayout
2957
2958 ENTRY main {
2959 arg = f32[2,2,2]{0,1,2} parameter(0)
2960 ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
2961 }
2962 )";
2963 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2964
2965 Literal arg = LiteralUtil::CreateR3WithLayout<float>(
2966 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
2967 LayoutUtil::MakeLayout({0, 1, 2}));
2968 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg}));
2969 EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
2970 }
2971
TEST_P(HloEvaluatorBf16Test,Bitcast)2972 TEST_P(HloEvaluatorBf16Test, Bitcast) {
2973 // Regression test for b/114735354.
2974 constexpr absl::string_view hlo_text_base = R"(
2975 HloModule Bitcast
2976
2977 ENTRY main {
2978 param = %s[32,121]{1,0} parameter(0)
2979 ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param)
2980 }
2981 )";
2982 string hlo_text;
2983 if (use_bfloat16_) {
2984 hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
2985 } else {
2986 hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
2987 }
2988 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2989 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
2990 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
2991 if (use_bfloat16_) {
2992 EXPECT_TRUE(
2993 absl::c_equal(args[0].data<bfloat16>(), actual.data<bfloat16>()));
2994 } else {
2995 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
2996 }
2997 }
2998
2999 // Check that s32 under/overflow doesn't trigger a ubsan failure.
TEST_F(HloEvaluatorTest,Int32Overflow)3000 TEST_F(HloEvaluatorTest, Int32Overflow) {
3001 constexpr absl::string_view hlo_text = R"(
3002 HloModule Test
3003
3004 ENTRY main {
3005 c1 = s32[] constant(1073741824) // 2^30
3006 sum = s32[] add(c1, c1) // 2^31, i.e. INT_MIN
3007
3008 c2 = s32[] constant(-2147483648) // -2^31
3009 sub = s32[] subtract(c2, c1) // -2^31 - 2^30, underflows
3010
3011 mul = s32[] multiply(c1, c1)
3012 ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul)
3013 }
3014 )";
3015 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3016 TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({}));
3017 std::vector<Literal> actual = literal.DecomposeTuple();
3018 ASSERT_EQ(actual.size(), 3);
3019
3020 uint32 pow30 = uint32{1} << 30;
3021 uint32 pow31 = uint32{1} << 31;
3022 EXPECT_EQ(actual[0].GetFirstElement<int32>(), static_cast<int32>(pow31));
3023 EXPECT_EQ(actual[1].GetFirstElement<int32>(),
3024 static_cast<int32>(-(pow31 + pow30)));
3025 EXPECT_EQ(actual[2].GetFirstElement<int32>(),
3026 static_cast<int32>(pow31 * pow31));
3027 }
3028
TEST_F(HloEvaluatorTest,GetDimensionSize)3029 TEST_F(HloEvaluatorTest, GetDimensionSize) {
3030 constexpr absl::string_view hlo_text = R"(
3031 HloModule Test
3032
3033 ENTRY main {
3034 size = u32[] parameter(0)
3035
3036 data = s32[4] parameter(1)
3037
3038 sum = s32[4] add(data, data)
3039
3040 ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0}
3041 }
3042 )";
3043 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3044
3045 // Set up dynamic parameter binding.
3046 TF_CHECK_OK(m_->dynamic_parameter_binding().Bind(
3047 DynamicParameterBinding::DynamicParameter{0, {}},
3048 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
3049
3050 TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference,
3051 DynamicDimensionInference::Run(m_.get()));
3052
3053 evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
3054 Literal size_arg = LiteralUtil::CreateR0<uint32>(3);
3055 Literal data_arg = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
3056
3057 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
3058
3059 EXPECT_EQ(actual.GetFirstElement<uint32>(), static_cast<uint32>(3));
3060 }
3061
3062 // Check that we get a useful error if we pass inputs of the wrong shape.
TEST_F(HloEvaluatorTest,EvaluateWithWrongInputShapes)3063 TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) {
3064 constexpr absl::string_view hlo_text = R"(
3065 HloModule Test
3066
3067 ENTRY main {
3068 p0 = s32[1] parameter(0)
3069 ROOT sum = s32[1] add(p0, p0)
3070 }
3071 )";
3072 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3073 Literal input_wrong_shape = LiteralUtil::CreateR1<int32>({0, 1});
3074
3075 EXPECT_EQ(HloEvaluator()
3076 .Evaluate(*m_, {&input_wrong_shape})
3077 .status()
3078 .error_message(),
3079 "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
3080 "but arg was s32[2].");
3081 EXPECT_EQ(HloEvaluator()
3082 .Evaluate(*m_->entry_computation(), {&input_wrong_shape})
3083 .status()
3084 .error_message(),
3085 "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
3086 "but arg was s32[2].");
3087 }
3088
3089 // Check that we get a useful error if we pass too many or too few inputs.
TEST_F(HloEvaluatorTest,EvaluateWithWrongNumberOfInputs)3090 TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) {
3091 constexpr absl::string_view hlo_text = R"(
3092 HloModule Test
3093
3094 ENTRY main {
3095 p0 = s32[1] parameter(0)
3096 ROOT sum = s32[1] add(p0, p0)
3097 }
3098 )";
3099 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3100 Literal input = LiteralUtil::CreateR1<int32>({0});
3101
3102 EXPECT_EQ(
3103 HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(),
3104 "Expected 1 argument, but got 2.");
3105 EXPECT_EQ(HloEvaluator()
3106 .Evaluate(*m_->entry_computation(), {&input, &input})
3107 .status()
3108 .error_message(),
3109 "Expected 1 argument, but got 2.");
3110 }
3111
TEST_F(HloEvaluatorTest,PreserveFusionInputLayout)3112 TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) {
3113 constexpr absl::string_view hlo_text = R"(
3114 HloModule FusionInputLayout
3115
3116 fused_computation {
3117 param_0 = f32[20,20]{0,1} parameter(0)
3118 ROOT bitcast = f32[20,20]{1,0} bitcast(param_0)
3119 }
3120
3121 ENTRY kernel_entry {
3122 parameter.0 = f32[20,20]{0,1} parameter(0)
3123 ROOT fusion = f32[20,20]{1,0} fusion(parameter.0),
3124 kind=kLoop, calls=fused_computation
3125 })";
3126
3127 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3128 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3129
3130 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
3131 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
3132 }
3133
TEST_F(HloEvaluatorTest,PreserveFusionOutputLayout)3134 TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) {
3135 constexpr absl::string_view hlo_text = R"(
3136 HloModule FusionOutputLayout
3137
3138 fused_computation {
3139 param_0 = f32[20,20]{1,0} parameter(0)
3140 ROOT bitcast = f32[20,20]{0,1} bitcast(param_0)
3141 }
3142
3143 ENTRY kernel_entry {
3144 parameter.0 = f32[20,20]{1,0} parameter(0)
3145 ROOT fusion = f32[20,20]{0,1} fusion(parameter.0),
3146 kind=kLoop, calls=fused_computation
3147 })";
3148
3149 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3150 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3151 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
3152 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
3153 }
3154
TEST_F(HloEvaluatorTest,PreserveMOFusionOutputLayout)3155 TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) {
3156 constexpr absl::string_view hlo_text = R"(
3157 HloModule MOFusionOutputLayout
3158
3159 fused_computation {
3160 param_0 = f32[20,20]{1,0} parameter(0)
3161 bitcast = f32[20,20]{0,1} bitcast(param_0)
3162 ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast)
3163 }
3164
3165 ENTRY kernel_entry {
3166 parameter.0 = f32[20,20]{1,0} parameter(0)
3167 ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0),
3168 kind=kLoop, calls=fused_computation
3169 })";
3170
3171 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3172 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3173 TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]}));
3174 std::vector<Literal> actual_literals = actual_tuple.DecomposeTuple();
3175 EXPECT_TRUE(
3176 absl::c_equal(args[0].data<float>(), actual_literals[0].data<float>()));
3177 }
3178
3179 // Tests that custom_calls fail to evaluate when no handler is specified.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_NoHandler)3180 TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) {
3181 constexpr absl::string_view hlo_text = R"(
3182 HloModule EvaluateCustomCall_NoHandler
3183 ENTRY kernel_entry {
3184 parameter.0 = u32[2,2]{1,0} parameter(0)
3185 ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
3186 custom_call_target="_my_custom_call"
3187 }
3188 )";
3189
3190 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3191 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3192 EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(),
3193 ::tensorflow::error::UNIMPLEMENTED);
3194 }
3195
3196 // Tests when a custom_call handler returns an error.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_HandlerError)3197 TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) {
3198 constexpr absl::string_view hlo_text = R"(
3199 HloModule EvaluateCustomCall_HandlerError
3200 ENTRY kernel_entry {
3201 parameter.0 = u32[2,2]{1,0} parameter(0)
3202 ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
3203 custom_call_target="_my_custom_call"
3204 }
3205 )";
3206
3207 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3208 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3209 HloEvaluator evaluator;
3210 evaluator.set_custom_call_handler(
3211 [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
3212 return InternalError("Test error");
3213 });
3214 EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(),
3215 ::tensorflow::error::INTERNAL);
3216 }
3217
3218 // Tests the custom_call handler on calls with many inputs.
3219 // We sum the operands so that we can verify the operand and output literals
3220 // are properly mapped for access.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_ManyInputs)3221 TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) {
3222 constexpr absl::string_view hlo_text = R"(
3223 HloModule EvaluateCustomCall_ManyInputs
3224 ENTRY kernel_entry {
3225 parameter.0 = u32[1]{0} parameter(0)
3226 parameter.1 = u32[1]{0} parameter(1)
3227 ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1),
3228 custom_call_target="_my_custom_call"
3229 }
3230 )";
3231
3232 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3233 auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3234 HloEvaluator evaluator;
3235 evaluator.set_custom_call_handler(
3236 [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
3237 EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode());
3238 EXPECT_EQ("_my_custom_call", custom_call->custom_call_target());
3239 EXPECT_EQ(2, custom_call->operand_count());
3240 EXPECT_EQ(2, operands.size());
3241 auto output = Literal::CreateFromShape(custom_call->shape());
3242 auto operand0_data = operands[0]->data<uint32>();
3243 auto operand1_data = operands[1]->data<uint32>();
3244 auto output_data = output.data<uint32>();
3245 output_data[0] = operand0_data[0] + operand1_data[0];
3246 return output;
3247 });
3248 TF_ASSERT_OK_AND_ASSIGN(
3249 Literal actual_literal,
3250 evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]}));
3251 auto arg0_data = args[0].data<uint32>();
3252 auto arg1_data = args[1].data<uint32>();
3253 std::vector<uint32> expected_data = {arg0_data[0] + arg1_data[0]};
3254 EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data<uint32>()));
3255 }
3256
TEST_F(HloEvaluatorTest,IsFiniteF16)3257 TEST_F(HloEvaluatorTest, IsFiniteF16) {
3258 constexpr absl::string_view hlo_text = R"(
3259 HloModule test
3260
3261 ENTRY IsFiniteTest {
3262 c = f16[6] constant({nan, 7, nan, -1, inf, -inf})
3263 ROOT is-finite = pred[6] is-finite(c)
3264 })";
3265
3266 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3267 TF_ASSERT_OK_AND_ASSIGN(
3268 Literal actual_literal,
3269 HloEvaluator().Evaluate(*m_->entry_computation(), {}));
3270 EXPECT_THAT(actual_literal.data<bool>(),
3271 ::testing::ElementsAre(false, true, false, true, false, false));
3272 }
3273
TEST_F(HloEvaluatorTest,IsFiniteBf16)3274 TEST_F(HloEvaluatorTest, IsFiniteBf16) {
3275 constexpr absl::string_view hlo_text = R"(
3276 HloModule test
3277
3278 ENTRY IsFiniteTest {
3279 c = bf16[6] constant({nan, 7, nan, -1, inf, -inf})
3280 ROOT is-finite = pred[6] is-finite(c)
3281 })";
3282
3283 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3284 TF_ASSERT_OK_AND_ASSIGN(
3285 Literal actual_literal,
3286 HloEvaluator().Evaluate(*m_->entry_computation(), {}));
3287 EXPECT_THAT(actual_literal.data<bool>(),
3288 ::testing::ElementsAre(false, true, false, true, false, false));
3289 }
3290
3291 } // namespace
3292 } // namespace xla
3293