1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests the select-and-scatter XLA operation.
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/padding.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace xla {
39 namespace {
40 
41 struct SelectAndScatterTestParam {
42   std::vector<int64> operand_shape;
43   std::vector<int64> source_shape;
44   Padding padding_type;
45   absl::Span<const int64> window_dimensions;
46   absl::Span<const int64> window_strides;
47 };
48 
49 class SelectAndScatterTest
50     : public ClientLibraryTestBase,
51       public ::testing::WithParamInterface<SelectAndScatterTestParam> {
52  public:
SelectAndScatterTest()53   SelectAndScatterTest() : builder_(TestName()) {
54     // Create S32 GE and ADD computations for select and scatter respectively.
55     ge_s32_ = CreateScalarGeComputation(S32, &builder_);
56     add_s32_ = CreateScalarAddComputation(S32, &builder_);
57     ge_f32_ = CreateScalarGeComputation(F32, &builder_);
58     add_f32_ = CreateScalarAddComputation(F32, &builder_);
59     max_f32_ = CreateScalarMaxComputation(F32, &builder_);
60     min_f32_ = CreateScalarMinComputation(F32, &builder_);
61   }
62 
63   XlaBuilder builder_;
64   XlaComputation ge_s32_;
65   XlaComputation add_s32_;
66   XlaComputation ge_f32_;
67   XlaComputation add_f32_;
68   XlaComputation max_f32_;
69   XlaComputation min_f32_;
70 };
71 
XLA_TEST_P(SelectAndScatterTest,ParamTest)72 XLA_TEST_P(SelectAndScatterTest, ParamTest) {
73   auto operand_shape = GetParam().operand_shape;
74   Array<float> o(operand_shape);
75   o.FillRandom(1.5f);
76   auto operand = ConstantFromArray(&builder_, o);
77 
78   auto source_shape = GetParam().source_shape;
79   Array<float> s(source_shape);
80   s.FillRandom(12.0f);
81   auto source = ConstantFromArray(&builder_, s);
82 
83   SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions,
84                    GetParam().window_strides, GetParam().padding_type, source,
85                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
86 
87   ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5));
88 }
89 
90 INSTANTIATE_TEST_CASE_P(
91     SelectAndScatterTest_Instantiation, SelectAndScatterTest,
92     ::testing::Values(
93         SelectAndScatterTestParam{{6, 6, 6, 4, 4},
94                                   {3, 3, 3, 4, 4},
95                                   Padding::kSame,
96                                   {3, 3, 3, 1, 1},
97                                   {2, 2, 2, 1, 1}},
98         SelectAndScatterTestParam{{7, 7, 7, 4, 4},
99                                   {3, 3, 3, 4, 4},
100                                   Padding::kValid,
101                                   {3, 3, 3, 1, 1},
102                                   {2, 2, 2, 1, 1}},
103 
104         SelectAndScatterTestParam{{8, 8, 8, 4, 4},
105                                   {1, 3, 3, 4, 4},
106                                   Padding::kValid,
107                                   {8, 4, 4, 1, 1},
108                                   {1, 2, 2, 1, 1}},
109         SelectAndScatterTestParam{{6, 6, 256, 128},
110                                   {3, 3, 256, 128},
111                                   Padding::kSame,
112                                   {3, 3, 1, 1},
113                                   {2, 2, 1, 1}},
114         SelectAndScatterTestParam{{7, 7, 256, 128},
115                                   {3, 3, 256, 128},
116                                   Padding::kValid,
117                                   {3, 3, 1, 1},
118                                   {2, 2, 1, 1}},
119         SelectAndScatterTestParam{{6, 7, 256, 128},
120                                   {3, 3, 256, 128},
121                                   Padding::kValid,
122                                   {2, 3, 1, 1},
123                                   {2, 2, 1, 1}},
124         SelectAndScatterTestParam{{6, 7, 256, 128},
125                                   {2, 3, 256, 128},
126                                   Padding::kValid,
127                                   {2, 3, 1, 1},
128                                   {3, 2, 1, 1}},
129         SelectAndScatterTestParam{{9, 9, 16, 128},
130                                   {3, 3, 16, 128},
131                                   Padding::kValid,
132                                   {3, 3, 1, 1},
133                                   {3, 3, 1, 1}},
134         SelectAndScatterTestParam{{3, 3, 4, 4},
135                                   {1, 1, 4, 4},
136                                   Padding::kValid,
137                                   {3, 3, 1, 1},
138                                   {3, 3, 1, 1}},
139         SelectAndScatterTestParam{{3, 3, 4, 4},
140                                   {1, 1, 4, 4},
141                                   Padding::kValid,
142                                   {3, 3, 1, 1},
143                                   {3, 3, 1, 1}},
144         SelectAndScatterTestParam{{9, 3, 4, 4},
145                                   {3, 1, 4, 4},
146                                   Padding::kValid,
147                                   {3, 3, 1, 1},
148                                   {3, 3, 1, 1}},
149         // Uncovered by b/126212776.
150         SelectAndScatterTestParam{{15, 1, 1, 1},
151                                   {2, 1, 1, 1},
152                                   Padding::kValid,
153                                   {14, 1, 1, 1},
154                                   {1, 1, 1, 1}},
155         SelectAndScatterTestParam{{7, 3, 4, 4},
156                                   {3, 1, 4, 4},
157                                   Padding::kValid,
158                                   {3, 3, 1, 1},
159                                   {2, 3, 1, 1}},
160         SelectAndScatterTestParam{{1, 1, 5, 5},
161                                   {1, 1, 5, 5},
162                                   Padding::kSame,
163                                   {3, 3, 1, 1},
164                                   {3, 3, 1, 1}},
165         SelectAndScatterTestParam{{7, 7, 8, 256},
166                                   {4, 4, 8, 256},
167                                   Padding::kSame,
168                                   {2, 2, 1, 1},
169                                   {2, 2, 1, 1}},
170         SelectAndScatterTestParam{
171             {6, 4, 4}, {3, 4, 4}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
172         SelectAndScatterTestParam{
173             {6, 256, 128}, {3, 256, 128}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
174         SelectAndScatterTestParam{{7, 256, 128},
175                                   {3, 256, 128},
176                                   Padding::kValid,
177                                   {3, 1, 1},
178                                   {2, 1, 1}},
179         SelectAndScatterTestParam{{6, 256, 128},
180                                   {3, 256, 128},
181                                   Padding::kValid,
182                                   {2, 1, 1},
183                                   {2, 1, 1}},
184         SelectAndScatterTestParam{{6, 256, 128},
185                                   {2, 256, 128},
186                                   Padding::kValid,
187                                   {2, 1, 1},
188                                   {3, 1, 1}},
189         SelectAndScatterTestParam{
190             {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
191         SelectAndScatterTestParam{
192             {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
193         SelectAndScatterTestParam{
194             {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
195         SelectAndScatterTestParam{
196             {9, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
197         SelectAndScatterTestParam{
198             {7, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {2, 1, 1}},
199         SelectAndScatterTestParam{
200             {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}},
201         SelectAndScatterTestParam{
202             {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}}));
203 
204 // Test for F32 1D array, with a zero-element input.
XLA_TEST_F(SelectAndScatterTest,R1S0F32)205 XLA_TEST_F(SelectAndScatterTest, R1S0F32) {
206   const auto operand = ConstantR1<float>(&builder_, {});
207   const auto source = ConstantR1<float>(&builder_, {});
208   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
209                    /*window_strides=*/{3}, Padding::kValid, source,
210                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
211   ComputeAndCompareR1<float>(&builder_, {}, {}, ErrorSpec(1e-7));
212 }
213 
214 // Test for F32 1D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest,R1F32)215 XLA_TEST_F(SelectAndScatterTest, R1F32) {
216   const auto operand =
217       ConstantR1<float>(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f});
218   const auto source = ConstantR1<float>(&builder_, {34.f, 42.f});
219   const std::vector<float> expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f};
220   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
221                    /*window_strides=*/{3}, Padding::kValid, source,
222                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
223   ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
224 }
225 
226 // Test for S32 1D array, when windows do not overlap and the init value is 1.
XLA_TEST_F(SelectAndScatterTest,R1S32)227 XLA_TEST_F(SelectAndScatterTest, R1S32) {
228   const auto operand = ConstantR1<int32>(&builder_, {-1, 0, 6, 4, -4, 10});
229   const auto source = ConstantR1<int32>(&builder_, {-10, 20});
230   const std::vector<int32> expected = {1, 1, -9, 1, 1, 21};
231   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
232                    /*window_strides=*/{3}, Padding::kValid, source,
233                    ConstantR0<int32>(&builder_, 1), add_s32_);
234   ComputeAndCompareR1<int32>(&builder_, expected, {});
235 }
236 
237 // Test for S32 1D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest,R1S32OverlappingWindow)238 XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) {
239   const auto operand = ConstantR1<int32>(&builder_, {1, 9, 3, 7, 5, 6});
240   const auto source = ConstantR1<int32>(&builder_, {34, 42, 53, 19});
241   const std::vector<int32> expected = {0, 76, 0, 72, 0, 0};
242   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
243                    /*window_strides=*/{1}, Padding::kValid, source,
244                    ConstantR0<int32>(&builder_, 0), add_s32_);
245   ComputeAndCompareR1<int32>(&builder_, expected, {});
246 }
247 
248 // Test for S32 2D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest,R2S32)249 XLA_TEST_F(SelectAndScatterTest, R2S32) {
250   const auto operand =
251       ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}});
252   const auto source = ConstantR2<int32>(&builder_, {{2, 6}});
253   Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
254   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
255                    /*window_strides=*/{2, 3}, Padding::kValid, source,
256                    ConstantR0<int32>(&builder_, 0), add_s32_);
257   ComputeAndCompareR2<int32>(&builder_, expected, {});
258 }
259 
260 // Test for tie breaking rule in ge_f32_. When a tie is present, the operand
261 // that has the lower lexicographical order (smaller index) should be chosen.
XLA_TEST_F(SelectAndScatterTest,R2F32Tie)262 XLA_TEST_F(SelectAndScatterTest, R2F32Tie) {
263   const auto operand = ConstantR2<float>(
264       &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}});
265   const auto source = ConstantR2<float>(
266       &builder_, {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
267   Array2D<float> expected(
268       {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}});
269   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3},
270                    /*window_strides=*/{1, 1}, Padding::kSame, source,
271                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
272   ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
273 }
274 
275 // Similar to SelectAndScatterTest.R2S32 but the input is transposed.
XLA_TEST_F(SelectAndScatterTest,ReshapeR2S32)276 XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) {
277   const auto operand = ConstantR2<int32>(
278       &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}});
279   const auto reshape =
280       Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6});
281   const auto source = ConstantR2<int32>(&builder_, {{2, 6}});
282   Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
283   SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3},
284                    /*window_strides=*/{2, 3}, Padding::kValid, source,
285                    ConstantR0<int32>(&builder_, 0), add_s32_);
286   ComputeAndCompareR2<int32>(&builder_, expected, {});
287 }
288 
289 // Test for S32 2D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest,R2S32OverlappingWindow)290 XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) {
291   const auto operand =
292       ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
293   const auto source = ConstantR2<int32>(&builder_, {{2, 6, 4}});
294   Array2D<int32> expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}});
295   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
296                    /*window_strides=*/{1, 1}, Padding::kValid, source,
297                    ConstantR0<int32>(&builder_, 0), add_s32_);
298   ComputeAndCompareR2<int32>(&builder_, expected, {});
299 }
300 
301 // Test for S32 2D array, when the padding is Padding::kSAME.
XLA_TEST_F(SelectAndScatterTest,R2S32SamePadding)302 XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) {
303   const auto operand =
304       ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
305   const auto source = ConstantR2<int32>(&builder_, {{2, 6, 4}});
306   Array2D<int32> expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}});
307   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
308                    /*window_strides=*/{2, 2}, Padding::kSame, source,
309                    ConstantR0<int32>(&builder_, 0), add_s32_);
310   ComputeAndCompareR2<int32>(&builder_, expected, {});
311 }
312 
313 // Test for S32 2D array, when the padding is Padding::kSAME and windows overlap
314 // with each other.
XLA_TEST_F(SelectAndScatterTest,R2S32SamePaddingOverlappingWindow)315 XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) {
316   const auto operand =
317       ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
318   const auto source =
319       ConstantR2<int32>(&builder_, {{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}});
320   Array2D<int32> expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}});
321   SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
322                    /*window_strides=*/{1, 1}, Padding::kSame, source,
323                    ConstantR0<int32>(&builder_, 0), add_s32_);
324   ComputeAndCompareR2<int32>(&builder_, expected, {});
325 }
326 
XLA_TEST_F(SelectAndScatterTest,R2F32OverlappingR2Source)327 XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) {
328   const auto operand = ConstantR2<float>(
329       &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}});
330   const auto source =
331       ConstantR2<float>(&builder_, {{1.0f, 2.0f}, {3.0f, 4.0f}});
332   Array2D<float> expected(
333       {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}});
334   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2},
335                    /*window_strides=*/{1, 1}, Padding::kValid, source,
336                    ConstantR0<float>(&builder_, 0.0f), add_f32_);
337   ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
338 }
339 
TEST_F(SelectAndScatterTest,R4F32Valid)340 TEST_F(SelectAndScatterTest, R4F32Valid) {
341   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
342                         {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
343                         {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
344                         {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
345   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
346   Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f},
347                         {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
348                         {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
349                         {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}};
350   Array4D<float> o(4, 6, 15, 220);
351   o.FillWithPZ(pzo);
352   auto operand = ConstantR4FromArray4D(&builder_, o);
353   Array4D<float> e(4, 6, 15, 220);
354   e.FillWithPZ(pze);
355   Array4D<float> s(2, 2, 15, 220);
356   s.FillWithPZ(pzs);
357   auto source = ConstantR4FromArray4D(&builder_, s);
358   s.FillWithPZ(pzs);
359   SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
360                    Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
361                    add_f32_);
362   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
363 }
364 
TEST_F(SelectAndScatterTest,R4F32Overlap)365 TEST_F(SelectAndScatterTest, R4F32Overlap) {
366   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
367                         {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
368                         {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
369                         {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
370   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
371   Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
372                         {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
373                         {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
374                         {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
375   Array4D<float> o(4, 5, 17, 128);
376   o.FillWithPZ(pzo);
377   auto operand = ConstantR4FromArray4D(&builder_, o);
378   Array4D<float> e(4, 5, 17, 128);
379   e.FillWithPZ(pze);
380   Array4D<float> s(2, 2, 17, 128);
381   s.FillWithPZ(pzs);
382   auto source = ConstantR4FromArray4D(&builder_, s);
383   s.FillWithPZ(pzs);
384   SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
385                    Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
386                    add_f32_);
387   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
388 }
389 
TEST_F(SelectAndScatterTest,R4F32OverlapSmall)390 TEST_F(SelectAndScatterTest, R4F32OverlapSmall) {
391   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
392                         {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
393                         {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
394                         {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
395   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
396   Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
397                         {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
398                         {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
399                         {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
400   Array4D<float> o(4, 5, 1, 1);
401   o.FillWithPZ(pzo);
402   auto operand = ConstantR4FromArray4D(&builder_, o);
403   Array4D<float> e(4, 5, 1, 1);
404   e.FillWithPZ(pze);
405   Array4D<float> s(2, 2, 1, 1);
406   s.FillWithPZ(pzs);
407   auto source = ConstantR4FromArray4D(&builder_, s);
408   s.FillWithPZ(pzs);
409   SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
410                    Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
411                    add_f32_);
412   ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
413 }
414 
TEST_F(SelectAndScatterTest,R4F32RefValidFixedSmall)415 TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) {
416   // This test is testing the Reference Util
417   Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
418                         {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
419                         {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
420                         {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
421   Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
422   Array4D<float> o(4, 6, 4, 4);
423   o.FillWithPZ(pzo);
424   auto operand = ConstantR4FromArray4D(&builder_, o);
425   Array4D<float> s(2, 2, 4, 4);
426   s.FillWithPZ(pzs);
427 
428   auto source = ConstantR4FromArray4D(&builder_, s);
429   s.FillWithPZ(pzs);
430   SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
431                    Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
432                    add_f32_);
433   auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1},
434                                                    {2, 3, 1, 1}, false);
435   ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
436 }
437 
XLA_TEST_F(SelectAndScatterTest,R1F32OverlappingWindowMaxScatter)438 XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) {
439   const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
440   const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
441   const std::vector<float> expected = {0, 0, 0, 53, 0, 0, 0};
442   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
443                    /*window_strides=*/{1}, Padding::kValid, source,
444                    ConstantR0<float>(&builder_, 0), max_f32_);
445   ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
446 }
447 
XLA_TEST_F(SelectAndScatterTest,R1F32OverlappingWindowMinScatter)448 XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) {
449   const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
450   const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
451   const float max_float = std::numeric_limits<float>::max();
452   const std::vector<float> expected = {max_float, max_float, max_float, 19,
453                                        max_float, max_float, max_float};
454   SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
455                    /*window_strides=*/{1}, Padding::kValid, source,
456                    ConstantR0<float>(&builder_, max_float), min_f32_);
457   ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
458 }
459 
460 }  // namespace
461 }  // namespace xla
462