1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests that slice operations can be performed.
17 
18 #include <numeric>
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/array2d.h"
27 #include "tensorflow/compiler/xla/client/local_client.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 class SliceTest : public ClientLibraryTestBase {};
40 
TEST_F(SliceTest,Slice3x3x3_To_3x3x1_F32)41 TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
42   Array3D<float> values(3, 3, 3);
43   values.FillIota(0);
44 
45   XlaBuilder builder(TestName());
46   auto original = ConstantR3FromArray3D<float>(&builder, values);
47   Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1});
48 
49   Array3D<float> expected{
50       {{0.0}, {3.0}, {6.0}}, {{9.0}, {12.0}, {15.0}}, {{18.0}, {21.0}, {24.0}}};
51   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
52 }
53 
TEST_F(SliceTest,Slice3x3x3_To_3x1x3_F32)54 TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) {
55   Array3D<float> values(3, 3, 3);
56   values.FillIota(0);
57 
58   XlaBuilder builder(TestName());
59   auto original = ConstantR3FromArray3D<float>(&builder, values);
60   Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1});
61 
62   Array3D<float> expected{
63       {{0.0, 1.0, 2.0}}, {{9.0, 10.0, 11.0}}, {{18.0, 19.0, 20.0}}};
64   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
65 }
66 
TEST_F(SliceTest,Slice3x3x3_To_1x3x3_F32)67 TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) {
68   Array3D<float> values(3, 3, 3);
69   values.FillIota(0);
70 
71   XlaBuilder builder(TestName());
72   auto original = ConstantR3FromArray3D<float>(&builder, values);
73   Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1});
74 
75   Array3D<float> expected{
76       {{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}}}};
77   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.000001));
78 }
79 
XLA_TEST_F(SliceTest,Slice0x0to0x0F32)80 XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
81   XlaBuilder builder(TestName());
82   auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 0));
83   Slice(original, {0, 0}, {0, 0}, {1, 1});
84 
85   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
86 }
87 
XLA_TEST_F(SliceTest,Slice0x20to0x5F32)88 XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
89   XlaBuilder builder(TestName());
90   auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 20));
91   Slice(original, {0, 15}, {0, 20}, {1, 1});
92 
93   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
94 }
95 
XLA_TEST_F(SliceTest,Slice3x0to2x0F32)96 XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
97   XlaBuilder builder(TestName());
98   auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(3, 0));
99   Slice(original, {1, 0}, {3, 0}, {1, 1});
100 
101   ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
102 }
103 
XLA_TEST_F(SliceTest,SliceQuadrantOf256x256)104 XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
105   Array2D<float> values(256, 256);
106   for (int row = 0; row < 256; ++row) {
107     for (int col = 0; col < 256; ++col) {
108       values(row, col) = (row << 10) | col;
109     }
110   }
111 
112   XlaBuilder builder(TestName());
113   auto original = ConstantR2FromArray2D<float>(&builder, values);
114   Slice(original, {128, 128}, {256, 256}, {1, 1});
115 
116   Array2D<float> expected(128, 128);
117   for (int row = 0; row < 128; ++row) {
118     for (int col = 0; col < 128; ++col) {
119       expected(row, col) = ((row + 128) << 10) | (col + 128);
120     }
121   }
122   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
123 }
124 
125 // Tests: (f32[1,4096], starts={0, 3072}, limits={1, 4096}) -> f32[1,1024])
TEST_F(SliceTest,Slice_1x4096_To_1x1024)126 TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
127   Array2D<float> values(1, 4096);
128   std::iota(values.data(), values.data() + 4096, 0.0);
129 
130   XlaBuilder builder(TestName());
131   auto original = ConstantR2FromArray2D<float>(&builder, values);
132   Slice(original, {0, 3072}, {1, 4096}, {1, 1});
133 
134   Array2D<float> expected(1, 1024);
135   std::iota(expected.data(), expected.data() + 1024, 3072.0);
136   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
137 }
138 
139 // Tests slice: (f32[16,4], starts={0, 0}, limits={16, 2}) -> f32[16,2]
TEST_F(SliceTest,Slice_16x4_To_16x2)140 TEST_F(SliceTest, Slice_16x4_To_16x2) {
141   Array2D<float> values(16, 4);
142   Array2D<float> expected(16, 2);
143   for (int row = 0; row < 16; ++row) {
144     for (int col = 0; col < 4; ++col) {
145       values(row, col) = (row << 10) | col;
146       if (col < 2) {
147         expected(row, col) = (row << 10) | col;
148       }
149     }
150   }
151   XlaBuilder builder(TestName());
152   auto original = ConstantR2FromArray2D<float>(&builder, values);
153   Slice(original, {0, 0}, {16, 2}, {1, 1});
154   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
155 }
156 
157 // Tests: (f32[2, 2, 24, 256], starts = {1, 0, 8, 0}, ends = {2, 2, 16, 128}
TEST_F(SliceTest,SliceR4ThreeDimsMiddleMinor)158 TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
159   Array4D<float> values(2, 2, 24, 256);
160   values.FillRandom(3.14f);
161   auto expected = ReferenceUtil::Slice4D(
162       values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}});
163   XlaBuilder builder(TestName());
164   auto original = ConstantR4FromArray4D(&builder, values);
165   Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
166   ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
167 }
168 
TEST_F(SliceTest,SliceOfReshape)169 TEST_F(SliceTest, SliceOfReshape) {
170   Array2D<int> values(2 * 3 * 24, 7);
171   values.FillIota(1);
172   XlaBuilder builder(TestName());
173   auto original = ConstantR2FromArray2D(&builder, values);
174   auto reshape = Reshape(original, {24, 3, 2, 7});
175   Slice(reshape, {0, 0, 0, 0}, {11, 3, 2, 7}, {1, 1, 1, 1});
176   ComputeAndCompare(&builder, {});
177 }
178 
TEST_F(SliceTest,SliceOfCollapsingReshape)179 TEST_F(SliceTest, SliceOfCollapsingReshape) {
180   Array4D<int> values(2, 3, 5, 7);
181   values.FillIota(1);
182   XlaBuilder builder(TestName());
183   auto original = ConstantR4FromArray4D(&builder, values);
184   auto reshape = Reshape(original, {2 * 3 * 5, 7});
185   Slice(reshape, {0, 0}, {4, 7}, {1, 1});
186   ComputeAndCompare(&builder, {});
187 }
188 
XLA_TEST_F(SliceTest,StridedSliceR4WithOutputLayout)189 XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
190   Array4D<float> values(2, 4, 6, 8);
191   values.FillRandom(3.14f);
192   auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
193                                          /*strides=*/{{1, 1, 2, 1}});
194   auto expected_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
195       *expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
196   XlaBuilder builder(TestName());
197   auto original = ConstantR4FromArray4D(&builder, values);
198   Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
199   ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001),
200                            &expected_literal.shape());
201 }
202 
203 struct R1Spec {
204   int64 input_dim0;
205   int64 slice_start;
206   int64 slice_limit;
207   int64 slice_stride;
208 };
209 
210 // Parameterized test that generates R1 values, slices them according
211 // to the R1Spec, and compares the result with a computed version.
212 class SliceR1Test : public ClientLibraryTestBase,
213                     public ::testing::WithParamInterface<R1Spec> {
214  protected:
215   template <typename NativeT>
Run(const R1Spec & spec)216   void Run(const R1Spec& spec) {
217     // This can't be an std::vector, since you can't grab a Span of a
218     // vector<bool>.
219     absl::InlinedVector<NativeT, 1> input(spec.input_dim0);
220     std::iota(input.begin(), input.end(), NativeT());
221     auto literal = LiteralUtil::CreateR1<NativeT>(input);
222 
223     XlaBuilder builder(TestName());
224     auto original = Parameter(&builder, 0, literal.shape(), "p0");
225     Slice(original, {spec.slice_start}, {spec.slice_limit},
226           {spec.slice_stride});
227 
228     // Ditto.
229     absl::InlinedVector<NativeT, 1> expected;
230     for (int i = spec.slice_start; i < spec.slice_limit;
231          i += spec.slice_stride) {
232       expected.push_back(i);
233     }
234 
235     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
236                             client_->TransferToServer(literal));
237     ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
238   }
239 };
240 
241 // A version of SliceR1Test used to label and disable 'large' tests
242 class SliceR1LargeTest : public SliceR1Test {};
243 
SliceR1TestDataToString(const::testing::TestParamInfo<R1Spec> & data)244 string SliceR1TestDataToString(const ::testing::TestParamInfo<R1Spec>& data) {
245   const R1Spec& spec = data.param;
246   return absl::StrFormat("%d_%d_%d_%d", spec.input_dim0, spec.slice_start,
247                          spec.slice_limit, spec.slice_stride);
248 }
249 
XLA_TEST_P(SliceR1Test,DoIt_F32)250 XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
251 
XLA_TEST_P(SliceR1Test,DoIt_F64)252 XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); }
253 
XLA_TEST_P(SliceR1Test,DoIt_U32)254 XLA_TEST_P(SliceR1Test, DoIt_U32) { Run<uint32>(GetParam()); }
255 
XLA_TEST_P(SliceR1Test,DoIt_S32)256 XLA_TEST_P(SliceR1Test, DoIt_S32) { Run<int32>(GetParam()); }
257 
XLA_TEST_P(SliceR1Test,DoIt_U64)258 XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
259 
XLA_TEST_P(SliceR1Test,DoIt_S64)260 XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
261 
XLA_TEST_P(SliceR1LargeTest,DoIt_F32)262 XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run<float>(GetParam()); }
263 
XLA_TEST_P(SliceR1LargeTest,DoIt_F64)264 XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run<double>(GetParam()); }
265 
XLA_TEST_P(SliceR1LargeTest,DoIt_U32)266 XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run<uint32>(GetParam()); }
267 
XLA_TEST_P(SliceR1LargeTest,DoIt_S32)268 XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run<int32>(GetParam()); }
269 
XLA_TEST_P(SliceR1LargeTest,DoIt_U64)270 XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run<uint64>(GetParam()); }
271 
XLA_TEST_P(SliceR1LargeTest,DoIt_S64)272 XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run<int64>(GetParam()); }
273 
XLA_TEST_P(SliceR1Test,DoIt_PRED)274 XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run<bool>(GetParam()); }
275 
276 // Tests for R1 slice ops.
277 // The format for each testcase is {input size, start, limit, stride}.
278 // clang-format off
279 INSTANTIATE_TEST_CASE_P(
280     SliceR1TestInstantiation,
281     SliceR1Test,
282     ::testing::Values(
283         R1Spec{10, 0, 0, 1},
284         R1Spec{10, 7, 7, 1},
285         R1Spec{10, 0, 5, 1},
286         R1Spec{10, 3, 5, 1},
287         R1Spec{10, 0, 10, 1},
288         R1Spec{1024, 0, 5, 1},
289         R1Spec{1024, 3, 5, 1},
290         R1Spec{1024 + 17, 0, 5, 1},
291         R1Spec{1024 + 17, 3, 5, 1},
292         R1Spec{1024 + 17, 1024, 1024 + 6, 1},
293         R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1},
294         R1Spec{1024, 1024 - 4, 1024, 1},
295         R1Spec{4 * 1024, 7, 7 + 1024, 1},
296         R1Spec{4 * 1024, 0, 4 * 1024, 1},
297         R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1},
298         R1Spec{4 * 1024, 1024, 3 * 1024, 1},
299         R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1},
300         R1Spec{16 * 1024, 0, 5, 1},
301         R1Spec{16 * 1024, 3, 5, 1},
302         R1Spec{16 * 1024 + 17, 0, 5, 1},
303         R1Spec{16 * 1024 + 17, 3, 5, 1},
304         R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1},
305         R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1},
306         R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1},
307         R1Spec{64 * 1024, 0, 64 * 1024, 1},
308         R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1},
309         R1Spec{64 * 1024, 1024, 63 * 1024, 1},
310         R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1},
311         R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1},
312         R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1},
313         R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}
314     ),
315     SliceR1TestDataToString
316 );
317 
318 // TODO(b/69425338): This uses too much memory on GPU.
319 #ifndef XLA_TEST_BACKEND_GPU
320 INSTANTIATE_TEST_CASE_P(
321     SliceR1TestBigSlicesInstantiation,
322     SliceR1LargeTest,
323     ::testing::Values(
324           R1Spec{
325               16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1},
326           R1Spec{
327               16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1},
328           R1Spec{
329               16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}
330     ),
331     SliceR1TestDataToString
332 );
333 #endif
334 
335 INSTANTIATE_TEST_CASE_P(
336     SliceStridedR1TestInstantiation,
337     SliceR1Test,
338     ::testing::Values(
339         R1Spec{10, 2, 4, 2},
340         R1Spec{10, 0, 10, 2},
341         R1Spec{10, 0, 10, 3},
342         R1Spec{10, 0, 10, 4},
343         R1Spec{10, 0, 10, 5},
344         R1Spec{10, 0, 10, 10},
345         R1Spec{500, 200, 400, 7},
346         R1Spec{4096, 1, 4095, 3},
347         R1Spec{2047, 1024 - 24, 1024 + 160, 31},
348         R1Spec{2047, 1, 2046, 3 * 128},
349         R1Spec{4096, 1024 + 3, 4095, 500},
350         R1Spec{8192, 0, 8192, 1024 * 3 + 400},
351         R1Spec{1024 * 1024, 0, 1024 * 1024, 2},
352         R1Spec{1024 * 1024, 0, 1024 * 1024, 8},
353         R1Spec{1024 * 1024, 0, 1024 * 1024, 7},
354         R1Spec{1024 * 1024, 0, 1024 * 1024, 125},
355         R1Spec{1024 * 1024, 3, 1024 - 9, 2},
356         R1Spec{1024 * 1024, 3, 1024 - 9, 8},
357         R1Spec{1024 * 1024, 3, 1024 - 9, 7},
358         R1Spec{1024 * 1024, 3, 1024 - 9, 125},
359         R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 2},
360         R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 8},
361         R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 7},
362         R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 125},
363         R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 2},
364         R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 8},
365         R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 7},
366         R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125},
367         R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4097},
368         R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4093},
369         R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4097},
370         R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4093}
371     ),
372     SliceR1TestDataToString
373 );
374 // clang-format on
375 
376 struct R2Spec {
377   int64 input_dim0;
378   int64 input_dim1;
379   std::array<int64, 2> slice_starts;
380   std::array<int64, 2> slice_limits;
381   std::array<int64, 2> slice_strides;
382   std::array<int64, 2> layout;
383 };
384 
385 // Parameterized test that generates patterned R2 values, slices them according
386 // to the R2Spec, and compares the results with the ReferenceUtil version.
387 class SliceR2Test : public ClientLibraryTestBase,
388                     public ::testing::WithParamInterface<R2Spec> {};
389 
XLA_TEST_P(SliceR2Test,DoIt)390 XLA_TEST_P(SliceR2Test, DoIt) {
391   const R2Spec& spec = GetParam();
392   Array2D<int32> input(spec.input_dim0, spec.input_dim1);
393   input.FillUnique();
394   auto literal = LiteralUtil::CreateR2FromArray2DWithLayout(
395       input, LayoutUtil::MakeLayout(spec.layout));
396 
397   XlaBuilder builder(TestName());
398   auto a = Parameter(&builder, 0, literal.shape(), "p0");
399   Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
400 
401   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
402                           client_->TransferToServer(literal));
403   std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
404       input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
405   ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
406 }
407 
408 INSTANTIATE_TEST_CASE_P(
409     SliceR2TestInstantiation, SliceR2Test,
410     ::testing::Values(
411         R2Spec{4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, {{0, 1}}},              //
412         R2Spec{4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, {{1, 0}}},              //
413         R2Spec{16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, {{0, 1}}},             //
414         R2Spec{16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, {{1, 0}}},             //
415         R2Spec{256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}}, {{1, 0}}},     //
416         R2Spec{500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}}, {{1, 0}}},   //
417         R2Spec{500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}}, {{1, 0}}},   //
418         R2Spec{384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}}, {{1, 0}}},   //
419         R2Spec{357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}}, {{1, 0}}},   //
420         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, {{0, 1}}},           //
421         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{1, 2}}, {{1, 0}}},           //
422         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, {{0, 1}}},           //
423         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 1}}, {{1, 0}}},           //
424         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, {{0, 1}}},           //
425         R2Spec{10, 10, {{0, 0}}, {{10, 10}}, {{2, 2}}, {{1, 0}}},           //
426         R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{3, 5}}, {{1, 0}}},   //
427         R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{3, 5}}, {{0, 1}}},   //
428         R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{5, 3}}, {{1, 0}}},   //
429         R2Spec{256, 400, {{100, 129}}, {{256, 400}}, {{5, 3}}, {{0, 1}}},   //
430         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{1, 0}}},  //
431         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}},  //
432         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}},  //
433         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}},  //
434         R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}},   //
435         R2Spec{
436             511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}},  //
437         R2Spec{
438             511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{0, 1}}},  //
439         R2Spec{
440             511, 513, {{129, 300}}, {{400, 500}}, {{129, 101}}, {{1, 0}}},  //
441         R2Spec{
442             511, 513, {{129, 300}}, {{400, 500}}, {{129, 101}}, {{0, 1}}},  //
443         R2Spec{
444             511, 1023, {{129, 257}}, {{500, 1000}}, {{129, 255}}, {{1, 0}}},  //
445         R2Spec{
446             511, 1023, {{129, 257}}, {{500, 1000}}, {{129, 255}}, {{0, 1}}},  //
447         R2Spec{511,
448                513,
449                {{129, 255}},
450                {{511 - 129, 513 - 140}},
451                {{13, 19}},
452                {{1, 0}}},  //
453         R2Spec{511,
454                513,
455                {{129, 255}},
456                {{511 - 129, 513 - 140}},
457                {{13, 19}},
458                {{0, 1}}}  //
459         ));
460 
461 struct R4Spec {
462   std::array<int64, 4> input_dims;
463   std::array<int64, 4> input_layout;  // minor-to-major
464   std::array<int64, 4> slice_starts;
465   std::array<int64, 4> slice_limits;
466   std::array<int64, 4> slice_strides;
467 };
468 
R4SpecToString(const::testing::TestParamInfo<R4Spec> & data)469 string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
470   const R4Spec& spec = data.param;
471   return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"),
472                       "__layout_", absl::StrJoin(spec.input_layout, ""),
473                       "__starts_", absl::StrJoin(spec.slice_starts, "x"),
474                       "__limits_", absl::StrJoin(spec.slice_limits, "x"),
475                       "__strides_", absl::StrJoin(spec.slice_strides, "x"));
476 }
477 
478 class SliceR4Test : public ClientLibraryTestBase,
479                     public ::testing::WithParamInterface<R4Spec> {
480  protected:
Run(const R4Spec & spec)481   void Run(const R4Spec& spec) {
482     Array4D<float> values(spec.input_dims[0], spec.input_dims[1],
483                           spec.input_dims[2], spec.input_dims[3]);
484     values.FillIota(3.14159);
485     auto expected = ReferenceUtil::Slice4D(
486         values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
487     XlaBuilder builder(TestName());
488     auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
489         values, LayoutUtil::MakeLayout(spec.input_layout));
490     auto parameter = Parameter(&builder, 0, literal.shape(), "p0");
491     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
492                             client_->TransferToServer(literal));
493     Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
494     ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
495   }
496 };
497 
XLA_TEST_P(SliceR4Test,DoIt)498 XLA_TEST_P(SliceR4Test, DoIt) { Run(GetParam()); }
499 
500 const R4Spec kR4SpecValues[] = {
501     R4Spec{{{2, 2, 2, 2}},
502            {{3, 2, 1, 0}},
503            {{0, 0, 0, 0}},
504            {{0, 0, 0, 0}},
505            {{1, 1, 1, 1}}},  //
506     R4Spec{{{3, 3, 4, 4}},
507            {{3, 2, 1, 0}},
508            {{0, 0, 0, 0}},
509            {{3, 3, 4, 4}},
510            {{1, 1, 2, 1}}},  //
511     R4Spec{{{2, 3, 16, 4}},
512            {{3, 2, 1, 0}},
513            {{0, 0, 0, 0}},
514            {{2, 3, 16, 4}},
515            {{1, 1, 3, 1}}},  //
516     R4Spec{{{4, 16, 3, 2}},
517            {{0, 1, 2, 3}},
518            {{1, 4, 1, 0}},
519            {{3, 12, 3, 2}},
520            {{1, 1, 3, 2}}},  //
521     R4Spec{{{2, 2, 257, 129}},
522            {{3, 2, 1, 0}},
523            {{1, 1, 62, 64}},
524            {{2, 2, 195, 129}},
525            {{1, 1, 3, 1}}},  //
526     R4Spec{{{3, 5, 257, 129}},
527            {{3, 2, 1, 0}},
528            {{1, 2, 61, 64}},
529            {{3, 5, 199, 129}},
530            {{1, 1, 3, 1}}},  //
531     R4Spec{{{5, 8, 257, 129}},
532            {{3, 2, 1, 0}},
533            {{2, 3, 60, 64}},
534            {{3, 5, 200, 68}},
535            {{1, 1, 1, 1}}},  //
536     R4Spec{{{8, 10, 256, 130}},
537            {{3, 2, 1, 0}},
538            {{1, 2, 60, 127}},
539            {{7, 9, 166, 129}},
540            {{4, 2, 3, 1}}},  //
541     R4Spec{{{2, 4, 8, 4}},
542            {{3, 2, 1, 0}},
543            {{1, 2, 0, 1}},
544            {{2, 4, 8, 3}},
545            {{1, 1, 7, 1}}},  //
546     R4Spec{{{10, 21, 256, 150}},
547            {{3, 2, 1, 0}},
548            {{1, 2, 9, 127}},
549            {{9, 16, 82, 133}},
550            {{3, 5, 7, 2}}},  //
551     R4Spec{{{15, 25, 256, 150}},
552            {{3, 2, 1, 0}},
553            {{4, 6, 19, 126}},
554            {{15, 25, 89, 135}},
555            {{5, 7, 7, 3}}},  //
556     R4Spec{{{2, 4, 256, 150}},
557            {{3, 2, 1, 0}},
558            {{1, 2, 29, 125}},
559            {{2, 4, 159, 145}},
560            {{1, 1, 7, 7}}},  //
561     R4Spec{{{2, 4, 256, 150}},
562            {{3, 2, 1, 0}},
563            {{1, 2, 39, 119}},
564            {{2, 4, 158, 145}},
565            {{1, 1, 7, 11}}},  //
566     R4Spec{{{1, 1, 5, 512}},
567            {{3, 2, 1, 0}},
568            {{0, 0, 0, 0}},
569            {{1, 1, 5, 512}},
570            {{1, 1, 4, 1}}},  //
571     R4Spec{{{1, 1, 513, 513}},
572            {{3, 2, 1, 0}},
573            {{0, 0, 0, 0}},
574            {{1, 1, 513, 513}},
575            {{1, 1, 512, 512}}},  //
576     R4Spec{{{1, 1, 1024, 4}},
577            {{3, 2, 1, 0}},
578            {{0, 0, 15, 0}},
579            {{1, 1, 1022, 4}},
580            {{1, 1, 23, 1}}},  //
581     R4Spec{{{1, 1, 1024, 4}},
582            {{3, 2, 1, 0}},
583            {{0, 0, 14, 0}},
584            {{1, 1, 1023, 4}},
585            {{1, 1, 101, 1}}},  //
586     R4Spec{{{1, 1, 4, 1024}},
587            {{3, 2, 1, 0}},
588            {{0, 0, 1, 20}},
589            {{1, 1, 4, 1023}},
590            {{1, 1, 1, 129}}},  //
591     R4Spec{{{5, 5, 512, 1024}},
592            {{3, 2, 1, 0}},
593            {{1, 1, 0, 0}},
594            {{4, 4, 512, 1024}},
595            {{2, 2, 2, 1}}},  //
596     R4Spec{{{5, 5, 512, 1024}},
597            {{3, 2, 1, 0}},
598            {{1, 1, 0, 0}},
599            {{4, 4, 512, 1024}},
600            {{2, 1, 1, 400}}},  //
601     R4Spec{{{32, 64, 128, 256}},
602            {{3, 2, 1, 0}},
603            {{10, 20, 30, 40}},
604            {{30, 60, 100, 200}},
605            {{11, 21, 31, 41}}},  //
606     R4Spec{{{1, 1, 14, 2048}},
607            {{3, 2, 1, 0}},
608            {{0, 0, 2, 0}},
609            {{1, 1, 14, 2}},
610            {{1, 1, 1, 1}}},  //
611 };
612 
613 INSTANTIATE_TEST_CASE_P(SliceR4TestInstantiation, SliceR4Test,
614                         ::testing::ValuesIn(kR4SpecValues), R4SpecToString);
615 
616 }  // namespace
617 }  // namespace xla
618