1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/math.h"
23 #include "tensorflow/compiler/xla/client/lib/matrix.h"
24 #include "tensorflow/compiler/xla/client/lib/slicing.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/math/math_util.h"
33 
34 namespace xla {
35 
36 namespace {
37 
38 // Get the diagonal blocks of the coefficient matrix
DiagonalBlocks(XlaOp a,int64 block_size)39 XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
40   XlaBuilder* builder = a.builder();
41   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
42     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
43     int ndims = shape.rank();
44     int64 n = ShapeUtil::GetDimension(shape, -1);
45     int64 num_blocks = n / block_size;
46 
47     XlaOp diag_blocks;
48 
49     // If the coefficient matrix is exactly the block size, we just add a
50     // singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
51     if (n == block_size) {
52       std::vector<int64> permutation(ndims);
53       std::iota(permutation.begin(), permutation.end(), 1);
54       permutation.insert(permutation.end() - 2, 0);
55       return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
56     }
57 
58     // We can grab entire blocks using gather
59     if (n > block_size) {
60       // Construct the starting indices of the diagonal blocks
61       auto start_indices =
62           Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks),
63                                   ConstantR0<int32>(builder, block_size)),
64                               /*broadcast_sizes=*/{2}),
65                     /*permutation=*/{1, 0});
66 
67       PaddingConfig padding_config =
68           MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}});
69       start_indices =
70           Pad(start_indices, ConstantR0<int32>(builder, 0), padding_config);
71 
72       // Gather the diagonal blocks
73       std::vector<int64> slice_sizes(ndims);
74       GatherDimensionNumbers dim_numbers;
75       for (int i = 0; i < ndims - 2; ++i) {
76         dim_numbers.add_offset_dims(i);
77         dim_numbers.add_start_index_map(i);
78         slice_sizes[i] = ShapeUtil::GetDimension(shape, i);
79       }
80       slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size;
81       dim_numbers.add_offset_dims(ndims - 1);
82       dim_numbers.add_offset_dims(ndims);
83       dim_numbers.add_start_index_map(ndims - 2);
84       dim_numbers.add_start_index_map(ndims - 1);
85       dim_numbers.set_index_vector_dim(1);
86       diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
87     }
88 
89     // The last block might be smaller than the block size,
90     // so we will need to pad it
91     if (n % block_size != 0) {
92       // Pad with zeros
93       auto last_blocks =
94           SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
95       PaddingConfig config = MakeNoPaddingConfig(ndims);
96       int64 padding = block_size - n % block_size;
97       config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding);
98       config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
99       last_blocks =
100           Pad(last_blocks, Zero(builder, shape.element_type()), config);
101 
102       // Add a singleton dimension
103       // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
104       TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
105       auto shape_dims = AsInt64Slice(blocks_shape.dimensions());
106       auto last_blocks_dims = std::vector<int64>(ndims);
107       std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
108       last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
109       last_blocks = Reshape(last_blocks, last_blocks_dims);
110 
111       // Concatenate with the other blocks if necessary
112       if (n > block_size) {
113         diag_blocks =
114             ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
115       } else {
116         diag_blocks = last_blocks;
117       }
118     }
119 
120     return diag_blocks;
121   });
122 }
123 
InvertDiagonalBlocks(XlaOp diag_blocks,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)124 XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a,
125                            bool conjugate_a,
126                            PrecisionConfig::Precision precision) {
127   XlaBuilder* builder = diag_blocks.builder();
128   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
129     // Input is a batch of square lower triangular square matrices. Its shape is
130     // (..., size, size). We resize this to (num_blocks, size, size).
131     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
132     int64 block_size = ShapeUtil::GetDimension(shape, -1);
133     int64 num_blocks = ShapeUtil::ElementsIn(shape) /
134                        tensorflow::MathUtil::IPow(block_size, 2);
135     diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
136 
137     // The input must be triangular because we rely on that when doing
138     // multiplications later on
139     diag_blocks = Triangle(diag_blocks, /*lower=*/lower);
140 
141     // Rescale blocks to be unit triangular, but avoid dividing by
142     // zero (which can happen if the last block was padded) otherwise it will
143     // introduce nans which will propagate
144     auto diags = GetMatrixDiagonal(diag_blocks);
145     auto ones = FullLike(diags, 1);
146     diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
147     auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
148 
149     // We can now use the fact that for an upper triangular matrix
150     // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
151     // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
152     // have been rescaled to be unit triangular, so L22 = L22' = 1.
153 
154     // Initialize the output matrix with -1s on the diagonal. We use -1 instead
155     // of 1 because we cannot do matrix-vector multiplies with variable shapes
156     // inside of a loop, or do irregularly shaped in-place updates. Hence,
157     // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
158     // entire row i.e. we calculate
159     // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
160     // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
161     auto identity =
162         IdentityMatrix(builder, shape.element_type(), block_size, block_size);
163     auto neg_identity = -identity;
164 
165     // The first or last  diagonal element should be set to 1 instead of -1
166     // though, since we never update it
167     auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
168     auto start_index = ConstantR0<int>(builder, (lower) ? 0 : block_size - 1);
169     auto output_block =
170         DynamicUpdateSlice(neg_identity, pos_one,
171                            /*start_indices=*/{start_index, start_index});
172 
173     // Broadcast diag([1, -1, -1, ...]) to every block
174     XlaOp output = Broadcast(output_block,
175                              /*broadcast_sizes=*/{num_blocks});
176 
177     // Now we construct a loop that performs matrix-vector multiplications
178     // inverting the blocks one row at a time
179     std::vector<Shape> tuple_shapes = {
180         // The loop iteration counter is a scalar, incremented each iteration.
181         ShapeUtil::MakeShape(S32, {}),
182         // The output has the shape of A, with one row updated each iteration.
183         ShapeUtil::MakeShape(shape.element_type(),
184                              {num_blocks, block_size, block_size}),
185         // The input is a loop invariant.
186         ShapeUtil::MakeShape(shape.element_type(),
187                              {num_blocks, block_size, block_size})};
188     Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
189 
190     auto init_i = One(builder, S32);
191     auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
192 
193     // Construct the loop condition function.
194     std::unique_ptr<XlaBuilder> condb =
195         builder->CreateSubBuilder("InvertDiagCond");
196     {
197       auto i = GetTupleElement(
198           Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
199       Lt(i, ConstantR0<int32>(condb.get(), block_size));
200     }
201     TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
202 
203     // Construct the loop body function.
204     std::unique_ptr<XlaBuilder> bodyb =
205         builder->CreateSubBuilder("InvertDiagBody");
206     {
207       auto input_tuple =
208           Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
209 
210       auto i = GetTupleElement(input_tuple, 0);
211       auto body_out = GetTupleElement(input_tuple, 1);
212       auto body_input = GetTupleElement(input_tuple, 2);
213 
214       auto zero = ConstantR0<int32>(bodyb.get(), 0);
215       auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i;
216       auto input_row =
217           DynamicSlice(body_input, {zero, j, zero},
218                        /*slice_sizes=*/{num_blocks, 1, block_size});
219 
220       // We want -L21 L11^{-1}
221       DotDimensionNumbers dnums;
222       dnums.add_lhs_batch_dimensions(0);
223       dnums.add_rhs_batch_dimensions(0);
224       dnums.add_lhs_contracting_dimensions(2);
225       dnums.add_rhs_contracting_dimensions(1);
226       PrecisionConfig precision_proto;
227       precision_proto.add_operand_precision(precision);
228       precision_proto.add_operand_precision(precision);
229       auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
230 
231       body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
232 
233       auto next_i = i + ScalarLike(i, 1);
234       Tuple(bodyb.get(), {next_i, body_out, body_input});
235     }
236     TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
237 
238     // Construct the While loop and return the result,
239     // return while_loop(cond_fun, body_fun, init)[1]
240     auto invert_while = While(cond, body, init);
241     auto inv_diag_blocks = GetTupleElement(invert_while, 1);
242 
243     // Undo the scaling
244     inv_diag_blocks = Div(inv_diag_blocks, diags,
245                           /*broadcast_dimensions=*/{0, 1});
246 
247     // Reshape back to original batch major dimensions
248     return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
249   });
250 }
251 
SolveWithInvertedDiagonalBlocks(XlaOp a,XlaOp b,XlaOp inv_diag_blocks,bool left_side,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)252 XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
253                                       bool left_side, bool lower,
254                                       bool transpose_a, bool conjugate_a,
255                                       PrecisionConfig::Precision precision) {
256   XlaBuilder* builder = a.builder();
257   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
258     TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks));
259     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
260     int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1);
261 
262     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
263     int64 ndims = a_shape.rank();
264     int64 n = ShapeUtil::GetDimension(a_shape, -1);
265     int64 num_blocks = n / block_size + (n % block_size != 0);
266     int64 m_dim = (left_side) ? -1 : -2;
267     int64 m = ShapeUtil::GetDimension(b_shape, m_dim);
268 
269     // Initialize the solution
270     auto x = ZerosLike(b);
271 
272     // This loop is unrolled for performance reasons, but it could be expressed
273     // rolled as well since the matrices are of the same size each iteration
274     for (int i = 0; i < num_blocks; i++) {
275       // High-level intuition: We have B[i] = L[i] @ X. Since L is upper
276       // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
277       // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
278       // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
279 
280       // Decide whether we go from first block to last or vice versa
281       auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i;
282 
283       // Get the size of the inverse blocks (the last one might be smaller)
284       int64 block = (n % block_size != 0 && j + 1 == num_blocks)
285                         ? n % block_size
286                         : block_size;
287       auto inv_block =
288           MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
289                                                    {j + 1, block, block}),
290                                   /*dimensions=*/{ndims - 2, ndims - 1}),
291                          conjugate_a);
292 
293       // Get the corresponding row of B
294       int64 k = std::min((j + 1) * block_size, n);
295       std::vector<int64> start = {j * block_size, 0};
296       std::vector<int64> end = {k, m};
297       if (!left_side) {
298         std::swap(start[0], start[1]);
299         std::swap(end[0], end[1]);
300       }
301       auto b_row = SliceInMinorDims(b, start, end);
302 
303       XlaOp remainder;
304       if (i == 0) {
305         remainder = b_row;
306       } else {
307         // This matrix multiply involves a lot of multiplying with zero (namely,
308         // X[i * block_size:] = 0), but this is faster than slicing...
309         end = {k, n};
310         if (!left_side) {
311           std::swap(end[0], end[1]);
312         }
313         if (transpose_a) {
314           std::swap(start[0], start[1]);
315           std::swap(end[0], end[1]);
316         }
317         auto a_row =
318             MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
319         if (left_side) {
320           remainder =
321               b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x,
322                                precision);
323         } else {
324           remainder =
325               b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a),
326                                precision);
327         }
328       }
329 
330       XlaOp x_update;
331       auto zero = Zero(builder, S32);
332       auto start_index = ConstantR0WithType(builder, S32, j * block_size);
333       std::vector<XlaOp> update_starts = {start_index, zero};
334       if (left_side) {
335         x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a),
336                             remainder, precision);
337       } else {
338         x_update = BatchDot(remainder,
339                             MaybeTransposeInMinorDims(inv_block, transpose_a),
340                             precision);
341         std::swap(update_starts[0], update_starts[1]);
342       }
343       x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts);
344     }
345 
346     return x;
347   });
348 }
349 
BuildTriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,int64 block_size,PrecisionConfig::Precision precision)350 XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
351                            bool transpose_a, bool conjugate_a,
352                            bool unit_diagonal, int64 block_size,
353                            PrecisionConfig::Precision precision) {
354   XlaBuilder* builder = a.builder();
355   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
356     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
357     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
358     if (a_shape.rank() != b_shape.rank()) {
359       return InvalidArgument(
360           "Arguments to TriangularSolve have shapes with different ranks: "
361           "%s vs. %s",
362           ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
363     }
364     const int64 ndims = a_shape.rank();
365     if (ndims < 2) {
366       return InvalidArgument(
367           "Arguments to TriangularSolve was rank %d but must have rank >= 2.",
368           ndims);
369     }
370     // The batch dimensions must be equal.
371     std::vector<int64> batch_dimensions;
372     for (int i = 0; i < ndims - 2; ++i) {
373       int64 a_size = a_shape.dimensions(i);
374       int64 b_size = b_shape.dimensions(i);
375       if (a_size != b_size) {
376         return InvalidArgument(
377             "Batch dimensions of arguments to TriangularSolve must be equal; "
378             "shapes were %s and %s.",
379             ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
380       }
381       batch_dimensions.push_back(a_size);
382     }
383 
384     if (ShapeUtil::GetDimension(a_shape, -1) !=
385         ShapeUtil::GetDimension(a_shape, -2)) {
386       return InvalidArgument(
387           "The 'a' argument to TriangularSolve must be a batched square matrix;"
388           " shape was: %s",
389           ShapeUtil::HumanString(a_shape));
390     }
391     const int64 m = ShapeUtil::GetDimension(b_shape, -2);
392     const int64 n = ShapeUtil::GetDimension(b_shape, -1);
393     if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) {
394       return InvalidArgument(
395           "Arguments to TriangularSolve have incompatible matrix shapes %s and "
396           "%s",
397           ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
398     }
399 
400     if (block_size < 1) {
401       return InvalidArgument(
402           "block_size argument to TriangularSolve must be >= 1; got %d",
403           block_size);
404     }
405 
406     block_size = std::max(
407         int64{1}, std::min(block_size, ShapeUtil::GetDimension(a_shape, -1)));
408 
409     if (ShapeUtil::IsZeroElementArray(b_shape)) {
410       // The output has the same shape as 'b', and since the output has zero
411       // elements, any such array will do.
412       return b;
413     }
414 
415     // TODO(phawkins): consider pushing triangle masking into
416     // InvertDiagonalBlocks.
417     if (unit_diagonal) {
418       // Mask everything but the subdiagonal/superdiagonal elements.
419       a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
420                 : Select(TriangleMask(a, 0), ZerosLike(a), a);
421       int64 k = ShapeUtil::GetDimension(a_shape, -1);
422       a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
423                    /*broadcast_dimensions=*/{ndims - 2, ndims - 1});
424     } else {
425       // Mask off the ignored elements of the triangular matrix a.
426       a = Triangle(a, lower);
427     }
428 
429     // We find the diagonal blocks of the coefficient matrix
430     auto diag_blocks = DiagonalBlocks(a, block_size);
431 
432     // We invert these blocks in parallel using batched matrix-vector products
433     auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
434                                                 conjugate_a, precision);
435 
436     // We now find the solution using GEMMs
437     auto x =
438         SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower,
439                                         transpose_a, conjugate_a, precision);
440 
441     return x;
442   });
443 }
444 
445 }  // namespace
446 
InstructionMatchesPattern(HloInstruction * instruction)447 bool TriangularSolveExpander::InstructionMatchesPattern(
448     HloInstruction* instruction) {
449   return instruction->opcode() == HloOpcode::kTriangularSolve;
450 }
451 
ExpandInstruction(HloInstruction * instruction)452 StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
453     HloInstruction* instruction) {
454   const TriangularSolveOptions& options =
455       instruction->triangular_solve_options();
456   const string name = absl::StrFormat(
457       "xla.triangular_solve_%s_%s_%s_%s_%s_%s",
458       instruction->operand(0)->shape().ToString(),
459       instruction->operand(1)->shape().ToString(),
460       options.left_side() ? "left" : "right",
461       options.lower() ? "lower" : "upper",
462       TriangularSolveOptions_Transpose_Name(options.transpose_a()),
463       options.unit_diagonal() ? "unit" : "nonunit");
464 
465   HloModule* module = instruction->parent()->parent();
466 
467   HloComputation*& computation =
468       computation_cache_.emplace(name, nullptr).first->second;
469   if (!computation) {
470     // Builds a new expansion.
471     //
472     // We do something unusual here: we build the computation using the
473     // XlaBuilder API, which is nominally an XLA client API. We do this because
474     // the external APIs for building complicated computations (XlaBuilder)
475     // are much more ergonomic than the internal ones. As it turns out,
476     // XlaBuilder isn't really a client API—what it does is build a
477     // HloModuleProto protocol buffer, that we can then deserialize and clone
478     // into our HloModule. Ideally we would avoid the protocol buffer step;
479     // that is left as an exercise for future work.
480     XlaBuilder builder(name);
481     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
482     XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b");
483     bool transpose_a =
484         options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE;
485     bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT;
486 
487     BuildTriangularSolve(a, b, options.left_side(), options.lower(),
488                          transpose_a, conjugate_a, options.unit_diagonal(),
489                          /*block_size=*/128,
490                          /*precision=*/PrecisionConfig::HIGHEST);
491     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
492 
493     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
494                         xla_computation.GetProgramShape());
495     HloModuleConfig config(program_shape);
496     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
497                                              xla_computation.proto(), config));
498     HloCloneContext context(module);
499     computation =
500         module->DeepCloneComputation(new_module->entry_computation(), &context);
501   }
502 
503   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
504       instruction->shape(), instruction->operands(), computation));
505 }
506 
507 }  // namespace xla
508