1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <initializer_list>
17 #include <memory>
18
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace xla {
37 namespace {
38
39 class TupleTest : public ClientLibraryTestBase {
40 public:
41 ErrorSpec error_spec_{0.0001};
42 };
43
44 // Tests a tuple-shaped constant.
XLA_TEST_F(TupleTest,TupleConstant)45 XLA_TEST_F(TupleTest, TupleConstant) {
46 XlaBuilder builder(TestName());
47
48 const float constant_scalar = 7.3f;
49 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
50 std::initializer_list<std::initializer_list<float>> constant_matrix = {
51 {1.1f, 2.2f, 3.5f}, // row 0
52 {4.8f, 5.0f, 6.7f}, // row 1
53 };
54 auto value = LiteralUtil::MakeTupleFromSlices(
55 {LiteralUtil::CreateR0<float>(constant_scalar),
56 LiteralUtil::CreateR1<float>(constant_vector),
57 LiteralUtil::CreateR2<float>(constant_matrix)});
58
59 ConstantLiteral(&builder, value);
60 ComputeAndCompareTuple(&builder, value, {}, error_spec_);
61 }
62
63 // Tests a tuple made of scalar constants.
XLA_TEST_F(TupleTest,TupleScalarConstant)64 XLA_TEST_F(TupleTest, TupleScalarConstant) {
65 XlaBuilder builder(TestName());
66
67 const float constant_scalar1 = 7.3f;
68 const float constant_scalar2 = 1.2f;
69 auto value = LiteralUtil::MakeTupleFromSlices(
70 {LiteralUtil::CreateR0<float>(constant_scalar1),
71 LiteralUtil::CreateR0<float>(constant_scalar2)});
72
73 ConstantLiteral(&builder, value);
74 ComputeAndCompareTuple(&builder, value, {}, error_spec_);
75 }
76
77 // Tests the creation of tuple data.
XLA_TEST_F(TupleTest,TupleCreate)78 XLA_TEST_F(TupleTest, TupleCreate) {
79 XlaBuilder builder(TestName());
80
81 const float constant_scalar = 7.3f;
82 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
83 std::initializer_list<std::initializer_list<float>> constant_matrix = {
84 {1.1f, 2.2f, 3.5f}, // row 0
85 {4.8f, 5.0f, 6.7f}, // row 1
86 };
87 Tuple(&builder, {ConstantR0<float>(&builder, constant_scalar),
88 ConstantR1<float>(&builder, constant_vector),
89 ConstantR2<float>(&builder, constant_matrix)});
90
91 auto expected = LiteralUtil::MakeTupleFromSlices(
92 {LiteralUtil::CreateR0<float>(constant_scalar),
93 LiteralUtil::CreateR1<float>(constant_vector),
94 LiteralUtil::CreateR2<float>(constant_matrix)});
95 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
96 }
97
98 // Tests the creation of tuple data.
XLA_TEST_F(TupleTest,TupleCreateWithZeroElementEntry)99 XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
100 XlaBuilder builder(TestName());
101
102 Tuple(&builder,
103 {ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
104
105 auto expected = LiteralUtil::MakeTupleFromSlices(
106 {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
107 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
108 }
109
110 // Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest,EmptyTupleCreate)111 XLA_TEST_F(TupleTest, EmptyTupleCreate) {
112 XlaBuilder builder(TestName());
113 Tuple(&builder, {});
114 auto expected = LiteralUtil::MakeTuple({});
115 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
116 }
117
118 // Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest,GetTupleElement)119 XLA_TEST_F(TupleTest, GetTupleElement) {
120 XlaBuilder builder(TestName());
121 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
122 std::initializer_list<std::initializer_list<float>> constant_matrix = {
123 {1.f, 2.f, 3.f}, // row 0
124 {4.f, 5.f, 6.f}, // row 1
125 };
126 auto tuple_data =
127 Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
128 ConstantR2<float>(&builder, constant_matrix)});
129 GetTupleElement(tuple_data, 1);
130 ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
131 error_spec_);
132 }
133
134 // Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest,GetTupleElementWithZeroElements)135 XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
136 XlaBuilder builder(TestName());
137 auto tuple_data =
138 Tuple(&builder,
139 {ConstantR1<float>(&builder, {}),
140 ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 101))});
141 GetTupleElement(tuple_data, 1);
142 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
143 }
144
XLA_TEST_F(TupleTest,GetTupleElementOfNonTupleFailsGracefully)145 XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
146 XlaBuilder builder(TestName());
147 auto value = ConstantR1<float>(&builder, {4.5f});
148 GetTupleElement(value, 1);
149 auto result_status = builder.Build();
150 EXPECT_FALSE(result_status.ok());
151 EXPECT_THAT(
152 result_status.status().error_message(),
153 ::testing::HasSubstr("Operand to GetTupleElement() is not a tuple"));
154 }
155
156 // Extracts both elements from a tuple with GetTupleElement and then adds them
157 // together.
XLA_TEST_F(TupleTest,AddTupleElements)158 XLA_TEST_F(TupleTest, AddTupleElements) {
159 XlaBuilder builder(TestName());
160 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
161 std::initializer_list<std::initializer_list<float>> constant_matrix = {
162 {1.f, 2.f, 3.f}, // row 0
163 {4.f, 5.f, 6.f}, // row 1
164 };
165 auto tuple_data =
166 Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
167 ConstantR2<float>(&builder, constant_matrix)});
168 auto vector_element = GetTupleElement(tuple_data, 0);
169 auto matrix_element = GetTupleElement(tuple_data, 1);
170 auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
171 auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
172 Add(matrix_element, vector_element,
173 /*broadcast_dimensions=*/{1});
174
175 Array2D<float> expected({
176 {2.f, 4.f, 6.f}, // row 0
177 {5.f, 7.f, 9.f}, // row 1
178 });
179 ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3})));
180 ASSERT_TRUE(ShapeUtil::Equal(matrix_shape,
181 ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3})));
182 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
183 }
184
185 // Extracts both elements from a tuple and then puts them into a new tuple in
186 // the opposite order.
XLA_TEST_F(TupleTest,TupleGTEToTuple)187 XLA_TEST_F(TupleTest, TupleGTEToTuple) {
188 XlaBuilder builder(TestName());
189 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
190 std::initializer_list<std::initializer_list<float>> constant_matrix = {
191 {1.f, 2.f, 3.f}, // row 0
192 {4.f, 5.f, 6.f}, // row 1
193 };
194 auto tuple_data =
195 Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
196 ConstantR2<float>(&builder, constant_matrix)});
197 Tuple(&builder,
198 {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
199 auto expected = LiteralUtil::MakeTupleFromSlices(
200 {LiteralUtil::CreateR2<float>(constant_matrix),
201 LiteralUtil::CreateR1<float>(constant_vector)});
202 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
203 }
204
XLA_TEST_F(TupleTest,SelectBetweenPredTuples)205 XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
206 XlaBuilder b(TestName());
207 XlaOp v1, v2;
208
209 for (bool direction : {false, true}) {
210 std::unique_ptr<GlobalData> v1_data =
211 CreateR0Parameter<float>(0.0f, /*parameter_number=*/0, /*name=*/"v1",
212 /*builder=*/&b, /*data_handle=*/&v1);
213 std::unique_ptr<GlobalData> v2_data =
214 CreateR0Parameter<float>(1.0f, /*parameter_number=*/1, /*name=*/"v2",
215 /*builder=*/&b, /*data_handle=*/&v2);
216 auto v1_gt = Gt(v1, v2); // false
217 auto v2_gt = Gt(v2, v1); // true
218 auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true}
219 auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
220 Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
221 auto expected = LiteralUtil::MakeTupleFromSlices(
222 {LiteralUtil::CreateR0<bool>(direction),
223 LiteralUtil::CreateR0<bool>(!direction)});
224
225 ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
226 error_spec_);
227 }
228 }
229
230 // Builds two new tuples from an existing tuple (by means of GetTupleElement),
231 // then adds up the components of the new tuples.
XLA_TEST_F(TupleTest,TupleGTEToTupleToGTEAdd)232 XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
233 //
234 // v------ --(GTE 0)-- --(GTE 0)----------
235 // \ / \ / \
236 // (tuple)-- (tuple01)-- \
237 // / | \ / \ \
238 // m------ | --(GTE 1)-- --(GTE 1)------------ \
239 // | \ \
240 // | (add)
241 // | / /
242 // |--------(GTE 1)-- --(GTE 0)------------ /
243 // \ \ / /
244 // \ (tuple10)-- /
245 // \ / \ /
246 // -----(GTE 0)-- --(GTE 1)----------
247 XlaBuilder builder(TestName());
248 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
249 std::initializer_list<std::initializer_list<float>> constant_matrix = {
250 {1.f, 2.f, 3.f}, // row 0
251 {4.f, 5.f, 6.f}, // row 1
252 };
253 auto tuple_data =
254 Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
255 ConstantR2<float>(&builder, constant_matrix)});
256 auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0),
257 GetTupleElement(tuple_data, 1)});
258 auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1),
259 GetTupleElement(tuple_data, 0)});
260 auto vector_from_01 = GetTupleElement(new_tuple01, 0);
261 auto vector_from_10 = GetTupleElement(new_tuple10, 1);
262 auto matrix_from_01 = GetTupleElement(new_tuple01, 1);
263 auto matrix_from_10 = GetTupleElement(new_tuple10, 0);
264
265 auto addvectors = Add(vector_from_01, vector_from_10);
266 auto addmatrices = Add(matrix_from_01, matrix_from_10);
267
268 Add(addmatrices, addvectors,
269 /*broadcast_dimensions=*/{1});
270
271 Array2D<float> expected({
272 {4.f, 8.f, 12.f}, // row 0
273 {10.f, 14.f, 18.f}, // row 1
274 });
275 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
276 }
277
XLA_TEST_F(TupleTest,SelectBetweenTuplesOnFalse)278 XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
279 // Tests a selection between tuples with "false" path taken.
280 XlaBuilder builder(TestName());
281
282 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
283 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
284 auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
285 ConstantR1<float>(&builder, vec2)});
286 auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
287 ConstantR1<float>(&builder, vec1)});
288
289 Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
290 auto expected = LiteralUtil::MakeTupleFromSlices(
291 {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
292 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
293 }
294
XLA_TEST_F(TupleTest,TuplesInAMap)295 XLA_TEST_F(TupleTest, TuplesInAMap) {
296 XlaComputation tuple_computation;
297 {
298 // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
299 //
300 // Need to put a select in there to prevent HLO-level optimizations from
301 // optimizing out the tuples.
302 XlaBuilder b("sort_square");
303 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
304 auto x2 = Mul(x, x);
305 auto x_smaller_tuple = Tuple(&b, {x, x2});
306 auto x2_smaller_tuple = Tuple(&b, {x2, x});
307 auto sorted = Select(Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
308 auto smaller = GetTupleElement(sorted, 0);
309 auto greater = GetTupleElement(sorted, 1);
310 Add(greater, Mul(ConstantR0<float>(&b, 100.0f), smaller));
311 auto computation_status = b.Build();
312 ASSERT_IS_OK(computation_status.status());
313 tuple_computation = computation_status.ConsumeValueOrDie();
314 }
315
316 XlaBuilder b(TestName());
317 auto input = ConstantR1<float>(&b, {-1.0f, 1.0f, 2.1f});
318 Map(&b, {input}, tuple_computation, {0});
319 ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
320 }
321
XLA_TEST_F(TupleTest,SelectBetweenTuplesOnTrue)322 XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
323 // Tests a selection between tuples with "true" path taken.
324 XlaBuilder builder(TestName());
325
326 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
327 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
328 auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
329 ConstantR1<float>(&builder, vec2)});
330 auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
331 ConstantR1<float>(&builder, vec1)});
332
333 Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
334 auto expected = LiteralUtil::MakeTupleFromSlices(
335 {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
336 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
337 }
338
XLA_TEST_F(TupleTest,SelectBetweenTuplesElementResult)339 XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
340 // Tests a selection between tuples but the final result is an element of the
341 // tuple, not the whole tuple.
342 XlaBuilder builder(TestName());
343
344 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
345 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
346 auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
347 ConstantR1<float>(&builder, vec2)});
348 auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
349 ConstantR1<float>(&builder, vec1)});
350
351 auto select = Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
352 GetTupleElement(select, 0);
353
354 ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
355 }
356
357 // Cascaded selects between tuple types.
XLA_TEST_F(TupleTest,SelectBetweenTuplesCascaded)358 XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) {
359 //
360 // vec1 vec2 vec2 vec1
361 // | | | |
362 // | | | |
363 // (tuple 12) (tuple 21)
364 // \ /
365 // \ /
366 // \ /
367 // true -- --(GTE 0)--(select 1)
368 // \ / |
369 // (pred tuple)-- | --(GTE 0)--
370 // / \ V / \
371 // false -- --(GTE 1)--(select 2)-- --(add)
372 // / \ /
373 // / --(GTE 1)--
374 // /
375 // (tuple 21)
376 XlaBuilder builder(TestName());
377
378 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
379 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
380
381 auto pred_tuple = Tuple(&builder, {ConstantR0<bool>(&builder, true),
382 ConstantR0<bool>(&builder, false)});
383 auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
384 ConstantR1<float>(&builder, vec2)});
385 auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
386 ConstantR1<float>(&builder, vec1)});
387
388 auto select1 = Select(GetTupleElement(pred_tuple, 0), tuple12, tuple21);
389 auto select2 = Select(GetTupleElement(pred_tuple, 1), tuple21, select1);
390 Add(GetTupleElement(select2, 0), GetTupleElement(select2, 1));
391
392 ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
393 }
394
XLA_TEST_F(TupleTest,SelectBetweenTuplesReuseConstants)395 XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
396 // Similar to SelectBetweenTuples, but the constants are shared between the
397 // input tuples.
398 XlaBuilder builder(TestName());
399
400 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
401 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
402 auto c1 = ConstantR1<float>(&builder, vec1);
403 auto c2 = ConstantR1<float>(&builder, vec2);
404 auto tuple12 = Tuple(&builder, {c1, c2});
405 auto tuple21 = Tuple(&builder, {c2, c1});
406
407 Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
408
409 auto expected = LiteralUtil::MakeTupleFromSlices(
410 {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
411 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
412 }
413
XLA_TEST_F(TupleTest,NestedTuples)414 XLA_TEST_F(TupleTest, NestedTuples) {
415 XlaBuilder builder(TestName());
416 auto inner_tuple = Tuple(&builder, {ConstantR1<float>(&builder, {1.0, 2.0}),
417 ConstantR0<float>(&builder, 42.0)});
418 Tuple(&builder, {inner_tuple, ConstantR1<float>(&builder, {22.0, 44.0})});
419
420 auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
421 auto expected_s = LiteralUtil::CreateR0<float>(42.0);
422 auto expected_inner_tuple =
423 LiteralUtil::MakeTuple({&expected_v1, &expected_s});
424 auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
425 auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
426
427 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
428 }
429
XLA_TEST_F(TupleTest,GetTupleElementOfNestedTuple)430 XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
431 XlaBuilder builder(TestName());
432
433 Shape data_shape = ShapeUtil::MakeShape(F32, {3});
434 Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
435 Shape outer_tuple_shape =
436 ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
437
438 auto input = Parameter(&builder, 0, outer_tuple_shape, "input");
439 auto gte0 = GetTupleElement(input, 0);
440 auto gte1 = GetTupleElement(gte0, 1);
441 Add(gte1, ConstantR1<float>(&builder, {10.0, 11.0, 12.0}));
442
443 std::unique_ptr<GlobalData> data =
444 client_
445 ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
446 LiteralUtil::MakeTupleFromSlices({
447 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
448 LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
449 }),
450 LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
451 }))
452 .ConsumeValueOrDie();
453
454 std::vector<GlobalData*> arguments = {data.get()};
455 const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
456 ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
457 }
458
XLA_TEST_F(TupleTest,ComplexTuples)459 XLA_TEST_F(TupleTest, ComplexTuples) {
460 XlaBuilder builder(TestName());
461 {
462 Shape c64r0 = ShapeUtil::MakeShape(C64, {});
463 Shape c64r1 = ShapeUtil::MakeShape(C64, {2});
464 Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2});
465 Shape arg0_shape = ShapeUtil::MakeTupleShape(
466 {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})});
467 auto input0 = Parameter(&builder, 0, arg0_shape, "input0");
468 auto t0 = GetTupleElement(input0, 0);
469 auto t1 = GetTupleElement(input0, 1);
470 auto t10 = GetTupleElement(t1, 0);
471 auto t11 = GetTupleElement(t1, 1);
472 auto sum = Add(Add(t10, t11, {1}), t0);
473 auto input1 = Parameter(&builder, 1, c64r1, "input1");
474 auto prod = Mul(input1, sum, {1});
475 Tuple(&builder, {Tuple(&builder, {prod, sum}),
476 ConstantR0<complex64>(&builder, {123, 456})});
477 }
478
479 std::unique_ptr<GlobalData> arg0 =
480 client_
481 ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
482 {LiteralUtil::CreateR0<complex64>({1, 2}),
483 LiteralUtil::MakeTupleFromSlices(
484 {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
485 LiteralUtil::CreateR2<complex64>(
486 {{{100, 200}, {300, 400}},
487 {{1000, 2000}, {3000, 4000}},
488 {{10000, 20000}, {30000, 40000}}})})}))
489 .ConsumeValueOrDie();
490 std::unique_ptr<GlobalData> arg1 =
491 client_
492 ->TransferToServer(
493 LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
494 .ConsumeValueOrDie();
495 auto sum =
496 LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
497 {{1011, 2022}, {3031, 4042}},
498 {{10011, 20022}, {30031, 40042}}});
499 Literal prod(sum.shape());
500 ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
501 return sum.Get<complex64>(indexes) *
502 (indexes[indexes.size() - 1] == 0
503 ? complex64(1, 2)
504 : complex64(1, -2));
505 })
506 .ok());
507 auto expected = LiteralUtil::MakeTupleFromSlices(
508 {LiteralUtil::MakeTupleFromSlices({prod, sum}),
509 LiteralUtil::CreateR0<complex64>({123, 456})});
510 ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
511 error_spec_);
512 }
513
514 class TupleHloTest : public HloTestBase {};
515
XLA_TEST_F(TupleHloTest,BitcastAfterGTE)516 XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
517 const char* testcase = R"(
518 HloModule m, is_scheduled=true
519
520 ENTRY test {
521 name.1 = (f32[3]{0}) parameter(0)
522 get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0
523 bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1)
524 copy = f32[1,3]{1,0} copy(bitcast)
525 ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
526 }
527 )";
528 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
529 auto param =
530 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
531 auto result = ExecuteNoHloPasses(std::move(module), {¶m});
532 EXPECT_TRUE(LiteralTestUtil::Equal(
533 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
534 result));
535 }
536
537 // Disabled on interpreter due to lack of outfeed.
XLA_TEST_F(TupleHloTest,DISABLED_ON_INTERPRETER (NonAmbiguousTopLevelAllocation))538 XLA_TEST_F(TupleHloTest,
539 DISABLED_ON_INTERPRETER(NonAmbiguousTopLevelAllocation)) {
540 const char* testcase = R"(
541 HloModule tuple
542
543 ENTRY main {
544 a = f32[2] parameter(0)
545 b = f32[2] parameter(1)
546 c = f32[2] parameter(2)
547 d = f32[2] parameter(3)
548 cond = pred[] parameter(4)
549
550 tup0 = (f32[2],f32[2]) tuple(a, b)
551 tup1 = (f32[2],f32[2]) tuple(c, d)
552
553 s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1)
554 gte = f32[2] get-tuple-element(s), index=0
555 tuple = (f32[2]) tuple(gte)
556 token0 = token[] after-all()
557 ROOT outfeed = token[] outfeed(tuple, token0)
558 }
559 )";
560 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
561 auto param0 = LiteralUtil::CreateR1<float>({1, 2});
562 auto param1 = LiteralUtil::CreateR1<float>({2, 3});
563 auto param4 = LiteralUtil::CreateR0<bool>(false);
564 // Put execution on a separate thread so we can block on outfeed.
565 std::unique_ptr<tensorflow::Thread> thread(
566 tensorflow::Env::Default()->StartThread(
567 tensorflow::ThreadOptions(), "execute_thread", [&] {
568 TF_EXPECT_OK(Execute(std::move(module),
569 {¶m0, ¶m1, ¶m1, ¶m0, ¶m4})
570 .status());
571 }));
572 auto expected =
573 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
574 auto literal = Literal::CreateFromShape(expected.shape());
575 TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
576 backend().default_stream_executor(), expected.shape(), literal));
577 EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
578 }
579
580 } // namespace
581 } // namespace xla
582