1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/array3d.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/reference_util.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 #include "tensorflow/core/platform/test.h"
32
33 namespace xla {
34 namespace {
35
36 using ConcatTest = ClientLibraryTestBase;
37 using ::testing::HasSubstr;
38
39 // Concatenate expects at least one argument.
XLA_TEST_F(ConcatTest,Concat_Nothing)40 XLA_TEST_F(ConcatTest, Concat_Nothing) {
41 XlaBuilder builder(TestName());
42 ConcatInDim(&builder, {}, 0);
43 StatusOr<XlaComputation> computation_status = builder.Build();
44 ASSERT_FALSE(computation_status.ok());
45 EXPECT_THAT(computation_status.status().ToString(),
46 HasSubstr("Concatenate expects at least one argument"));
47 }
48
49 // Concatenate with one argument works.
XLA_TEST_F(ConcatTest,Concat_R1_With_Nothing)50 XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
51 XlaBuilder builder(TestName());
52 auto a = ConstantR1<float>(&builder, {42.0, 64.0});
53 ConcatInDim(&builder, {a}, 0);
54
55 std::vector<float> expected = {42, 64};
56 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
57 }
58
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_Nothing)59 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
60 XlaBuilder builder(TestName());
61 auto a = ConstantR1<float>(&builder, {});
62 ConcatInDim(&builder, {a}, 0);
63
64 std::vector<float> expected = {};
65 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
66 }
67
68 // Show that we can't concatenate R0 with R0 because we can't name the dimension
69 // to concatenate on.
XLA_TEST_F(ConcatTest,CannotConcatR0WithR0)70 XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
71 XlaBuilder builder(TestName());
72 auto a = ConstantR0<float>(&builder, 42.0);
73 auto b = ConstantR0<float>(&builder, 64.0);
74 ConcatInDim(&builder, {a, b}, 0);
75 StatusOr<XlaComputation> computation_status = builder.Build();
76 ASSERT_FALSE(computation_status.ok());
77 EXPECT_THAT(computation_status.status().ToString(),
78 HasSubstr("out of bounds: 0"));
79 }
80
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_R1_L0)81 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
82 XlaBuilder builder(TestName());
83 auto a = ConstantR1<float>(&builder, {});
84 auto b = ConstantR1<float>(&builder, {});
85 ConcatInDim(&builder, {a, b}, 0);
86
87 std::vector<float> expected = {};
88 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
89 }
90
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_R1_L1)91 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
92 XlaBuilder builder(TestName());
93 auto a = ConstantR1<float>(&builder, {});
94 auto b = ConstantR1<float>(&builder, {256.0});
95 ConcatInDim(&builder, {a, b}, 0);
96
97 std::vector<float> expected = {256};
98 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
99 }
100
XLA_TEST_F(ConcatTest,Concat_R1_L2_With_R1_L0)101 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
102 XlaBuilder builder(TestName());
103 auto a = ConstantR1<float>(&builder, {42.0, 64.0});
104 auto b = ConstantR1<float>(&builder, {});
105 ConcatInDim(&builder, {a, b}, 0);
106
107 std::vector<float> expected = {42, 64};
108 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
109 }
110
XLA_TEST_F(ConcatTest,Concat_R1_L2_With_R1_L1)111 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
112 XlaBuilder builder(TestName());
113 auto a = ConstantR1<float>(&builder, {42.0, 64.0});
114 auto b = ConstantR1<float>(&builder, {256.0});
115 ConcatInDim(&builder, {a, b}, 0);
116
117 std::vector<float> expected = {42, 64, 256};
118 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
119 }
120
XLA_TEST_F(ConcatTest,Concat_R1_L253_With_R1_L7)121 XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
122 std::vector<float> lhs(253);
123 std::vector<float> rhs(7);
124 std::vector<float> expected(253 + 7);
125 for (int i = 0; i < 253; ++i) {
126 expected[i] = lhs[i] = i + 1;
127 }
128 for (int i = 0; i < 7; ++i) {
129 expected[253 + i] = rhs[i] = 253 + i + 1;
130 }
131
132 XlaBuilder builder(TestName());
133 auto a = ConstantR1<float>(&builder, lhs);
134 auto b = ConstantR1<float>(&builder, rhs);
135 ConcatInDim(&builder, {a, b}, 0);
136
137 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
138 }
139
XLA_TEST_F(ConcatTest,Concat_0x0_With_0x0)140 XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
141 for (int dim : {0, 1}) {
142 XlaBuilder builder(TestName());
143 auto a = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
144 auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
145 ConcatInDim(&builder, {a, b}, dim);
146
147 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
148 ErrorSpec(0.0001));
149 }
150 }
151
XLA_TEST_F(ConcatTest,Concat_1x1_With_1x1_InDim0)152 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
153 XlaBuilder builder(TestName());
154 auto a_array = CreatePatternedMatrix(1, 1);
155 auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
156 auto a = ConstantR2FromArray2D(&builder, *a_array);
157 auto b = ConstantR2FromArray2D(&builder, *b_array);
158 ConcatInDim(&builder, {a, b}, 0);
159
160 Array2D<float> expected({
161 {0},
162 {64},
163 });
164 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
165 }
166
XLA_TEST_F(ConcatTest,Concat_1x1_With_1x1_InDim1)167 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
168 XlaBuilder builder(TestName());
169 auto a_array = CreatePatternedMatrix(1, 1);
170 auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
171 auto a = ConstantR2FromArray2D(&builder, *a_array);
172 auto b = ConstantR2FromArray2D(&builder, *b_array);
173 ConcatInDim(&builder, {a, b}, 1);
174
175 Array2D<float> expected({
176 {0, 64},
177 });
178 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
179 }
180
XLA_TEST_F(ConcatTest,Concat2x0With2x5)181 XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
182 XlaBuilder builder(TestName());
183 auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
184 auto a = ConstantR2FromArray2D(&builder, Array2D<float>(2, 0));
185 auto b = ConstantR2FromArray2D(&builder, *b_array);
186 ConcatInDim(&builder, {a, b}, 1);
187
188 ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
189 }
190
XLA_TEST_F(ConcatTest,Concat2x3With2x5)191 XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
192 XlaBuilder builder(TestName());
193 auto a_array = CreatePatternedMatrix(2, 3);
194 auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
195 auto a = ConstantR2FromArray2D(&builder, *a_array);
196 auto b = ConstantR2FromArray2D(&builder, *b_array);
197 ConcatInDim(&builder, {a, b}, 1);
198
199 Array2D<float> expected({
200 {0, 1, 2, 64, 65, 66, 67, 68},
201 {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
202 });
203 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
204 }
205
XLA_TEST_F(ConcatTest,Concat3x2With0x2)206 XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
207 XlaBuilder builder(TestName());
208 auto a_array = CreatePatternedMatrix(3, 2);
209 auto a = ConstantR2FromArray2D(&builder, *a_array);
210 auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 2));
211 ConcatInDim(&builder, {a, b}, 0);
212
213 ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
214 }
215
XLA_TEST_F(ConcatTest,Concat3x2With5x2)216 XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
217 XlaBuilder builder(TestName());
218 auto a_array = CreatePatternedMatrix(3, 2);
219 auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
220 auto a = ConstantR2FromArray2D(&builder, *a_array);
221 auto b = ConstantR2FromArray2D(&builder, *b_array);
222 ConcatInDim(&builder, {a, b}, 0);
223
224 Array2D<float> expected({
225 {0, 1},
226 {1000, 1001},
227 {2000, 2001},
228 {64, 65},
229 {1064, 1065},
230 {2064, 2065},
231 {3064, 3065},
232 {4064, 4065},
233 });
234 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
235 }
236
XLA_TEST_F(ConcatTest,Concat_R3_3x0x2_3x0x1)237 XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
238 XlaBuilder builder(TestName());
239 auto a = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 2));
240 auto b = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 1));
241 ConcatInDim(&builder, {a, b}, 2);
242 ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
243 ErrorSpec(0.0001));
244 }
245
XLA_TEST_F(ConcatTest,Concat_R3_3x1x2_3x1x1)246 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
247 XlaBuilder builder(TestName());
248 Array3D<float> a_array({
249 // 3x1x2
250 {{0, 1}},
251 {{2, 3}},
252 {{4, 5}},
253 });
254 Array3D<float> b_array({
255 // 3x1x1
256 {{6}},
257 {{7}},
258 {{8}},
259 });
260 auto a = ConstantR3FromArray3D(&builder, a_array);
261 auto b = ConstantR3FromArray3D(&builder, b_array);
262 ConcatInDim(&builder, {a, b}, 2);
263
264 Array3D<float> expected({
265 {{0, 1, 6}},
266 {{2, 3, 7}},
267 {{4, 5, 8}},
268 });
269 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
270 }
271
XLA_TEST_F(ConcatTest,Concat_R1_1x1_1x1_1x1)272 XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
273 XlaBuilder builder(TestName());
274 auto a = ConstantR1<float>(&builder, {42.0});
275 auto b = ConstantR1<float>(&builder, {64.0});
276 auto c = ConstantR1<float>(&builder, {256.0});
277 ConcatInDim(&builder, {a, b, c}, 0);
278
279 std::vector<float> expected = {42, 64, 256};
280 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
281 }
282
XLA_TEST_F(ConcatTest,Concat_R3_3x1x2_3x1x1_3x1x1)283 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
284 XlaBuilder builder(TestName());
285 Array3D<float> a_array({
286 // 3x1x2
287 {{0, 1}},
288 {{4, 5}},
289 {{8, 9}},
290 });
291 Array3D<float> b_array({
292 // 3x1x1
293 {{2}},
294 {{6}},
295 {{10}},
296 });
297 Array3D<float> c_array({
298 // 3x1x1
299 {{3}},
300 {{7}},
301 {{11}},
302 });
303 auto a = ConstantR3FromArray3D(&builder, a_array);
304 auto b = ConstantR3FromArray3D(&builder, b_array);
305 auto c = ConstantR3FromArray3D(&builder, c_array);
306 ConcatInDim(&builder, {a, b, c}, 2);
307
308 Array3D<float> expected({
309 {{0, 1, 2, 3}},
310 {{4, 5, 6, 7}},
311 {{8, 9, 10, 11}},
312 });
313 ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
314 }
315
XLA_TEST_F(ConcatTest,DoubleConcatLeftAssociative)316 XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
317 XlaBuilder builder(TestName());
318 auto a = ConstantR1<float>(&builder, {42.0});
319 auto b = ConstantR1<float>(&builder, {64.0});
320 auto c = ConstantR1<float>(&builder, {256.0});
321 // concatenated = (a concat b) concat c
322 ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0);
323
324 std::vector<float> expected = {42, 64, 256};
325 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
326 }
327
XLA_TEST_F(ConcatTest,DoubleConcatRightAssociative)328 XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
329 XlaBuilder builder(TestName());
330 auto a = ConstantR1<float>(&builder, {42.0});
331 auto b = ConstantR1<float>(&builder, {64.0});
332 auto c = ConstantR1<float>(&builder, {256.0});
333 // concatenated = a concat (b concat c)
334 ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0);
335
336 std::vector<float> expected = {42, 64, 256};
337 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
338 }
339
XLA_TEST_F(ConcatTest,Concat_1x1024_With_1x1024_InDim0)340 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
341 Array2D<float> lhs(1, 1024);
342 Array2D<float> rhs(1, 1024);
343 for (int i = 0; i < 1024; ++i) {
344 lhs(0, i) = i;
345 rhs(0, i) = i + 1024;
346 }
347
348 XlaBuilder builder(TestName());
349 auto a = ConstantR2FromArray2D<float>(&builder, lhs);
350 auto b = ConstantR2FromArray2D<float>(&builder, rhs);
351 ConcatInDim(&builder, {a, b}, 0);
352
353 Array2D<float> expected(2, 1024);
354 for (int i = 0; i < 1024; ++i) {
355 expected(0, i) = i;
356 expected(1, i) = i + 1024;
357 }
358 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
359 }
360
XLA_TEST_F(ConcatTest,Concat_1x1024_With_1x1024_InDim1)361 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
362 Array2D<float> lhs(1, 1024);
363 Array2D<float> rhs(1, 1024);
364 for (int i = 0; i < 1024; ++i) {
365 lhs(0, i) = i;
366 rhs(0, i) = i + 1024;
367 }
368
369 XlaBuilder builder(TestName());
370 auto a = ConstantR2FromArray2D<float>(&builder, lhs);
371 auto b = ConstantR2FromArray2D<float>(&builder, rhs);
372 ConcatInDim(&builder, {a, b}, 1);
373
374 Array2D<float> expected(1, 2048);
375 for (int i = 0; i < 1024; ++i) {
376 expected(0, i) = i;
377 expected(0, i + 1024) = i + 1024;
378 }
379 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
380 }
381
XLA_TEST_F(ConcatTest,Concat_64x64_With_64x2)382 XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
383 Array2D<float> lhs(64, 64);
384 Array2D<float> rhs(64, 2);
385 for (int i0 = 0; i0 < 64; ++i0) {
386 for (int i1 = 0; i1 < 64; ++i1) {
387 lhs(i0, i1) = (i0 << 10) | i1;
388 }
389 for (int i1 = 0; i1 < 2; ++i1) {
390 rhs(i0, i1) = (i0 << 10) | (i1 + 64);
391 }
392 }
393
394 XlaBuilder builder(TestName());
395 auto a = ConstantR2FromArray2D<float>(&builder, lhs);
396 auto b = ConstantR2FromArray2D<float>(&builder, rhs);
397 ConcatInDim(&builder, {a, b}, 1);
398
399 Array2D<float> expected(64, 66);
400 for (int i0 = 0; i0 < 64; ++i0) {
401 for (int i1 = 0; i1 < 66; ++i1) {
402 expected(i0, i1) = (i0 << 10) | i1;
403 }
404 }
405 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
406 }
407
408 // Show that we can't concatenate with an opaques.
XLA_TEST_F(ConcatTest,CannotConcatOpaques)409 XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
410 XlaBuilder builder(TestName());
411 auto opaque_shape = ShapeUtil::MakeOpaqueShape();
412 auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
413 auto x = Parameter(&builder, 0, r1f32, "x");
414 auto y = Parameter(&builder, 1, opaque_shape, "y");
415 ConcatInDim(&builder, {x, y}, 0);
416 StatusOr<XlaComputation> computation_status = builder.Build();
417 ASSERT_FALSE(computation_status.ok());
418 EXPECT_THAT(
419 computation_status.status().ToString(),
420 HasSubstr("Expected array argument for operand of concatenation"));
421 }
422
423 // Show that we can't concatenate with tokens.
XLA_TEST_F(ConcatTest,CannotConcatTokens)424 XLA_TEST_F(ConcatTest, CannotConcatTokens) {
425 XlaBuilder builder(TestName());
426 auto token_shape = ShapeUtil::MakeTokenShape();
427 auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
428 auto x = Parameter(&builder, 0, r1f32, "x");
429 auto y = Parameter(&builder, 1, token_shape, "y");
430 ConcatInDim(&builder, {x, y}, 0);
431 StatusOr<XlaComputation> computation_status = builder.Build();
432 ASSERT_FALSE(computation_status.ok());
433 EXPECT_THAT(
434 computation_status.status().ToString(),
435 HasSubstr("Expected array argument for operand of concatenation"));
436 }
437
XLA_TEST_F(ConcatTest,ConcatSeveralBoxedPredicates)438 XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
439 XlaBuilder builder(TestName());
440 auto p0 = ConstantR1<bool>(&builder, {true});
441 auto p1 = ConstantR1<bool>(&builder, {false});
442 auto p2 = ConstantR1<bool>(&builder, {true});
443 ConcatInDim(&builder, {p0, p1, p2}, 0);
444
445 bool expected[] = {true, false, true};
446 ComputeAndCompareR1<bool>(&builder, expected, {});
447 }
448
XLA_TEST_F(ConcatTest,ConcatSeveralR1S32s)449 XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
450 XlaBuilder builder(TestName());
451 auto a0 = ConstantR1<int32>(&builder, {1});
452 auto a1 = ConstantR1<int32>(&builder, {2, 3});
453 auto a2 = ConstantR1<int32>(&builder, {4, 5, 6});
454 auto a3 = ConstantR1<int32>(&builder, {7, 8, 9, 10});
455 ConcatInDim(&builder, {a0, a1, a2, a3}, 0);
456
457 std::vector<int32> expected(10);
458 std::iota(expected.begin(), expected.end(), 1);
459 ComputeAndCompareR1<int32>(&builder, expected, {});
460 }
461
XLA_TEST_F(ConcatTest,ConcatR3WeirdDims)462 XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
463 XlaBuilder builder(TestName());
464
465 Array3D<float> arr0(9, 17, 1);
466 arr0.Fill(1);
467
468 Array3D<float> arr1(9, 17, 256);
469 arr1.Fill(2);
470
471 Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
472 for (int64 i = 0; i < expected.n1(); ++i) {
473 for (int64 j = 0; j < expected.n2(); ++j) {
474 int64 kk = 0;
475 for (const Array3D<float>& arr : {arr0, arr1}) {
476 for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
477 expected(i, j, kk) = arr(i, j, k);
478 }
479 }
480 }
481 }
482
483 XlaOp h0;
484 auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
485 &builder, &h0);
486 XlaOp h1;
487 auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
488 &builder, &h1);
489
490 ConcatInDim(&builder, {h0, h1}, 2);
491
492 ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
493 }
494
XLA_TEST_F(ConcatTest,ConcatDeeplyNested)495 XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
496 XlaBuilder builder(TestName());
497 auto a_literal = LiteralUtil::CreateR1<float>({256.0});
498 auto a = Parameter(&builder, 0, a_literal.shape(), "x");
499 auto b = ConcatInDim(&builder, {a, a}, 0);
500 auto c = ConcatInDim(&builder, {b, b}, 0);
501 auto d = ConcatInDim(&builder, {c, c}, 0);
502 auto e = ConcatInDim(&builder, {d, d}, 0);
503 auto f = ConcatInDim(&builder, {e, e}, 0);
504 auto g = ConcatInDim(&builder, {f, f}, 0);
505 auto h = ConcatInDim(&builder, {g, g}, 0);
506 auto i = ConcatInDim(&builder, {h, h}, 0);
507 auto j = ConcatInDim(&builder, {i, i}, 0);
508 auto k = ConcatInDim(&builder, {j, j}, 0);
509 auto l = ConcatInDim(&builder, {k, k}, 0);
510 auto m = ConcatInDim(&builder, {l, l}, 0);
511 auto n = ConcatInDim(&builder, {m, m}, 0);
512 auto o = ConcatInDim(&builder, {n, n}, 0);
513 auto p = ConcatInDim(&builder, {o, o}, 0);
514 auto q = ConcatInDim(&builder, {p, p}, 0);
515 ConcatInDim(&builder, {q, q}, 0);
516 std::vector<float> expected(131072, 256.0);
517 auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie();
518 ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
519 }
520
521 // Describes a binary rank-2 concatenation test.
522 struct R2BinarySpec {
523 int64 lhs_dim0;
524 int64 lhs_dim1;
525 int64 rhs_dim0;
526 int64 rhs_dim1;
527 int64 concat_dimension;
528 };
529
530 // TEST_P harness for binary rank-2 concatenation.
531 class ConcatR2BinaryTest : public ClientLibraryTestBase,
532 public ::testing::WithParamInterface<R2BinarySpec> {
533 };
534
TEST_P(ConcatR2BinaryTest,DoIt)535 TEST_P(ConcatR2BinaryTest, DoIt) {
536 const R2BinarySpec& spec = GetParam();
537 Array2D<int32> lhs(spec.lhs_dim0, spec.lhs_dim1);
538 lhs.FillUnique();
539 Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1);
540 rhs.FillUnique(1000);
541
542 XlaBuilder builder(TestName());
543 auto a0 = ConstantR2FromArray2D<int32>(&builder, lhs);
544 auto a1 = ConstantR2FromArray2D<int32>(&builder, rhs);
545 ConcatInDim(&builder, {a0, a1}, spec.concat_dimension);
546
547 std::unique_ptr<Array2D<int32>> expected =
548 ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
549 ComputeAndCompareR2<int32>(&builder, *expected, {});
550 }
551
552 // Regression test for b/31944287. x*y is used (at the same index) by all
553 // operands of the concat. We should emit x*y in three incoming basic blocks of
554 // the concat because these basic blocks are not control-equivalent.
555 //
556 // x*y
557 // / | \
558 // add1 add2 add3
559 // \ | /
560 // concat
XLA_TEST_F(ConcatTest,ConcatOperandsOfSameOperand)561 XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
562 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
563 auto x_literal = LiteralUtil::CreateR0<float>(2.f);
564 auto y_literal = LiteralUtil::CreateR0<float>(3.f);
565 auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
566 auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
567
568 XlaBuilder builder(TestName());
569 auto x = Parameter(&builder, 0, f32_scalar, "x");
570 auto y = Parameter(&builder, 1, f32_scalar, "y");
571 auto mul = Mul(x, y);
572 auto add1 = Add(mul, ConstantR1<float>(&builder, {1.f, 2.f}));
573 auto add2 = Add(mul, ConstantR1<float>(&builder, {3.f, 4.f}));
574 auto add3 = Add(mul, ConstantR1<float>(&builder, {5.f, 6.f}));
575 ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0);
576
577 ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
578 {x_data.get(), y_data.get()}, ErrorSpec(1e-4));
579 }
580
581 // Test that the HLO optimization to replace a concat of a bradcasted scalar
582 // produces the correct result in rank 1.
XLA_TEST_F(ConcatTest,ConcatBroadcastArgument)583 XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
584 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
585 auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
586 auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
587 auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
588 auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
589 auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
590 auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
591
592 XlaBuilder builder(TestName());
593 auto x = Parameter(&builder, 0, x_literal.shape(), "x");
594 auto y = Parameter(&builder, 1, f32_scalar, "y");
595 auto z = Parameter(&builder, 2, f32_scalar, "z");
596 auto bcast = Broadcast(y, {5});
597 auto bcast2 = Broadcast(z, {3});
598 auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0);
599 ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0);
600
601 ComputeAndCompareR1<float>(
602 &builder,
603 {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f},
604 {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4));
605 }
606
607 // Test that the HLO optimization to replace a concat of a bradcasted scalar
608 // produces the correct result in rank 3 with both high and low padding in
609 // different dimensions.
XLA_TEST_F(ConcatTest,ConcatBroadcastArgumentR3)610 XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
611 auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
612 Array3D<float> x3d(3, 5, 7, 3.14f);
613 auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
614 auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
615 auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
616 auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
617 auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
618 auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
619
620 XlaBuilder builder(TestName());
621 auto x = Parameter(&builder, 0, x_literal.shape(), "x");
622 auto y = Parameter(&builder, 1, f32_scalar, "y");
623 auto z = Parameter(&builder, 2, f32_scalar, "y");
624 auto y_bcast = Broadcast(y, {1, 5, 7});
625 auto z_bcast = Broadcast(z, {4, 1, 7});
626 auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0);
627 ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1);
628 Array3D<float> y_bcast3d(1, 5, 7, 1.5f);
629 Array3D<float> z_bcast3d(4, 1, 7, 5.5f);
630 auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0);
631 auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1);
632
633 ComputeAndCompareR3<float>(&builder, *concat1,
634 {x_data.get(), y_data.get(), z_data.get()},
635 ErrorSpec(1e-4));
636 }
637
638 INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
639 ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
640 R2BinarySpec{1, 1, 1, 1, 1},
641 R2BinarySpec{4, 3, 4, 3, 0},
642 R2BinarySpec{4, 3, 4, 3, 1},
643 R2BinarySpec{7, 128, 1, 128, 0},
644 R2BinarySpec{8, 127, 8, 1, 1}));
645
646 } // namespace
647 } // namespace xla
648