1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests that slice operations can be performed.
17
18 #include <numeric>
19 #include <vector>
20
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/array2d.h"
27 #include "tensorflow/compiler/xla/client/local_client.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace xla {
37 namespace {
38
39 class SliceTest : public ClientLibraryTestBase {};
40
TEST_F(SliceTest,Slice3x3x3_To_3x3x1_F32)41 TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
42 Array3D<float> values(3, 3, 3);
43 values.FillIota(0);
44
45 XlaBuilder builder(TestName());
46 auto original = ConstantR3FromArray3D<float>(&builder, values);
47 Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1});
48
49 Array3D<float> expected{
50 {{0.0}, {3.0}, {6.0}}, {{9.0}, {12.0}, {15.0}}, {{18.0}, {21.0}, {24.0}}};
51 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
52 }
53
TEST_F(SliceTest,Slice3x3x3_To_3x1x3_F32)54 TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) {
55 Array3D<float> values(3, 3, 3);
56 values.FillIota(0);
57
58 XlaBuilder builder(TestName());
59 auto original = ConstantR3FromArray3D<float>(&builder, values);
60 Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1});
61
62 Array3D<float> expected{
63 {{0.0, 1.0, 2.0}}, {{9.0, 10.0, 11.0}}, {{18.0, 19.0, 20.0}}};
64 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
65 }
66
TEST_F(SliceTest,Slice3x3x3_To_1x3x3_F32)67 TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) {
68 Array3D<float> values(3, 3, 3);
69 values.FillIota(0);
70
71 XlaBuilder builder(TestName());
72 auto original = ConstantR3FromArray3D<float>(&builder, values);
73 Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1});
74
75 Array3D<float> expected{
76 {{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}}}};
77 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
78 }
79
XLA_TEST_F(SliceTest,Slice0x0to0x0F32)80 XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
81 XlaBuilder builder(TestName());
82 auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 0));
83 Slice(original, {0, 0}, {0, 0}, {1, 1});
84
85 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
86 }
87
XLA_TEST_F(SliceTest,Slice0x20to0x5F32)88 XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
89 XlaBuilder builder(TestName());
90 auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 20));
91 Slice(original, {0, 15}, {0, 20}, {1, 1});
92
93 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
94 }
95
XLA_TEST_F(SliceTest,Slice3x0to2x0F32)96 XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
97 XlaBuilder builder(TestName());
98 auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(3, 0));
99 Slice(original, {1, 0}, {3, 0}, {1, 1});
100
101 ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
102 }
103
XLA_TEST_F(SliceTest,SliceQuadrantOf256x256)104 XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
105 Array2D<float> values(256, 256);
106 for (int row = 0; row < 256; ++row) {
107 for (int col = 0; col < 256; ++col) {
108 values(row, col) = (row << 10) | col;
109 }
110 }
111
112 XlaBuilder builder(TestName());
113 auto original = ConstantR2FromArray2D<float>(&builder, values);
114 Slice(original, {128, 128}, {256, 256}, {1, 1});
115
116 Array2D<float> expected(128, 128);
117 for (int row = 0; row < 128; ++row) {
118 for (int col = 0; col < 128; ++col) {
119 expected(row, col) = ((row + 128) << 10) | (col + 128);
120 }
121 }
122 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
123 }
124
125 // Tests: (f32[1,4096], starts={0, 3072}, limits={1, 4096}) -> f32[1,1024])
TEST_F(SliceTest,Slice_1x4096_To_1x1024)126 TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
127 Array2D<float> values(1, 4096);
128 std::iota(values.data(), values.data() + 4096, 0.0);
129
130 XlaBuilder builder(TestName());
131 auto original = ConstantR2FromArray2D<float>(&builder, values);
132 Slice(original, {0, 3072}, {1, 4096}, {1, 1});
133
134 Array2D<float> expected(1, 1024);
135 std::iota(expected.data(), expected.data() + 1024, 3072.0);
136 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
137 }
138
139 // Tests slice: (f32[16,4], starts={0, 0}, limits={16, 2}) -> f32[16,2]
TEST_F(SliceTest,Slice_16x4_To_16x2)140 TEST_F(SliceTest, Slice_16x4_To_16x2) {
141 Array2D<float> values(16, 4);
142 Array2D<float> expected(16, 2);
143 for (int row = 0; row < 16; ++row) {
144 for (int col = 0; col < 4; ++col) {
145 values(row, col) = (row << 10) | col;
146 if (col < 2) {
147 expected(row, col) = (row << 10) | col;
148 }
149 }
150 }
151 XlaBuilder builder(TestName());
152 auto original = ConstantR2FromArray2D<float>(&builder, values);
153 Slice(original, {0, 0}, {16, 2}, {1, 1});
154 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
155 }
156
157 // Tests: (f32[2, 2, 24, 256], starts = {1, 0, 8, 0}, ends = {2, 2, 16, 128}
TEST_F(SliceTest,SliceR4ThreeDimsMiddleMinor)158 TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
159 Array4D<float> values(2, 2, 24, 256);
160 values.FillRandom(3.14f);
161 auto expected = ReferenceUtil::Slice4D(
162 values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}});
163 XlaBuilder builder(TestName());
164 auto original = ConstantR4FromArray4D(&builder, values);
165 Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
166 ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
167 }
168
TEST_F(SliceTest,SliceOfReshape)169 TEST_F(SliceTest, SliceOfReshape) {
170 Array2D<int> values(2 * 3 * 24, 7);
171 values.FillIota(1);
172 XlaBuilder builder(TestName());
173 auto original = ConstantR2FromArray2D(&builder, values);
174 auto reshape = Reshape(original, {24, 3, 2, 7});
175 Slice(reshape, {0, 0, 0, 0}, {11, 3, 2, 7}, {1, 1, 1, 1});
176 ComputeAndCompare(&builder, {});
177 }
178
TEST_F(SliceTest,SliceOfCollapsingReshape)179 TEST_F(SliceTest, SliceOfCollapsingReshape) {
180 Array4D<int> values(2, 3, 5, 7);
181 values.FillIota(1);
182 XlaBuilder builder(TestName());
183 auto original = ConstantR4FromArray4D(&builder, values);
184 auto reshape = Reshape(original, {2 * 3 * 5, 7});
185 Slice(reshape, {0, 0}, {4, 7}, {1, 1});
186 ComputeAndCompare(&builder, {});
187 }
188
XLA_TEST_F(SliceTest,StridedSliceR4WithOutputLayout)189 XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
190 Array4D<float> values(2, 4, 6, 8);
191 values.FillRandom(3.14f);
192 auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
193 /*strides=*/{{1, 1, 2, 1}});
194 auto expected_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
195 *expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
196 XlaBuilder builder(TestName());
197 auto original = ConstantR4FromArray4D(&builder, values);
198 Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
199 ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001),
200 &expected_literal.shape());
201 }
202
203 struct R1Spec {
204 int64 input_dim0;
205 int64 slice_start;
206 int64 slice_limit;
207 int64 slice_stride;
208 };
209
210 // Parameterized test that generates R1 values, slices them according
211 // to the R1Spec, and compares the result with a computed version.
212 class SliceR1Test : public ClientLibraryTestBase,
213 public ::testing::WithParamInterface<R1Spec> {
214 protected:
215 template <typename NativeT>
Run(const R1Spec & spec)216 void Run(const R1Spec& spec) {
217 // This can't be an std::vector, since you can't grab a Span of a
218 // vector<bool>.
219 absl::InlinedVector<NativeT, 1> input(spec.input_dim0);
220 std::iota(input.begin(), input.end(), NativeT());
221 auto literal = LiteralUtil::CreateR1<NativeT>(input);
222
223 XlaBuilder builder(TestName());
224 auto original = Parameter(&builder, 0, literal.shape(), "p0");
225 Slice(original, {spec.slice_start}, {spec.slice_limit},
226 {spec.slice_stride});
227
228 // Ditto.
229 absl::InlinedVector<NativeT, 1> expected;
230 for (int i = spec.slice_start; i < spec.slice_limit;
231 i += spec.slice_stride) {
232 expected.push_back(i);
233 }
234
235 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
236 client_->TransferToServer(literal));
237 ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
238 }
239 };
240
241 // A version of SliceR1Test used to label and disable 'large' tests
242 class SliceR1LargeTest : public SliceR1Test {};
243
SliceR1TestDataToString(const::testing::TestParamInfo<R1Spec> & data)244 string SliceR1TestDataToString(const ::testing::TestParamInfo<R1Spec>& data) {
245 const R1Spec& spec = data.param;
246 return absl::StrFormat("%d_%d_%d_%d", spec.input_dim0, spec.slice_start,
247 spec.slice_limit, spec.slice_stride);
248 }
249
XLA_TEST_P(SliceR1Test,DoIt_F32)250 XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
251
XLA_TEST_P(SliceR1Test,DoIt_F64)252 XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); }
253
XLA_TEST_P(SliceR1Test,DoIt_U32)254 XLA_TEST_P(SliceR1Test, DoIt_U32) { Run<uint32>(GetParam()); }
255
XLA_TEST_P(SliceR1Test,DoIt_S32)256 XLA_TEST_P(SliceR1Test, DoIt_S32) { Run<int32>(GetParam()); }
257
XLA_TEST_P(SliceR1Test,DoIt_U64)258 XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
259
XLA_TEST_P(SliceR1Test,DoIt_S64)260 XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
261
XLA_TEST_P(SliceR1LargeTest,DoIt_F32)262 XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run<float>(GetParam()); }
263
XLA_TEST_P(SliceR1LargeTest,DoIt_F64)264 XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run<double>(GetParam()); }
265
XLA_TEST_P(SliceR1LargeTest,DoIt_U32)266 XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run<uint32>(GetParam()); }
267
XLA_TEST_P(SliceR1LargeTest,DoIt_S32)268 XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run<int32>(GetParam()); }
269
XLA_TEST_P(SliceR1LargeTest,DoIt_U64)270 XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run<uint64>(GetParam()); }
271
XLA_TEST_P(SliceR1LargeTest,DoIt_S64)272 XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run<int64>(GetParam()); }
273
XLA_TEST_P(SliceR1Test,DoIt_PRED)274 XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run<bool>(GetParam()); }
275
276 // Tests for R1 slice ops.
277 // The format for each testcase is {input size, start, limit, stride}.
278 // clang-format off
279 INSTANTIATE_TEST_CASE_P(
280 SliceR1TestInstantiation,
281 SliceR1Test,
282 ::testing::Values(
283 R1Spec{10, 0, 0, 1},
284 R1Spec{10, 7, 7, 1},
285 R1Spec{10, 0, 5, 1},
286 R1Spec{10, 3, 5, 1},
287 R1Spec{10, 0, 10, 1},
288 R1Spec{1024, 0, 5, 1},
289 R1Spec{1024, 3, 5, 1},
290 R1Spec{1024 + 17, 0, 5, 1},
291 R1Spec{1024 + 17, 3, 5, 1},
292 R1Spec{1024 + 17, 1024, 1024 + 6, 1},
293 R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1},
294 R1Spec{1024, 1024 - 4, 1024, 1},
295 R1Spec{4 * 1024, 7, 7 + 1024, 1},
296 R1Spec{4 * 1024, 0, 4 * 1024, 1},
297 R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1},
298 R1Spec{4 * 1024, 1024, 3 * 1024, 1},
299 R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1},
300 R1Spec{16 * 1024, 0, 5, 1},
301 R1Spec{16 * 1024, 3, 5, 1},
302 R1Spec{16 * 1024 + 17, 0, 5, 1},
303 R1Spec{16 * 1024 + 17, 3, 5, 1},
304 R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1},
305 R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1},
306 R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1},
307 R1Spec{64 * 1024, 0, 64 * 1024, 1},
308 R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1},
309 R1Spec{64 * 1024, 1024, 63 * 1024, 1},
310 R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1},
311 R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1},
312 R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1},
313 R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}
314 ),
315 SliceR1TestDataToString
316 );
317
318 // TODO(b/69425338): This uses too much memory on GPU.
319 #ifndef XLA_TEST_BACKEND_GPU
320 INSTANTIATE_TEST_CASE_P(
321 SliceR1TestBigSlicesInstantiation,
322 SliceR1LargeTest,
323 ::testing::Values(
324 R1Spec{
325 16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1},
326 R1Spec{
327 16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1},
328 R1Spec{
329 16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}
330 ),
331 SliceR1TestDataToString
332 );
333 #endif
334
335 INSTANTIATE_TEST_CASE_P(
336 SliceStridedR1TestInstantiation,
337 SliceR1Test,
338 ::testing::Values(
339 R1Spec{10, 2, 4, 2},
340 R1Spec{10, 0, 10, 2},
341 R1Spec{10, 0, 10, 3},
342 R1Spec{10, 0, 10, 4},
343 R1Spec{10, 0, 10, 5},
344 R1Spec{10, 0, 10, 10},
345 R1Spec{500, 200, 400, 7},
346 R1Spec{4096, 1, 4095, 3},
347 R1Spec{2047, 1024 - 24, 1024 + 160, 31},
348 R1Spec{2047, 1, 2046, 3 * 128},
349 R1Spec{4096, 1024 + 3, 4095, 500},
350 R1Spec{8192, 0, 8192, 1024 * 3 + 400},
351 R1Spec{1024 * 1024, 0, 1024 * 1024, 2},
352 R1Spec{1024 * 1024, 0, 1024 * 1024, 8},
353 R1Spec{1024 * 1024, 0, 1024 * 1024, 7},
354 R1Spec{1024 * 1024, 0, 1024 * 1024, 125},
355 R1Spec{1024 * 1024, 3, 1024 - 9, 2},
356 R1Spec{1024 * 1024, 3, 1024 - 9, 8},
357 R1Spec{1024 * 1024, 3, 1024 - 9, 7},
358 R1Spec{1024 * 1024, 3, 1024 - 9, 125},
359 R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 2},
360 R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 8},
361 R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 7},
362 R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 125},
363 R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 2},
364 R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 8},
365 R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 7},
366 R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125},
367 R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4097},
368 R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4093},
369 R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4097},
370 R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4093}
371 ),
372 SliceR1TestDataToString
373 );
374 // clang-format on
375
376 struct R2Spec {
377 int64 input_dim0;
378 int64 input_dim1;
379 std::array<int64, 2> slice_starts;
380 std::array<int64, 2> slice_limits;
381 std::array<int64, 2> slice_strides;
382 std::array<int64, 2> layout;
383 };
384
385 // Parameterized test that generates patterned R2 values, slices them according
386 // to the R2Spec, and compares the results with the ReferenceUtil version.
387 class SliceR2Test : public ClientLibraryTestBase,
388 public ::testing::WithParamInterface<R2Spec> {};
389
XLA_TEST_P(SliceR2Test,DoIt)390 XLA_TEST_P(SliceR2Test, DoIt) {
391 const R2Spec& spec = GetParam();
392 Array2D<int32> input(spec.input_dim0, spec.input_dim1);
393 input.FillUnique();
394 auto literal = LiteralUtil::CreateR2FromArray2DWithLayout(
395 input, LayoutUtil::MakeLayout(spec.layout));
396
397 XlaBuilder builder(TestName());
398 auto a = Parameter(&builder, 0, literal.shape(), "p0");
399 Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
400
401 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
402 client_->TransferToServer(literal));
403 std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
404 input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
405 ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
406 }
407
408 INSTANTIATE_TEST_CASE_P(
409 SliceR2TestInstantiation, SliceR2Test,
410 ::testing::Values(
411 R2Spec{4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, {{0, 1}}}, //
412 R2Spec{4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, {{1, 0}}}, //
413 R2Spec{16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, {{0, 1}}}, //
414 R2Spec{16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, {{1, 0}}}, //
415 R2Spec{256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}}, {{1, 0}}}, //
416 R2Spec{500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}}, {{1, 0}}}, //
417 R2Spec{500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}}, {{1, 0}}}, //
418 R2Spec{384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}}, {{1, 0}}}, //
419 R2Spec{357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}}, {{1, 0}}}, //
420 R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, {{0, 1}}}, //
421 R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, {{1, 0}}}, //
422 R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, {{0, 1}}}, //
423 R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, {{1, 0}}}, //
424 R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, {{0, 1}}}, //
425 R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, {{1, 0}}}, //
426 R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{3, 5}}, {{1, 0}}}, //
427 R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{3, 5}}, {{0, 1}}}, //
428 R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{5, 3}}, {{1, 0}}}, //
429 R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{5, 3}}, {{0, 1}}}, //
430 R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{1, 0}}}, //
431 R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, //
432 R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, //
433 R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, //
434 R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, //
435 R2Spec{
436 511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, //
437 R2Spec{
438 511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{0, 1}}}, //
439 R2Spec{
440 511, 513, {{129, 300}}, {{400, 500}}, {{129, 101}}, {{1, 0}}}, //
441 R2Spec{
442 511, 513, {{129, 300}}, {{400, 500}}, {{129, 101}}, {{0, 1}}}, //
443 R2Spec{
444 511, 1023, {{129, 257}}, {{500, 1000}}, {{129, 255}}, {{1, 0}}}, //
445 R2Spec{
446 511, 1023, {{129, 257}}, {{500, 1000}}, {{129, 255}}, {{0, 1}}}, //
447 R2Spec{511,
448 513,
449 {{129, 255}},
450 {{511 - 129, 513 - 140}},
451 {{13, 19}},
452 {{1, 0}}}, //
453 R2Spec{511,
454 513,
455 {{129, 255}},
456 {{511 - 129, 513 - 140}},
457 {{13, 19}},
458 {{0, 1}}} //
459 ));
460
461 struct R4Spec {
462 std::array<int64, 4> input_dims;
463 std::array<int64, 4> input_layout; // minor-to-major
464 std::array<int64, 4> slice_starts;
465 std::array<int64, 4> slice_limits;
466 std::array<int64, 4> slice_strides;
467 };
468
R4SpecToString(const::testing::TestParamInfo<R4Spec> & data)469 string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
470 const R4Spec& spec = data.param;
471 return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"),
472 "__layout_", absl::StrJoin(spec.input_layout, ""),
473 "__starts_", absl::StrJoin(spec.slice_starts, "x"),
474 "__limits_", absl::StrJoin(spec.slice_limits, "x"),
475 "__strides_", absl::StrJoin(spec.slice_strides, "x"));
476 }
477
478 class SliceR4Test : public ClientLibraryTestBase,
479 public ::testing::WithParamInterface<R4Spec> {
480 protected:
Run(const R4Spec & spec)481 void Run(const R4Spec& spec) {
482 Array4D<float> values(spec.input_dims[0], spec.input_dims[1],
483 spec.input_dims[2], spec.input_dims[3]);
484 values.FillIota(3.14159);
485 auto expected = ReferenceUtil::Slice4D(
486 values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
487 XlaBuilder builder(TestName());
488 auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
489 values, LayoutUtil::MakeLayout(spec.input_layout));
490 auto parameter = Parameter(&builder, 0, literal.shape(), "p0");
491 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
492 client_->TransferToServer(literal));
493 Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
494 ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
495 }
496 };
497
XLA_TEST_P(SliceR4Test,DoIt)498 XLA_TEST_P(SliceR4Test, DoIt) { Run(GetParam()); }
499
500 const R4Spec kR4SpecValues[] = {
501 R4Spec{{{2, 2, 2, 2}},
502 {{3, 2, 1, 0}},
503 {{0, 0, 0, 0}},
504 {{0, 0, 0, 0}},
505 {{1, 1, 1, 1}}}, //
506 R4Spec{{{3, 3, 4, 4}},
507 {{3, 2, 1, 0}},
508 {{0, 0, 0, 0}},
509 {{3, 3, 4, 4}},
510 {{1, 1, 2, 1}}}, //
511 R4Spec{{{2, 3, 16, 4}},
512 {{3, 2, 1, 0}},
513 {{0, 0, 0, 0}},
514 {{2, 3, 16, 4}},
515 {{1, 1, 3, 1}}}, //
516 R4Spec{{{4, 16, 3, 2}},
517 {{0, 1, 2, 3}},
518 {{1, 4, 1, 0}},
519 {{3, 12, 3, 2}},
520 {{1, 1, 3, 2}}}, //
521 R4Spec{{{2, 2, 257, 129}},
522 {{3, 2, 1, 0}},
523 {{1, 1, 62, 64}},
524 {{2, 2, 195, 129}},
525 {{1, 1, 3, 1}}}, //
526 R4Spec{{{3, 5, 257, 129}},
527 {{3, 2, 1, 0}},
528 {{1, 2, 61, 64}},
529 {{3, 5, 199, 129}},
530 {{1, 1, 3, 1}}}, //
531 R4Spec{{{5, 8, 257, 129}},
532 {{3, 2, 1, 0}},
533 {{2, 3, 60, 64}},
534 {{3, 5, 200, 68}},
535 {{1, 1, 1, 1}}}, //
536 R4Spec{{{8, 10, 256, 130}},
537 {{3, 2, 1, 0}},
538 {{1, 2, 60, 127}},
539 {{7, 9, 166, 129}},
540 {{4, 2, 3, 1}}}, //
541 R4Spec{{{2, 4, 8, 4}},
542 {{3, 2, 1, 0}},
543 {{1, 2, 0, 1}},
544 {{2, 4, 8, 3}},
545 {{1, 1, 7, 1}}}, //
546 R4Spec{{{10, 21, 256, 150}},
547 {{3, 2, 1, 0}},
548 {{1, 2, 9, 127}},
549 {{9, 16, 82, 133}},
550 {{3, 5, 7, 2}}}, //
551 R4Spec{{{15, 25, 256, 150}},
552 {{3, 2, 1, 0}},
553 {{4, 6, 19, 126}},
554 {{15, 25, 89, 135}},
555 {{5, 7, 7, 3}}}, //
556 R4Spec{{{2, 4, 256, 150}},
557 {{3, 2, 1, 0}},
558 {{1, 2, 29, 125}},
559 {{2, 4, 159, 145}},
560 {{1, 1, 7, 7}}}, //
561 R4Spec{{{2, 4, 256, 150}},
562 {{3, 2, 1, 0}},
563 {{1, 2, 39, 119}},
564 {{2, 4, 158, 145}},
565 {{1, 1, 7, 11}}}, //
566 R4Spec{{{1, 1, 5, 512}},
567 {{3, 2, 1, 0}},
568 {{0, 0, 0, 0}},
569 {{1, 1, 5, 512}},
570 {{1, 1, 4, 1}}}, //
571 R4Spec{{{1, 1, 513, 513}},
572 {{3, 2, 1, 0}},
573 {{0, 0, 0, 0}},
574 {{1, 1, 513, 513}},
575 {{1, 1, 512, 512}}}, //
576 R4Spec{{{1, 1, 1024, 4}},
577 {{3, 2, 1, 0}},
578 {{0, 0, 15, 0}},
579 {{1, 1, 1022, 4}},
580 {{1, 1, 23, 1}}}, //
581 R4Spec{{{1, 1, 1024, 4}},
582 {{3, 2, 1, 0}},
583 {{0, 0, 14, 0}},
584 {{1, 1, 1023, 4}},
585 {{1, 1, 101, 1}}}, //
586 R4Spec{{{1, 1, 4, 1024}},
587 {{3, 2, 1, 0}},
588 {{0, 0, 1, 20}},
589 {{1, 1, 4, 1023}},
590 {{1, 1, 1, 129}}}, //
591 R4Spec{{{5, 5, 512, 1024}},
592 {{3, 2, 1, 0}},
593 {{1, 1, 0, 0}},
594 {{4, 4, 512, 1024}},
595 {{2, 2, 2, 1}}}, //
596 R4Spec{{{5, 5, 512, 1024}},
597 {{3, 2, 1, 0}},
598 {{1, 1, 0, 0}},
599 {{4, 4, 512, 1024}},
600 {{2, 1, 1, 400}}}, //
601 R4Spec{{{32, 64, 128, 256}},
602 {{3, 2, 1, 0}},
603 {{10, 20, 30, 40}},
604 {{30, 60, 100, 200}},
605 {{11, 21, 31, 41}}}, //
606 R4Spec{{{1, 1, 14, 2048}},
607 {{3, 2, 1, 0}},
608 {{0, 0, 2, 0}},
609 {{1, 1, 14, 2}},
610 {{1, 1, 1, 1}}}, //
611 };
612
613 INSTANTIATE_TEST_CASE_P(SliceR4TestInstantiation, SliceR4Test,
614 ::testing::ValuesIn(kR4SpecValues), R4SpecToString);
615
616 } // namespace
617 } // namespace xla
618