1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <numeric>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/client/lib/math.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32
33 namespace xla {
34 namespace {
35
36 using TriangularSolveTest = ClientLibraryTestBase;
37 using TriangularSolveLeftLookingTest = ClientLibraryTestBase;
38
39 static constexpr float kNan = std::numeric_limits<float>::quiet_NaN();
40
AValsLower()41 Array2D<float> AValsLower() {
42 return {{2, kNan, kNan, kNan},
43 {3, 6, kNan, kNan},
44 {4, 7, 9, kNan},
45 {5, 8, 10, 11}};
46 }
47
AValsUpper()48 Array2D<float> AValsUpper() {
49 return {{2, 3, 4, 5},
50 {kNan, 6, 7, 8},
51 {kNan, kNan, 9, 10},
52 {kNan, kNan, kNan, 11}};
53 }
54
AValsLowerUnitDiagonal()55 Array2D<float> AValsLowerUnitDiagonal() {
56 return {{kNan, kNan, kNan, kNan},
57 {3, kNan, kNan, kNan},
58 {4, 7, kNan, kNan},
59 {5, 8, 10, kNan}};
60 }
61
AValsUpperUnitDiagonal()62 Array2D<float> AValsUpperUnitDiagonal() {
63 return {{kNan, 3, 4, 5},
64 {kNan, kNan, 7, 8},
65 {kNan, kNan, kNan, 10},
66 {kNan, kNan, kNan, kNan}};
67 }
68
BValsRight()69 Array2D<float> BValsRight() {
70 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
71 }
72
BValsLeft()73 Array2D<float> BValsLeft() {
74 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
75 }
76
77 static constexpr complex64 kNanC64 = complex64(kNan, kNan);
78
AValsLowerComplex()79 Array2D<complex64> AValsLowerComplex() {
80 return {{2, kNanC64, kNanC64, kNanC64},
81 {complex64(3, 1), 6, kNanC64, kNanC64},
82 {4, complex64(7, 2), 9, kNanC64},
83 {5, 8, complex64(10, 3), 11}};
84 }
85
AValsUpperComplex()86 Array2D<complex64> AValsUpperComplex() {
87 return {{2, 3, complex64(4, 3), 5},
88 {kNanC64, 6, complex64(7, 2), 8},
89 {kNanC64, kNanC64, complex64(9, 1), 10},
90 {kNanC64, kNanC64, kNanC64, 11}};
91 }
92
BValsRightComplex()93 Array2D<complex64> BValsRightComplex() {
94 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
95 }
96
BValsLeftComplex()97 Array2D<complex64> BValsLeftComplex() {
98 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
99 }
100
XLA_TEST_F(TriangularSolveTest,EmptyArrays)101 XLA_TEST_F(TriangularSolveTest, EmptyArrays) {
102 XlaBuilder builder(TestName());
103
104 XlaOp a, b;
105 auto a_data =
106 CreateR2Parameter<float>(Array2D<float>(0, 0), 0, "a", &builder, &a);
107 auto b_data =
108 CreateR2Parameter<float>(Array2D<float>(0, 10), 1, "b", &builder, &b);
109 TriangularSolve(a, b,
110 /*left_side=*/true, /*lower=*/true,
111 /*unit_diagonal=*/false,
112 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
113
114 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 10),
115 {a_data.get(), b_data.get()});
116 }
117
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerTranspose)118 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
119 XlaBuilder builder(TestName());
120
121 XlaOp a, b;
122 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
123 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
124 TriangularSolve(a, b,
125 /*left_side=*/false, /*lower=*/true,
126 /*unit_diagonal=*/false,
127 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
128
129 Array2D<float> expected({
130 {0.5, 0.08333334, 0.04629629, 0.03367003},
131 {2.5, -0.25, -0.1388889, -0.1010101},
132 {4.5, -0.58333331, -0.32407406, -0.23569024},
133 });
134
135 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
136 ErrorSpec(1e-2, 1e-2));
137 }
138
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerNotranspose)139 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
140 XlaBuilder builder(TestName());
141
142 XlaOp a, b;
143 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
144 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
145 TriangularSolve(a, b,
146 /*left_side=*/false, /*lower=*/true,
147 /*unit_diagonal=*/false,
148 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
149
150 Array2D<float> expected({
151 {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
152 {0.64393939, 0.06565657, -0.03030303, 0.72727273},
153 {1.4520202, 0.2003367, 0.01010101, 1.09090909},
154 });
155
156 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
157 ErrorSpec(1e-2, 1e-2));
158 }
159
XLA_TEST_F(TriangularSolveTest,SimpleRightUpperTranspose)160 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
161 XlaBuilder builder(TestName());
162
163 XlaOp a, b;
164 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
165 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
166 TriangularSolve(a, b,
167 /*left_side=*/false, /*lower=*/false,
168 /*unit_diagonal=*/false,
169 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
170
171 Array2D<float> expected({
172 {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
173 {0.64393939, 0.06565657, -0.03030303, 0.72727273},
174 {1.4520202, 0.2003367, 0.01010101, 1.09090909},
175 });
176
177 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
178 ErrorSpec(1e-2, 1e-2));
179 }
180
XLA_TEST_F(TriangularSolveTest,SimpleRightUpperNotranspose)181 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
182 XlaBuilder builder(TestName());
183
184 XlaOp a, b;
185 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
186 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
187 TriangularSolve(a, b,
188 /*left_side=*/false, /*lower=*/false,
189 /*unit_diagonal=*/false,
190 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
191
192 Array2D<float> expected({
193 {0.5, 0.08333334, 0.04629629, 0.03367003},
194 {2.5, -0.25, -0.1388889, -0.1010101},
195 {4.5, -0.58333331, -0.32407406, -0.23569024},
196 });
197
198 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
199 ErrorSpec(1e-2, 1e-2));
200 }
201
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerTranspose)202 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
203 XlaBuilder builder(TestName());
204
205 XlaOp a, b;
206 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
207 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
208 TriangularSolve(a, b,
209 /*left_side=*/true, /*lower=*/true,
210 /*unit_diagonal=*/false,
211 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
212
213 Array2D<float> expected({
214 {-0.89646465, -0.69444444, -0.49242424},
215 {-0.27441077, -0.24074074, -0.20707071},
216 {-0.23232323, -0.22222222, -0.21212121},
217 {0.90909091, 1., 1.09090909},
218 });
219
220 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
221 ErrorSpec(1e-2, 1e-2));
222 }
223
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNotranspose)224 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
225 XlaBuilder builder(TestName());
226
227 XlaOp a, b;
228 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
229 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
230 TriangularSolve(a, b,
231 /*left_side=*/true, /*lower=*/true,
232 /*unit_diagonal=*/false,
233 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
234
235 Array2D<float> expected({
236 {0.5, 1.0, 1.5},
237 {0.41666667, 0.33333333, 0.25},
238 {0.23148148, 0.18518519, 0.13888889},
239 {0.16835017, 0.13468013, 0.1010101},
240 });
241
242 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
243 ErrorSpec(1e-2, 1e-2));
244 }
245
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNoTransposeUnitDiagonal)246 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) {
247 XlaBuilder builder(TestName());
248
249 XlaOp a, b;
250 auto a_data =
251 CreateR2Parameter<float>(AValsLowerUnitDiagonal(), 0, "a", &builder, &a);
252 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
253 TriangularSolve(a, b,
254 /*left_side=*/true, /*lower=*/true,
255 /*unit_diagonal=*/true,
256 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
257
258 Array2D<float> expected(
259 {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}});
260
261 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
262 ErrorSpec(1e-2, 1e-2));
263 }
264
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNotransposeIrregularblock)265 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) {
266 XlaBuilder builder(TestName());
267
268 XlaOp a, b;
269 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
270 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
271 TriangularSolve(a, b,
272 /*left_side=*/true, /*lower=*/true,
273 /*unit_diagonal=*/false,
274 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
275
276 Array2D<float> expected({
277 {0.5, 1.0, 1.5},
278 {0.41666667, 0.33333333, 0.25},
279 {0.23148148, 0.18518519, 0.13888889},
280 {0.16835017, 0.13468013, 0.1010101},
281 });
282
283 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
284 ErrorSpec(1e-2, 1e-2));
285 }
286
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperTranspose)287 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
288 XlaBuilder builder(TestName());
289
290 XlaOp a, b;
291 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
292 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
293 TriangularSolve(a, b,
294 /*left_side=*/true, /*lower=*/false,
295 /*unit_diagonal=*/false,
296 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
297
298 Array2D<float> expected({
299 {0.5, 1.0, 1.5},
300 {0.41666667, 0.33333333, 0.25},
301 {0.23148148, 0.18518519, 0.13888889},
302 {0.16835017, 0.13468013, 0.1010101},
303 });
304
305 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
306 ErrorSpec(1e-2, 1e-2));
307 }
308
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperNotranspose)309 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
310 XlaBuilder builder(TestName());
311
312 XlaOp a, b;
313 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
314 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
315 TriangularSolve(a, b,
316 /*left_side=*/true, /*lower=*/false,
317 /*unit_diagonal=*/false,
318 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
319
320 Array2D<float> expected({
321 {-0.89646465, -0.69444444, -0.49242424},
322 {-0.27441077, -0.24074074, -0.20707071},
323 {-0.23232323, -0.22222222, -0.21212121},
324 {0.90909091, 1., 1.09090909},
325 });
326
327 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
328 ErrorSpec(1e-2, 1e-2));
329 }
330
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperNotransposeUnitDiagonal)331 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) {
332 XlaBuilder builder(TestName());
333
334 XlaOp a, b;
335 auto a_data =
336 CreateR2Parameter<float>(AValsUpperUnitDiagonal(), 0, "a", &builder, &a);
337 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
338 TriangularSolve(a, b,
339 /*left_side=*/true, /*lower=*/false,
340 /*unit_diagonal=*/true,
341 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
342
343 Array2D<float> expected({{-1402., -1538., -1674.},
344 {575., 631., 687.},
345 {-93., -102., -111.},
346 {10., 11., 12.}});
347
348 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
349 ErrorSpec(1e-2, 1e-2));
350 }
351
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerTransposeConjugate)352 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
353 XlaBuilder builder(TestName());
354
355 XlaOp a, b;
356 auto a_data =
357 CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
358 auto b_data =
359 CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
360 TriangularSolve(a, b,
361 /*left_side=*/false, /*lower=*/true,
362 /*unit_diagonal=*/false,
363 /*transpose_a=*/TriangularSolveOptions::ADJOINT);
364
365 Array2D<complex64> expected({
366 {0.5, complex64(0.08333333, 0.08333333),
367 complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
368 {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
369 complex64(0.08670034, -0.02104377)},
370 {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
371 complex64(0.11026936, -0.03114478)},
372 });
373
374 ComputeAndCompareR2<complex64>(
375 &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
376 }
377
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperTransposeNoconjugate)378 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
379 XlaBuilder builder(TestName());
380
381 XlaOp a, b;
382 auto a_data =
383 CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
384 auto b_data =
385 CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
386 TriangularSolve(a, b,
387 /*left_side=*/true, /*lower=*/false,
388 /*unit_diagonal=*/false,
389 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
390
391 Array2D<complex64> expected({
392 {0.5, 1., 1.5},
393 {0.41666667, 0.33333333, 0.25},
394 {complex64(0.20020325, -2.81504065e-01),
395 complex64(0.13821138, -4.22764228e-01),
396 complex64(0.07621951, -5.64024390e-01)},
397 {complex64(0.19678492, 2.55912786e-01),
398 complex64(0.17738359, 3.84331116e-01),
399 complex64(0.15798226, 5.12749446e-01)},
400 });
401
402 ComputeAndCompareR2<complex64>(
403 &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
404 }
405
XLA_TEST_F(TriangularSolveTest,BatchedLeftUpper)406 XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) {
407 XlaBuilder builder(TestName());
408
409 Array3D<float> bvals(7, 5, 5);
410 bvals.FillIota(1.);
411
412 // Set avals to the upper triangle of bvals.
413 Array3D<float> avals = bvals;
414 avals.Each([](absl::Span<const int64> indices, float* value) {
415 if (indices[1] > indices[2]) {
416 *value = 0;
417 }
418 });
419
420 XlaOp a, b;
421 auto a_data = CreateR3Parameter<float>(avals, 0, "a", &builder, &a);
422 auto b_data = CreateR3Parameter<float>(bvals, 1, "b", &builder, &b);
423 BatchDot(
424 ConstantR3FromArray3D(&builder, avals),
425 TriangularSolve(a, b,
426 /*left_side=*/true, /*lower=*/false,
427 /*unit_diagonal=*/false,
428 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE));
429
430 ComputeAndCompareR3<float>(&builder, bvals, {a_data.get(), b_data.get()},
431 ErrorSpec(1e-2, 1e-2));
432 }
433
434 struct TriangularSolveTestSpec {
435 int m, n; // A is mxm, B is mxn
436 bool left_side;
437 bool lower;
438 TriangularSolveOptions::Transpose transpose_a;
439 };
440
441 class TriangularSolveParametricTest
442 : public ClientLibraryTestBase,
443 public ::testing::WithParamInterface<TriangularSolveTestSpec> {};
444
XLA_TEST_P(TriangularSolveParametricTest,Random)445 XLA_TEST_P(TriangularSolveParametricTest, Random) {
446 TriangularSolveTestSpec spec = GetParam();
447
448 XlaBuilder builder(TestName());
449
450 Array2D<float> avals(spec.m, spec.m);
451 avals.FillRandom(1.0);
452 for (int i = 0; i < spec.m; ++i) {
453 avals(i, i) += 10;
454 }
455
456 std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
457 : std::make_pair(spec.n, spec.m);
458 Array2D<float> bvals(bdims.first, bdims.second);
459 bvals.FillRandom(1.0);
460
461 XlaOp a, b;
462 auto a_data = CreateR2Parameter<float>(avals, 0, "a", &builder, &a);
463 auto b_data = CreateR2Parameter<float>(bvals, 1, "b", &builder, &b);
464 auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
465 /*unit_diagonal=*/false, spec.transpose_a);
466 auto a_tri = Triangle(a, spec.lower);
467 a_tri = MaybeTransposeInMinorDims(
468 a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE);
469 if (spec.left_side) {
470 BatchDot(a_tri, x);
471 } else {
472 BatchDot(x, a_tri);
473 }
474
475 ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
476 ErrorSpec(1e-2, 1e-2));
477 }
478
TriangularSolveTests()479 std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
480 std::vector<TriangularSolveTestSpec> specs;
481 for (int m : {5, 10}) {
482 for (int n : {5, 10}) {
483 for (bool left_side : {false, true}) {
484 for (bool lower : {false, true}) {
485 for (TriangularSolveOptions::Transpose transpose_a :
486 {TriangularSolveOptions::NO_TRANSPOSE,
487 TriangularSolveOptions::TRANSPOSE}) {
488 specs.push_back({m, n, left_side, lower, transpose_a});
489 }
490 }
491 }
492 }
493 }
494 return specs;
495 }
496
497 INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation,
498 TriangularSolveParametricTest,
499 ::testing::ValuesIn(TriangularSolveTests()));
500
501 } // namespace
502 } // namespace xla
503