1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16 
17 #include <initializer_list>
18 #include <memory>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/permutation_util.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/lib/core/status_test_util.h"
46 #include "tensorflow/core/platform/test.h"
47 #include "tensorflow/core/platform/test_benchmark.h"
48 #include "tensorflow/core/platform/types.h"
49 
50 namespace xla {
51 namespace {
52 
53 static std::array<bool, 2> use_bf16_params{true, false};
54 
55 // Test fixture for the HloEvaluator.
56 //
57 // In bf16 mode, all f32 shapes are converted to bf16 before running.
58 class HloEvaluatorTest : public HloTestBase {
59  public:
HloEvaluatorTest()60   HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); }
61 
Evaluate(absl::Span<const Literal * const> arg_literals={})62   StatusOr<Literal> Evaluate(
63       absl::Span<const Literal* const> arg_literals = {}) {
64     if (use_bfloat16_) {
65       HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
66     }
67     return evaluator_.Evaluate(*m_->entry_computation(), arg_literals);
68   }
69 
70   // Evaluate function that takes in a local module instead of using m_
71   // that is in HloTestBase. Once m_ in HloTestBase is
72   // removed, this should be the default Evaluate function.
EvaluateWithModule(HloModule * module,absl::Span<const Literal * const> arg_literals={})73   Literal EvaluateWithModule(
74       HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
75     if (use_bfloat16_) {
76       HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
77     }
78     return evaluator_.Evaluate(*module->entry_computation(), arg_literals)
79         .ConsumeValueOrDie();
80   }
81 
TestUnaryOp(HloOpcode opcode,Literal expected,Literal input,float aabs=0)82   void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
83                    float aabs = 0) {
84     HloComputation::Builder b(TestName());
85     auto c1 =
86         b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
87     b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
88     m_->AddEntryComputation(b.Build());
89 
90     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
91 
92     auto element_type = expected.shape().element_type();
93     if (element_type == F32 || element_type == F64) {
94       ErrorSpec error(aabs);
95       EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
96     } else {
97       EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
98     }
99   }
100 
TestBinaryOp(HloOpcode opcode,Literal expected,Literal lhs,Literal rhs)101   void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
102                     Literal rhs) {
103     HloComputation::Builder b(TestName());
104     auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
105     auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
106     b.AddInstruction(
107         HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
108     m_->AddEntryComputation(b.Build());
109 
110     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
111 
112     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
113   }
114 
TestTernaryOp(HloOpcode opcode,Literal expected,Literal src0,Literal src1,Literal src2)115   void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0,
116                      Literal src1, Literal src2) {
117     HloComputation::Builder b(TestName());
118     auto operand0 =
119         b.AddInstruction(HloInstruction::CreateConstant(std::move(src0)));
120     auto operand1 =
121         b.AddInstruction(HloInstruction::CreateConstant(std::move(src1)));
122     auto operand2 =
123         b.AddInstruction(HloInstruction::CreateConstant(std::move(src2)));
124     b.AddInstruction(HloInstruction::CreateTernary(
125         expected.shape(), opcode, operand0, operand1, operand2));
126     m_->AddEntryComputation(b.Build());
127 
128     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
129 
130     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
131   }
132 
MaxComputationScalarF32()133   std::unique_ptr<HloComputation> MaxComputationScalarF32() {
134     HloComputation::Builder max_computation("max");
135     Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
136     auto param_lhs = max_computation.AddInstruction(
137         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
138     auto param_rhs = max_computation.AddInstruction(
139         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
140     max_computation.AddInstruction(HloInstruction::CreateBinary(
141         scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
142     return max_computation.Build();
143   }
144 
ReduceWindowMaxIotaTest(int window_size,int padding,int stride,int window_dilation,int base_dilation,const Literal & expected)145   void ReduceWindowMaxIotaTest(int window_size, int padding, int stride,
146                                int window_dilation, int base_dilation,
147                                const Literal& expected) {
148     HloComputation::Builder b(TestName());
149 
150     // arg:
151     // f32[4,4] {
152     //  {  0,  1,  2,  3 },
153     //  {  4,  5,  6,  7 },
154     //  {  8,  9, 10, 11 },
155     //  { 12, 13, 14, 15 }
156     // }
157     auto arg_array = absl::make_unique<Array2D<float>>(4, 4);
158     arg_array->FillIota(0);
159     auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
160 
161     HloInstruction* arg_instruction = b.AddInstruction(
162         HloInstruction::CreateConstant(std::move(arg_literal)));
163     auto init_value = b.AddInstruction(
164         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
165     auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32());
166 
167     Window window;
168     WindowDimension dim;
169     dim.set_size(window_size);
170     dim.set_stride(stride);
171     dim.set_padding_low(padding);
172     dim.set_padding_high(padding);
173     dim.set_window_dilation(window_dilation);
174     dim.set_base_dilation(base_dilation);
175     *window.add_dimensions() = dim;
176     *window.add_dimensions() = dim;
177 
178     int dim0 = expected.shape().dimensions(0);
179     int dim1 = expected.shape().dimensions(1);
180     Shape shape = ShapeUtil::MakeShape(F32, {dim0, dim1});
181     b.AddInstruction(HloInstruction::CreateReduceWindow(
182         shape, arg_instruction, init_value, window, max_func));
183 
184     m_->AddEntryComputation(b.Build());
185     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
186     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
187   }
188 
189  protected:
HloEvaluatorTest(bool use_bfloat16)190   explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {
191     InitializeFftData();
192   }
193 
194   // Initializes data sets used in FFT tests below.
195   void InitializeFftData();
196 
197   HloEvaluator evaluator_;
198 
199   const bool use_bfloat16_;
200   std::unique_ptr<HloModule> m_ = CreateNewVerifiedModule();
201 
202   // Data sets used in FFT tests below.
203   ErrorSpec fft_error_ = ErrorSpec(1e-4, 1e-5);
204   Literal fft_c64x2x4x8_;
205   Literal fft_c64x2x4x8_1d_;
206   Literal fft_c64x2x4x8_2d_;
207   Literal fft_c64x2x4x8_3d_;
208 };
209 
210 // Lets you write TEST_Ps that run twice, once with and once without bf16.
211 class HloEvaluatorBf16Test : public ::testing::WithParamInterface<bool>,
212                              public HloEvaluatorTest {
213  protected:
HloEvaluatorBf16Test()214   HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {}
215 };
216 
217 INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test,
218                          ::testing::ValuesIn(use_bf16_params));
219 
220 // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
221 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesClamp)222 TEST_P(HloEvaluatorBf16Test, DoesClamp) {
223   auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
224   auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
225   auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
226 
227   Shape shape = low.shape();
228   HloComputation::Builder b(TestName());
229   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
230   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
231   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
232   b.AddInstruction(
233       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
234   m_->AddEntryComputation(b.Build());
235 
236   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
237 
238   auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
239 
240   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
241 }
242 
243 // Verifies that clamping of int64 does not cause loss of precision
TEST_P(HloEvaluatorBf16Test,DoesClampInt64)244 TEST_P(HloEvaluatorBf16Test, DoesClampInt64) {
245   auto ones = [](int bits) { return (int64{1} << bits) - 1; };
246 
247   auto low =
248       LiteralUtil::CreateR2<int64>({{0, ones(54)}, {ones(54), ones(58)}});
249   auto value = LiteralUtil::CreateR2<int64>({{0, ones(56)}, {0, ones(58)}});
250   auto high = LiteralUtil::CreateR2<int64>(
251       {{ones(54), ones(55)}, {ones(56), ones(58)}});
252 
253   Shape shape = low.shape();
254   HloComputation::Builder b(TestName());
255   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
256   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
257   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
258   b.AddInstruction(
259       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
260   m_->AddEntryComputation(b.Build());
261 
262   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
263 
264   auto expected =
265       LiteralUtil::CreateR2<int64>({{0, ones(55)}, {ones(54), ones(58)}});
266 
267   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
268 }
269 
TEST_P(HloEvaluatorBf16Test,DISABLED_DoesClampSpecialBroadcast)270 TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) {
271   auto low = LiteralUtil::CreateR0<float>(0.f);
272   auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
273   auto high = LiteralUtil::CreateR0<float>(1.f);
274 
275   Shape shape = value.shape();
276   HloComputation::Builder b(TestName());
277   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
278   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
279   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
280   b.AddInstruction(
281       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
282   m_->AddEntryComputation(b.Build());
283 
284   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
285 
286   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
287 
288   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
289 }
290 
291 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
292 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesSelect)293 TEST_P(HloEvaluatorBf16Test, DoesSelect) {
294   auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
295   auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
296   auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
297 
298   Shape shape = on_true.shape();
299   HloComputation::Builder b(TestName());
300   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
301   auto c2 =
302       b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true)));
303   auto c3 =
304       b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false)));
305   b.AddInstruction(
306       HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
307   m_->AddEntryComputation(b.Build());
308 
309   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
310 
311   auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
312 
313   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
314 }
315 
316 // Verifies that HloEvaluator evaluates a HLO instruction that performs
317 // element-wise addition with 2 operands.
TEST_F(HloEvaluatorTest,DoesAdd)318 TEST_F(HloEvaluatorTest, DoesAdd) {
319   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
320   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
321   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}});
322   TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
323                std::move(rhs));
324 }
325 // Verifies that HloEvaluator evaluates a HLO instruction that performs
326 // element-wise and with 2 operands.
TEST_P(HloEvaluatorBf16Test,DoesAnd)327 TEST_P(HloEvaluatorBf16Test, DoesAnd) {
328   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
329   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
330   auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {4, 4}});
331   TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
332                std::move(rhs));
333 }
334 // Verifies that HloEvaluator evaluates a HLO instruction that performs
335 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesOr)336 TEST_F(HloEvaluatorTest, DoesOr) {
337   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
338   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
339   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-100, 4}});
340   TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
341                std::move(rhs));
342 }
343 // Verifies that HloEvaluator evaluates a HLO instruction that performs
344 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesXor)345 TEST_F(HloEvaluatorTest, DoesXor) {
346   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
347   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
348   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-104, 0}});
349   TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
350                std::move(rhs));
351 }
352 // Verifies that HloEvaluator evaluates a HLO instruction that performs
353 // element-wise multiply with 2 operands.
TEST_F(HloEvaluatorTest,DoesMultiply)354 TEST_F(HloEvaluatorTest, DoesMultiply) {
355   auto lhs = LiteralUtil::CreateR2<int32>({{-1, 0}, {-100, 4}});
356   auto rhs = LiteralUtil::CreateR2<int32>(
357       {{std::numeric_limits<int32>::min(), 4}, {4, 4}});
358   auto expected = LiteralUtil::CreateR2<int32>(
359       {{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
360   TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
361                std::move(rhs));
362 }
363 // Verifies that HloEvaluator evaluates a HLO instruction that performs
364 // element-wise divide with 2 operands.
TEST_F(HloEvaluatorTest,DoesDivideInt64)365 TEST_F(HloEvaluatorTest, DoesDivideInt64) {
366   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
367   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
368   auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}});
369   TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
370                std::move(rhs));
371 }
372 
TEST_F(HloEvaluatorTest,DoesClampS64)373 TEST_F(HloEvaluatorTest, DoesClampS64) {
374   auto low = LiteralUtil::CreateR1<int64>(
375       {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL});
376   auto value = LiteralUtil::CreateR1<int64>(
377       {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL});
378   auto high = LiteralUtil::CreateR1<int64>(
379       {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL});
380   auto expected = LiteralUtil::CreateR1<int64>(
381       {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL});
382   TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low),
383                 std::move(value), std::move(high));
384 }
385 
TEST_P(HloEvaluatorBf16Test,DoesDivideDouble)386 TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) {
387   auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
388   auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
389   auto expected =
390       LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
391   TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
392                std::move(rhs));
393 }
394 
395 // Verifies that HloEvaluator evaluates a HLO instruction that performs
396 // element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest,DoesAbsR2)397 TEST_F(HloEvaluatorTest, DoesAbsR2) {
398   auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
399   auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
400   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
401 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR0)402 TEST_P(HloEvaluatorBf16Test, DoesAbsR0) {
403   auto operand = LiteralUtil::CreateR0<float>(-1.0f);
404   auto expected = LiteralUtil::CreateR0<float>(1.0f);
405   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
406 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR1WithZeroSize)407 TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) {
408   auto operand = LiteralUtil::CreateR1<float>({});
409   auto expected = LiteralUtil::CreateR1<float>({});
410   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
411 }
412 
TEST_F(HloEvaluatorTest,DoesAbsC128)413 TEST_F(HloEvaluatorTest, DoesAbsC128) {
414   auto x = LiteralUtil::CreateR0<complex128>({1, 2});
415   auto expected_real = LiteralUtil::CreateR0<double>(2.23607);
416   TestUnaryOp(HloOpcode::kAbs, std::move(expected_real), std::move(x), 3e-06);
417 }
418 
TEST_F(HloEvaluatorTest,DoesNegateR2)419 TEST_F(HloEvaluatorTest, DoesNegateR2) {
420   auto operand = LiteralUtil::CreateR2<int32>(
421       {{0, std::numeric_limits<int32>::min()}, {-1, 4}});
422   auto expected = LiteralUtil::CreateR2<int32>(
423       {{0, std::numeric_limits<int>::min()}, {1, -4}});
424   TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
425 }
TEST_P(HloEvaluatorBf16Test,DoesCosR2)426 TEST_P(HloEvaluatorBf16Test, DoesCosR2) {
427   auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
428   auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
429   TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
430               use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
431 }
TEST_P(HloEvaluatorBf16Test,DoesSinR2)432 TEST_P(HloEvaluatorBf16Test, DoesSinR2) {
433   auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
434   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
435   TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
436               use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
437 }
TEST_F(HloEvaluatorTest,DoesNotR2)438 TEST_F(HloEvaluatorTest, DoesNotR2) {
439   auto operand =
440       LiteralUtil::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
441                                     {-1, std::numeric_limits<int>::max()}});
442   auto expected =
443       LiteralUtil::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
444                                     {0, std::numeric_limits<int>::min()}});
445   TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
446 }
447 
TEST_F(HloEvaluatorTest,DoesRealC128)448 TEST_F(HloEvaluatorTest, DoesRealC128) {
449   auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
450   auto expected_real = LiteralUtil::CreateR1<double>({1, -100});
451   TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x));
452 }
453 
TEST_F(HloEvaluatorTest,DoesImagC128)454 TEST_F(HloEvaluatorTest, DoesImagC128) {
455   auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
456   auto expected_imag = LiteralUtil::CreateR1<double>({0, 4});
457   TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
458 }
459 
460 // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
461 // constant operands.
TEST_F(HloEvaluatorTest,DoesTraverseInstructions)462 TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
463   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
464   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
465   auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
466   std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
467 
468   Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
469 
470   HloComputation::Builder b(TestName());
471   auto param_lhs =
472       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
473   auto param_rhs =
474       b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
475   auto lhs_instruction = b.AddInstruction(HloInstruction::CreateBinary(
476       shape, HloOpcode::kAdd, param_lhs, param_rhs));
477 
478   auto param_rhs2 =
479       b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2"));
480   b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd,
481                                                 lhs_instruction, param_rhs2));
482   m_->AddEntryComputation(b.Build());
483 
484   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args));
485 
486   auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
487 
488   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
489 }
490 
491 // Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesReshape)492 TEST_F(HloEvaluatorTest, DoesReshape) {
493   HloComputation::Builder b(TestName());
494   const int64 dimensions[] = {11, 8, 7, 5, 9};
495   TF_ASSERT_OK_AND_ASSIGN(auto literal,
496                           LiteralUtil::CreateRandomLiteral<F32>(
497                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
498   auto literal_clone = literal.Clone();
499   HloInstruction* literal_instruction =
500       b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
501 
502   Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
503   const int64 permutation[] = {1, 2, 0, 4, 3};
504   b.AddInstruction(
505       HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
506   m_->AddEntryComputation(b.Build());
507 
508   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
509 
510   using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
511   result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
512     std::vector<int64> rindexes = PermuteInverse(indices, permutation);
513     EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
514   });
515 }
516 
517 // Verifies Broadcast operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesBroadcast)518 TEST_F(HloEvaluatorTest, DoesBroadcast) {
519   HloComputation::Builder b(TestName());
520   auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
521   auto output_literal = LiteralUtil::CreateR3<int32>(
522       {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
523   HloInstruction* literal_instruction = b.AddInstruction(
524       HloInstruction::CreateConstant(std::move(input_literal)));
525   b.AddInstruction(HloInstruction::CreateBroadcast(
526       output_literal.shape(), literal_instruction, {1, 2}));
527   m_->AddEntryComputation(b.Build());
528 
529   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
530 
531   EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
532 }
533 
TEST_F(HloEvaluatorTest,DoesBroadcastScalar)534 TEST_F(HloEvaluatorTest, DoesBroadcastScalar) {
535   HloComputation::Builder b(TestName());
536   auto input_literal = LiteralUtil::CreateR0<int32>(111);
537   auto output_literal = LiteralUtil::CreateR2<int32>(
538       {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
539 
540   HloInstruction* literal_instruction = b.AddInstruction(
541       HloInstruction::CreateConstant(std::move(input_literal)));
542   // Broadcast dimension should be empty in the case of scalars.
543   b.AddInstruction(HloInstruction::CreateBroadcast(
544       output_literal.shape(), literal_instruction,
545       /*broadcast_dimensions=*/{}));
546   m_->AddEntryComputation(b.Build());
547 
548   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
549 
550   EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
551 }
552 
TEST_F(HloEvaluatorTest,DoesConcatenateSimple)553 TEST_F(HloEvaluatorTest, DoesConcatenateSimple) {
554   HloComputation::Builder b(TestName());
555 
556   HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
557       LiteralUtil::CreateR2<int64>({{-1, -2}, {100, 200}})));
558   HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
559       LiteralUtil::CreateR2<int64>({{-2, -3}, {-100, -200}})));
560 
561   std::vector<HloInstruction*> operands = {operand1, operand2};
562 
563   Shape shape = ShapeUtil::MakeShape(S64, {4, 2});
564   b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
565 
566   m_->AddEntryComputation(b.Build());
567 
568   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
569 
570   auto expected = LiteralUtil::CreateR2<int64>(
571       {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
572   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
573 }
574 
TEST_F(HloEvaluatorTest,ConcatenateHandlesShapeWithZeroElement)575 TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
576   HloComputation::Builder b(TestName());
577 
578   HloInstruction* operand1 = b.AddInstruction(
579       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({100, 200})));
580   HloInstruction* operand2 = b.AddInstruction(
581       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({})));
582 
583   std::vector<HloInstruction*> operands = {operand1, operand2};
584 
585   Shape shape = ShapeUtil::MakeShape(S64, {2});
586   b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
587 
588   m_->AddEntryComputation(b.Build());
589 
590   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
591 
592   auto expected = LiteralUtil::CreateR1<int64>({100, 200});
593   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
594 }
595 
TEST_P(HloEvaluatorBf16Test,ConvertWithSameLayout)596 TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) {
597   HloComputation::Builder b(TestName());
598 
599   auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
600   auto expected =
601       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
602   ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
603                                                expected.shape()));
604 
605   HloInstruction* constant = b.AddInstruction(
606       HloInstruction::CreateConstant(std::move(input_literal)));
607   b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
608   m_->AddEntryComputation(b.Build());
609 
610   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
611 
612   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
613 }
614 
TEST_P(HloEvaluatorBf16Test,ConvertWithDifferentLayout)615 TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) {
616   HloComputation::Builder b(TestName());
617 
618   auto input_literal = LiteralUtil::CreateR2WithLayout<int32>(
619       {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
620   auto expected = LiteralUtil::CreateR2WithLayout<float>(
621       {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
622   ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
623                                                 expected.shape()));
624 
625   HloInstruction* constant = b.AddInstruction(
626       HloInstruction::CreateConstant(std::move(input_literal)));
627   b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
628   m_->AddEntryComputation(b.Build());
629 
630   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
631 
632   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
633 }
634 
CreatePaddingConfig(std::initializer_list<std::array<int64,3>> padding_dimensions)635 PaddingConfig CreatePaddingConfig(
636     std::initializer_list<std::array<int64, 3>> padding_dimensions) {
637   PaddingConfig padding_config;
638 
639   for (auto& paddings_per_dim : padding_dimensions) {
640     auto dimension = padding_config.add_dimensions();
641     dimension->set_edge_padding_low(paddings_per_dim[0]);
642     dimension->set_edge_padding_high(paddings_per_dim[1]);
643     dimension->set_interior_padding(paddings_per_dim[2]);
644   }
645   return padding_config;
646 }
647 
TEST_F(HloEvaluatorTest,Pad2DIntegerArrayWithZeroDimension)648 TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
649   auto operand = LiteralUtil::CreateR2<int32>({{}, {}});
650   HloComputation::Builder b(TestName());
651   auto operand_instruction =
652       b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
653 
654   constexpr int32 kPadValue = 10;
655   auto pad_value = LiteralUtil::CreateR0<int32>(kPadValue);
656   auto padding_value_instruction =
657       b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
658 
659   auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}});
660   Shape shape = ShapeUtil::MakeShape(S32, {5, 2});
661   b.AddInstruction(HloInstruction::CreatePad(
662       shape, operand_instruction, padding_value_instruction, padding_config));
663   m_->AddEntryComputation(b.Build());
664 
665   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
666 
667   auto expected = LiteralUtil::CreateR2<int32>(
668       {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
669 
670   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
671 }
672 
TEST_P(HloEvaluatorBf16Test,Pad4DFloatArrayWithInteriorPadding)673 TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) {
674   HloComputation::Builder b(TestName());
675 
676   Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
677   auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
678   HloInstruction* input_instruction =
679       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
680   constexpr float kPadValue = 1.5;
681   auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
682   HloInstruction* pad_instruction =
683       b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
684 
685   Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1});
686   auto r4_padding_on_dim0_dim1 =
687       CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}});
688   b.AddInstruction(HloInstruction::CreatePad(
689       shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
690   m_->AddEntryComputation(b.Build());
691 
692   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
693 
694   auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
695   expected_array->Fill(kPadValue);
696   (*expected_array)(1, 0, 0, 0) = 1.0f;
697   (*expected_array)(1, 2, 0, 0) = 2.0f;
698   (*expected_array)(4, 0, 0, 0) = 3.0f;
699   (*expected_array)(4, 2, 0, 0) = 4.0f;
700   (*expected_array)(7, 0, 0, 0) = 5.0f;
701   (*expected_array)(7, 2, 0, 0) = 6.0f;
702 
703   auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
704 
705   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
706 }
707 
TEST_P(HloEvaluatorBf16Test,NegativePadding2D)708 TEST_P(HloEvaluatorBf16Test, NegativePadding2D) {
709   HloComputation::Builder b(TestName());
710 
711   // input_array:
712   // f32[4,3] {
713   //  { 1, 2, 3 },
714   //  { 5, 6, 7 },
715   //  { 9, 10, 11 },
716   //  { 13, 14, 15 },
717   // }
718   auto input_array = absl::make_unique<Array2D<float>>(4, 3);
719   input_array->FillUnique(1.0f);
720   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
721   HloInstruction* input_instruction =
722       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
723 
724   auto pad_value_instruction = b.AddInstruction(
725       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
726 
727   auto r2_padding_on_dim0_dim1 =
728       CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
729   Shape shape = ShapeUtil::MakeShape(F32, {1, 5});
730   b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
731                                              pad_value_instruction,
732                                              r2_padding_on_dim0_dim1));
733 
734   m_->AddEntryComputation(b.Build());
735 
736   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
737 
738   // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
739   auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
740   (*expected_array)(0, 0) = 7.0f;
741   (*expected_array)(0, 1) = 2.718f;
742   (*expected_array)(0, 2) = 2.718f;
743   (*expected_array)(0, 3) = 2.718f;
744   (*expected_array)(0, 4) = 2.718f;
745   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
746 
747   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
748 }
749 
TEST_P(HloEvaluatorBf16Test,NegativeAndInteriorPadding2D)750 TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) {
751   HloComputation::Builder b(TestName());
752 
753   // f32[4,3] {
754   //  { 1, 2, 3 },
755   //  { 5, 6, 7 },
756   //  { 9, 10, 11 },
757   //  { 13, 14, 15 },
758   // }
759   auto input_array = absl::make_unique<Array2D<float>>(4, 3);
760   input_array->FillUnique(1.0f);
761   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
762   HloInstruction* input_instruction =
763       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
764 
765   auto pad_value_instruction = b.AddInstruction(
766       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
767 
768   PaddingConfig padding_config = MakeNoPaddingConfig(2);
769 
770   // Negative padding that results in zero dimensions.
771   auto r2_padding_on_dim0_dim1 =
772       CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}});
773 
774   Shape shape = ShapeUtil::MakeShape(F32, {0, 9});
775   b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
776                                              pad_value_instruction,
777                                              r2_padding_on_dim0_dim1));
778 
779   m_->AddEntryComputation(b.Build());
780 
781   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
782 
783   auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
784   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
785 
786   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
787 }
788 
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank1)789 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) {
790   HloComputation::Builder b(TestName());
791 
792   // lhs:
793   // f32[4,1] {
794   //  { 1 },
795   //  { 2 },
796   //  { 3 },
797   //  { 4 },
798   // }
799   auto lhs_array = absl::make_unique<Array2D<float>>(4, 1);
800   lhs_array->FillUnique(1.0f);
801   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
802   HloInstruction* lhs_instruction =
803       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
804 
805   // rhs:
806   // f32[2] { 1, 2 },
807   auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
808   HloInstruction* rhs_instruction =
809       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
810 
811   Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
812   DotDimensionNumbers dot_dnums;
813   dot_dnums.add_lhs_contracting_dimensions(1);
814   dot_dnums.add_rhs_contracting_dimensions(0);
815   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
816                                              rhs_instruction, dot_dnums,
817                                              DefaultPrecisionConfig(2)));
818   m_->AddEntryComputation(b.Build());
819 
820   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
821 
822   // clang-format off
823   auto expected_array = Array2D<float>({
824       {1.f, 2.f},
825       {2.f, 4.f},
826       {3.f, 6.f},
827       {4.f, 8.f},
828   });
829   // clang-format on
830   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
831 
832   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
833 }
834 
TEST_P(HloEvaluatorBf16Test,DotRank1AndRank2)835 TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) {
836   HloComputation::Builder b(TestName());
837 
838   // lhs:
839   // f32[3]
840   //  { 1, 2, 3 },
841   auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
842   HloInstruction* lhs_instruction =
843       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
844 
845   // rhs:
846   // f32[3,2] {
847   //  { 1, 2 },
848   //  { 3, 4 },
849   //  { 5, 6 },
850   // }
851   auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
852   rhs_array->FillUnique(1.0f);
853   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
854   HloInstruction* rhs_instruction =
855       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
856 
857   Shape shape = ShapeUtil::MakeShape(F32, {2});
858   DotDimensionNumbers dot_dnums;
859   dot_dnums.add_lhs_contracting_dimensions(0);
860   dot_dnums.add_rhs_contracting_dimensions(0);
861   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
862                                              rhs_instruction, dot_dnums,
863                                              DefaultPrecisionConfig(2)));
864   m_->AddEntryComputation(b.Build());
865 
866   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
867 
868   auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
869 
870   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
871 }
872 
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank2)873 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) {
874   HloComputation::Builder b(TestName());
875 
876   // lhs:
877   // f32[4,3] {
878   //  { 1, 2, 3 },
879   //  { 5, 6, 7 },
880   //  { 9, 10, 11 },
881   //  { 13, 14, 15 },
882   // }
883   auto lhs_array = absl::make_unique<Array2D<float>>(4, 3);
884   lhs_array->FillUnique(1.0f);
885   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
886   HloInstruction* lhs_instruction =
887       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
888 
889   // rhs:
890   // f32[3,2] {
891   //  { 1, 2 },
892   //  { 3, 4 },
893   //  { 5, 6 },
894   // }
895   auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
896   rhs_array->FillUnique(1.0f);
897   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
898   HloInstruction* rhs_instruction =
899       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
900 
901   Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
902   DotDimensionNumbers dot_dnums;
903   dot_dnums.add_lhs_contracting_dimensions(1);
904   dot_dnums.add_rhs_contracting_dimensions(0);
905   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
906                                              rhs_instruction, dot_dnums,
907                                              DefaultPrecisionConfig(2)));
908   m_->AddEntryComputation(b.Build());
909 
910   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
911 
912   auto expected_array = Array2D<float>({
913       {22.f, 28.f},
914       {58.f, 76.f},
915       {94.f, 124.f},
916       {130.f, 172.f},
917   });
918   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
919 
920   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
921 }
922 
TEST_P(HloEvaluatorBf16Test,DotRank4AndRank4)923 TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) {
924   HloComputation::Builder b(TestName());
925 
926   auto lhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
927   lhs_array->FillIota(1.0f);
928   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*lhs_array);
929   HloInstruction* lhs_instruction =
930       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
931 
932   auto rhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
933   rhs_array->FillIota(2.0f);
934   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*rhs_array);
935   HloInstruction* rhs_instruction =
936       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
937 
938   Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1});
939   DotDimensionNumbers dot_dnums;
940 
941   dot_dnums.add_lhs_batch_dimensions(0);
942   dot_dnums.add_rhs_batch_dimensions(0);
943   dot_dnums.add_lhs_contracting_dimensions(1);
944   dot_dnums.add_lhs_contracting_dimensions(2);
945   dot_dnums.add_rhs_contracting_dimensions(1);
946   dot_dnums.add_rhs_contracting_dimensions(2);
947   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
948                                              rhs_instruction, dot_dnums,
949                                              DefaultPrecisionConfig(2)));
950   m_->AddEntryComputation(b.Build());
951 
952   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
953 
954   float expected_1 = 0;
955   for (float i = 1.0f; i < 7.0f; ++i) {
956     expected_1 += i * i + i;
957   }
958   float expected_2 = 0;
959   for (float i = 7.0f; i < 13.0f; ++i) {
960     expected_2 += i * i + i;
961   }
962   auto expected_array = Array3D<float>({{{expected_1}}, {{expected_2}}});
963   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
964 
965   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
966 }
967 
TEST_P(HloEvaluatorBf16Test,SimpleConv1D)968 TEST_P(HloEvaluatorBf16Test, SimpleConv1D) {
969   HloComputation::Builder b(TestName());
970 
971   Array3D<float> lhs_array = {{{1, 2, 3}}};
972   auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
973   HloInstruction* lhs_instruction =
974       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
975 
976   Array3D<float> rhs_array = {{{3.f, 4.f}}};
977   auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
978   HloInstruction* rhs_instruction =
979       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
980 
981   Window window;
982   WindowDimension dim;
983   dim.set_size(2);
984   dim.set_stride(1);
985   dim.set_padding_low(0);
986   dim.set_padding_high(1);
987   dim.set_window_dilation(1);
988   dim.set_base_dilation(1);
989   *window.add_dimensions() = dim;
990 
991   ConvolutionDimensionNumbers dnums;
992   dnums.set_input_batch_dimension(0);
993   dnums.set_output_batch_dimension(0);
994   dnums.set_input_feature_dimension(1);
995   dnums.set_output_feature_dimension(1);
996   dnums.add_input_spatial_dimensions(2);
997   dnums.add_output_spatial_dimensions(2);
998 
999   dnums.set_kernel_output_feature_dimension(0);
1000   dnums.set_kernel_input_feature_dimension(1);
1001   dnums.add_kernel_spatial_dimensions(2);
1002 
1003   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
1004   b.AddInstruction(HloInstruction::CreateConvolve(
1005       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1006       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1007   m_->AddEntryComputation(b.Build());
1008 
1009   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1010 
1011   Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
1012   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
1013 
1014   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1015 }
1016 
TEST_P(HloEvaluatorBf16Test,Simple4x4Conv2DWith2x2Kernel)1017 TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) {
1018   HloComputation::Builder b(TestName());
1019 
1020   Array4D<float> lhs_array(1, 1, 4, 4);
1021   // clang-format off
1022   lhs_array.FillWithYX(Array2D<float>({
1023     {1,  2,  3,  4 },
1024     {5,  6,  7,  8 },
1025     {9,  10, 11, 12},
1026     {13, 14, 15, 16},
1027   }));
1028   // clang-format on
1029   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1030   HloInstruction* lhs_instruction =
1031       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1032 
1033   Array4D<float> rhs_array(1, 1, 2, 2);
1034   // clang-format off
1035   rhs_array.FillWithYX(Array2D<float>({
1036     {5, 6},
1037     {7, 8},
1038   }));
1039   // clang-format on
1040   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1041   HloInstruction* rhs_instruction =
1042       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1043 
1044   Window window;
1045   WindowDimension dim;
1046   dim.set_size(2);
1047   dim.set_stride(1);
1048   dim.set_padding_low(0);
1049   dim.set_padding_high(1);
1050   dim.set_window_dilation(1);
1051   dim.set_base_dilation(1);
1052   *window.add_dimensions() = dim;
1053   *window.add_dimensions() = dim;
1054 
1055   ConvolutionDimensionNumbers dnums =
1056       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1057 
1058   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
1059   b.AddInstruction(HloInstruction::CreateConvolve(
1060       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1061       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1062   m_->AddEntryComputation(b.Build());
1063 
1064   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1065 
1066   Array4D<float> expected_array(1, 1, 4, 4);
1067   // clang-format off
1068   expected_array.FillWithYX(Array2D<float>({
1069     {100, 126, 152,  76},
1070     {204, 230, 256, 124},
1071     {308, 334, 360, 172},
1072     {149, 160, 171,  80},
1073   }));
1074   // clang-format on
1075   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1076 
1077   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1078 }
1079 
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensionsReversed)1080 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) {
1081   HloComputation::Builder b(TestName());
1082 
1083   // clang-format off
1084   // Input dimensions: [feature=2, height=3, batch=1, width=4]
1085   Array4D<float> input({
1086     {{{1, 2, 3, 4}},
1087      {{5, 6, 7, 8}},
1088      {{9, 10, 11, 12}}},
1089     {{{13, 14, 15, 16}},
1090      {{17, 18, 19, 20}},
1091      {{21, 22, 23, 24}}}
1092   });
1093   // Weight dimensions:
1094   // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1095   Array4D<float> weight({{
1096     {{1, 7, 13},
1097      {4, 10, 16}},
1098     {{2, 8, 14},
1099      {5, 11, 17}},
1100     {{3, 9, 15},
1101      {6, 12, 18}}
1102   }});
1103   // clang-format on
1104 
1105   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1106   HloInstruction* lhs_instruction =
1107       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1108 
1109   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1110   HloInstruction* rhs_instruction =
1111       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1112   rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
1113       rhs_instruction->shape(), rhs_instruction, {3, 1}));
1114 
1115   Window window;
1116   WindowDimension dim;
1117   dim.set_size(3);
1118   dim.set_stride(1);
1119   dim.set_padding_low(0);
1120   dim.set_padding_high(0);
1121   dim.set_window_dilation(1);
1122   dim.set_base_dilation(1);
1123   dim.set_window_reversal(true);
1124   *window.add_dimensions() = dim;
1125   *window.add_dimensions() = dim;
1126 
1127   ConvolutionDimensionNumbers dnums;
1128   dnums.set_input_batch_dimension(2);
1129   dnums.set_output_batch_dimension(2);
1130   dnums.set_input_feature_dimension(0);
1131   dnums.set_output_feature_dimension(0);
1132   dnums.add_input_spatial_dimensions(1);
1133   dnums.add_output_spatial_dimensions(1);
1134   dnums.add_input_spatial_dimensions(3);
1135   dnums.add_output_spatial_dimensions(3);
1136 
1137   dnums.set_kernel_output_feature_dimension(0);
1138   dnums.set_kernel_input_feature_dimension(2);
1139   dnums.add_kernel_spatial_dimensions(3);
1140   dnums.add_kernel_spatial_dimensions(1);
1141 
1142   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1143   b.AddInstruction(HloInstruction::CreateConvolve(
1144       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1145       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1146   m_->AddEntryComputation(b.Build());
1147 
1148   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1149 
1150   // clang-format off
1151   // Result dimensions: [feature=1, height=1, batch=1, width=2]
1152   Array4D<float> expected_array({{{{2514, 2685}}}});
1153   Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1154   // clang-format on
1155   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1156       use_bfloat16_ ? expected_array_bf16 : expected_array);
1157 
1158   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1159 }
1160 
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensions)1161 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) {
1162   HloComputation::Builder b(TestName());
1163 
1164   // clang-format off
1165   // Input dimensions: [feature=2, height=3, batch=1, width=4]
1166   Array4D<float> input({
1167     {{{1, 2, 3, 4}},
1168      {{5, 6, 7, 8}},
1169      {{9, 10, 11, 12}}},
1170     {{{13, 14, 15, 16}},
1171      {{17, 18, 19, 20}},
1172      {{21, 22, 23, 24}}}
1173   });
1174   // Weight dimensions:
1175   // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1176   Array4D<float> weight({{
1177     {{1, 7, 13},
1178      {4, 10, 16}},
1179     {{2, 8, 14},
1180      {5, 11, 17}},
1181     {{3, 9, 15},
1182      {6, 12, 18}}
1183   }});
1184   // clang-format on
1185 
1186   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1187   HloInstruction* lhs_instruction =
1188       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1189 
1190   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1191   HloInstruction* rhs_instruction =
1192       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1193 
1194   Window window;
1195   WindowDimension dim;
1196   dim.set_size(3);
1197   dim.set_stride(1);
1198   dim.set_padding_low(0);
1199   dim.set_padding_high(0);
1200   dim.set_window_dilation(1);
1201   dim.set_base_dilation(1);
1202   *window.add_dimensions() = dim;
1203   *window.add_dimensions() = dim;
1204 
1205   ConvolutionDimensionNumbers dnums;
1206   dnums.set_input_batch_dimension(2);
1207   dnums.set_output_batch_dimension(2);
1208   dnums.set_input_feature_dimension(0);
1209   dnums.set_output_feature_dimension(0);
1210   dnums.add_input_spatial_dimensions(1);
1211   dnums.add_output_spatial_dimensions(1);
1212   dnums.add_input_spatial_dimensions(3);
1213   dnums.add_output_spatial_dimensions(3);
1214 
1215   dnums.set_kernel_output_feature_dimension(0);
1216   dnums.set_kernel_input_feature_dimension(2);
1217   dnums.add_kernel_spatial_dimensions(3);
1218   dnums.add_kernel_spatial_dimensions(1);
1219 
1220   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1221   b.AddInstruction(HloInstruction::CreateConvolve(
1222       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1223       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1224   m_->AddEntryComputation(b.Build());
1225 
1226   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1227 
1228   // clang-format off
1229   // Result dimensions: [feature=1, height=1, batch=1, width=2]
1230   Array4D<float> expected_array({{{{2514, 2685}}}});
1231   Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1232   // clang-format on
1233   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1234       use_bfloat16_ ? expected_array_bf16 : expected_array);
1235 
1236   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1237 }
1238 
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithHighPadding)1239 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) {
1240   HloComputation::Builder b(TestName());
1241 
1242   Array4D<float> lhs_array(1, 1, 4, 4);
1243   // clang-format off
1244   lhs_array.FillWithYX(Array2D<float>({
1245     {1,  2,  3,  4 },
1246     {5,  6,  7,  8 },
1247     {9,  10, 11, 12},
1248     {13, 14, 15, 16},
1249   }));
1250   // clang-format on
1251   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1252   HloInstruction* lhs_instruction =
1253       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1254 
1255   Array4D<float> rhs_array(1, 1, 2, 2);
1256   // clang-format off
1257   rhs_array.FillWithYX(Array2D<float>({
1258     {5, 6},
1259     {7, 8},
1260   }));
1261   // clang-format on
1262   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1263   HloInstruction* rhs_instruction =
1264       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1265 
1266   Window window;
1267   WindowDimension dim;
1268   dim.set_size(2);
1269   dim.set_stride(1);
1270   dim.set_padding_low(0);
1271   dim.set_padding_high(1);
1272   dim.set_window_dilation(1);
1273   dim.set_base_dilation(2);
1274   *window.add_dimensions() = dim;
1275   *window.add_dimensions() = dim;
1276 
1277   ConvolutionDimensionNumbers dnums =
1278       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1279 
1280   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
1281   b.AddInstruction(HloInstruction::CreateConvolve(
1282       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1283       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1284   m_->AddEntryComputation(b.Build());
1285 
1286   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1287 
1288   Array4D<float> expected_array(1, 1, 7, 7);
1289   expected_array.FillWithYX(Array2D<float>({
1290       {5, 12, 10, 18, 15, 24, 20},
1291       {35, 48, 42, 56, 49, 64, 56},
1292       {25, 36, 30, 42, 35, 48, 40},
1293       {63, 80, 70, 88, 77, 96, 84},
1294       {45, 60, 50, 66, 55, 72, 60},
1295       {91, 112, 98, 120, 105, 128, 112},
1296       {65, 84, 70, 90, 75, 96, 80},
1297   }));
1298   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1299 
1300   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1301 }
1302 
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithLowAndHighPadding)1303 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) {
1304   HloComputation::Builder b(TestName());
1305 
1306   Array4D<float> lhs_array(1, 1, 4, 4);
1307   // clang-format off
1308   lhs_array.FillWithYX(Array2D<float>({
1309     {1,  2,  3,  4 },
1310     {5,  6,  7,  8 },
1311     {9,  10, 11, 12},
1312     {13, 14, 15, 16},
1313   }));
1314   // clang-format on
1315   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1316   HloInstruction* lhs_instruction =
1317       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1318 
1319   Array4D<float> rhs_array(1, 1, 2, 2);
1320   // clang-format off
1321   rhs_array.FillWithYX(Array2D<float>({
1322     {5, 6},
1323     {7, 8},
1324   }));
1325   // clang-format on
1326   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1327   HloInstruction* rhs_instruction =
1328       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1329 
1330   Window window;
1331   WindowDimension dim;
1332   dim.set_size(2);
1333   dim.set_stride(1);
1334   dim.set_padding_low(1);
1335   dim.set_padding_high(1);
1336   dim.set_window_dilation(1);
1337   dim.set_base_dilation(2);
1338   *window.add_dimensions() = dim;
1339   *window.add_dimensions() = dim;
1340 
1341   ConvolutionDimensionNumbers dnums =
1342       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1343 
1344   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
1345   b.AddInstruction(HloInstruction::CreateConvolve(
1346       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1347       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1348   m_->AddEntryComputation(b.Build());
1349 
1350   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1351 
1352   Array4D<float> expected_array(1, 1, 8, 8);
1353   expected_array.FillWithYX(Array2D<float>({
1354       {8, 7, 16, 14, 24, 21, 32, 28},
1355       {6, 5, 12, 10, 18, 15, 24, 20},
1356       {40, 35, 48, 42, 56, 49, 64, 56},
1357       {30, 25, 36, 30, 42, 35, 48, 40},
1358       {72, 63, 80, 70, 88, 77, 96, 84},
1359       {54, 45, 60, 50, 66, 55, 72, 60},
1360       {104, 91, 112, 98, 120, 105, 128, 112},
1361       {78, 65, 84, 70, 90, 75, 96, 80},
1362   }));
1363   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1364 
1365   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1366 }
1367 
TEST_P(HloEvaluatorBf16Test,DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides)1368 TEST_P(HloEvaluatorBf16Test,
1369        DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) {
1370   HloComputation::Builder b(TestName());
1371 
1372   Array4D<float> lhs_array(1, 1, 4, 4);
1373   // clang-format off
1374   lhs_array.FillWithYX(Array2D<float>({
1375     {1,  2,  3,  4 },
1376     {5,  6,  7,  8 },
1377     {9,  10, 11, 12},
1378     {13, 14, 15, 16},
1379   }));
1380   // clang-format on
1381   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1382   HloInstruction* lhs_instruction =
1383       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1384 
1385   Array4D<float> rhs_array(1, 1, 2, 3);
1386   // clang-format off
1387   rhs_array.FillWithYX(Array2D<float>({
1388     {5, 6, 7},
1389     {8, 9, 10},
1390   }));
1391   // clang-format on
1392   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1393   HloInstruction* rhs_instruction =
1394       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1395 
1396   Window window;
1397   WindowDimension dim;
1398   dim.set_size(2);
1399   dim.set_stride(1);
1400   dim.set_padding_low(2);
1401   dim.set_padding_high(2);
1402   dim.set_window_dilation(2);
1403   dim.set_base_dilation(2);
1404   *window.add_dimensions() = dim;
1405   dim.set_size(3);
1406   dim.set_stride(3);
1407   dim.set_padding_low(2);
1408   dim.set_padding_high(-1);
1409   dim.set_window_dilation(1);
1410   dim.set_base_dilation(3);
1411   *window.add_dimensions() = dim;
1412 
1413   ConvolutionDimensionNumbers dnums =
1414       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1415 
1416   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
1417   b.AddInstruction(HloInstruction::CreateConvolve(
1418       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1419       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1420   m_->AddEntryComputation(b.Build());
1421 
1422   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1423 
1424   Array4D<float> expected_array(1, 1, 9, 3);
1425   expected_array.FillWithYX(Array2D<float>({
1426       {10, 20, 30},
1427       {0, 0, 0},
1428       {57, 74, 91},
1429       {0, 0, 0},
1430       {125, 142, 159},
1431       {0, 0, 0},
1432       {193, 210, 227},
1433       {0, 0, 0},
1434       {91, 98, 105},
1435   }));
1436   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1437 
1438   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1439 }
1440 
TEST_P(HloEvaluatorBf16Test,Conv2DGroupedConvolution)1441 TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) {
1442   HloComputation::Builder b(TestName());
1443   std::vector<int64> input_dims = {1, 2, 2, 4};
1444   std::vector<int64> filter_dims = {2, 2, 2, 8};
1445   Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
1446   Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
1447   // Tensorflow dimension numbers for 2D convolution.
1448   ConvolutionDimensionNumbers dnums;
1449   dnums.set_input_batch_dimension(0);
1450   dnums.set_output_batch_dimension(0);
1451   dnums.add_input_spatial_dimensions(1);
1452   dnums.add_output_spatial_dimensions(1);
1453   dnums.add_input_spatial_dimensions(2);
1454   dnums.add_output_spatial_dimensions(2);
1455   dnums.set_input_feature_dimension(3);
1456   dnums.set_output_feature_dimension(3);
1457   dnums.add_kernel_spatial_dimensions(0);
1458   dnums.add_kernel_spatial_dimensions(1);
1459   dnums.set_kernel_input_feature_dimension(2);
1460   dnums.set_kernel_output_feature_dimension(3);
1461 
1462   Window window;
1463   WindowDimension dim;
1464   dim.set_size(2);
1465   dim.set_stride(1);
1466   dim.set_padding_low(0);
1467   dim.set_padding_high(0);
1468   dim.set_window_dilation(1);
1469   dim.set_base_dilation(1);
1470   *window.add_dimensions() = dim;
1471   *window.add_dimensions() = dim;
1472 
1473   std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
1474   std::iota(input_elems.begin(), input_elems.end(), -7);
1475   auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
1476   auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1477   HloInstruction* lhs_instruction =
1478       b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
1479 
1480   std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1481   std::iota(filter_elems.begin(), filter_elems.end(), -31);
1482   auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
1483   auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1484   HloInstruction* rhs_instruction =
1485       b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
1486 
1487   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
1488   b.AddInstruction(HloInstruction::CreateConvolve(
1489       shape, lhs_instruction, rhs_instruction,
1490       /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums,
1491       DefaultPrecisionConfig(2)));
1492   m_->AddEntryComputation(b.Build());
1493 
1494   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1495 
1496   Array4D<float> expected_array(1, 1, 1, 8);
1497   expected_array.FillWithYX(
1498       Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
1499   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1500   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1501 }
1502 
1503 // Initialization of data sets for FFT tests:
1504 
InitializeFftData()1505 void HloEvaluatorTest::InitializeFftData() {
1506   // clang-format off
1507   fft_c64x2x4x8_ = LiteralUtil::CreateR3<complex64>({
1508     {{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0},
1509       {4.0, 0.0}, {5.0, 0.0}, {6.0, 0.0}, {7.0, 0.0}},
1510      {{0.0, 0.0}, {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0},
1511       {0.0, 4.0}, {0.0, 5.0}, {0.0, 6.0}, {0.0, 7.0}},
1512      {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0},
1513       {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}},
1514      {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0},
1515       {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}},
1516     {{{-4.0, 0.0}, {-3.0, 0.0}, {-2.0, 0.0}, {-1.0, 0.0},
1517       {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}},
1518      {{0.0, -4.0}, {0.0, -3.0}, {0.0, -2.0}, {0.0, -1.0},
1519       {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0}, {0.0, 4.0}},
1520      {{3.5, 3.5}, {-1.707107, -0.707107}, {-1.0, -0.0}, {-0.707107, 0.292893},
1521       {-0.5, 0.5}, {-0.292893, 0.707107}, {0.0, 1.0}, {0.707107, 1.707107}},
1522      {{3.5, 3.5}, {1.707107, 0.707107}, {1.0, 0.0}, {0.707107, -0.292893},
1523       {0.5, -0.5}, {0.292893, -0.707107}, {-0.0, -1.0}, {-0.707107, -1.707107}}}
1524   });
1525   fft_c64x2x4x8_1d_ = LiteralUtil::CreateR3<complex64>({
1526     {{{28.0, 0.0}, {-4.0, 9.656854}, {-4.0, 4.0}, {-4.0, 1.656854},
1527       {-4.0, 0.0}, {-4.0, -1.656854}, {-4.0, -4.0}, {-4.0, -9.656854}},
1528      {{0.0, 28.0}, {-9.656854, -4.0}, {-4.0, -4.0}, {-1.656854, -4.0},
1529       {0.0, -4.0}, {1.656854, -4.0}, {4.0, -4.0}, {9.656854, -4.0}},
1530      {{28.0, 28.0}, {5.656854, 13.656854}, {0.0, 8.0}, {-2.343146, 5.656854},
1531       {-4.0, 4.0}, {-5.656854, 2.343146}, {-8.0, -0.0}, {-13.656854, -5.656854}},  // NOLINT
1532      {{28.0, 28.0}, {-5.656854, -13.656854}, {-0.0, -8.0}, {2.343146, -5.656854},  // NOLINT
1533       {4.0, -4.0}, {5.656854, -2.343146}, {8.0, 0.0}, {13.656854, 5.656854}}},
1534     {{{0.0, 0.0}, {-5.0, 12.071068}, {-4.0, 4.0}, {-5.0, 2.071068},
1535       {-4.0, 0.0}, {-5.0, -2.071068}, {-4.0, -4.0}, {-5.0, -12.071068}},
1536      {{0.0, 0.0}, {-12.071068, -5.0}, {-4.0, -4.0}, {-2.071068, -5.0},
1537       {0.0, -4.0}, {2.071068, -5.0}, {4.0, -4.0}, {12.071068, -5.0}},
1538      {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0},
1539       {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}},
1540      {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0},
1541       {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}}
1542   });
1543   fft_c64x2x4x8_2d_ = LiteralUtil::CreateR3<complex64>({
1544     {{{84.0, 84.0}, {-13.656854, 5.656854}, {-8.0, 0.0}, {-5.656854, -2.343146},
1545       {-4.0, -4.0}, {-2.343146, -5.656854}, {0.0, -8.0}, {5.656854, -13.656854}},  // NOLINT
1546      {{0.0, 0.0}, {0.0, -0.0}, {0.0, 0.0}, {0.0, 0.0},
1547       {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
1548      {{28.0, -28.0}, {16.970562, 40.970562}, {0.0, 24.0}, {-7.029438, 16.970562},      // NOLINT
1549       {-12.0, 12.0}, {-16.970562, 7.029438}, {-24.0, 0.0}, {-40.970562, -16.970562}},  // NOLINT
1550      {{0.0, -56.0}, {-19.313708, -8.0}, {-8.0, -8.0}, {-3.313708, -8.0},
1551       {0.0, -8.0}, {3.313708, -8.0}, {8.0, -8.0}, {19.313708, -8.0}}},
1552     {{{7.0, 7.0}, {-10.071068, 14.071068}, {-1.0, 7.0}, {-0.071068, 4.071068},
1553       {3.0, 3.0}, {4.071068, -0.071068}, {7.0, -1.0}, {14.071068, -10.071068}},
1554      {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136},
1555       {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}},
1556      {{-7.0, 7.0}, {2.071068, 22.071068}, {-3.0, 11.0}, {-3.928932, 8.071068},
1557       {-3.0, 3.0}, {-4.071068, -0.071068}, {-3.0, -5.0}, {-10.071068, -14.071068}},  // NOLINT
1558      {{0.0, -14.0}, {0.0, -12.0}, {0.0, -10.0}, {0.0, -8.0},
1559       {0.0, -6.0}, {0.0, -4.0}, {0.0, -2.0}, {0.0, 0.0}}}
1560   });
1561   fft_c64x2x4x8_3d_ = LiteralUtil::CreateR3<complex64>({
1562     {{{91.0, 91.0}, {-23.727922, 19.727922}, {-9.0, 7.0}, {-5.727922, 1.727922},
1563       {-1.0, -1.0}, {1.727922, -5.727922}, {7.0, -9}, {19.727922, -23.727922}},
1564      {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136},
1565       {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}},
1566      {{21.0, -21.0}, {19.041630, 63.041630}, {-3.0, 35.0}, {-10.958370, 25.041630},     // NOLINT
1567       {-15.0, 15.0}, {-21.041630, 6.958370}, {-27.0, -5.0}, {-51.041630, -31.041630}},  // NOLINT
1568      {{0.0, -70.0}, {-19.313708, -20.0}, {-8.0, -18.0}, {-3.313708, -16.0},
1569       {0.0, -14.0}, {3.313708, -12.0}, {8.0, -10.0}, {19.313708, -8.0}}},
1570     {{{77.0, 77.0}, {-3.585786, -8.414214}, {-7.0, -7.0}, {-5.585786, -6.414214},   // NOLINT
1571       {-7.0, -7.0}, {-6.414214, -5.585786}, {-7.0, -7.0}, {-8.414214, -3.585786}},  // NOLINT
1572      {{0.0, 0.0}, {12.0, -24.142136}, {12.0, -8.0}, {16.0, -4.142136},
1573       {16.0, 0.0}, {20.0, 4.142136}, {20.0, 8.0}, {24.0, 24.142136}},
1574      {{35.0, -35.0}, {14.899494, 18.899494}, {3.0, 13.0}, {-3.100506, 8.899494},
1575       {-9.0, 9.0}, {-12.899494, 7.100506}, {-21.0, 5.0}, {-30.899494, -2.899494}},  // NOLINT
1576      {{0.0, -42.0}, {-19.313708, 4.0}, {-8.0, 2.0}, {-3.313708, 0.0},
1577       {0.0, -2.0}, {3.313708, -4.0}, {8.0, -6.0}, {19.313708, -8.0}}}
1578   });
1579   // clang-format on
1580 }
1581 
1582 // Simple FFT tests:
1583 
1584 TEST_F(HloEvaluatorTest, 1D_FFT_4_on_c64x4) {
1585   const char* hlo_text = R"(
1586 HloModule Fft
1587 
1588 ENTRY main {
1589   operand = c64[4] parameter(0)
1590   ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4}
1591 }
1592 )";
1593   auto input = LiteralUtil::CreateR1<complex64>(
1594       {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}});
1595   auto expected = LiteralUtil::CreateR1<complex64>(
1596       {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}});
1597   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1598   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1599   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1600   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1601 }
1602 
1603 TEST_F(HloEvaluatorTest, 1D_IFFT_4_on_c64x4) {
1604   const char* hlo_text = R"(
1605 HloModule Fft
1606 
1607 ENTRY main {
1608   operand = c64[4] parameter(0)
1609   ROOT ifft = c64[4] fft(operand), fft_type=IFFT, fft_length={4}
1610 }
1611 )";
1612   auto input = LiteralUtil::CreateR1<complex64>(
1613       {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}});
1614   auto expected = LiteralUtil::CreateR1<complex64>(
1615       {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}});
1616   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1617   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1618   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1619   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1620 }
1621 
1622 TEST_F(HloEvaluatorTest, 1D_RFFT_4_on_f32x4) {
1623   const char* hlo_text = R"(
1624 HloModule Fft
1625 
1626 ENTRY main {
1627   operand = f32[4] parameter(0)
1628   ROOT rfft = c64[3] fft(operand), fft_type=RFFT, fft_length={4}
1629 }
1630 )";
1631   auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
1632   auto expected =
1633       LiteralUtil::CreateR1<complex64>({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}});
1634   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1635   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1636   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1637   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1638 }
1639 
1640 TEST_F(HloEvaluatorTest, 1D_IRFFT_4_on_c64x3) {
1641   const char* hlo_text = R"(
1642 HloModule Fft
1643 
1644 ENTRY main {
1645   operand = c64[3] parameter(0)
1646   ROOT irfft = f32[4] fft(operand), fft_type=IRFFT, fft_length={4}
1647 }
1648 )";
1649   auto input =
1650       LiteralUtil::CreateR1<complex64>({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}});
1651   auto expected = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
1652   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1653   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1654   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1655   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1656 }
1657 
1658 // 1D FFT tests:
1659 
1660 TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8) {
1661   const char* hlo_text = R"(
1662 HloModule Fft
1663 
1664 ENTRY main {
1665   operand = c64[2, 4, 8] parameter(0)
1666   ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={8}
1667 }
1668 )";
1669   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1670   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1671   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape()));
1672   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_));
1673 }
1674 
1675 TEST_F(HloEvaluatorTest, 1D_IFFT_8_on_c64x2x4x8) {
1676   const char* hlo_text = R"(
1677 HloModule Fft
1678 
1679 ENTRY main {
1680   operand = c64[2, 4, 8] parameter(0)
1681   ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={8}
1682 }
1683 )";
1684   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1685   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_1d_}));
1686   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1687   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1688 }
1689 
1690 TEST_F(HloEvaluatorTest, 1D_RFFT_8_on_f32x8) {
1691   const char* hlo_text = R"(
1692 HloModule Fft
1693 
1694 ENTRY main {
1695   operand = f32[8] parameter(0)
1696   ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={8}
1697 }
1698 )";
1699   auto input =
1700       LiteralUtil::CreateR1<float>({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1});
1701   auto expected = LiteralUtil::CreateR1<complex64>({{39.6, 0.0},
1702                                                     {-3.6, 8.691169},
1703                                                     {-3.6, 3.6},
1704                                                     {-3.6, 1.491169},
1705                                                     {-3.6, 0.0}});
1706   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1707   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1708   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1709   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1710 }
1711 
1712 TEST_F(HloEvaluatorTest, 1D_IRFFT_8_on_c64x5) {
1713   const char* hlo_text = R"(
1714 HloModule Fft
1715 
1716 ENTRY main {
1717   operand = c64[5] parameter(0)
1718   ROOT irfft = f32[8] fft(operand), fft_type=IRFFT, fft_length={8}
1719 }
1720 )";
1721   auto input = LiteralUtil::CreateR1<complex64>({{39.6, 0.0},
1722                                                  {-3.6, 8.691169},
1723                                                  {-3.6, 3.6},
1724                                                  {-3.6, 1.491169},
1725                                                  {-3.6, 0.0}});
1726   auto expected =
1727       LiteralUtil::CreateR1<float>({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1});
1728   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1729   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1730   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1731   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1732 }
1733 
1734 TEST_F(HloEvaluatorTest, 1D_RFFT_9_on_f32x9) {
1735   const char* hlo_text = R"(
1736 HloModule Fft
1737 
1738 ENTRY main {
1739   operand = f32[9] parameter(0)
1740   ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={9}
1741 }
1742 )";
1743   auto input = LiteralUtil::CreateR1<float>(
1744       {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9});
1745   auto expected = LiteralUtil::CreateR1<complex64>({{49.5, 0.0},
1746                                                     {-3.360560, 11.705792},
1747                                                     {-3.893717, 5.712929},
1748                                                     {-4.5, 3.117691},
1749                                                     {-4.895723, 1.021942}});
1750   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1751   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1752   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1753   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1754 }
1755 
1756 TEST_F(HloEvaluatorTest, 1D_IRFFT_9_on_c64x5) {
1757   const char* hlo_text = R"(
1758 HloModule Fft
1759 
1760 ENTRY main {
1761   operand = c64[5] parameter(0)
1762   ROOT irfft = f32[9] fft(operand), fft_type=IRFFT, fft_length={9}
1763 }
1764 )";
1765   auto input = LiteralUtil::CreateR1<complex64>({{49.5, 0.0},
1766                                                  {-3.360560, 11.705792},
1767                                                  {-3.893717, 5.712929},
1768                                                  {-4.5, 3.117691},
1769                                                  {-4.895723, 1.021942}});
1770   auto expected = LiteralUtil::CreateR1<float>(
1771       {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9});
1772   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1773   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1774   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1775   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1776 }
1777 
1778 // 2D FFT tests:
1779 
1780 TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8) {
1781   const char* hlo_text = R"(
1782 HloModule Fft
1783 
1784 ENTRY main {
1785   operand = c64[2, 4, 8] parameter(0)
1786   ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={4, 8}
1787 }
1788 )";
1789   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1790   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1791   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape()));
1792   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_));
1793 }
1794 
1795 TEST_F(HloEvaluatorTest, 2D_IFFT_4x8_on_c64x2x4x8) {
1796   const char* hlo_text = R"(
1797 HloModule Fft
1798 
1799 ENTRY main {
1800   operand = c64[2, 4, 8] parameter(0)
1801   ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={4, 8}
1802 }
1803 )";
1804   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1805   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_2d_}));
1806   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1807   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1808 }
1809 
1810 TEST_F(HloEvaluatorTest, 2D_RFFT_3x8_on_f32x3x8) {
1811   const char* hlo_text = R"(
1812 HloModule Fft
1813 
1814 ENTRY main {
1815   operand = f32[3, 8] parameter(0)
1816   ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 8}
1817 }
1818 )";
1819   auto input =
1820       LiteralUtil::CreateR2<float>({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1},
1821                                     {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8},
1822                                     {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}});
1823   auto expected = LiteralUtil::CreateR2<complex64>({{{118.8, 0.0},
1824                                                      {-4.4, 10.622540},
1825                                                      {-4.4, 4.4},
1826                                                      {-4.4, 1.822540},
1827                                                      {-4.4, 0.0}},
1828                                                     {{0.0, 0.0},
1829                                                      {-19.926162, 0.797280},
1830                                                      {-10.128203, -3.728203},
1831                                                      {-6.069756, -5.602720},
1832                                                      {-3.2, -6.928203}},
1833                                                     {{0.0, 0.0},
1834                                                      {13.526162, 14.653687},
1835                                                      {3.728203, 10.128203},
1836                                                      {-0.330244, 8.253687},
1837                                                      {-3.2, 6.928203}}});
1838   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1839   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1840   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1841   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1842 }
1843 
1844 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x8_on_c64x3x5) {
1845   const char* hlo_text = R"(
1846 HloModule Fft
1847 
1848 ENTRY main {
1849   operand = c64[3, 5] parameter(0)
1850   ROOT irfft = f32[3, 8] fft(operand), fft_type=IRFFT, fft_length={3, 8}
1851 }
1852 )";
1853   auto input = LiteralUtil::CreateR2<complex64>({{{118.8, 0.0},
1854                                                   {-4.4, 10.622540},
1855                                                   {-4.4, 4.4},
1856                                                   {-4.4, 1.822540},
1857                                                   {-4.4, 0.0}},
1858                                                  {{0.0, 0.0},
1859                                                   {-19.926162, 0.797280},
1860                                                   {-10.128203, -3.728203},
1861                                                   {-6.069756, -5.602720},
1862                                                   {-3.2, -6.928203}},
1863                                                  {{0.0, 0.0},
1864                                                   {13.526162, 14.653687},
1865                                                   {3.728203, 10.128203},
1866                                                   {-0.330244, 8.253687},
1867                                                   {-3.2, 6.928203}}});
1868   auto expected =
1869       LiteralUtil::CreateR2<float>({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1},
1870                                     {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8},
1871                                     {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}});
1872   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1873   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1874   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1875   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1876 }
1877 
1878 TEST_F(HloEvaluatorTest, 2D_RFFT_3x9_on_f32x3x9) {
1879   const char* hlo_text = R"(
1880 HloModule Fft
1881 
1882 ENTRY main {
1883   operand = f32[3, 9] parameter(0)
1884   ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 9}
1885 }
1886 )";
1887   auto input = LiteralUtil::CreateR2<float>(
1888       {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1},
1889        {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9},
1890        {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}});
1891   auto expected = LiteralUtil::CreateR2<complex64>({{{148.5, 0.0},
1892                                                      {-4.95, 13.600013},
1893                                                      {-4.95, 5.899180},
1894                                                      {-4.95, 2.857884},
1895                                                      {-4.95, 0.872819}},
1896                                                     {{0.0, 0.0},
1897                                                      {-25.014467, 2.096690},
1898                                                      {-12.888800, -3.503916},
1899                                                      {-8.1, -5.715768},
1900                                                      {-4.974333, -7.159452}},
1901                                                     {{0.0, 0.0},
1902                                                      {17.814467, 17.685147},
1903                                                      {5.688800, 12.084542},
1904                                                      {0.9, 9.872690},
1905                                                      {-2.225667, 8.429006}}});
1906   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1907   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1908   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1909   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1910 }
1911 
1912 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x9_on_c64x3x5) {
1913   const char* hlo_text = R"(
1914 HloModule Fft
1915 
1916 ENTRY main {
1917   operand = c64[3, 5] parameter(0)
1918   ROOT irfft = f32[3, 9] fft(operand), fft_type=IRFFT, fft_length={3, 9}
1919 }
1920 )";
1921   auto input = LiteralUtil::CreateR2<complex64>({{{148.5, 0.0},
1922                                                   {-4.95, 13.600013},
1923                                                   {-4.95, 5.899180},
1924                                                   {-4.95, 2.857884},
1925                                                   {-4.95, 0.872819}},
1926                                                  {{0.0, 0.0},
1927                                                   {-25.014467, 2.096690},
1928                                                   {-12.888800, -3.503916},
1929                                                   {-8.1, -5.715768},
1930                                                   {-4.974333, -7.159452}},
1931                                                  {{0.0, 0.0},
1932                                                   {17.814467, 17.685147},
1933                                                   {5.688800, 12.084542},
1934                                                   {0.9, 9.872690},
1935                                                   {-2.225667, 8.429006}}});
1936   auto expected = LiteralUtil::CreateR2<float>(
1937       {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1},
1938        {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9},
1939        {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}});
1940   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1941   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1942   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1943   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1944 }
1945 
1946 // 3D FFT tests:
1947 
1948 TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8) {
1949   const char* hlo_text = R"(
1950 HloModule Fft
1951 
1952 ENTRY main {
1953   operand = c64[2, 4, 8] parameter(0)
1954   ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={2, 4, 8}
1955 }
1956 )";
1957   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1958   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1959   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape()));
1960   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_));
1961 }
1962 
1963 TEST_F(HloEvaluatorTest, 3D_IFFT_2x4x8_on_c64x2x4x8) {
1964   const char* hlo_text = R"(
1965 HloModule Fft
1966 
1967 ENTRY main {
1968   operand = c64[2, 4, 8] parameter(0)
1969   ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={2, 4, 8}
1970 }
1971 )";
1972   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1973   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_3d_}));
1974   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1975   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1976 }
1977 
1978 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_f32x3x3x4) {
1979   const char* hlo_text = R"(
1980 HloModule Fft
1981 
1982 ENTRY main {
1983   operand = f32[3, 3, 4] parameter(0)
1984   ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4}
1985 }
1986 )";
1987   auto input = LiteralUtil::CreateR3<float>(
1988       {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}},
1989        {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}},
1990        {{-1.8, -2.7, -3.6, -4.5},
1991         {-5.4, -6.3, -7.2, -8.1},
1992         {1.9, 2.9, 3.9, 4.9}}});
1993   auto expected = LiteralUtil::CreateR3<complex64>(
1994       {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}},
1995         {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}},
1996         {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}},
1997        {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}},
1998         {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}},
1999         {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}},
2000        {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}},
2001         {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}},
2002         {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}});
2003   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2004   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2005   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2006   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2007 }
2008 
2009 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_c64x3x3x3) {
2010   const char* hlo_text = R"(
2011 HloModule Fft
2012 
2013 ENTRY main {
2014   operand = c64[3, 3, 3] parameter(0)
2015   ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4}
2016 }
2017 )";
2018   auto input = LiteralUtil::CreateR3<complex64>(
2019       {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}},
2020         {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}},
2021         {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}},
2022        {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}},
2023         {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}},
2024         {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}},
2025        {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}},
2026         {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}},
2027         {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}});
2028   auto expected = LiteralUtil::CreateR3<float>(
2029       {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}},
2030        {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}},
2031        {{-1.8, -2.7, -3.6, -4.5},
2032         {-5.4, -6.3, -7.2, -8.1},
2033         {1.9, 2.9, 3.9, 4.9}}});
2034   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2035   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2036   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2037   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2038 }
2039 
2040 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x5_on_f32x3x3x5) {
2041   const char* hlo_text = R"(
2042 HloModule Fft
2043 
2044 ENTRY main {
2045   operand = f32[3, 3, 5] parameter(0)
2046   ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 5}
2047 }
2048 )";
2049   auto input = LiteralUtil::CreateR3<float>({{{1.8, 2.7, 3.6, 4.5, 5.4},
2050                                               {8.1, 7.2, 6.3, 5.4, 4.5},
2051                                               {1.1, 2.2, 3.3, 4.4, 5.5}},
2052                                              {{5.4, 6.3, 7.2, 8.1, 9.0},
2053                                               {4.5, 3.6, 2.7, 1.8, 0.9},
2054                                               {5.5, 6.6, 7.7, 8.8, 9.9}},
2055                                              {{-1.8, -2.7, -3.6, -4.5, -5.4},
2056                                               {-5.4, -6.3, -7.2, -8.1, -9.0},
2057                                               {1.9, 2.9, 3.9, 4.9, 5.9}}});
2058   auto expected = LiteralUtil::CreateR3<complex64>(
2059       {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}},
2060         {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}},
2061         {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}},
2062        {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}},
2063         {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}},
2064         {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}},
2065        {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}},
2066         {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}},
2067         {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}});
2068   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2069   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2070   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2071   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2072 }
2073 
2074 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x5_on_c64x3x3x3) {
2075   const char* hlo_text = R"(
2076 HloModule Fft
2077 
2078 ENTRY main {
2079   operand = c64[3, 3, 3] parameter(0)
2080   ROOT irfft = f32[3, 3, 5] fft(operand), fft_type=IRFFT, fft_length={3, 3, 5}
2081 }
2082 )";
2083   auto input = LiteralUtil::CreateR3<complex64>(
2084       {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}},
2085         {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}},
2086         {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}},
2087        {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}},
2088         {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}},
2089         {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}},
2090        {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}},
2091         {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}},
2092         {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}});
2093   auto expected = LiteralUtil::CreateR3<float>({{{1.8, 2.7, 3.6, 4.5, 5.4},
2094                                                  {8.1, 7.2, 6.3, 5.4, 4.5},
2095                                                  {1.1, 2.2, 3.3, 4.4, 5.5}},
2096                                                 {{5.4, 6.3, 7.2, 8.1, 9.0},
2097                                                  {4.5, 3.6, 2.7, 1.8, 0.9},
2098                                                  {5.5, 6.6, 7.7, 8.8, 9.9}},
2099                                                 {{-1.8, -2.7, -3.6, -4.5, -5.4},
2100                                                  {-5.4, -6.3, -7.2, -8.1, -9.0},
2101                                                  {1.9, 2.9, 3.9, 4.9, 5.9}}});
2102   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2103   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2104   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2105   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2106 }
2107 
2108 // FFT tests with non-default data layout:
2109 
2110 TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8_with_layout) {
2111   const char* hlo_text = R"(
2112 HloModule Fft
2113 
2114 ENTRY main {
2115   operand = c64[2, 4, 8]{0, 2, 1} parameter(0)
2116   ROOT fft = c64[2, 4, 8]{1, 2, 0} fft(operand), fft_type=FFT, fft_length={8}
2117 }
2118 )";
2119   auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({0, 2, 1}));
2120   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2121   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2122   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape()));
2123   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_));
2124 }
2125 
2126 TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8_with_layout) {
2127   const char* hlo_text = R"(
2128 HloModule Fft
2129 
2130 ENTRY main {
2131   operand = c64[2, 4, 8]{2, 0, 1} parameter(0)
2132   ROOT fft = c64[2, 4, 8]{1, 0, 2} fft(operand), fft_type=FFT, fft_length={4, 8}
2133 }
2134 )";
2135   auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({2, 0, 1}));
2136   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2137   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2138   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape()));
2139   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_));
2140 }
2141 
2142 TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8_with_layout) {
2143   const char* hlo_text = R"(
2144 HloModule Fft
2145 
2146 ENTRY main {
2147   operand = c64[2, 4, 8]{1, 2, 0} parameter(0)
2148   ROOT fft =
2149     c64[2, 4, 8]{0, 2, 1} fft(operand), fft_type=FFT, fft_length={2, 4, 8}
2150 }
2151 )";
2152   auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({1, 2, 0}));
2153   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2154   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2155   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape()));
2156   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_));
2157 }
2158 
2159 // FFT tests with unusual parameters:
2160 
2161 // Zero-length transform.
2162 TEST_F(HloEvaluatorTest, 1D_FFT_0_on_c64x1x1x1x1) {
2163   const char* hlo_text = R"(
2164 HloModule Fft
2165 
2166 ENTRY main {
2167   operand = c64[1, 1, 1, 1] parameter(0)
2168   ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={0}
2169 }
2170 )";
2171   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2172   auto expected = LiteralUtil::CreateR4<complex64>({{{{{0.0, 0.0}}}}});
2173   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2174   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2175   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2176   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2177 }
2178 
2179 // Zero-length axis.
2180 TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x0) {
2181   const char* hlo_text = R"(
2182 HloModule Fft
2183 
2184 ENTRY main {
2185   operand = c64[1, 1, 1, 0] parameter(0)
2186   ROOT fft = c64[1, 1, 1, 0] fft(operand), fft_type=FFT, fft_length={1}
2187 }
2188 )";
2189   TF_ASSERT_OK_AND_ASSIGN(
2190       auto input,
2191       LiteralUtil::CreateR4<complex64>({{{{}}}}).Reshape({1, 1, 1, 0}));
2192   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2193   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2194   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2195   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2196 }
2197 
2198 // Some/all dimensions have length 1.
2199 TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x1) {
2200   const char* hlo_text = R"(
2201 HloModule Fft
2202 
2203 ENTRY main {
2204   operand = c64[1, 1, 1, 1] parameter(0)
2205   ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1}
2206 }
2207 )";
2208   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2209   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2210   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2211   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2212   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2213 }
2214 
2215 // Zero-length transform.
2216 TEST_F(HloEvaluatorTest, 3D_FFT_1x0x1_on_c64x1x1x1x1) {
2217   const char* hlo_text = R"(
2218 HloModule Fft
2219 
2220 ENTRY main {
2221   operand = c64[1, 1, 1, 1] parameter(0)
2222   ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 0, 1}
2223 }
2224 )";
2225   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2226   auto expected = LiteralUtil::CreateR4<complex64>({{{{{0.0, 0.0}}}}});
2227   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2228   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2229   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2230   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2231 }
2232 
2233 // Zero-length axis.
2234 TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x0x1x0x1) {
2235   const char* hlo_text = R"(
2236 HloModule Fft
2237 
2238 ENTRY main {
2239   operand = c64[0, 1, 0, 1] parameter(0)
2240   ROOT fft = c64[0, 1, 0, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1}
2241 }
2242 )";
2243   TF_ASSERT_OK_AND_ASSIGN(
2244       auto input,
2245       LiteralUtil::CreateR4<complex64>({{{{}}}}).Reshape({0, 1, 0, 1}));
2246   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2247   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2248   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2249   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2250 }
2251 
2252 // Some/all dimensions have length 1.
2253 TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x1x1x1x1) {
2254   const char* hlo_text = R"(
2255 HloModule Fft
2256 
2257 ENTRY main {
2258   operand = c64[1, 1, 1, 1] parameter(0)
2259   ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1}
2260 }
2261 )";
2262   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2263   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2264   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2265   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2266   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2267 }
2268 
2269 // Some/all dimensions have length 1.
2270 TEST_F(HloEvaluatorTest, 3D_FFT_3x1x1_on_c64x1x3x1x1) {
2271   const char* hlo_text = R"(
2272 HloModule Fft
2273 
2274 ENTRY main {
2275   operand = c64[1, 3, 1, 1] parameter(0)
2276   ROOT fft = c64[1, 3, 1, 1] fft(operand), fft_type=FFT, fft_length={3, 1, 1}
2277 }
2278 )";
2279   auto input = LiteralUtil::CreateR4<complex64>(
2280       {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}});
2281   auto expected =
2282       LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}},
2283                                          {{{84.5367, 97.5818}}},
2284                                          {{{-0.0566792, -48.7418}}}}});
2285   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2286   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2287   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2288   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2289 }
2290 
2291 // Some/all dimensions have length 1.
2292 TEST_F(HloEvaluatorTest, 3D_IFFT_3x1x1_on_c64x1x3x1x1) {
2293   const char* hlo_text = R"(
2294 HloModule Fft
2295 
2296 ENTRY main {
2297   operand = c64[1, 3, 1, 1] parameter(0)
2298   ROOT ifft = c64[1, 3, 1, 1] fft(operand), fft_type=IFFT, fft_length={3, 1, 1}
2299 }
2300 )";
2301   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}},
2302                                                   {{{84.5367, 97.5818}}},
2303                                                   {{{-0.0566792, -48.7418}}}}});
2304   auto expected = LiteralUtil::CreateR4<complex64>(
2305       {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}});
2306   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2307   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2308   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2309   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2310 }
2311 
2312 // Odd transform length.
2313 TEST_F(HloEvaluatorTest, 1D_FFT_5_on_c64x5) {
2314   const char* hlo_text = R"(
2315 HloModule Fft
2316 
2317 ENTRY main {
2318   operand = c64[5] parameter(0)
2319   ROOT fft = c64[5] fft(operand), fft_type=FFT, fft_length={5}
2320 }
2321 )";
2322   auto input = LiteralUtil::CreateR1<complex64>(
2323       {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}});
2324   auto expected = LiteralUtil::CreateR1<complex64>({{15.0, 15.0},
2325                                                     {0.940955, 5.94095},
2326                                                     {-1.6877, 3.3123},
2327                                                     {-3.3123, 1.6877},
2328                                                     {-5.94095, -0.940955}});
2329   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2330   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2331   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2332   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2333 }
2334 
2335 // Odd transform length.
2336 TEST_F(HloEvaluatorTest, 1D_IFFT_5_on_c64x5) {
2337   const char* hlo_text = R"(
2338 HloModule Fft
2339 
2340 ENTRY main {
2341   operand = c64[5] parameter(0)
2342   ROOT ifft = c64[5] fft(operand), fft_type=IFFT, fft_length={5}
2343 }
2344 )";
2345   auto input = LiteralUtil::CreateR1<complex64>({{15.0, 15.0},
2346                                                  {0.940955, 5.94095},
2347                                                  {-1.6877, 3.3123},
2348                                                  {-3.3123, 1.6877},
2349                                                  {-5.94095, -0.940955}});
2350   auto expected = LiteralUtil::CreateR1<complex64>(
2351       {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}});
2352   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2353   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2354   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2355   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2356 }
2357 
2358 // All input values are zero.
2359 TEST_F(HloEvaluatorTest, 1D_FFT_4_on_zero_c64x4) {
2360   const char* hlo_text = R"(
2361 HloModule Fft
2362 
2363 ENTRY main {
2364   operand = c64[4] parameter(0)
2365   ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4}
2366 }
2367 )";
2368   auto input = LiteralUtil::CreateR1<complex64>(
2369       {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}});
2370   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2371   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2372   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2373   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2374 }
2375 
2376 // All input values are zero.
2377 TEST_F(HloEvaluatorTest, 3D_FFT_3x3x4_on_zero_c64x3x3x4) {
2378   const char* hlo_text = R"(
2379 HloModule Fft
2380 
2381 ENTRY main {
2382   operand = c64[3, 3, 4] parameter(0)
2383   ROOT fft = c64[3, 3, 4] fft(operand), fft_type=FFT, fft_length={3, 3, 4}
2384 }
2385 )";
2386   auto input = LiteralUtil::CreateR3<complex64>(
2387       {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2388         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2389         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2390        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2391         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2392         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2393        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2394         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2395         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2396   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2397   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2398   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2399   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2400 }
2401 
2402 // All input values are zero.
2403 TEST_F(HloEvaluatorTest, 3D_IFFT_3x3x4_on_zero_c64x3x3x4) {
2404   const char* hlo_text = R"(
2405 HloModule Fft
2406 
2407 ENTRY main {
2408   operand = c64[3, 3, 4] parameter(0)
2409   ROOT ifft = c64[3, 3, 4] fft(operand), fft_type=IFFT, fft_length={3, 3, 4}
2410 }
2411 )";
2412   auto input = LiteralUtil::CreateR3<complex64>(
2413       {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2414         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2415         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2416        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2417         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2418         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2419        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2420         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2421         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2422   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2423   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2424   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2425   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2426 }
2427 
2428 // All input values are zero.
2429 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_zero_f32x3x3x4) {
2430   const char* hlo_text = R"(
2431 HloModule Fft
2432 
2433 ENTRY main {
2434   operand = f32[3, 3, 4] parameter(0)
2435   ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4}
2436 }
2437 )";
2438   auto input = LiteralUtil::CreateR3<float>(
2439       {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2440        {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2441        {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}});
2442   auto expected = LiteralUtil::CreateR3<complex64>(
2443       {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2444         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2445         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2446        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2447         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2448         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2449        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2450         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2451         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2452   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2453   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2454   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2455   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2456 }
2457 
2458 // All input values are zero.
2459 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_zero_c64x3x3x3) {
2460   const char* hlo_text = R"(
2461 HloModule Fft
2462 
2463 ENTRY main {
2464   operand = c64[3, 3, 3] parameter(0)
2465   ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4}
2466 }
2467 )";
2468   auto input = LiteralUtil::CreateR3<complex64>(
2469       {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2470         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2471         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2472        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2473         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2474         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2475        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2476         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2477         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2478   auto expected = LiteralUtil::CreateR3<float>(
2479       {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2480        {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2481        {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}});
2482   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2483   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2484   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2485   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2486 }
2487 
2488 // Input values, for which IRFFT discards non-zero imaginary parts.
2489 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x4_on_c64x3x3) {
2490   const char* hlo_text = R"(
2491 HloModule Fft
2492 
2493 ENTRY main {
2494   operand = c64[3, 3] parameter(0)
2495   ROOT irfft = f32[3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 4}
2496 }
2497 )";
2498   auto input =
2499       LiteralUtil::CreateR2<complex64>({{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}},
2500                                         {{3.0, 0.0}, {4.0, 0.0}, {5.0, 0.0}},
2501                                         {{6.0, 0.0}, {7.0, 0.0}, {8.0, 0.0}}});
2502   auto expected =
2503       LiteralUtil::CreateR2<float>({{4.0, -0.5, 0.0, -0.5},
2504                                     {-1.5, 0.433013, 0.0, -0.433013},
2505                                     {-1.5, -0.433013, 0.0, 0.433013}});
2506   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2507   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2508   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2509   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2510 }
2511 
2512 class HloEvaluatorPreciseReduceTest : public HloTestBase {};
2513 
2514 // Tests that Reduce doesn't lose precision when adding many numbers (because
2515 // it accumulates its result in a double).
TEST_F(HloEvaluatorPreciseReduceTest,AddReductionPrecisionTest)2516 TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
2517   auto m = CreateNewVerifiedModule();
2518   HloComputation::Builder b(TestName());
2519 
2520   constexpr int kNumElements = 1 << 25;  // float += 1 saturates at 1<<24
2521   std::vector<float> v(kNumElements, 1.0f);
2522   HloInstruction* arg_instruction = b.AddInstruction(
2523       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
2524   HloInstruction* init_value = b.AddInstruction(
2525       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2526 
2527   HloComputation::Builder add_computation("add");
2528   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2529   auto param_lhs = add_computation.AddInstruction(
2530       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2531   auto param_rhs = add_computation.AddInstruction(
2532       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2533   add_computation.AddInstruction(HloInstruction::CreateBinary(
2534       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2535   auto add_func = m->AddEmbeddedComputation(add_computation.Build());
2536 
2537   HloInstruction* reduce_instruction = b.AddInstruction(
2538       HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
2539                                    /*dimensions_to_reduce=*/{0}, add_func));
2540   m->AddEntryComputation(b.Build());
2541 
2542   HloEvaluator hlo_eval;
2543   Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
2544   LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
2545 }
2546 
2547 // Reducing many numbers should be fast because it doesn't create
2548 // intermediate Literals; the microbenchmark should finish in < 1 msec.
BM_ReducePrecisely(::testing::benchmark::State & state)2549 void BM_ReducePrecisely(::testing::benchmark::State& state) {
2550   HloComputation::Builder b("BM_ReducePrecisely");
2551   HloModuleConfig config;
2552   config.set_debug_options(GetDebugOptionsFromFlags());
2553   HloModule module("BM_ReducePrecisely", config);
2554 
2555   constexpr int kNumElements = 1 << 25;  // float += 1 saturates at 1<<24
2556   std::vector<float> v(kNumElements, 1.0f);
2557   HloInstruction* arg_instruction = b.AddInstruction(
2558       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
2559   auto init_value = b.AddInstruction(
2560       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2561 
2562   HloComputation::Builder add_computation("add");
2563   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2564   auto param_lhs = add_computation.AddInstruction(
2565       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2566   auto param_rhs = add_computation.AddInstruction(
2567       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2568   add_computation.AddInstruction(HloInstruction::CreateBinary(
2569       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2570   auto add_func = module.AddEmbeddedComputation(add_computation.Build());
2571 
2572   HloInstruction* reduce_instruction = b.AddInstruction(
2573       HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
2574                                    /*dimensions_to_reduce=*/{0}, add_func));
2575   module.AddEntryComputation(b.Build());
2576 
2577   // Benchmark loop
2578   for (auto s : state) {
2579     HloEvaluator hlo_eval;
2580     hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
2581   }
2582 }
2583 
2584 BENCHMARK(BM_ReducePrecisely);
2585 
TEST_P(HloEvaluatorBf16Test,ReduceAdd)2586 TEST_P(HloEvaluatorBf16Test, ReduceAdd) {
2587   HloComputation::Builder b(TestName());
2588 
2589   // arg:
2590   // f32[2,3] {
2591   //  { 1, 2, 3 },
2592   //  { 5, 6, 7 },
2593   // }
2594   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
2595   arg_array->FillUnique(1.0f);
2596   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2597 
2598   HloInstruction* arg_instruction =
2599       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2600 
2601   auto init_value = b.AddInstruction(
2602       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2603 
2604   HloComputation::Builder add_computation("add");
2605   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2606   auto param_lhs = add_computation.AddInstruction(
2607       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2608   auto param_rhs = add_computation.AddInstruction(
2609       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2610   add_computation.AddInstruction(HloInstruction::CreateBinary(
2611       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2612   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2613 
2614   Shape shape = ShapeUtil::MakeShape(F32, {2});
2615   b.AddInstruction(
2616       HloInstruction::CreateReduce(shape, arg_instruction, init_value,
2617                                    /*dimensions_to_reduce=*/{1}, add_func));
2618 
2619   m_->AddEntryComputation(b.Build());
2620 
2621   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2622 
2623   auto expected = LiteralUtil::CreateR1<float>({6, 18});
2624 
2625   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2626 }
2627 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMax)2628 TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) {
2629   HloComputation::Builder b(TestName());
2630 
2631   // arg:
2632   // f32[2,3] {
2633   //  { 1, 2, 3 },
2634   //  { 5, 6, 7 },
2635   // }
2636   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
2637   arg_array->FillUnique(1.0f);
2638   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2639 
2640   HloInstruction* arg_instruction =
2641       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2642 
2643   auto init_value = b.AddInstruction(
2644       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2645   auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32());
2646 
2647   Window window;
2648   WindowDimension dim;
2649   dim.set_size(2);
2650   dim.set_stride(1);
2651   dim.set_padding_low(0);
2652   dim.set_padding_high(0);
2653   dim.set_window_dilation(1);
2654   dim.set_base_dilation(1);
2655   *window.add_dimensions() = dim;
2656   *window.add_dimensions() = dim;
2657 
2658   Shape shape = ShapeUtil::MakeShape(F32, {1, 2});
2659   b.AddInstruction(HloInstruction::CreateReduceWindow(
2660       shape, arg_instruction, init_value, window, max_func));
2661 
2662   m_->AddEntryComputation(b.Build());
2663 
2664   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2665 
2666   auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
2667   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2668 }
2669 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaWindowDilation)2670 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaWindowDilation) {
2671   auto expected = LiteralUtil::CreateR2<float>({{10, 11}, {14, 15}});
2672   ReduceWindowMaxIotaTest(
2673       /*window_size=*/2,
2674       /*padding=*/0,
2675       /*stride=*/1,
2676       /*window_dilation=*/2,
2677       /*base_dilation=*/1,
2678       /*expected=*/expected);
2679 }
2680 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideWindowDilation)2681 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideWindowDilation) {
2682   auto expected = LiteralUtil::CreateR2<float>({{10}});
2683   ReduceWindowMaxIotaTest(
2684       /*window_size=*/2,
2685       /*padding=*/0,
2686       /*stride=*/2,
2687       /*window_dilation=*/2,
2688       /*base_dilation=*/1,
2689       /*expected=*/expected);
2690 }
2691 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaBaseDilation)2692 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaBaseDilation) {
2693   auto expected = LiteralUtil::CreateR2<float>({{0, 1, 1, 2, 2, 3},
2694                                                 {4, 5, 5, 6, 6, 7},
2695                                                 {4, 5, 5, 6, 6, 7},
2696                                                 {8, 9, 9, 10, 10, 11},
2697                                                 {8, 9, 9, 10, 10, 11},
2698                                                 {12, 13, 13, 14, 14, 15}});
2699   ReduceWindowMaxIotaTest(
2700       /*window_size=*/2,
2701       /*padding=*/0,
2702       /*stride=*/1,
2703       /*window_dilation=*/1,
2704       /*base_dilation=*/2,
2705       /*expected=*/expected);
2706 }
2707 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideBaseDilation)2708 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBaseDilation) {
2709   auto expected =
2710       LiteralUtil::CreateR2<float>({{0, 1, 2}, {4, 5, 6}, {8, 9, 10}});
2711   ReduceWindowMaxIotaTest(
2712       /*window_size=*/2,
2713       /*padding=*/0,
2714       /*stride=*/2,
2715       /*window_dilation=*/1,
2716       /*base_dilation=*/2,
2717       /*expected=*/expected);
2718 }
2719 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideBothDilation)2720 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBothDilation) {
2721   auto expected =
2722       LiteralUtil::CreateR2<float>({{5, 6, 7}, {9, 10, 11}, {13, 14, 15}});
2723   ReduceWindowMaxIotaTest(
2724       /*window_size=*/2,
2725       /*padding=*/0,
2726       /*stride=*/2,
2727       /*window_dilation=*/2,
2728       /*base_dilation=*/2,
2729       /*expected=*/expected);
2730 }
2731 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaPaddingStrideBaseDilation)2732 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaPaddingStrideBaseDilation) {
2733   // The base is dilated first, and then padding is applied, hence this result.
2734   auto expected =
2735       LiteralUtil::CreateR2<float>({{0, 2, 3}, {8, 10, 11}, {12, 14, 15}});
2736   ReduceWindowMaxIotaTest(
2737       /*window_size=*/3,
2738       /*padding=*/1,
2739       /*stride=*/3,
2740       /*window_dilation=*/1,
2741       /*base_dilation=*/2,
2742       /*expected=*/expected);
2743 }
2744 
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd)2745 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) {
2746   HloComputation::Builder b(TestName());
2747 
2748   // arg:
2749   // f32[2,3] {
2750   //  { 1, 2, 3 },
2751   //  { 5, 6, 7 },
2752   // }
2753   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
2754   arg_array->FillUnique(1.0f);
2755   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2756 
2757   HloInstruction* arg_instruction =
2758       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2759 
2760   auto init_value = b.AddInstruction(
2761       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2762 
2763   HloComputation::Builder add_computation("add");
2764   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2765   auto param_lhs = add_computation.AddInstruction(
2766       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2767   auto param_rhs = add_computation.AddInstruction(
2768       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2769   add_computation.AddInstruction(HloInstruction::CreateBinary(
2770       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2771   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2772 
2773   Window window;
2774   WindowDimension dim;
2775   dim.set_size(1);
2776   dim.set_stride(1);
2777   dim.set_padding_low(0);
2778   dim.set_padding_high(0);
2779   dim.set_window_dilation(1);
2780   dim.set_base_dilation(1);
2781   *window.add_dimensions() = dim;
2782   dim.set_size(2);
2783   dim.set_stride(1);
2784   dim.set_padding_low(1);
2785   dim.set_padding_high(0);
2786   dim.set_window_dilation(1);
2787   dim.set_base_dilation(1);
2788   *window.add_dimensions() = dim;
2789 
2790   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
2791   b.AddInstruction(HloInstruction::CreateReduceWindow(
2792       shape, arg_instruction, init_value, window, add_func));
2793 
2794   m_->AddEntryComputation(b.Build());
2795 
2796   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2797 
2798   auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
2799   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2800 }
2801 
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd6D)2802 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
2803   HloComputation::Builder b(TestName());
2804 
2805   // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
2806   std::vector<int64> input_dims(6, 4);
2807   Literal arg_literal =
2808       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
2809 
2810   HloInstruction* arg_instruction =
2811       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2812 
2813   auto init_value = b.AddInstruction(
2814       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2815 
2816   HloComputation::Builder add_computation("add");
2817   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2818   auto param_lhs = add_computation.AddInstruction(
2819       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2820   auto param_rhs = add_computation.AddInstruction(
2821       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2822   add_computation.AddInstruction(HloInstruction::CreateBinary(
2823       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2824   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2825 
2826   Window window;
2827 
2828   WindowDimension trivial_dim;
2829   trivial_dim.set_size(1);
2830   trivial_dim.set_stride(1);
2831   trivial_dim.set_padding_low(0);
2832   trivial_dim.set_padding_high(0);
2833   trivial_dim.set_window_dilation(1);
2834   trivial_dim.set_base_dilation(1);
2835 
2836   WindowDimension active_dim;
2837   active_dim.set_size(2);
2838   active_dim.set_stride(1);
2839   active_dim.set_padding_low(0);
2840   active_dim.set_padding_high(0);
2841   active_dim.set_window_dilation(1);
2842   active_dim.set_base_dilation(1);
2843 
2844   *window.add_dimensions() = trivial_dim;
2845   *window.add_dimensions() = active_dim;
2846   *window.add_dimensions() = active_dim;
2847   *window.add_dimensions() = active_dim;
2848   *window.add_dimensions() = trivial_dim;
2849   *window.add_dimensions() = trivial_dim;
2850 
2851   Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4});
2852   b.AddInstruction(HloInstruction::CreateReduceWindow(
2853       shape, arg_instruction, init_value, window, add_func));
2854 
2855   m_->AddEntryComputation(b.Build());
2856 
2857   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2858 
2859   std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
2860   Literal result_literal =
2861       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
2862   EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
2863 }
2864 
TEST_P(HloEvaluatorBf16Test,Min3In5Stride2Tuple)2865 TEST_P(HloEvaluatorBf16Test, Min3In5Stride2Tuple) {
2866   HloComputation::Builder builder("main");
2867   auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
2868       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2869   auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
2870       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2871   HloComputation::Builder bcompute("ComputeFunction");
2872   auto shape1 = ShapeUtil::MakeShape(F32, {});
2873   auto shape2 = ShapeUtil::MakeShape(F32, {});
2874   auto p2 =
2875       bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
2876   auto p3 =
2877       bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
2878   auto p4 =
2879       bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
2880   auto p5 =
2881       bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
2882   std::vector<HloInstruction*> compute_vec = {
2883       bcompute.AddInstruction(
2884           HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
2885       bcompute.AddInstruction(
2886           HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
2887   bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
2888   auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
2889   std::vector<HloInstruction*> input_vec = {input1, input2};
2890   auto init1 = builder.AddInstruction(
2891       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2892   auto init2 = builder.AddInstruction(
2893       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2894   std::vector<HloInstruction*> init_vec = {init1, init2};
2895   auto padding = std::pair<int64, int64>(0, 0);
2896   TF_ASSERT_OK_AND_ASSIGN(auto window,
2897                           ShapeInference::InferWindowFromDimensions(
2898                               {3}, {2}, absl::MakeSpan(&padding, 1),
2899                               /*lhs_dilation=*/{},
2900                               /*rhs_dilation=*/{}));
2901   std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
2902   std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
2903   TF_ASSERT_OK_AND_ASSIGN(Shape shape,
2904                           ShapeInference::InferReduceWindowShape(
2905                               input_shapes, init_shapes, window,
2906                               compute_tuple->ComputeProgramShape()));
2907   builder.AddInstruction(HloInstruction::CreateReduceWindow(
2908       shape, input_vec, init_vec, window, compute_tuple));
2909   auto r1 = LiteralUtil::CreateR1<float>({100, 1});
2910   auto expected = LiteralUtil::MakeTuple({&r1, &r1});
2911   m_->AddEntryComputation(builder.Build());
2912   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2913   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2914 }
2915 
TEST_P(HloEvaluatorBf16Test,Min3In5Stride2TupleDiffInput)2916 TEST_P(HloEvaluatorBf16Test, Min3In5Stride2TupleDiffInput) {
2917   HloComputation::Builder builder("main");
2918   auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
2919       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2920   auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
2921       LiteralUtil::CreateR1<int>({15, 28, 300, 107, 12})));
2922   HloComputation::Builder bcompute("ComputeFunction");
2923   auto shape1 = ShapeUtil::MakeShape(F32, {});
2924   auto shape2 = ShapeUtil::MakeShape(S32, {});
2925   auto p2 =
2926       bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
2927   auto p3 =
2928       bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
2929   auto p4 =
2930       bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
2931   auto p5 =
2932       bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
2933   std::vector<HloInstruction*> compute_vec = {
2934       bcompute.AddInstruction(
2935           HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
2936       bcompute.AddInstruction(
2937           HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
2938   bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
2939   auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
2940   std::vector<HloInstruction*> input_vec = {input1, input2};
2941   auto init1 = builder.AddInstruction(
2942       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2943   auto init2 = builder.AddInstruction(
2944       HloInstruction::CreateConstant(LiteralUtil::MaxValue(S32)));
2945   std::vector<HloInstruction*> init_vec = {init1, init2};
2946   auto padding = std::pair<int64, int64>(0, 0);
2947   TF_ASSERT_OK_AND_ASSIGN(auto window,
2948                           ShapeInference::InferWindowFromDimensions(
2949                               {3}, {2}, absl::MakeSpan(&padding, 1),
2950                               /*lhs_dilation=*/{},
2951                               /*rhs_dilation=*/{}));
2952   std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
2953   std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
2954   TF_ASSERT_OK_AND_ASSIGN(Shape shape,
2955                           ShapeInference::InferReduceWindowShape(
2956                               input_shapes, init_shapes, window,
2957                               compute_tuple->ComputeProgramShape()));
2958   builder.AddInstruction(HloInstruction::CreateReduceWindow(
2959       shape, input_vec, init_vec, window, compute_tuple));
2960   auto r1 = LiteralUtil::CreateR1<float>({100, 1});
2961   auto r2 = LiteralUtil::CreateR1<int>({15, 12});
2962   auto expected = LiteralUtil::MakeTuple({&r1, &r2});
2963   m_->AddEntryComputation(builder.Build());
2964   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2965   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2966 }
2967 
TEST_P(HloEvaluatorBf16Test,StridedSlice)2968 TEST_P(HloEvaluatorBf16Test, StridedSlice) {
2969   HloComputation::Builder b(TestName());
2970 
2971   // arg:
2972   // f32[3,5] {
2973   //  { 1, 2, 3, 4, 5 },
2974   //  { 9, 10, 11, 12, 13 },
2975   //  { 17, 18, 19, 20, 21 },
2976   // }
2977   auto operand_array = absl::make_unique<Array2D<float>>(3, 5);
2978   operand_array->FillUnique(1.0f);
2979   auto operand_literal =
2980       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
2981 
2982   HloInstruction* operand = b.AddInstruction(
2983       HloInstruction::CreateConstant(std::move(operand_literal)));
2984 
2985   Shape shape = ShapeUtil::MakeShape(F32, {2, 1});
2986   b.AddInstruction(HloInstruction::CreateSlice(shape, operand,
2987                                                /*start_indices=*/{0, 2},
2988                                                /*limit_indices=*/{3, 5},
2989                                                /*strides=*/{2, 3}));
2990   m_->AddEntryComputation(b.Build());
2991 
2992   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2993 
2994   auto expected = LiteralUtil::CreateR2<float>({
2995       {3},
2996       {19},
2997   });
2998 
2999   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3000 }
3001 
TEST_P(HloEvaluatorBf16Test,DynamicSlice)3002 TEST_P(HloEvaluatorBf16Test, DynamicSlice) {
3003   HloComputation::Builder b(TestName());
3004 
3005   // arg:
3006   // f32[2,4] {
3007   //  { 1, 2, 3, 4 },
3008   //  { 5, 6, 7, 8 },
3009   // }
3010   auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
3011   operand_array->FillUnique(1.0f);
3012   auto operand_literal =
3013       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3014 
3015   HloInstruction* operand = b.AddInstruction(
3016       HloInstruction::CreateConstant(std::move(operand_literal)));
3017 
3018   auto zero = b.AddInstruction(
3019       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
3020   auto one = b.AddInstruction(
3021       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
3022 
3023   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3024   b.AddInstruction(
3025       HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3}));
3026   m_->AddEntryComputation(b.Build());
3027 
3028   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3029 
3030   auto expected = LiteralUtil::CreateR2<float>({
3031       {2, 3, 4},
3032       {6, 7, 8},
3033   });
3034 
3035   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3036 }
3037 
3038 // Verifies that the HloEvaluator's implementation goes along with existing
3039 // backends' behavior, although this is not required by the spec.
TEST_P(HloEvaluatorBf16Test,DynamicSliceModSlice)3040 TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) {
3041   HloComputation::Builder b(TestName());
3042 
3043   // arg:
3044   // f32[2,4] {
3045   //  { 1, 2, 3, 4 },
3046   //  { 5, 6, 7, 8 },
3047   // }
3048   auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
3049   operand_array->FillUnique(1.0f);
3050   auto operand_literal =
3051       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3052 
3053   HloInstruction* operand = b.AddInstruction(
3054       HloInstruction::CreateConstant(std::move(operand_literal)));
3055 
3056   auto two = b.AddInstruction(
3057       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
3058   auto one = b.AddInstruction(
3059       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
3060 
3061   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3062   b.AddInstruction(
3063       HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3}));
3064   m_->AddEntryComputation(b.Build());
3065 
3066   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3067 
3068   auto expected = LiteralUtil::CreateR2<float>({
3069       {2, 3, 4},
3070       {6, 7, 8},
3071   });
3072 
3073   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3074 }
3075 
TEST_P(HloEvaluatorBf16Test,DynamicSliceUpdate)3076 TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) {
3077   HloComputation::Builder b(TestName());
3078 
3079   // arg:
3080   // f32[2,3] {
3081   //  { 1, 2, 3 },
3082   //  { 5, 6, 7 },
3083   // }
3084   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
3085   operand_array->FillUnique(1.0);
3086   auto operand_literal =
3087       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3088 
3089   HloInstruction* operand = b.AddInstruction(
3090       HloInstruction::CreateConstant(std::move(operand_literal)));
3091 
3092   auto zero = b.AddInstruction(
3093       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
3094   auto one = b.AddInstruction(
3095       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
3096 
3097   auto update = b.AddInstruction(HloInstruction::CreateConstant(
3098       LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
3099 
3100   Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
3101   b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3102       shape, operand, update, {zero, one}));
3103   m_->AddEntryComputation(b.Build());
3104 
3105   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3106 
3107   auto expected = LiteralUtil::CreateR2<double>({
3108       {1, -2, -3},
3109       {5, -6, -7},
3110   });
3111 
3112   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3113 }
3114 
TEST_P(HloEvaluatorBf16Test,SetAndGetTuples)3115 TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) {
3116   HloComputation::Builder b(TestName());
3117 
3118   // arg:
3119   // f32[2,3] {
3120   //  { 1, 2, 3 },
3121   //  { 5, 6, 7 },
3122   // }
3123   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
3124   operand_array->FillUnique(1.0);
3125   auto operand_literal2 =
3126       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3127 
3128   HloInstruction* operand2 = b.AddInstruction(
3129       HloInstruction::CreateConstant(std::move(operand_literal2)));
3130   HloInstruction* operand1 = b.AddInstruction(
3131       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
3132 
3133   auto tuple =
3134       b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
3135 
3136   Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
3137   b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1));
3138 
3139   m_->AddEntryComputation(b.Build());
3140 
3141   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3142 
3143   auto expected = LiteralUtil::CreateR2<double>({
3144       {1, 2, 3},
3145       {5, 6, 7},
3146   });
3147 
3148   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3149 }
3150 
TEST_P(HloEvaluatorBf16Test,SetAndGetNestedTuples)3151 TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) {
3152   HloComputation::Builder b(TestName());
3153 
3154   // arg:
3155   // f32[2,3] {
3156   //  { 1, 2, 3 },
3157   //  { 5, 6, 7 },
3158   // }
3159   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
3160   operand_array->FillUnique(1.0);
3161 
3162   HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
3163       LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
3164   HloInstruction* operand1 = b.AddInstruction(
3165       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
3166 
3167   auto tuple1 =
3168       b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
3169   auto tuple2 =
3170       b.AddInstruction(HloInstruction::CreateTuple({operand2, operand2}));
3171 
3172   auto outer_tuple =
3173       b.AddInstruction(HloInstruction::CreateTuple({tuple1, tuple2}));
3174 
3175   b.AddInstruction(
3176       HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1));
3177 
3178   m_->AddEntryComputation(b.Build());
3179 
3180   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3181 
3182   auto result_inner_literal =
3183       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3184   auto expected =
3185       LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
3186 
3187   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3188 }
3189 
TEST_P(HloEvaluatorBf16Test,Reverse)3190 TEST_P(HloEvaluatorBf16Test, Reverse) {
3191   HloComputation::Builder b(TestName());
3192 
3193   // Input shape is float[4x3x2x1].
3194   // clang-format off
3195   Array4D<float> input({
3196     {{{1.0f}, {2.0f}},
3197      {{3.0f}, {4.0f}},
3198      {{5.0f}, {6.0f}}},
3199     {{{7.0f}, {8.0f}},
3200      {{9.0f}, {10.0f}},
3201      {{11.0f}, {12.0f}}},
3202     {{{13.0f}, {14.0f}},
3203      {{15.0f}, {16.0f}},
3204      {{17.0f}, {18.0f}}},
3205     {{{19.0f}, {20.0f}},
3206      {{21.0f}, {22.0f}},
3207      {{23.0f}, {24.0f}}},
3208   });
3209   // clang-format on
3210   auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
3211   HloInstruction* operand = b.AddInstruction(
3212       HloInstruction::CreateConstant(std::move(operand_literal)));
3213 
3214   const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
3215   b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
3216   m_->AddEntryComputation(b.Build());
3217 
3218   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3219 
3220   // clang-format off
3221   auto expected = LiteralUtil::CreateR4FromArray4D<float>({
3222     {{{23.0f}, {24.0f}},
3223      {{21.0f}, {22.0f}},
3224      {{19.0f}, {20.0f}}},
3225 
3226     {{{17.0f}, {18.0f}},
3227      {{15.0f}, {16.0f}},
3228      {{13.0f}, {14.0f}}},
3229 
3230     {{{11.0f}, {12.0f}},
3231      {{9.0f}, {10.0f}},
3232      {{7.0f}, {8.0f}}},
3233 
3234     {{{5.0f}, {6.0f}},
3235      {{3.0f}, {4.0f}},
3236      {{1.0f}, {2.0f}}},
3237   });
3238   // clang-format on
3239 
3240   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3241 }
3242 
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutions)3243 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) {
3244   HloComputation::Builder b(TestName());
3245   Shape shape = ShapeUtil::MakeShape(F32, {4});
3246 
3247   HloInstruction* param0 =
3248       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
3249   HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
3250       shape, HloOpcode::kMultiply, param0, param0));
3251   HloInstruction* add = b.AddInstruction(
3252       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, square));
3253 
3254   // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
3255   HloEvaluator evaluator;
3256   Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
3257   Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
3258   TF_ASSERT_OK_AND_ASSIGN(
3259       Literal result,
3260       evaluator.EvaluateWithSubstitutions(
3261           add, {{param0, &param0_literal}, {square, &square_literal}}));
3262   EXPECT_TRUE(LiteralTestUtil::Equal(
3263       LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
3264 }
3265 
3266 // Check that EvaluateWithSubstitutions works if one of the operands to the op
3267 // we're evaluating is a constant.
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutionsWithConstantOperand)3268 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) {
3269   HloComputation::Builder b(TestName());
3270   Shape shape = ShapeUtil::MakeShape(F32, {4});
3271 
3272   HloInstruction* param0 =
3273       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
3274   HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
3275       shape, HloOpcode::kMultiply, param0, param0));
3276   HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
3277       LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
3278   HloInstruction* add = b.AddInstruction(
3279       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
3280 
3281   // Evaluate add with square = {10, 20, 30, 40}.
3282   HloEvaluator evaluator;
3283   Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
3284   TF_ASSERT_OK_AND_ASSIGN(
3285       Literal result,
3286       evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}));
3287   EXPECT_TRUE(LiteralTestUtil::Equal(
3288       LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
3289 }
3290 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV1)3291 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
3292   const char* hlo_text = R"(
3293 HloModule TensorFlowGatherV1
3294 
3295 ENTRY main {
3296   operand = s32[3,3] parameter(0)
3297   indices = s32[2] parameter(1)
3298   ROOT gather = s32[2,3] gather(operand, indices),
3299       offset_dims={1},
3300       collapsed_slice_dims={0},
3301       start_index_map={0},
3302       index_vector_dim=1,
3303       slice_sizes={1, 3}
3304 }
3305 )";
3306   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3307   Literal operand =
3308       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3309   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
3310   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3311   EXPECT_TRUE(LiteralTestUtil::Equal(
3312       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), result));
3313 }
3314 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV2)3315 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
3316   const char* hlo_text = R"(
3317 HloModule TensorFlowGatherV2
3318 
3319 ENTRY main {
3320   operand = s32[3,3] parameter(0)
3321   indices = s32[2] parameter(1)
3322   ROOT gather = s32[3,2] gather(operand, indices),
3323       offset_dims={0},
3324       collapsed_slice_dims={1},
3325       start_index_map={1},
3326       index_vector_dim=1,
3327       slice_sizes={3, 1}
3328 }
3329 )";
3330   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3331   Literal operand =
3332       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3333   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
3334   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3335   EXPECT_TRUE(LiteralTestUtil::Equal(
3336       LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), result));
3337 }
3338 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherMultipleBatchDims)3339 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
3340   const char* hlo_text = R"(
3341 HloModule TensorFlowGatherMultipleBatchDims
3342 
3343 ENTRY main {
3344   operand = s32[3,3] parameter(0)
3345   indices = s32[2,2] parameter(1)
3346   ROOT gather = s32[2,3,2] gather(operand, indices),
3347       offset_dims={1},
3348       collapsed_slice_dims={1},
3349       start_index_map={1},
3350       index_vector_dim=2,
3351       slice_sizes={3, 1}
3352 }
3353 )";
3354   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3355   Literal operand =
3356       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3357   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
3358   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3359   EXPECT_TRUE(LiteralTestUtil::Equal(
3360       LiteralUtil::CreateR3<int32>(
3361           {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
3362       result));
3363 }
3364 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNd)3365 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
3366   const char* hlo_text = R"(
3367 HloModule TensorFlowGatherNd
3368 
3369 ENTRY main {
3370   operand = s32[3,3,2] parameter(0)
3371   indices = s32[2,2] parameter(1)
3372   ROOT gather = s32[2,2] gather(operand, indices),
3373       offset_dims={1},
3374       collapsed_slice_dims={0,1},
3375       start_index_map={0,1},
3376       index_vector_dim=1,
3377       slice_sizes={1,1,2}
3378 }
3379 )";
3380   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3381   Literal operand =
3382       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
3383                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
3384                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
3385   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
3386   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3387   EXPECT_TRUE(LiteralTestUtil::Equal(
3388       LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}), result));
3389 }
3390 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim)3391 TEST_F(HloEvaluatorTest,
3392        EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
3393   const char* hlo_text = R"(
3394 HloModule TensorFlowGatherNd
3395 
3396 ENTRY main {
3397   operand = s32[3,3,2] parameter(0)
3398   indices = s32[2,2] parameter(1)
3399   ROOT gather = s32[2,2] gather(operand, indices),
3400       offset_dims={1},
3401       collapsed_slice_dims={0,1},
3402       start_index_map={0,1},
3403       index_vector_dim=0,
3404       slice_sizes={1,1,2}
3405 }
3406 )";
3407   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3408   Literal operand =
3409       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
3410                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
3411                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
3412   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
3413   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3414   EXPECT_TRUE(LiteralTestUtil::Equal(
3415       LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}), result));
3416 }
3417 
TEST_F(HloEvaluatorTest,EvaluateGather_DynamicSlice)3418 TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
3419   const char* hlo_text = R"(
3420 HloModule DynamicSlice
3421 
3422 ENTRY main {
3423   operand = s32[3,3] parameter(0)
3424   indices = s32[2] parameter(1)
3425   ROOT gather = s32[1,1] gather(operand, indices),
3426       offset_dims={0,1},
3427       collapsed_slice_dims={},
3428       start_index_map={0,1},
3429       index_vector_dim=0,
3430       slice_sizes={1,1}
3431 }
3432 )";
3433   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3434   Literal operand =
3435       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3436   Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
3437   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3438   EXPECT_TRUE(
3439       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}), result));
3440 }
3441 
TEST_F(HloEvaluatorTest,EvaluateGather_BatchDynamicSlice)3442 TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
3443   const char* hlo_text = R"(
3444 HloModule BatchDynamicSlice
3445 
3446 ENTRY main {
3447   operand = s32[3,3] parameter(0)
3448   indices = s32[2,2] parameter(1)
3449   ROOT gather = s32[2,1,1] gather(operand, indices),
3450       offset_dims={1,2},
3451       collapsed_slice_dims={},
3452       start_index_map={0,1},
3453       index_vector_dim=0,
3454       slice_sizes={1,1}
3455 }
3456 )";
3457   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3458   Literal operand =
3459       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3460   Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
3461   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3462   EXPECT_TRUE(LiteralTestUtil::Equal(
3463       LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}), result));
3464 }
3465 
TEST_F(HloEvaluatorTest,EvaluateGather_ZeroDimBounds)3466 TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
3467   const char* hlo_text = R"(
3468 HloModule TensorFlowGatherV1
3469 
3470 ENTRY main {
3471   operand = s32[3,0] parameter(0)
3472   indices = s32[2] parameter(1)
3473   ROOT gather = s32[2,0] gather(operand, indices),
3474       offset_dims={1},
3475       collapsed_slice_dims={0},
3476       start_index_map={0},
3477       index_vector_dim=1,
3478       slice_sizes={1, 0}
3479 }
3480 )";
3481   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3482   Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
3483   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
3484   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3485   EXPECT_TRUE(
3486       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}), result));
3487 }
3488 
TEST_F(HloEvaluatorTest,EvaluateGather_NoOutputWindowDims)3489 TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
3490   const string hlo_text = R"(
3491 HloModule GatherXd
3492 
3493 ENTRY main {
3494   operand = s32[3] parameter(0)
3495   indices = s32[2,2,1] parameter(1)
3496   ROOT gather = s32[2,2] gather(operand, indices),
3497       offset_dims={},
3498       collapsed_slice_dims={0},
3499       start_index_map={0},
3500       index_vector_dim=2,
3501       slice_sizes={1}
3502 }
3503 )";
3504   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3505 
3506   Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
3507   Literal start_indices =
3508       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
3509   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3510   EXPECT_TRUE(LiteralTestUtil::Equal(
3511       LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}), result));
3512 }
3513 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV1_Update)3514 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
3515   const char* hlo_text = R"(
3516 HloModule TensorFlowScatterV1
3517 
3518 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3519   lhs = s32[] parameter(0)
3520   ROOT rhs = s32[] parameter(1)
3521 }
3522 
3523 ENTRY main {
3524   operand = s32[3,3] parameter(0)
3525   indices = s32[2] parameter(1)
3526   updates = s32[2,3] parameter(2)
3527   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3528       to_apply=update_s32,
3529       update_window_dims={1},
3530       inserted_window_dims={0},
3531       scatter_dims_to_operand_dims={0},
3532       index_vector_dim=1
3533 }
3534 )";
3535   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3536   Literal operand =
3537       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3538   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3539   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3540   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3541                           Evaluate({&operand, &scatter_indices, &updates}));
3542   EXPECT_TRUE(LiteralTestUtil::Equal(
3543       LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
3544       result));
3545 }
3546 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV2_Update)3547 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
3548   const char* hlo_text = R"(
3549 HloModule TensorFlowScatterV2
3550 
3551 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3552   lhs = s32[] parameter(0)
3553   ROOT rhs = s32[] parameter(1)
3554 }
3555 
3556 ENTRY main {
3557   operand = s32[3,3] parameter(0)
3558   indices = s32[2] parameter(1)
3559   updates = s32[3,2] parameter(2)
3560   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3561       to_apply=update_s32,
3562       update_window_dims={0},
3563       inserted_window_dims={1},
3564       scatter_dims_to_operand_dims={1},
3565       index_vector_dim=1
3566 }
3567 )";
3568   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3569   Literal operand =
3570       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3571   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3572   Literal updates =
3573       LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
3574   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3575                           Evaluate({&operand, &scatter_indices, &updates}));
3576   EXPECT_TRUE(LiteralTestUtil::Equal(
3577       LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
3578       result));
3579 }
3580 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Add)3581 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
3582   const char* hlo_text = R"(
3583 HloModule TensorFlowScatter
3584 
3585 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3586   lhs = s32[] parameter(0)
3587   rhs = s32[] parameter(1)
3588   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3589 }
3590 
3591 ENTRY main {
3592   operand = s32[3,3] parameter(0)
3593   indices = s32[2] parameter(1)
3594   updates = s32[2,3] parameter(2)
3595   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3596       to_apply=add_s32,
3597       update_window_dims={1},
3598       inserted_window_dims={0},
3599       scatter_dims_to_operand_dims={0},
3600       index_vector_dim=1
3601 }
3602 )";
3603   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3604   Literal operand =
3605       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3606   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3607   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3608   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3609                           Evaluate({&operand, &scatter_indices, &updates}));
3610   EXPECT_TRUE(LiteralTestUtil::Equal(
3611       LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
3612       result));
3613 }
3614 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Mul)3615 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
3616   const char* hlo_text = R"(
3617 HloModule TensorFlowScatter
3618 
3619 mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3620   lhs = s32[] parameter(0)
3621   rhs = s32[] parameter(1)
3622   ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
3623 }
3624 
3625 ENTRY main {
3626   operand = s32[3,3] parameter(0)
3627   indices = s32[2] parameter(1)
3628   updates = s32[2,3] parameter(2)
3629   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3630       to_apply=mul_s32,
3631       update_window_dims={1},
3632       inserted_window_dims={0},
3633       scatter_dims_to_operand_dims={0},
3634       index_vector_dim=1
3635 }
3636 )";
3637   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3638   Literal operand =
3639       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3640   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3641   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3642   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3643                           Evaluate({&operand, &scatter_indices, &updates}));
3644   EXPECT_TRUE(LiteralTestUtil::Equal(
3645       LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
3646       result));
3647 }
3648 
TEST_P(HloEvaluatorBf16Test,EvaluateScatter_TensorFlowScatter_F32)3649 TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) {
3650   const char* hlo_text = R"(
3651 HloModule TensorFlowScatter
3652 
3653 add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
3654   lhs = f32[] parameter(0)
3655   rhs = f32[] parameter(1)
3656   ROOT add = f32[] add(f32[] lhs, f32[] rhs)
3657 }
3658 
3659 ENTRY main {
3660   operand = f32[3,3] parameter(0)
3661   indices = s32[2] parameter(1)
3662   updates = f32[2,3] parameter(2)
3663   ROOT scatter = f32[3,3] scatter(operand, indices, updates),
3664       to_apply=add_f32,
3665       update_window_dims={1},
3666       inserted_window_dims={0},
3667       scatter_dims_to_operand_dims={0},
3668       index_vector_dim=1
3669 }
3670 )";
3671   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3672   Literal operand = LiteralUtil::CreateR2<float>(
3673       {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
3674   Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
3675   Literal updates =
3676       LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
3677   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3678                           Evaluate({&operand, &scatter_indices, &updates}));
3679   EXPECT_TRUE(LiteralTestUtil::Near(
3680       LiteralUtil::CreateR2<float>(
3681           {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
3682       result, ErrorSpec{0.1, 0.01}));
3683 }
3684 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_RepeatedIndices)3685 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
3686   const char* hlo_text = R"(
3687 HloModule TensorFlowScatter
3688 
3689 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3690   lhs = s32[] parameter(0)
3691   rhs = s32[] parameter(1)
3692   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3693 }
3694 
3695 ENTRY main {
3696   operand = s32[3,3] parameter(0)
3697   indices = s32[2] parameter(1)
3698   updates = s32[2,3] parameter(2)
3699   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3700       to_apply=add_s32,
3701       update_window_dims={1},
3702       inserted_window_dims={0},
3703       scatter_dims_to_operand_dims={0},
3704       index_vector_dim=1
3705 }
3706 )";
3707   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3708   Literal operand =
3709       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3710   Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
3711   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3712   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3713                           Evaluate({&operand, &scatter_indices, &updates}));
3714   EXPECT_TRUE(LiteralTestUtil::Equal(
3715       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
3716       result));
3717 }
3718 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_MultipleBatchDims)3719 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
3720   const char* hlo_text = R"(
3721 HloModule TensorFlowScatterMultipleBatchDims
3722 
3723 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3724   lhs = s32[] parameter(0)
3725   rhs = s32[] parameter(1)
3726   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3727 }
3728 
3729 ENTRY main {
3730   operand = s32[3,3] parameter(0)
3731   indices = s32[2,2] parameter(1)
3732   updates = s32[2,3,2] parameter(2)
3733   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3734       to_apply=add_s32,
3735       update_window_dims={1},
3736       inserted_window_dims={1},
3737       scatter_dims_to_operand_dims={1},
3738       index_vector_dim=2
3739 }
3740 )";
3741   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3742   Literal operand =
3743       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3744   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
3745   Literal updates = LiteralUtil::CreateR3<int32>(
3746       {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
3747   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3748                           Evaluate({&operand, &scatter_indices, &updates}));
3749   EXPECT_TRUE(LiteralTestUtil::Equal(
3750       LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
3751       result));
3752 }
3753 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd)3754 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
3755   const char* hlo_text = R"(
3756 HloModule TensorFlowScatterNd
3757 
3758 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3759   lhs = s32[] parameter(0)
3760   ROOT rhs = s32[] parameter(1)
3761 }
3762 
3763 ENTRY main {
3764   operand = s32[3,3,2] parameter(0)
3765   indices = s32[2,2] parameter(1)
3766   updates = s32[2,2] parameter(2)
3767   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
3768       to_apply=update_s32,
3769       update_window_dims={1},
3770       inserted_window_dims={0,1},
3771       scatter_dims_to_operand_dims={0,1},
3772       index_vector_dim=1
3773 }
3774 )";
3775   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3776   Literal operand =
3777       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
3778                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
3779                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
3780   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
3781   Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
3782   Literal expected =
3783       LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}},  //
3784                                     {{-40, 40}, {-5, 5}, {-6, 6}},  //
3785                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
3786   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3787                           Evaluate({&operand, &scatter_indices, &updates}));
3788   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3789 }
3790 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim)3791 TEST_F(HloEvaluatorTest,
3792        EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) {
3793   const char* hlo_text = R"(
3794 HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
3795 
3796 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3797   lhs = s32[] parameter(0)
3798   ROOT rhs = s32[] parameter(1)
3799 }
3800 
3801 ENTRY main {
3802   operand = s32[3,3,2] parameter(0)
3803   indices = s32[2,2] parameter(1)
3804   updates = s32[2,2] parameter(2)
3805   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
3806       to_apply=update_s32,
3807       update_window_dims={1},
3808       inserted_window_dims={0,1},
3809       scatter_dims_to_operand_dims={0,1},
3810       index_vector_dim=0
3811 }
3812 )";
3813   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3814   Literal operand =
3815       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
3816                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
3817                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
3818   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
3819   Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
3820   Literal expected =
3821       LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}},  //
3822                                     {{-4, 4}, {-5, 5}, {-6, 6}},      //
3823                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
3824   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3825                           Evaluate({&operand, &scatter_indices, &updates}));
3826   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3827 }
3828 
TEST_F(HloEvaluatorTest,EvaluateScatter_DynamicUpdateSlice)3829 TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
3830   const char* hlo_text = R"(
3831 HloModule DynamicUpdateSlice
3832 
3833 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3834   lhs = s32[] parameter(0)
3835   ROOT rhs = s32[] parameter(1)
3836 }
3837 
3838 ENTRY main {
3839   operand = s32[3,3] parameter(0)
3840   indices = s32[2] parameter(1)
3841   updates = s32[1,1] parameter(2)
3842   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3843       to_apply=update_s32,
3844       update_window_dims={0,1},
3845       inserted_window_dims={},
3846       scatter_dims_to_operand_dims={0,1},
3847       index_vector_dim=0
3848 }
3849 )";
3850   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3851   Literal operand =
3852       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3853   Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
3854   Literal updates = LiteralUtil::CreateR2<int32>({{10}});
3855   Literal expected =
3856       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
3857   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3858                           Evaluate({&operand, &scatter_indices, &updates}));
3859   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3860 }
3861 
TEST_F(HloEvaluatorTest,EvaluateScatter_BatchDynamicUpdateSlice)3862 TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
3863   const char* hlo_text = R"(
3864 HloModule BatchDynamicUpdateSlice
3865 
3866 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3867   lhs = s32[] parameter(0)
3868   ROOT rhs = s32[] parameter(1)
3869 }
3870 
3871 ENTRY main {
3872   operand = s32[3,3] parameter(0)
3873   indices = s32[2,2] parameter(1)
3874   updates = s32[2,1,1] parameter(2)
3875   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3876       to_apply=update_s32,
3877       update_window_dims={1,2},
3878       inserted_window_dims={},
3879       scatter_dims_to_operand_dims={0,1},
3880       index_vector_dim=0
3881 }
3882 )";
3883   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3884   Literal operand =
3885       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3886   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
3887   Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
3888   Literal expected =
3889       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
3890   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3891                           Evaluate({&operand, &scatter_indices, &updates}));
3892   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3893 }
3894 
TEST_F(HloEvaluatorTest,EvaluateScatter_ZeroDimBounds)3895 TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
3896   const char* hlo_text = R"(
3897 HloModule TensorFlowScatter_ZeroDimBounds
3898 
3899 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3900   lhs = s32[] parameter(0)
3901   ROOT rhs = s32[] parameter(1)
3902 }
3903 
3904 ENTRY main {
3905   operand = s32[3,0] parameter(0)
3906   indices = s32[2] parameter(1)
3907   updates = s32[2,0] parameter(2)
3908   ROOT scatter = s32[3,0] scatter(operand, indices, updates),
3909       to_apply=update_s32,
3910       update_window_dims={1},
3911       inserted_window_dims={0},
3912       scatter_dims_to_operand_dims={0},
3913       index_vector_dim=1
3914 }
3915 )";
3916   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3917   Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
3918   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
3919   Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
3920   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3921                           Evaluate({&operand, &scatter_indices, &updates}));
3922   EXPECT_TRUE(LiteralTestUtil::Equal(operand, result));
3923 }
3924 
TEST_F(HloEvaluatorTest,EvaluateScatter_NoUpdateWindowDims)3925 TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
3926   const string hlo_text = R"(
3927 HloModule Scatter_NoUpdateWindowDims
3928 
3929 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3930   lhs = s32[] parameter(0)
3931   rhs = s32[] parameter(1)
3932   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3933 }
3934 
3935 ENTRY main {
3936   operand = s32[3] parameter(0)
3937   indices = s32[2,2,1] parameter(1)
3938   updates = s32[2,2] parameter(2)
3939   ROOT scatter = s32[3] scatter(operand, indices, updates),
3940       to_apply=add_s32,
3941       update_window_dims={},
3942       inserted_window_dims={0},
3943       scatter_dims_to_operand_dims={0},
3944       index_vector_dim=2
3945 }
3946 )";
3947   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3948 
3949   Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
3950   Literal scatter_indices =
3951       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
3952   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
3953   Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
3954   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3955                           Evaluate({&operand, &scatter_indices, &updates}));
3956   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3957 }
3958 
TEST_F(HloEvaluatorTest,EvaluateScatter_NegativeIndices)3959 TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
3960   const char* hlo_text = R"(
3961 HloModule TensorFlowScatter_NegativeIndices
3962 
3963 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3964   lhs = s32[] parameter(0)
3965   rhs = s32[] parameter(1)
3966   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3967 }
3968 
3969 ENTRY main {
3970   operand = s32[3,3] parameter(0)
3971   indices = s32[2] parameter(1)
3972   updates = s32[2,3] parameter(2)
3973   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3974       to_apply=add_s32,
3975       update_window_dims={1},
3976       inserted_window_dims={0},
3977       scatter_dims_to_operand_dims={0},
3978       index_vector_dim=1
3979 }
3980 )";
3981   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3982                           ParseAndReturnVerifiedModule(hlo_text));
3983   Literal operand =
3984       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3985   // No updates should happen for the negative indices.
3986   Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2});
3987   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
3988   EXPECT_TRUE(LiteralTestUtil::Equal(
3989       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
3990       EvaluateWithModule(module.get(),
3991                          {&operand, &scatter_indices, &updates})));
3992 }
3993 
TEST_F(HloEvaluatorTest,EvaluateScatter_OobIndices)3994 TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) {
3995   const string hlo_text = R"(
3996 HloModule BatchDynamicUpdateSlice
3997 
3998 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3999   lhs = s32[] parameter(0)
4000   ROOT rhs = s32[] parameter(1)
4001 }
4002 
4003 ENTRY main {
4004   operand = s32[3,3]{1,0} parameter(0)
4005   indices = s32[6,2]{1,0} parameter(1)
4006   updates = s32[6,1,1]{2,1,0} parameter(2)
4007   ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
4008       to_apply=update_s32,
4009       update_window_dims={1,2},
4010       inserted_window_dims={},
4011       scatter_dims_to_operand_dims={0,1},
4012       index_vector_dim=1
4013 }
4014 )";
4015   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4016                           ParseAndReturnVerifiedModule(hlo_text));
4017   Literal operand =
4018       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
4019   // No updates should happen for the OOB indices.
4020   Literal scatter_indices = LiteralUtil::CreateR2<int32>(
4021       {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
4022   Literal updates = LiteralUtil::CreateR3<int32>(
4023       {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
4024   EXPECT_TRUE(LiteralTestUtil::Equal(
4025       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
4026       EvaluateWithModule(module.get(),
4027                          {&operand, &scatter_indices, &updates})));
4028 }
4029 
TEST_F(HloEvaluatorTest,EvaluateScatter_OobUpdateWindow)4030 TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
4031   const char* hlo_text = R"(
4032 HloModule TensorFlowScatterNd_OobUpdateWindow
4033 
4034 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
4035   lhs = s32[] parameter(0)
4036   ROOT rhs = s32[] parameter(1)
4037 }
4038 
4039 ENTRY main {
4040   operand = s32[3,3,2] parameter(0)
4041   indices = s32[1,2] parameter(1)
4042   updates = s32[1,2,2] parameter(2)
4043   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
4044       to_apply=update_s32,
4045       update_window_dims={1,2},
4046       inserted_window_dims={0},
4047       scatter_dims_to_operand_dims={0,1},
4048       index_vector_dim=1
4049 }
4050 )";
4051   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4052                           ParseAndReturnVerifiedModule(hlo_text));
4053   Literal operand =
4054       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
4055                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
4056                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
4057   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
4058   Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
4059   // Given the update window size of 2,2 and the index of 0,2, the update window
4060   // will be OOB. So, nothing should be updated.
4061   Literal expected = operand.Clone();
4062   EXPECT_TRUE(LiteralTestUtil::Equal(
4063       expected, EvaluateWithModule(module.get(),
4064                                    {&operand, &scatter_indices, &updates})));
4065 }
4066 
4067 // Verifies that HloEvaluator evaluates a HLO instruction that performs
4068 // element-wise comparison with 2 bfloat16 operands.
TEST_F(HloEvaluatorTest,DoesCompareBF16)4069 TEST_F(HloEvaluatorTest, DoesCompareBF16) {
4070   // lhs >= rhs
4071   auto lhs = LiteralUtil::CreateR2<bfloat16>(
4072       {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
4073        {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
4074   auto rhs = LiteralUtil::CreateR2<bfloat16>(
4075       {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
4076        {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
4077   auto expected =
4078       LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
4079 
4080   HloComputation::Builder b(TestName());
4081   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
4082   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
4083   b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
4084                                                  ComparisonDirection::kGe));
4085   m_->AddEntryComputation(b.Build());
4086 
4087   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
4088   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4089 }
4090 
TEST_P(HloEvaluatorBf16Test,Bf16Reduction)4091 TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
4092   const string hlo_text = R"(
4093 HloModule Bf16Reduction
4094 
4095 add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
4096   lhs = bf16[] parameter(0)
4097   rhs = bf16[] parameter(1)
4098   ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
4099 }
4100 
4101 ENTRY main {
4102   arg0 = bf16[4]{0} parameter(0)
4103   init = bf16[] constant(0)
4104   ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
4105 }
4106 )";
4107   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4108 
4109   Literal arg = LiteralUtil::CreateR1<bfloat16>(
4110       {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
4111   Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
4112   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
4113   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4114 }
4115 
TEST_F(HloEvaluatorTest,MixedPrecisionReduction)4116 TEST_F(HloEvaluatorTest, MixedPrecisionReduction) {
4117   const string hlo_text = R"(
4118 HloModule MixedPrecisionReduction
4119 
4120 add_f32 {
4121   lhs = f32[] parameter(0)
4122   rhs = f32[] parameter(1)
4123   ROOT add = f32[] add(lhs, rhs)
4124 }
4125 
4126 ENTRY main {
4127   arg0 = f32[4]{0} parameter(0)
4128   init = f32[] constant(0)
4129   ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_f32
4130 }
4131 )";
4132   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4133 
4134   Literal arg = LiteralUtil::CreateR1<float>({1.0f, 3.0f, -2.0f, 42.0f});
4135   Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
4136   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
4137   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4138 }
4139 
TEST_F(HloEvaluatorTest,DontFailOnCallUnimplementedOps)4140 TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) {
4141   // Infeed triggers unimplemented error within HandleCall, and we verify that
4142   // the Evaluator does fail in such case.
4143   const string hlo_text = R"(
4144 HloModule DontFailOnCall
4145 
4146 call {
4147   token0 = token[] after-all()
4148   ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
4149 }
4150 
4151 ENTRY main {
4152   ROOT result = ((u32[3]{0}, pred[]), token[]) call(), to_apply=call
4153 }
4154 )";
4155   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4156   auto statusor = Evaluate();
4157   EXPECT_FALSE(statusor.status().ok());
4158 }
4159 
TEST_F(HloEvaluatorTest,DontFailOnFusionWithUnimplementedOps)4160 TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) {
4161   // Infeed triggers unimplemented error within HandleFusion, and we verify that
4162   // the Evaluator does fail in such case.
4163   const string hlo_text = R"(
4164 HloModule DontFailOnFusion
4165 
4166 fused_computation {
4167   token0 = token[] after-all()
4168   ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
4169 }
4170 
4171 ENTRY main {
4172   ROOT result = ((u32[3]{0}, pred[]), token[]) fusion(), kind=kLoop, calls=fused_computation
4173 }
4174 )";
4175   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4176   auto statusor = Evaluate();
4177   EXPECT_FALSE(statusor.status().ok());
4178 }
4179 
TEST_P(HloEvaluatorBf16Test,SliceWithDifferentLayout)4180 TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) {
4181   // Regression test for b/114735354.
4182   const string hlo_text = R"(
4183 HloModule SliceWithDifferentLayout
4184 
4185 ENTRY main {
4186   arg = f32[2,2,2]{0,1,2} parameter(0)
4187   ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
4188 }
4189 )";
4190   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4191 
4192   Literal arg = LiteralUtil::CreateR3WithLayout<float>(
4193       {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
4194       LayoutUtil::MakeLayout({0, 1, 2}));
4195   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg}));
4196   EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
4197 }
4198 
TEST_P(HloEvaluatorBf16Test,Bitcast)4199 TEST_P(HloEvaluatorBf16Test, Bitcast) {
4200   // Regression test for b/114735354.
4201   const absl::string_view hlo_text_base = R"(
4202 HloModule Bitcast
4203 
4204 ENTRY main {
4205   param = %s[32,121]{1,0} parameter(0)
4206   ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param)
4207 }
4208 )";
4209   string hlo_text;
4210   if (use_bfloat16_) {
4211     hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
4212   } else {
4213     hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
4214   }
4215   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4216   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4217   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4218   if (use_bfloat16_) {
4219     EXPECT_TRUE(
4220         absl::c_equal(args[0].data<bfloat16>(), actual.data<bfloat16>()));
4221   } else {
4222     EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4223   }
4224 }
4225 
4226 // Check that s32 under/overflow doesn't trigger a ubsan failure.
TEST_F(HloEvaluatorTest,Int32Overflow)4227 TEST_F(HloEvaluatorTest, Int32Overflow) {
4228   const absl::string_view hlo_text = R"(
4229 HloModule Test
4230 
4231 ENTRY main {
4232   c1 = s32[] constant(1073741824)  // 2^30
4233   sum = s32[] add(c1, c1)  // 2^31, i.e. INT_MIN
4234 
4235   c2 = s32[] constant(-2147483648)  // -2^31
4236   sub = s32[] subtract(c2, c1)  // -2^31 - 2^30, underflows
4237 
4238   mul = s32[] multiply(c1, c1)
4239   ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul)
4240 }
4241 )";
4242   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4243   TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({}));
4244   std::vector<Literal> actual = literal.DecomposeTuple();
4245   ASSERT_EQ(actual.size(), 3);
4246 
4247   uint32 pow30 = uint32{1} << 30;
4248   uint32 pow31 = uint32{1} << 31;
4249   EXPECT_EQ(actual[0].GetFirstElement<int32>(), static_cast<int32>(pow31));
4250   EXPECT_EQ(actual[1].GetFirstElement<int32>(),
4251             static_cast<int32>(-(pow31 + pow30)));
4252   EXPECT_EQ(actual[2].GetFirstElement<int32>(),
4253             static_cast<int32>(pow31 * pow31));
4254 }
4255 
TEST_F(HloEvaluatorTest,GetDimensionSize)4256 TEST_F(HloEvaluatorTest, GetDimensionSize) {
4257   const absl::string_view hlo_text = R"(
4258 HloModule Test
4259 
4260 ENTRY main {
4261   size = s32[] parameter(0)
4262 
4263   data = s32[4] parameter(1)
4264 
4265   sum = s32[4] add(data, data)
4266 
4267   ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0}
4268 }
4269 )";
4270   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4271 
4272   // Set up dynamic parameter binding.
4273   TF_CHECK_OK(m_->dynamic_parameter_binding().Bind(
4274       DynamicParameterBinding::DynamicParameter{0, {}},
4275       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
4276 
4277   TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference,
4278                           DynamicDimensionInference::Run(m_.get()));
4279 
4280   evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
4281   Literal size_arg = LiteralUtil::CreateR0<int32>(3);
4282   Literal data_arg = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
4283 
4284   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
4285 
4286   EXPECT_EQ(actual.GetFirstElement<int32>(), static_cast<int32>(3));
4287 }
4288 
4289 // Check that we get a useful error if we pass inputs of the wrong shape.
TEST_F(HloEvaluatorTest,EvaluateWithWrongInputShapes)4290 TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) {
4291   const absl::string_view hlo_text = R"(
4292 HloModule Test
4293 
4294 ENTRY main {
4295   p0 = s32[1] parameter(0)
4296   ROOT sum = s32[1] add(p0, p0)
4297 }
4298 )";
4299   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4300   Literal input_wrong_shape = LiteralUtil::CreateR1<int32>({0, 1});
4301 
4302   EXPECT_EQ(HloEvaluator()
4303                 .Evaluate(*m_, {&input_wrong_shape})
4304                 .status()
4305                 .error_message(),
4306             "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
4307             "but arg was s32[2]{0}.");
4308   EXPECT_EQ(HloEvaluator()
4309                 .Evaluate(*m_->entry_computation(), {&input_wrong_shape})
4310                 .status()
4311                 .error_message(),
4312             "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
4313             "but arg was s32[2]{0}.");
4314 }
4315 
4316 // Check that we get a useful error if we pass too many or too few inputs.
TEST_F(HloEvaluatorTest,EvaluateWithWrongNumberOfInputs)4317 TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) {
4318   const absl::string_view hlo_text = R"(
4319 HloModule Test
4320 
4321 ENTRY main {
4322   p0 = s32[1] parameter(0)
4323   ROOT sum = s32[1] add(p0, p0)
4324 }
4325 )";
4326   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4327   Literal input = LiteralUtil::CreateR1<int32>({0});
4328 
4329   EXPECT_EQ(
4330       HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(),
4331       "Expected 1 argument, but got 2.");
4332   EXPECT_EQ(HloEvaluator()
4333                 .Evaluate(*m_->entry_computation(), {&input, &input})
4334                 .status()
4335                 .error_message(),
4336             "Expected 1 argument, but got 2.");
4337 }
4338 
TEST_F(HloEvaluatorTest,PreserveFusionInputLayout)4339 TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) {
4340   const absl::string_view hlo_text = R"(
4341     HloModule FusionInputLayout
4342 
4343     fused_computation {
4344       param_0 = f32[20,20]{0,1} parameter(0)
4345       ROOT bitcast = f32[20,20]{1,0} bitcast(param_0)
4346     }
4347 
4348     ENTRY kernel_entry {
4349       parameter.0 = f32[20,20]{0,1} parameter(0)
4350       ROOT fusion = f32[20,20]{1,0} fusion(parameter.0),
4351         kind=kLoop, calls=fused_computation
4352     })";
4353 
4354   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4355   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4356 
4357   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4358   EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4359 }
4360 
TEST_F(HloEvaluatorTest,PreserveFusionOutputLayout)4361 TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) {
4362   const absl::string_view hlo_text = R"(
4363     HloModule FusionOutputLayout
4364 
4365     fused_computation {
4366       param_0 = f32[20,20]{1,0} parameter(0)
4367       ROOT bitcast = f32[20,20]{0,1} bitcast(param_0)
4368     }
4369 
4370     ENTRY kernel_entry {
4371       parameter.0 = f32[20,20]{1,0} parameter(0)
4372       ROOT fusion = f32[20,20]{0,1} fusion(parameter.0),
4373         kind=kLoop, calls=fused_computation
4374     })";
4375 
4376   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4377   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4378   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4379   EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4380 }
4381 
TEST_F(HloEvaluatorTest,PreserveMOFusionOutputLayout)4382 TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) {
4383   const absl::string_view hlo_text = R"(
4384     HloModule MOFusionOutputLayout
4385 
4386     fused_computation {
4387       param_0 = f32[20,20]{1,0} parameter(0)
4388       bitcast = f32[20,20]{0,1} bitcast(param_0)
4389       ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast)
4390     }
4391 
4392     ENTRY kernel_entry {
4393       parameter.0 = f32[20,20]{1,0} parameter(0)
4394       ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0),
4395         kind=kLoop, calls=fused_computation
4396     })";
4397 
4398   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4399   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4400   TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]}));
4401   std::vector<Literal> actual_literals = actual_tuple.DecomposeTuple();
4402   EXPECT_TRUE(
4403       absl::c_equal(args[0].data<float>(), actual_literals[0].data<float>()));
4404 }
4405 
4406 // Tests that custom_calls fail to evaluate when no handler is specified.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_NoHandler)4407 TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) {
4408   const absl::string_view hlo_text = R"(
4409     HloModule EvaluateCustomCall_NoHandler
4410     ENTRY kernel_entry {
4411       parameter.0 = u32[2,2]{1,0} parameter(0)
4412       ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
4413           custom_call_target="_my_custom_call"
4414     }
4415   )";
4416 
4417   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4418   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4419   EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(),
4420             ::tensorflow::error::UNIMPLEMENTED);
4421 }
4422 
4423 // Tests when a custom_call handler returns an error.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_HandlerError)4424 TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) {
4425   const absl::string_view hlo_text = R"(
4426     HloModule EvaluateCustomCall_HandlerError
4427     ENTRY kernel_entry {
4428       parameter.0 = u32[2,2]{1,0} parameter(0)
4429       ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
4430           custom_call_target="_my_custom_call"
4431     }
4432   )";
4433 
4434   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4435   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4436   HloEvaluator evaluator;
4437   evaluator.set_custom_call_handler(
4438       [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
4439         return InternalError("Test error");
4440       });
4441   EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(),
4442             ::tensorflow::error::INTERNAL);
4443 }
4444 
4445 // Tests the custom_call handler on calls with many inputs.
4446 // We sum the operands so that we can verify the operand and output literals
4447 // are properly mapped for access.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_ManyInputs)4448 TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) {
4449   const absl::string_view hlo_text = R"(
4450     HloModule EvaluateCustomCall_ManyInputs
4451     ENTRY kernel_entry {
4452       parameter.0 = u32[1]{0} parameter(0)
4453       parameter.1 = u32[1]{0} parameter(1)
4454       ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1),
4455           custom_call_target="_my_custom_call"
4456     }
4457   )";
4458 
4459   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4460   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
4461   HloEvaluator evaluator;
4462   evaluator.set_custom_call_handler(
4463       [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
4464         EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode());
4465         EXPECT_EQ("_my_custom_call", custom_call->custom_call_target());
4466         EXPECT_EQ(2, custom_call->operand_count());
4467         EXPECT_EQ(2, operands.size());
4468         auto output = Literal::CreateFromShape(custom_call->shape());
4469         auto operand0_data = operands[0]->data<uint32>();
4470         auto operand1_data = operands[1]->data<uint32>();
4471         auto output_data = output.data<uint32>();
4472         output_data[0] = operand0_data[0] + operand1_data[0];
4473         return output;
4474       });
4475   TF_ASSERT_OK_AND_ASSIGN(
4476       Literal actual_literal,
4477       evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]}));
4478   auto arg0_data = args[0].data<uint32>();
4479   auto arg1_data = args[1].data<uint32>();
4480   std::vector<uint32> expected_data = {arg0_data[0] + arg1_data[0]};
4481   EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data<uint32>()));
4482 }
4483 
TEST_F(HloEvaluatorTest,IsFiniteF16)4484 TEST_F(HloEvaluatorTest, IsFiniteF16) {
4485   const absl::string_view hlo_text = R"(
4486   HloModule test
4487 
4488   ENTRY IsFiniteTest {
4489     c = f16[6] constant({nan, 7, nan, -1, inf, -inf})
4490     ROOT is-finite = pred[6] is-finite(c)
4491   })";
4492 
4493   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4494   TF_ASSERT_OK_AND_ASSIGN(
4495       Literal actual_literal,
4496       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4497   EXPECT_THAT(actual_literal.data<bool>(),
4498               ::testing::ElementsAre(false, true, false, true, false, false));
4499 }
4500 
TEST_F(HloEvaluatorTest,IsFiniteBf16)4501 TEST_F(HloEvaluatorTest, IsFiniteBf16) {
4502   const absl::string_view hlo_text = R"(
4503   HloModule test
4504 
4505   ENTRY IsFiniteTest {
4506     c = bf16[6] constant({nan, 7, nan, -1, inf, -inf})
4507     ROOT is-finite = pred[6] is-finite(c)
4508   })";
4509 
4510   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4511   TF_ASSERT_OK_AND_ASSIGN(
4512       Literal actual_literal,
4513       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4514   EXPECT_THAT(actual_literal.data<bool>(),
4515               ::testing::ElementsAre(false, true, false, true, false, false));
4516 }
4517 
4518 // Check that evaluating `f32[<huge>, 0] iota` doesn't oom (it's an empty
4519 // array!).
TEST_F(HloEvaluatorTest,ZeroSizedIotaWithHugeDimension)4520 TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) {
4521   const absl::string_view hlo_text = R"(
4522   HloModule test
4523   ENTRY t {
4524     ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0
4525   })";
4526   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4527   TF_ASSERT_OK_AND_ASSIGN(
4528       Literal actual_literal,
4529       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4530   EXPECT_THAT(actual_literal.data<float>(), ::testing::IsEmpty());
4531 }
4532 
TEST_F(HloEvaluatorTest,CopyStartCopyDone)4533 TEST_F(HloEvaluatorTest, CopyStartCopyDone) {
4534   const absl::string_view hlo_text = R"(
4535   HloModule test
4536   ENTRY CopyStartCopyDone {
4537     init = f32[] constant(42.0)
4538     copy-start = (f32[]{:S(1)}, f32[], u32[]) copy-start(init)
4539     ROOT copy-done = f32[] copy-done(copy-start)
4540   }
4541   )";
4542   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4543   Literal expected = LiteralUtil::CreateR0<float>(42.0f);
4544   TF_ASSERT_OK_AND_ASSIGN(
4545       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4546   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4547 }
4548 
TEST_F(HloEvaluatorTest,MapBF16)4549 TEST_F(HloEvaluatorTest, MapBF16) {
4550   const absl::string_view hlo_text = R"(
4551   HloModule test
4552 
4553   map_computation {
4554     p = bf16[] parameter(0)
4555     add = bf16[] add(p, p)
4556     ROOT conv = f32[] convert(add)
4557   }
4558 
4559   ENTRY CopyStartCopyDone {
4560     c = bf16[3] constant({1, 2, 3})
4561     ROOT map = f32[3] map(c), to_apply=map_computation
4562   }
4563   )";
4564   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4565   Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4566   TF_ASSERT_OK_AND_ASSIGN(
4567       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4568   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4569 }
4570 
TEST_F(HloEvaluatorTest,MapS16)4571 TEST_F(HloEvaluatorTest, MapS16) {
4572   const absl::string_view hlo_text = R"(
4573   HloModule test
4574 
4575   map_computation {
4576     p = s16[] parameter(0)
4577     add = s16[] add(p, p)
4578     ROOT conv = f32[] convert(add)
4579   }
4580 
4581   ENTRY CopyStartCopyDone {
4582     c = s16[3] constant({1, 2, 3})
4583     ROOT map = f32[3] map(c), to_apply=map_computation
4584   }
4585   )";
4586   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4587   Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4588   TF_ASSERT_OK_AND_ASSIGN(
4589       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4590   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4591 }
4592 
TEST_F(HloEvaluatorTest,MapU16)4593 TEST_F(HloEvaluatorTest, MapU16) {
4594   const absl::string_view hlo_text = R"(
4595   HloModule test
4596 
4597   map_computation {
4598     p = u16[] parameter(0)
4599     add = u16[] add(p, p)
4600     ROOT conv = f32[] convert(add)
4601   }
4602 
4603   ENTRY CopyStartCopyDone {
4604     c = u16[3] constant({1, 2, 3})
4605     ROOT map = f32[3] map(c), to_apply=map_computation
4606   }
4607   )";
4608   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4609   Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4610   TF_ASSERT_OK_AND_ASSIGN(
4611       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4612   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4613 }
4614 
TEST_F(HloEvaluatorTest,DotUpcast)4615 TEST_F(HloEvaluatorTest, DotUpcast) {
4616   const absl::string_view hlo_text = R"(
4617   HloModule test
4618   ENTRY DotUpcast {
4619     l = s16[4,3]{1,0} parameter(0)
4620     r = s8[3,2]{1,0} parameter(1)
4621     ROOT result = s32[4,2] dot(l, r), lhs_contracting_dims={1},
4622                                       rhs_contracting_dims={0}
4623   }
4624   )";
4625   // lhs:
4626   // s16[4,3] {
4627   //  { 1, 2, 3 },
4628   //  { 5, 6, 7 },
4629   //  { 9, 10, 11 },
4630   //  { 13, 14, 15 },
4631   // }
4632   auto lhs_array = absl::make_unique<Array2D<int16>>(4, 3);
4633   lhs_array->FillUnique(1);
4634   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<int16>(*lhs_array);
4635 
4636   // rhs:
4637   // s8[3,2] {
4638   //  { 1, 2 },
4639   //  { 3, 4 },
4640   //  { 5, 6 },
4641   // }
4642   auto rhs_array = absl::make_unique<Array2D<int8>>(3, 2);
4643   rhs_array->FillUnique(1);
4644   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<int8>(*rhs_array);
4645   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4646   TF_ASSERT_OK_AND_ASSIGN(Literal result,
4647                           Evaluate({&lhs_literal, &rhs_literal}));
4648 
4649   auto expected_array =
4650       Array2D<int32>({{22, 28}, {58, 76}, {94, 124}, {130, 172}});
4651   auto expected = LiteralUtil::CreateR2FromArray2D<int32>(expected_array);
4652 
4653   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4654 }
4655 
TEST_F(HloEvaluatorTest,SortC64)4656 TEST_F(HloEvaluatorTest, SortC64) {
4657   const absl::string_view hlo_text = R"(
4658   HloModule m
4659 
4660   sort_lt_comparator {
4661     parameter.0 = c64[] parameter(0)
4662     real.0 = f32[] real(parameter.0)
4663     parameter.1 = c64[] parameter(1)
4664     real.1 = f32[] real(parameter.1)
4665     ROOT compare = pred[] compare(real.0, real.1), direction=LT
4666   }
4667 
4668   ENTRY main {
4669     c = c64[3] constant({(2, 0), (4, 0), (6, 0)})
4670     ROOT sort = c64[3]{0} sort(c), dimensions={0}, to_apply=sort_lt_comparator
4671   }
4672   )";
4673   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4674   Literal expected =
4675       LiteralUtil::CreateR1<std::complex<float>>({2.f, 4.f, 6.f});
4676   TF_ASSERT_OK_AND_ASSIGN(
4677       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4678   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4679 }
4680 
4681 }  // namespace
4682 }  // namespace xla
4683