1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <initializer_list>
17 #include <memory>
18 
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace xla {
37 namespace {
38 
39 class TupleTest : public ClientLibraryTestBase {
40  public:
41   ErrorSpec error_spec_{0.0001};
42 };
43 
44 // Tests a tuple-shaped constant.
XLA_TEST_F(TupleTest,TupleConstant)45 XLA_TEST_F(TupleTest, TupleConstant) {
46   XlaBuilder builder(TestName());
47 
48   const float constant_scalar = 7.3f;
49   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
50   std::initializer_list<std::initializer_list<float>> constant_matrix = {
51       {1.1f, 2.2f, 3.5f},  // row 0
52       {4.8f, 5.0f, 6.7f},  // row 1
53   };
54   auto value = LiteralUtil::MakeTupleFromSlices(
55       {LiteralUtil::CreateR0<float>(constant_scalar),
56        LiteralUtil::CreateR1<float>(constant_vector),
57        LiteralUtil::CreateR2<float>(constant_matrix)});
58 
59   ConstantLiteral(&builder, value);
60   ComputeAndCompareTuple(&builder, value, {}, error_spec_);
61 }
62 
63 // Tests a tuple made of scalar constants.
XLA_TEST_F(TupleTest,TupleScalarConstant)64 XLA_TEST_F(TupleTest, TupleScalarConstant) {
65   XlaBuilder builder(TestName());
66 
67   const float constant_scalar1 = 7.3f;
68   const float constant_scalar2 = 1.2f;
69   auto value = LiteralUtil::MakeTupleFromSlices(
70       {LiteralUtil::CreateR0<float>(constant_scalar1),
71        LiteralUtil::CreateR0<float>(constant_scalar2)});
72 
73   ConstantLiteral(&builder, value);
74   ComputeAndCompareTuple(&builder, value, {}, error_spec_);
75 }
76 
77 // Tests the creation of tuple data.
XLA_TEST_F(TupleTest,TupleCreate)78 XLA_TEST_F(TupleTest, TupleCreate) {
79   XlaBuilder builder(TestName());
80 
81   const float constant_scalar = 7.3f;
82   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
83   std::initializer_list<std::initializer_list<float>> constant_matrix = {
84       {1.1f, 2.2f, 3.5f},  // row 0
85       {4.8f, 5.0f, 6.7f},  // row 1
86   };
87   Tuple(&builder, {ConstantR0<float>(&builder, constant_scalar),
88                    ConstantR1<float>(&builder, constant_vector),
89                    ConstantR2<float>(&builder, constant_matrix)});
90 
91   auto expected = LiteralUtil::MakeTupleFromSlices(
92       {LiteralUtil::CreateR0<float>(constant_scalar),
93        LiteralUtil::CreateR1<float>(constant_vector),
94        LiteralUtil::CreateR2<float>(constant_matrix)});
95   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
96 }
97 
98 // Tests the creation of tuple data.
XLA_TEST_F(TupleTest,TupleCreateWithZeroElementEntry)99 XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
100   XlaBuilder builder(TestName());
101 
102   Tuple(&builder,
103         {ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
104 
105   auto expected = LiteralUtil::MakeTupleFromSlices(
106       {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
107   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
108 }
109 
110 // Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest,EmptyTupleCreate)111 XLA_TEST_F(TupleTest, EmptyTupleCreate) {
112   XlaBuilder builder(TestName());
113   Tuple(&builder, {});
114   auto expected = LiteralUtil::MakeTuple({});
115   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
116 }
117 
118 // Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest,GetTupleElement)119 XLA_TEST_F(TupleTest, GetTupleElement) {
120   XlaBuilder builder(TestName());
121   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
122   std::initializer_list<std::initializer_list<float>> constant_matrix = {
123       {1.f, 2.f, 3.f},  // row 0
124       {4.f, 5.f, 6.f},  // row 1
125   };
126   auto tuple_data =
127       Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
128                        ConstantR2<float>(&builder, constant_matrix)});
129   GetTupleElement(tuple_data, 1);
130   ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
131                              error_spec_);
132 }
133 
134 // Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest,GetTupleElementWithZeroElements)135 XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
136   XlaBuilder builder(TestName());
137   auto tuple_data =
138       Tuple(&builder,
139             {ConstantR1<float>(&builder, {}),
140              ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 101))});
141   GetTupleElement(tuple_data, 1);
142   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
143 }
144 
XLA_TEST_F(TupleTest,GetTupleElementOfNonTupleFailsGracefully)145 XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
146   XlaBuilder builder(TestName());
147   auto value = ConstantR1<float>(&builder, {4.5f});
148   GetTupleElement(value, 1);
149   auto result_status = builder.Build();
150   EXPECT_FALSE(result_status.ok());
151   EXPECT_THAT(
152       result_status.status().error_message(),
153       ::testing::HasSubstr("Operand to GetTupleElement() is not a tuple"));
154 }
155 
156 // Extracts both elements from a tuple with GetTupleElement and then adds them
157 // together.
XLA_TEST_F(TupleTest,AddTupleElements)158 XLA_TEST_F(TupleTest, AddTupleElements) {
159   XlaBuilder builder(TestName());
160   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
161   std::initializer_list<std::initializer_list<float>> constant_matrix = {
162       {1.f, 2.f, 3.f},  // row 0
163       {4.f, 5.f, 6.f},  // row 1
164   };
165   auto tuple_data =
166       Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
167                        ConstantR2<float>(&builder, constant_matrix)});
168   auto vector_element = GetTupleElement(tuple_data, 0);
169   auto matrix_element = GetTupleElement(tuple_data, 1);
170   auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
171   auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
172   Add(matrix_element, vector_element,
173       /*broadcast_dimensions=*/{1});
174 
175   Array2D<float> expected({
176       {2.f, 4.f, 6.f},  // row 0
177       {5.f, 7.f, 9.f},  // row 1
178   });
179   ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3})));
180   ASSERT_TRUE(ShapeUtil::Equal(matrix_shape,
181                                ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3})));
182   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
183 }
184 
185 // Extracts both elements from a tuple and then puts them into a new tuple in
186 // the opposite order.
XLA_TEST_F(TupleTest,TupleGTEToTuple)187 XLA_TEST_F(TupleTest, TupleGTEToTuple) {
188   XlaBuilder builder(TestName());
189   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
190   std::initializer_list<std::initializer_list<float>> constant_matrix = {
191       {1.f, 2.f, 3.f},  // row 0
192       {4.f, 5.f, 6.f},  // row 1
193   };
194   auto tuple_data =
195       Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
196                        ConstantR2<float>(&builder, constant_matrix)});
197   Tuple(&builder,
198         {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
199   auto expected = LiteralUtil::MakeTupleFromSlices(
200       {LiteralUtil::CreateR2<float>(constant_matrix),
201        LiteralUtil::CreateR1<float>(constant_vector)});
202   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
203 }
204 
XLA_TEST_F(TupleTest,SelectBetweenPredTuples)205 XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
206   XlaBuilder b(TestName());
207   XlaOp v1, v2;
208 
209   for (bool direction : {false, true}) {
210     std::unique_ptr<GlobalData> v1_data =
211         CreateR0Parameter<float>(0.0f, /*parameter_number=*/0, /*name=*/"v1",
212                                  /*builder=*/&b, /*data_handle=*/&v1);
213     std::unique_ptr<GlobalData> v2_data =
214         CreateR0Parameter<float>(1.0f, /*parameter_number=*/1, /*name=*/"v2",
215                                  /*builder=*/&b, /*data_handle=*/&v2);
216     auto v1_gt = Gt(v1, v2);                 // false
217     auto v2_gt = Gt(v2, v1);                 // true
218     auto v1_v2 = Tuple(&b, {v1_gt, v2_gt});  // {false, true}
219     auto v2_v1 = Tuple(&b, {v2_gt, v1_gt});  // {true, false}
220     Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
221     auto expected = LiteralUtil::MakeTupleFromSlices(
222         {LiteralUtil::CreateR0<bool>(direction),
223          LiteralUtil::CreateR0<bool>(!direction)});
224 
225     ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
226                            error_spec_);
227   }
228 }
229 
230 // Builds two new tuples from an existing tuple (by means of GetTupleElement),
231 // then adds up the components of the new tuples.
XLA_TEST_F(TupleTest,TupleGTEToTupleToGTEAdd)232 XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
233   //
234   // v------           --(GTE 0)--             --(GTE 0)----------
235   //        \         /           \           /                   \
236   //         (tuple)--             (tuple01)--                     \
237   //        /   |     \           /           \                     \
238   // m------    |      --(GTE 1)--             --(GTE 1)------------ \
239   //            |                                                   \ \
240   //            |                                                    (add)
241   //            |                                                   / /
242   //            |--------(GTE 1)--             --(GTE 0)------------ /
243   //             \                \           /                     /
244   //              \                (tuple10)--                     /
245   //               \              /           \                   /
246   //                -----(GTE 0)--             --(GTE 1)----------
247   XlaBuilder builder(TestName());
248   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
249   std::initializer_list<std::initializer_list<float>> constant_matrix = {
250       {1.f, 2.f, 3.f},  // row 0
251       {4.f, 5.f, 6.f},  // row 1
252   };
253   auto tuple_data =
254       Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
255                        ConstantR2<float>(&builder, constant_matrix)});
256   auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0),
257                                       GetTupleElement(tuple_data, 1)});
258   auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1),
259                                       GetTupleElement(tuple_data, 0)});
260   auto vector_from_01 = GetTupleElement(new_tuple01, 0);
261   auto vector_from_10 = GetTupleElement(new_tuple10, 1);
262   auto matrix_from_01 = GetTupleElement(new_tuple01, 1);
263   auto matrix_from_10 = GetTupleElement(new_tuple10, 0);
264 
265   auto addvectors = Add(vector_from_01, vector_from_10);
266   auto addmatrices = Add(matrix_from_01, matrix_from_10);
267 
268   Add(addmatrices, addvectors,
269       /*broadcast_dimensions=*/{1});
270 
271   Array2D<float> expected({
272       {4.f, 8.f, 12.f},    // row 0
273       {10.f, 14.f, 18.f},  // row 1
274   });
275   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
276 }
277 
XLA_TEST_F(TupleTest,SelectBetweenTuplesOnFalse)278 XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
279   // Tests a selection between tuples with "false" path taken.
280   XlaBuilder builder(TestName());
281 
282   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
283   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
284   auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
285                                   ConstantR1<float>(&builder, vec2)});
286   auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
287                                   ConstantR1<float>(&builder, vec1)});
288 
289   Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
290   auto expected = LiteralUtil::MakeTupleFromSlices(
291       {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
292   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
293 }
294 
XLA_TEST_F(TupleTest,TuplesInAMap)295 XLA_TEST_F(TupleTest, TuplesInAMap) {
296   XlaComputation tuple_computation;
297   {
298     // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
299     //
300     // Need to put a select in there to prevent HLO-level optimizations from
301     // optimizing out the tuples.
302     XlaBuilder b("sort_square");
303     auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
304     auto x2 = Mul(x, x);
305     auto x_smaller_tuple = Tuple(&b, {x, x2});
306     auto x2_smaller_tuple = Tuple(&b, {x2, x});
307     auto sorted = Select(Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
308     auto smaller = GetTupleElement(sorted, 0);
309     auto greater = GetTupleElement(sorted, 1);
310     Add(greater, Mul(ConstantR0<float>(&b, 100.0f), smaller));
311     auto computation_status = b.Build();
312     ASSERT_IS_OK(computation_status.status());
313     tuple_computation = computation_status.ConsumeValueOrDie();
314   }
315 
316   XlaBuilder b(TestName());
317   auto input = ConstantR1<float>(&b, {-1.0f, 1.0f, 2.1f});
318   Map(&b, {input}, tuple_computation, {0});
319   ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
320 }
321 
XLA_TEST_F(TupleTest,SelectBetweenTuplesOnTrue)322 XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
323   // Tests a selection between tuples with "true" path taken.
324   XlaBuilder builder(TestName());
325 
326   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
327   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
328   auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
329                                   ConstantR1<float>(&builder, vec2)});
330   auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
331                                   ConstantR1<float>(&builder, vec1)});
332 
333   Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
334   auto expected = LiteralUtil::MakeTupleFromSlices(
335       {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
336   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
337 }
338 
XLA_TEST_F(TupleTest,SelectBetweenTuplesElementResult)339 XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
340   // Tests a selection between tuples but the final result is an element of the
341   // tuple, not the whole tuple.
342   XlaBuilder builder(TestName());
343 
344   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
345   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
346   auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
347                                   ConstantR1<float>(&builder, vec2)});
348   auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
349                                   ConstantR1<float>(&builder, vec1)});
350 
351   auto select = Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
352   GetTupleElement(select, 0);
353 
354   ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
355 }
356 
357 // Cascaded selects between tuple types.
XLA_TEST_F(TupleTest,SelectBetweenTuplesCascaded)358 XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) {
359   //
360   //                       vec1     vec2   vec2     vec1
361   //                        |        |      |        |
362   //                        |        |      |        |
363   //                        (tuple 12)      (tuple 21)
364   //                               \            /
365   //                                \          /
366   //                                 \        /
367   //  true  --            --(GTE 0)--(select 1)
368   //          \          /             |
369   //       (pred tuple)--              |          --(GTE 0)--
370   //          /          \             V         /           \
371   //  false --            --(GTE 1)--(select 2)--             --(add)
372   //                                 /           \           /
373   //                                /             --(GTE 1)--
374   //                               /
375   //                          (tuple 21)
376   XlaBuilder builder(TestName());
377 
378   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
379   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
380 
381   auto pred_tuple = Tuple(&builder, {ConstantR0<bool>(&builder, true),
382                                      ConstantR0<bool>(&builder, false)});
383   auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
384                                   ConstantR1<float>(&builder, vec2)});
385   auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
386                                   ConstantR1<float>(&builder, vec1)});
387 
388   auto select1 = Select(GetTupleElement(pred_tuple, 0), tuple12, tuple21);
389   auto select2 = Select(GetTupleElement(pred_tuple, 1), tuple21, select1);
390   Add(GetTupleElement(select2, 0), GetTupleElement(select2, 1));
391 
392   ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
393 }
394 
XLA_TEST_F(TupleTest,SelectBetweenTuplesReuseConstants)395 XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
396   // Similar to SelectBetweenTuples, but the constants are shared between the
397   // input tuples.
398   XlaBuilder builder(TestName());
399 
400   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
401   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
402   auto c1 = ConstantR1<float>(&builder, vec1);
403   auto c2 = ConstantR1<float>(&builder, vec2);
404   auto tuple12 = Tuple(&builder, {c1, c2});
405   auto tuple21 = Tuple(&builder, {c2, c1});
406 
407   Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
408 
409   auto expected = LiteralUtil::MakeTupleFromSlices(
410       {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
411   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
412 }
413 
XLA_TEST_F(TupleTest,NestedTuples)414 XLA_TEST_F(TupleTest, NestedTuples) {
415   XlaBuilder builder(TestName());
416   auto inner_tuple = Tuple(&builder, {ConstantR1<float>(&builder, {1.0, 2.0}),
417                                       ConstantR0<float>(&builder, 42.0)});
418   Tuple(&builder, {inner_tuple, ConstantR1<float>(&builder, {22.0, 44.0})});
419 
420   auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
421   auto expected_s = LiteralUtil::CreateR0<float>(42.0);
422   auto expected_inner_tuple =
423       LiteralUtil::MakeTuple({&expected_v1, &expected_s});
424   auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
425   auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
426 
427   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
428 }
429 
XLA_TEST_F(TupleTest,GetTupleElementOfNestedTuple)430 XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
431   XlaBuilder builder(TestName());
432 
433   Shape data_shape = ShapeUtil::MakeShape(F32, {3});
434   Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
435   Shape outer_tuple_shape =
436       ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
437 
438   auto input = Parameter(&builder, 0, outer_tuple_shape, "input");
439   auto gte0 = GetTupleElement(input, 0);
440   auto gte1 = GetTupleElement(gte0, 1);
441   Add(gte1, ConstantR1<float>(&builder, {10.0, 11.0, 12.0}));
442 
443   std::unique_ptr<GlobalData> data =
444       client_
445           ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
446               LiteralUtil::MakeTupleFromSlices({
447                   LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
448                   LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
449               }),
450               LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
451           }))
452           .ConsumeValueOrDie();
453 
454   std::vector<GlobalData*> arguments = {data.get()};
455   const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
456   ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
457 }
458 
XLA_TEST_F(TupleTest,ComplexTuples)459 XLA_TEST_F(TupleTest, ComplexTuples) {
460   XlaBuilder builder(TestName());
461   {
462     Shape c64r0 = ShapeUtil::MakeShape(C64, {});
463     Shape c64r1 = ShapeUtil::MakeShape(C64, {2});
464     Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2});
465     Shape arg0_shape = ShapeUtil::MakeTupleShape(
466         {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})});
467     auto input0 = Parameter(&builder, 0, arg0_shape, "input0");
468     auto t0 = GetTupleElement(input0, 0);
469     auto t1 = GetTupleElement(input0, 1);
470     auto t10 = GetTupleElement(t1, 0);
471     auto t11 = GetTupleElement(t1, 1);
472     auto sum = Add(Add(t10, t11, {1}), t0);
473     auto input1 = Parameter(&builder, 1, c64r1, "input1");
474     auto prod = Mul(input1, sum, {1});
475     Tuple(&builder, {Tuple(&builder, {prod, sum}),
476                      ConstantR0<complex64>(&builder, {123, 456})});
477   }
478 
479   std::unique_ptr<GlobalData> arg0 =
480       client_
481           ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
482               {LiteralUtil::CreateR0<complex64>({1, 2}),
483                LiteralUtil::MakeTupleFromSlices(
484                    {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
485                     LiteralUtil::CreateR2<complex64>(
486                         {{{100, 200}, {300, 400}},
487                          {{1000, 2000}, {3000, 4000}},
488                          {{10000, 20000}, {30000, 40000}}})})}))
489           .ConsumeValueOrDie();
490   std::unique_ptr<GlobalData> arg1 =
491       client_
492           ->TransferToServer(
493               LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
494           .ConsumeValueOrDie();
495   auto sum =
496       LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
497                                         {{1011, 2022}, {3031, 4042}},
498                                         {{10011, 20022}, {30031, 40042}}});
499   Literal prod(sum.shape());
500   ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
501                     return sum.Get<complex64>(indexes) *
502                            (indexes[indexes.size() - 1] == 0
503                                 ? complex64(1, 2)
504                                 : complex64(1, -2));
505                   })
506                   .ok());
507   auto expected = LiteralUtil::MakeTupleFromSlices(
508       {LiteralUtil::MakeTupleFromSlices({prod, sum}),
509        LiteralUtil::CreateR0<complex64>({123, 456})});
510   ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
511                          error_spec_);
512 }
513 
514 class TupleHloTest : public HloTestBase {};
515 
XLA_TEST_F(TupleHloTest,BitcastAfterGTE)516 XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
517   const char* testcase = R"(
518     HloModule m, is_scheduled=true
519 
520     ENTRY test {
521       name.1 = (f32[3]{0}) parameter(0)
522       get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0
523       bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1)
524       copy = f32[1,3]{1,0} copy(bitcast)
525       ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
526     }
527   )";
528   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
529   auto param =
530       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
531   auto result = ExecuteNoHloPasses(std::move(module), {&param});
532   EXPECT_TRUE(LiteralTestUtil::Equal(
533       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
534       result));
535 }
536 
537 // Disabled on interpreter due to lack of outfeed.
XLA_TEST_F(TupleHloTest,DISABLED_ON_INTERPRETER (NonAmbiguousTopLevelAllocation))538 XLA_TEST_F(TupleHloTest,
539            DISABLED_ON_INTERPRETER(NonAmbiguousTopLevelAllocation)) {
540   const char* testcase = R"(
541     HloModule tuple
542 
543     ENTRY main {
544       a = f32[2] parameter(0)
545       b = f32[2] parameter(1)
546       c = f32[2] parameter(2)
547       d = f32[2] parameter(3)
548       cond = pred[] parameter(4)
549 
550       tup0 = (f32[2],f32[2]) tuple(a, b)
551       tup1 = (f32[2],f32[2]) tuple(c, d)
552 
553       s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1)
554       gte = f32[2] get-tuple-element(s), index=0
555       tuple = (f32[2]) tuple(gte)
556       token0 = token[] after-all()
557       ROOT outfeed = token[] outfeed(tuple, token0)
558     }
559   )";
560   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
561   auto param0 = LiteralUtil::CreateR1<float>({1, 2});
562   auto param1 = LiteralUtil::CreateR1<float>({2, 3});
563   auto param4 = LiteralUtil::CreateR0<bool>(false);
564   // Put execution on a separate thread so we can block on outfeed.
565   std::unique_ptr<tensorflow::Thread> thread(
566       tensorflow::Env::Default()->StartThread(
567           tensorflow::ThreadOptions(), "execute_thread", [&] {
568             TF_EXPECT_OK(Execute(std::move(module),
569                                  {&param0, &param1, &param1, &param0, &param4})
570                              .status());
571           }));
572   auto expected =
573       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
574   auto literal = Literal::CreateFromShape(expected.shape());
575   TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
576       backend().default_stream_executor(), expected.shape(), literal));
577   EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
578 }
579 
580 }  // namespace
581 }  // namespace xla
582