1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34
35 namespace xla {
36
37 namespace {
38
39 // Get the diagonal blocks of the coefficient matrix
DiagonalBlocks(XlaOp a,int64 block_size)40 XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
41 XlaBuilder* builder = a.builder();
42 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
43 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
44 int ndims = shape.rank();
45 int64 n = ShapeUtil::GetDimension(shape, -1);
46 int64 num_blocks = n / block_size;
47 absl::Span<int64 const> batch_dims = absl::MakeConstSpan(
48 shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
49
50 XlaOp diag_blocks;
51
52 // If the coefficient matrix is exactly the block size, we just add a
53 // singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
54 if (n == block_size) {
55 std::vector<int64> permutation(ndims);
56 std::iota(permutation.begin(), permutation.end(), 1);
57 permutation.insert(permutation.end() - 2, 0);
58 return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
59 }
60
61 // We can grab entire blocks using gather
62 if (n > block_size) {
63 // Construct the starting indices of the diagonal blocks
64 auto start_indices =
65 Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks),
66 ConstantR0<int32>(builder, block_size)),
67 /*broadcast_sizes=*/{2}),
68 /*permutation=*/{1, 0});
69
70 PaddingConfig padding_config =
71 MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}});
72 start_indices =
73 Pad(start_indices, ConstantR0<int32>(builder, 0), padding_config);
74
75 // Gather the diagonal blocks
76 std::vector<int64> slice_sizes(ndims);
77 GatherDimensionNumbers dim_numbers;
78 for (int i = 0; i < ndims - 2; ++i) {
79 dim_numbers.add_offset_dims(i);
80 dim_numbers.add_start_index_map(i);
81 slice_sizes[i] = ShapeUtil::GetDimension(shape, i);
82 }
83 slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size;
84 dim_numbers.add_offset_dims(ndims - 1);
85 dim_numbers.add_offset_dims(ndims);
86 dim_numbers.add_start_index_map(ndims - 2);
87 dim_numbers.add_start_index_map(ndims - 1);
88 dim_numbers.set_index_vector_dim(1);
89 diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
90 }
91
92 // The last block might be smaller than the block size,
93 // so we will need to pad it
94 if (n % block_size != 0) {
95 // Pad with identity matrix.
96 auto last_blocks =
97 SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
98 PaddingConfig config = MakeNoPaddingConfig(ndims);
99 int64 padding = block_size - n % block_size;
100 config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
101 last_blocks =
102 Pad(last_blocks, Zero(builder, shape.element_type()), config);
103
104 auto eye =
105 IdentityMatrix(builder, shape.element_type(), padding, padding);
106 config = MakeNoPaddingConfig(2);
107 config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
108 eye = Pad(eye, Zero(builder, shape.element_type()), config);
109 eye = Broadcast(eye, batch_dims);
110 last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
111
112 // Add a singleton dimension
113 // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
114 TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
115 auto shape_dims = AsInt64Slice(blocks_shape.dimensions());
116 auto last_blocks_dims = std::vector<int64>(ndims);
117 std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
118 last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
119 last_blocks = Reshape(last_blocks, last_blocks_dims);
120
121 // Concatenate with the other blocks if necessary
122 if (n > block_size) {
123 diag_blocks =
124 ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
125 } else {
126 diag_blocks = last_blocks;
127 }
128 }
129
130 return diag_blocks;
131 });
132 }
133
SolveWithInvertedDiagonalBlocks(XlaOp a,XlaOp b,XlaOp inv_diag_blocks,bool left_side,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)134 XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
135 bool left_side, bool lower,
136 bool transpose_a, bool conjugate_a,
137 PrecisionConfig::Precision precision) {
138 XlaBuilder* builder = a.builder();
139 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
140 TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks));
141 TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
142 int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1);
143
144 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
145 int64 ndims = a_shape.rank();
146 int64 n = ShapeUtil::GetDimension(a_shape, -1);
147 int64 num_blocks = n / block_size + (n % block_size != 0);
148 int64 m_dim = (left_side) ? -1 : -2;
149 int64 m = ShapeUtil::GetDimension(b_shape, m_dim);
150
151 std::vector<XlaOp> update_ops;
152 int bdims = b_shape.rank();
153 int64 block_dim = (left_side) ? bdims - 2 : bdims - 1;
154
155 // Initialize the solution
156 XlaOp x;
157
158 // This loop is unrolled for performance reasons, but it could be expressed
159 // rolled as well since the matrices are of the same size each iteration
160 for (int i = 0; i < num_blocks; i++) {
161 // High-level intuition: We have B[i] = L[i] @ X. Since L is upper
162 // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
163 // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
164 // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
165
166 // Decide whether we go from first block to last or vice versa
167 bool backward = left_side ^ lower ^ transpose_a;
168 auto j = backward ? num_blocks - 1 - i : i;
169
170 // Get the size of the inverse blocks (the last one might be smaller)
171 int64 block = (n % block_size != 0 && j + 1 == num_blocks)
172 ? n % block_size
173 : block_size;
174 auto inv_block =
175 MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
176 {j + 1, block, block}),
177 /*dimensions=*/{ndims - 2, ndims - 1}),
178 conjugate_a);
179
180 // Get the corresponding row of B
181 int64 k = std::min((j + 1) * block_size, n);
182 std::vector<int64> start = {j * block_size, 0};
183 std::vector<int64> end = {k, m};
184 if (!left_side) {
185 std::swap(start[0], start[1]);
186 std::swap(end[0], end[1]);
187 }
188 auto b_row = SliceInMinorDims(b, start, end);
189
190 XlaOp remainder;
191 if (i == 0) {
192 remainder = b_row;
193 } else {
194 // This matrix multiply get rid of a lot of multiplying with zero
195 // (namely, X[i * block_size:] = 0), L[i, :i] @ X[:i]
196 if (backward) {
197 start = {j * block_size,
198 std::max(int64{0}, (num_blocks - i) * block_size)};
199 end = {k, n};
200 } else {
201 start = {j * block_size, 0};
202 end = {k, std::min(i * block_size, n)};
203 }
204
205 if (!left_side ^ transpose_a) {
206 std::swap(start[0], start[1]);
207 std::swap(end[0], end[1]);
208 }
209 auto a_row =
210 MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
211 if (left_side) {
212 remainder = b_row - BatchDot(a_row, transpose_a, x, false, precision);
213 } else {
214 remainder = b_row - BatchDot(x, false, a_row, transpose_a, precision);
215 }
216 }
217
218 XlaOp x_update;
219 if (left_side) {
220 x_update =
221 BatchDot(inv_block, transpose_a, remainder, false, precision);
222 } else {
223 x_update =
224 BatchDot(remainder, false, inv_block, transpose_a, precision);
225 }
226
227 if (i == 0) {
228 x = x_update;
229 } else {
230 if (backward) {
231 x = ConcatInDim(builder, {x_update, x}, block_dim);
232 } else {
233 x = ConcatInDim(builder, {x, x_update}, block_dim);
234 }
235 }
236 }
237
238 return x;
239 });
240 }
241
242 } // namespace
243
InvertDiagonalBlocks(XlaOp diag_blocks,bool lower_triangular,PrecisionConfig::Precision precision)244 XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
245 XlaOp diag_blocks, bool lower_triangular,
246 PrecisionConfig::Precision precision) {
247 XlaBuilder* builder = diag_blocks.builder();
248 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
249 // Input is a batch of square lower triangular square matrices. Its shape is
250 // (..., size, size). We resize this to (num_blocks, size, size).
251 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
252 int64 block_size = ShapeUtil::GetDimension(shape, -1);
253 int64 num_blocks = ShapeUtil::ElementsIn(shape) /
254 tensorflow::MathUtil::IPow(block_size, 2);
255 diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
256
257 // The input must be triangular because we rely on that when doing
258 // multiplications later on
259 diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
260
261 // Rescale blocks to be unit triangular, but avoid dividing by
262 // zero (which can happen if the last block was padded) otherwise it will
263 // introduce nans which will propagate
264 auto diags = GetMatrixDiagonal(diag_blocks);
265 auto ones = FullLike(diags, 1);
266 diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
267 auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
268
269 // We can now use the fact that for an upper triangular matrix
270 // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
271 // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
272 // have been rescaled to be unit triangular, so L22 = L22' = 1.
273
274 // Initialize the output matrix with -1s on the diagonal. We use -1 instead
275 // of 1 because we cannot do matrix-vector multiplies with variable shapes
276 // inside of a loop, or do irregularly shaped in-place updates. Hence,
277 // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
278 // entire row i.e. we calculate
279 // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
280 // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
281 auto identity =
282 IdentityMatrix(builder, shape.element_type(), block_size, block_size);
283 auto neg_identity = -identity;
284
285 // The first or last diagonal element should be set to 1 instead of -1
286 // though, since we never update it
287 auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
288 auto start_index =
289 ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
290 auto output_block =
291 DynamicUpdateSlice(neg_identity, pos_one,
292 /*start_indices=*/{start_index, start_index});
293
294 // Broadcast diag([1, -1, -1, ...]) to every block
295 XlaOp output = Broadcast(output_block,
296 /*broadcast_sizes=*/{num_blocks});
297
298 // Now we construct a loop that performs matrix-vector multiplications
299 // inverting the blocks one row at a time
300 std::vector<Shape> tuple_shapes = {
301 // The loop iteration counter is a scalar, incremented each iteration.
302 ShapeUtil::MakeShape(S32, {}),
303 // The output has the shape of A, with one row updated each iteration.
304 ShapeUtil::MakeShape(shape.element_type(),
305 {num_blocks, block_size, block_size}),
306 // The input is a loop invariant.
307 ShapeUtil::MakeShape(shape.element_type(),
308 {num_blocks, block_size, block_size})};
309 Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
310
311 auto init_i = One(builder, S32);
312 auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
313
314 // Construct the loop condition function.
315 std::unique_ptr<XlaBuilder> condb =
316 builder->CreateSubBuilder("InvertDiagCond");
317 {
318 auto i = GetTupleElement(
319 Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
320 Lt(i, ConstantR0<int32>(condb.get(), block_size));
321 }
322 TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
323
324 // Construct the loop body function.
325 std::unique_ptr<XlaBuilder> bodyb =
326 builder->CreateSubBuilder("InvertDiagBody");
327 {
328 auto input_tuple =
329 Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
330
331 auto i = GetTupleElement(input_tuple, 0);
332 auto body_out = GetTupleElement(input_tuple, 1);
333 auto body_input = GetTupleElement(input_tuple, 2);
334
335 auto zero = ConstantR0<int32>(bodyb.get(), 0);
336 auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
337 auto input_row =
338 DynamicSlice(body_input, {zero, j, zero},
339 /*slice_sizes=*/{num_blocks, 1, block_size});
340
341 // We want -L21 L11^{-1}
342 DotDimensionNumbers dnums;
343 dnums.add_lhs_batch_dimensions(0);
344 dnums.add_rhs_batch_dimensions(0);
345 dnums.add_lhs_contracting_dimensions(2);
346 dnums.add_rhs_contracting_dimensions(1);
347 PrecisionConfig precision_proto;
348 precision_proto.add_operand_precision(precision);
349 precision_proto.add_operand_precision(precision);
350 auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
351
352 body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
353
354 auto next_i = i + ScalarLike(i, 1);
355 Tuple(bodyb.get(), {next_i, body_out, body_input});
356 }
357 TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
358
359 // Construct the While loop and return the result,
360 // return while_loop(cond_fun, body_fun, init)[1]
361 auto invert_while = While(cond, body, init);
362 auto inv_diag_blocks = GetTupleElement(invert_while, 1);
363 // Undo the scaling
364 inv_diag_blocks = Div(inv_diag_blocks, diags,
365 /*broadcast_dimensions=*/{0, 1});
366
367 // Reshape back to original batch major dimensions
368 return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
369 });
370 }
371
SolveByInvertingDiagonalBlocks(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)372 XlaOp TriangularSolveExpander::SolveByInvertingDiagonalBlocks(
373 XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
374 bool conjugate_a, bool unit_diagonal,
375 PrecisionConfig::Precision precision) {
376 XlaBuilder* builder = a.builder();
377 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
378 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
379 const int64 ndims = a_shape.rank();
380 int64 k = ShapeUtil::GetDimension(a_shape, -1);
381
382 // TODO(phawkins): consider pushing triangle masking into
383 // InvertDiagonalBlocks.
384 if (unit_diagonal) {
385 // Mask everything but the subdiagonal/superdiagonal elements.
386 a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
387 : Select(TriangleMask(a, 0), ZerosLike(a), a);
388 a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
389 /*broadcast_dimensions=*/{ndims - 2, ndims - 1});
390 } else {
391 // Mask off the ignored elements of the triangular matrix a.
392 a = Triangle(a, lower);
393 }
394
395 // We find the diagonal blocks of the coefficient matrix
396 int64 block_size = std::min(block_size_, k);
397 auto diag_blocks = DiagonalBlocks(a, block_size);
398
399 // We invert these blocks in parallel using batched matrix-vector products
400 auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
401
402 // We now find the solution using GEMMs
403 return SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
404 lower, transpose_a, conjugate_a,
405 precision);
406 });
407 }
408
409 // def trsm_left_lower_leftlooking(a, b):
410 // n = a.shape[-1]
411 // assert a.shape == (n, n)
412 // b = b.copy()
413 // for j in range(n):
414 // b[j, :] = (b[j, :] - np.dot(a[j, :j], b[:j, :])) / a[j, j]
415 // return b
SolveDirectly(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)416 XlaOp TriangularSolveExpander::SolveDirectly(
417 XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
418 bool conjugate_a, bool unit_diagonal,
419 PrecisionConfig::Precision precision) {
420 XlaBuilder* builder = a.builder();
421 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
422 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
423 TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
424 int64 m = ShapeUtil::GetDimension(b_shape, -2);
425 int64 n = ShapeUtil::GetDimension(b_shape, -1);
426 const int64 a_size = ShapeUtil::GetDimension(a_shape, -1);
427 a = MaybeConjugate(a, conjugate_a);
428 bool backwards = transpose_a ^ lower ^ !left_side;
429 for (int64 i = 0; i < a_size; ++i) {
430 int64 j = backwards ? i : (a_size - i - 1);
431 std::vector<int64> b_row_start, b_row_end;
432 if (left_side) {
433 b_row_start = {j, 0};
434 b_row_end = {j + 1, n};
435 } else {
436 b_row_start = {0, j};
437 b_row_end = {m, j + 1};
438 }
439 auto b_row = SliceInMinorDims(b, b_row_start, b_row_end);
440
441 std::vector<int64> a_start = {j, backwards ? 0 : (j + 1)};
442 std::vector<int64> a_end = {j + 1, backwards ? j : a_size};
443 if (transpose_a ^ !left_side) {
444 std::swap(a_start[0], a_start[1]);
445 std::swap(a_end[0], a_end[1]);
446 }
447 auto a_chunk = SliceInMinorDims(a, a_start, a_end);
448 if (left_side) {
449 bool which = transpose_a ^ lower;
450 auto b_chunk =
451 SliceInMinorDims(b, {which ? 0 : (j + 1), 0}, {which ? j : m, n});
452 b_row = b_row - BatchDot(a_chunk, /*transpose_x=*/transpose_a, b_chunk,
453 /*transpose_y=*/false, precision);
454 } else {
455 bool which = transpose_a ^ !lower;
456 auto b_chunk =
457 SliceInMinorDims(b, {0, which ? 0 : (j + 1)}, {m, which ? j : n});
458 b_row = b_row - BatchDot(b_chunk, /*transpose_x=*/false, a_chunk,
459 /*transpose_y=*/transpose_a, precision);
460 }
461 if (!unit_diagonal) {
462 auto a_diag = SliceInMinorDims(a, {j, j}, {j + 1, j + 1});
463 b_row = b_row / a_diag;
464 }
465
466 b = UpdateSliceInMinorDims(b, b_row, b_row_start);
467 }
468
469 return b;
470 });
471 }
472
BuildTriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,int64 block_size,PrecisionConfig::Precision precision)473 XlaOp TriangularSolveExpander::BuildTriangularSolve(
474 XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
475 bool conjugate_a, bool unit_diagonal, int64 block_size,
476 PrecisionConfig::Precision precision) {
477 XlaBuilder* builder = a.builder();
478 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
479 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
480 TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
481 if (a_shape.rank() != b_shape.rank()) {
482 return InvalidArgument(
483 "Arguments to TriangularSolve have shapes with different ranks: "
484 "%s vs. %s",
485 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
486 }
487 const int64 ndims = a_shape.rank();
488 if (ndims < 2) {
489 return InvalidArgument(
490 "Arguments to TriangularSolve was rank %d but must have rank >= 2.",
491 ndims);
492 }
493 // The batch dimensions must be equal.
494 std::vector<int64> batch_dimensions;
495 int64 batch = 1;
496 for (int i = 0; i < ndims - 2; ++i) {
497 int64 a_size = a_shape.dimensions(i);
498 int64 b_size = b_shape.dimensions(i);
499 if (a_size != b_size) {
500 return InvalidArgument(
501 "Batch dimensions of arguments to TriangularSolve must be equal; "
502 "shapes were %s and %s.",
503 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
504 }
505 batch_dimensions.push_back(a_size);
506 batch *= a_size;
507 }
508
509 if (ShapeUtil::GetDimension(a_shape, -1) !=
510 ShapeUtil::GetDimension(a_shape, -2)) {
511 return InvalidArgument(
512 "The 'a' argument to TriangularSolve must be a batched square matrix;"
513 " shape was: %s",
514 ShapeUtil::HumanString(a_shape));
515 }
516 const int64 m = ShapeUtil::GetDimension(b_shape, -2);
517 const int64 n = ShapeUtil::GetDimension(b_shape, -1);
518 if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) {
519 return InvalidArgument(
520 "Arguments to TriangularSolve have incompatible matrix shapes %s and "
521 "%s",
522 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
523 }
524
525 int64 a_size = ShapeUtil::GetDimension(a_shape, -1);
526
527 if (ShapeUtil::IsZeroElementArray(b_shape)) {
528 // The output has the same shape as 'b', and since the output has zero
529 // elements, any such array will do.
530 return b;
531 }
532
533 // Degenerate case: 1x1 matrices.
534 if (a_size == 1) {
535 return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
536 }
537
538 // Prefer the direct implementation whenever there is a nontrivial batch
539 // dimension and the matrix is very small.
540 if (batch > block_size_ / 16 && a_size < block_size_ / 4) {
541 return SolveDirectly(a, b, left_side, lower, transpose_a, conjugate_a,
542 unit_diagonal, precision);
543 } else {
544 return SolveByInvertingDiagonalBlocks(a, b, left_side, lower, transpose_a,
545 conjugate_a, unit_diagonal,
546 precision);
547 }
548 });
549 }
550
TriangularSolveExpander(int64 block_size)551 TriangularSolveExpander::TriangularSolveExpander(int64 block_size)
552 : block_size_(block_size) {
553 CHECK_GE(block_size_, 1);
554 }
555
InstructionMatchesPattern(HloInstruction * instruction)556 bool TriangularSolveExpander::InstructionMatchesPattern(
557 HloInstruction* instruction) {
558 return instruction->opcode() == HloOpcode::kTriangularSolve;
559 }
560
ExpandInstruction(HloInstruction * instruction)561 StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
562 HloInstruction* instruction) {
563 const TriangularSolveOptions& options =
564 instruction->triangular_solve_options();
565 const string name = absl::StrFormat(
566 "xla.triangular_solve_%s_%s_%s_%s_%s_%s",
567 instruction->operand(0)->shape().ToString(),
568 instruction->operand(1)->shape().ToString(),
569 options.left_side() ? "left" : "right",
570 options.lower() ? "lower" : "upper",
571 TriangularSolveOptions_Transpose_Name(options.transpose_a()),
572 options.unit_diagonal() ? "unit" : "nonunit");
573
574 HloModule* module = instruction->parent()->parent();
575
576 HloComputation*& computation =
577 computation_cache_.emplace(name, nullptr).first->second;
578 if (!computation) {
579 // Builds a new expansion.
580 //
581 // We do something unusual here: we build the computation using the
582 // XlaBuilder API, which is nominally an XLA client API. We do this because
583 // the external APIs for building complicated computations (XlaBuilder)
584 // are much more ergonomic than the internal ones. As it turns out,
585 // XlaBuilder isn't really a client API—what it does is build a
586 // HloModuleProto protocol buffer, that we can then deserialize and clone
587 // into our HloModule. Ideally we would avoid the protocol buffer step;
588 // that is left as an exercise for future work.
589 XlaBuilder builder(name);
590 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
591 XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b");
592 bool transpose_a =
593 options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE;
594 bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT;
595
596 BuildTriangularSolve(a, b, options.left_side(), options.lower(),
597 transpose_a, conjugate_a, options.unit_diagonal(),
598 /*block_size=*/block_size_,
599 /*precision=*/PrecisionConfig::HIGHEST);
600 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
601
602 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
603 xla_computation.GetProgramShape());
604 HloModuleConfig config(program_shape);
605 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
606 xla_computation.proto(), config));
607 HloCloneContext context(module);
608 computation =
609 module->DeepCloneComputation(new_module->entry_computation(), &context);
610 }
611
612 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
613 instruction->shape(), instruction->operands(), computation));
614 }
615
616 } // namespace xla
617