1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34 
35 namespace xla {
36 
37 namespace {
38 
39 // Get the diagonal blocks of the coefficient matrix
DiagonalBlocks(XlaOp a,int64 block_size)40 XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
41   XlaBuilder* builder = a.builder();
42   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
43     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
44     int ndims = shape.rank();
45     int64 n = ShapeUtil::GetDimension(shape, -1);
46     int64 num_blocks = n / block_size;
47     absl::Span<int64 const> batch_dims = absl::MakeConstSpan(
48         shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
49 
50     XlaOp diag_blocks;
51 
52     // If the coefficient matrix is exactly the block size, we just add a
53     // singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
54     if (n == block_size) {
55       std::vector<int64> permutation(ndims);
56       std::iota(permutation.begin(), permutation.end(), 1);
57       permutation.insert(permutation.end() - 2, 0);
58       return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
59     }
60 
61     // We can grab entire blocks using gather
62     if (n > block_size) {
63       // Construct the starting indices of the diagonal blocks
64       auto start_indices =
65           Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks),
66                                   ConstantR0<int32>(builder, block_size)),
67                               /*broadcast_sizes=*/{2}),
68                     /*permutation=*/{1, 0});
69 
70       PaddingConfig padding_config =
71           MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}});
72       start_indices =
73           Pad(start_indices, ConstantR0<int32>(builder, 0), padding_config);
74 
75       // Gather the diagonal blocks
76       std::vector<int64> slice_sizes(ndims);
77       GatherDimensionNumbers dim_numbers;
78       for (int i = 0; i < ndims - 2; ++i) {
79         dim_numbers.add_offset_dims(i);
80         dim_numbers.add_start_index_map(i);
81         slice_sizes[i] = ShapeUtil::GetDimension(shape, i);
82       }
83       slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size;
84       dim_numbers.add_offset_dims(ndims - 1);
85       dim_numbers.add_offset_dims(ndims);
86       dim_numbers.add_start_index_map(ndims - 2);
87       dim_numbers.add_start_index_map(ndims - 1);
88       dim_numbers.set_index_vector_dim(1);
89       diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
90     }
91 
92     // The last block might be smaller than the block size,
93     // so we will need to pad it
94     if (n % block_size != 0) {
95       // Pad with identity matrix.
96       auto last_blocks =
97           SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
98       PaddingConfig config = MakeNoPaddingConfig(ndims);
99       int64 padding = block_size - n % block_size;
100       config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
101       last_blocks =
102           Pad(last_blocks, Zero(builder, shape.element_type()), config);
103 
104       auto eye =
105           IdentityMatrix(builder, shape.element_type(), padding, padding);
106       config = MakeNoPaddingConfig(2);
107       config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
108       eye = Pad(eye, Zero(builder, shape.element_type()), config);
109       eye = Broadcast(eye, batch_dims);
110       last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
111 
112       // Add a singleton dimension
113       // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
114       TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
115       auto shape_dims = AsInt64Slice(blocks_shape.dimensions());
116       auto last_blocks_dims = std::vector<int64>(ndims);
117       std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
118       last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
119       last_blocks = Reshape(last_blocks, last_blocks_dims);
120 
121       // Concatenate with the other blocks if necessary
122       if (n > block_size) {
123         diag_blocks =
124             ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
125       } else {
126         diag_blocks = last_blocks;
127       }
128     }
129 
130     return diag_blocks;
131   });
132 }
133 
SolveWithInvertedDiagonalBlocks(XlaOp a,XlaOp b,XlaOp inv_diag_blocks,bool left_side,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)134 XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
135                                       bool left_side, bool lower,
136                                       bool transpose_a, bool conjugate_a,
137                                       PrecisionConfig::Precision precision) {
138   XlaBuilder* builder = a.builder();
139   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
140     TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks));
141     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
142     int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1);
143 
144     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
145     int64 ndims = a_shape.rank();
146     int64 n = ShapeUtil::GetDimension(a_shape, -1);
147     int64 num_blocks = n / block_size + (n % block_size != 0);
148     int64 m_dim = (left_side) ? -1 : -2;
149     int64 m = ShapeUtil::GetDimension(b_shape, m_dim);
150 
151     std::vector<XlaOp> update_ops;
152     int bdims = b_shape.rank();
153     int64 block_dim = (left_side) ? bdims - 2 : bdims - 1;
154 
155     // Initialize the solution
156     XlaOp x;
157 
158     // This loop is unrolled for performance reasons, but it could be expressed
159     // rolled as well since the matrices are of the same size each iteration
160     for (int i = 0; i < num_blocks; i++) {
161       // High-level intuition: We have B[i] = L[i] @ X. Since L is upper
162       // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
163       // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
164       // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
165 
166       // Decide whether we go from first block to last or vice versa
167       bool backward = left_side ^ lower ^ transpose_a;
168       auto j = backward ? num_blocks - 1 - i : i;
169 
170       // Get the size of the inverse blocks (the last one might be smaller)
171       int64 block = (n % block_size != 0 && j + 1 == num_blocks)
172                         ? n % block_size
173                         : block_size;
174       auto inv_block =
175           MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
176                                                    {j + 1, block, block}),
177                                   /*dimensions=*/{ndims - 2, ndims - 1}),
178                          conjugate_a);
179 
180       // Get the corresponding row of B
181       int64 k = std::min((j + 1) * block_size, n);
182       std::vector<int64> start = {j * block_size, 0};
183       std::vector<int64> end = {k, m};
184       if (!left_side) {
185         std::swap(start[0], start[1]);
186         std::swap(end[0], end[1]);
187       }
188       auto b_row = SliceInMinorDims(b, start, end);
189 
190       XlaOp remainder;
191       if (i == 0) {
192         remainder = b_row;
193       } else {
194         // This matrix multiply get rid of a lot of multiplying with zero
195         // (namely, X[i * block_size:] = 0), L[i, :i] @ X[:i]
196         if (backward) {
197           start = {j * block_size,
198                    std::max(int64{0}, (num_blocks - i) * block_size)};
199           end = {k, n};
200         } else {
201           start = {j * block_size, 0};
202           end = {k, std::min(i * block_size, n)};
203         }
204 
205         if (!left_side ^ transpose_a) {
206           std::swap(start[0], start[1]);
207           std::swap(end[0], end[1]);
208         }
209         auto a_row =
210             MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
211         if (left_side) {
212           remainder = b_row - BatchDot(a_row, transpose_a, x, false, precision);
213         } else {
214           remainder = b_row - BatchDot(x, false, a_row, transpose_a, precision);
215         }
216       }
217 
218       XlaOp x_update;
219       if (left_side) {
220         x_update =
221             BatchDot(inv_block, transpose_a, remainder, false, precision);
222       } else {
223         x_update =
224             BatchDot(remainder, false, inv_block, transpose_a, precision);
225       }
226 
227       if (i == 0) {
228         x = x_update;
229       } else {
230         if (backward) {
231           x = ConcatInDim(builder, {x_update, x}, block_dim);
232         } else {
233           x = ConcatInDim(builder, {x, x_update}, block_dim);
234         }
235       }
236     }
237 
238     return x;
239   });
240 }
241 
242 }  // namespace
243 
InvertDiagonalBlocks(XlaOp diag_blocks,bool lower_triangular,PrecisionConfig::Precision precision)244 XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
245     XlaOp diag_blocks, bool lower_triangular,
246     PrecisionConfig::Precision precision) {
247   XlaBuilder* builder = diag_blocks.builder();
248   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
249     // Input is a batch of square lower triangular square matrices. Its shape is
250     // (..., size, size). We resize this to (num_blocks, size, size).
251     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
252     int64 block_size = ShapeUtil::GetDimension(shape, -1);
253     int64 num_blocks = ShapeUtil::ElementsIn(shape) /
254                        tensorflow::MathUtil::IPow(block_size, 2);
255     diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
256 
257     // The input must be triangular because we rely on that when doing
258     // multiplications later on
259     diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
260 
261     // Rescale blocks to be unit triangular, but avoid dividing by
262     // zero (which can happen if the last block was padded) otherwise it will
263     // introduce nans which will propagate
264     auto diags = GetMatrixDiagonal(diag_blocks);
265     auto ones = FullLike(diags, 1);
266     diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
267     auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
268 
269     // We can now use the fact that for an upper triangular matrix
270     // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
271     // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
272     // have been rescaled to be unit triangular, so L22 = L22' = 1.
273 
274     // Initialize the output matrix with -1s on the diagonal. We use -1 instead
275     // of 1 because we cannot do matrix-vector multiplies with variable shapes
276     // inside of a loop, or do irregularly shaped in-place updates. Hence,
277     // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
278     // entire row i.e. we calculate
279     // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
280     // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
281     auto identity =
282         IdentityMatrix(builder, shape.element_type(), block_size, block_size);
283     auto neg_identity = -identity;
284 
285     // The first or last  diagonal element should be set to 1 instead of -1
286     // though, since we never update it
287     auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
288     auto start_index =
289         ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
290     auto output_block =
291         DynamicUpdateSlice(neg_identity, pos_one,
292                            /*start_indices=*/{start_index, start_index});
293 
294     // Broadcast diag([1, -1, -1, ...]) to every block
295     XlaOp output = Broadcast(output_block,
296                              /*broadcast_sizes=*/{num_blocks});
297 
298     // Now we construct a loop that performs matrix-vector multiplications
299     // inverting the blocks one row at a time
300     std::vector<Shape> tuple_shapes = {
301         // The loop iteration counter is a scalar, incremented each iteration.
302         ShapeUtil::MakeShape(S32, {}),
303         // The output has the shape of A, with one row updated each iteration.
304         ShapeUtil::MakeShape(shape.element_type(),
305                              {num_blocks, block_size, block_size}),
306         // The input is a loop invariant.
307         ShapeUtil::MakeShape(shape.element_type(),
308                              {num_blocks, block_size, block_size})};
309     Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
310 
311     auto init_i = One(builder, S32);
312     auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
313 
314     // Construct the loop condition function.
315     std::unique_ptr<XlaBuilder> condb =
316         builder->CreateSubBuilder("InvertDiagCond");
317     {
318       auto i = GetTupleElement(
319           Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
320       Lt(i, ConstantR0<int32>(condb.get(), block_size));
321     }
322     TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
323 
324     // Construct the loop body function.
325     std::unique_ptr<XlaBuilder> bodyb =
326         builder->CreateSubBuilder("InvertDiagBody");
327     {
328       auto input_tuple =
329           Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
330 
331       auto i = GetTupleElement(input_tuple, 0);
332       auto body_out = GetTupleElement(input_tuple, 1);
333       auto body_input = GetTupleElement(input_tuple, 2);
334 
335       auto zero = ConstantR0<int32>(bodyb.get(), 0);
336       auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
337       auto input_row =
338           DynamicSlice(body_input, {zero, j, zero},
339                        /*slice_sizes=*/{num_blocks, 1, block_size});
340 
341       // We want -L21 L11^{-1}
342       DotDimensionNumbers dnums;
343       dnums.add_lhs_batch_dimensions(0);
344       dnums.add_rhs_batch_dimensions(0);
345       dnums.add_lhs_contracting_dimensions(2);
346       dnums.add_rhs_contracting_dimensions(1);
347       PrecisionConfig precision_proto;
348       precision_proto.add_operand_precision(precision);
349       precision_proto.add_operand_precision(precision);
350       auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
351 
352       body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
353 
354       auto next_i = i + ScalarLike(i, 1);
355       Tuple(bodyb.get(), {next_i, body_out, body_input});
356     }
357     TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
358 
359     // Construct the While loop and return the result,
360     // return while_loop(cond_fun, body_fun, init)[1]
361     auto invert_while = While(cond, body, init);
362     auto inv_diag_blocks = GetTupleElement(invert_while, 1);
363     // Undo the scaling
364     inv_diag_blocks = Div(inv_diag_blocks, diags,
365                           /*broadcast_dimensions=*/{0, 1});
366 
367     // Reshape back to original batch major dimensions
368     return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
369   });
370 }
371 
SolveByInvertingDiagonalBlocks(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)372 XlaOp TriangularSolveExpander::SolveByInvertingDiagonalBlocks(
373     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
374     bool conjugate_a, bool unit_diagonal,
375     PrecisionConfig::Precision precision) {
376   XlaBuilder* builder = a.builder();
377   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
378     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
379     const int64 ndims = a_shape.rank();
380     int64 k = ShapeUtil::GetDimension(a_shape, -1);
381 
382     // TODO(phawkins): consider pushing triangle masking into
383     // InvertDiagonalBlocks.
384     if (unit_diagonal) {
385       // Mask everything but the subdiagonal/superdiagonal elements.
386       a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
387                 : Select(TriangleMask(a, 0), ZerosLike(a), a);
388       a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
389                    /*broadcast_dimensions=*/{ndims - 2, ndims - 1});
390     } else {
391       // Mask off the ignored elements of the triangular matrix a.
392       a = Triangle(a, lower);
393     }
394 
395     // We find the diagonal blocks of the coefficient matrix
396     int64 block_size = std::min(block_size_, k);
397     auto diag_blocks = DiagonalBlocks(a, block_size);
398 
399     // We invert these blocks in parallel using batched matrix-vector products
400     auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
401 
402     // We now find the solution using GEMMs
403     return SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
404                                            lower, transpose_a, conjugate_a,
405                                            precision);
406   });
407 }
408 
409 // def trsm_left_lower_leftlooking(a, b):
410 //   n = a.shape[-1]
411 //   assert a.shape == (n, n)
412 //   b = b.copy()
413 //   for j in range(n):
414 //     b[j, :] = (b[j, :] - np.dot(a[j, :j], b[:j, :])) / a[j, j]
415 //   return b
SolveDirectly(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)416 XlaOp TriangularSolveExpander::SolveDirectly(
417     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
418     bool conjugate_a, bool unit_diagonal,
419     PrecisionConfig::Precision precision) {
420   XlaBuilder* builder = a.builder();
421   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
422     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
423     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
424     int64 m = ShapeUtil::GetDimension(b_shape, -2);
425     int64 n = ShapeUtil::GetDimension(b_shape, -1);
426     const int64 a_size = ShapeUtil::GetDimension(a_shape, -1);
427     a = MaybeConjugate(a, conjugate_a);
428     bool backwards = transpose_a ^ lower ^ !left_side;
429     for (int64 i = 0; i < a_size; ++i) {
430       int64 j = backwards ? i : (a_size - i - 1);
431       std::vector<int64> b_row_start, b_row_end;
432       if (left_side) {
433         b_row_start = {j, 0};
434         b_row_end = {j + 1, n};
435       } else {
436         b_row_start = {0, j};
437         b_row_end = {m, j + 1};
438       }
439       auto b_row = SliceInMinorDims(b, b_row_start, b_row_end);
440 
441       std::vector<int64> a_start = {j, backwards ? 0 : (j + 1)};
442       std::vector<int64> a_end = {j + 1, backwards ? j : a_size};
443       if (transpose_a ^ !left_side) {
444         std::swap(a_start[0], a_start[1]);
445         std::swap(a_end[0], a_end[1]);
446       }
447       auto a_chunk = SliceInMinorDims(a, a_start, a_end);
448       if (left_side) {
449         bool which = transpose_a ^ lower;
450         auto b_chunk =
451             SliceInMinorDims(b, {which ? 0 : (j + 1), 0}, {which ? j : m, n});
452         b_row = b_row - BatchDot(a_chunk, /*transpose_x=*/transpose_a, b_chunk,
453                                  /*transpose_y=*/false, precision);
454       } else {
455         bool which = transpose_a ^ !lower;
456         auto b_chunk =
457             SliceInMinorDims(b, {0, which ? 0 : (j + 1)}, {m, which ? j : n});
458         b_row = b_row - BatchDot(b_chunk, /*transpose_x=*/false, a_chunk,
459                                  /*transpose_y=*/transpose_a, precision);
460       }
461       if (!unit_diagonal) {
462         auto a_diag = SliceInMinorDims(a, {j, j}, {j + 1, j + 1});
463         b_row = b_row / a_diag;
464       }
465 
466       b = UpdateSliceInMinorDims(b, b_row, b_row_start);
467     }
468 
469     return b;
470   });
471 }
472 
BuildTriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,int64 block_size,PrecisionConfig::Precision precision)473 XlaOp TriangularSolveExpander::BuildTriangularSolve(
474     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
475     bool conjugate_a, bool unit_diagonal, int64 block_size,
476     PrecisionConfig::Precision precision) {
477   XlaBuilder* builder = a.builder();
478   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
479     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
480     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
481     if (a_shape.rank() != b_shape.rank()) {
482       return InvalidArgument(
483           "Arguments to TriangularSolve have shapes with different ranks: "
484           "%s vs. %s",
485           ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
486     }
487     const int64 ndims = a_shape.rank();
488     if (ndims < 2) {
489       return InvalidArgument(
490           "Arguments to TriangularSolve was rank %d but must have rank >= 2.",
491           ndims);
492     }
493     // The batch dimensions must be equal.
494     std::vector<int64> batch_dimensions;
495     int64 batch = 1;
496     for (int i = 0; i < ndims - 2; ++i) {
497       int64 a_size = a_shape.dimensions(i);
498       int64 b_size = b_shape.dimensions(i);
499       if (a_size != b_size) {
500         return InvalidArgument(
501             "Batch dimensions of arguments to TriangularSolve must be equal; "
502             "shapes were %s and %s.",
503             ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
504       }
505       batch_dimensions.push_back(a_size);
506       batch *= a_size;
507     }
508 
509     if (ShapeUtil::GetDimension(a_shape, -1) !=
510         ShapeUtil::GetDimension(a_shape, -2)) {
511       return InvalidArgument(
512           "The 'a' argument to TriangularSolve must be a batched square matrix;"
513           " shape was: %s",
514           ShapeUtil::HumanString(a_shape));
515     }
516     const int64 m = ShapeUtil::GetDimension(b_shape, -2);
517     const int64 n = ShapeUtil::GetDimension(b_shape, -1);
518     if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) {
519       return InvalidArgument(
520           "Arguments to TriangularSolve have incompatible matrix shapes %s and "
521           "%s",
522           ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
523     }
524 
525     int64 a_size = ShapeUtil::GetDimension(a_shape, -1);
526 
527     if (ShapeUtil::IsZeroElementArray(b_shape)) {
528       // The output has the same shape as 'b', and since the output has zero
529       // elements, any such array will do.
530       return b;
531     }
532 
533     // Degenerate case: 1x1 matrices.
534     if (a_size == 1) {
535       return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
536     }
537 
538     // Prefer the direct implementation whenever there is a nontrivial batch
539     // dimension and the matrix is very small.
540     if (batch > block_size_ / 16 && a_size < block_size_ / 4) {
541       return SolveDirectly(a, b, left_side, lower, transpose_a, conjugate_a,
542                            unit_diagonal, precision);
543     } else {
544       return SolveByInvertingDiagonalBlocks(a, b, left_side, lower, transpose_a,
545                                             conjugate_a, unit_diagonal,
546                                             precision);
547     }
548   });
549 }
550 
TriangularSolveExpander(int64 block_size)551 TriangularSolveExpander::TriangularSolveExpander(int64 block_size)
552     : block_size_(block_size) {
553   CHECK_GE(block_size_, 1);
554 }
555 
InstructionMatchesPattern(HloInstruction * instruction)556 bool TriangularSolveExpander::InstructionMatchesPattern(
557     HloInstruction* instruction) {
558   return instruction->opcode() == HloOpcode::kTriangularSolve;
559 }
560 
ExpandInstruction(HloInstruction * instruction)561 StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
562     HloInstruction* instruction) {
563   const TriangularSolveOptions& options =
564       instruction->triangular_solve_options();
565   const string name = absl::StrFormat(
566       "xla.triangular_solve_%s_%s_%s_%s_%s_%s",
567       instruction->operand(0)->shape().ToString(),
568       instruction->operand(1)->shape().ToString(),
569       options.left_side() ? "left" : "right",
570       options.lower() ? "lower" : "upper",
571       TriangularSolveOptions_Transpose_Name(options.transpose_a()),
572       options.unit_diagonal() ? "unit" : "nonunit");
573 
574   HloModule* module = instruction->parent()->parent();
575 
576   HloComputation*& computation =
577       computation_cache_.emplace(name, nullptr).first->second;
578   if (!computation) {
579     // Builds a new expansion.
580     //
581     // We do something unusual here: we build the computation using the
582     // XlaBuilder API, which is nominally an XLA client API. We do this because
583     // the external APIs for building complicated computations (XlaBuilder)
584     // are much more ergonomic than the internal ones. As it turns out,
585     // XlaBuilder isn't really a client API—what it does is build a
586     // HloModuleProto protocol buffer, that we can then deserialize and clone
587     // into our HloModule. Ideally we would avoid the protocol buffer step;
588     // that is left as an exercise for future work.
589     XlaBuilder builder(name);
590     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
591     XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b");
592     bool transpose_a =
593         options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE;
594     bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT;
595 
596     BuildTriangularSolve(a, b, options.left_side(), options.lower(),
597                          transpose_a, conjugate_a, options.unit_diagonal(),
598                          /*block_size=*/block_size_,
599                          /*precision=*/PrecisionConfig::HIGHEST);
600     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
601 
602     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
603                         xla_computation.GetProgramShape());
604     HloModuleConfig config(program_shape);
605     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
606                                              xla_computation.proto(), config));
607     HloCloneContext context(module);
608     computation =
609         module->DeepCloneComputation(new_module->entry_computation(), &context);
610   }
611 
612   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
613       instruction->shape(), instruction->operands(), computation));
614 }
615 
616 }  // namespace xla
617