1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <numeric>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/client/lib/math.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 
33 namespace xla {
34 namespace {
35 
36 using TriangularSolveTest = ClientLibraryTestBase;
37 using TriangularSolveLeftLookingTest = ClientLibraryTestBase;
38 
39 static constexpr float kNan = std::numeric_limits<float>::quiet_NaN();
40 
AValsLower()41 Array2D<float> AValsLower() {
42   return {{2, kNan, kNan, kNan},
43           {3, 6, kNan, kNan},
44           {4, 7, 9, kNan},
45           {5, 8, 10, 11}};
46 }
47 
AValsUpper()48 Array2D<float> AValsUpper() {
49   return {{2, 3, 4, 5},
50           {kNan, 6, 7, 8},
51           {kNan, kNan, 9, 10},
52           {kNan, kNan, kNan, 11}};
53 }
54 
AValsLowerUnitDiagonal()55 Array2D<float> AValsLowerUnitDiagonal() {
56   return {{kNan, kNan, kNan, kNan},
57           {3, kNan, kNan, kNan},
58           {4, 7, kNan, kNan},
59           {5, 8, 10, kNan}};
60 }
61 
AValsUpperUnitDiagonal()62 Array2D<float> AValsUpperUnitDiagonal() {
63   return {{kNan, 3, 4, 5},
64           {kNan, kNan, 7, 8},
65           {kNan, kNan, kNan, 10},
66           {kNan, kNan, kNan, kNan}};
67 }
68 
BValsRight()69 Array2D<float> BValsRight() {
70   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
71 }
72 
BValsLeft()73 Array2D<float> BValsLeft() {
74   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
75 }
76 
77 static constexpr complex64 kNanC64 = complex64(kNan, kNan);
78 
AValsLowerComplex()79 Array2D<complex64> AValsLowerComplex() {
80   return {{2, kNanC64, kNanC64, kNanC64},
81           {complex64(3, 1), 6, kNanC64, kNanC64},
82           {4, complex64(7, 2), 9, kNanC64},
83           {5, 8, complex64(10, 3), 11}};
84 }
85 
AValsUpperComplex()86 Array2D<complex64> AValsUpperComplex() {
87   return {{2, 3, complex64(4, 3), 5},
88           {kNanC64, 6, complex64(7, 2), 8},
89           {kNanC64, kNanC64, complex64(9, 1), 10},
90           {kNanC64, kNanC64, kNanC64, 11}};
91 }
92 
BValsRightComplex()93 Array2D<complex64> BValsRightComplex() {
94   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
95 }
96 
BValsLeftComplex()97 Array2D<complex64> BValsLeftComplex() {
98   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
99 }
100 
XLA_TEST_F(TriangularSolveTest,EmptyArrays)101 XLA_TEST_F(TriangularSolveTest, EmptyArrays) {
102   XlaBuilder builder(TestName());
103 
104   XlaOp a, b;
105   auto a_data =
106       CreateR2Parameter<float>(Array2D<float>(0, 0), 0, "a", &builder, &a);
107   auto b_data =
108       CreateR2Parameter<float>(Array2D<float>(0, 10), 1, "b", &builder, &b);
109   TriangularSolve(a, b,
110                   /*left_side=*/true, /*lower=*/true,
111                   /*unit_diagonal=*/false,
112                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
113 
114   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 10),
115                              {a_data.get(), b_data.get()});
116 }
117 
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerTranspose)118 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
119   XlaBuilder builder(TestName());
120 
121   XlaOp a, b;
122   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
123   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
124   TriangularSolve(a, b,
125                   /*left_side=*/false, /*lower=*/true,
126                   /*unit_diagonal=*/false,
127                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
128 
129   Array2D<float> expected({
130       {0.5, 0.08333334, 0.04629629, 0.03367003},
131       {2.5, -0.25, -0.1388889, -0.1010101},
132       {4.5, -0.58333331, -0.32407406, -0.23569024},
133   });
134 
135   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
136                              ErrorSpec(1e-2, 1e-2));
137 }
138 
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerNotranspose)139 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
140   XlaBuilder builder(TestName());
141 
142   XlaOp a, b;
143   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
144   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
145   TriangularSolve(a, b,
146                   /*left_side=*/false, /*lower=*/true,
147                   /*unit_diagonal=*/false,
148                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
149 
150   Array2D<float> expected({
151       {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
152       {0.64393939, 0.06565657, -0.03030303, 0.72727273},
153       {1.4520202, 0.2003367, 0.01010101, 1.09090909},
154   });
155 
156   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
157                              ErrorSpec(1e-2, 1e-2));
158 }
159 
XLA_TEST_F(TriangularSolveTest,SimpleRightUpperTranspose)160 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
161   XlaBuilder builder(TestName());
162 
163   XlaOp a, b;
164   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
165   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
166   TriangularSolve(a, b,
167                   /*left_side=*/false, /*lower=*/false,
168                   /*unit_diagonal=*/false,
169                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
170 
171   Array2D<float> expected({
172       {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
173       {0.64393939, 0.06565657, -0.03030303, 0.72727273},
174       {1.4520202, 0.2003367, 0.01010101, 1.09090909},
175   });
176 
177   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
178                              ErrorSpec(1e-2, 1e-2));
179 }
180 
XLA_TEST_F(TriangularSolveTest,SimpleRightUpperNotranspose)181 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
182   XlaBuilder builder(TestName());
183 
184   XlaOp a, b;
185   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
186   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
187   TriangularSolve(a, b,
188                   /*left_side=*/false, /*lower=*/false,
189                   /*unit_diagonal=*/false,
190                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
191 
192   Array2D<float> expected({
193       {0.5, 0.08333334, 0.04629629, 0.03367003},
194       {2.5, -0.25, -0.1388889, -0.1010101},
195       {4.5, -0.58333331, -0.32407406, -0.23569024},
196   });
197 
198   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
199                              ErrorSpec(1e-2, 1e-2));
200 }
201 
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerTranspose)202 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
203   XlaBuilder builder(TestName());
204 
205   XlaOp a, b;
206   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
207   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
208   TriangularSolve(a, b,
209                   /*left_side=*/true, /*lower=*/true,
210                   /*unit_diagonal=*/false,
211                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
212 
213   Array2D<float> expected({
214       {-0.89646465, -0.69444444, -0.49242424},
215       {-0.27441077, -0.24074074, -0.20707071},
216       {-0.23232323, -0.22222222, -0.21212121},
217       {0.90909091, 1., 1.09090909},
218   });
219 
220   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
221                              ErrorSpec(1e-2, 1e-2));
222 }
223 
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNotranspose)224 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
225   XlaBuilder builder(TestName());
226 
227   XlaOp a, b;
228   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
229   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
230   TriangularSolve(a, b,
231                   /*left_side=*/true, /*lower=*/true,
232                   /*unit_diagonal=*/false,
233                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
234 
235   Array2D<float> expected({
236       {0.5, 1.0, 1.5},
237       {0.41666667, 0.33333333, 0.25},
238       {0.23148148, 0.18518519, 0.13888889},
239       {0.16835017, 0.13468013, 0.1010101},
240   });
241 
242   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
243                              ErrorSpec(1e-2, 1e-2));
244 }
245 
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNoTransposeUnitDiagonal)246 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) {
247   XlaBuilder builder(TestName());
248 
249   XlaOp a, b;
250   auto a_data =
251       CreateR2Parameter<float>(AValsLowerUnitDiagonal(), 0, "a", &builder, &a);
252   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
253   TriangularSolve(a, b,
254                   /*left_side=*/true, /*lower=*/true,
255                   /*unit_diagonal=*/true,
256                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
257 
258   Array2D<float> expected(
259       {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}});
260 
261   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
262                              ErrorSpec(1e-2, 1e-2));
263 }
264 
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNotransposeIrregularblock)265 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) {
266   XlaBuilder builder(TestName());
267 
268   XlaOp a, b;
269   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
270   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
271   TriangularSolve(a, b,
272                   /*left_side=*/true, /*lower=*/true,
273                   /*unit_diagonal=*/false,
274                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
275 
276   Array2D<float> expected({
277       {0.5, 1.0, 1.5},
278       {0.41666667, 0.33333333, 0.25},
279       {0.23148148, 0.18518519, 0.13888889},
280       {0.16835017, 0.13468013, 0.1010101},
281   });
282 
283   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
284                              ErrorSpec(1e-2, 1e-2));
285 }
286 
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperTranspose)287 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
288   XlaBuilder builder(TestName());
289 
290   XlaOp a, b;
291   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
292   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
293   TriangularSolve(a, b,
294                   /*left_side=*/true, /*lower=*/false,
295                   /*unit_diagonal=*/false,
296                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
297 
298   Array2D<float> expected({
299       {0.5, 1.0, 1.5},
300       {0.41666667, 0.33333333, 0.25},
301       {0.23148148, 0.18518519, 0.13888889},
302       {0.16835017, 0.13468013, 0.1010101},
303   });
304 
305   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
306                              ErrorSpec(1e-2, 1e-2));
307 }
308 
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperNotranspose)309 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
310   XlaBuilder builder(TestName());
311 
312   XlaOp a, b;
313   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
314   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
315   TriangularSolve(a, b,
316                   /*left_side=*/true, /*lower=*/false,
317                   /*unit_diagonal=*/false,
318                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
319 
320   Array2D<float> expected({
321       {-0.89646465, -0.69444444, -0.49242424},
322       {-0.27441077, -0.24074074, -0.20707071},
323       {-0.23232323, -0.22222222, -0.21212121},
324       {0.90909091, 1., 1.09090909},
325   });
326 
327   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
328                              ErrorSpec(1e-2, 1e-2));
329 }
330 
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperNotransposeUnitDiagonal)331 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) {
332   XlaBuilder builder(TestName());
333 
334   XlaOp a, b;
335   auto a_data =
336       CreateR2Parameter<float>(AValsUpperUnitDiagonal(), 0, "a", &builder, &a);
337   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
338   TriangularSolve(a, b,
339                   /*left_side=*/true, /*lower=*/false,
340                   /*unit_diagonal=*/true,
341                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
342 
343   Array2D<float> expected({{-1402., -1538., -1674.},
344                            {575., 631., 687.},
345                            {-93., -102., -111.},
346                            {10., 11., 12.}});
347 
348   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
349                              ErrorSpec(1e-2, 1e-2));
350 }
351 
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerTransposeConjugate)352 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
353   XlaBuilder builder(TestName());
354 
355   XlaOp a, b;
356   auto a_data =
357       CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
358   auto b_data =
359       CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
360   TriangularSolve(a, b,
361                   /*left_side=*/false, /*lower=*/true,
362                   /*unit_diagonal=*/false,
363                   /*transpose_a=*/TriangularSolveOptions::ADJOINT);
364 
365   Array2D<complex64> expected({
366       {0.5, complex64(0.08333333, 0.08333333),
367        complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
368       {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
369        complex64(0.08670034, -0.02104377)},
370       {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
371        complex64(0.11026936, -0.03114478)},
372   });
373 
374   ComputeAndCompareR2<complex64>(
375       &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
376 }
377 
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperTransposeNoconjugate)378 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
379   XlaBuilder builder(TestName());
380 
381   XlaOp a, b;
382   auto a_data =
383       CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
384   auto b_data =
385       CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
386   TriangularSolve(a, b,
387                   /*left_side=*/true, /*lower=*/false,
388                   /*unit_diagonal=*/false,
389                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
390 
391   Array2D<complex64> expected({
392       {0.5, 1., 1.5},
393       {0.41666667, 0.33333333, 0.25},
394       {complex64(0.20020325, -2.81504065e-01),
395        complex64(0.13821138, -4.22764228e-01),
396        complex64(0.07621951, -5.64024390e-01)},
397       {complex64(0.19678492, 2.55912786e-01),
398        complex64(0.17738359, 3.84331116e-01),
399        complex64(0.15798226, 5.12749446e-01)},
400   });
401 
402   ComputeAndCompareR2<complex64>(
403       &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
404 }
405 
XLA_TEST_F(TriangularSolveTest,BatchedLeftUpper)406 XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) {
407   XlaBuilder builder(TestName());
408 
409   Array3D<float> bvals(7, 5, 5);
410   bvals.FillIota(1.);
411 
412   // Set avals to the upper triangle of bvals.
413   Array3D<float> avals = bvals;
414   avals.Each([](absl::Span<const int64> indices, float* value) {
415     if (indices[1] > indices[2]) {
416       *value = 0;
417     }
418   });
419 
420   XlaOp a, b;
421   auto a_data = CreateR3Parameter<float>(avals, 0, "a", &builder, &a);
422   auto b_data = CreateR3Parameter<float>(bvals, 1, "b", &builder, &b);
423   BatchDot(
424       ConstantR3FromArray3D(&builder, avals),
425       TriangularSolve(a, b,
426                       /*left_side=*/true, /*lower=*/false,
427                       /*unit_diagonal=*/false,
428                       /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE));
429 
430   ComputeAndCompareR3<float>(&builder, bvals, {a_data.get(), b_data.get()},
431                              ErrorSpec(1e-2, 1e-2));
432 }
433 
434 struct TriangularSolveTestSpec {
435   int m, n;  // A is mxm, B is mxn
436   bool left_side;
437   bool lower;
438   TriangularSolveOptions::Transpose transpose_a;
439 };
440 
441 class TriangularSolveParametricTest
442     : public ClientLibraryTestBase,
443       public ::testing::WithParamInterface<TriangularSolveTestSpec> {};
444 
XLA_TEST_P(TriangularSolveParametricTest,Random)445 XLA_TEST_P(TriangularSolveParametricTest, Random) {
446   TriangularSolveTestSpec spec = GetParam();
447 
448   XlaBuilder builder(TestName());
449 
450   Array2D<float> avals(spec.m, spec.m);
451   avals.FillRandom(1.0);
452   for (int i = 0; i < spec.m; ++i) {
453     avals(i, i) += 10;
454   }
455 
456   std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
457                                              : std::make_pair(spec.n, spec.m);
458   Array2D<float> bvals(bdims.first, bdims.second);
459   bvals.FillRandom(1.0);
460 
461   XlaOp a, b;
462   auto a_data = CreateR2Parameter<float>(avals, 0, "a", &builder, &a);
463   auto b_data = CreateR2Parameter<float>(bvals, 1, "b", &builder, &b);
464   auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
465                            /*unit_diagonal=*/false, spec.transpose_a);
466   auto a_tri = Triangle(a, spec.lower);
467   a_tri = MaybeTransposeInMinorDims(
468       a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE);
469   if (spec.left_side) {
470     BatchDot(a_tri, x);
471   } else {
472     BatchDot(x, a_tri);
473   }
474 
475   ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
476                              ErrorSpec(1e-2, 1e-2));
477 }
478 
TriangularSolveTests()479 std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
480   std::vector<TriangularSolveTestSpec> specs;
481   for (int m : {5, 10}) {
482     for (int n : {5, 10}) {
483       for (bool left_side : {false, true}) {
484         for (bool lower : {false, true}) {
485           for (TriangularSolveOptions::Transpose transpose_a :
486                {TriangularSolveOptions::NO_TRANSPOSE,
487                 TriangularSolveOptions::TRANSPOSE}) {
488             specs.push_back({m, n, left_side, lower, transpose_a});
489           }
490         }
491       }
492     }
493   }
494   return specs;
495 }
496 
497 INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation,
498                          TriangularSolveParametricTest,
499                          ::testing::ValuesIn(TriangularSolveTests()));
500 
501 }  // namespace
502 }  // namespace xla
503