1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests the select-and-scatter XLA operation.
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/padding.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace xla {
39 namespace {
40
41 struct SelectAndScatterTestParam {
42 std::vector<int64> operand_shape;
43 std::vector<int64> source_shape;
44 Padding padding_type;
45 absl::Span<const int64> window_dimensions;
46 absl::Span<const int64> window_strides;
47 };
48
49 class SelectAndScatterTest
50 : public ClientLibraryTestBase,
51 public ::testing::WithParamInterface<SelectAndScatterTestParam> {
52 public:
SelectAndScatterTest()53 SelectAndScatterTest() : builder_(TestName()) {
54 // Create S32 GE and ADD computations for select and scatter respectively.
55 ge_s32_ = CreateScalarGeComputation(S32, &builder_);
56 add_s32_ = CreateScalarAddComputation(S32, &builder_);
57 ge_f32_ = CreateScalarGeComputation(F32, &builder_);
58 add_f32_ = CreateScalarAddComputation(F32, &builder_);
59 max_f32_ = CreateScalarMaxComputation(F32, &builder_);
60 min_f32_ = CreateScalarMinComputation(F32, &builder_);
61 }
62
63 XlaBuilder builder_;
64 XlaComputation ge_s32_;
65 XlaComputation add_s32_;
66 XlaComputation ge_f32_;
67 XlaComputation add_f32_;
68 XlaComputation max_f32_;
69 XlaComputation min_f32_;
70 };
71
XLA_TEST_P(SelectAndScatterTest,ParamTest)72 XLA_TEST_P(SelectAndScatterTest, ParamTest) {
73 auto operand_shape = GetParam().operand_shape;
74 Array<float> o(operand_shape);
75 o.FillRandom(1.5f);
76 auto operand = ConstantFromArray(&builder_, o);
77
78 auto source_shape = GetParam().source_shape;
79 Array<float> s(source_shape);
80 s.FillRandom(12.0f);
81 auto source = ConstantFromArray(&builder_, s);
82
83 SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions,
84 GetParam().window_strides, GetParam().padding_type, source,
85 ConstantR0<float>(&builder_, 0.0f), add_f32_);
86
87 ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5));
88 }
89
90 INSTANTIATE_TEST_CASE_P(
91 SelectAndScatterTest_Instantiation, SelectAndScatterTest,
92 ::testing::Values(
93 SelectAndScatterTestParam{{6, 6, 6, 4, 4},
94 {3, 3, 3, 4, 4},
95 Padding::kSame,
96 {3, 3, 3, 1, 1},
97 {2, 2, 2, 1, 1}},
98 SelectAndScatterTestParam{{7, 7, 7, 4, 4},
99 {3, 3, 3, 4, 4},
100 Padding::kValid,
101 {3, 3, 3, 1, 1},
102 {2, 2, 2, 1, 1}},
103
104 SelectAndScatterTestParam{{8, 8, 8, 4, 4},
105 {1, 3, 3, 4, 4},
106 Padding::kValid,
107 {8, 4, 4, 1, 1},
108 {1, 2, 2, 1, 1}},
109 SelectAndScatterTestParam{{6, 6, 256, 128},
110 {3, 3, 256, 128},
111 Padding::kSame,
112 {3, 3, 1, 1},
113 {2, 2, 1, 1}},
114 SelectAndScatterTestParam{{7, 7, 256, 128},
115 {3, 3, 256, 128},
116 Padding::kValid,
117 {3, 3, 1, 1},
118 {2, 2, 1, 1}},
119 SelectAndScatterTestParam{{6, 7, 256, 128},
120 {3, 3, 256, 128},
121 Padding::kValid,
122 {2, 3, 1, 1},
123 {2, 2, 1, 1}},
124 SelectAndScatterTestParam{{6, 7, 256, 128},
125 {2, 3, 256, 128},
126 Padding::kValid,
127 {2, 3, 1, 1},
128 {3, 2, 1, 1}},
129 SelectAndScatterTestParam{{9, 9, 16, 128},
130 {3, 3, 16, 128},
131 Padding::kValid,
132 {3, 3, 1, 1},
133 {3, 3, 1, 1}},
134 SelectAndScatterTestParam{{3, 3, 4, 4},
135 {1, 1, 4, 4},
136 Padding::kValid,
137 {3, 3, 1, 1},
138 {3, 3, 1, 1}},
139 SelectAndScatterTestParam{{3, 3, 4, 4},
140 {1, 1, 4, 4},
141 Padding::kValid,
142 {3, 3, 1, 1},
143 {3, 3, 1, 1}},
144 SelectAndScatterTestParam{{9, 3, 4, 4},
145 {3, 1, 4, 4},
146 Padding::kValid,
147 {3, 3, 1, 1},
148 {3, 3, 1, 1}},
149 // Uncovered by b/126212776.
150 SelectAndScatterTestParam{{15, 1, 1, 1},
151 {2, 1, 1, 1},
152 Padding::kValid,
153 {14, 1, 1, 1},
154 {1, 1, 1, 1}},
155 SelectAndScatterTestParam{{7, 3, 4, 4},
156 {3, 1, 4, 4},
157 Padding::kValid,
158 {3, 3, 1, 1},
159 {2, 3, 1, 1}},
160 SelectAndScatterTestParam{{1, 1, 5, 5},
161 {1, 1, 5, 5},
162 Padding::kSame,
163 {3, 3, 1, 1},
164 {3, 3, 1, 1}},
165 SelectAndScatterTestParam{{7, 7, 8, 256},
166 {4, 4, 8, 256},
167 Padding::kSame,
168 {2, 2, 1, 1},
169 {2, 2, 1, 1}},
170 SelectAndScatterTestParam{
171 {6, 4, 4}, {3, 4, 4}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
172 SelectAndScatterTestParam{
173 {6, 256, 128}, {3, 256, 128}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
174 SelectAndScatterTestParam{{7, 256, 128},
175 {3, 256, 128},
176 Padding::kValid,
177 {3, 1, 1},
178 {2, 1, 1}},
179 SelectAndScatterTestParam{{6, 256, 128},
180 {3, 256, 128},
181 Padding::kValid,
182 {2, 1, 1},
183 {2, 1, 1}},
184 SelectAndScatterTestParam{{6, 256, 128},
185 {2, 256, 128},
186 Padding::kValid,
187 {2, 1, 1},
188 {3, 1, 1}},
189 SelectAndScatterTestParam{
190 {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
191 SelectAndScatterTestParam{
192 {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
193 SelectAndScatterTestParam{
194 {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
195 SelectAndScatterTestParam{
196 {9, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
197 SelectAndScatterTestParam{
198 {7, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {2, 1, 1}},
199 SelectAndScatterTestParam{
200 {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}},
201 SelectAndScatterTestParam{
202 {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}}));
203
204 // Test for F32 1D array, with a zero-element input.
XLA_TEST_F(SelectAndScatterTest,R1S0F32)205 XLA_TEST_F(SelectAndScatterTest, R1S0F32) {
206 const auto operand = ConstantR1<float>(&builder_, {});
207 const auto source = ConstantR1<float>(&builder_, {});
208 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
209 /*window_strides=*/{3}, Padding::kValid, source,
210 ConstantR0<float>(&builder_, 0.0f), add_f32_);
211 ComputeAndCompareR1<float>(&builder_, {}, {}, ErrorSpec(1e-7));
212 }
213
214 // Test for F32 1D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest,R1F32)215 XLA_TEST_F(SelectAndScatterTest, R1F32) {
216 const auto operand =
217 ConstantR1<float>(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f});
218 const auto source = ConstantR1<float>(&builder_, {34.f, 42.f});
219 const std::vector<float> expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f};
220 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
221 /*window_strides=*/{3}, Padding::kValid, source,
222 ConstantR0<float>(&builder_, 0.0f), add_f32_);
223 ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
224 }
225
226 // Test for S32 1D array, when windows do not overlap and the init value is 1.
XLA_TEST_F(SelectAndScatterTest,R1S32)227 XLA_TEST_F(SelectAndScatterTest, R1S32) {
228 const auto operand = ConstantR1<int32>(&builder_, {-1, 0, 6, 4, -4, 10});
229 const auto source = ConstantR1<int32>(&builder_, {-10, 20});
230 const std::vector<int32> expected = {1, 1, -9, 1, 1, 21};
231 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
232 /*window_strides=*/{3}, Padding::kValid, source,
233 ConstantR0<int32>(&builder_, 1), add_s32_);
234 ComputeAndCompareR1<int32>(&builder_, expected, {});
235 }
236
237 // Test for S32 1D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest,R1S32OverlappingWindow)238 XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) {
239 const auto operand = ConstantR1<int32>(&builder_, {1, 9, 3, 7, 5, 6});
240 const auto source = ConstantR1<int32>(&builder_, {34, 42, 53, 19});
241 const std::vector<int32> expected = {0, 76, 0, 72, 0, 0};
242 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
243 /*window_strides=*/{1}, Padding::kValid, source,
244 ConstantR0<int32>(&builder_, 0), add_s32_);
245 ComputeAndCompareR1<int32>(&builder_, expected, {});
246 }
247
248 // Test for S32 2D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest,R2S32)249 XLA_TEST_F(SelectAndScatterTest, R2S32) {
250 const auto operand =
251 ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}});
252 const auto source = ConstantR2<int32>(&builder_, {{2, 6}});
253 Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
254 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
255 /*window_strides=*/{2, 3}, Padding::kValid, source,
256 ConstantR0<int32>(&builder_, 0), add_s32_);
257 ComputeAndCompareR2<int32>(&builder_, expected, {});
258 }
259
260 // Test for tie breaking rule in ge_f32_. When a tie is present, the operand
261 // that has the lower lexicographical order (smaller index) should be chosen.
XLA_TEST_F(SelectAndScatterTest,R2F32Tie)262 XLA_TEST_F(SelectAndScatterTest, R2F32Tie) {
263 const auto operand = ConstantR2<float>(
264 &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}});
265 const auto source = ConstantR2<float>(
266 &builder_, {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
267 Array2D<float> expected(
268 {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}});
269 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3},
270 /*window_strides=*/{1, 1}, Padding::kSame, source,
271 ConstantR0<float>(&builder_, 0.0f), add_f32_);
272 ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
273 }
274
275 // Similar to SelectAndScatterTest.R2S32 but the input is transposed.
XLA_TEST_F(SelectAndScatterTest,ReshapeR2S32)276 XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) {
277 const auto operand = ConstantR2<int32>(
278 &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}});
279 const auto reshape =
280 Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6});
281 const auto source = ConstantR2<int32>(&builder_, {{2, 6}});
282 Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
283 SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3},
284 /*window_strides=*/{2, 3}, Padding::kValid, source,
285 ConstantR0<int32>(&builder_, 0), add_s32_);
286 ComputeAndCompareR2<int32>(&builder_, expected, {});
287 }
288
289 // Test for S32 2D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest,R2S32OverlappingWindow)290 XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) {
291 const auto operand =
292 ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
293 const auto source = ConstantR2<int32>(&builder_, {{2, 6, 4}});
294 Array2D<int32> expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}});
295 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
296 /*window_strides=*/{1, 1}, Padding::kValid, source,
297 ConstantR0<int32>(&builder_, 0), add_s32_);
298 ComputeAndCompareR2<int32>(&builder_, expected, {});
299 }
300
301 // Test for S32 2D array, when the padding is Padding::kSAME.
XLA_TEST_F(SelectAndScatterTest,R2S32SamePadding)302 XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) {
303 const auto operand =
304 ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
305 const auto source = ConstantR2<int32>(&builder_, {{2, 6, 4}});
306 Array2D<int32> expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}});
307 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
308 /*window_strides=*/{2, 2}, Padding::kSame, source,
309 ConstantR0<int32>(&builder_, 0), add_s32_);
310 ComputeAndCompareR2<int32>(&builder_, expected, {});
311 }
312
313 // Test for S32 2D array, when the padding is Padding::kSAME and windows overlap
314 // with each other.
XLA_TEST_F(SelectAndScatterTest,R2S32SamePaddingOverlappingWindow)315 XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) {
316 const auto operand =
317 ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
318 const auto source =
319 ConstantR2<int32>(&builder_, {{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}});
320 Array2D<int32> expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}});
321 SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
322 /*window_strides=*/{1, 1}, Padding::kSame, source,
323 ConstantR0<int32>(&builder_, 0), add_s32_);
324 ComputeAndCompareR2<int32>(&builder_, expected, {});
325 }
326
XLA_TEST_F(SelectAndScatterTest,R2F32OverlappingR2Source)327 XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) {
328 const auto operand = ConstantR2<float>(
329 &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}});
330 const auto source =
331 ConstantR2<float>(&builder_, {{1.0f, 2.0f}, {3.0f, 4.0f}});
332 Array2D<float> expected(
333 {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}});
334 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2},
335 /*window_strides=*/{1, 1}, Padding::kValid, source,
336 ConstantR0<float>(&builder_, 0.0f), add_f32_);
337 ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
338 }
339
TEST_F(SelectAndScatterTest,R4F32Valid)340 TEST_F(SelectAndScatterTest, R4F32Valid) {
341 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
342 {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
343 {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
344 {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
345 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
346 Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f},
347 {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
348 {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
349 {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}};
350 Array4D<float> o(4, 6, 15, 220);
351 o.FillWithPZ(pzo);
352 auto operand = ConstantR4FromArray4D(&builder_, o);
353 Array4D<float> e(4, 6, 15, 220);
354 e.FillWithPZ(pze);
355 Array4D<float> s(2, 2, 15, 220);
356 s.FillWithPZ(pzs);
357 auto source = ConstantR4FromArray4D(&builder_, s);
358 s.FillWithPZ(pzs);
359 SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
360 Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
361 add_f32_);
362 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
363 }
364
TEST_F(SelectAndScatterTest,R4F32Overlap)365 TEST_F(SelectAndScatterTest, R4F32Overlap) {
366 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
367 {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
368 {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
369 {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
370 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
371 Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
372 {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
373 {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
374 {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
375 Array4D<float> o(4, 5, 17, 128);
376 o.FillWithPZ(pzo);
377 auto operand = ConstantR4FromArray4D(&builder_, o);
378 Array4D<float> e(4, 5, 17, 128);
379 e.FillWithPZ(pze);
380 Array4D<float> s(2, 2, 17, 128);
381 s.FillWithPZ(pzs);
382 auto source = ConstantR4FromArray4D(&builder_, s);
383 s.FillWithPZ(pzs);
384 SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
385 Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
386 add_f32_);
387 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
388 }
389
TEST_F(SelectAndScatterTest,R4F32OverlapSmall)390 TEST_F(SelectAndScatterTest, R4F32OverlapSmall) {
391 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
392 {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
393 {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
394 {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
395 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
396 Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
397 {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
398 {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
399 {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
400 Array4D<float> o(4, 5, 1, 1);
401 o.FillWithPZ(pzo);
402 auto operand = ConstantR4FromArray4D(&builder_, o);
403 Array4D<float> e(4, 5, 1, 1);
404 e.FillWithPZ(pze);
405 Array4D<float> s(2, 2, 1, 1);
406 s.FillWithPZ(pzs);
407 auto source = ConstantR4FromArray4D(&builder_, s);
408 s.FillWithPZ(pzs);
409 SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
410 Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
411 add_f32_);
412 ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
413 }
414
TEST_F(SelectAndScatterTest,R4F32RefValidFixedSmall)415 TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) {
416 // This test is testing the Reference Util
417 Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
418 {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
419 {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
420 {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
421 Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
422 Array4D<float> o(4, 6, 4, 4);
423 o.FillWithPZ(pzo);
424 auto operand = ConstantR4FromArray4D(&builder_, o);
425 Array4D<float> s(2, 2, 4, 4);
426 s.FillWithPZ(pzs);
427
428 auto source = ConstantR4FromArray4D(&builder_, s);
429 s.FillWithPZ(pzs);
430 SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
431 Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
432 add_f32_);
433 auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1},
434 {2, 3, 1, 1}, false);
435 ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
436 }
437
XLA_TEST_F(SelectAndScatterTest,R1F32OverlappingWindowMaxScatter)438 XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) {
439 const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
440 const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
441 const std::vector<float> expected = {0, 0, 0, 53, 0, 0, 0};
442 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
443 /*window_strides=*/{1}, Padding::kValid, source,
444 ConstantR0<float>(&builder_, 0), max_f32_);
445 ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
446 }
447
XLA_TEST_F(SelectAndScatterTest,R1F32OverlappingWindowMinScatter)448 XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) {
449 const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
450 const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
451 const float max_float = std::numeric_limits<float>::max();
452 const std::vector<float> expected = {max_float, max_float, max_float, 19,
453 max_float, max_float, max_float};
454 SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
455 /*window_strides=*/{1}, Padding::kValid, source,
456 ConstantR0<float>(&builder_, max_float), min_f32_);
457 ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
458 }
459
460 } // namespace
461 } // namespace xla
462