1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/client/global_data.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace xla {
39 namespace {
40 
41 class ParamsTest : public ClientLibraryTestBase {};
42 
XLA_TEST_F(ParamsTest,ConstantR0F32Param)43 XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
44   XlaBuilder builder(TestName());
45   Literal param0_literal = LiteralUtil::CreateR0<float>(3.14159f);
46   std::unique_ptr<GlobalData> param0_data =
47       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
48 
49   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
50 
51   ComputeAndCompareR0<float>(&builder, 3.14159f, {param0_data.get()},
52                              ErrorSpec(0.0001f));
53 }
54 
XLA_TEST_F(ParamsTest,ConstantR1S0F32Param)55 XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
56   XlaBuilder builder(TestName());
57   Literal param0_literal = LiteralUtil::CreateR1<float>({});
58   std::unique_ptr<GlobalData> param0_data =
59       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
60 
61   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0");
62 
63   ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
64                              ErrorSpec(0.01f));
65 }
66 
XLA_TEST_F(ParamsTest,ConstantR1S2F32Param)67 XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
68   XlaBuilder builder(TestName());
69   Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
70   std::unique_ptr<GlobalData> param0_data =
71       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
72 
73   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
74 
75   ComputeAndCompareR1<float>(&builder, {3.14f, -100.25f}, {param0_data.get()},
76                              ErrorSpec(0.01f));
77 }
78 
XLA_TEST_F(ParamsTest,ConstantR1U8Param)79 XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
80   XlaBuilder builder(TestName());
81   string str("hello world");
82   Literal param0_literal = LiteralUtil::CreateR1U8(str);
83   std::unique_ptr<GlobalData> param0_data =
84       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
85 
86   Parameter(&builder, 0,
87             ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}),
88             "param0");
89 
90   ComputeAndCompareR1U8(&builder, str, {param0_data.get()});
91 }
92 
XLA_TEST_F(ParamsTest,ConstantR2_3x0_F32Param)93 XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
94   XlaBuilder builder(TestName());
95   Literal param0_literal =
96       LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
97   std::unique_ptr<GlobalData> param0_data =
98       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
99 
100   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
101 
102   ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0),
103                              {param0_data.get()}, ErrorSpec(0.01f));
104 }
105 
XLA_TEST_F(ParamsTest,ConstantR2F32Param)106 XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
107   XlaBuilder builder(TestName());
108   Literal param0_literal = LiteralUtil::CreateR2<float>(
109       {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
110   std::unique_ptr<GlobalData> param0_data =
111       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
112 
113   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
114 
115   Array2D<float> expected_array(
116       {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
117   ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
118                              ErrorSpec(0.01f));
119 }
120 
XLA_TEST_F(ParamsTest,TwoParameters)121 XLA_TEST_F(ParamsTest, TwoParameters) {
122   XlaBuilder builder(TestName());
123 
124   Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
125   std::unique_ptr<GlobalData> param0_data =
126       client_->TransferToServer(literal0).ConsumeValueOrDie();
127   auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
128 
129   Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
130   std::unique_ptr<GlobalData> param1_data =
131       client_->TransferToServer(literal1).ConsumeValueOrDie();
132   auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
133 
134   // Use both parameters
135   //
136   // {1, 2} + {10, 20} = {11, 22}
137   auto sum = Add(param0, param1);
138   sum = Add(param0, param1);
139 
140   // Use only the second parameter again, to show that it can be used
141   // twice and to make the computation asymmetric in the two
142   // parameters to test that the parameters are not swapped.
143   //
144   // {11, 22} * {10, 20} = {110, 440}
145   Mul(sum, param1);
146 
147   ComputeAndCompareR1<float>(&builder, {110, 440},
148                              {param0_data.get(), param1_data.get()},
149                              ErrorSpec(0.0001f));
150 }
151 
XLA_TEST_F(ParamsTest,MissingParameter)152 XLA_TEST_F(ParamsTest, MissingParameter) {
153   // Test that an error is returned when a computation with an incomplete set of
154   // parameters (parameter numbers not contiguous from 0) is executed.
155   Literal literal = LiteralUtil::CreateR0<float>(3.14159f);
156   std::unique_ptr<GlobalData> data =
157       client_->TransferToServer(literal).ConsumeValueOrDie();
158 
159   XlaBuilder builder(TestName());
160   Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2");
161   auto computation_status = builder.Build();
162 
163   ASSERT_NE(computation_status.status(), Status::OK());
164 }
165 
XLA_TEST_F(ParamsTest,UnusedParameter)166 XLA_TEST_F(ParamsTest, UnusedParameter) {
167   XlaBuilder builder(TestName());
168 
169   Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
170   std::unique_ptr<GlobalData> param0_data =
171       client_->TransferToServer(literal0).ConsumeValueOrDie();
172   Parameter(&builder, 0, literal0.shape(), "param0");
173 
174   Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
175   std::unique_ptr<GlobalData> param1_data =
176       client_->TransferToServer(literal1).ConsumeValueOrDie();
177   Parameter(&builder, 1, literal1.shape(), "param1");
178 
179   ComputeAndCompareR1<float>(&builder, {10, 20},
180                              {param0_data.get(), param1_data.get()},
181                              ErrorSpec(0.0001f));
182 }
183 
XLA_TEST_F(ParamsTest,UnusedParametersInUnusedExpression)184 XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
185   // Build a computation with a couple unused parameters which are used in an
186   // unused expression.
187   XlaBuilder builder(TestName());
188 
189   Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
190   std::unique_ptr<GlobalData> param0_data =
191       client_->TransferToServer(literal0).ConsumeValueOrDie();
192 
193   Literal literal1 = LiteralUtil::CreateR1<float>({10, 20, 30});
194   std::unique_ptr<GlobalData> param1_data =
195       client_->TransferToServer(literal1).ConsumeValueOrDie();
196 
197   auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
198   auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
199   auto param2 = Parameter(&builder, 2, literal1.shape(), "param2");
200 
201   // This add is unused.
202   Add(param1, param2);
203 
204   Neg(param0);
205 
206   ComputeAndCompareR1<float>(
207       &builder, {-1, -2},
208       {param0_data.get(), param1_data.get(), param1_data.get()},
209       ErrorSpec(0.0001f));
210 }
211 
XLA_TEST_F(ParamsTest,HundredLargeR1Parameters)212 XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
213   XlaBuilder builder(TestName());
214   constexpr int size = 8 * 128 * 2;
215 
216   std::vector<float> init_value = {{0, 1}};
217   init_value.resize(size);
218   XlaOp sum_handle = ConstantR1<float>(&builder, init_value);
219   std::vector<float> sum = {{0, 1}};
220   sum.resize(size);
221 
222   std::vector<std::unique_ptr<GlobalData>> param_data_owner;
223 
224   constexpr int parameter_count = 100;
225   for (int i = 0; i < parameter_count; ++i) {
226     const float entry0 = i;
227     const float entry1 = 2 * i;
228     sum[0] += entry0;
229     sum[1] += entry1;
230 
231     std::vector<float> sum_value = {{entry0, entry1}};
232     sum_value.resize(size);
233     Literal literal = LiteralUtil::CreateR1<float>(sum_value);
234     param_data_owner.push_back(
235         client_->TransferToServer(literal).ConsumeValueOrDie());
236     XlaOp param = Parameter(&builder, i, literal.shape(), "param");
237     sum_handle = Add(sum_handle, param);
238   }
239 
240   std::vector<GlobalData*> param_data;
241   param_data.reserve(param_data_owner.size());
242   for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
243     param_data.push_back(data.get());
244   }
245 
246   ComputeAndCompareR1<float>(&builder, sum, param_data, ErrorSpec(0.0001f));
247 }
248 
249 // Only run the 3,000-parameter tests in opt mode to avoid test timeouts.
250 // Timeout last observed on 2017-11-20.
251 #ifdef NDEBUG
252 
253 // TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too
254 // much space in parameter memory for the kernel.
255 //
256 // TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM
257 // compilation.
XLA_TEST_F(ParamsTest,DISABLED_ON_CPU (DISABLED_ON_GPU (ThreeThousandParameters)))258 XLA_TEST_F(ParamsTest,
259            DISABLED_ON_CPU(DISABLED_ON_GPU(ThreeThousandParameters))) {
260   XlaBuilder builder(TestName());
261 
262   std::vector<std::unique_ptr<GlobalData>> param_data_owner;
263   XlaOp sum_handle = ConstantR0<float>(&builder, 0.0f);
264   float target = 0.0;
265   constexpr int kParamCount = 3000;
266   for (int i = 0; i < kParamCount; ++i) {
267     target += i;
268     Literal literal = LiteralUtil::CreateR0<float>(i);
269     param_data_owner.push_back(
270         std::move(client_->TransferToServer(literal)).ValueOrDie());
271     XlaOp param = Parameter(&builder, i, literal.shape(), "param");
272     sum_handle = Add(sum_handle, param);
273   }
274 
275   std::vector<GlobalData*> param_data;
276   param_data.reserve(param_data_owner.size());
277   for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
278     param_data.push_back(data.get());
279   }
280 
281   ComputeAndCompareR0<float>(&builder, target, param_data, ErrorSpec(0.0001f));
282 }
283 
284 // TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too
285 // much space in parameter memory for the kernel.
286 //
287 // TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM
288 // compilation.
XLA_TEST_F(ParamsTest,DISABLED_ON_CPU (DISABLED_ON_GPU (ThreeThousandParametersAndOutputElements)))289 XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
290                            ThreeThousandParametersAndOutputElements))) {
291   XlaBuilder builder(TestName());
292 
293   std::vector<std::unique_ptr<GlobalData>> param_data_owner;
294   XlaOp sum_handle = ConstantR1<int32>(&builder, {0, 0});
295   int32 target = 0;
296   constexpr int kParamCount = 3000;
297   std::vector<XlaOp> params;
298   for (int i = 0; i < kParamCount; ++i) {
299     target += i;
300     Literal literal = LiteralUtil::CreateR1<int32>({i, i});
301     param_data_owner.push_back(
302         std::move(client_->TransferToServer(literal)).ValueOrDie());
303     XlaOp param = Parameter(&builder, i, literal.shape(), "param");
304     params.push_back(param);
305     sum_handle = Add(sum_handle, param);
306   }
307 
308   std::vector<XlaOp> outputs;
309   for (int i = 0; i < kParamCount; ++i) {
310     outputs.push_back(Add(params[i], sum_handle));
311   }
312 
313   Tuple(&builder, outputs);
314 
315   std::vector<GlobalData*> param_data;
316   param_data.reserve(param_data_owner.size());
317   for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
318     param_data.push_back(data.get());
319   }
320 
321   std::vector<Literal> elements;
322   std::vector<const Literal*> ptrs;
323   elements.reserve(kParamCount);
324   for (int i = 0; i < kParamCount; ++i) {
325     elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
326     ptrs.push_back(&elements.back());
327   }
328   ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
329 }
330 
331 // Test large number of parameters flowing into a while-loop.
332 // Construct conceptually the following HLO graph:
333 //
334 // p0 = parameter(0)
335 // p1 = parameter(1)
336 // ...
337 // pN = parameter(N)
338 // result = while (false) {
339 //   p0 += (1, 1);
340 //   p1 += (1, 1);
341 //   ...
342 //   pN += (1, 1)
343 // }
344 // result = {p0, p1, ..., pN}
345 //
346 // TODO(b/70173746): Times out during compilation on GPU and CPU backends as of
347 // 2017-12-12.
XLA_TEST_F(ParamsTest,DISABLED_ON_CPU (DISABLED_ON_GPU (ManyParametersIntoWhileLoop)))348 XLA_TEST_F(ParamsTest,
349            DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) {
350   XlaBuilder builder(TestName());
351 
352   std::vector<std::unique_ptr<GlobalData>> param_data_owner;
353   constexpr int kParamCount = 1900;
354   std::vector<XlaOp> params;
355   std::vector<Shape> parameter_shapes;
356   for (int i = 0; i < kParamCount; ++i) {
357     Literal literal = LiteralUtil::CreateR1<int32>({i, i});
358     param_data_owner.push_back(
359         std::move(client_->TransferToServer(literal)).ValueOrDie());
360     XlaOp param = Parameter(&builder, i, literal.shape(), "param");
361     params.push_back(param);
362     parameter_shapes.push_back(literal.shape());
363   }
364 
365   // Add bool parameter for the loop condition. Use a parameter HLO instead of a
366   // constant because DCE may eliminate the while-body otherwise.
367   Literal bool_literal = LiteralUtil::CreateR0<bool>(false);
368   param_data_owner.push_back(
369       std::move(client_->TransferToServer(bool_literal)).ValueOrDie());
370   XlaOp bool_param =
371       Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param");
372   params.push_back(bool_param);
373   parameter_shapes.push_back(bool_literal.shape());
374 
375   auto init = Tuple(&builder, params);
376 
377   // Create a computation for the condition: while(bool_param).
378   Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes);
379   XlaComputation condition;
380   {
381     XlaBuilder builder("condition");
382     auto condition_parameter =
383         Parameter(&builder, 0, while_shape, "condition_parameter");
384     GetTupleElement(condition_parameter, kParamCount);
385     condition = builder.Build().ConsumeValueOrDie();
386   }
387 
388   // Create a computation for the body.
389   // Add {1, 1} to the each tuple element.
390   XlaComputation body;
391   {
392     XlaBuilder builder("body");
393     auto body_parameter = Parameter(&builder, 0, while_shape, "body_parameter");
394     std::vector<XlaOp> updates;
395     for (int i = 0; i < kParamCount; ++i) {
396       auto add = Add(GetTupleElement(body_parameter, i),
397                      ConstantR1<int32>(&builder, {1, 1}));
398       updates.push_back(add);
399     }
400     // Add bool parameter.
401     updates.push_back(GetTupleElement(body_parameter, kParamCount));
402 
403     Tuple(&builder, updates);
404     body = builder.Build().ConsumeValueOrDie();
405   }
406 
407   auto loop = While(condition, body, init);
408 
409   std::vector<XlaOp> outputs;
410   for (int i = 0; i < kParamCount; ++i) {
411     outputs.push_back(GetTupleElement(loop, i));
412   }
413   Tuple(&builder, outputs);
414 
415   std::vector<GlobalData*> param_data;
416   param_data.reserve(param_data_owner.size());
417   for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
418     param_data.push_back(data.get());
419   }
420 
421   std::vector<Literal> elements;
422   std::vector<const Literal*> ptrs;
423   elements.reserve(kParamCount);
424   for (int i = 0; i < kParamCount; ++i) {
425     elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
426     ptrs.push_back(&elements.back());
427   }
428   ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
429 }
430 
431 #endif
432 
XLA_TEST_F(ParamsTest,TupleOfR1ParametersAddedTogether)433 XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
434   XlaBuilder builder(TestName());
435 
436   Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3});
437   Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3});
438   auto input = Parameter(&builder, 0, tuple_shape, "input");
439   auto lhs = GetTupleElement(input, 0);
440   auto rhs = GetTupleElement(input, 1);
441   Add(lhs, rhs);
442 
443   std::unique_ptr<GlobalData> data =
444       client_
445           ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
446               LiteralUtil::CreateR1<float>({1, 2, 3}),
447               LiteralUtil::CreateR1<float>({4, 5, 6}),
448           }))
449           .ConsumeValueOrDie();
450 
451   std::vector<GlobalData*> arguments = {data.get()};
452   const std::vector<float> expected = {1 + 4, 2 + 5, 3 + 6};
453   ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
454 }
455 
456 // Verifies that passing a 2x2 with {0, 1} layout returns the same value back
457 // when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest,R2_2x2_Layout_01)458 XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
459   Literal literal = LiteralUtil::CreateR2WithLayout<float>(
460       {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
461   XlaBuilder builder(TestName());
462   Parameter(&builder, 0, literal.shape(), "input");
463 
464   std::unique_ptr<GlobalData> data =
465       client_->TransferToServer(literal).ConsumeValueOrDie();
466   ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
467 }
468 
469 // As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest,R2_2x2_Layout_10)470 XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
471   Literal literal = LiteralUtil::CreateR2WithLayout<float>(
472       {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
473   XlaBuilder builder(TestName());
474   Parameter(&builder, 0, literal.shape(), "input");
475 
476   std::unique_ptr<GlobalData> data =
477       client_->TransferToServer(literal).ConsumeValueOrDie();
478   ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
479 }
480 
XLA_TEST_F(ParamsTest,R2_2x2_TryToPassReverseLayoutToParameter)481 XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
482   Literal literal = LiteralUtil::CreateR2<float>({
483       {1, 3},
484       {2, 4},
485   });
486   const Shape original = literal.shape();
487   {
488     // Reverse the layout present in original, and make that the layout of the
489     // literal.
490     std::vector<int64> original_layout(
491         original.layout().minor_to_major().begin(),
492         original.layout().minor_to_major().end());
493     std::reverse(original_layout.begin(), original_layout.end());
494     *literal.mutable_shape_do_not_use()->mutable_layout() =
495         LayoutUtil::MakeLayout(original_layout);
496     ASSERT_EQ(2, literal.Get<float>({0, 1}));
497   }
498   // Use the original shape in building the computation.
499   XlaBuilder builder(TestName());
500   auto input = Parameter(&builder, 0, original, "input");
501   // Use the slice operator to get an off-diagonal element.
502   Slice(input, {0, 1}, {1, 2}, {1, 1});
503 
504   std::unique_ptr<GlobalData> data =
505       client_->TransferToServer(literal).ConsumeValueOrDie();
506   // Check that we got the off-diagonal value that we expected.
507   Array2D<float> expected(1, 1);
508   expected(0, 0) = 2;
509   ComputeAndCompareR2(&builder, expected, {data.get()}, ErrorSpec(1e-3));
510 }
511 
512 }  // namespace
513 }  // namespace xla
514