1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/array3d.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/reference_util.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 #include "tensorflow/core/platform/test.h"
32 
33 namespace xla {
34 namespace {
35 
36 using ConcatTest = ClientLibraryTestBase;
37 using ::testing::HasSubstr;
38 
39 // Concatenate expects at least one argument.
XLA_TEST_F(ConcatTest,Concat_Nothing)40 XLA_TEST_F(ConcatTest, Concat_Nothing) {
41   XlaBuilder builder(TestName());
42   ConcatInDim(&builder, {}, 0);
43   StatusOr<XlaComputation> computation_status = builder.Build();
44   ASSERT_FALSE(computation_status.ok());
45   EXPECT_THAT(computation_status.status().ToString(),
46               HasSubstr("Concatenate expects at least one argument"));
47 }
48 
49 // Concatenate with one argument works.
XLA_TEST_F(ConcatTest,Concat_R1_With_Nothing)50 XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
51   XlaBuilder builder(TestName());
52   auto a = ConstantR1<float>(&builder, {42.0, 64.0});
53   ConcatInDim(&builder, {a}, 0);
54 
55   std::vector<float> expected = {42, 64};
56   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
57 }
58 
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_Nothing)59 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
60   XlaBuilder builder(TestName());
61   auto a = ConstantR1<float>(&builder, {});
62   ConcatInDim(&builder, {a}, 0);
63 
64   std::vector<float> expected = {};
65   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
66 }
67 
68 // Show that we can't concatenate R0 with R0 because we can't name the dimension
69 // to concatenate on.
XLA_TEST_F(ConcatTest,CannotConcatR0WithR0)70 XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
71   XlaBuilder builder(TestName());
72   auto a = ConstantR0<float>(&builder, 42.0);
73   auto b = ConstantR0<float>(&builder, 64.0);
74   ConcatInDim(&builder, {a, b}, 0);
75   StatusOr<XlaComputation> computation_status = builder.Build();
76   ASSERT_FALSE(computation_status.ok());
77   EXPECT_THAT(computation_status.status().ToString(),
78               HasSubstr("out of bounds: 0"));
79 }
80 
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_R1_L0)81 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
82   XlaBuilder builder(TestName());
83   auto a = ConstantR1<float>(&builder, {});
84   auto b = ConstantR1<float>(&builder, {});
85   ConcatInDim(&builder, {a, b}, 0);
86 
87   std::vector<float> expected = {};
88   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
89 }
90 
XLA_TEST_F(ConcatTest,Concat_R1_L0_With_R1_L1)91 XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
92   XlaBuilder builder(TestName());
93   auto a = ConstantR1<float>(&builder, {});
94   auto b = ConstantR1<float>(&builder, {256.0});
95   ConcatInDim(&builder, {a, b}, 0);
96 
97   std::vector<float> expected = {256};
98   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
99 }
100 
XLA_TEST_F(ConcatTest,Concat_R1_L2_With_R1_L0)101 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
102   XlaBuilder builder(TestName());
103   auto a = ConstantR1<float>(&builder, {42.0, 64.0});
104   auto b = ConstantR1<float>(&builder, {});
105   ConcatInDim(&builder, {a, b}, 0);
106 
107   std::vector<float> expected = {42, 64};
108   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
109 }
110 
XLA_TEST_F(ConcatTest,Concat_R1_L2_With_R1_L1)111 XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
112   XlaBuilder builder(TestName());
113   auto a = ConstantR1<float>(&builder, {42.0, 64.0});
114   auto b = ConstantR1<float>(&builder, {256.0});
115   ConcatInDim(&builder, {a, b}, 0);
116 
117   std::vector<float> expected = {42, 64, 256};
118   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
119 }
120 
XLA_TEST_F(ConcatTest,Concat_R1_L253_With_R1_L7)121 XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
122   std::vector<float> lhs(253);
123   std::vector<float> rhs(7);
124   std::vector<float> expected(253 + 7);
125   for (int i = 0; i < 253; ++i) {
126     expected[i] = lhs[i] = i + 1;
127   }
128   for (int i = 0; i < 7; ++i) {
129     expected[253 + i] = rhs[i] = 253 + i + 1;
130   }
131 
132   XlaBuilder builder(TestName());
133   auto a = ConstantR1<float>(&builder, lhs);
134   auto b = ConstantR1<float>(&builder, rhs);
135   ConcatInDim(&builder, {a, b}, 0);
136 
137   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
138 }
139 
XLA_TEST_F(ConcatTest,Concat_0x0_With_0x0)140 XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
141   for (int dim : {0, 1}) {
142     XlaBuilder builder(TestName());
143     auto a = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
144     auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
145     ConcatInDim(&builder, {a, b}, dim);
146 
147     ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
148                                ErrorSpec(0.0001));
149   }
150 }
151 
XLA_TEST_F(ConcatTest,Concat_1x1_With_1x1_InDim0)152 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
153   XlaBuilder builder(TestName());
154   auto a_array = CreatePatternedMatrix(1, 1);
155   auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
156   auto a = ConstantR2FromArray2D(&builder, *a_array);
157   auto b = ConstantR2FromArray2D(&builder, *b_array);
158   ConcatInDim(&builder, {a, b}, 0);
159 
160   Array2D<float> expected({
161       {0},
162       {64},
163   });
164   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
165 }
166 
XLA_TEST_F(ConcatTest,Concat_1x1_With_1x1_InDim1)167 XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
168   XlaBuilder builder(TestName());
169   auto a_array = CreatePatternedMatrix(1, 1);
170   auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
171   auto a = ConstantR2FromArray2D(&builder, *a_array);
172   auto b = ConstantR2FromArray2D(&builder, *b_array);
173   ConcatInDim(&builder, {a, b}, 1);
174 
175   Array2D<float> expected({
176       {0, 64},
177   });
178   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
179 }
180 
XLA_TEST_F(ConcatTest,Concat2x0With2x5)181 XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
182   XlaBuilder builder(TestName());
183   auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
184   auto a = ConstantR2FromArray2D(&builder, Array2D<float>(2, 0));
185   auto b = ConstantR2FromArray2D(&builder, *b_array);
186   ConcatInDim(&builder, {a, b}, 1);
187 
188   ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
189 }
190 
XLA_TEST_F(ConcatTest,Concat2x3With2x5)191 XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
192   XlaBuilder builder(TestName());
193   auto a_array = CreatePatternedMatrix(2, 3);
194   auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
195   auto a = ConstantR2FromArray2D(&builder, *a_array);
196   auto b = ConstantR2FromArray2D(&builder, *b_array);
197   ConcatInDim(&builder, {a, b}, 1);
198 
199   Array2D<float> expected({
200       {0, 1, 2, 64, 65, 66, 67, 68},
201       {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
202   });
203   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
204 }
205 
XLA_TEST_F(ConcatTest,Concat3x2With0x2)206 XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
207   XlaBuilder builder(TestName());
208   auto a_array = CreatePatternedMatrix(3, 2);
209   auto a = ConstantR2FromArray2D(&builder, *a_array);
210   auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 2));
211   ConcatInDim(&builder, {a, b}, 0);
212 
213   ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
214 }
215 
XLA_TEST_F(ConcatTest,Concat3x2With5x2)216 XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
217   XlaBuilder builder(TestName());
218   auto a_array = CreatePatternedMatrix(3, 2);
219   auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
220   auto a = ConstantR2FromArray2D(&builder, *a_array);
221   auto b = ConstantR2FromArray2D(&builder, *b_array);
222   ConcatInDim(&builder, {a, b}, 0);
223 
224   Array2D<float> expected({
225       {0, 1},
226       {1000, 1001},
227       {2000, 2001},
228       {64, 65},
229       {1064, 1065},
230       {2064, 2065},
231       {3064, 3065},
232       {4064, 4065},
233   });
234   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
235 }
236 
XLA_TEST_F(ConcatTest,Concat_R3_3x0x2_3x0x1)237 XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
238   XlaBuilder builder(TestName());
239   auto a = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 2));
240   auto b = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 1));
241   ConcatInDim(&builder, {a, b}, 2);
242   ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
243                              ErrorSpec(0.0001));
244 }
245 
XLA_TEST_F(ConcatTest,Concat_R3_3x1x2_3x1x1)246 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
247   XlaBuilder builder(TestName());
248   Array3D<float> a_array({
249       // 3x1x2
250       {{0, 1}},
251       {{2, 3}},
252       {{4, 5}},
253   });
254   Array3D<float> b_array({
255       // 3x1x1
256       {{6}},
257       {{7}},
258       {{8}},
259   });
260   auto a = ConstantR3FromArray3D(&builder, a_array);
261   auto b = ConstantR3FromArray3D(&builder, b_array);
262   ConcatInDim(&builder, {a, b}, 2);
263 
264   Array3D<float> expected({
265       {{0, 1, 6}},
266       {{2, 3, 7}},
267       {{4, 5, 8}},
268   });
269   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
270 }
271 
XLA_TEST_F(ConcatTest,Concat_R1_1x1_1x1_1x1)272 XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
273   XlaBuilder builder(TestName());
274   auto a = ConstantR1<float>(&builder, {42.0});
275   auto b = ConstantR1<float>(&builder, {64.0});
276   auto c = ConstantR1<float>(&builder, {256.0});
277   ConcatInDim(&builder, {a, b, c}, 0);
278 
279   std::vector<float> expected = {42, 64, 256};
280   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
281 }
282 
XLA_TEST_F(ConcatTest,Concat_R3_3x1x2_3x1x1_3x1x1)283 XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
284   XlaBuilder builder(TestName());
285   Array3D<float> a_array({
286       // 3x1x2
287       {{0, 1}},
288       {{4, 5}},
289       {{8, 9}},
290   });
291   Array3D<float> b_array({
292       // 3x1x1
293       {{2}},
294       {{6}},
295       {{10}},
296   });
297   Array3D<float> c_array({
298       // 3x1x1
299       {{3}},
300       {{7}},
301       {{11}},
302   });
303   auto a = ConstantR3FromArray3D(&builder, a_array);
304   auto b = ConstantR3FromArray3D(&builder, b_array);
305   auto c = ConstantR3FromArray3D(&builder, c_array);
306   ConcatInDim(&builder, {a, b, c}, 2);
307 
308   Array3D<float> expected({
309       {{0, 1, 2, 3}},
310       {{4, 5, 6, 7}},
311       {{8, 9, 10, 11}},
312   });
313   ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
314 }
315 
XLA_TEST_F(ConcatTest,DoubleConcatLeftAssociative)316 XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
317   XlaBuilder builder(TestName());
318   auto a = ConstantR1<float>(&builder, {42.0});
319   auto b = ConstantR1<float>(&builder, {64.0});
320   auto c = ConstantR1<float>(&builder, {256.0});
321   // concatenated = (a concat b) concat c
322   ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0);
323 
324   std::vector<float> expected = {42, 64, 256};
325   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
326 }
327 
XLA_TEST_F(ConcatTest,DoubleConcatRightAssociative)328 XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
329   XlaBuilder builder(TestName());
330   auto a = ConstantR1<float>(&builder, {42.0});
331   auto b = ConstantR1<float>(&builder, {64.0});
332   auto c = ConstantR1<float>(&builder, {256.0});
333   // concatenated = a concat (b concat c)
334   ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0);
335 
336   std::vector<float> expected = {42, 64, 256};
337   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
338 }
339 
XLA_TEST_F(ConcatTest,Concat_1x1024_With_1x1024_InDim0)340 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
341   Array2D<float> lhs(1, 1024);
342   Array2D<float> rhs(1, 1024);
343   for (int i = 0; i < 1024; ++i) {
344     lhs(0, i) = i;
345     rhs(0, i) = i + 1024;
346   }
347 
348   XlaBuilder builder(TestName());
349   auto a = ConstantR2FromArray2D<float>(&builder, lhs);
350   auto b = ConstantR2FromArray2D<float>(&builder, rhs);
351   ConcatInDim(&builder, {a, b}, 0);
352 
353   Array2D<float> expected(2, 1024);
354   for (int i = 0; i < 1024; ++i) {
355     expected(0, i) = i;
356     expected(1, i) = i + 1024;
357   }
358   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
359 }
360 
XLA_TEST_F(ConcatTest,Concat_1x1024_With_1x1024_InDim1)361 XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
362   Array2D<float> lhs(1, 1024);
363   Array2D<float> rhs(1, 1024);
364   for (int i = 0; i < 1024; ++i) {
365     lhs(0, i) = i;
366     rhs(0, i) = i + 1024;
367   }
368 
369   XlaBuilder builder(TestName());
370   auto a = ConstantR2FromArray2D<float>(&builder, lhs);
371   auto b = ConstantR2FromArray2D<float>(&builder, rhs);
372   ConcatInDim(&builder, {a, b}, 1);
373 
374   Array2D<float> expected(1, 2048);
375   for (int i = 0; i < 1024; ++i) {
376     expected(0, i) = i;
377     expected(0, i + 1024) = i + 1024;
378   }
379   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
380 }
381 
XLA_TEST_F(ConcatTest,Concat_64x64_With_64x2)382 XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
383   Array2D<float> lhs(64, 64);
384   Array2D<float> rhs(64, 2);
385   for (int i0 = 0; i0 < 64; ++i0) {
386     for (int i1 = 0; i1 < 64; ++i1) {
387       lhs(i0, i1) = (i0 << 10) | i1;
388     }
389     for (int i1 = 0; i1 < 2; ++i1) {
390       rhs(i0, i1) = (i0 << 10) | (i1 + 64);
391     }
392   }
393 
394   XlaBuilder builder(TestName());
395   auto a = ConstantR2FromArray2D<float>(&builder, lhs);
396   auto b = ConstantR2FromArray2D<float>(&builder, rhs);
397   ConcatInDim(&builder, {a, b}, 1);
398 
399   Array2D<float> expected(64, 66);
400   for (int i0 = 0; i0 < 64; ++i0) {
401     for (int i1 = 0; i1 < 66; ++i1) {
402       expected(i0, i1) = (i0 << 10) | i1;
403     }
404   }
405   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
406 }
407 
408 // Show that we can't concatenate with an opaques.
XLA_TEST_F(ConcatTest,CannotConcatOpaques)409 XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
410   XlaBuilder builder(TestName());
411   auto opaque_shape = ShapeUtil::MakeOpaqueShape();
412   auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
413   auto x = Parameter(&builder, 0, r1f32, "x");
414   auto y = Parameter(&builder, 1, opaque_shape, "y");
415   ConcatInDim(&builder, {x, y}, 0);
416   StatusOr<XlaComputation> computation_status = builder.Build();
417   ASSERT_FALSE(computation_status.ok());
418   EXPECT_THAT(
419       computation_status.status().ToString(),
420       HasSubstr("Expected array argument for operand of concatenation"));
421 }
422 
423 // Show that we can't concatenate with tokens.
XLA_TEST_F(ConcatTest,CannotConcatTokens)424 XLA_TEST_F(ConcatTest, CannotConcatTokens) {
425   XlaBuilder builder(TestName());
426   auto token_shape = ShapeUtil::MakeTokenShape();
427   auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
428   auto x = Parameter(&builder, 0, r1f32, "x");
429   auto y = Parameter(&builder, 1, token_shape, "y");
430   ConcatInDim(&builder, {x, y}, 0);
431   StatusOr<XlaComputation> computation_status = builder.Build();
432   ASSERT_FALSE(computation_status.ok());
433   EXPECT_THAT(
434       computation_status.status().ToString(),
435       HasSubstr("Expected array argument for operand of concatenation"));
436 }
437 
XLA_TEST_F(ConcatTest,ConcatSeveralBoxedPredicates)438 XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
439   XlaBuilder builder(TestName());
440   auto p0 = ConstantR1<bool>(&builder, {true});
441   auto p1 = ConstantR1<bool>(&builder, {false});
442   auto p2 = ConstantR1<bool>(&builder, {true});
443   ConcatInDim(&builder, {p0, p1, p2}, 0);
444 
445   bool expected[] = {true, false, true};
446   ComputeAndCompareR1<bool>(&builder, expected, {});
447 }
448 
XLA_TEST_F(ConcatTest,ConcatSeveralR1S32s)449 XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
450   XlaBuilder builder(TestName());
451   auto a0 = ConstantR1<int32>(&builder, {1});
452   auto a1 = ConstantR1<int32>(&builder, {2, 3});
453   auto a2 = ConstantR1<int32>(&builder, {4, 5, 6});
454   auto a3 = ConstantR1<int32>(&builder, {7, 8, 9, 10});
455   ConcatInDim(&builder, {a0, a1, a2, a3}, 0);
456 
457   std::vector<int32> expected(10);
458   std::iota(expected.begin(), expected.end(), 1);
459   ComputeAndCompareR1<int32>(&builder, expected, {});
460 }
461 
XLA_TEST_F(ConcatTest,ConcatR3WeirdDims)462 XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
463   XlaBuilder builder(TestName());
464 
465   Array3D<float> arr0(9, 17, 1);
466   arr0.Fill(1);
467 
468   Array3D<float> arr1(9, 17, 256);
469   arr1.Fill(2);
470 
471   Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
472   for (int64 i = 0; i < expected.n1(); ++i) {
473     for (int64 j = 0; j < expected.n2(); ++j) {
474       int64 kk = 0;
475       for (const Array3D<float>& arr : {arr0, arr1}) {
476         for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
477           expected(i, j, kk) = arr(i, j, k);
478         }
479       }
480     }
481   }
482 
483   XlaOp h0;
484   auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
485                                      &builder, &h0);
486   XlaOp h1;
487   auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
488                                      &builder, &h1);
489 
490   ConcatInDim(&builder, {h0, h1}, 2);
491 
492   ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
493 }
494 
XLA_TEST_F(ConcatTest,ConcatDeeplyNested)495 XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
496   XlaBuilder builder(TestName());
497   auto a_literal = LiteralUtil::CreateR1<float>({256.0});
498   auto a = Parameter(&builder, 0, a_literal.shape(), "x");
499   auto b = ConcatInDim(&builder, {a, a}, 0);
500   auto c = ConcatInDim(&builder, {b, b}, 0);
501   auto d = ConcatInDim(&builder, {c, c}, 0);
502   auto e = ConcatInDim(&builder, {d, d}, 0);
503   auto f = ConcatInDim(&builder, {e, e}, 0);
504   auto g = ConcatInDim(&builder, {f, f}, 0);
505   auto h = ConcatInDim(&builder, {g, g}, 0);
506   auto i = ConcatInDim(&builder, {h, h}, 0);
507   auto j = ConcatInDim(&builder, {i, i}, 0);
508   auto k = ConcatInDim(&builder, {j, j}, 0);
509   auto l = ConcatInDim(&builder, {k, k}, 0);
510   auto m = ConcatInDim(&builder, {l, l}, 0);
511   auto n = ConcatInDim(&builder, {m, m}, 0);
512   auto o = ConcatInDim(&builder, {n, n}, 0);
513   auto p = ConcatInDim(&builder, {o, o}, 0);
514   auto q = ConcatInDim(&builder, {p, p}, 0);
515   ConcatInDim(&builder, {q, q}, 0);
516   std::vector<float> expected(131072, 256.0);
517   auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie();
518   ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
519 }
520 
521 // Describes a binary rank-2 concatenation test.
522 struct R2BinarySpec {
523   int64 lhs_dim0;
524   int64 lhs_dim1;
525   int64 rhs_dim0;
526   int64 rhs_dim1;
527   int64 concat_dimension;
528 };
529 
530 // TEST_P harness for binary rank-2 concatenation.
531 class ConcatR2BinaryTest : public ClientLibraryTestBase,
532                            public ::testing::WithParamInterface<R2BinarySpec> {
533 };
534 
TEST_P(ConcatR2BinaryTest,DoIt)535 TEST_P(ConcatR2BinaryTest, DoIt) {
536   const R2BinarySpec& spec = GetParam();
537   Array2D<int32> lhs(spec.lhs_dim0, spec.lhs_dim1);
538   lhs.FillUnique();
539   Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1);
540   rhs.FillUnique(1000);
541 
542   XlaBuilder builder(TestName());
543   auto a0 = ConstantR2FromArray2D<int32>(&builder, lhs);
544   auto a1 = ConstantR2FromArray2D<int32>(&builder, rhs);
545   ConcatInDim(&builder, {a0, a1}, spec.concat_dimension);
546 
547   std::unique_ptr<Array2D<int32>> expected =
548       ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
549   ComputeAndCompareR2<int32>(&builder, *expected, {});
550 }
551 
552 // Regression test for b/31944287. x*y is used (at the same index) by all
553 // operands of the concat. We should emit x*y in three incoming basic blocks of
554 // the concat because these basic blocks are not control-equivalent.
555 //
556 //      x*y
557 //    /  |   \
558 // add1 add2 add3
559 //    \  |   /
560 //     concat
XLA_TEST_F(ConcatTest,ConcatOperandsOfSameOperand)561 XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
562   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
563   auto x_literal = LiteralUtil::CreateR0<float>(2.f);
564   auto y_literal = LiteralUtil::CreateR0<float>(3.f);
565   auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
566   auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
567 
568   XlaBuilder builder(TestName());
569   auto x = Parameter(&builder, 0, f32_scalar, "x");
570   auto y = Parameter(&builder, 1, f32_scalar, "y");
571   auto mul = Mul(x, y);
572   auto add1 = Add(mul, ConstantR1<float>(&builder, {1.f, 2.f}));
573   auto add2 = Add(mul, ConstantR1<float>(&builder, {3.f, 4.f}));
574   auto add3 = Add(mul, ConstantR1<float>(&builder, {5.f, 6.f}));
575   ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0);
576 
577   ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
578                              {x_data.get(), y_data.get()}, ErrorSpec(1e-4));
579 }
580 
581 // Test that the HLO optimization to replace a concat of a bradcasted scalar
582 // produces the correct result in rank 1.
XLA_TEST_F(ConcatTest,ConcatBroadcastArgument)583 XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
584   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
585   auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
586   auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
587   auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
588   auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
589   auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
590   auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
591 
592   XlaBuilder builder(TestName());
593   auto x = Parameter(&builder, 0, x_literal.shape(), "x");
594   auto y = Parameter(&builder, 1, f32_scalar, "y");
595   auto z = Parameter(&builder, 2, f32_scalar, "z");
596   auto bcast = Broadcast(y, {5});
597   auto bcast2 = Broadcast(z, {3});
598   auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0);
599   ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0);
600 
601   ComputeAndCompareR1<float>(
602       &builder,
603       {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f},
604       {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4));
605 }
606 
607 // Test that the HLO optimization to replace a concat of a bradcasted scalar
608 // produces the correct result in rank 3 with both high and low padding in
609 // different dimensions.
XLA_TEST_F(ConcatTest,ConcatBroadcastArgumentR3)610 XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
611   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
612   Array3D<float> x3d(3, 5, 7, 3.14f);
613   auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
614   auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
615   auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
616   auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
617   auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
618   auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
619 
620   XlaBuilder builder(TestName());
621   auto x = Parameter(&builder, 0, x_literal.shape(), "x");
622   auto y = Parameter(&builder, 1, f32_scalar, "y");
623   auto z = Parameter(&builder, 2, f32_scalar, "y");
624   auto y_bcast = Broadcast(y, {1, 5, 7});
625   auto z_bcast = Broadcast(z, {4, 1, 7});
626   auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0);
627   ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1);
628   Array3D<float> y_bcast3d(1, 5, 7, 1.5f);
629   Array3D<float> z_bcast3d(4, 1, 7, 5.5f);
630   auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0);
631   auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1);
632 
633   ComputeAndCompareR3<float>(&builder, *concat1,
634                              {x_data.get(), y_data.get(), z_data.get()},
635                              ErrorSpec(1e-4));
636 }
637 
638 INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
639                         ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
640                                           R2BinarySpec{1, 1, 1, 1, 1},
641                                           R2BinarySpec{4, 3, 4, 3, 0},
642                                           R2BinarySpec{4, 3, 4, 3, 1},
643                                           R2BinarySpec{7, 128, 1, 128, 0},
644                                           R2BinarySpec{8, 127, 8, 1, 1}));
645 
646 }  // namespace
647 }  // namespace xla
648