1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16 
17 #include <initializer_list>
18 #include <memory>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/reference_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/test.h"
37 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
38 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
39 #include "tensorflow/compiler/xla/tests/test_utils.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/platform/test.h"
46 #include "tensorflow/core/platform/test_benchmark.h"
47 #include "tensorflow/core/platform/types.h"
48 
49 namespace xla {
50 namespace {
51 
52 static std::array<bool, 2> use_bf16_params{true, false};
53 
54 // Test fixture for the HloEvaluator.
55 //
56 // In bf16 mode, all f32 shapes are converted to bf16 before running.
57 class HloEvaluatorTest : public HloTestBase {
58  public:
HloEvaluatorTest()59   HloEvaluatorTest() : use_bfloat16_(false) {}
60 
Evaluate(absl::Span<const Literal * const> arg_literals={})61   StatusOr<Literal> Evaluate(
62       absl::Span<const Literal* const> arg_literals = {}) {
63     if (use_bfloat16_) {
64       HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
65     }
66     return evaluator_.Evaluate(*m_->entry_computation(), arg_literals);
67   }
68 
69   // Evaluate function that takes in a local module instead of using m_
70   // that is in HloTestBase. Once m_ in HloTestBase is
71   // removed, this should be the default Evaluate function.
EvaluateWithModule(HloModule * module,absl::Span<const Literal * const> arg_literals={})72   Literal EvaluateWithModule(
73       HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
74     if (use_bfloat16_) {
75       HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
76     }
77     return evaluator_.Evaluate(*module->entry_computation(), arg_literals)
78         .ConsumeValueOrDie();
79   }
80 
TestUnaryOp(HloOpcode opcode,Literal expected,Literal input,float aabs=0)81   void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
82                    float aabs = 0) {
83     HloComputation::Builder b(TestName());
84     auto c1 =
85         b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
86     b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
87     m_->AddEntryComputation(b.Build());
88 
89     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
90 
91     auto element_type = expected.shape().element_type();
92     if (element_type == F32 || element_type == F64) {
93       ErrorSpec error(aabs);
94       EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
95     } else {
96       EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
97     }
98   }
99 
TestBinaryOp(HloOpcode opcode,Literal expected,Literal lhs,Literal rhs)100   void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
101                     Literal rhs) {
102     HloComputation::Builder b(TestName());
103     auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
104     auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
105     b.AddInstruction(
106         HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
107     m_->AddEntryComputation(b.Build());
108 
109     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
110 
111     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
112   }
113 
TestTernaryOp(HloOpcode opcode,Literal expected,Literal src0,Literal src1,Literal src2)114   void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0,
115                      Literal src1, Literal src2) {
116     HloComputation::Builder b(TestName());
117     auto operand0 =
118         b.AddInstruction(HloInstruction::CreateConstant(std::move(src0)));
119     auto operand1 =
120         b.AddInstruction(HloInstruction::CreateConstant(std::move(src1)));
121     auto operand2 =
122         b.AddInstruction(HloInstruction::CreateConstant(std::move(src2)));
123     b.AddInstruction(HloInstruction::CreateTernary(
124         expected.shape(), opcode, operand0, operand1, operand2));
125     m_->AddEntryComputation(b.Build());
126 
127     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
128 
129     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
130   }
131 
132  protected:
HloEvaluatorTest(bool use_bfloat16)133   explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {}
134   HloEvaluator evaluator_;
135 
136   const bool use_bfloat16_;
137   std::unique_ptr<HloModule> m_ = CreateNewVerifiedModule();
138 };
139 
140 // Lets you write TEST_Ps that run twice, once with and once without bf16.
141 class HloEvaluatorBf16Test : public ::testing::WithParamInterface<bool>,
142                              public HloEvaluatorTest {
143  protected:
HloEvaluatorBf16Test()144   HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {}
145 };
146 
147 INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test,
148                          ::testing::ValuesIn(use_bf16_params));
149 
150 // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
151 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesClamp)152 TEST_P(HloEvaluatorBf16Test, DoesClamp) {
153   auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
154   auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
155   auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
156 
157   Shape shape = low.shape();
158   HloComputation::Builder b(TestName());
159   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
160   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
161   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
162   b.AddInstruction(
163       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
164   m_->AddEntryComputation(b.Build());
165 
166   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
167 
168   auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
169 
170   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
171 }
172 
173 // Verifies that clamping of int64 does not cause loss of precision
TEST_P(HloEvaluatorBf16Test,DoesClampInt64)174 TEST_P(HloEvaluatorBf16Test, DoesClampInt64) {
175   auto ones = [](int bits) { return (int64{1} << bits) - 1; };
176 
177   auto low =
178       LiteralUtil::CreateR2<int64>({{0, ones(54)}, {ones(54), ones(58)}});
179   auto value = LiteralUtil::CreateR2<int64>({{0, ones(56)}, {0, ones(58)}});
180   auto high = LiteralUtil::CreateR2<int64>(
181       {{ones(54), ones(55)}, {ones(56), ones(58)}});
182 
183   Shape shape = low.shape();
184   HloComputation::Builder b(TestName());
185   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
186   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
187   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
188   b.AddInstruction(
189       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
190   m_->AddEntryComputation(b.Build());
191 
192   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
193 
194   auto expected =
195       LiteralUtil::CreateR2<int64>({{0, ones(55)}, {ones(54), ones(58)}});
196 
197   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
198 }
199 
TEST_P(HloEvaluatorBf16Test,DISABLED_DoesClampSpecialBroadcast)200 TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) {
201   auto low = LiteralUtil::CreateR0<float>(0.f);
202   auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
203   auto high = LiteralUtil::CreateR0<float>(1.f);
204 
205   Shape shape = value.shape();
206   HloComputation::Builder b(TestName());
207   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
208   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
209   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
210   b.AddInstruction(
211       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
212   m_->AddEntryComputation(b.Build());
213 
214   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
215 
216   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
217 
218   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
219 }
220 
221 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
222 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesSelect)223 TEST_P(HloEvaluatorBf16Test, DoesSelect) {
224   auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
225   auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
226   auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
227 
228   Shape shape = on_true.shape();
229   HloComputation::Builder b(TestName());
230   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
231   auto c2 =
232       b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true)));
233   auto c3 =
234       b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false)));
235   b.AddInstruction(
236       HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
237   m_->AddEntryComputation(b.Build());
238 
239   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
240 
241   auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
242 
243   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
244 }
245 
246 // Verifies that HloEvaluator evaluates a HLO instruction that performs
247 // element-wise addition with 2 operands.
TEST_F(HloEvaluatorTest,DoesAdd)248 TEST_F(HloEvaluatorTest, DoesAdd) {
249   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
250   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
251   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}});
252   TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
253                std::move(rhs));
254 }
255 // Verifies that HloEvaluator evaluates a HLO instruction that performs
256 // element-wise and with 2 operands.
TEST_P(HloEvaluatorBf16Test,DoesAnd)257 TEST_P(HloEvaluatorBf16Test, DoesAnd) {
258   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
259   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
260   auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {4, 4}});
261   TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
262                std::move(rhs));
263 }
264 // Verifies that HloEvaluator evaluates a HLO instruction that performs
265 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesOr)266 TEST_F(HloEvaluatorTest, DoesOr) {
267   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
268   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
269   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-100, 4}});
270   TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
271                std::move(rhs));
272 }
273 // Verifies that HloEvaluator evaluates a HLO instruction that performs
274 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesXor)275 TEST_F(HloEvaluatorTest, DoesXor) {
276   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
277   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
278   auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-104, 0}});
279   TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
280                std::move(rhs));
281 }
282 // Verifies that HloEvaluator evaluates a HLO instruction that performs
283 // element-wise multiply with 2 operands.
TEST_F(HloEvaluatorTest,DoesMultiply)284 TEST_F(HloEvaluatorTest, DoesMultiply) {
285   auto lhs = LiteralUtil::CreateR2<int32>({{-1, 0}, {-100, 4}});
286   auto rhs = LiteralUtil::CreateR2<int32>(
287       {{std::numeric_limits<int32>::min(), 4}, {4, 4}});
288   auto expected = LiteralUtil::CreateR2<int32>(
289       {{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
290   TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
291                std::move(rhs));
292 }
293 // Verifies that HloEvaluator evaluates a HLO instruction that performs
294 // element-wise divide with 2 operands.
TEST_F(HloEvaluatorTest,DoesDivideInt64)295 TEST_F(HloEvaluatorTest, DoesDivideInt64) {
296   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
297   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
298   auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}});
299   TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
300                std::move(rhs));
301 }
302 
TEST_F(HloEvaluatorTest,DoesClampS64)303 TEST_F(HloEvaluatorTest, DoesClampS64) {
304   auto low = LiteralUtil::CreateR1<int64>(
305       {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL});
306   auto value = LiteralUtil::CreateR1<int64>(
307       {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL});
308   auto high = LiteralUtil::CreateR1<int64>(
309       {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL});
310   auto expected = LiteralUtil::CreateR1<int64>(
311       {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL});
312   TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low),
313                 std::move(value), std::move(high));
314 }
315 
TEST_P(HloEvaluatorBf16Test,DoesDivideDouble)316 TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) {
317   auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
318   auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
319   auto expected =
320       LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
321   TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
322                std::move(rhs));
323 }
324 
325 // Verifies that HloEvaluator evaluates a HLO instruction that performs
326 // element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest,DoesAbsR2)327 TEST_F(HloEvaluatorTest, DoesAbsR2) {
328   auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
329   auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
330   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
331 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR0)332 TEST_P(HloEvaluatorBf16Test, DoesAbsR0) {
333   auto operand = LiteralUtil::CreateR0<float>(-1.0f);
334   auto expected = LiteralUtil::CreateR0<float>(1.0f);
335   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
336 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR1WithZeroSize)337 TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) {
338   auto operand = LiteralUtil::CreateR1<float>({});
339   auto expected = LiteralUtil::CreateR1<float>({});
340   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
341 }
TEST_F(HloEvaluatorTest,DoesNegateR2)342 TEST_F(HloEvaluatorTest, DoesNegateR2) {
343   auto operand = LiteralUtil::CreateR2<int32>(
344       {{0, std::numeric_limits<int32>::min()}, {-1, 4}});
345   auto expected = LiteralUtil::CreateR2<int32>(
346       {{0, std::numeric_limits<int>::min()}, {1, -4}});
347   TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
348 }
TEST_P(HloEvaluatorBf16Test,DoesCosR2)349 TEST_P(HloEvaluatorBf16Test, DoesCosR2) {
350   auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
351   auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
352   TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
353               use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
354 }
TEST_P(HloEvaluatorBf16Test,DoesSinR2)355 TEST_P(HloEvaluatorBf16Test, DoesSinR2) {
356   auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
357   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
358   TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
359               use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
360 }
TEST_F(HloEvaluatorTest,DoesNotR2)361 TEST_F(HloEvaluatorTest, DoesNotR2) {
362   auto operand =
363       LiteralUtil::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
364                                     {-1, std::numeric_limits<int>::max()}});
365   auto expected =
366       LiteralUtil::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
367                                     {0, std::numeric_limits<int>::min()}});
368   TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
369 }
370 
TEST_F(HloEvaluatorTest,DoesRealC128)371 TEST_F(HloEvaluatorTest, DoesRealC128) {
372   auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
373   auto expected_real = LiteralUtil::CreateR1<double>({1, -100});
374   TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x));
375 }
376 
TEST_F(HloEvaluatorTest,DoesImagC128)377 TEST_F(HloEvaluatorTest, DoesImagC128) {
378   auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
379   auto expected_imag = LiteralUtil::CreateR1<double>({0, 4});
380   TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
381 }
382 
383 // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
384 // constant operands.
TEST_F(HloEvaluatorTest,DoesTraverseInstructions)385 TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
386   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
387   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
388   auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
389   std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
390 
391   Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
392 
393   HloComputation::Builder b(TestName());
394   auto param_lhs =
395       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
396   auto param_rhs =
397       b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
398   auto lhs_instruction = b.AddInstruction(HloInstruction::CreateBinary(
399       shape, HloOpcode::kAdd, param_lhs, param_rhs));
400 
401   auto param_rhs2 =
402       b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2"));
403   b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd,
404                                                 lhs_instruction, param_rhs2));
405   m_->AddEntryComputation(b.Build());
406 
407   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args));
408 
409   auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
410 
411   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
412 }
413 
414 // Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesReshape)415 TEST_F(HloEvaluatorTest, DoesReshape) {
416   HloComputation::Builder b(TestName());
417   const int64 dimensions[] = {11, 8, 7, 5, 9};
418   TF_ASSERT_OK_AND_ASSIGN(auto literal,
419                           LiteralUtil::CreateRandomLiteral<F32>(
420                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
421   auto literal_clone = literal.Clone();
422   HloInstruction* literal_instruction =
423       b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
424 
425   Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
426   const int64 permutation[] = {1, 2, 0, 4, 3};
427   b.AddInstruction(
428       HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
429   m_->AddEntryComputation(b.Build());
430 
431   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
432 
433   using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
434   result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
435     std::vector<int64> rindexes = Permute(permutation, indices);
436     EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
437   });
438 }
439 
440 // Verifies Broadcast operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesBroadcast)441 TEST_F(HloEvaluatorTest, DoesBroadcast) {
442   HloComputation::Builder b(TestName());
443   auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
444   auto output_literal = LiteralUtil::CreateR3<int32>(
445       {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
446   HloInstruction* literal_instruction = b.AddInstruction(
447       HloInstruction::CreateConstant(std::move(input_literal)));
448   b.AddInstruction(HloInstruction::CreateBroadcast(
449       output_literal.shape(), literal_instruction, {1, 2}));
450   m_->AddEntryComputation(b.Build());
451 
452   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
453 
454   EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
455 }
456 
TEST_F(HloEvaluatorTest,DoesBroadcastScalar)457 TEST_F(HloEvaluatorTest, DoesBroadcastScalar) {
458   HloComputation::Builder b(TestName());
459   auto input_literal = LiteralUtil::CreateR0<int32>(111);
460   auto output_literal = LiteralUtil::CreateR2<int32>(
461       {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
462 
463   HloInstruction* literal_instruction = b.AddInstruction(
464       HloInstruction::CreateConstant(std::move(input_literal)));
465   // Broadcast dimension should be empty in the case of scalars.
466   b.AddInstruction(HloInstruction::CreateBroadcast(
467       output_literal.shape(), literal_instruction,
468       /*broadcast_dimensions=*/{}));
469   m_->AddEntryComputation(b.Build());
470 
471   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
472 
473   EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
474 }
475 
TEST_F(HloEvaluatorTest,DoesConcatenateSimple)476 TEST_F(HloEvaluatorTest, DoesConcatenateSimple) {
477   HloComputation::Builder b(TestName());
478 
479   HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
480       LiteralUtil::CreateR2<int64>({{-1, -2}, {100, 200}})));
481   HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
482       LiteralUtil::CreateR2<int64>({{-2, -3}, {-100, -200}})));
483 
484   std::vector<HloInstruction*> operands = {operand1, operand2};
485 
486   Shape shape = ShapeUtil::MakeShape(S64, {4, 2});
487   b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
488 
489   m_->AddEntryComputation(b.Build());
490 
491   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
492 
493   auto expected = LiteralUtil::CreateR2<int64>(
494       {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
495   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
496 }
497 
TEST_F(HloEvaluatorTest,ConcatenateHandlesShapeWithZeroElement)498 TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
499   HloComputation::Builder b(TestName());
500 
501   HloInstruction* operand1 = b.AddInstruction(
502       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({100, 200})));
503   HloInstruction* operand2 = b.AddInstruction(
504       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({})));
505 
506   std::vector<HloInstruction*> operands = {operand1, operand2};
507 
508   Shape shape = ShapeUtil::MakeShape(S64, {2});
509   b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
510 
511   m_->AddEntryComputation(b.Build());
512 
513   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
514 
515   auto expected = LiteralUtil::CreateR1<int64>({100, 200});
516   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
517 }
518 
TEST_P(HloEvaluatorBf16Test,ConvertWithSameLayout)519 TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) {
520   HloComputation::Builder b(TestName());
521 
522   auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
523   auto expected =
524       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
525   ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
526                                                expected.shape()));
527 
528   HloInstruction* constant = b.AddInstruction(
529       HloInstruction::CreateConstant(std::move(input_literal)));
530   b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
531   m_->AddEntryComputation(b.Build());
532 
533   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
534 
535   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
536 }
537 
TEST_P(HloEvaluatorBf16Test,ConvertWithDifferentLayout)538 TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) {
539   HloComputation::Builder b(TestName());
540 
541   auto input_literal = LiteralUtil::CreateR2WithLayout<int32>(
542       {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
543   auto expected = LiteralUtil::CreateR2WithLayout<float>(
544       {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
545   ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
546                                                 expected.shape()));
547 
548   HloInstruction* constant = b.AddInstruction(
549       HloInstruction::CreateConstant(std::move(input_literal)));
550   b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
551   m_->AddEntryComputation(b.Build());
552 
553   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
554 
555   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
556 }
557 
CreatePaddingConfig(std::initializer_list<std::array<int64,3>> padding_dimensions)558 PaddingConfig CreatePaddingConfig(
559     std::initializer_list<std::array<int64, 3>> padding_dimensions) {
560   PaddingConfig padding_config;
561 
562   for (auto& paddings_per_dim : padding_dimensions) {
563     auto dimension = padding_config.add_dimensions();
564     dimension->set_edge_padding_low(paddings_per_dim[0]);
565     dimension->set_edge_padding_high(paddings_per_dim[1]);
566     dimension->set_interior_padding(paddings_per_dim[2]);
567   }
568   return padding_config;
569 }
570 
TEST_F(HloEvaluatorTest,Pad2DIntegerArrayWithZeroDimension)571 TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
572   auto operand = LiteralUtil::CreateR2<int32>({{}, {}});
573   HloComputation::Builder b(TestName());
574   auto operand_instruction =
575       b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
576 
577   constexpr int32 kPadValue = 10;
578   auto pad_value = LiteralUtil::CreateR0<int32>(kPadValue);
579   auto padding_value_instruction =
580       b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
581 
582   auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}});
583   Shape shape = ShapeUtil::MakeShape(S32, {5, 2});
584   b.AddInstruction(HloInstruction::CreatePad(
585       shape, operand_instruction, padding_value_instruction, padding_config));
586   m_->AddEntryComputation(b.Build());
587 
588   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
589 
590   auto expected = LiteralUtil::CreateR2<int32>(
591       {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
592 
593   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
594 }
595 
TEST_P(HloEvaluatorBf16Test,Pad4DFloatArrayWithInteriorPadding)596 TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) {
597   HloComputation::Builder b(TestName());
598 
599   Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
600   auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
601   HloInstruction* input_instruction =
602       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
603   constexpr float kPadValue = 1.5;
604   auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
605   HloInstruction* pad_instruction =
606       b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
607 
608   Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1});
609   auto r4_padding_on_dim0_dim1 =
610       CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}});
611   b.AddInstruction(HloInstruction::CreatePad(
612       shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
613   m_->AddEntryComputation(b.Build());
614 
615   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
616 
617   auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
618   expected_array->Fill(kPadValue);
619   (*expected_array)(1, 0, 0, 0) = 1.0f;
620   (*expected_array)(1, 2, 0, 0) = 2.0f;
621   (*expected_array)(4, 0, 0, 0) = 3.0f;
622   (*expected_array)(4, 2, 0, 0) = 4.0f;
623   (*expected_array)(7, 0, 0, 0) = 5.0f;
624   (*expected_array)(7, 2, 0, 0) = 6.0f;
625 
626   auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
627 
628   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
629 }
630 
TEST_P(HloEvaluatorBf16Test,NegativePadding2D)631 TEST_P(HloEvaluatorBf16Test, NegativePadding2D) {
632   HloComputation::Builder b(TestName());
633 
634   // input_array:
635   // f32[4,3] {
636   //  { 1, 2, 3 },
637   //  { 5, 6, 7 },
638   //  { 9, 10, 11 },
639   //  { 13, 14, 15 },
640   // }
641   auto input_array = absl::make_unique<Array2D<float>>(4, 3);
642   input_array->FillUnique(1.0f);
643   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
644   HloInstruction* input_instruction =
645       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
646 
647   auto pad_value_instruction = b.AddInstruction(
648       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
649 
650   auto r2_padding_on_dim0_dim1 =
651       CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
652   Shape shape = ShapeUtil::MakeShape(F32, {1, 5});
653   b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
654                                              pad_value_instruction,
655                                              r2_padding_on_dim0_dim1));
656 
657   m_->AddEntryComputation(b.Build());
658 
659   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
660 
661   // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
662   auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
663   (*expected_array)(0, 0) = 7.0f;
664   (*expected_array)(0, 1) = 2.718f;
665   (*expected_array)(0, 2) = 2.718f;
666   (*expected_array)(0, 3) = 2.718f;
667   (*expected_array)(0, 4) = 2.718f;
668   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
669 
670   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
671 }
672 
TEST_P(HloEvaluatorBf16Test,NegativeAndInteriorPadding2D)673 TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) {
674   HloComputation::Builder b(TestName());
675 
676   // f32[4,3] {
677   //  { 1, 2, 3 },
678   //  { 5, 6, 7 },
679   //  { 9, 10, 11 },
680   //  { 13, 14, 15 },
681   // }
682   auto input_array = absl::make_unique<Array2D<float>>(4, 3);
683   input_array->FillUnique(1.0f);
684   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
685   HloInstruction* input_instruction =
686       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
687 
688   auto pad_value_instruction = b.AddInstruction(
689       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
690 
691   PaddingConfig padding_config = MakeNoPaddingConfig(2);
692 
693   // Negative padding that results in zero dimensions.
694   auto r2_padding_on_dim0_dim1 =
695       CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}});
696 
697   Shape shape = ShapeUtil::MakeShape(F32, {0, 9});
698   b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
699                                              pad_value_instruction,
700                                              r2_padding_on_dim0_dim1));
701 
702   m_->AddEntryComputation(b.Build());
703 
704   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
705 
706   auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
707   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
708 
709   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
710 }
711 
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank1)712 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) {
713   HloComputation::Builder b(TestName());
714 
715   // lhs:
716   // f32[4,1] {
717   //  { 1 },
718   //  { 2 },
719   //  { 3 },
720   //  { 4 },
721   // }
722   auto lhs_array = absl::make_unique<Array2D<float>>(4, 1);
723   lhs_array->FillUnique(1.0f);
724   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
725   HloInstruction* lhs_instruction =
726       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
727 
728   // rhs:
729   // f32[2] { 1, 2 },
730   auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
731   HloInstruction* rhs_instruction =
732       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
733 
734   Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
735   DotDimensionNumbers dot_dnums;
736   dot_dnums.add_lhs_contracting_dimensions(1);
737   dot_dnums.add_rhs_contracting_dimensions(0);
738   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
739                                              rhs_instruction, dot_dnums,
740                                              DefaultPrecisionConfig(2)));
741   m_->AddEntryComputation(b.Build());
742 
743   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
744 
745   // clang-format off
746   auto expected_array = Array2D<float>({
747       {1.f, 2.f},
748       {2.f, 4.f},
749       {3.f, 6.f},
750       {4.f, 8.f},
751   });
752   // clang-format on
753   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
754 
755   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
756 }
757 
TEST_P(HloEvaluatorBf16Test,DotRank1AndRank2)758 TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) {
759   HloComputation::Builder b(TestName());
760 
761   // lhs:
762   // f32[3]
763   //  { 1, 2, 3 },
764   auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
765   HloInstruction* lhs_instruction =
766       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
767 
768   // rhs:
769   // f32[3,2] {
770   //  { 1, 2 },
771   //  { 3, 4 },
772   //  { 5, 6 },
773   // }
774   auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
775   rhs_array->FillUnique(1.0f);
776   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
777   HloInstruction* rhs_instruction =
778       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
779 
780   Shape shape = ShapeUtil::MakeShape(F32, {2});
781   DotDimensionNumbers dot_dnums;
782   dot_dnums.add_lhs_contracting_dimensions(0);
783   dot_dnums.add_rhs_contracting_dimensions(0);
784   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
785                                              rhs_instruction, dot_dnums,
786                                              DefaultPrecisionConfig(2)));
787   m_->AddEntryComputation(b.Build());
788 
789   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
790 
791   auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
792 
793   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
794 }
795 
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank2)796 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) {
797   HloComputation::Builder b(TestName());
798 
799   // lhs:
800   // f32[4,3] {
801   //  { 1, 2, 3 },
802   //  { 5, 6, 7 },
803   //  { 9, 10, 11 },
804   //  { 13, 14, 15 },
805   // }
806   auto lhs_array = absl::make_unique<Array2D<float>>(4, 3);
807   lhs_array->FillUnique(1.0f);
808   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
809   HloInstruction* lhs_instruction =
810       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
811 
812   // rhs:
813   // f32[3,2] {
814   //  { 1, 2 },
815   //  { 3, 4 },
816   //  { 5, 6 },
817   // }
818   auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
819   rhs_array->FillUnique(1.0f);
820   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
821   HloInstruction* rhs_instruction =
822       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
823 
824   Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
825   DotDimensionNumbers dot_dnums;
826   dot_dnums.add_lhs_contracting_dimensions(1);
827   dot_dnums.add_rhs_contracting_dimensions(0);
828   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
829                                              rhs_instruction, dot_dnums,
830                                              DefaultPrecisionConfig(2)));
831   m_->AddEntryComputation(b.Build());
832 
833   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
834 
835   auto expected_array = Array2D<float>({
836       {22.f, 28.f},
837       {58.f, 76.f},
838       {94.f, 124.f},
839       {130.f, 172.f},
840   });
841   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
842 
843   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
844 }
845 
TEST_P(HloEvaluatorBf16Test,DotRank4AndRank4)846 TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) {
847   HloComputation::Builder b(TestName());
848 
849   auto lhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
850   lhs_array->FillIota(1.0f);
851   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*lhs_array);
852   HloInstruction* lhs_instruction =
853       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
854 
855   auto rhs_array = absl::make_unique<Array4D<float>>(2, 2, 3, 1);
856   rhs_array->FillIota(2.0f);
857   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*rhs_array);
858   HloInstruction* rhs_instruction =
859       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
860 
861   Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1});
862   DotDimensionNumbers dot_dnums;
863 
864   dot_dnums.add_lhs_batch_dimensions(0);
865   dot_dnums.add_rhs_batch_dimensions(0);
866   dot_dnums.add_lhs_contracting_dimensions(1);
867   dot_dnums.add_lhs_contracting_dimensions(2);
868   dot_dnums.add_rhs_contracting_dimensions(1);
869   dot_dnums.add_rhs_contracting_dimensions(2);
870   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
871                                              rhs_instruction, dot_dnums,
872                                              DefaultPrecisionConfig(2)));
873   m_->AddEntryComputation(b.Build());
874 
875   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
876 
877   float expected_1 = 0;
878   for (float i = 1.0f; i < 7.0f; ++i) {
879     expected_1 += i * i + i;
880   }
881   float expected_2 = 0;
882   for (float i = 7.0f; i < 13.0f; ++i) {
883     expected_2 += i * i + i;
884   }
885   auto expected_array = Array3D<float>({{{expected_1}}, {{expected_2}}});
886   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
887 
888   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
889 }
890 
TEST_P(HloEvaluatorBf16Test,SimpleConv1D)891 TEST_P(HloEvaluatorBf16Test, SimpleConv1D) {
892   HloComputation::Builder b(TestName());
893 
894   Array3D<float> lhs_array = {{{1, 2, 3}}};
895   auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
896   HloInstruction* lhs_instruction =
897       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
898 
899   Array3D<float> rhs_array = {{{3.f, 4.f}}};
900   auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
901   HloInstruction* rhs_instruction =
902       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
903 
904   Window window;
905   WindowDimension dim;
906   dim.set_size(2);
907   dim.set_stride(1);
908   dim.set_padding_low(0);
909   dim.set_padding_high(1);
910   dim.set_window_dilation(1);
911   dim.set_base_dilation(1);
912   *window.add_dimensions() = dim;
913 
914   ConvolutionDimensionNumbers dnums;
915   dnums.set_input_batch_dimension(0);
916   dnums.set_output_batch_dimension(0);
917   dnums.set_input_feature_dimension(1);
918   dnums.set_output_feature_dimension(1);
919   dnums.add_input_spatial_dimensions(2);
920   dnums.add_output_spatial_dimensions(2);
921 
922   dnums.set_kernel_output_feature_dimension(0);
923   dnums.set_kernel_input_feature_dimension(1);
924   dnums.add_kernel_spatial_dimensions(2);
925 
926   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
927   b.AddInstruction(HloInstruction::CreateConvolve(
928       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
929       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
930   m_->AddEntryComputation(b.Build());
931 
932   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
933 
934   Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
935   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
936 
937   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
938 }
939 
TEST_P(HloEvaluatorBf16Test,Simple4x4Conv2DWith2x2Kernel)940 TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) {
941   HloComputation::Builder b(TestName());
942 
943   Array4D<float> lhs_array(1, 1, 4, 4);
944   // clang-format off
945   lhs_array.FillWithYX(Array2D<float>({
946     {1,  2,  3,  4 },
947     {5,  6,  7,  8 },
948     {9,  10, 11, 12},
949     {13, 14, 15, 16},
950   }));
951   // clang-format on
952   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
953   HloInstruction* lhs_instruction =
954       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
955 
956   Array4D<float> rhs_array(1, 1, 2, 2);
957   // clang-format off
958   rhs_array.FillWithYX(Array2D<float>({
959     {5, 6},
960     {7, 8},
961   }));
962   // clang-format on
963   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
964   HloInstruction* rhs_instruction =
965       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
966 
967   Window window;
968   WindowDimension dim;
969   dim.set_size(2);
970   dim.set_stride(1);
971   dim.set_padding_low(0);
972   dim.set_padding_high(1);
973   dim.set_window_dilation(1);
974   dim.set_base_dilation(1);
975   *window.add_dimensions() = dim;
976   *window.add_dimensions() = dim;
977 
978   ConvolutionDimensionNumbers dnums =
979       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
980 
981   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
982   b.AddInstruction(HloInstruction::CreateConvolve(
983       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
984       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
985   m_->AddEntryComputation(b.Build());
986 
987   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
988 
989   Array4D<float> expected_array(1, 1, 4, 4);
990   // clang-format off
991   expected_array.FillWithYX(Array2D<float>({
992     {100, 126, 152,  76},
993     {204, 230, 256, 124},
994     {308, 334, 360, 172},
995     {149, 160, 171,  80},
996   }));
997   // clang-format on
998   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
999 
1000   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1001 }
1002 
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensionsReversed)1003 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) {
1004   HloComputation::Builder b(TestName());
1005 
1006   // clang-format off
1007   // Input dimensions: [feature=2, height=3, batch=1, width=4]
1008   Array4D<float> input({
1009     {{{1, 2, 3, 4}},
1010      {{5, 6, 7, 8}},
1011      {{9, 10, 11, 12}}},
1012     {{{13, 14, 15, 16}},
1013      {{17, 18, 19, 20}},
1014      {{21, 22, 23, 24}}}
1015   });
1016   // Weight dimensions:
1017   // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1018   Array4D<float> weight({{
1019     {{1, 7, 13},
1020      {4, 10, 16}},
1021     {{2, 8, 14},
1022      {5, 11, 17}},
1023     {{3, 9, 15},
1024      {6, 12, 18}}
1025   }});
1026   // clang-format on
1027 
1028   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1029   HloInstruction* lhs_instruction =
1030       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1031 
1032   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1033   HloInstruction* rhs_instruction =
1034       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1035   rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
1036       rhs_instruction->shape(), rhs_instruction, {3, 1}));
1037 
1038   Window window;
1039   WindowDimension dim;
1040   dim.set_size(3);
1041   dim.set_stride(1);
1042   dim.set_padding_low(0);
1043   dim.set_padding_high(0);
1044   dim.set_window_dilation(1);
1045   dim.set_base_dilation(1);
1046   dim.set_window_reversal(true);
1047   *window.add_dimensions() = dim;
1048   *window.add_dimensions() = dim;
1049 
1050   ConvolutionDimensionNumbers dnums;
1051   dnums.set_input_batch_dimension(2);
1052   dnums.set_output_batch_dimension(2);
1053   dnums.set_input_feature_dimension(0);
1054   dnums.set_output_feature_dimension(0);
1055   dnums.add_input_spatial_dimensions(1);
1056   dnums.add_output_spatial_dimensions(1);
1057   dnums.add_input_spatial_dimensions(3);
1058   dnums.add_output_spatial_dimensions(3);
1059 
1060   dnums.set_kernel_output_feature_dimension(0);
1061   dnums.set_kernel_input_feature_dimension(2);
1062   dnums.add_kernel_spatial_dimensions(3);
1063   dnums.add_kernel_spatial_dimensions(1);
1064 
1065   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1066   b.AddInstruction(HloInstruction::CreateConvolve(
1067       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1068       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1069   m_->AddEntryComputation(b.Build());
1070 
1071   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1072 
1073   // clang-format off
1074   // Result dimensions: [feature=1, height=1, batch=1, width=2]
1075   Array4D<float> expected_array({{{{2514, 2685}}}});
1076   Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1077   // clang-format on
1078   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1079       use_bfloat16_ ? expected_array_bf16 : expected_array);
1080 
1081   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1082 }
1083 
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensions)1084 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) {
1085   HloComputation::Builder b(TestName());
1086 
1087   // clang-format off
1088   // Input dimensions: [feature=2, height=3, batch=1, width=4]
1089   Array4D<float> input({
1090     {{{1, 2, 3, 4}},
1091      {{5, 6, 7, 8}},
1092      {{9, 10, 11, 12}}},
1093     {{{13, 14, 15, 16}},
1094      {{17, 18, 19, 20}},
1095      {{21, 22, 23, 24}}}
1096   });
1097   // Weight dimensions:
1098   // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1099   Array4D<float> weight({{
1100     {{1, 7, 13},
1101      {4, 10, 16}},
1102     {{2, 8, 14},
1103      {5, 11, 17}},
1104     {{3, 9, 15},
1105      {6, 12, 18}}
1106   }});
1107   // clang-format on
1108 
1109   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1110   HloInstruction* lhs_instruction =
1111       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1112 
1113   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1114   HloInstruction* rhs_instruction =
1115       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1116 
1117   Window window;
1118   WindowDimension dim;
1119   dim.set_size(3);
1120   dim.set_stride(1);
1121   dim.set_padding_low(0);
1122   dim.set_padding_high(0);
1123   dim.set_window_dilation(1);
1124   dim.set_base_dilation(1);
1125   *window.add_dimensions() = dim;
1126   *window.add_dimensions() = dim;
1127 
1128   ConvolutionDimensionNumbers dnums;
1129   dnums.set_input_batch_dimension(2);
1130   dnums.set_output_batch_dimension(2);
1131   dnums.set_input_feature_dimension(0);
1132   dnums.set_output_feature_dimension(0);
1133   dnums.add_input_spatial_dimensions(1);
1134   dnums.add_output_spatial_dimensions(1);
1135   dnums.add_input_spatial_dimensions(3);
1136   dnums.add_output_spatial_dimensions(3);
1137 
1138   dnums.set_kernel_output_feature_dimension(0);
1139   dnums.set_kernel_input_feature_dimension(2);
1140   dnums.add_kernel_spatial_dimensions(3);
1141   dnums.add_kernel_spatial_dimensions(1);
1142 
1143   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1144   b.AddInstruction(HloInstruction::CreateConvolve(
1145       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1146       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1147   m_->AddEntryComputation(b.Build());
1148 
1149   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1150 
1151   // clang-format off
1152   // Result dimensions: [feature=1, height=1, batch=1, width=2]
1153   Array4D<float> expected_array({{{{2514, 2685}}}});
1154   Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1155   // clang-format on
1156   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1157       use_bfloat16_ ? expected_array_bf16 : expected_array);
1158 
1159   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1160 }
1161 
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithHighPadding)1162 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) {
1163   HloComputation::Builder b(TestName());
1164 
1165   Array4D<float> lhs_array(1, 1, 4, 4);
1166   // clang-format off
1167   lhs_array.FillWithYX(Array2D<float>({
1168     {1,  2,  3,  4 },
1169     {5,  6,  7,  8 },
1170     {9,  10, 11, 12},
1171     {13, 14, 15, 16},
1172   }));
1173   // clang-format on
1174   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1175   HloInstruction* lhs_instruction =
1176       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1177 
1178   Array4D<float> rhs_array(1, 1, 2, 2);
1179   // clang-format off
1180   rhs_array.FillWithYX(Array2D<float>({
1181     {5, 6},
1182     {7, 8},
1183   }));
1184   // clang-format on
1185   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1186   HloInstruction* rhs_instruction =
1187       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1188 
1189   Window window;
1190   WindowDimension dim;
1191   dim.set_size(2);
1192   dim.set_stride(1);
1193   dim.set_padding_low(0);
1194   dim.set_padding_high(1);
1195   dim.set_window_dilation(1);
1196   dim.set_base_dilation(2);
1197   *window.add_dimensions() = dim;
1198   *window.add_dimensions() = dim;
1199 
1200   ConvolutionDimensionNumbers dnums =
1201       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1202 
1203   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
1204   b.AddInstruction(HloInstruction::CreateConvolve(
1205       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1206       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1207   m_->AddEntryComputation(b.Build());
1208 
1209   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1210 
1211   Array4D<float> expected_array(1, 1, 7, 7);
1212   expected_array.FillWithYX(Array2D<float>({
1213       {5, 12, 10, 18, 15, 24, 20},
1214       {35, 48, 42, 56, 49, 64, 56},
1215       {25, 36, 30, 42, 35, 48, 40},
1216       {63, 80, 70, 88, 77, 96, 84},
1217       {45, 60, 50, 66, 55, 72, 60},
1218       {91, 112, 98, 120, 105, 128, 112},
1219       {65, 84, 70, 90, 75, 96, 80},
1220   }));
1221   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1222 
1223   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1224 }
1225 
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithLowAndHighPadding)1226 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) {
1227   HloComputation::Builder b(TestName());
1228 
1229   Array4D<float> lhs_array(1, 1, 4, 4);
1230   // clang-format off
1231   lhs_array.FillWithYX(Array2D<float>({
1232     {1,  2,  3,  4 },
1233     {5,  6,  7,  8 },
1234     {9,  10, 11, 12},
1235     {13, 14, 15, 16},
1236   }));
1237   // clang-format on
1238   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1239   HloInstruction* lhs_instruction =
1240       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1241 
1242   Array4D<float> rhs_array(1, 1, 2, 2);
1243   // clang-format off
1244   rhs_array.FillWithYX(Array2D<float>({
1245     {5, 6},
1246     {7, 8},
1247   }));
1248   // clang-format on
1249   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1250   HloInstruction* rhs_instruction =
1251       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1252 
1253   Window window;
1254   WindowDimension dim;
1255   dim.set_size(2);
1256   dim.set_stride(1);
1257   dim.set_padding_low(1);
1258   dim.set_padding_high(1);
1259   dim.set_window_dilation(1);
1260   dim.set_base_dilation(2);
1261   *window.add_dimensions() = dim;
1262   *window.add_dimensions() = dim;
1263 
1264   ConvolutionDimensionNumbers dnums =
1265       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1266 
1267   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
1268   b.AddInstruction(HloInstruction::CreateConvolve(
1269       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1270       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1271   m_->AddEntryComputation(b.Build());
1272 
1273   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1274 
1275   Array4D<float> expected_array(1, 1, 8, 8);
1276   expected_array.FillWithYX(Array2D<float>({
1277       {8, 7, 16, 14, 24, 21, 32, 28},
1278       {6, 5, 12, 10, 18, 15, 24, 20},
1279       {40, 35, 48, 42, 56, 49, 64, 56},
1280       {30, 25, 36, 30, 42, 35, 48, 40},
1281       {72, 63, 80, 70, 88, 77, 96, 84},
1282       {54, 45, 60, 50, 66, 55, 72, 60},
1283       {104, 91, 112, 98, 120, 105, 128, 112},
1284       {78, 65, 84, 70, 90, 75, 96, 80},
1285   }));
1286   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1287 
1288   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1289 }
1290 
TEST_P(HloEvaluatorBf16Test,DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides)1291 TEST_P(HloEvaluatorBf16Test,
1292        DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) {
1293   HloComputation::Builder b(TestName());
1294 
1295   Array4D<float> lhs_array(1, 1, 4, 4);
1296   // clang-format off
1297   lhs_array.FillWithYX(Array2D<float>({
1298     {1,  2,  3,  4 },
1299     {5,  6,  7,  8 },
1300     {9,  10, 11, 12},
1301     {13, 14, 15, 16},
1302   }));
1303   // clang-format on
1304   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1305   HloInstruction* lhs_instruction =
1306       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1307 
1308   Array4D<float> rhs_array(1, 1, 2, 3);
1309   // clang-format off
1310   rhs_array.FillWithYX(Array2D<float>({
1311     {5, 6, 7},
1312     {8, 9, 10},
1313   }));
1314   // clang-format on
1315   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1316   HloInstruction* rhs_instruction =
1317       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1318 
1319   Window window;
1320   WindowDimension dim;
1321   dim.set_size(2);
1322   dim.set_stride(1);
1323   dim.set_padding_low(2);
1324   dim.set_padding_high(2);
1325   dim.set_window_dilation(2);
1326   dim.set_base_dilation(2);
1327   *window.add_dimensions() = dim;
1328   dim.set_size(3);
1329   dim.set_stride(3);
1330   dim.set_padding_low(2);
1331   dim.set_padding_high(-1);
1332   dim.set_window_dilation(1);
1333   dim.set_base_dilation(3);
1334   *window.add_dimensions() = dim;
1335 
1336   ConvolutionDimensionNumbers dnums =
1337       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1338 
1339   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
1340   b.AddInstruction(HloInstruction::CreateConvolve(
1341       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1342       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1343   m_->AddEntryComputation(b.Build());
1344 
1345   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1346 
1347   Array4D<float> expected_array(1, 1, 9, 3);
1348   expected_array.FillWithYX(Array2D<float>({
1349       {10, 20, 30},
1350       {0, 0, 0},
1351       {57, 74, 91},
1352       {0, 0, 0},
1353       {125, 142, 159},
1354       {0, 0, 0},
1355       {193, 210, 227},
1356       {0, 0, 0},
1357       {91, 98, 105},
1358   }));
1359   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1360 
1361   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1362 }
1363 
TEST_P(HloEvaluatorBf16Test,Conv2DGroupedConvolution)1364 TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) {
1365   HloComputation::Builder b(TestName());
1366   std::vector<int64> input_dims = {1, 2, 2, 4};
1367   std::vector<int64> filter_dims = {2, 2, 2, 8};
1368   Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
1369   Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
1370   // Tensorflow dimension numbers for 2D convolution.
1371   ConvolutionDimensionNumbers dnums;
1372   dnums.set_input_batch_dimension(0);
1373   dnums.set_output_batch_dimension(0);
1374   dnums.add_input_spatial_dimensions(1);
1375   dnums.add_output_spatial_dimensions(1);
1376   dnums.add_input_spatial_dimensions(2);
1377   dnums.add_output_spatial_dimensions(2);
1378   dnums.set_input_feature_dimension(3);
1379   dnums.set_output_feature_dimension(3);
1380   dnums.add_kernel_spatial_dimensions(0);
1381   dnums.add_kernel_spatial_dimensions(1);
1382   dnums.set_kernel_input_feature_dimension(2);
1383   dnums.set_kernel_output_feature_dimension(3);
1384 
1385   Window window;
1386   WindowDimension dim;
1387   dim.set_size(2);
1388   dim.set_stride(1);
1389   dim.set_padding_low(0);
1390   dim.set_padding_high(0);
1391   dim.set_window_dilation(1);
1392   dim.set_base_dilation(1);
1393   *window.add_dimensions() = dim;
1394   *window.add_dimensions() = dim;
1395 
1396   std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
1397   std::iota(input_elems.begin(), input_elems.end(), -7);
1398   auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
1399   auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
1400   HloInstruction* lhs_instruction =
1401       b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
1402 
1403   std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1404   std::iota(filter_elems.begin(), filter_elems.end(), -31);
1405   auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
1406   auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
1407   HloInstruction* rhs_instruction =
1408       b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
1409 
1410   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
1411   b.AddInstruction(HloInstruction::CreateConvolve(
1412       shape, lhs_instruction, rhs_instruction,
1413       /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums,
1414       DefaultPrecisionConfig(2)));
1415   m_->AddEntryComputation(b.Build());
1416 
1417   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1418 
1419   Array4D<float> expected_array(1, 1, 1, 8);
1420   expected_array.FillWithYX(
1421       Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
1422   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1423   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1424 }
1425 
1426 class HloEvaluatorPreciseReduceTest : public HloTestBase {};
1427 
1428 // Tests that Reduce doesn't lose precision when adding many numbers (because
1429 // it accumulates its result in a double).
TEST_F(HloEvaluatorPreciseReduceTest,AddReductionPrecisionTest)1430 TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
1431   auto m = CreateNewVerifiedModule();
1432   HloComputation::Builder b(TestName());
1433 
1434   constexpr int kNumElements = 1 << 25;  // float += 1 saturates at 1<<24
1435   std::vector<float> v(kNumElements, 1.0f);
1436   HloInstruction* arg_instruction = b.AddInstruction(
1437       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
1438   HloInstruction* init_value = b.AddInstruction(
1439       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1440 
1441   HloComputation::Builder add_computation("add");
1442   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1443   auto param_lhs = add_computation.AddInstruction(
1444       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1445   auto param_rhs = add_computation.AddInstruction(
1446       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1447   add_computation.AddInstruction(HloInstruction::CreateBinary(
1448       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1449   auto add_func = m->AddEmbeddedComputation(add_computation.Build());
1450 
1451   HloInstruction* reduce_instruction = b.AddInstruction(
1452       HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
1453                                    /*dimensions_to_reduce=*/{0}, add_func));
1454   m->AddEntryComputation(b.Build());
1455 
1456   HloEvaluator hlo_eval;
1457   Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
1458   LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
1459 }
1460 
1461 // Reducing many numbers should be fast because it doesn't create
1462 // intermediate Literals; the microbenchmark should finish in < 1 msec.
BM_ReducePrecisely(int num_iters)1463 void BM_ReducePrecisely(int num_iters) {
1464   tensorflow::testing::StopTiming();
1465   HloComputation::Builder b("BM_ReducePrecisely");
1466   HloModuleConfig config;
1467   config.set_debug_options(GetDebugOptionsFromFlags());
1468   HloModule module("BM_ReducePrecisely", config);
1469 
1470   constexpr int kNumElements = 1 << 25;  // float += 1 saturates at 1<<24
1471   std::vector<float> v(kNumElements, 1.0f);
1472   HloInstruction* arg_instruction = b.AddInstruction(
1473       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
1474   auto init_value = b.AddInstruction(
1475       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1476 
1477   HloComputation::Builder add_computation("add");
1478   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1479   auto param_lhs = add_computation.AddInstruction(
1480       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1481   auto param_rhs = add_computation.AddInstruction(
1482       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1483   add_computation.AddInstruction(HloInstruction::CreateBinary(
1484       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1485   auto add_func = module.AddEmbeddedComputation(add_computation.Build());
1486 
1487   HloInstruction* reduce_instruction = b.AddInstruction(
1488       HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
1489                                    /*dimensions_to_reduce=*/{0}, add_func));
1490   module.AddEntryComputation(b.Build());
1491 
1492   HloEvaluator hlo_eval;
1493   tensorflow::testing::StartTiming();
1494   hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
1495   tensorflow::testing::StopTiming();
1496 }
1497 
1498 BENCHMARK(BM_ReducePrecisely);
1499 
TEST_P(HloEvaluatorBf16Test,ReduceAdd)1500 TEST_P(HloEvaluatorBf16Test, ReduceAdd) {
1501   HloComputation::Builder b(TestName());
1502 
1503   // arg:
1504   // f32[2,3] {
1505   //  { 1, 2, 3 },
1506   //  { 5, 6, 7 },
1507   // }
1508   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
1509   arg_array->FillUnique(1.0f);
1510   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
1511 
1512   HloInstruction* arg_instruction =
1513       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1514 
1515   auto init_value = b.AddInstruction(
1516       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1517 
1518   HloComputation::Builder add_computation("add");
1519   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1520   auto param_lhs = add_computation.AddInstruction(
1521       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1522   auto param_rhs = add_computation.AddInstruction(
1523       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1524   add_computation.AddInstruction(HloInstruction::CreateBinary(
1525       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1526   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
1527 
1528   Shape shape = ShapeUtil::MakeShape(F32, {2});
1529   b.AddInstruction(
1530       HloInstruction::CreateReduce(shape, arg_instruction, init_value,
1531                                    /*dimensions_to_reduce=*/{1}, add_func));
1532 
1533   m_->AddEntryComputation(b.Build());
1534 
1535   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1536 
1537   auto expected = LiteralUtil::CreateR1<float>({6, 18});
1538 
1539   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1540 }
1541 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMax)1542 TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) {
1543   HloComputation::Builder b(TestName());
1544 
1545   // arg:
1546   // f32[2,3] {
1547   //  { 1, 2, 3 },
1548   //  { 5, 6, 7 },
1549   // }
1550   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
1551   arg_array->FillUnique(1.0f);
1552   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
1553 
1554   HloInstruction* arg_instruction =
1555       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1556 
1557   auto init_value = b.AddInstruction(
1558       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1559 
1560   HloComputation::Builder max_computation("max");
1561   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1562   auto param_lhs = max_computation.AddInstruction(
1563       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1564   auto param_rhs = max_computation.AddInstruction(
1565       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1566   max_computation.AddInstruction(HloInstruction::CreateBinary(
1567       scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
1568   auto max_func = m_->AddEmbeddedComputation(max_computation.Build());
1569 
1570   Window window;
1571   WindowDimension dim;
1572   dim.set_size(2);
1573   dim.set_stride(1);
1574   dim.set_padding_low(0);
1575   dim.set_padding_high(0);
1576   dim.set_window_dilation(1);
1577   dim.set_base_dilation(1);
1578   *window.add_dimensions() = dim;
1579   *window.add_dimensions() = dim;
1580 
1581   Shape shape = ShapeUtil::MakeShape(F32, {1, 2});
1582   b.AddInstruction(HloInstruction::CreateReduceWindow(
1583       shape, arg_instruction, init_value, window, max_func));
1584 
1585   m_->AddEntryComputation(b.Build());
1586 
1587   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1588 
1589   auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
1590   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1591 }
1592 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxWindowDilation)1593 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) {
1594   HloComputation::Builder b(TestName());
1595 
1596   // arg:
1597   // f32[3,3] {
1598   //  { 1, 2, 3 },
1599   //  { 5, 6, 7 },
1600   //  { 9, 10, 11 },
1601   // }
1602   auto arg_array = absl::make_unique<Array2D<float>>(3, 3);
1603   arg_array->FillUnique(1.0f);
1604   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
1605 
1606   HloInstruction* arg_instruction =
1607       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1608 
1609   auto init_value = b.AddInstruction(
1610       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1611 
1612   HloComputation::Builder max_computation("max");
1613   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1614   auto param_lhs = max_computation.AddInstruction(
1615       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1616   auto param_rhs = max_computation.AddInstruction(
1617       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1618   max_computation.AddInstruction(HloInstruction::CreateBinary(
1619       scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
1620   auto max_func = m_->AddEmbeddedComputation(max_computation.Build());
1621 
1622   Window window;
1623   WindowDimension dim;
1624   dim.set_size(2);
1625   dim.set_stride(1);
1626   dim.set_padding_low(0);
1627   dim.set_padding_high(0);
1628   dim.set_window_dilation(2);
1629   dim.set_base_dilation(1);
1630   *window.add_dimensions() = dim;
1631   *window.add_dimensions() = dim;
1632 
1633   Shape shape = ShapeUtil::MakeShape(F32, {1, 1});
1634   b.AddInstruction(HloInstruction::CreateReduceWindow(
1635       shape, arg_instruction, init_value, window, max_func));
1636 
1637   m_->AddEntryComputation(b.Build());
1638 
1639   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1640 
1641   auto expected = LiteralUtil::CreateR2<float>({{11}});
1642   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1643 }
1644 
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd)1645 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) {
1646   HloComputation::Builder b(TestName());
1647 
1648   // arg:
1649   // f32[2,3] {
1650   //  { 1, 2, 3 },
1651   //  { 5, 6, 7 },
1652   // }
1653   auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
1654   arg_array->FillUnique(1.0f);
1655   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
1656 
1657   HloInstruction* arg_instruction =
1658       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1659 
1660   auto init_value = b.AddInstruction(
1661       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1662 
1663   HloComputation::Builder add_computation("add");
1664   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1665   auto param_lhs = add_computation.AddInstruction(
1666       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1667   auto param_rhs = add_computation.AddInstruction(
1668       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1669   add_computation.AddInstruction(HloInstruction::CreateBinary(
1670       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1671   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
1672 
1673   Window window;
1674   WindowDimension dim;
1675   dim.set_size(1);
1676   dim.set_stride(1);
1677   dim.set_padding_low(0);
1678   dim.set_padding_high(0);
1679   dim.set_window_dilation(1);
1680   dim.set_base_dilation(1);
1681   *window.add_dimensions() = dim;
1682   dim.set_size(2);
1683   dim.set_stride(1);
1684   dim.set_padding_low(1);
1685   dim.set_padding_high(0);
1686   dim.set_window_dilation(1);
1687   dim.set_base_dilation(1);
1688   *window.add_dimensions() = dim;
1689 
1690   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1691   b.AddInstruction(HloInstruction::CreateReduceWindow(
1692       shape, arg_instruction, init_value, window, add_func));
1693 
1694   m_->AddEntryComputation(b.Build());
1695 
1696   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1697 
1698   auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
1699   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1700 }
1701 
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd6D)1702 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
1703   HloComputation::Builder b(TestName());
1704 
1705   // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
1706   std::vector<int64> input_dims(6, 4);
1707   Literal arg_literal =
1708       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
1709 
1710   HloInstruction* arg_instruction =
1711       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
1712 
1713   auto init_value = b.AddInstruction(
1714       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
1715 
1716   HloComputation::Builder add_computation("add");
1717   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1718   auto param_lhs = add_computation.AddInstruction(
1719       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
1720   auto param_rhs = add_computation.AddInstruction(
1721       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
1722   add_computation.AddInstruction(HloInstruction::CreateBinary(
1723       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
1724   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
1725 
1726   Window window;
1727 
1728   WindowDimension trivial_dim;
1729   trivial_dim.set_size(1);
1730   trivial_dim.set_stride(1);
1731   trivial_dim.set_padding_low(0);
1732   trivial_dim.set_padding_high(0);
1733   trivial_dim.set_window_dilation(1);
1734   trivial_dim.set_base_dilation(1);
1735 
1736   WindowDimension active_dim;
1737   active_dim.set_size(2);
1738   active_dim.set_stride(1);
1739   active_dim.set_padding_low(0);
1740   active_dim.set_padding_high(0);
1741   active_dim.set_window_dilation(1);
1742   active_dim.set_base_dilation(1);
1743 
1744   *window.add_dimensions() = trivial_dim;
1745   *window.add_dimensions() = active_dim;
1746   *window.add_dimensions() = active_dim;
1747   *window.add_dimensions() = active_dim;
1748   *window.add_dimensions() = trivial_dim;
1749   *window.add_dimensions() = trivial_dim;
1750 
1751   Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4});
1752   b.AddInstruction(HloInstruction::CreateReduceWindow(
1753       shape, arg_instruction, init_value, window, add_func));
1754 
1755   m_->AddEntryComputation(b.Build());
1756 
1757   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1758 
1759   std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
1760   Literal result_literal =
1761       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
1762   EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
1763 }
1764 
TEST_P(HloEvaluatorBf16Test,StridedSlice)1765 TEST_P(HloEvaluatorBf16Test, StridedSlice) {
1766   HloComputation::Builder b(TestName());
1767 
1768   // arg:
1769   // f32[3,5] {
1770   //  { 1, 2, 3, 4, 5 },
1771   //  { 9, 10, 11, 12, 13 },
1772   //  { 17, 18, 19, 20, 21 },
1773   // }
1774   auto operand_array = absl::make_unique<Array2D<float>>(3, 5);
1775   operand_array->FillUnique(1.0f);
1776   auto operand_literal =
1777       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
1778 
1779   HloInstruction* operand = b.AddInstruction(
1780       HloInstruction::CreateConstant(std::move(operand_literal)));
1781 
1782   Shape shape = ShapeUtil::MakeShape(F32, {2, 1});
1783   b.AddInstruction(HloInstruction::CreateSlice(shape, operand,
1784                                                /*start_indices=*/{0, 2},
1785                                                /*limit_indices=*/{3, 5},
1786                                                /*strides=*/{2, 3}));
1787   m_->AddEntryComputation(b.Build());
1788 
1789   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1790 
1791   auto expected = LiteralUtil::CreateR2<float>({
1792       {3},
1793       {19},
1794   });
1795 
1796   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1797 }
1798 
TEST_P(HloEvaluatorBf16Test,DynamicSlice)1799 TEST_P(HloEvaluatorBf16Test, DynamicSlice) {
1800   HloComputation::Builder b(TestName());
1801 
1802   // arg:
1803   // f32[2,4] {
1804   //  { 1, 2, 3, 4 },
1805   //  { 5, 6, 7, 8 },
1806   // }
1807   auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
1808   operand_array->FillUnique(1.0f);
1809   auto operand_literal =
1810       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
1811 
1812   HloInstruction* operand = b.AddInstruction(
1813       HloInstruction::CreateConstant(std::move(operand_literal)));
1814 
1815   auto zero = b.AddInstruction(
1816       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
1817   auto one = b.AddInstruction(
1818       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
1819 
1820   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1821   b.AddInstruction(
1822       HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3}));
1823   m_->AddEntryComputation(b.Build());
1824 
1825   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1826 
1827   auto expected = LiteralUtil::CreateR2<float>({
1828       {2, 3, 4},
1829       {6, 7, 8},
1830   });
1831 
1832   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1833 }
1834 
1835 // Verifies that the HloEvaluator's implementation goes along with existing
1836 // backends' behavior, although this is not required by the spec.
TEST_P(HloEvaluatorBf16Test,DynamicSliceModSlice)1837 TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) {
1838   HloComputation::Builder b(TestName());
1839 
1840   // arg:
1841   // f32[2,4] {
1842   //  { 1, 2, 3, 4 },
1843   //  { 5, 6, 7, 8 },
1844   // }
1845   auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
1846   operand_array->FillUnique(1.0f);
1847   auto operand_literal =
1848       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
1849 
1850   HloInstruction* operand = b.AddInstruction(
1851       HloInstruction::CreateConstant(std::move(operand_literal)));
1852 
1853   auto two = b.AddInstruction(
1854       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
1855   auto one = b.AddInstruction(
1856       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
1857 
1858   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1859   b.AddInstruction(
1860       HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3}));
1861   m_->AddEntryComputation(b.Build());
1862 
1863   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1864 
1865   auto expected = LiteralUtil::CreateR2<float>({
1866       {2, 3, 4},
1867       {6, 7, 8},
1868   });
1869 
1870   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1871 }
1872 
TEST_P(HloEvaluatorBf16Test,DynamicSliceUpdate)1873 TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) {
1874   HloComputation::Builder b(TestName());
1875 
1876   // arg:
1877   // f32[2,3] {
1878   //  { 1, 2, 3 },
1879   //  { 5, 6, 7 },
1880   // }
1881   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
1882   operand_array->FillUnique(1.0);
1883   auto operand_literal =
1884       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
1885 
1886   HloInstruction* operand = b.AddInstruction(
1887       HloInstruction::CreateConstant(std::move(operand_literal)));
1888 
1889   auto zero = b.AddInstruction(
1890       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
1891   auto one = b.AddInstruction(
1892       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
1893 
1894   auto update = b.AddInstruction(HloInstruction::CreateConstant(
1895       LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
1896 
1897   Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
1898   b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1899       shape, operand, update, {zero, one}));
1900   m_->AddEntryComputation(b.Build());
1901 
1902   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1903 
1904   auto expected = LiteralUtil::CreateR2<double>({
1905       {1, -2, -3},
1906       {5, -6, -7},
1907   });
1908 
1909   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1910 }
1911 
TEST_P(HloEvaluatorBf16Test,SetAndGetTuples)1912 TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) {
1913   HloComputation::Builder b(TestName());
1914 
1915   // arg:
1916   // f32[2,3] {
1917   //  { 1, 2, 3 },
1918   //  { 5, 6, 7 },
1919   // }
1920   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
1921   operand_array->FillUnique(1.0);
1922   auto operand_literal2 =
1923       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
1924 
1925   HloInstruction* operand2 = b.AddInstruction(
1926       HloInstruction::CreateConstant(std::move(operand_literal2)));
1927   HloInstruction* operand1 = b.AddInstruction(
1928       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
1929 
1930   auto tuple =
1931       b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
1932 
1933   Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
1934   b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1));
1935 
1936   m_->AddEntryComputation(b.Build());
1937 
1938   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1939 
1940   auto expected = LiteralUtil::CreateR2<double>({
1941       {1, 2, 3},
1942       {5, 6, 7},
1943   });
1944 
1945   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1946 }
1947 
TEST_P(HloEvaluatorBf16Test,SetAndGetNestedTuples)1948 TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) {
1949   HloComputation::Builder b(TestName());
1950 
1951   // arg:
1952   // f32[2,3] {
1953   //  { 1, 2, 3 },
1954   //  { 5, 6, 7 },
1955   // }
1956   auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
1957   operand_array->FillUnique(1.0);
1958 
1959   HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
1960       LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
1961   HloInstruction* operand1 = b.AddInstruction(
1962       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
1963 
1964   auto tuple1 =
1965       b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
1966   auto tuple2 =
1967       b.AddInstruction(HloInstruction::CreateTuple({operand2, operand2}));
1968 
1969   auto outer_tuple =
1970       b.AddInstruction(HloInstruction::CreateTuple({tuple1, tuple2}));
1971 
1972   b.AddInstruction(
1973       HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1));
1974 
1975   m_->AddEntryComputation(b.Build());
1976 
1977   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1978 
1979   auto result_inner_literal =
1980       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
1981   auto expected =
1982       LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
1983 
1984   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1985 }
1986 
TEST_P(HloEvaluatorBf16Test,Reverse)1987 TEST_P(HloEvaluatorBf16Test, Reverse) {
1988   HloComputation::Builder b(TestName());
1989 
1990   // Input shape is float[4x3x2x1].
1991   // clang-format off
1992   Array4D<float> input({
1993     {{{1.0f}, {2.0f}},
1994      {{3.0f}, {4.0f}},
1995      {{5.0f}, {6.0f}}},
1996     {{{7.0f}, {8.0f}},
1997      {{9.0f}, {10.0f}},
1998      {{11.0f}, {12.0f}}},
1999     {{{13.0f}, {14.0f}},
2000      {{15.0f}, {16.0f}},
2001      {{17.0f}, {18.0f}}},
2002     {{{19.0f}, {20.0f}},
2003      {{21.0f}, {22.0f}},
2004      {{23.0f}, {24.0f}}},
2005   });
2006   // clang-format on
2007   auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
2008   HloInstruction* operand = b.AddInstruction(
2009       HloInstruction::CreateConstant(std::move(operand_literal)));
2010 
2011   const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
2012   b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
2013   m_->AddEntryComputation(b.Build());
2014 
2015   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2016 
2017   // clang-format off
2018   auto expected = LiteralUtil::CreateR4FromArray4D<float>({
2019     {{{23.0f}, {24.0f}},
2020      {{21.0f}, {22.0f}},
2021      {{19.0f}, {20.0f}}},
2022 
2023     {{{17.0f}, {18.0f}},
2024      {{15.0f}, {16.0f}},
2025      {{13.0f}, {14.0f}}},
2026 
2027     {{{11.0f}, {12.0f}},
2028      {{9.0f}, {10.0f}},
2029      {{7.0f}, {8.0f}}},
2030 
2031     {{{5.0f}, {6.0f}},
2032      {{3.0f}, {4.0f}},
2033      {{1.0f}, {2.0f}}},
2034   });
2035   // clang-format on
2036 
2037   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2038 }
2039 
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutions)2040 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) {
2041   HloComputation::Builder b(TestName());
2042   Shape shape = ShapeUtil::MakeShape(F32, {4});
2043 
2044   HloInstruction* param0 =
2045       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
2046   HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
2047       shape, HloOpcode::kMultiply, param0, param0));
2048   HloInstruction* add = b.AddInstruction(
2049       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, square));
2050 
2051   // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
2052   HloEvaluator evaluator;
2053   Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
2054   Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
2055   TF_ASSERT_OK_AND_ASSIGN(
2056       Literal result,
2057       evaluator.EvaluateWithSubstitutions(
2058           add, {{param0, &param0_literal}, {square, &square_literal}}));
2059   EXPECT_TRUE(LiteralTestUtil::Equal(
2060       LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
2061 }
2062 
2063 // Check that EvaluateWithSubstitutions works if one of the operands to the op
2064 // we're evaluating is a constant.
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutionsWithConstantOperand)2065 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) {
2066   HloComputation::Builder b(TestName());
2067   Shape shape = ShapeUtil::MakeShape(F32, {4});
2068 
2069   HloInstruction* param0 =
2070       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
2071   HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
2072       shape, HloOpcode::kMultiply, param0, param0));
2073   HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
2074       LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
2075   HloInstruction* add = b.AddInstruction(
2076       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
2077 
2078   // Evaluate add with square = {10, 20, 30, 40}.
2079   HloEvaluator evaluator;
2080   Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
2081   TF_ASSERT_OK_AND_ASSIGN(
2082       Literal result,
2083       evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}));
2084   EXPECT_TRUE(LiteralTestUtil::Equal(
2085       LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
2086 }
2087 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV1)2088 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
2089   const char* hlo_text = R"(
2090 HloModule TensorFlowGatherV1
2091 
2092 ENTRY main {
2093   operand = s32[3,3] parameter(0)
2094   indices = s32[2] parameter(1)
2095   ROOT gather = s32[2,3] gather(operand, indices),
2096       offset_dims={1},
2097       collapsed_slice_dims={0},
2098       start_index_map={0},
2099       index_vector_dim=1,
2100       slice_sizes={1, 3}
2101 }
2102 )";
2103   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2104   Literal operand =
2105       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2106   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
2107   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2108   EXPECT_TRUE(LiteralTestUtil::Equal(
2109       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}), result));
2110 }
2111 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV2)2112 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
2113   const char* hlo_text = R"(
2114 HloModule TensorFlowGatherV2
2115 
2116 ENTRY main {
2117   operand = s32[3,3] parameter(0)
2118   indices = s32[2] parameter(1)
2119   ROOT gather = s32[3,2] gather(operand, indices),
2120       offset_dims={0},
2121       collapsed_slice_dims={1},
2122       start_index_map={1},
2123       index_vector_dim=1,
2124       slice_sizes={3, 1}
2125 }
2126 )";
2127   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2128   Literal operand =
2129       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2130   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
2131   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2132   EXPECT_TRUE(LiteralTestUtil::Equal(
2133       LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}), result));
2134 }
2135 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherMultipleBatchDims)2136 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
2137   const char* hlo_text = R"(
2138 HloModule TensorFlowGatherMultipleBatchDims
2139 
2140 ENTRY main {
2141   operand = s32[3,3] parameter(0)
2142   indices = s32[2,2] parameter(1)
2143   ROOT gather = s32[2,3,2] gather(operand, indices),
2144       offset_dims={1},
2145       collapsed_slice_dims={1},
2146       start_index_map={1},
2147       index_vector_dim=2,
2148       slice_sizes={3, 1}
2149 }
2150 )";
2151   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2152   Literal operand =
2153       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2154   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
2155   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2156   EXPECT_TRUE(LiteralTestUtil::Equal(
2157       LiteralUtil::CreateR3<int32>(
2158           {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
2159       result));
2160 }
2161 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNd)2162 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
2163   const char* hlo_text = R"(
2164 HloModule TensorFlowGatherNd
2165 
2166 ENTRY main {
2167   operand = s32[3,3,2] parameter(0)
2168   indices = s32[2,2] parameter(1)
2169   ROOT gather = s32[2,2] gather(operand, indices),
2170       offset_dims={1},
2171       collapsed_slice_dims={0,1},
2172       start_index_map={0,1},
2173       index_vector_dim=1,
2174       slice_sizes={1,1,2}
2175 }
2176 )";
2177   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2178   Literal operand =
2179       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
2180                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
2181                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
2182   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
2183   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2184   EXPECT_TRUE(LiteralTestUtil::Equal(
2185       LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}), result));
2186 }
2187 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim)2188 TEST_F(HloEvaluatorTest,
2189        EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
2190   const char* hlo_text = R"(
2191 HloModule TensorFlowGatherNd
2192 
2193 ENTRY main {
2194   operand = s32[3,3,2] parameter(0)
2195   indices = s32[2,2] parameter(1)
2196   ROOT gather = s32[2,2] gather(operand, indices),
2197       offset_dims={1},
2198       collapsed_slice_dims={0,1},
2199       start_index_map={0,1},
2200       index_vector_dim=0,
2201       slice_sizes={1,1,2}
2202 }
2203 )";
2204   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2205   Literal operand =
2206       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
2207                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
2208                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
2209   Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
2210   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2211   EXPECT_TRUE(LiteralTestUtil::Equal(
2212       LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}), result));
2213 }
2214 
TEST_F(HloEvaluatorTest,EvaluateGather_DynamicSlice)2215 TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
2216   const char* hlo_text = R"(
2217 HloModule DynamicSlice
2218 
2219 ENTRY main {
2220   operand = s32[3,3] parameter(0)
2221   indices = s32[2] parameter(1)
2222   ROOT gather = s32[1,1] gather(operand, indices),
2223       offset_dims={0,1},
2224       collapsed_slice_dims={},
2225       start_index_map={0,1},
2226       index_vector_dim=0,
2227       slice_sizes={1,1}
2228 }
2229 )";
2230   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2231   Literal operand =
2232       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2233   Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
2234   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2235   EXPECT_TRUE(
2236       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}), result));
2237 }
2238 
TEST_F(HloEvaluatorTest,EvaluateGather_BatchDynamicSlice)2239 TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
2240   const char* hlo_text = R"(
2241 HloModule BatchDynamicSlice
2242 
2243 ENTRY main {
2244   operand = s32[3,3] parameter(0)
2245   indices = s32[2,2] parameter(1)
2246   ROOT gather = s32[2,1,1] gather(operand, indices),
2247       offset_dims={1,2},
2248       collapsed_slice_dims={},
2249       start_index_map={0,1},
2250       index_vector_dim=0,
2251       slice_sizes={1,1}
2252 }
2253 )";
2254   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2255   Literal operand =
2256       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2257   Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
2258   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2259   EXPECT_TRUE(LiteralTestUtil::Equal(
2260       LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}), result));
2261 }
2262 
TEST_F(HloEvaluatorTest,EvaluateGather_ZeroDimBounds)2263 TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
2264   const char* hlo_text = R"(
2265 HloModule TensorFlowGatherV1
2266 
2267 ENTRY main {
2268   operand = s32[3,0] parameter(0)
2269   indices = s32[2] parameter(1)
2270   ROOT gather = s32[2,0] gather(operand, indices),
2271       offset_dims={1},
2272       collapsed_slice_dims={0},
2273       start_index_map={0},
2274       index_vector_dim=1,
2275       slice_sizes={1, 0}
2276 }
2277 )";
2278   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2279   Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
2280   Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
2281   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2282   EXPECT_TRUE(
2283       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}), result));
2284 }
2285 
TEST_F(HloEvaluatorTest,EvaluateGather_NoOutputWindowDims)2286 TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
2287   const string hlo_text = R"(
2288 HloModule GatherXd
2289 
2290 ENTRY main {
2291   operand = s32[3] parameter(0)
2292   indices = s32[2,2,1] parameter(1)
2293   ROOT gather = s32[2,2] gather(operand, indices),
2294       offset_dims={},
2295       collapsed_slice_dims={0},
2296       start_index_map={0},
2297       index_vector_dim=2,
2298       slice_sizes={1}
2299 }
2300 )";
2301   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2302 
2303   Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
2304   Literal start_indices =
2305       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
2306   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
2307   EXPECT_TRUE(LiteralTestUtil::Equal(
2308       LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}), result));
2309 }
2310 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV1_Update)2311 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
2312   const char* hlo_text = R"(
2313 HloModule TensorFlowScatterV1
2314 
2315 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2316   lhs = s32[] parameter(0)
2317   ROOT rhs = s32[] parameter(1)
2318 }
2319 
2320 ENTRY main {
2321   operand = s32[3,3] parameter(0)
2322   indices = s32[2] parameter(1)
2323   updates = s32[2,3] parameter(2)
2324   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2325       to_apply=update_s32,
2326       update_window_dims={1},
2327       inserted_window_dims={0},
2328       scatter_dims_to_operand_dims={0},
2329       index_vector_dim=1
2330 }
2331 )";
2332   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2333   Literal operand =
2334       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2335   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2336   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2337   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2338                           Evaluate({&operand, &scatter_indices, &updates}));
2339   EXPECT_TRUE(LiteralTestUtil::Equal(
2340       LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
2341       result));
2342 }
2343 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV2_Update)2344 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
2345   const char* hlo_text = R"(
2346 HloModule TensorFlowScatterV2
2347 
2348 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2349   lhs = s32[] parameter(0)
2350   ROOT rhs = s32[] parameter(1)
2351 }
2352 
2353 ENTRY main {
2354   operand = s32[3,3] parameter(0)
2355   indices = s32[2] parameter(1)
2356   updates = s32[3,2] parameter(2)
2357   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2358       to_apply=update_s32,
2359       update_window_dims={0},
2360       inserted_window_dims={1},
2361       scatter_dims_to_operand_dims={1},
2362       index_vector_dim=1
2363 }
2364 )";
2365   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2366   Literal operand =
2367       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2368   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2369   Literal updates =
2370       LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
2371   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2372                           Evaluate({&operand, &scatter_indices, &updates}));
2373   EXPECT_TRUE(LiteralTestUtil::Equal(
2374       LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
2375       result));
2376 }
2377 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Add)2378 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
2379   const char* hlo_text = R"(
2380 HloModule TensorFlowScatter
2381 
2382 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2383   lhs = s32[] parameter(0)
2384   rhs = s32[] parameter(1)
2385   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2386 }
2387 
2388 ENTRY main {
2389   operand = s32[3,3] parameter(0)
2390   indices = s32[2] parameter(1)
2391   updates = s32[2,3] parameter(2)
2392   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2393       to_apply=add_s32,
2394       update_window_dims={1},
2395       inserted_window_dims={0},
2396       scatter_dims_to_operand_dims={0},
2397       index_vector_dim=1
2398 }
2399 )";
2400   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2401   Literal operand =
2402       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2403   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2404   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2405   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2406                           Evaluate({&operand, &scatter_indices, &updates}));
2407   EXPECT_TRUE(LiteralTestUtil::Equal(
2408       LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
2409       result));
2410 }
2411 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Mul)2412 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
2413   const char* hlo_text = R"(
2414 HloModule TensorFlowScatter
2415 
2416 mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2417   lhs = s32[] parameter(0)
2418   rhs = s32[] parameter(1)
2419   ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
2420 }
2421 
2422 ENTRY main {
2423   operand = s32[3,3] parameter(0)
2424   indices = s32[2] parameter(1)
2425   updates = s32[2,3] parameter(2)
2426   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2427       to_apply=mul_s32,
2428       update_window_dims={1},
2429       inserted_window_dims={0},
2430       scatter_dims_to_operand_dims={0},
2431       index_vector_dim=1
2432 }
2433 )";
2434   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2435   Literal operand =
2436       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2437   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2438   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2439   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2440                           Evaluate({&operand, &scatter_indices, &updates}));
2441   EXPECT_TRUE(LiteralTestUtil::Equal(
2442       LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
2443       result));
2444 }
2445 
TEST_P(HloEvaluatorBf16Test,EvaluateScatter_TensorFlowScatter_F32)2446 TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) {
2447   const char* hlo_text = R"(
2448 HloModule TensorFlowScatter
2449 
2450 add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
2451   lhs = f32[] parameter(0)
2452   rhs = f32[] parameter(1)
2453   ROOT add = f32[] add(f32[] lhs, f32[] rhs)
2454 }
2455 
2456 ENTRY main {
2457   operand = f32[3,3] parameter(0)
2458   indices = s32[2] parameter(1)
2459   updates = f32[2,3] parameter(2)
2460   ROOT scatter = f32[3,3] scatter(operand, indices, updates),
2461       to_apply=add_f32,
2462       update_window_dims={1},
2463       inserted_window_dims={0},
2464       scatter_dims_to_operand_dims={0},
2465       index_vector_dim=1
2466 }
2467 )";
2468   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2469   Literal operand = LiteralUtil::CreateR2<float>(
2470       {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
2471   Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
2472   Literal updates =
2473       LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
2474   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2475                           Evaluate({&operand, &scatter_indices, &updates}));
2476   EXPECT_TRUE(LiteralTestUtil::Near(
2477       LiteralUtil::CreateR2<float>(
2478           {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
2479       result, ErrorSpec{0.1, 0.01}));
2480 }
2481 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_RepeatedIndices)2482 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
2483   const char* hlo_text = R"(
2484 HloModule TensorFlowScatter
2485 
2486 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2487   lhs = s32[] parameter(0)
2488   rhs = s32[] parameter(1)
2489   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2490 }
2491 
2492 ENTRY main {
2493   operand = s32[3,3] parameter(0)
2494   indices = s32[2] parameter(1)
2495   updates = s32[2,3] parameter(2)
2496   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2497       to_apply=add_s32,
2498       update_window_dims={1},
2499       inserted_window_dims={0},
2500       scatter_dims_to_operand_dims={0},
2501       index_vector_dim=1
2502 }
2503 )";
2504   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2505   Literal operand =
2506       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2507   Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
2508   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2509   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2510                           Evaluate({&operand, &scatter_indices, &updates}));
2511   EXPECT_TRUE(LiteralTestUtil::Equal(
2512       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
2513       result));
2514 }
2515 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_MultipleBatchDims)2516 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
2517   const char* hlo_text = R"(
2518 HloModule TensorFlowScatterMultipleBatchDims
2519 
2520 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2521   lhs = s32[] parameter(0)
2522   rhs = s32[] parameter(1)
2523   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2524 }
2525 
2526 ENTRY main {
2527   operand = s32[3,3] parameter(0)
2528   indices = s32[2,2] parameter(1)
2529   updates = s32[2,3,2] parameter(2)
2530   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2531       to_apply=add_s32,
2532       update_window_dims={1},
2533       inserted_window_dims={1},
2534       scatter_dims_to_operand_dims={1},
2535       index_vector_dim=2
2536 }
2537 )";
2538   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2539   Literal operand =
2540       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2541   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
2542   Literal updates = LiteralUtil::CreateR3<int32>(
2543       {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
2544   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2545                           Evaluate({&operand, &scatter_indices, &updates}));
2546   EXPECT_TRUE(LiteralTestUtil::Equal(
2547       LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
2548       result));
2549 }
2550 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd)2551 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
2552   const char* hlo_text = R"(
2553 HloModule TensorFlowScatterNd
2554 
2555 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2556   lhs = s32[] parameter(0)
2557   ROOT rhs = s32[] parameter(1)
2558 }
2559 
2560 ENTRY main {
2561   operand = s32[3,3,2] parameter(0)
2562   indices = s32[2,2] parameter(1)
2563   updates = s32[2,2] parameter(2)
2564   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
2565       to_apply=update_s32,
2566       update_window_dims={1},
2567       inserted_window_dims={0,1},
2568       scatter_dims_to_operand_dims={0,1},
2569       index_vector_dim=1
2570 }
2571 )";
2572   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2573   Literal operand =
2574       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
2575                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
2576                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
2577   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
2578   Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
2579   Literal expected =
2580       LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}},  //
2581                                     {{-40, 40}, {-5, 5}, {-6, 6}},  //
2582                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
2583   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2584                           Evaluate({&operand, &scatter_indices, &updates}));
2585   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2586 }
2587 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim)2588 TEST_F(HloEvaluatorTest,
2589        EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) {
2590   const char* hlo_text = R"(
2591 HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
2592 
2593 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2594   lhs = s32[] parameter(0)
2595   ROOT rhs = s32[] parameter(1)
2596 }
2597 
2598 ENTRY main {
2599   operand = s32[3,3,2] parameter(0)
2600   indices = s32[2,2] parameter(1)
2601   updates = s32[2,2] parameter(2)
2602   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
2603       to_apply=update_s32,
2604       update_window_dims={1},
2605       inserted_window_dims={0,1},
2606       scatter_dims_to_operand_dims={0,1},
2607       index_vector_dim=0
2608 }
2609 )";
2610   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2611   Literal operand =
2612       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
2613                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
2614                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
2615   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
2616   Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
2617   Literal expected =
2618       LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}},  //
2619                                     {{-4, 4}, {-5, 5}, {-6, 6}},      //
2620                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
2621   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2622                           Evaluate({&operand, &scatter_indices, &updates}));
2623   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2624 }
2625 
TEST_F(HloEvaluatorTest,EvaluateScatter_DynamicUpdateSlice)2626 TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
2627   const char* hlo_text = R"(
2628 HloModule DynamicUpdateSlice
2629 
2630 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2631   lhs = s32[] parameter(0)
2632   ROOT rhs = s32[] parameter(1)
2633 }
2634 
2635 ENTRY main {
2636   operand = s32[3,3] parameter(0)
2637   indices = s32[2] parameter(1)
2638   updates = s32[1,1] parameter(2)
2639   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2640       to_apply=update_s32,
2641       update_window_dims={0,1},
2642       inserted_window_dims={},
2643       scatter_dims_to_operand_dims={0,1},
2644       index_vector_dim=0
2645 }
2646 )";
2647   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2648   Literal operand =
2649       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2650   Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
2651   Literal updates = LiteralUtil::CreateR2<int32>({{10}});
2652   Literal expected =
2653       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
2654   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2655                           Evaluate({&operand, &scatter_indices, &updates}));
2656   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2657 }
2658 
TEST_F(HloEvaluatorTest,EvaluateScatter_BatchDynamicUpdateSlice)2659 TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
2660   const char* hlo_text = R"(
2661 HloModule BatchDynamicUpdateSlice
2662 
2663 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2664   lhs = s32[] parameter(0)
2665   ROOT rhs = s32[] parameter(1)
2666 }
2667 
2668 ENTRY main {
2669   operand = s32[3,3] parameter(0)
2670   indices = s32[2,2] parameter(1)
2671   updates = s32[2,1,1] parameter(2)
2672   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2673       to_apply=update_s32,
2674       update_window_dims={1,2},
2675       inserted_window_dims={},
2676       scatter_dims_to_operand_dims={0,1},
2677       index_vector_dim=0
2678 }
2679 )";
2680   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2681   Literal operand =
2682       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2683   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
2684   Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
2685   Literal expected =
2686       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
2687   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2688                           Evaluate({&operand, &scatter_indices, &updates}));
2689   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2690 }
2691 
TEST_F(HloEvaluatorTest,EvaluateScatter_ZeroDimBounds)2692 TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
2693   const char* hlo_text = R"(
2694 HloModule TensorFlowScatter_ZeroDimBounds
2695 
2696 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2697   lhs = s32[] parameter(0)
2698   ROOT rhs = s32[] parameter(1)
2699 }
2700 
2701 ENTRY main {
2702   operand = s32[3,0] parameter(0)
2703   indices = s32[2] parameter(1)
2704   updates = s32[2,0] parameter(2)
2705   ROOT scatter = s32[3,0] scatter(operand, indices, updates),
2706       to_apply=update_s32,
2707       update_window_dims={1},
2708       inserted_window_dims={0},
2709       scatter_dims_to_operand_dims={0},
2710       index_vector_dim=1
2711 }
2712 )";
2713   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2714   Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
2715   Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
2716   Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
2717   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2718                           Evaluate({&operand, &scatter_indices, &updates}));
2719   EXPECT_TRUE(LiteralTestUtil::Equal(operand, result));
2720 }
2721 
TEST_F(HloEvaluatorTest,EvaluateScatter_NoUpdateWindowDims)2722 TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
2723   const string hlo_text = R"(
2724 HloModule Scatter_NoUpdateWindowDims
2725 
2726 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2727   lhs = s32[] parameter(0)
2728   rhs = s32[] parameter(1)
2729   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2730 }
2731 
2732 ENTRY main {
2733   operand = s32[3] parameter(0)
2734   indices = s32[2,2,1] parameter(1)
2735   updates = s32[2,2] parameter(2)
2736   ROOT scatter = s32[3] scatter(operand, indices, updates),
2737       to_apply=add_s32,
2738       update_window_dims={},
2739       inserted_window_dims={0},
2740       scatter_dims_to_operand_dims={0},
2741       index_vector_dim=2
2742 }
2743 )";
2744   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2745 
2746   Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
2747   Literal scatter_indices =
2748       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
2749   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
2750   Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
2751   TF_ASSERT_OK_AND_ASSIGN(Literal result,
2752                           Evaluate({&operand, &scatter_indices, &updates}));
2753   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2754 }
2755 
TEST_F(HloEvaluatorTest,EvaluateScatter_NegativeIndices)2756 TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
2757   const char* hlo_text = R"(
2758 HloModule TensorFlowScatter_NegativeIndices
2759 
2760 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2761   lhs = s32[] parameter(0)
2762   rhs = s32[] parameter(1)
2763   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
2764 }
2765 
2766 ENTRY main {
2767   operand = s32[3,3] parameter(0)
2768   indices = s32[2] parameter(1)
2769   updates = s32[2,3] parameter(2)
2770   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2771       to_apply=add_s32,
2772       update_window_dims={1},
2773       inserted_window_dims={0},
2774       scatter_dims_to_operand_dims={0},
2775       index_vector_dim=1
2776 }
2777 )";
2778   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2779                           ParseAndReturnVerifiedModule(hlo_text));
2780   Literal operand =
2781       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2782   // No updates should happen for the negative indices.
2783   Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2});
2784   Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
2785   EXPECT_TRUE(LiteralTestUtil::Equal(
2786       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
2787       EvaluateWithModule(module.get(),
2788                          {&operand, &scatter_indices, &updates})));
2789 }
2790 
TEST_F(HloEvaluatorTest,EvaluateScatter_OobIndices)2791 TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) {
2792   const string hlo_text = R"(
2793 HloModule BatchDynamicUpdateSlice
2794 
2795 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2796   lhs = s32[] parameter(0)
2797   ROOT rhs = s32[] parameter(1)
2798 }
2799 
2800 ENTRY main {
2801   operand = s32[3,3]{1,0} parameter(0)
2802   indices = s32[6,2]{1,0} parameter(1)
2803   updates = s32[6,1,1]{2,1,0} parameter(2)
2804   ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
2805       to_apply=update_s32,
2806       update_window_dims={1,2},
2807       inserted_window_dims={},
2808       scatter_dims_to_operand_dims={0,1},
2809       index_vector_dim=1
2810 }
2811 )";
2812   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2813                           ParseAndReturnVerifiedModule(hlo_text));
2814   Literal operand =
2815       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
2816   // No updates should happen for the OOB indices.
2817   Literal scatter_indices = LiteralUtil::CreateR2<int32>(
2818       {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
2819   Literal updates = LiteralUtil::CreateR3<int32>(
2820       {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
2821   EXPECT_TRUE(LiteralTestUtil::Equal(
2822       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
2823       EvaluateWithModule(module.get(),
2824                          {&operand, &scatter_indices, &updates})));
2825 }
2826 
TEST_F(HloEvaluatorTest,EvaluateScatter_OobUpdateWindow)2827 TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
2828   const char* hlo_text = R"(
2829 HloModule TensorFlowScatterNd_OobUpdateWindow
2830 
2831 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2832   lhs = s32[] parameter(0)
2833   ROOT rhs = s32[] parameter(1)
2834 }
2835 
2836 ENTRY main {
2837   operand = s32[3,3,2] parameter(0)
2838   indices = s32[1,2] parameter(1)
2839   updates = s32[1,2,2] parameter(2)
2840   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
2841       to_apply=update_s32,
2842       update_window_dims={1,2},
2843       inserted_window_dims={0},
2844       scatter_dims_to_operand_dims={0,1},
2845       index_vector_dim=1
2846 }
2847 )";
2848   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2849                           ParseAndReturnVerifiedModule(hlo_text));
2850   Literal operand =
2851       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
2852                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
2853                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
2854   Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
2855   Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
2856   // Given the update window size of 2,2 and the index of 0,2, the update window
2857   // will be OOB. So, nothing should be updated.
2858   Literal expected = operand.Clone();
2859   EXPECT_TRUE(LiteralTestUtil::Equal(
2860       expected, EvaluateWithModule(module.get(),
2861                                    {&operand, &scatter_indices, &updates})));
2862 }
2863 
2864 // Verifies that HloEvaluator evaluates a HLO instruction that performs
2865 // element-wise comparison with 2 bfloat16 operands.
TEST_F(HloEvaluatorTest,DoesCompareBF16)2866 TEST_F(HloEvaluatorTest, DoesCompareBF16) {
2867   // lhs >= rhs
2868   auto lhs = LiteralUtil::CreateR2<bfloat16>(
2869       {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
2870        {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
2871   auto rhs = LiteralUtil::CreateR2<bfloat16>(
2872       {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
2873        {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
2874   auto expected =
2875       LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
2876 
2877   HloComputation::Builder b(TestName());
2878   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
2879   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
2880   b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
2881                                                  ComparisonDirection::kGe));
2882   m_->AddEntryComputation(b.Build());
2883 
2884   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2885   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2886 }
2887 
TEST_P(HloEvaluatorBf16Test,Bf16Reduction)2888 TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
2889   const string hlo_text = R"(
2890 HloModule Bf16Reduction
2891 
2892 add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
2893   lhs = bf16[] parameter(0)
2894   rhs = bf16[] parameter(1)
2895   ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
2896 }
2897 
2898 ENTRY main {
2899   arg0 = bf16[4]{0} parameter(0)
2900   init = bf16[] constant(0)
2901   ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
2902 }
2903 )";
2904   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2905 
2906   Literal arg = LiteralUtil::CreateR1<bfloat16>(
2907       {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
2908   Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
2909   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
2910   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2911 }
2912 
TEST_F(HloEvaluatorTest,DontFailOnCallUnimplementedOps)2913 TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) {
2914   // Infeed triggers unimplemented error within HandleCall, and we verify that
2915   // the Evaluator does fail in such case.
2916   const string hlo_text = R"(
2917 HloModule DontFailOnCall
2918 
2919 call {
2920   token0 = token[] after-all()
2921   ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
2922 }
2923 
2924 ENTRY main {
2925   ROOT result = ((u32[3]{0}, pred[]), token[]) call(), to_apply=call
2926 }
2927 )";
2928   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2929   auto statusor = Evaluate();
2930   EXPECT_FALSE(statusor.status().ok());
2931 }
2932 
TEST_F(HloEvaluatorTest,DontFailOnFusionWithUnimplementedOps)2933 TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) {
2934   // Infeed triggers unimplemented error within HandleFusion, and we verify that
2935   // the Evaluator does fail in such case.
2936   const string hlo_text = R"(
2937 HloModule DontFailOnFusion
2938 
2939 fused_computation {
2940   token0 = token[] after-all()
2941   ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
2942 }
2943 
2944 ENTRY main {
2945   ROOT result = ((u32[3]{0}, pred[]), token[]) fusion(), kind=kLoop, calls=fused_computation
2946 }
2947 )";
2948   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2949   auto statusor = Evaluate();
2950   EXPECT_FALSE(statusor.status().ok());
2951 }
2952 
TEST_P(HloEvaluatorBf16Test,SliceWithDifferentLayout)2953 TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) {
2954   // Regression test for b/114735354.
2955   const string hlo_text = R"(
2956 HloModule SliceWithDifferentLayout
2957 
2958 ENTRY main {
2959   arg = f32[2,2,2]{0,1,2} parameter(0)
2960   ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
2961 }
2962 )";
2963   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2964 
2965   Literal arg = LiteralUtil::CreateR3WithLayout<float>(
2966       {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
2967       LayoutUtil::MakeLayout({0, 1, 2}));
2968   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg}));
2969   EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
2970 }
2971 
TEST_P(HloEvaluatorBf16Test,Bitcast)2972 TEST_P(HloEvaluatorBf16Test, Bitcast) {
2973   // Regression test for b/114735354.
2974   constexpr absl::string_view hlo_text_base = R"(
2975 HloModule Bitcast
2976 
2977 ENTRY main {
2978   param = %s[32,121]{1,0} parameter(0)
2979   ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param)
2980 }
2981 )";
2982   string hlo_text;
2983   if (use_bfloat16_) {
2984     hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
2985   } else {
2986     hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
2987   }
2988   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2989   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
2990   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
2991   if (use_bfloat16_) {
2992     EXPECT_TRUE(
2993         absl::c_equal(args[0].data<bfloat16>(), actual.data<bfloat16>()));
2994   } else {
2995     EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
2996   }
2997 }
2998 
2999 // Check that s32 under/overflow doesn't trigger a ubsan failure.
TEST_F(HloEvaluatorTest,Int32Overflow)3000 TEST_F(HloEvaluatorTest, Int32Overflow) {
3001   constexpr absl::string_view hlo_text = R"(
3002 HloModule Test
3003 
3004 ENTRY main {
3005   c1 = s32[] constant(1073741824)  // 2^30
3006   sum = s32[] add(c1, c1)  // 2^31, i.e. INT_MIN
3007 
3008   c2 = s32[] constant(-2147483648)  // -2^31
3009   sub = s32[] subtract(c2, c1)  // -2^31 - 2^30, underflows
3010 
3011   mul = s32[] multiply(c1, c1)
3012   ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul)
3013 }
3014 )";
3015   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3016   TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({}));
3017   std::vector<Literal> actual = literal.DecomposeTuple();
3018   ASSERT_EQ(actual.size(), 3);
3019 
3020   uint32 pow30 = uint32{1} << 30;
3021   uint32 pow31 = uint32{1} << 31;
3022   EXPECT_EQ(actual[0].GetFirstElement<int32>(), static_cast<int32>(pow31));
3023   EXPECT_EQ(actual[1].GetFirstElement<int32>(),
3024             static_cast<int32>(-(pow31 + pow30)));
3025   EXPECT_EQ(actual[2].GetFirstElement<int32>(),
3026             static_cast<int32>(pow31 * pow31));
3027 }
3028 
TEST_F(HloEvaluatorTest,GetDimensionSize)3029 TEST_F(HloEvaluatorTest, GetDimensionSize) {
3030   constexpr absl::string_view hlo_text = R"(
3031 HloModule Test
3032 
3033 ENTRY main {
3034   size = u32[] parameter(0)
3035 
3036   data = s32[4] parameter(1)
3037 
3038   sum = s32[4] add(data, data)
3039 
3040   ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0}
3041 }
3042 )";
3043   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3044 
3045   // Set up dynamic parameter binding.
3046   TF_CHECK_OK(m_->dynamic_parameter_binding().Bind(
3047       DynamicParameterBinding::DynamicParameter{0, {}},
3048       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
3049 
3050   TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference,
3051                           DynamicDimensionInference::Run(m_.get()));
3052 
3053   evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
3054   Literal size_arg = LiteralUtil::CreateR0<uint32>(3);
3055   Literal data_arg = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
3056 
3057   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
3058 
3059   EXPECT_EQ(actual.GetFirstElement<uint32>(), static_cast<uint32>(3));
3060 }
3061 
3062 // Check that we get a useful error if we pass inputs of the wrong shape.
TEST_F(HloEvaluatorTest,EvaluateWithWrongInputShapes)3063 TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) {
3064   constexpr absl::string_view hlo_text = R"(
3065 HloModule Test
3066 
3067 ENTRY main {
3068   p0 = s32[1] parameter(0)
3069   ROOT sum = s32[1] add(p0, p0)
3070 }
3071 )";
3072   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3073   Literal input_wrong_shape = LiteralUtil::CreateR1<int32>({0, 1});
3074 
3075   EXPECT_EQ(HloEvaluator()
3076                 .Evaluate(*m_, {&input_wrong_shape})
3077                 .status()
3078                 .error_message(),
3079             "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
3080             "but arg was s32[2].");
3081   EXPECT_EQ(HloEvaluator()
3082                 .Evaluate(*m_->entry_computation(), {&input_wrong_shape})
3083                 .status()
3084                 .error_message(),
3085             "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
3086             "but arg was s32[2].");
3087 }
3088 
3089 // Check that we get a useful error if we pass too many or too few inputs.
TEST_F(HloEvaluatorTest,EvaluateWithWrongNumberOfInputs)3090 TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) {
3091   constexpr absl::string_view hlo_text = R"(
3092 HloModule Test
3093 
3094 ENTRY main {
3095   p0 = s32[1] parameter(0)
3096   ROOT sum = s32[1] add(p0, p0)
3097 }
3098 )";
3099   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3100   Literal input = LiteralUtil::CreateR1<int32>({0});
3101 
3102   EXPECT_EQ(
3103       HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(),
3104       "Expected 1 argument, but got 2.");
3105   EXPECT_EQ(HloEvaluator()
3106                 .Evaluate(*m_->entry_computation(), {&input, &input})
3107                 .status()
3108                 .error_message(),
3109             "Expected 1 argument, but got 2.");
3110 }
3111 
TEST_F(HloEvaluatorTest,PreserveFusionInputLayout)3112 TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) {
3113   constexpr absl::string_view hlo_text = R"(
3114     HloModule FusionInputLayout
3115 
3116     fused_computation {
3117       param_0 = f32[20,20]{0,1} parameter(0)
3118       ROOT bitcast = f32[20,20]{1,0} bitcast(param_0)
3119     }
3120 
3121     ENTRY kernel_entry {
3122       parameter.0 = f32[20,20]{0,1} parameter(0)
3123       ROOT fusion = f32[20,20]{1,0} fusion(parameter.0),
3124         kind=kLoop, calls=fused_computation
3125     })";
3126 
3127   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3128   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3129 
3130   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
3131   EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
3132 }
3133 
TEST_F(HloEvaluatorTest,PreserveFusionOutputLayout)3134 TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) {
3135   constexpr absl::string_view hlo_text = R"(
3136     HloModule FusionOutputLayout
3137 
3138     fused_computation {
3139       param_0 = f32[20,20]{1,0} parameter(0)
3140       ROOT bitcast = f32[20,20]{0,1} bitcast(param_0)
3141     }
3142 
3143     ENTRY kernel_entry {
3144       parameter.0 = f32[20,20]{1,0} parameter(0)
3145       ROOT fusion = f32[20,20]{0,1} fusion(parameter.0),
3146         kind=kLoop, calls=fused_computation
3147     })";
3148 
3149   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3150   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3151   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
3152   EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
3153 }
3154 
TEST_F(HloEvaluatorTest,PreserveMOFusionOutputLayout)3155 TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) {
3156   constexpr absl::string_view hlo_text = R"(
3157     HloModule MOFusionOutputLayout
3158 
3159     fused_computation {
3160       param_0 = f32[20,20]{1,0} parameter(0)
3161       bitcast = f32[20,20]{0,1} bitcast(param_0)
3162       ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast)
3163     }
3164 
3165     ENTRY kernel_entry {
3166       parameter.0 = f32[20,20]{1,0} parameter(0)
3167       ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0),
3168         kind=kLoop, calls=fused_computation
3169     })";
3170 
3171   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3172   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3173   TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]}));
3174   std::vector<Literal> actual_literals = actual_tuple.DecomposeTuple();
3175   EXPECT_TRUE(
3176       absl::c_equal(args[0].data<float>(), actual_literals[0].data<float>()));
3177 }
3178 
3179 // Tests that custom_calls fail to evaluate when no handler is specified.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_NoHandler)3180 TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) {
3181   constexpr absl::string_view hlo_text = R"(
3182     HloModule EvaluateCustomCall_NoHandler
3183     ENTRY kernel_entry {
3184       parameter.0 = u32[2,2]{1,0} parameter(0)
3185       ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
3186           custom_call_target="_my_custom_call"
3187     }
3188   )";
3189 
3190   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3191   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3192   EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(),
3193             ::tensorflow::error::UNIMPLEMENTED);
3194 }
3195 
3196 // Tests when a custom_call handler returns an error.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_HandlerError)3197 TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) {
3198   constexpr absl::string_view hlo_text = R"(
3199     HloModule EvaluateCustomCall_HandlerError
3200     ENTRY kernel_entry {
3201       parameter.0 = u32[2,2]{1,0} parameter(0)
3202       ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
3203           custom_call_target="_my_custom_call"
3204     }
3205   )";
3206 
3207   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3208   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3209   HloEvaluator evaluator;
3210   evaluator.set_custom_call_handler(
3211       [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
3212         return InternalError("Test error");
3213       });
3214   EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(),
3215             ::tensorflow::error::INTERNAL);
3216 }
3217 
3218 // Tests the custom_call handler on calls with many inputs.
3219 // We sum the operands so that we can verify the operand and output literals
3220 // are properly mapped for access.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_ManyInputs)3221 TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) {
3222   constexpr absl::string_view hlo_text = R"(
3223     HloModule EvaluateCustomCall_ManyInputs
3224     ENTRY kernel_entry {
3225       parameter.0 = u32[1]{0} parameter(0)
3226       parameter.1 = u32[1]{0} parameter(1)
3227       ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1),
3228           custom_call_target="_my_custom_call"
3229     }
3230   )";
3231 
3232   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3233   auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie();
3234   HloEvaluator evaluator;
3235   evaluator.set_custom_call_handler(
3236       [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
3237         EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode());
3238         EXPECT_EQ("_my_custom_call", custom_call->custom_call_target());
3239         EXPECT_EQ(2, custom_call->operand_count());
3240         EXPECT_EQ(2, operands.size());
3241         auto output = Literal::CreateFromShape(custom_call->shape());
3242         auto operand0_data = operands[0]->data<uint32>();
3243         auto operand1_data = operands[1]->data<uint32>();
3244         auto output_data = output.data<uint32>();
3245         output_data[0] = operand0_data[0] + operand1_data[0];
3246         return output;
3247       });
3248   TF_ASSERT_OK_AND_ASSIGN(
3249       Literal actual_literal,
3250       evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]}));
3251   auto arg0_data = args[0].data<uint32>();
3252   auto arg1_data = args[1].data<uint32>();
3253   std::vector<uint32> expected_data = {arg0_data[0] + arg1_data[0]};
3254   EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data<uint32>()));
3255 }
3256 
TEST_F(HloEvaluatorTest,IsFiniteF16)3257 TEST_F(HloEvaluatorTest, IsFiniteF16) {
3258   constexpr absl::string_view hlo_text = R"(
3259   HloModule test
3260 
3261   ENTRY IsFiniteTest {
3262     c = f16[6] constant({nan, 7, nan, -1, inf, -inf})
3263     ROOT is-finite = pred[6] is-finite(c)
3264   })";
3265 
3266   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3267   TF_ASSERT_OK_AND_ASSIGN(
3268       Literal actual_literal,
3269       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
3270   EXPECT_THAT(actual_literal.data<bool>(),
3271               ::testing::ElementsAre(false, true, false, true, false, false));
3272 }
3273 
TEST_F(HloEvaluatorTest,IsFiniteBf16)3274 TEST_F(HloEvaluatorTest, IsFiniteBf16) {
3275   constexpr absl::string_view hlo_text = R"(
3276   HloModule test
3277 
3278   ENTRY IsFiniteTest {
3279     c = bf16[6] constant({nan, 7, nan, -1, inf, -inf})
3280     ROOT is-finite = pred[6] is-finite(c)
3281   })";
3282 
3283   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3284   TF_ASSERT_OK_AND_ASSIGN(
3285       Literal actual_literal,
3286       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
3287   EXPECT_THAT(actual_literal.data<bool>(),
3288               ::testing::ElementsAre(false, true, false, true, false, false));
3289 }
3290 
3291 }  // namespace
3292 }  // namespace xla
3293