1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/math.h"
23 #include "tensorflow/compiler/xla/client/lib/matrix.h"
24 #include "tensorflow/compiler/xla/client/lib/slicing.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/math/math_util.h"
33
34 namespace xla {
35
36 namespace {
37
38 // Get the diagonal blocks of the coefficient matrix
DiagonalBlocks(XlaOp a,int64 block_size)39 XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
40 XlaBuilder* builder = a.builder();
41 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
42 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
43 int ndims = shape.rank();
44 int64 n = ShapeUtil::GetDimension(shape, -1);
45 int64 num_blocks = n / block_size;
46
47 XlaOp diag_blocks;
48
49 // If the coefficient matrix is exactly the block size, we just add a
50 // singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
51 if (n == block_size) {
52 std::vector<int64> permutation(ndims);
53 std::iota(permutation.begin(), permutation.end(), 1);
54 permutation.insert(permutation.end() - 2, 0);
55 return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
56 }
57
58 // We can grab entire blocks using gather
59 if (n > block_size) {
60 // Construct the starting indices of the diagonal blocks
61 auto start_indices =
62 Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks),
63 ConstantR0<int32>(builder, block_size)),
64 /*broadcast_sizes=*/{2}),
65 /*permutation=*/{1, 0});
66
67 PaddingConfig padding_config =
68 MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}});
69 start_indices =
70 Pad(start_indices, ConstantR0<int32>(builder, 0), padding_config);
71
72 // Gather the diagonal blocks
73 std::vector<int64> slice_sizes(ndims);
74 GatherDimensionNumbers dim_numbers;
75 for (int i = 0; i < ndims - 2; ++i) {
76 dim_numbers.add_offset_dims(i);
77 dim_numbers.add_start_index_map(i);
78 slice_sizes[i] = ShapeUtil::GetDimension(shape, i);
79 }
80 slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size;
81 dim_numbers.add_offset_dims(ndims - 1);
82 dim_numbers.add_offset_dims(ndims);
83 dim_numbers.add_start_index_map(ndims - 2);
84 dim_numbers.add_start_index_map(ndims - 1);
85 dim_numbers.set_index_vector_dim(1);
86 diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
87 }
88
89 // The last block might be smaller than the block size,
90 // so we will need to pad it
91 if (n % block_size != 0) {
92 // Pad with zeros
93 auto last_blocks =
94 SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
95 PaddingConfig config = MakeNoPaddingConfig(ndims);
96 int64 padding = block_size - n % block_size;
97 config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding);
98 config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
99 last_blocks =
100 Pad(last_blocks, Zero(builder, shape.element_type()), config);
101
102 // Add a singleton dimension
103 // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
104 TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
105 auto shape_dims = AsInt64Slice(blocks_shape.dimensions());
106 auto last_blocks_dims = std::vector<int64>(ndims);
107 std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
108 last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
109 last_blocks = Reshape(last_blocks, last_blocks_dims);
110
111 // Concatenate with the other blocks if necessary
112 if (n > block_size) {
113 diag_blocks =
114 ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
115 } else {
116 diag_blocks = last_blocks;
117 }
118 }
119
120 return diag_blocks;
121 });
122 }
123
InvertDiagonalBlocks(XlaOp diag_blocks,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)124 XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a,
125 bool conjugate_a,
126 PrecisionConfig::Precision precision) {
127 XlaBuilder* builder = diag_blocks.builder();
128 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
129 // Input is a batch of square lower triangular square matrices. Its shape is
130 // (..., size, size). We resize this to (num_blocks, size, size).
131 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
132 int64 block_size = ShapeUtil::GetDimension(shape, -1);
133 int64 num_blocks = ShapeUtil::ElementsIn(shape) /
134 tensorflow::MathUtil::IPow(block_size, 2);
135 diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
136
137 // The input must be triangular because we rely on that when doing
138 // multiplications later on
139 diag_blocks = Triangle(diag_blocks, /*lower=*/lower);
140
141 // Rescale blocks to be unit triangular, but avoid dividing by
142 // zero (which can happen if the last block was padded) otherwise it will
143 // introduce nans which will propagate
144 auto diags = GetMatrixDiagonal(diag_blocks);
145 auto ones = FullLike(diags, 1);
146 diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
147 auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
148
149 // We can now use the fact that for an upper triangular matrix
150 // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
151 // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
152 // have been rescaled to be unit triangular, so L22 = L22' = 1.
153
154 // Initialize the output matrix with -1s on the diagonal. We use -1 instead
155 // of 1 because we cannot do matrix-vector multiplies with variable shapes
156 // inside of a loop, or do irregularly shaped in-place updates. Hence,
157 // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
158 // entire row i.e. we calculate
159 // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
160 // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
161 auto identity =
162 IdentityMatrix(builder, shape.element_type(), block_size, block_size);
163 auto neg_identity = -identity;
164
165 // The first or last diagonal element should be set to 1 instead of -1
166 // though, since we never update it
167 auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
168 auto start_index = ConstantR0<int>(builder, (lower) ? 0 : block_size - 1);
169 auto output_block =
170 DynamicUpdateSlice(neg_identity, pos_one,
171 /*start_indices=*/{start_index, start_index});
172
173 // Broadcast diag([1, -1, -1, ...]) to every block
174 XlaOp output = Broadcast(output_block,
175 /*broadcast_sizes=*/{num_blocks});
176
177 // Now we construct a loop that performs matrix-vector multiplications
178 // inverting the blocks one row at a time
179 std::vector<Shape> tuple_shapes = {
180 // The loop iteration counter is a scalar, incremented each iteration.
181 ShapeUtil::MakeShape(S32, {}),
182 // The output has the shape of A, with one row updated each iteration.
183 ShapeUtil::MakeShape(shape.element_type(),
184 {num_blocks, block_size, block_size}),
185 // The input is a loop invariant.
186 ShapeUtil::MakeShape(shape.element_type(),
187 {num_blocks, block_size, block_size})};
188 Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
189
190 auto init_i = One(builder, S32);
191 auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
192
193 // Construct the loop condition function.
194 std::unique_ptr<XlaBuilder> condb =
195 builder->CreateSubBuilder("InvertDiagCond");
196 {
197 auto i = GetTupleElement(
198 Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
199 Lt(i, ConstantR0<int32>(condb.get(), block_size));
200 }
201 TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
202
203 // Construct the loop body function.
204 std::unique_ptr<XlaBuilder> bodyb =
205 builder->CreateSubBuilder("InvertDiagBody");
206 {
207 auto input_tuple =
208 Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
209
210 auto i = GetTupleElement(input_tuple, 0);
211 auto body_out = GetTupleElement(input_tuple, 1);
212 auto body_input = GetTupleElement(input_tuple, 2);
213
214 auto zero = ConstantR0<int32>(bodyb.get(), 0);
215 auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i;
216 auto input_row =
217 DynamicSlice(body_input, {zero, j, zero},
218 /*slice_sizes=*/{num_blocks, 1, block_size});
219
220 // We want -L21 L11^{-1}
221 DotDimensionNumbers dnums;
222 dnums.add_lhs_batch_dimensions(0);
223 dnums.add_rhs_batch_dimensions(0);
224 dnums.add_lhs_contracting_dimensions(2);
225 dnums.add_rhs_contracting_dimensions(1);
226 PrecisionConfig precision_proto;
227 precision_proto.add_operand_precision(precision);
228 precision_proto.add_operand_precision(precision);
229 auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
230
231 body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
232
233 auto next_i = i + ScalarLike(i, 1);
234 Tuple(bodyb.get(), {next_i, body_out, body_input});
235 }
236 TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
237
238 // Construct the While loop and return the result,
239 // return while_loop(cond_fun, body_fun, init)[1]
240 auto invert_while = While(cond, body, init);
241 auto inv_diag_blocks = GetTupleElement(invert_while, 1);
242
243 // Undo the scaling
244 inv_diag_blocks = Div(inv_diag_blocks, diags,
245 /*broadcast_dimensions=*/{0, 1});
246
247 // Reshape back to original batch major dimensions
248 return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
249 });
250 }
251
SolveWithInvertedDiagonalBlocks(XlaOp a,XlaOp b,XlaOp inv_diag_blocks,bool left_side,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)252 XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
253 bool left_side, bool lower,
254 bool transpose_a, bool conjugate_a,
255 PrecisionConfig::Precision precision) {
256 XlaBuilder* builder = a.builder();
257 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
258 TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks));
259 TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
260 int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1);
261
262 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
263 int64 ndims = a_shape.rank();
264 int64 n = ShapeUtil::GetDimension(a_shape, -1);
265 int64 num_blocks = n / block_size + (n % block_size != 0);
266 int64 m_dim = (left_side) ? -1 : -2;
267 int64 m = ShapeUtil::GetDimension(b_shape, m_dim);
268
269 // Initialize the solution
270 auto x = ZerosLike(b);
271
272 // This loop is unrolled for performance reasons, but it could be expressed
273 // rolled as well since the matrices are of the same size each iteration
274 for (int i = 0; i < num_blocks; i++) {
275 // High-level intuition: We have B[i] = L[i] @ X. Since L is upper
276 // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
277 // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
278 // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
279
280 // Decide whether we go from first block to last or vice versa
281 auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i;
282
283 // Get the size of the inverse blocks (the last one might be smaller)
284 int64 block = (n % block_size != 0 && j + 1 == num_blocks)
285 ? n % block_size
286 : block_size;
287 auto inv_block =
288 MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
289 {j + 1, block, block}),
290 /*dimensions=*/{ndims - 2, ndims - 1}),
291 conjugate_a);
292
293 // Get the corresponding row of B
294 int64 k = std::min((j + 1) * block_size, n);
295 std::vector<int64> start = {j * block_size, 0};
296 std::vector<int64> end = {k, m};
297 if (!left_side) {
298 std::swap(start[0], start[1]);
299 std::swap(end[0], end[1]);
300 }
301 auto b_row = SliceInMinorDims(b, start, end);
302
303 XlaOp remainder;
304 if (i == 0) {
305 remainder = b_row;
306 } else {
307 // This matrix multiply involves a lot of multiplying with zero (namely,
308 // X[i * block_size:] = 0), but this is faster than slicing...
309 end = {k, n};
310 if (!left_side) {
311 std::swap(end[0], end[1]);
312 }
313 if (transpose_a) {
314 std::swap(start[0], start[1]);
315 std::swap(end[0], end[1]);
316 }
317 auto a_row =
318 MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
319 if (left_side) {
320 remainder =
321 b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x,
322 precision);
323 } else {
324 remainder =
325 b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a),
326 precision);
327 }
328 }
329
330 XlaOp x_update;
331 auto zero = Zero(builder, S32);
332 auto start_index = ConstantR0WithType(builder, S32, j * block_size);
333 std::vector<XlaOp> update_starts = {start_index, zero};
334 if (left_side) {
335 x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a),
336 remainder, precision);
337 } else {
338 x_update = BatchDot(remainder,
339 MaybeTransposeInMinorDims(inv_block, transpose_a),
340 precision);
341 std::swap(update_starts[0], update_starts[1]);
342 }
343 x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts);
344 }
345
346 return x;
347 });
348 }
349
BuildTriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,int64 block_size,PrecisionConfig::Precision precision)350 XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
351 bool transpose_a, bool conjugate_a,
352 bool unit_diagonal, int64 block_size,
353 PrecisionConfig::Precision precision) {
354 XlaBuilder* builder = a.builder();
355 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
356 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
357 TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
358 if (a_shape.rank() != b_shape.rank()) {
359 return InvalidArgument(
360 "Arguments to TriangularSolve have shapes with different ranks: "
361 "%s vs. %s",
362 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
363 }
364 const int64 ndims = a_shape.rank();
365 if (ndims < 2) {
366 return InvalidArgument(
367 "Arguments to TriangularSolve was rank %d but must have rank >= 2.",
368 ndims);
369 }
370 // The batch dimensions must be equal.
371 std::vector<int64> batch_dimensions;
372 for (int i = 0; i < ndims - 2; ++i) {
373 int64 a_size = a_shape.dimensions(i);
374 int64 b_size = b_shape.dimensions(i);
375 if (a_size != b_size) {
376 return InvalidArgument(
377 "Batch dimensions of arguments to TriangularSolve must be equal; "
378 "shapes were %s and %s.",
379 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
380 }
381 batch_dimensions.push_back(a_size);
382 }
383
384 if (ShapeUtil::GetDimension(a_shape, -1) !=
385 ShapeUtil::GetDimension(a_shape, -2)) {
386 return InvalidArgument(
387 "The 'a' argument to TriangularSolve must be a batched square matrix;"
388 " shape was: %s",
389 ShapeUtil::HumanString(a_shape));
390 }
391 const int64 m = ShapeUtil::GetDimension(b_shape, -2);
392 const int64 n = ShapeUtil::GetDimension(b_shape, -1);
393 if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) {
394 return InvalidArgument(
395 "Arguments to TriangularSolve have incompatible matrix shapes %s and "
396 "%s",
397 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
398 }
399
400 if (block_size < 1) {
401 return InvalidArgument(
402 "block_size argument to TriangularSolve must be >= 1; got %d",
403 block_size);
404 }
405
406 block_size = std::max(
407 int64{1}, std::min(block_size, ShapeUtil::GetDimension(a_shape, -1)));
408
409 if (ShapeUtil::IsZeroElementArray(b_shape)) {
410 // The output has the same shape as 'b', and since the output has zero
411 // elements, any such array will do.
412 return b;
413 }
414
415 // TODO(phawkins): consider pushing triangle masking into
416 // InvertDiagonalBlocks.
417 if (unit_diagonal) {
418 // Mask everything but the subdiagonal/superdiagonal elements.
419 a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
420 : Select(TriangleMask(a, 0), ZerosLike(a), a);
421 int64 k = ShapeUtil::GetDimension(a_shape, -1);
422 a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
423 /*broadcast_dimensions=*/{ndims - 2, ndims - 1});
424 } else {
425 // Mask off the ignored elements of the triangular matrix a.
426 a = Triangle(a, lower);
427 }
428
429 // We find the diagonal blocks of the coefficient matrix
430 auto diag_blocks = DiagonalBlocks(a, block_size);
431
432 // We invert these blocks in parallel using batched matrix-vector products
433 auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
434 conjugate_a, precision);
435
436 // We now find the solution using GEMMs
437 auto x =
438 SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower,
439 transpose_a, conjugate_a, precision);
440
441 return x;
442 });
443 }
444
445 } // namespace
446
InstructionMatchesPattern(HloInstruction * instruction)447 bool TriangularSolveExpander::InstructionMatchesPattern(
448 HloInstruction* instruction) {
449 return instruction->opcode() == HloOpcode::kTriangularSolve;
450 }
451
ExpandInstruction(HloInstruction * instruction)452 StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
453 HloInstruction* instruction) {
454 const TriangularSolveOptions& options =
455 instruction->triangular_solve_options();
456 const string name = absl::StrFormat(
457 "xla.triangular_solve_%s_%s_%s_%s_%s_%s",
458 instruction->operand(0)->shape().ToString(),
459 instruction->operand(1)->shape().ToString(),
460 options.left_side() ? "left" : "right",
461 options.lower() ? "lower" : "upper",
462 TriangularSolveOptions_Transpose_Name(options.transpose_a()),
463 options.unit_diagonal() ? "unit" : "nonunit");
464
465 HloModule* module = instruction->parent()->parent();
466
467 HloComputation*& computation =
468 computation_cache_.emplace(name, nullptr).first->second;
469 if (!computation) {
470 // Builds a new expansion.
471 //
472 // We do something unusual here: we build the computation using the
473 // XlaBuilder API, which is nominally an XLA client API. We do this because
474 // the external APIs for building complicated computations (XlaBuilder)
475 // are much more ergonomic than the internal ones. As it turns out,
476 // XlaBuilder isn't really a client API—what it does is build a
477 // HloModuleProto protocol buffer, that we can then deserialize and clone
478 // into our HloModule. Ideally we would avoid the protocol buffer step;
479 // that is left as an exercise for future work.
480 XlaBuilder builder(name);
481 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
482 XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b");
483 bool transpose_a =
484 options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE;
485 bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT;
486
487 BuildTriangularSolve(a, b, options.left_side(), options.lower(),
488 transpose_a, conjugate_a, options.unit_diagonal(),
489 /*block_size=*/128,
490 /*precision=*/PrecisionConfig::HIGHEST);
491 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
492
493 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
494 xla_computation.GetProgramShape());
495 HloModuleConfig config(program_shape);
496 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
497 xla_computation.proto(), config));
498 HloCloneContext context(module);
499 computation =
500 module->DeepCloneComputation(new_module->entry_computation(), &context);
501 }
502
503 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
504 instruction->shape(), instruction->operands(), computation));
505 }
506
507 } // namespace xla
508