1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 
36 namespace xla {
37 
38 // The Cholesky–Banachiewicz algorithm. See
39 // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
40 // for a description.
41 //
42 // def cholesky_unblocked(a):
43 //   assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
44 //   n = a.shape[-2]
45 //   l = np.zeros_like(a)
46 //   for j in xrange(n):
47 //     mask = np.zeros_like(a)
48 //     mask[i, k] == 1 when i >= k and k == j
49 //     l_square = np.dot(l, l_t)
50 //     temp = a - l_square
51 //     l[..., j, j] = temp(j, j)
52 //     l = temp / l[..., j, j) * mask + l
53 //   return l
54 // Returns a (result, error) pair.
CholeskyUnblocked(XlaOp a,PrecisionConfig::Precision precision)55 StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
56     XlaOp a, PrecisionConfig::Precision precision) {
57   XlaBuilder* builder = a.builder();
58   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
59   const int ndims = a_shape.rank();
60   const int64 n = ShapeUtil::GetDimension(a_shape, -1);
61   std::vector<int64> error_dims(a_shape.dimensions().begin(),
62                                 a_shape.dimensions().end());
63   error_dims.back() = error_dims.at(ndims - 2) = 1;
64 
65   auto major_dims = AsInt64Slice(a_shape.dimensions())
66                         .subspan(
67                             /*pos=*/0,
68                             /*len=*/ndims - 2);
69 
70   auto matrix_dims = AsInt64Slice(a_shape.dimensions())
71                          .subspan(
72                              /*pos=*/0,
73                              /*len=*/ndims);
74 
75   XlaOp l = ZerosLike(a);
76 
77   // Construct the for loop body to iterate over rows.
78   auto body_fn = [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
79                      XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
80     std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end());
81     std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end());
82     auto body_a = loop_vars[0];
83     auto body_l = loop_vars[1];
84     auto seen_error = loop_vars[2];
85     auto iota_row =
86         Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
87     auto iota_col =
88         Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
89 
90     auto mask_pred = Ge(iota_col, iota_row);
91     mask_pred = And(mask_pred, Eq(iota_row, i));
92     auto mask_zeros =
93         Zeros(body_builder,
94               ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims));
95     // L * L.T, This matrix has of a lot of multiplying with zero
96     // (namely, L[:, j:] = 0) and redundant computation, but it is faster
97     // than slice.
98     auto l_square =
99         BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
100 
101     // A - L*L.T
102     l_square = body_a - l_square;
103     auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
104     if (ShapeUtil::ElementIsComplex(a_shape)) {
105       auto sqrt = Sqrt(Real(l_ii));
106       l_ii = Complex(sqrt, ZerosLike(sqrt));
107       seen_error = Or(seen_error, IsNan(sqrt));
108     } else {
109       l_ii = Sqrt(l_ii);
110       seen_error = Or(seen_error, IsNan(l_ii));
111     }
112     // L = (A - L*L.T) / l_ii * mask + L
113     body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
114 
115     return std::vector<XlaOp>{body_a, body_l, seen_error};
116   };
117 
118   TF_ASSIGN_OR_RETURN(
119       auto cholesky_while,
120       ForEachIndex(
121           n, S32, body_fn,
122           {a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
123           "unblocked", builder));
124 
125   return std::make_pair(cholesky_while[1], cholesky_while[2]);
126 }
127 
BuildCholesky(XlaOp a,int64 block_size,PrecisionConfig::Precision precision)128 XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
129                                       PrecisionConfig::Precision precision) {
130   XlaBuilder* builder = a.builder();
131   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
133     const int ndims = a_shape.rank();
134     if (ndims < 2) {
135       return InvalidArgument(
136           "Argument to Cholesky must have rank >= 2; shape was %s",
137           a_shape.ToString());
138     }
139 
140     const int64 n = ShapeUtil::GetDimension(a_shape, -1);
141     if (n != ShapeUtil::GetDimension(a_shape, -2)) {
142       return InvalidArgument(
143           "Argument to Cholesky must be batched square matrices; got shape %s",
144           ShapeUtil::HumanString(a_shape));
145     }
146 
147     if (block_size < 1) {
148       return InvalidArgument(
149           "block_size argument to Cholesky must be >= 1; got %d", block_size);
150     }
151 
152     std::vector<int64> error_dims(a_shape.dimensions().begin(),
153                                   a_shape.dimensions().end());
154     error_dims.back() = error_dims.at(ndims - 2) = 1;
155     std::vector<int64> error_dim_indices(ndims);
156     absl::c_iota(error_dim_indices, 0);
157 
158     // Blocked left-looking Cholesky factorization.
159     // Algorithm 1 from
160     // Haidar, Azzam, et al. "High-performance Cholesky factorization for
161     // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
162     XlaOp l = ZerosLike(a);
163     XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
164     for (int64 i = 0; i < n; i += block_size) {
165       int64 k = std::min(block_size, n - i);
166       auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
167       if (i > 0) {
168         // TODO(phawkins): consider implementing SYRK for the diagonal part of
169         // the panel.
170         // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
171         auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
172         auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
173         auto delta =
174             BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
175         panel = panel - delta;
176       }
177 
178       // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
179       auto x = SliceInMinorDims(panel, {0, 0}, {k, k});
180       XlaOp factorized;
181       // TODO(b/167896062): A failure in one element of a batch shouldn't fail
182       // other elements.
183       XlaOp factorized_error;
184       if (k == 1) {
185         if (ShapeUtil::ElementIsComplex(a_shape)) {
186           auto sqrt = Sqrt(Real(x));
187           factorized = Complex(sqrt, ZerosLike(sqrt));
188           factorized_error = IsNan(sqrt);
189         } else {
190           factorized = Sqrt(x);
191           factorized_error = IsNan(factorized);
192         }
193       } else {
194         TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
195         std::tie(factorized, factorized_error) = tile_output;
196       }
197       seen_error = Or(seen_error, factorized_error);
198       l = UpdateSliceInMinorDims(l, factorized, {i, i});
199 
200       if (i + k < n) {
201         // l[i+k:, i:i+k] =
202         //     trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
203         auto update = TriangularSolve(
204             factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}),
205             /*left_side=*/false,
206             /*lower=*/true,
207             /*unit_diagonal=*/false,
208             /*transpose_a=*/TriangularSolveOptions::ADJOINT);
209         l = UpdateSliceInMinorDims(l, update, {i + k, i});
210       }
211     }
212     return Select(
213         BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
214         FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
215   });
216 }
217 
InstructionMatchesPattern(HloInstruction * instruction)218 bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
219   return instruction->opcode() == HloOpcode::kCholesky;
220 }
221 
ExpandInstruction(HloInstruction * instruction)222 StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
223     HloInstruction* instruction) {
224   const CholeskyOptions& options = instruction->cholesky_options();
225   const string name = absl::StrFormat(
226       "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(),
227       options.lower() ? "lower" : "upper");
228 
229   HloModule* module = instruction->parent()->parent();
230 
231   HloComputation*& computation =
232       computation_cache_.emplace(name, nullptr).first->second;
233   if (!computation) {
234     // Builds a new expansion.
235     //
236     // TODO(b/62327888): We do something unusual here: we build the computation
237     // using the XlaBuilder API, which is nominally an XLA client API. We do
238     // this because the external APIs for building complicated computations
239     // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
240     // out, XlaBuilder isn't really a client API—what it does is build a
241     // HloModuleProto protocol buffer, that we can then deserialize and clone
242     // into our HloModule. Ideally we would avoid the protocol buffer step;
243     // that is left as an exercise for future work.
244     XlaBuilder builder(name);
245     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
246     XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()),
247                             /*block_size=*/128,
248                             /*precision=*/PrecisionConfig::HIGHEST);
249     MaybeTransposeInMinorDims(l, !options.lower());
250 
251     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
252 
253     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
254                         xla_computation.GetProgramShape());
255     HloModuleConfig config(program_shape);
256     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
257                                              xla_computation.proto(), config));
258     HloCloneContext context(module);
259     computation =
260         module->DeepCloneComputation(new_module->entry_computation(), &context);
261   }
262 
263   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
264       instruction->shape(), instruction->operands(), computation));
265 }
266 
267 }  // namespace xla
268