1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/platform_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/platform/test_benchmark.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 namespace {
42 
43 class WhileTest : public ClientLibraryTestBase {};
44 
45 // Tests a while node when the result type T is S32.
46 //
47 // int32 result = 0;
48 // while (result < 5) {
49 //   result = result + 1;
50 // }
XLA_TEST_F(WhileTest,WhileWithScalarS32Result)51 XLA_TEST_F(WhileTest, WhileWithScalarS32Result) {
52   auto result_shape = ShapeUtil::MakeShape(S32, {});
53 
54   // Create a computation for the condition: repeat for 5 iterations.
55   XlaComputation condition;
56   {
57     XlaBuilder builder("condition");
58     auto prev = Parameter(&builder, 0, result_shape, "prev");
59     Gt(ConstantR0<int32>(&builder, 5), prev);
60     condition = builder.Build().ConsumeValueOrDie();
61   }
62 
63   // Create a computation for the body: add 1 to the result variable.
64   XlaComputation body;
65   {
66     XlaBuilder builder("body");
67     auto prev = Parameter(&builder, 0, result_shape, "prev");
68     auto input = ConstantR0<int32>(&builder, 1);
69     Add(input, prev);
70     body = builder.Build().ConsumeValueOrDie();
71   }
72 
73   // Create a While node with computations for the condition and the body.
74   XlaBuilder builder(TestName());
75   auto init = ConstantR0<int32>(&builder, 0);
76   While(condition, body, init);
77 
78   ComputeAndCompareR0<int32>(&builder, 5, {});
79 }
80 
81 // Tests a while node when the result type T is S64.
82 //
83 // int32 result = 0;
84 // while (result < 5) {
85 //   result = result + 1;
86 // }
XLA_TEST_F(WhileTest,WhileWithScalarS64Result)87 XLA_TEST_F(WhileTest, WhileWithScalarS64Result) {
88   auto result_shape = ShapeUtil::MakeShape(S64, {});
89 
90   // Create a computation for the condition: repeat for 5 iterations.
91   XlaComputation condition;
92   {
93     XlaBuilder builder("condition");
94     auto prev = Parameter(&builder, 0, result_shape, "prev");
95     Gt(ConstantR0<int64>(&builder, 5), prev);
96     condition = builder.Build().ConsumeValueOrDie();
97   }
98 
99   // Create a computation for the body: add 1 to the result variable.
100   XlaComputation body;
101   {
102     XlaBuilder builder("body");
103     auto prev = Parameter(&builder, 0, result_shape, "prev");
104     auto input = ConstantR0<int64>(&builder, 1);
105     Add(input, prev);
106     body = builder.Build().ConsumeValueOrDie();
107   }
108 
109   // Create a While node with computations for the condition and the body.
110   XlaBuilder builder(TestName());
111   auto init = ConstantR0<int64>(&builder, 0);
112   While(condition, body, init);
113 
114   ComputeAndCompareR0<int64>(&builder, 5, {});
115 }
116 
XLA_TEST_F(WhileTest,WhileWithScalarResultNonConstInit)117 XLA_TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
118   auto result_shape = ShapeUtil::MakeShape(S32, {});
119   auto orig_shape = ShapeUtil::MakeShape(S32, {2});
120 
121   // Create a computation for the condition: repeat for 5 iterations.
122   XlaComputation condition;
123   {
124     XlaBuilder builder("condition");
125     auto prev = Parameter(&builder, 0, result_shape, "prev");
126     Gt(ConstantR0<int32>(&builder, 5), prev);
127     condition = builder.Build().ConsumeValueOrDie();
128   }
129 
130   // Create a computation for the body: add 1 to the result variable.
131   XlaComputation body;
132   {
133     XlaBuilder builder("body");
134     auto prev = Parameter(&builder, 0, result_shape, "prev");
135     auto input = ConstantR0<int32>(&builder, 1);
136     Add(input, prev);
137     body = builder.Build().ConsumeValueOrDie();
138   }
139 
140   // Create a While node with computations for the condition and the body.
141   XlaBuilder builder(TestName());
142   auto init =
143       Reduce(ConstantR1<int32>(&builder, 2, 1), ConstantR0<int32>(&builder, 0),
144              CreateScalarAddComputation(S32, &builder), {0});
145   While(condition, body, init);
146 
147   ComputeAndCompareR0<int32>(&builder, 5, {});
148 }
149 
XLA_TEST_F(WhileTest,WhileWithPredicateResult)150 XLA_TEST_F(WhileTest, WhileWithPredicateResult) {
151   auto result_shape = ShapeUtil::MakeShape(PRED, {});
152 
153   // Create a computation for the condition: run until condition is true.
154   XlaComputation condition;
155   {
156     XlaBuilder builder("condition");
157     auto prev = Parameter(&builder, 0, result_shape, "prev");
158     Ne(ConstantR0<bool>(&builder, true), prev);
159     condition = builder.Build().ConsumeValueOrDie();
160   }
161 
162   // Create a computation for the body: or condition with true.
163   XlaComputation body;
164   {
165     XlaBuilder builder("body");
166     auto prev = Parameter(&builder, 0, result_shape, "prev");
167     Or(prev, ConstantR0<bool>(&builder, true));
168     body = builder.Build().ConsumeValueOrDie();
169   }
170 
171   // Create a While node with computations for the condition and the body.
172   XlaBuilder builder(TestName());
173   auto init =
174       Ne(ConstantR0<bool>(&builder, false), ConstantR0<bool>(&builder, true));
175   While(condition, body, init);
176 
177   ComputeAndCompareR0<bool>(&builder, true, {});
178 }
179 
180 // Tests a while node when the result type T is a vector.
181 //
182 // All constants are chosen to produce exact results.
183 // vector<float> result(0);
184 // while (result.sum() < 15.5f) {
185 //   result = result + vector<float>(0);
186 // }
XLA_TEST_F(WhileTest,DISABLED_ON_INTERPRETER (WhileWithEmptyVectorResult))187 XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) {
188   Shape result_shape = ShapeUtil::MakeShape(F32, {0});
189 
190   // Create a computation for the reduction.
191   XlaComputation add;
192   {
193     XlaBuilder builder("add");
194     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
195     auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
196     Add(x, y);
197     add = builder.Build().ConsumeValueOrDie();
198   }
199 
200   // Create a computation for the condition.
201   // Repeat until the sum of the result vector is less than 15.5f.
202   XlaComputation condition;
203   {
204     XlaBuilder builder("condition");
205     auto prev = Parameter(&builder, 0, result_shape, "prev");
206     auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
207                       /*dimensions_to_reduce=*/{0});
208     Gt(ConstantR0<float>(&builder, 15.5f), sum);
209     condition = builder.Build().ConsumeValueOrDie();
210   }
211 
212   // Create a computation for the body.
213   // Add a constant vector of 1.f to the result vector.
214   XlaComputation body;
215   {
216     XlaBuilder builder("body");
217     auto prev = Parameter(&builder, 0, result_shape, "prev");
218     auto input = ConstantR1<float>(&builder, {});
219     Add(input, prev);
220     body = builder.Build().ConsumeValueOrDie();
221   }
222 
223   // Create a While node with computations for the condition and the body.
224   XlaBuilder builder("while");
225   auto init = ConstantR1<float>(&builder, {});
226   auto result = While(condition, body, init);
227   VLOG(2) << "while = "
228           << ShapeUtil::HumanString(
229                  builder.GetShape(result).ConsumeValueOrDie());
230 
231   ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
232 }
233 
234 // Tests a while node when the result type T is a vector.
235 //
236 // All constants are chosen to produce exact results.
237 // vector<float> result(8, 0.0f);
238 // while (result.sum() < 15.5f) {
239 //   result = result + vector<float>(8, 0.125f);
240 // }
XLA_TEST_F(WhileTest,WhileWithVectorResult)241 XLA_TEST_F(WhileTest, WhileWithVectorResult) {
242   Shape result_shape = ShapeUtil::MakeShape(F32, {8});
243 
244   // Create a computation for the reduction.
245   XlaComputation add;
246   {
247     XlaBuilder builder("add");
248     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
249     auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
250     Add(x, y);
251     add = builder.Build().ConsumeValueOrDie();
252   }
253 
254   // Create a computation for the condition.
255   // Repeat until the sum of the result vector is less than 5.5f.
256   XlaComputation condition;
257   {
258     XlaBuilder builder("condition");
259     auto prev = Parameter(&builder, 0, result_shape, "prev");
260     auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
261                       /*dimensions_to_reduce=*/{0});
262     Gt(ConstantR0<float>(&builder, 15.5f), sum);
263     condition = builder.Build().ConsumeValueOrDie();
264   }
265 
266   // Create a computation for the body.
267   // Add a constant vector of 1.f to the result vector.
268   XlaComputation body;
269   {
270     XlaBuilder builder("body");
271     auto prev = Parameter(&builder, 0, result_shape, "prev");
272     auto input = ConstantR1<float>(&builder, 8, 0.125f);
273     Add(input, prev);
274     body = builder.Build().ConsumeValueOrDie();
275   }
276 
277   // Create a While node with computations for the condition and the body.
278   XlaBuilder builder("while");
279   auto init = ConstantR1<float>(&builder, 8, 0.f);
280   auto result = While(condition, body, init);
281   VLOG(2) << "while = "
282           << ShapeUtil::HumanString(
283                  builder.GetShape(result).ConsumeValueOrDie());
284 
285   // Individual elements with increase by 1/8 each time through the loop, so
286   // the sum will increase by 1.0.  It will first be >15.5 when the elements
287   // have all reached 2.0.
288   std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f};
289   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
290 }
291 
292 // Tests a while node when the result type is a vector which is part
293 // of the result tuple.
294 //
295 // All constants are chosen to produce exact results.
296 // vector<float> result(8, 0.0f);
297 // while (result.sum() < 15.5f) {
298 //   result = result + vector<float>(8, 0.125f);
299 // }
300 // tuple = tuple { while }
XLA_TEST_F(WhileTest,WhileWithVectorResultIntoTuple)301 XLA_TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
302   Shape result_shape = ShapeUtil::MakeShape(F32, {8});
303 
304   // Create a computation for the reduction.
305   XlaComputation add;
306   {
307     XlaBuilder builder("add");
308     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
309     auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
310     Add(x, y);
311     add = builder.Build().ConsumeValueOrDie();
312   }
313 
314   // Create a computation for the condition.
315   // Repeat until the sum of the result vector is less than 5.5f.
316   XlaComputation condition;
317   {
318     XlaBuilder builder("condition");
319     auto prev = Parameter(&builder, 0, result_shape, "prev");
320     auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
321                       /*dimensions_to_reduce=*/{0});
322     Gt(ConstantR0<float>(&builder, 15.5f), sum);
323     condition = builder.Build().ConsumeValueOrDie();
324   }
325 
326   // Create a computation for the body.
327   // Add a constant vector of 1.f to the result vector.
328   XlaComputation body;
329   {
330     XlaBuilder builder("body");
331     auto prev = Parameter(&builder, 0, result_shape, "prev");
332     auto input = ConstantR1<float>(&builder, 8, 0.125f);
333     Add(input, prev);
334     body = builder.Build().ConsumeValueOrDie();
335   }
336 
337   // Create a While node with computations for the condition and the body.
338   XlaBuilder builder("while");
339   auto init = ConstantR1<float>(&builder, 8, 0.f);
340   auto result = While(condition, body, init);
341   VLOG(2) << "while = "
342           << ShapeUtil::HumanString(
343                  builder.GetShape(result).ConsumeValueOrDie());
344   Tuple(&builder, {result});
345 
346   // Individual elements with increase by 1/8 each time through the loop, so
347   // the sum will increase by 1.0.  It will first be >15.5 when the elements
348   // have all reached 2.0.
349   auto expected_data =
350       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
351   auto expected = LiteralUtil::MakeTuple({&expected_data});
352   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
353   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
354 }
355 
XLA_TEST_F(WhileTest,WhileWithPermutationAndTupleResult)356 XLA_TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
357   std::vector<Shape> shape_elements = {
358       ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
359       ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
360   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
361 
362   // Create a computation for the condition.
363   // Repeat for N iterations.
364   const int N = 2;
365   XlaComputation condition;
366   {
367     XlaBuilder builder("condition");
368     auto prev = Parameter(&builder, 0, result_shape, "prev");
369     auto iteration = GetTupleElement(prev, 0);
370     Gt(ConstantR0<int32>(&builder, N), iteration);
371     condition = builder.Build().ConsumeValueOrDie();
372   }
373 
374   // Create a computation for the body.
375   // Add 1 to the iteration variable and permute the weights.
376   XlaComputation body;
377   {
378     XlaBuilder builder("body");
379     auto prev = Parameter(&builder, 0, result_shape, "prev");
380     auto iteration = GetTupleElement(prev, 0);
381     auto w1 = GetTupleElement(prev, 1);
382     auto w2 = GetTupleElement(prev, 2);
383     auto w3 = GetTupleElement(prev, 3);
384     Tuple(&builder,
385           {Add(iteration, ConstantR0<int32>(&builder, 1)), w3, w1, w2});
386     body = builder.Build().ConsumeValueOrDie();
387   }
388 
389   // Create a While node with computations for the condition and the body.
390   XlaBuilder builder("while");
391   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
392                                ConstantR1<float>(&builder, 3, 1.f),
393                                ConstantR1<float>(&builder, 3, 2.f),
394                                ConstantR1<float>(&builder, 3, 3.f)});
395   auto result = While(condition, body, init);
396   VLOG(2) << "result = "
397           << ShapeUtil::HumanString(
398                  builder.GetShape(result).ConsumeValueOrDie());
399 
400   auto expected_counter = LiteralUtil::CreateR0<int32>(N);
401   auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
402   auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
403   auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
404   auto expected = LiteralUtil::MakeTuple(
405       {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
406   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
407   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
408 }
409 
XLA_TEST_F(WhileTest,WhileWithPermutationAndVectorResult)410 XLA_TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
411   std::vector<Shape> shape_elements = {
412       ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
413       ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
414   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
415 
416   // Create a computation for the condition.
417   // Repeat for N iterations.
418   const int N = 2;
419   XlaComputation condition;
420   {
421     XlaBuilder builder("condition");
422     auto prev = Parameter(&builder, 0, result_shape, "prev");
423     auto iteration = GetTupleElement(prev, 0);
424     Gt(ConstantR0<int32>(&builder, N), iteration);
425     condition = builder.Build().ConsumeValueOrDie();
426   }
427 
428   // Create a computation for the body.
429   // Add 1 to the iteration variable permute the weights.
430   XlaComputation body;
431   {
432     XlaBuilder builder("body");
433     auto prev = Parameter(&builder, 0, result_shape, "prev");
434     auto iteration = GetTupleElement(prev, 0);
435     auto w1 = GetTupleElement(prev, 1);
436     auto w2 = GetTupleElement(prev, 2);
437     auto w3 = GetTupleElement(prev, 3);
438     Tuple(&builder,
439           {Add(iteration, ConstantR0<int32>(&builder, 1)), w3, w1, w2});
440     body = builder.Build().ConsumeValueOrDie();
441   }
442 
443   // Create a While node with computations for the condition and the body.
444   XlaBuilder builder("while");
445   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
446                                ConstantR1<float>(&builder, 3, 1.f),
447                                ConstantR1<float>(&builder, 3, 2.f),
448                                ConstantR1<float>(&builder, 3, 3.f)});
449   auto xla_while = While(condition, body, init);
450 
451   auto add12 =
452       Add(GetTupleElement(xla_while, 1), GetTupleElement(xla_while, 2));
453   auto result = Add(add12, GetTupleElement(xla_while, 3));
454   VLOG(2) << "result = "
455           << ShapeUtil::HumanString(
456                  builder.GetShape(result).ConsumeValueOrDie());
457   std::vector<float> expected = {6.f, 6.f, 6.f};
458   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
459 }
460 
461 // Tests a while node when the result type T is a Tuple.
462 //
463 // tuple<int32, vector<float>> result(0, vector<float>(10, 0.0f));
464 // while (get<0>(result) < 5) {
465 //   get<0>(result) = get<0>(result) + 1;
466 //   get<1>(result) = get<1>(result) + vector<float>(10, 1.0f);
467 // }
XLA_TEST_F(WhileTest,WhileWithTupleResult)468 XLA_TEST_F(WhileTest, WhileWithTupleResult) {
469   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
470                                        ShapeUtil::MakeShape(F32, {10})};
471   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
472 
473   // Create a computation for the condition.
474   // Repeat for 5 iterations.
475   XlaComputation condition;
476   {
477     XlaBuilder builder("condition");
478     auto prev = Parameter(&builder, 0, result_shape, "prev");
479     auto iteration = GetTupleElement(prev, 0);
480     Gt(ConstantR0<int32>(&builder, 5), iteration);
481     condition = builder.Build().ConsumeValueOrDie();
482   }
483 
484   // Create a computation for the body.
485   // Add 1 to the iteration variable and add a constant vector of 1.0f to
486   // the weight variable, both of which are tuple elements.
487   XlaComputation body;
488   {
489     XlaBuilder builder("body");
490     auto prev = Parameter(&builder, 0, result_shape, "prev");
491     auto iteration = GetTupleElement(prev, 0);
492     auto weights = GetTupleElement(prev, 1);
493     auto input = ConstantR1<float>(&builder, 10, 1.f);
494     auto new_weights = Add(weights, input);
495     Tuple(&builder,
496           {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
497     body = builder.Build().ConsumeValueOrDie();
498   }
499 
500   // Create a While node with computations for the condition and the body.
501   XlaBuilder builder("while");
502   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
503                                ConstantR1<float>(&builder, 10, 0.f)});
504   auto result = While(condition, body, init);
505   VLOG(2) << "while = "
506           << ShapeUtil::HumanString(
507                  builder.GetShape(result).ConsumeValueOrDie());
508 
509   auto expected_counter = LiteralUtil::CreateR0<int32>(5);
510   auto expected_data = LiteralUtil::CreateR1<float>(
511       {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
512   auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
513   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
514   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
515 }
516 
XLA_TEST_F(WhileTest,WhileWithPredicateTupleResult)517 XLA_TEST_F(WhileTest, WhileWithPredicateTupleResult) {
518   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
519                                        ShapeUtil::MakeShape(PRED, {})};
520   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
521 
522   // Create a computation for the condition.
523   // Repeat for 5 iterations.
524   XlaComputation condition;
525   {
526     XlaBuilder builder("condition");
527     auto prev = Parameter(&builder, 0, result_shape, "prev");
528     auto iteration = GetTupleElement(prev, 0);
529     Gt(ConstantR0<int32>(&builder, 5), iteration);
530     condition = builder.Build().ConsumeValueOrDie();
531   }
532 
533   // Create a computation for the body.
534   // Add 1 to the iteration variable and or the predicate with true
535   XlaComputation body;
536   {
537     XlaBuilder builder("body");
538     auto prev = Parameter(&builder, 0, result_shape, "prev");
539     auto iteration = GetTupleElement(prev, 0);
540     auto pred = GetTupleElement(prev, 1);
541     auto new_pred = Or(pred, ConstantR0<bool>(&builder, true));
542     Tuple(&builder, {Add(iteration, ConstantR0<int32>(&builder, 1)), new_pred});
543     body = builder.Build().ConsumeValueOrDie();
544   }
545 
546   // Create a While node with computations for the condition and the body.
547   XlaBuilder builder("while");
548   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
549                                Ne(ConstantR0<bool>(&builder, false),
550                                   ConstantR0<bool>(&builder, true))});
551   auto result = While(condition, body, init);
552   VLOG(2) << "while = "
553           << ShapeUtil::HumanString(
554                  builder.GetShape(result).ConsumeValueOrDie());
555 
556   auto expected_counter = LiteralUtil::CreateR0<int32>(5);
557   auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
558   auto expected =
559       LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
560   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
561 }
562 
XLA_TEST_F(WhileTest,WhileWithTupleConstantScalarResult)563 XLA_TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
564   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
565                                        ShapeUtil::MakeShape(S32, {})};
566   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
567 
568   // Create a computation for the condition.
569   // Repeat for 5 iterations.
570   XlaComputation condition;
571   {
572     XlaBuilder builder("condition");
573     auto prev = Parameter(&builder, 0, result_shape, "prev");
574     auto iteration = GetTupleElement(prev, 0);
575     Gt(ConstantR0<int32>(&builder, 5), iteration);
576     condition = builder.Build().ConsumeValueOrDie();
577   }
578 
579   // Create a computation for the body.
580   // Add 1 to the iteration variable and set the other tuple element to a
581   // constant.
582   XlaComputation body;
583   {
584     XlaBuilder builder("body");
585     auto prev = Parameter(&builder, 0, result_shape, "prev");
586     auto iteration = GetTupleElement(prev, 0);
587     Tuple(&builder, {Add(iteration, ConstantR0<int32>(&builder, 1)),
588                      ConstantR0<int32>(&builder, 7)});
589     body = builder.Build().ConsumeValueOrDie();
590   }
591 
592   // Create a While node with computations for the condition and the body.
593   XlaBuilder builder("while");
594   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
595                                ConstantR0<int32>(&builder, 7)});
596   auto result = While(condition, body, init);
597   VLOG(2) << "while = "
598           << ShapeUtil::HumanString(
599                  builder.GetShape(result).ConsumeValueOrDie());
600 
601   auto expected_counter = LiteralUtil::CreateR0<int32>(5);
602   auto expected_data = LiteralUtil::CreateR0<int32>(7);
603   auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
604   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
605   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
606 }
607 
608 // Tests two while nodes when the result type T is a Tuple and the second
609 // while node uses the result of the first while node which is used in two
610 // nodes.
611 // tuple<int32, vector<float>> w0(0, vector<float>(10, 0.0f));
612 // w0 = while (get<0>(w0) < c1) {
613 //        get<0>(w0) = get<0>(w0) + 1;
614 //        get<1>(w0) = get<1>(w0) + vector<float>(10, 1.0f);
615 //      }
616 // tuple<int32, vector<float>> w1(get<0>(w0), get<1>(w0));
617 // w1 = while (get<0>(w1) < c2) {
618 //        get<0>(w1) = get<0>(w1) + 1;
619 //        get<1>(w1) = get<1>(w1) + vector<float>(10, 1.0f);
620 //      }
621 // result = get<1>(w0) + get<1>(w1)
XLA_TEST_F(WhileTest,TwoWhileWithTupleResult)622 XLA_TEST_F(WhileTest, TwoWhileWithTupleResult) {
623   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
624                                        ShapeUtil::MakeShape(F32, {10})};
625   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
626 
627   // Create a computation for the condition.
628   // Repeat for 5 iterations.
629   XlaComputation condition;
630   const int c1 = 5;
631   {
632     XlaBuilder builder("condition");
633     auto prev = Parameter(&builder, 0, result_shape, "prev");
634     auto iteration = GetTupleElement(prev, 0);
635     Lt(iteration, ConstantR0<int32>(&builder, c1));
636     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
637   }
638 
639   XlaComputation condition2;
640   const int c2 = 7;
641   {
642     XlaBuilder builder("condition2");
643     auto prev = Parameter(&builder, 0, result_shape, "prev");
644     auto iteration = GetTupleElement(prev, 0);
645     Lt(iteration, ConstantR0<int32>(&builder, c2));
646     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
647   }
648 
649   // Create a computation for the body.
650   // Add 1 to the iteration variable and add a constant vector of 1.0f to
651   // the weight variable, both of which are tuple elements.
652   XlaComputation body;
653   {
654     XlaBuilder builder("body");
655     auto prev = Parameter(&builder, 0, result_shape, "prev");
656     auto iteration = GetTupleElement(prev, 0);
657     auto weights = GetTupleElement(prev, 1);
658     auto input = ConstantR1<float>(&builder, 10, 1.f);
659     auto new_weights = Add(weights, input);
660     Tuple(&builder,
661           {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
662     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
663   }
664 
665   XlaComputation body2;
666   {
667     XlaBuilder builder("body");
668     auto prev = Parameter(&builder, 0, result_shape, "prev");
669     auto iteration = GetTupleElement(prev, 0);
670     auto weights = GetTupleElement(prev, 1);
671     auto input = ConstantR1<float>(&builder, 10, 1.f);
672     auto new_weights = Add(weights, input);
673     Tuple(&builder,
674           {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
675     TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
676   }
677 
678   // Create a While node with computations for the condition and the body.
679   XlaBuilder builder("while");
680   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
681                                ConstantR1<float>(&builder, 10, 0.f)});
682   auto while1 = While(condition, body, init);
683 
684   auto while2 = While(condition2, body2, while1);
685 
686   auto while_result1 = GetTupleElement(while1, 1);
687   auto while_result2 = GetTupleElement(while2, 1);
688   VLOG(2) << "while_result2 = "
689           << ShapeUtil::HumanString(
690                  builder.GetShape(while_result2).ConsumeValueOrDie());
691   auto result = Add(while_result1, while_result2);
692   VLOG(2) << "result = "
693           << ShapeUtil::HumanString(
694                  builder.GetShape(result).ConsumeValueOrDie());
695   const float sum = c1 + c2;
696   std::vector<float> expected(10, sum);
697   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
698 }
699 
700 // Test while nodes that share the while body computation.
XLA_TEST_F(WhileTest,TwoWhileLoopsAndSharedBody)701 XLA_TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
702   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
703                                        ShapeUtil::MakeShape(F32, {10})};
704   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
705 
706   // Create a computation for the condition.
707   // Repeat for 5 iterations.
708   XlaComputation condition;
709   const int c1 = 5;
710   {
711     XlaBuilder builder("condition");
712     auto prev = Parameter(&builder, 0, result_shape, "prev");
713     auto iteration = GetTupleElement(prev, 0);
714     Lt(iteration, ConstantR0<int32>(&builder, c1));
715     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
716   }
717 
718   XlaComputation condition2;
719   const int c2 = 7;
720   {
721     XlaBuilder builder("condition2");
722     auto prev = Parameter(&builder, 0, result_shape, "prev");
723     auto iteration = GetTupleElement(prev, 0);
724     Lt(iteration, ConstantR0<int32>(&builder, c2));
725     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
726   }
727 
728   // Create a computation for the body.
729   // Add 1 to the iteration variable and add a constant vector of 1.0f to
730   // the weight variable, both of which are tuple elements.
731   XlaComputation body;
732   {
733     XlaBuilder builder("body");
734     auto prev = Parameter(&builder, 0, result_shape, "prev");
735     auto iteration = GetTupleElement(prev, 0);
736     auto weights = GetTupleElement(prev, 1);
737     auto input = ConstantR1<float>(&builder, 10, 1.f);
738     auto new_weights = Add(weights, input);
739     Tuple(&builder,
740           {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
741     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
742   }
743 
744   // Create a While node with computations for the condition and the body.
745   XlaBuilder builder("while");
746   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
747                                ConstantR1<float>(&builder, 10, 0.f)});
748   auto while1 = While(condition, body, init);
749 
750   auto while2 = While(condition2, body, while1);
751 
752   auto while_result1 = GetTupleElement(while1, 1);
753   auto while_result2 = GetTupleElement(while2, 1);
754   VLOG(2) << "while_result2 = "
755           << ShapeUtil::HumanString(
756                  builder.GetShape(while_result2).ConsumeValueOrDie());
757   auto result = Add(while_result1, while_result2);
758   VLOG(2) << "result = "
759           << ShapeUtil::HumanString(
760                  builder.GetShape(result).ConsumeValueOrDie());
761   const float sum = c1 + c2;
762   std::vector<float> expected(10, sum);
763   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
764 }
765 
XLA_TEST_F(WhileTest,WhileLoopsWithSharedBodyAndInit)766 XLA_TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
767   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
768                                        ShapeUtil::MakeShape(F32, {10})};
769   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
770 
771   // Create a computation for the condition.
772   // Repeat for 5 iterations.
773   XlaComputation condition;
774   const int c1 = 5;
775   {
776     XlaBuilder builder("condition");
777     auto prev = Parameter(&builder, 0, result_shape, "prev");
778     auto iteration = GetTupleElement(prev, 0);
779     Lt(iteration, ConstantR0<int32>(&builder, c1));
780     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
781   }
782 
783   XlaComputation condition2;
784   const int c2 = 7;
785   {
786     XlaBuilder builder("condition2");
787     auto prev = Parameter(&builder, 0, result_shape, "prev");
788     auto iteration = GetTupleElement(prev, 0);
789     Lt(iteration, ConstantR0<int32>(&builder, c2));
790     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
791   }
792 
793   // Create a computation for the body.
794   // Add 1 to the iteration variable and add a constant vector of 1.0f to
795   // the weight variable, both of which are tuple elements.
796   XlaComputation body;
797   {
798     XlaBuilder builder("body");
799     auto prev = Parameter(&builder, 0, result_shape, "prev");
800     auto iteration = GetTupleElement(prev, 0);
801     auto weights = GetTupleElement(prev, 1);
802     auto input = ConstantR1<float>(&builder, 10, 1.f);
803     auto new_weights = Add(weights, input);
804     Tuple(&builder,
805           {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
806     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
807   }
808 
809   // Create a While node with computations for the condition and the body.
810   XlaBuilder builder("while");
811   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
812                                ConstantR1<float>(&builder, 10, 0.f)});
813   auto while1 = While(condition, body, init);
814   auto while2 = While(condition2, body, init);
815 
816   auto while_result1 = GetTupleElement(while1, 1);
817   auto while_result2 = GetTupleElement(while2, 1);
818   VLOG(2) << "while_result2 = "
819           << ShapeUtil::HumanString(
820                  builder.GetShape(while_result2).ConsumeValueOrDie());
821   auto result = Add(while_result1, while_result2);
822   VLOG(2) << "result = "
823           << ShapeUtil::HumanString(
824                  builder.GetShape(result).ConsumeValueOrDie());
825   const float sum = c1 + c2;
826   std::vector<float> expected(10, sum);
827   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
828 }
829 
830 // WhileTest that uses DynamicUpdateSlice instruction in body computation.
831 // Loop state tuple element 1 has as its single user operand(0) of
832 // DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU.
XLA_TEST_F(WhileTest,WhileWithDynamicUpdateSlice)833 XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
834   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
835                                        ShapeUtil::MakeShape(F32, {10})};
836   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
837 
838   // Create a computation for the condition.
839   // Repeat for 5 iterations.
840   XlaComputation condition;
841   {
842     XlaBuilder builder("condition");
843     auto prev = Parameter(&builder, 0, result_shape, "prev");
844     auto iteration = GetTupleElement(prev, 0);
845     Gt(ConstantR0<int32>(&builder, 5), iteration);
846     condition = builder.Build().ConsumeValueOrDie();
847   }
848 
849   // Create a computation for the body.
850   // Add 1 to the iteration variable and add a constant vector of 1.0f to
851   // the weight variable, both of which are tuple elements.
852   XlaComputation body;
853   {
854     XlaBuilder builder("body");
855     auto prev = Parameter(&builder, 0, result_shape, "prev");
856     // TupleElement 0
857     auto iteration = GetTupleElement(prev, 0);
858     auto out0 = Add(iteration, ConstantR0<int32>(&builder, 1));
859     // TupleElement 1
860     auto input = GetTupleElement(prev, 1);
861     // Update.
862     auto update = ConvertElementType(Broadcast(out0, {2}), F32);
863     // Starts = iteration * 2;
864     auto starts = Mul(iteration, ConstantR0<int32>(&builder, 2));
865     // UpdateSlice.
866     auto out1 = DynamicUpdateSlice(input, update, starts);
867 
868     Tuple(&builder, {out0, out1});
869     body = builder.Build().ConsumeValueOrDie();
870   }
871 
872   // Create a While node with computations for the condition and the body.
873   XlaBuilder builder("while");
874   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
875                                ConstantR1<float>(&builder, 10, 0.f)});
876   auto result = While(condition, body, init);
877   VLOG(2) << "while = "
878           << ShapeUtil::HumanString(
879                  builder.GetShape(result).ConsumeValueOrDie());
880 
881   auto expected_counter = LiteralUtil::CreateR0<int32>(5);
882   auto expected_data = LiteralUtil::CreateR1<float>(
883       {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
884   auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
885   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
886   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
887 }
888 
889 // Tests a while node when the result type T is a vector of S32.
890 //
891 // int32 result = (0, 0, 0, 0, 0, 0);
892 // while (result[0] < count) {
893 //   result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]);
894 // }
895 //
896 // This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a
897 // pair:
898 //   ((iteration, (random vector))).
899 //
900 // Note: this test currently only tests generating random values within a loop.
901 // Per backend the values generated can be different as the different backends
902 // use different random number generators.
903 // TODO(b/32240857): Extend test to verify outputs.
XLA_TEST_F(WhileTest,WhileWithPrngScalarResult)904 XLA_TEST_F(WhileTest, WhileWithPrngScalarResult) {
905   auto v6s32 = ShapeUtil::MakeShape(S32, {6});
906 
907   // Create a computation for the condition: repeat for count iterations.
908   auto build_condition = [this, v6s32](int count) {
909     XlaBuilder builder(TestName());
910     auto prev = Reshape(
911         Slice(Parameter(&builder, 0, v6s32, "prev"), {0}, {1}, {1}), {0}, {});
912     Gt(ConstantR0<int32>(&builder, count), prev);
913     return builder.Build().ConsumeValueOrDie();
914   };
915 
916   // Create a computation for the body: add 1 to the result variable.
917   XlaComputation body;
918   {
919     XlaBuilder builder("body");
920     auto prev = Parameter(&builder, 0, v6s32, "prev");
921     auto inc = ConcatInDim(&builder,
922                            {ConstantR1<int32>(&builder, {1}),
923                             RngUniform(ConstantR0<int32>(&builder, 0),
924                                        ConstantR0<int32>(&builder, 100),
925                                        ShapeUtil::MakeShape(S32, {5}))},
926                            0);
927     Add(inc, prev);
928     body = builder.Build().ConsumeValueOrDie();
929   }
930 
931   // Create a While node with computations for the condition and the body.
932   auto while_loop = [this, &body, build_condition](int count) {
933     XlaBuilder builder(TestName());
934     auto init = ConstantR1<int32>(&builder, {0, 0, 0, 0, 0, 0});
935     While(build_condition(count), body, init);
936     return builder.Build();
937   };
938 
939   for (int i = 1; i < 4; ++i) {
940     TF_ASSERT_OK_AND_ASSIGN(auto computation, while_loop(i));
941 
942     ExecutionOptions execution_options = execution_options_;
943     execution_options.set_seed(65);
944     TF_ASSERT_OK_AND_ASSIGN(
945         auto result,
946         client_->ExecuteAndTransfer(computation, {}, &execution_options));
947   }
948 }
949 
XLA_TEST_F(WhileTest,WhileThatSwapsParameterWithTupleElement)950 XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
951   auto element_shape = ShapeUtil::MakeShape(F32, {2});
952 
953   XlaBuilder outer("outer");
954   auto p = Parameter(&outer, 0, element_shape, "param");
955   auto t = Tuple(&outer, {p, ConstantR1<float>(&outer, {1, 1})});
956 
957   TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t));
958 
959   XlaBuilder cond("cond");
960   auto cond_t = Parameter(&cond, 0, tuple_shape, "t");
961   Any(Eq(GetTupleElement(cond_t, 0), ConstantR1<float>(&cond, {42, 42})));
962 
963   XlaBuilder body("body");
964   auto body_t = Parameter(&body, 0, tuple_shape, "t");
965   auto e = GetTupleElement(body_t, 1);
966   Tuple(&body, {e, e});
967 
968   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
969   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
970   While(cond_computation, body_computation, t);
971 
972   auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
973   auto expected =
974       LiteralUtil::MakeTuple({&expected_element, &expected_element});
975   TF_ASSERT_OK_AND_ASSIGN(
976       std::unique_ptr<GlobalData> parameter_data,
977       client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
978   ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
979                          ErrorSpec(1e-6));
980 }
981 
XLA_TEST_F(WhileTest,WhileThatSwapsParameterWithBroadcast)982 XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
983   auto element_shape = ShapeUtil::MakeShape(F32, {2});
984 
985   XlaBuilder outer("outer");
986   auto p = Parameter(&outer, 0, element_shape, "param");
987 
988   XlaBuilder cond("cond");
989   auto cond_t = Parameter(&cond, 0, element_shape, "t");
990   Any(Eq(cond_t, ConstantR1<float>(&cond, {42, 42})));
991 
992   XlaBuilder body("body");
993   Parameter(&body, 0, element_shape, "t");
994   Broadcast(ConstantR0<float>(&body, 1.0), {2});
995 
996   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
997   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
998   While(cond_computation, body_computation, p);
999 
1000   TF_ASSERT_OK_AND_ASSIGN(
1001       std::unique_ptr<GlobalData> parameter_data,
1002       client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
1003   ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
1004                              ErrorSpec(1e-6));
1005 }
1006 
XLA_TEST_F(WhileTest,WhileThatTurnsScalarParameterToTupleElement)1007 XLA_TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
1008   auto element_shape = ShapeUtil::MakeShape(F32, {});
1009 
1010   XlaBuilder outer("outer");
1011   auto p = Parameter(&outer, 0, element_shape, "param");
1012 
1013   XlaBuilder cond("cond");
1014   auto cond_t = Parameter(&cond, 0, element_shape, "t");
1015   Eq(cond_t, ConstantR0<float>(&cond, 42));
1016 
1017   XlaBuilder body("body");
1018   auto body_t = Parameter(&body, 0, element_shape, "t");
1019   auto tuple = Tuple(&body, {body_t, Add(body_t, ConstantR0<float>(&body, 1))});
1020   GetTupleElement(tuple, 1);
1021 
1022   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
1023   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
1024   While(cond_computation, body_computation, p);
1025 
1026   TF_ASSERT_OK_AND_ASSIGN(
1027       std::unique_ptr<GlobalData> parameter_data,
1028       client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
1029   ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
1030                              ErrorSpec(1e-6));
1031 }
1032 
1033 // Tests loop where the init value comes from two sources (constant and
1034 // parameter).
1035 //
1036 // int32 result = (0, 1);
1037 // while (result[0] + result[1] < 30) {
1038 //   result[0] = result[0] + 1;
1039 //   result[1] = result[1] + 1;
1040 // }
XLA_TEST_F(WhileTest,WhileWithMixedTupleElements)1041 XLA_TEST_F(WhileTest, WhileWithMixedTupleElements) {
1042   auto result_shape = ShapeUtil::MakeTupleShape(
1043       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
1044 
1045   XlaBuilder outer("outer");
1046   auto p =
1047       Tuple(&outer, {ConstantR0<int32>(&outer, 0),
1048                      Parameter(&outer, 0, ShapeUtil::MakeShape(S32, {}), "t")});
1049 
1050   XlaBuilder cond("cond");
1051   auto params = Parameter(&cond, 0, result_shape, "prev");
1052   auto cond_t = Add(GetTupleElement(params, 1), GetTupleElement(params, 0));
1053   Lt(cond_t, ConstantR0<int32>(&cond, 30));
1054 
1055   XlaBuilder body("body");
1056   auto body_t = Parameter(&body, 0, result_shape, "t");
1057 
1058   Tuple(&body, {Add(GetTupleElement(body_t, 0), ConstantR0<int32>(&body, 1)),
1059                 Add(GetTupleElement(body_t, 1), ConstantR0<int32>(&body, 1))});
1060 
1061   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
1062   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
1063   While(cond_computation, body_computation, p);
1064 
1065   TF_ASSERT_OK_AND_ASSIGN(
1066       std::unique_ptr<GlobalData> parameter_data,
1067       client_->TransferToServer(LiteralUtil::CreateR0<int32>(1)));
1068 
1069   auto add1 = LiteralUtil::CreateR0<int32>(15);
1070   auto add2 = LiteralUtil::CreateR0<int32>(16);
1071   auto expected = LiteralUtil::MakeTuple({&add1, &add2});
1072   ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
1073                          ErrorSpec(1e-6));
1074 }
1075 
1076 // Tests nested while loops.
1077 //
1078 // int32 result = 0;
1079 // while (result < 30) {
1080 //   int i = 0;
1081 //   while (i < 7) {
1082 //     result = result + 2;
1083 //     i = i + 1;
1084 //   }
1085 // }
XLA_TEST_F(WhileTest,NestedWhileWithScalarResult)1086 XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
1087   auto outer_result_shape = ShapeUtil::MakeShape(S32, {});
1088   auto inner_result_shape = ShapeUtil::MakeTupleShape(
1089       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
1090 
1091   XlaComputation inner_condition;
1092   {
1093     XlaBuilder builder("inner_condition");
1094     auto params = Parameter(&builder, 0, inner_result_shape, "prev");
1095     auto i = GetTupleElement(params, 0);
1096     Lt(i, ConstantR0<int32>(&builder, 7));
1097     inner_condition = builder.Build().ConsumeValueOrDie();
1098   }
1099 
1100   // Creates a computation for the outer loop condition:
1101   // repeat while result < 30.
1102   XlaComputation outer_condition;
1103   {
1104     XlaBuilder builder("outer_condition");
1105     auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
1106     Lt(prev, ConstantR0<int32>(&builder, 30));
1107     outer_condition = builder.Build().ConsumeValueOrDie();
1108   }
1109 
1110   // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
1111   // `result`.
1112   XlaComputation inner_body;
1113   {
1114     XlaBuilder builder("inner_body");
1115     auto params = Parameter(&builder, 0, inner_result_shape, "prev");
1116     auto i = GetTupleElement(params, 0);
1117     auto result = GetTupleElement(params, 1);
1118     i = Add(ConstantR0<int32>(&builder, 1), i);
1119     result = Add(ConstantR0<int32>(&builder, 2), result);
1120     Tuple(&builder, {i, result});
1121     inner_body = builder.Build().ConsumeValueOrDie();
1122   }
1123 
1124   // Creates a computation for the outer loop: run the inner loop with i = 0.
1125   XlaComputation outer_body;
1126   {
1127     XlaBuilder builder("outer_body");
1128     auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
1129     auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), prev});
1130     auto result = While(inner_condition, inner_body, init);
1131     GetTupleElement(result, 1);
1132     outer_body = builder.Build().ConsumeValueOrDie();
1133   }
1134 
1135   // Create a While node with computations for the condition and the body.
1136   XlaBuilder builder(TestName());
1137   auto init = ConstantR0<int32>(&builder, 0);
1138   While(outer_condition, outer_body, init);
1139 
1140   ComputeAndCompareR0<int32>(&builder, 42, {});
1141 }
1142 
1143 // Tests a while node when the result type T is S32.
1144 // f = lambda result: tuple({result < 5})
1145 // int32 result = 0;
1146 // while (f(result).get<0>()) {
1147 //   result = result + 1;
1148 // }
XLA_TEST_F(WhileTest,WhileWithCallInsideCondition)1149 XLA_TEST_F(WhileTest, WhileWithCallInsideCondition) {
1150   auto result_shape = ShapeUtil::MakeShape(S32, {});
1151 
1152   // Create a computation for the condition: repeat for 5 iterations.
1153   XlaComputation condition_callee;
1154   {
1155     XlaBuilder builder("condition_callee");
1156     auto prev = Parameter(&builder, 0, result_shape, "prev");
1157     Tuple(&builder, {Gt(ConstantR0<int32>(&builder, 5), prev)});
1158 
1159     condition_callee = builder.Build().ConsumeValueOrDie();
1160   }
1161 
1162   XlaComputation condition;
1163   {
1164     XlaBuilder builder("condition");
1165     auto prev = Parameter(&builder, 0, result_shape, "prev");
1166     auto result = Call(&builder, condition_callee, {prev});
1167     GetTupleElement(result, 0);
1168     condition = builder.Build().ConsumeValueOrDie();
1169   }
1170 
1171   // Create a computation for the body: add 1 to the result variable.
1172   XlaComputation body;
1173   {
1174     XlaBuilder builder("body");
1175     auto prev = Parameter(&builder, 0, result_shape, "prev");
1176     auto input = ConstantR0<int32>(&builder, 1);
1177     Add(input, prev);
1178     body = builder.Build().ConsumeValueOrDie();
1179   }
1180 
1181   // Create a While node with computations for the condition and the body.
1182   XlaBuilder builder(TestName());
1183   auto init = ConstantR0<int32>(&builder, 0);
1184   While(condition, body, init);
1185 
1186   ComputeAndCompareR0<int32>(&builder, 5, {});
1187 }
1188 
XLA_TEST_F(WhileTest,WhileWithLoopInvariantOperation)1189 XLA_TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
1190   auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2});
1191   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
1192   auto while_shape = ShapeUtil::MakeTupleShape(
1193       {scalar_s32, matrix_shape, matrix_shape, matrix_shape});
1194 
1195   // Create a computation for the condition: repeat for 5 iterations.
1196   XlaComputation condition;
1197   {
1198     XlaBuilder builder("condition");
1199     auto state = Parameter(&builder, 0, while_shape, "state");
1200     Gt(ConstantR0<int32>(&builder, 5), GetTupleElement(state, 0));
1201     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
1202   }
1203 
1204   XlaComputation body;
1205   {
1206     XlaBuilder builder("body");
1207     auto state = Parameter(&builder, 0, while_shape, "state");
1208     auto indvar = GetTupleElement(state, 0);
1209     auto input_0 = GetTupleElement(state, 1);
1210     auto input_1 = GetTupleElement(state, 2);
1211     auto output = Tanh(Dot(input_0, input_1));
1212     auto indvar_next = Add(indvar, ConstantR0<int32>(&builder, 1));
1213     Tuple(&builder, {indvar_next, input_0, input_1, output});
1214     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
1215   }
1216 
1217   XlaBuilder builder(TestName());
1218   auto matrix_input = Parameter(&builder, 0, matrix_shape, "matrix");
1219   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), matrix_input,
1220                                matrix_input, matrix_input});
1221   auto while_instruction = While(condition, body, init);
1222   GetTupleElement(while_instruction, 3);
1223 
1224   TF_ASSERT_OK_AND_ASSIGN(
1225       auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
1226                             {{1.0, 2.0}, {-1.0, -2.0}})));
1227 
1228   ComputeAndCompareR2<float>(
1229       &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
1230       {param_value.get()}, ErrorSpec(4e-5));
1231 }
1232 
XLA_TEST_F(WhileTest,DISABLED_ON_INTERPRETER (WhileInfeedCondition))1233 XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
1234   auto while_shape = ShapeUtil::MakeShape(S32, {});
1235 
1236   XlaComputation condition;
1237   {
1238     XlaBuilder builder("condition");
1239     Parameter(&builder, 0, while_shape, "state");
1240     Infeed(&builder, ShapeUtil::MakeShape(PRED, {}));
1241     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
1242   }
1243 
1244   XlaComputation body;
1245   {
1246     XlaBuilder builder("body");
1247     auto indvar = Parameter(&builder, 0, while_shape, "state");
1248     Add(indvar, ConstantR0<int32>(&builder, 1));
1249     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
1250   }
1251 
1252   XlaBuilder builder(TestName());
1253   While(condition, body, ConstantR0<int32>(&builder, 0));
1254 
1255   TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
1256   TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
1257   TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
1258 
1259   ComputeAndCompareR0<int32>(&builder, 2, {});
1260 }
1261 
BM_WhileLoop(int num_iters)1262 void BM_WhileLoop(int num_iters) {
1263   // Benchmark a simple kernel to measure while loop overheads.
1264   tensorflow::testing::StopTiming();
1265 
1266   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
1267   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
1268   StreamExecutorMemoryAllocator allocator(platform, executors);
1269   LocalClient* client =
1270       ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
1271 
1272   const int64 seq_len = 100;
1273   Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1274       {ShapeUtil::MakeShape(S32, {}),
1275        ShapeUtil::MakeShape(F32, {seq_len, 1024, 1024})});
1276 
1277   // Create while condition computation with 'loop_limit'.
1278   const int32 loop_limit = 100;
1279   XlaComputation condition;
1280   {
1281     XlaBuilder builder("condition");
1282     auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
1283     auto iteration = GetTupleElement(prev, 0);
1284     Lt(iteration, ConstantR0<int32>(&builder, loop_limit));
1285     condition = builder.Build().ConsumeValueOrDie();
1286   }
1287 
1288   // Create while body computation with unit loop increment.
1289   XlaComputation body;
1290   {
1291     XlaBuilder builder("body");
1292     auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
1293     // TupleElement 0
1294     auto iteration = GetTupleElement(prev, 0);
1295     auto out0 = Add(iteration, ConstantR0<int32>(&builder, 1));
1296     // TupleElement 1
1297     auto input = GetTupleElement(prev, 1);
1298     // Update.
1299     auto one = ConstantR0<float>(&builder, 1.0);
1300     auto update = Broadcast(one, {1, 1024, 1024});
1301     // Starts = iteration * 2;
1302     auto zero = ConstantR0<int32>(&builder, 0);
1303     // UpdateSlice.
1304     auto out1 = DynamicUpdateSlice(input, update, {zero, zero, zero});
1305     Tuple(&builder, {out0, out1});
1306     body = builder.Build().ConsumeValueOrDie();
1307   }
1308 
1309   // Create a While instruction.
1310   XlaBuilder builder("while");
1311   auto zero = ConstantR0<float>(&builder, 0.0);
1312   auto input = Broadcast(zero, {seq_len, 1024, 1024});
1313   auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), input});
1314   While(condition, body, init);
1315   auto computation = builder.Build().ConsumeValueOrDie();
1316 
1317   std::unique_ptr<LocalExecutable> executable =
1318       client->Compile(computation, {}, ExecutableBuildOptions())
1319           .ConsumeValueOrDie();
1320 
1321   // Run some warm-up executions.
1322   ExecutableRunOptions options;
1323   options.set_allocator(&allocator);
1324   const int kWarmups = 2;
1325   for (int i = 0; i < kWarmups; ++i) {
1326     auto result = executable->Run({}, options);
1327     ASSERT_TRUE(result.ok());
1328   }
1329 
1330   // Run benchmark.
1331   tensorflow::testing::StartTiming();
1332   for (int i = 0; i < num_iters; ++i) {
1333     auto result = executable->Run({}, options);
1334     ASSERT_TRUE(result.ok());
1335   }
1336 }
1337 
1338 BENCHMARK(BM_WhileLoop);
1339 }  // namespace
1340 }  // namespace xla
1341