1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/platform_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/platform/test_benchmark.h"
38 #include "tensorflow/core/platform/types.h"
39
40 namespace xla {
41 namespace {
42
43 class WhileTest : public ClientLibraryTestBase {};
44
45 // Tests a while node when the result type T is S32.
46 //
47 // int32 result = 0;
48 // while (result < 5) {
49 // result = result + 1;
50 // }
XLA_TEST_F(WhileTest,WhileWithScalarS32Result)51 XLA_TEST_F(WhileTest, WhileWithScalarS32Result) {
52 auto result_shape = ShapeUtil::MakeShape(S32, {});
53
54 // Create a computation for the condition: repeat for 5 iterations.
55 XlaComputation condition;
56 {
57 XlaBuilder builder("condition");
58 auto prev = Parameter(&builder, 0, result_shape, "prev");
59 Gt(ConstantR0<int32>(&builder, 5), prev);
60 condition = builder.Build().ConsumeValueOrDie();
61 }
62
63 // Create a computation for the body: add 1 to the result variable.
64 XlaComputation body;
65 {
66 XlaBuilder builder("body");
67 auto prev = Parameter(&builder, 0, result_shape, "prev");
68 auto input = ConstantR0<int32>(&builder, 1);
69 Add(input, prev);
70 body = builder.Build().ConsumeValueOrDie();
71 }
72
73 // Create a While node with computations for the condition and the body.
74 XlaBuilder builder(TestName());
75 auto init = ConstantR0<int32>(&builder, 0);
76 While(condition, body, init);
77
78 ComputeAndCompareR0<int32>(&builder, 5, {});
79 }
80
81 // Tests a while node when the result type T is S64.
82 //
83 // int32 result = 0;
84 // while (result < 5) {
85 // result = result + 1;
86 // }
XLA_TEST_F(WhileTest,WhileWithScalarS64Result)87 XLA_TEST_F(WhileTest, WhileWithScalarS64Result) {
88 auto result_shape = ShapeUtil::MakeShape(S64, {});
89
90 // Create a computation for the condition: repeat for 5 iterations.
91 XlaComputation condition;
92 {
93 XlaBuilder builder("condition");
94 auto prev = Parameter(&builder, 0, result_shape, "prev");
95 Gt(ConstantR0<int64>(&builder, 5), prev);
96 condition = builder.Build().ConsumeValueOrDie();
97 }
98
99 // Create a computation for the body: add 1 to the result variable.
100 XlaComputation body;
101 {
102 XlaBuilder builder("body");
103 auto prev = Parameter(&builder, 0, result_shape, "prev");
104 auto input = ConstantR0<int64>(&builder, 1);
105 Add(input, prev);
106 body = builder.Build().ConsumeValueOrDie();
107 }
108
109 // Create a While node with computations for the condition and the body.
110 XlaBuilder builder(TestName());
111 auto init = ConstantR0<int64>(&builder, 0);
112 While(condition, body, init);
113
114 ComputeAndCompareR0<int64>(&builder, 5, {});
115 }
116
XLA_TEST_F(WhileTest,WhileWithScalarResultNonConstInit)117 XLA_TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
118 auto result_shape = ShapeUtil::MakeShape(S32, {});
119 auto orig_shape = ShapeUtil::MakeShape(S32, {2});
120
121 // Create a computation for the condition: repeat for 5 iterations.
122 XlaComputation condition;
123 {
124 XlaBuilder builder("condition");
125 auto prev = Parameter(&builder, 0, result_shape, "prev");
126 Gt(ConstantR0<int32>(&builder, 5), prev);
127 condition = builder.Build().ConsumeValueOrDie();
128 }
129
130 // Create a computation for the body: add 1 to the result variable.
131 XlaComputation body;
132 {
133 XlaBuilder builder("body");
134 auto prev = Parameter(&builder, 0, result_shape, "prev");
135 auto input = ConstantR0<int32>(&builder, 1);
136 Add(input, prev);
137 body = builder.Build().ConsumeValueOrDie();
138 }
139
140 // Create a While node with computations for the condition and the body.
141 XlaBuilder builder(TestName());
142 auto init =
143 Reduce(ConstantR1<int32>(&builder, 2, 1), ConstantR0<int32>(&builder, 0),
144 CreateScalarAddComputation(S32, &builder), {0});
145 While(condition, body, init);
146
147 ComputeAndCompareR0<int32>(&builder, 5, {});
148 }
149
XLA_TEST_F(WhileTest,WhileWithPredicateResult)150 XLA_TEST_F(WhileTest, WhileWithPredicateResult) {
151 auto result_shape = ShapeUtil::MakeShape(PRED, {});
152
153 // Create a computation for the condition: run until condition is true.
154 XlaComputation condition;
155 {
156 XlaBuilder builder("condition");
157 auto prev = Parameter(&builder, 0, result_shape, "prev");
158 Ne(ConstantR0<bool>(&builder, true), prev);
159 condition = builder.Build().ConsumeValueOrDie();
160 }
161
162 // Create a computation for the body: or condition with true.
163 XlaComputation body;
164 {
165 XlaBuilder builder("body");
166 auto prev = Parameter(&builder, 0, result_shape, "prev");
167 Or(prev, ConstantR0<bool>(&builder, true));
168 body = builder.Build().ConsumeValueOrDie();
169 }
170
171 // Create a While node with computations for the condition and the body.
172 XlaBuilder builder(TestName());
173 auto init =
174 Ne(ConstantR0<bool>(&builder, false), ConstantR0<bool>(&builder, true));
175 While(condition, body, init);
176
177 ComputeAndCompareR0<bool>(&builder, true, {});
178 }
179
180 // Tests a while node when the result type T is a vector.
181 //
182 // All constants are chosen to produce exact results.
183 // vector<float> result(0);
184 // while (result.sum() < 15.5f) {
185 // result = result + vector<float>(0);
186 // }
XLA_TEST_F(WhileTest,DISABLED_ON_INTERPRETER (WhileWithEmptyVectorResult))187 XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) {
188 Shape result_shape = ShapeUtil::MakeShape(F32, {0});
189
190 // Create a computation for the reduction.
191 XlaComputation add;
192 {
193 XlaBuilder builder("add");
194 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
195 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
196 Add(x, y);
197 add = builder.Build().ConsumeValueOrDie();
198 }
199
200 // Create a computation for the condition.
201 // Repeat until the sum of the result vector is less than 15.5f.
202 XlaComputation condition;
203 {
204 XlaBuilder builder("condition");
205 auto prev = Parameter(&builder, 0, result_shape, "prev");
206 auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
207 /*dimensions_to_reduce=*/{0});
208 Gt(ConstantR0<float>(&builder, 15.5f), sum);
209 condition = builder.Build().ConsumeValueOrDie();
210 }
211
212 // Create a computation for the body.
213 // Add a constant vector of 1.f to the result vector.
214 XlaComputation body;
215 {
216 XlaBuilder builder("body");
217 auto prev = Parameter(&builder, 0, result_shape, "prev");
218 auto input = ConstantR1<float>(&builder, {});
219 Add(input, prev);
220 body = builder.Build().ConsumeValueOrDie();
221 }
222
223 // Create a While node with computations for the condition and the body.
224 XlaBuilder builder("while");
225 auto init = ConstantR1<float>(&builder, {});
226 auto result = While(condition, body, init);
227 VLOG(2) << "while = "
228 << ShapeUtil::HumanString(
229 builder.GetShape(result).ConsumeValueOrDie());
230
231 ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
232 }
233
234 // Tests a while node when the result type T is a vector.
235 //
236 // All constants are chosen to produce exact results.
237 // vector<float> result(8, 0.0f);
238 // while (result.sum() < 15.5f) {
239 // result = result + vector<float>(8, 0.125f);
240 // }
XLA_TEST_F(WhileTest,WhileWithVectorResult)241 XLA_TEST_F(WhileTest, WhileWithVectorResult) {
242 Shape result_shape = ShapeUtil::MakeShape(F32, {8});
243
244 // Create a computation for the reduction.
245 XlaComputation add;
246 {
247 XlaBuilder builder("add");
248 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
249 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
250 Add(x, y);
251 add = builder.Build().ConsumeValueOrDie();
252 }
253
254 // Create a computation for the condition.
255 // Repeat until the sum of the result vector is less than 5.5f.
256 XlaComputation condition;
257 {
258 XlaBuilder builder("condition");
259 auto prev = Parameter(&builder, 0, result_shape, "prev");
260 auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
261 /*dimensions_to_reduce=*/{0});
262 Gt(ConstantR0<float>(&builder, 15.5f), sum);
263 condition = builder.Build().ConsumeValueOrDie();
264 }
265
266 // Create a computation for the body.
267 // Add a constant vector of 1.f to the result vector.
268 XlaComputation body;
269 {
270 XlaBuilder builder("body");
271 auto prev = Parameter(&builder, 0, result_shape, "prev");
272 auto input = ConstantR1<float>(&builder, 8, 0.125f);
273 Add(input, prev);
274 body = builder.Build().ConsumeValueOrDie();
275 }
276
277 // Create a While node with computations for the condition and the body.
278 XlaBuilder builder("while");
279 auto init = ConstantR1<float>(&builder, 8, 0.f);
280 auto result = While(condition, body, init);
281 VLOG(2) << "while = "
282 << ShapeUtil::HumanString(
283 builder.GetShape(result).ConsumeValueOrDie());
284
285 // Individual elements with increase by 1/8 each time through the loop, so
286 // the sum will increase by 1.0. It will first be >15.5 when the elements
287 // have all reached 2.0.
288 std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f};
289 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
290 }
291
292 // Tests a while node when the result type is a vector which is part
293 // of the result tuple.
294 //
295 // All constants are chosen to produce exact results.
296 // vector<float> result(8, 0.0f);
297 // while (result.sum() < 15.5f) {
298 // result = result + vector<float>(8, 0.125f);
299 // }
300 // tuple = tuple { while }
XLA_TEST_F(WhileTest,WhileWithVectorResultIntoTuple)301 XLA_TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
302 Shape result_shape = ShapeUtil::MakeShape(F32, {8});
303
304 // Create a computation for the reduction.
305 XlaComputation add;
306 {
307 XlaBuilder builder("add");
308 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
309 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
310 Add(x, y);
311 add = builder.Build().ConsumeValueOrDie();
312 }
313
314 // Create a computation for the condition.
315 // Repeat until the sum of the result vector is less than 5.5f.
316 XlaComputation condition;
317 {
318 XlaBuilder builder("condition");
319 auto prev = Parameter(&builder, 0, result_shape, "prev");
320 auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
321 /*dimensions_to_reduce=*/{0});
322 Gt(ConstantR0<float>(&builder, 15.5f), sum);
323 condition = builder.Build().ConsumeValueOrDie();
324 }
325
326 // Create a computation for the body.
327 // Add a constant vector of 1.f to the result vector.
328 XlaComputation body;
329 {
330 XlaBuilder builder("body");
331 auto prev = Parameter(&builder, 0, result_shape, "prev");
332 auto input = ConstantR1<float>(&builder, 8, 0.125f);
333 Add(input, prev);
334 body = builder.Build().ConsumeValueOrDie();
335 }
336
337 // Create a While node with computations for the condition and the body.
338 XlaBuilder builder("while");
339 auto init = ConstantR1<float>(&builder, 8, 0.f);
340 auto result = While(condition, body, init);
341 VLOG(2) << "while = "
342 << ShapeUtil::HumanString(
343 builder.GetShape(result).ConsumeValueOrDie());
344 Tuple(&builder, {result});
345
346 // Individual elements with increase by 1/8 each time through the loop, so
347 // the sum will increase by 1.0. It will first be >15.5 when the elements
348 // have all reached 2.0.
349 auto expected_data =
350 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
351 auto expected = LiteralUtil::MakeTuple({&expected_data});
352 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
353 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
354 }
355
XLA_TEST_F(WhileTest,WhileWithPermutationAndTupleResult)356 XLA_TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
357 std::vector<Shape> shape_elements = {
358 ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
359 ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
360 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
361
362 // Create a computation for the condition.
363 // Repeat for N iterations.
364 const int N = 2;
365 XlaComputation condition;
366 {
367 XlaBuilder builder("condition");
368 auto prev = Parameter(&builder, 0, result_shape, "prev");
369 auto iteration = GetTupleElement(prev, 0);
370 Gt(ConstantR0<int32>(&builder, N), iteration);
371 condition = builder.Build().ConsumeValueOrDie();
372 }
373
374 // Create a computation for the body.
375 // Add 1 to the iteration variable and permute the weights.
376 XlaComputation body;
377 {
378 XlaBuilder builder("body");
379 auto prev = Parameter(&builder, 0, result_shape, "prev");
380 auto iteration = GetTupleElement(prev, 0);
381 auto w1 = GetTupleElement(prev, 1);
382 auto w2 = GetTupleElement(prev, 2);
383 auto w3 = GetTupleElement(prev, 3);
384 Tuple(&builder,
385 {Add(iteration, ConstantR0<int32>(&builder, 1)), w3, w1, w2});
386 body = builder.Build().ConsumeValueOrDie();
387 }
388
389 // Create a While node with computations for the condition and the body.
390 XlaBuilder builder("while");
391 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
392 ConstantR1<float>(&builder, 3, 1.f),
393 ConstantR1<float>(&builder, 3, 2.f),
394 ConstantR1<float>(&builder, 3, 3.f)});
395 auto result = While(condition, body, init);
396 VLOG(2) << "result = "
397 << ShapeUtil::HumanString(
398 builder.GetShape(result).ConsumeValueOrDie());
399
400 auto expected_counter = LiteralUtil::CreateR0<int32>(N);
401 auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
402 auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
403 auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
404 auto expected = LiteralUtil::MakeTuple(
405 {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
406 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
407 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
408 }
409
XLA_TEST_F(WhileTest,WhileWithPermutationAndVectorResult)410 XLA_TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
411 std::vector<Shape> shape_elements = {
412 ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
413 ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
414 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
415
416 // Create a computation for the condition.
417 // Repeat for N iterations.
418 const int N = 2;
419 XlaComputation condition;
420 {
421 XlaBuilder builder("condition");
422 auto prev = Parameter(&builder, 0, result_shape, "prev");
423 auto iteration = GetTupleElement(prev, 0);
424 Gt(ConstantR0<int32>(&builder, N), iteration);
425 condition = builder.Build().ConsumeValueOrDie();
426 }
427
428 // Create a computation for the body.
429 // Add 1 to the iteration variable permute the weights.
430 XlaComputation body;
431 {
432 XlaBuilder builder("body");
433 auto prev = Parameter(&builder, 0, result_shape, "prev");
434 auto iteration = GetTupleElement(prev, 0);
435 auto w1 = GetTupleElement(prev, 1);
436 auto w2 = GetTupleElement(prev, 2);
437 auto w3 = GetTupleElement(prev, 3);
438 Tuple(&builder,
439 {Add(iteration, ConstantR0<int32>(&builder, 1)), w3, w1, w2});
440 body = builder.Build().ConsumeValueOrDie();
441 }
442
443 // Create a While node with computations for the condition and the body.
444 XlaBuilder builder("while");
445 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
446 ConstantR1<float>(&builder, 3, 1.f),
447 ConstantR1<float>(&builder, 3, 2.f),
448 ConstantR1<float>(&builder, 3, 3.f)});
449 auto xla_while = While(condition, body, init);
450
451 auto add12 =
452 Add(GetTupleElement(xla_while, 1), GetTupleElement(xla_while, 2));
453 auto result = Add(add12, GetTupleElement(xla_while, 3));
454 VLOG(2) << "result = "
455 << ShapeUtil::HumanString(
456 builder.GetShape(result).ConsumeValueOrDie());
457 std::vector<float> expected = {6.f, 6.f, 6.f};
458 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
459 }
460
461 // Tests a while node when the result type T is a Tuple.
462 //
463 // tuple<int32, vector<float>> result(0, vector<float>(10, 0.0f));
464 // while (get<0>(result) < 5) {
465 // get<0>(result) = get<0>(result) + 1;
466 // get<1>(result) = get<1>(result) + vector<float>(10, 1.0f);
467 // }
XLA_TEST_F(WhileTest,WhileWithTupleResult)468 XLA_TEST_F(WhileTest, WhileWithTupleResult) {
469 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
470 ShapeUtil::MakeShape(F32, {10})};
471 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
472
473 // Create a computation for the condition.
474 // Repeat for 5 iterations.
475 XlaComputation condition;
476 {
477 XlaBuilder builder("condition");
478 auto prev = Parameter(&builder, 0, result_shape, "prev");
479 auto iteration = GetTupleElement(prev, 0);
480 Gt(ConstantR0<int32>(&builder, 5), iteration);
481 condition = builder.Build().ConsumeValueOrDie();
482 }
483
484 // Create a computation for the body.
485 // Add 1 to the iteration variable and add a constant vector of 1.0f to
486 // the weight variable, both of which are tuple elements.
487 XlaComputation body;
488 {
489 XlaBuilder builder("body");
490 auto prev = Parameter(&builder, 0, result_shape, "prev");
491 auto iteration = GetTupleElement(prev, 0);
492 auto weights = GetTupleElement(prev, 1);
493 auto input = ConstantR1<float>(&builder, 10, 1.f);
494 auto new_weights = Add(weights, input);
495 Tuple(&builder,
496 {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
497 body = builder.Build().ConsumeValueOrDie();
498 }
499
500 // Create a While node with computations for the condition and the body.
501 XlaBuilder builder("while");
502 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
503 ConstantR1<float>(&builder, 10, 0.f)});
504 auto result = While(condition, body, init);
505 VLOG(2) << "while = "
506 << ShapeUtil::HumanString(
507 builder.GetShape(result).ConsumeValueOrDie());
508
509 auto expected_counter = LiteralUtil::CreateR0<int32>(5);
510 auto expected_data = LiteralUtil::CreateR1<float>(
511 {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
512 auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
513 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
514 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
515 }
516
XLA_TEST_F(WhileTest,WhileWithPredicateTupleResult)517 XLA_TEST_F(WhileTest, WhileWithPredicateTupleResult) {
518 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
519 ShapeUtil::MakeShape(PRED, {})};
520 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
521
522 // Create a computation for the condition.
523 // Repeat for 5 iterations.
524 XlaComputation condition;
525 {
526 XlaBuilder builder("condition");
527 auto prev = Parameter(&builder, 0, result_shape, "prev");
528 auto iteration = GetTupleElement(prev, 0);
529 Gt(ConstantR0<int32>(&builder, 5), iteration);
530 condition = builder.Build().ConsumeValueOrDie();
531 }
532
533 // Create a computation for the body.
534 // Add 1 to the iteration variable and or the predicate with true
535 XlaComputation body;
536 {
537 XlaBuilder builder("body");
538 auto prev = Parameter(&builder, 0, result_shape, "prev");
539 auto iteration = GetTupleElement(prev, 0);
540 auto pred = GetTupleElement(prev, 1);
541 auto new_pred = Or(pred, ConstantR0<bool>(&builder, true));
542 Tuple(&builder, {Add(iteration, ConstantR0<int32>(&builder, 1)), new_pred});
543 body = builder.Build().ConsumeValueOrDie();
544 }
545
546 // Create a While node with computations for the condition and the body.
547 XlaBuilder builder("while");
548 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
549 Ne(ConstantR0<bool>(&builder, false),
550 ConstantR0<bool>(&builder, true))});
551 auto result = While(condition, body, init);
552 VLOG(2) << "while = "
553 << ShapeUtil::HumanString(
554 builder.GetShape(result).ConsumeValueOrDie());
555
556 auto expected_counter = LiteralUtil::CreateR0<int32>(5);
557 auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
558 auto expected =
559 LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
560 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
561 }
562
XLA_TEST_F(WhileTest,WhileWithTupleConstantScalarResult)563 XLA_TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
564 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
565 ShapeUtil::MakeShape(S32, {})};
566 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
567
568 // Create a computation for the condition.
569 // Repeat for 5 iterations.
570 XlaComputation condition;
571 {
572 XlaBuilder builder("condition");
573 auto prev = Parameter(&builder, 0, result_shape, "prev");
574 auto iteration = GetTupleElement(prev, 0);
575 Gt(ConstantR0<int32>(&builder, 5), iteration);
576 condition = builder.Build().ConsumeValueOrDie();
577 }
578
579 // Create a computation for the body.
580 // Add 1 to the iteration variable and set the other tuple element to a
581 // constant.
582 XlaComputation body;
583 {
584 XlaBuilder builder("body");
585 auto prev = Parameter(&builder, 0, result_shape, "prev");
586 auto iteration = GetTupleElement(prev, 0);
587 Tuple(&builder, {Add(iteration, ConstantR0<int32>(&builder, 1)),
588 ConstantR0<int32>(&builder, 7)});
589 body = builder.Build().ConsumeValueOrDie();
590 }
591
592 // Create a While node with computations for the condition and the body.
593 XlaBuilder builder("while");
594 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
595 ConstantR0<int32>(&builder, 7)});
596 auto result = While(condition, body, init);
597 VLOG(2) << "while = "
598 << ShapeUtil::HumanString(
599 builder.GetShape(result).ConsumeValueOrDie());
600
601 auto expected_counter = LiteralUtil::CreateR0<int32>(5);
602 auto expected_data = LiteralUtil::CreateR0<int32>(7);
603 auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
604 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
605 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
606 }
607
608 // Tests two while nodes when the result type T is a Tuple and the second
609 // while node uses the result of the first while node which is used in two
610 // nodes.
611 // tuple<int32, vector<float>> w0(0, vector<float>(10, 0.0f));
612 // w0 = while (get<0>(w0) < c1) {
613 // get<0>(w0) = get<0>(w0) + 1;
614 // get<1>(w0) = get<1>(w0) + vector<float>(10, 1.0f);
615 // }
616 // tuple<int32, vector<float>> w1(get<0>(w0), get<1>(w0));
617 // w1 = while (get<0>(w1) < c2) {
618 // get<0>(w1) = get<0>(w1) + 1;
619 // get<1>(w1) = get<1>(w1) + vector<float>(10, 1.0f);
620 // }
621 // result = get<1>(w0) + get<1>(w1)
XLA_TEST_F(WhileTest,TwoWhileWithTupleResult)622 XLA_TEST_F(WhileTest, TwoWhileWithTupleResult) {
623 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
624 ShapeUtil::MakeShape(F32, {10})};
625 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
626
627 // Create a computation for the condition.
628 // Repeat for 5 iterations.
629 XlaComputation condition;
630 const int c1 = 5;
631 {
632 XlaBuilder builder("condition");
633 auto prev = Parameter(&builder, 0, result_shape, "prev");
634 auto iteration = GetTupleElement(prev, 0);
635 Lt(iteration, ConstantR0<int32>(&builder, c1));
636 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
637 }
638
639 XlaComputation condition2;
640 const int c2 = 7;
641 {
642 XlaBuilder builder("condition2");
643 auto prev = Parameter(&builder, 0, result_shape, "prev");
644 auto iteration = GetTupleElement(prev, 0);
645 Lt(iteration, ConstantR0<int32>(&builder, c2));
646 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
647 }
648
649 // Create a computation for the body.
650 // Add 1 to the iteration variable and add a constant vector of 1.0f to
651 // the weight variable, both of which are tuple elements.
652 XlaComputation body;
653 {
654 XlaBuilder builder("body");
655 auto prev = Parameter(&builder, 0, result_shape, "prev");
656 auto iteration = GetTupleElement(prev, 0);
657 auto weights = GetTupleElement(prev, 1);
658 auto input = ConstantR1<float>(&builder, 10, 1.f);
659 auto new_weights = Add(weights, input);
660 Tuple(&builder,
661 {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
662 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
663 }
664
665 XlaComputation body2;
666 {
667 XlaBuilder builder("body");
668 auto prev = Parameter(&builder, 0, result_shape, "prev");
669 auto iteration = GetTupleElement(prev, 0);
670 auto weights = GetTupleElement(prev, 1);
671 auto input = ConstantR1<float>(&builder, 10, 1.f);
672 auto new_weights = Add(weights, input);
673 Tuple(&builder,
674 {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
675 TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
676 }
677
678 // Create a While node with computations for the condition and the body.
679 XlaBuilder builder("while");
680 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
681 ConstantR1<float>(&builder, 10, 0.f)});
682 auto while1 = While(condition, body, init);
683
684 auto while2 = While(condition2, body2, while1);
685
686 auto while_result1 = GetTupleElement(while1, 1);
687 auto while_result2 = GetTupleElement(while2, 1);
688 VLOG(2) << "while_result2 = "
689 << ShapeUtil::HumanString(
690 builder.GetShape(while_result2).ConsumeValueOrDie());
691 auto result = Add(while_result1, while_result2);
692 VLOG(2) << "result = "
693 << ShapeUtil::HumanString(
694 builder.GetShape(result).ConsumeValueOrDie());
695 const float sum = c1 + c2;
696 std::vector<float> expected(10, sum);
697 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
698 }
699
700 // Test while nodes that share the while body computation.
XLA_TEST_F(WhileTest,TwoWhileLoopsAndSharedBody)701 XLA_TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
702 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
703 ShapeUtil::MakeShape(F32, {10})};
704 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
705
706 // Create a computation for the condition.
707 // Repeat for 5 iterations.
708 XlaComputation condition;
709 const int c1 = 5;
710 {
711 XlaBuilder builder("condition");
712 auto prev = Parameter(&builder, 0, result_shape, "prev");
713 auto iteration = GetTupleElement(prev, 0);
714 Lt(iteration, ConstantR0<int32>(&builder, c1));
715 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
716 }
717
718 XlaComputation condition2;
719 const int c2 = 7;
720 {
721 XlaBuilder builder("condition2");
722 auto prev = Parameter(&builder, 0, result_shape, "prev");
723 auto iteration = GetTupleElement(prev, 0);
724 Lt(iteration, ConstantR0<int32>(&builder, c2));
725 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
726 }
727
728 // Create a computation for the body.
729 // Add 1 to the iteration variable and add a constant vector of 1.0f to
730 // the weight variable, both of which are tuple elements.
731 XlaComputation body;
732 {
733 XlaBuilder builder("body");
734 auto prev = Parameter(&builder, 0, result_shape, "prev");
735 auto iteration = GetTupleElement(prev, 0);
736 auto weights = GetTupleElement(prev, 1);
737 auto input = ConstantR1<float>(&builder, 10, 1.f);
738 auto new_weights = Add(weights, input);
739 Tuple(&builder,
740 {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
741 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
742 }
743
744 // Create a While node with computations for the condition and the body.
745 XlaBuilder builder("while");
746 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
747 ConstantR1<float>(&builder, 10, 0.f)});
748 auto while1 = While(condition, body, init);
749
750 auto while2 = While(condition2, body, while1);
751
752 auto while_result1 = GetTupleElement(while1, 1);
753 auto while_result2 = GetTupleElement(while2, 1);
754 VLOG(2) << "while_result2 = "
755 << ShapeUtil::HumanString(
756 builder.GetShape(while_result2).ConsumeValueOrDie());
757 auto result = Add(while_result1, while_result2);
758 VLOG(2) << "result = "
759 << ShapeUtil::HumanString(
760 builder.GetShape(result).ConsumeValueOrDie());
761 const float sum = c1 + c2;
762 std::vector<float> expected(10, sum);
763 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
764 }
765
XLA_TEST_F(WhileTest,WhileLoopsWithSharedBodyAndInit)766 XLA_TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
767 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
768 ShapeUtil::MakeShape(F32, {10})};
769 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
770
771 // Create a computation for the condition.
772 // Repeat for 5 iterations.
773 XlaComputation condition;
774 const int c1 = 5;
775 {
776 XlaBuilder builder("condition");
777 auto prev = Parameter(&builder, 0, result_shape, "prev");
778 auto iteration = GetTupleElement(prev, 0);
779 Lt(iteration, ConstantR0<int32>(&builder, c1));
780 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
781 }
782
783 XlaComputation condition2;
784 const int c2 = 7;
785 {
786 XlaBuilder builder("condition2");
787 auto prev = Parameter(&builder, 0, result_shape, "prev");
788 auto iteration = GetTupleElement(prev, 0);
789 Lt(iteration, ConstantR0<int32>(&builder, c2));
790 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
791 }
792
793 // Create a computation for the body.
794 // Add 1 to the iteration variable and add a constant vector of 1.0f to
795 // the weight variable, both of which are tuple elements.
796 XlaComputation body;
797 {
798 XlaBuilder builder("body");
799 auto prev = Parameter(&builder, 0, result_shape, "prev");
800 auto iteration = GetTupleElement(prev, 0);
801 auto weights = GetTupleElement(prev, 1);
802 auto input = ConstantR1<float>(&builder, 10, 1.f);
803 auto new_weights = Add(weights, input);
804 Tuple(&builder,
805 {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
806 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
807 }
808
809 // Create a While node with computations for the condition and the body.
810 XlaBuilder builder("while");
811 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
812 ConstantR1<float>(&builder, 10, 0.f)});
813 auto while1 = While(condition, body, init);
814 auto while2 = While(condition2, body, init);
815
816 auto while_result1 = GetTupleElement(while1, 1);
817 auto while_result2 = GetTupleElement(while2, 1);
818 VLOG(2) << "while_result2 = "
819 << ShapeUtil::HumanString(
820 builder.GetShape(while_result2).ConsumeValueOrDie());
821 auto result = Add(while_result1, while_result2);
822 VLOG(2) << "result = "
823 << ShapeUtil::HumanString(
824 builder.GetShape(result).ConsumeValueOrDie());
825 const float sum = c1 + c2;
826 std::vector<float> expected(10, sum);
827 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
828 }
829
830 // WhileTest that uses DynamicUpdateSlice instruction in body computation.
831 // Loop state tuple element 1 has as its single user operand(0) of
832 // DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU.
XLA_TEST_F(WhileTest,WhileWithDynamicUpdateSlice)833 XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
834 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
835 ShapeUtil::MakeShape(F32, {10})};
836 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
837
838 // Create a computation for the condition.
839 // Repeat for 5 iterations.
840 XlaComputation condition;
841 {
842 XlaBuilder builder("condition");
843 auto prev = Parameter(&builder, 0, result_shape, "prev");
844 auto iteration = GetTupleElement(prev, 0);
845 Gt(ConstantR0<int32>(&builder, 5), iteration);
846 condition = builder.Build().ConsumeValueOrDie();
847 }
848
849 // Create a computation for the body.
850 // Add 1 to the iteration variable and add a constant vector of 1.0f to
851 // the weight variable, both of which are tuple elements.
852 XlaComputation body;
853 {
854 XlaBuilder builder("body");
855 auto prev = Parameter(&builder, 0, result_shape, "prev");
856 // TupleElement 0
857 auto iteration = GetTupleElement(prev, 0);
858 auto out0 = Add(iteration, ConstantR0<int32>(&builder, 1));
859 // TupleElement 1
860 auto input = GetTupleElement(prev, 1);
861 // Update.
862 auto update = ConvertElementType(Broadcast(out0, {2}), F32);
863 // Starts = iteration * 2;
864 auto starts = Mul(iteration, ConstantR0<int32>(&builder, 2));
865 // UpdateSlice.
866 auto out1 = DynamicUpdateSlice(input, update, starts);
867
868 Tuple(&builder, {out0, out1});
869 body = builder.Build().ConsumeValueOrDie();
870 }
871
872 // Create a While node with computations for the condition and the body.
873 XlaBuilder builder("while");
874 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
875 ConstantR1<float>(&builder, 10, 0.f)});
876 auto result = While(condition, body, init);
877 VLOG(2) << "while = "
878 << ShapeUtil::HumanString(
879 builder.GetShape(result).ConsumeValueOrDie());
880
881 auto expected_counter = LiteralUtil::CreateR0<int32>(5);
882 auto expected_data = LiteralUtil::CreateR1<float>(
883 {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
884 auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
885 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
886 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
887 }
888
889 // Tests a while node when the result type T is a vector of S32.
890 //
891 // int32 result = (0, 0, 0, 0, 0, 0);
892 // while (result[0] < count) {
893 // result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]);
894 // }
895 //
896 // This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a
897 // pair:
898 // ((iteration, (random vector))).
899 //
900 // Note: this test currently only tests generating random values within a loop.
901 // Per backend the values generated can be different as the different backends
902 // use different random number generators.
903 // TODO(b/32240857): Extend test to verify outputs.
XLA_TEST_F(WhileTest,WhileWithPrngScalarResult)904 XLA_TEST_F(WhileTest, WhileWithPrngScalarResult) {
905 auto v6s32 = ShapeUtil::MakeShape(S32, {6});
906
907 // Create a computation for the condition: repeat for count iterations.
908 auto build_condition = [this, v6s32](int count) {
909 XlaBuilder builder(TestName());
910 auto prev = Reshape(
911 Slice(Parameter(&builder, 0, v6s32, "prev"), {0}, {1}, {1}), {0}, {});
912 Gt(ConstantR0<int32>(&builder, count), prev);
913 return builder.Build().ConsumeValueOrDie();
914 };
915
916 // Create a computation for the body: add 1 to the result variable.
917 XlaComputation body;
918 {
919 XlaBuilder builder("body");
920 auto prev = Parameter(&builder, 0, v6s32, "prev");
921 auto inc = ConcatInDim(&builder,
922 {ConstantR1<int32>(&builder, {1}),
923 RngUniform(ConstantR0<int32>(&builder, 0),
924 ConstantR0<int32>(&builder, 100),
925 ShapeUtil::MakeShape(S32, {5}))},
926 0);
927 Add(inc, prev);
928 body = builder.Build().ConsumeValueOrDie();
929 }
930
931 // Create a While node with computations for the condition and the body.
932 auto while_loop = [this, &body, build_condition](int count) {
933 XlaBuilder builder(TestName());
934 auto init = ConstantR1<int32>(&builder, {0, 0, 0, 0, 0, 0});
935 While(build_condition(count), body, init);
936 return builder.Build();
937 };
938
939 for (int i = 1; i < 4; ++i) {
940 TF_ASSERT_OK_AND_ASSIGN(auto computation, while_loop(i));
941
942 ExecutionOptions execution_options = execution_options_;
943 execution_options.set_seed(65);
944 TF_ASSERT_OK_AND_ASSIGN(
945 auto result,
946 client_->ExecuteAndTransfer(computation, {}, &execution_options));
947 }
948 }
949
XLA_TEST_F(WhileTest,WhileThatSwapsParameterWithTupleElement)950 XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
951 auto element_shape = ShapeUtil::MakeShape(F32, {2});
952
953 XlaBuilder outer("outer");
954 auto p = Parameter(&outer, 0, element_shape, "param");
955 auto t = Tuple(&outer, {p, ConstantR1<float>(&outer, {1, 1})});
956
957 TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t));
958
959 XlaBuilder cond("cond");
960 auto cond_t = Parameter(&cond, 0, tuple_shape, "t");
961 Any(Eq(GetTupleElement(cond_t, 0), ConstantR1<float>(&cond, {42, 42})));
962
963 XlaBuilder body("body");
964 auto body_t = Parameter(&body, 0, tuple_shape, "t");
965 auto e = GetTupleElement(body_t, 1);
966 Tuple(&body, {e, e});
967
968 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
969 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
970 While(cond_computation, body_computation, t);
971
972 auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
973 auto expected =
974 LiteralUtil::MakeTuple({&expected_element, &expected_element});
975 TF_ASSERT_OK_AND_ASSIGN(
976 std::unique_ptr<GlobalData> parameter_data,
977 client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
978 ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
979 ErrorSpec(1e-6));
980 }
981
XLA_TEST_F(WhileTest,WhileThatSwapsParameterWithBroadcast)982 XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
983 auto element_shape = ShapeUtil::MakeShape(F32, {2});
984
985 XlaBuilder outer("outer");
986 auto p = Parameter(&outer, 0, element_shape, "param");
987
988 XlaBuilder cond("cond");
989 auto cond_t = Parameter(&cond, 0, element_shape, "t");
990 Any(Eq(cond_t, ConstantR1<float>(&cond, {42, 42})));
991
992 XlaBuilder body("body");
993 Parameter(&body, 0, element_shape, "t");
994 Broadcast(ConstantR0<float>(&body, 1.0), {2});
995
996 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
997 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
998 While(cond_computation, body_computation, p);
999
1000 TF_ASSERT_OK_AND_ASSIGN(
1001 std::unique_ptr<GlobalData> parameter_data,
1002 client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
1003 ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
1004 ErrorSpec(1e-6));
1005 }
1006
XLA_TEST_F(WhileTest,WhileThatTurnsScalarParameterToTupleElement)1007 XLA_TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
1008 auto element_shape = ShapeUtil::MakeShape(F32, {});
1009
1010 XlaBuilder outer("outer");
1011 auto p = Parameter(&outer, 0, element_shape, "param");
1012
1013 XlaBuilder cond("cond");
1014 auto cond_t = Parameter(&cond, 0, element_shape, "t");
1015 Eq(cond_t, ConstantR0<float>(&cond, 42));
1016
1017 XlaBuilder body("body");
1018 auto body_t = Parameter(&body, 0, element_shape, "t");
1019 auto tuple = Tuple(&body, {body_t, Add(body_t, ConstantR0<float>(&body, 1))});
1020 GetTupleElement(tuple, 1);
1021
1022 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
1023 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
1024 While(cond_computation, body_computation, p);
1025
1026 TF_ASSERT_OK_AND_ASSIGN(
1027 std::unique_ptr<GlobalData> parameter_data,
1028 client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
1029 ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
1030 ErrorSpec(1e-6));
1031 }
1032
1033 // Tests loop where the init value comes from two sources (constant and
1034 // parameter).
1035 //
1036 // int32 result = (0, 1);
1037 // while (result[0] + result[1] < 30) {
1038 // result[0] = result[0] + 1;
1039 // result[1] = result[1] + 1;
1040 // }
XLA_TEST_F(WhileTest,WhileWithMixedTupleElements)1041 XLA_TEST_F(WhileTest, WhileWithMixedTupleElements) {
1042 auto result_shape = ShapeUtil::MakeTupleShape(
1043 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
1044
1045 XlaBuilder outer("outer");
1046 auto p =
1047 Tuple(&outer, {ConstantR0<int32>(&outer, 0),
1048 Parameter(&outer, 0, ShapeUtil::MakeShape(S32, {}), "t")});
1049
1050 XlaBuilder cond("cond");
1051 auto params = Parameter(&cond, 0, result_shape, "prev");
1052 auto cond_t = Add(GetTupleElement(params, 1), GetTupleElement(params, 0));
1053 Lt(cond_t, ConstantR0<int32>(&cond, 30));
1054
1055 XlaBuilder body("body");
1056 auto body_t = Parameter(&body, 0, result_shape, "t");
1057
1058 Tuple(&body, {Add(GetTupleElement(body_t, 0), ConstantR0<int32>(&body, 1)),
1059 Add(GetTupleElement(body_t, 1), ConstantR0<int32>(&body, 1))});
1060
1061 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
1062 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
1063 While(cond_computation, body_computation, p);
1064
1065 TF_ASSERT_OK_AND_ASSIGN(
1066 std::unique_ptr<GlobalData> parameter_data,
1067 client_->TransferToServer(LiteralUtil::CreateR0<int32>(1)));
1068
1069 auto add1 = LiteralUtil::CreateR0<int32>(15);
1070 auto add2 = LiteralUtil::CreateR0<int32>(16);
1071 auto expected = LiteralUtil::MakeTuple({&add1, &add2});
1072 ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
1073 ErrorSpec(1e-6));
1074 }
1075
1076 // Tests nested while loops.
1077 //
1078 // int32 result = 0;
1079 // while (result < 30) {
1080 // int i = 0;
1081 // while (i < 7) {
1082 // result = result + 2;
1083 // i = i + 1;
1084 // }
1085 // }
XLA_TEST_F(WhileTest,NestedWhileWithScalarResult)1086 XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
1087 auto outer_result_shape = ShapeUtil::MakeShape(S32, {});
1088 auto inner_result_shape = ShapeUtil::MakeTupleShape(
1089 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
1090
1091 XlaComputation inner_condition;
1092 {
1093 XlaBuilder builder("inner_condition");
1094 auto params = Parameter(&builder, 0, inner_result_shape, "prev");
1095 auto i = GetTupleElement(params, 0);
1096 Lt(i, ConstantR0<int32>(&builder, 7));
1097 inner_condition = builder.Build().ConsumeValueOrDie();
1098 }
1099
1100 // Creates a computation for the outer loop condition:
1101 // repeat while result < 30.
1102 XlaComputation outer_condition;
1103 {
1104 XlaBuilder builder("outer_condition");
1105 auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
1106 Lt(prev, ConstantR0<int32>(&builder, 30));
1107 outer_condition = builder.Build().ConsumeValueOrDie();
1108 }
1109
1110 // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
1111 // `result`.
1112 XlaComputation inner_body;
1113 {
1114 XlaBuilder builder("inner_body");
1115 auto params = Parameter(&builder, 0, inner_result_shape, "prev");
1116 auto i = GetTupleElement(params, 0);
1117 auto result = GetTupleElement(params, 1);
1118 i = Add(ConstantR0<int32>(&builder, 1), i);
1119 result = Add(ConstantR0<int32>(&builder, 2), result);
1120 Tuple(&builder, {i, result});
1121 inner_body = builder.Build().ConsumeValueOrDie();
1122 }
1123
1124 // Creates a computation for the outer loop: run the inner loop with i = 0.
1125 XlaComputation outer_body;
1126 {
1127 XlaBuilder builder("outer_body");
1128 auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
1129 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), prev});
1130 auto result = While(inner_condition, inner_body, init);
1131 GetTupleElement(result, 1);
1132 outer_body = builder.Build().ConsumeValueOrDie();
1133 }
1134
1135 // Create a While node with computations for the condition and the body.
1136 XlaBuilder builder(TestName());
1137 auto init = ConstantR0<int32>(&builder, 0);
1138 While(outer_condition, outer_body, init);
1139
1140 ComputeAndCompareR0<int32>(&builder, 42, {});
1141 }
1142
1143 // Tests a while node when the result type T is S32.
1144 // f = lambda result: tuple({result < 5})
1145 // int32 result = 0;
1146 // while (f(result).get<0>()) {
1147 // result = result + 1;
1148 // }
XLA_TEST_F(WhileTest,WhileWithCallInsideCondition)1149 XLA_TEST_F(WhileTest, WhileWithCallInsideCondition) {
1150 auto result_shape = ShapeUtil::MakeShape(S32, {});
1151
1152 // Create a computation for the condition: repeat for 5 iterations.
1153 XlaComputation condition_callee;
1154 {
1155 XlaBuilder builder("condition_callee");
1156 auto prev = Parameter(&builder, 0, result_shape, "prev");
1157 Tuple(&builder, {Gt(ConstantR0<int32>(&builder, 5), prev)});
1158
1159 condition_callee = builder.Build().ConsumeValueOrDie();
1160 }
1161
1162 XlaComputation condition;
1163 {
1164 XlaBuilder builder("condition");
1165 auto prev = Parameter(&builder, 0, result_shape, "prev");
1166 auto result = Call(&builder, condition_callee, {prev});
1167 GetTupleElement(result, 0);
1168 condition = builder.Build().ConsumeValueOrDie();
1169 }
1170
1171 // Create a computation for the body: add 1 to the result variable.
1172 XlaComputation body;
1173 {
1174 XlaBuilder builder("body");
1175 auto prev = Parameter(&builder, 0, result_shape, "prev");
1176 auto input = ConstantR0<int32>(&builder, 1);
1177 Add(input, prev);
1178 body = builder.Build().ConsumeValueOrDie();
1179 }
1180
1181 // Create a While node with computations for the condition and the body.
1182 XlaBuilder builder(TestName());
1183 auto init = ConstantR0<int32>(&builder, 0);
1184 While(condition, body, init);
1185
1186 ComputeAndCompareR0<int32>(&builder, 5, {});
1187 }
1188
XLA_TEST_F(WhileTest,WhileWithLoopInvariantOperation)1189 XLA_TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
1190 auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2});
1191 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
1192 auto while_shape = ShapeUtil::MakeTupleShape(
1193 {scalar_s32, matrix_shape, matrix_shape, matrix_shape});
1194
1195 // Create a computation for the condition: repeat for 5 iterations.
1196 XlaComputation condition;
1197 {
1198 XlaBuilder builder("condition");
1199 auto state = Parameter(&builder, 0, while_shape, "state");
1200 Gt(ConstantR0<int32>(&builder, 5), GetTupleElement(state, 0));
1201 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
1202 }
1203
1204 XlaComputation body;
1205 {
1206 XlaBuilder builder("body");
1207 auto state = Parameter(&builder, 0, while_shape, "state");
1208 auto indvar = GetTupleElement(state, 0);
1209 auto input_0 = GetTupleElement(state, 1);
1210 auto input_1 = GetTupleElement(state, 2);
1211 auto output = Tanh(Dot(input_0, input_1));
1212 auto indvar_next = Add(indvar, ConstantR0<int32>(&builder, 1));
1213 Tuple(&builder, {indvar_next, input_0, input_1, output});
1214 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
1215 }
1216
1217 XlaBuilder builder(TestName());
1218 auto matrix_input = Parameter(&builder, 0, matrix_shape, "matrix");
1219 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), matrix_input,
1220 matrix_input, matrix_input});
1221 auto while_instruction = While(condition, body, init);
1222 GetTupleElement(while_instruction, 3);
1223
1224 TF_ASSERT_OK_AND_ASSIGN(
1225 auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
1226 {{1.0, 2.0}, {-1.0, -2.0}})));
1227
1228 ComputeAndCompareR2<float>(
1229 &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
1230 {param_value.get()}, ErrorSpec(4e-5));
1231 }
1232
XLA_TEST_F(WhileTest,DISABLED_ON_INTERPRETER (WhileInfeedCondition))1233 XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
1234 auto while_shape = ShapeUtil::MakeShape(S32, {});
1235
1236 XlaComputation condition;
1237 {
1238 XlaBuilder builder("condition");
1239 Parameter(&builder, 0, while_shape, "state");
1240 Infeed(&builder, ShapeUtil::MakeShape(PRED, {}));
1241 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
1242 }
1243
1244 XlaComputation body;
1245 {
1246 XlaBuilder builder("body");
1247 auto indvar = Parameter(&builder, 0, while_shape, "state");
1248 Add(indvar, ConstantR0<int32>(&builder, 1));
1249 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
1250 }
1251
1252 XlaBuilder builder(TestName());
1253 While(condition, body, ConstantR0<int32>(&builder, 0));
1254
1255 TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
1256 TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
1257 TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
1258
1259 ComputeAndCompareR0<int32>(&builder, 2, {});
1260 }
1261
BM_WhileLoop(int num_iters)1262 void BM_WhileLoop(int num_iters) {
1263 // Benchmark a simple kernel to measure while loop overheads.
1264 tensorflow::testing::StopTiming();
1265
1266 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
1267 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
1268 StreamExecutorMemoryAllocator allocator(platform, executors);
1269 LocalClient* client =
1270 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
1271
1272 const int64 seq_len = 100;
1273 Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1274 {ShapeUtil::MakeShape(S32, {}),
1275 ShapeUtil::MakeShape(F32, {seq_len, 1024, 1024})});
1276
1277 // Create while condition computation with 'loop_limit'.
1278 const int32 loop_limit = 100;
1279 XlaComputation condition;
1280 {
1281 XlaBuilder builder("condition");
1282 auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
1283 auto iteration = GetTupleElement(prev, 0);
1284 Lt(iteration, ConstantR0<int32>(&builder, loop_limit));
1285 condition = builder.Build().ConsumeValueOrDie();
1286 }
1287
1288 // Create while body computation with unit loop increment.
1289 XlaComputation body;
1290 {
1291 XlaBuilder builder("body");
1292 auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
1293 // TupleElement 0
1294 auto iteration = GetTupleElement(prev, 0);
1295 auto out0 = Add(iteration, ConstantR0<int32>(&builder, 1));
1296 // TupleElement 1
1297 auto input = GetTupleElement(prev, 1);
1298 // Update.
1299 auto one = ConstantR0<float>(&builder, 1.0);
1300 auto update = Broadcast(one, {1, 1024, 1024});
1301 // Starts = iteration * 2;
1302 auto zero = ConstantR0<int32>(&builder, 0);
1303 // UpdateSlice.
1304 auto out1 = DynamicUpdateSlice(input, update, {zero, zero, zero});
1305 Tuple(&builder, {out0, out1});
1306 body = builder.Build().ConsumeValueOrDie();
1307 }
1308
1309 // Create a While instruction.
1310 XlaBuilder builder("while");
1311 auto zero = ConstantR0<float>(&builder, 0.0);
1312 auto input = Broadcast(zero, {seq_len, 1024, 1024});
1313 auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), input});
1314 While(condition, body, init);
1315 auto computation = builder.Build().ConsumeValueOrDie();
1316
1317 std::unique_ptr<LocalExecutable> executable =
1318 client->Compile(computation, {}, ExecutableBuildOptions())
1319 .ConsumeValueOrDie();
1320
1321 // Run some warm-up executions.
1322 ExecutableRunOptions options;
1323 options.set_allocator(&allocator);
1324 const int kWarmups = 2;
1325 for (int i = 0; i < kWarmups; ++i) {
1326 auto result = executable->Run({}, options);
1327 ASSERT_TRUE(result.ok());
1328 }
1329
1330 // Run benchmark.
1331 tensorflow::testing::StartTiming();
1332 for (int i = 0; i < num_iters; ++i) {
1333 auto result = executable->Run({}, options);
1334 ASSERT_TRUE(result.ok());
1335 }
1336 }
1337
1338 BENCHMARK(BM_WhileLoop);
1339 } // namespace
1340 } // namespace xla
1341