1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/loops.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/primitive_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 
35 namespace xla {
36 
37 namespace {
38 
39 // The Cholesky–Banachiewicz algorithm. See
40 // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
41 // for a description.
42 //
43 // def cholesky_unblocked(a):
44 //   assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
45 //   n = a.shape[-2]
46 //   l = np.zeros_like(a)
47 //   for j in xrange(n):
48 //     row = l[..., j, :j]
49 //     row_t = np.swapaxes(row, -1, -2)
50 //     l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(row, row_t))
51 //     l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
52 //                       l[..., j, j]
53 //   return l
CholeskyUnblocked(XlaOp a,PrecisionConfig::Precision precision)54 XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
55   XlaBuilder* builder = a.builder();
56   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
57     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
58     const int n_dims = a_shape.rank();
59     const int64 n = ShapeUtil::GetDimension(a_shape, -1);
60     auto major_dims = AsInt64Slice(a_shape.dimensions())
61                           .subspan(
62                               /*pos=*/0,
63                               /*len=*/n_dims - 2);
64 
65     XlaOp l = ZerosLike(a);
66 
67     // Construct the for loop body to iterate over rows.
68     auto body_fn =
69         [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
70             XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
71       std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end());
72       std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end());
73       row_shape_dims.push_back(1);
74       row_shape_dims.push_back(n);
75       auto mask_zeros_row =
76           Zeros(body_builder,
77                 ShapeUtil::MakeShape(a_shape.element_type(), row_shape_dims));
78 
79       col_shape_dims.push_back(n);
80       col_shape_dims.push_back(1);
81       auto mask_zeros_col =
82           Zeros(body_builder,
83                 ShapeUtil::MakeShape(a_shape.element_type(), col_shape_dims));
84 
85       auto mask_range_row =
86           Iota(body_builder, ShapeUtil::MakeShape(S32, row_shape_dims),
87                /*iota_dimension=*/n_dims - 1);
88       auto mask_range_col =
89           Iota(body_builder, ShapeUtil::MakeShape(S32, col_shape_dims),
90                /*iota_dimension=*/n_dims - 2);
91       auto body_a = loop_vars[0];
92       auto body_l = loop_vars[1];
93 
94       // row = l[..., i, :i]
95       // select the whole i-th row, then mask out all columns past i-1
96       auto zero = ConstantR0<int32>(body_builder, 0);
97       auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
98       auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i);
99       // a[..., i, i]
100       auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
101       // np.dot(row, np.swapaxes(row, -1, -2))
102       auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision);
103       // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
104       //                                              np.swapaxes(row, -1, -2)))
105       auto l_ii = Sqrt(a_ii - diag_dot);
106 
107       // a[..., i+1:, i]
108       // select the whole i-th column, then mask out all rows above i+1
109       auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
110       auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i);
111 
112       // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
113       //                   l[..., i, i]
114       // The columns in [i, n] are zeroed out in `row`, so we just have to
115       // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i],
116       // r.T)
117       auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision);
118       // np.dot(l[..., i+1:, :i], r.T)
119       auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot);
120 
121       body_l =
122           DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
123       // Assign the diagonal after the rest of the column because otherwise the
124       // column assign will wrap around and overwrite the diagonal assign.
125       body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
126 
127       return std::vector<XlaOp>{body_a, body_l};
128     };
129 
130     TF_ASSIGN_OR_RETURN(
131         auto cholesky_while,
132         ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder));
133 
134     return cholesky_while[1];
135   });
136 }
137 
BuildCholesky(XlaOp a,int64 block_size,PrecisionConfig::Precision precision)138 XlaOp BuildCholesky(XlaOp a, int64 block_size,
139                     PrecisionConfig::Precision precision) {
140   XlaBuilder* builder = a.builder();
141   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
142     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
143     const int ndims = a_shape.rank();
144     if (ndims < 2) {
145       return InvalidArgument(
146           "Argument to Cholesky must have rank >= 2; shape was %s",
147           a_shape.ToString());
148     }
149 
150     const int64 n = ShapeUtil::GetDimension(a_shape, -1);
151     if (n != ShapeUtil::GetDimension(a_shape, -2)) {
152       return InvalidArgument(
153           "Argument to Cholesky must be batched square matrices; got shape %s",
154           ShapeUtil::HumanString(a_shape));
155     }
156 
157     if (primitive_util::IsComplexType(a_shape.element_type())) {
158       return Unimplemented(
159           "Complex types are not implemented in Cholesky; got shape %s",
160           ShapeUtil::HumanString(a_shape));
161     }
162 
163     if (block_size < 1) {
164       return InvalidArgument(
165           "block_size argument to Cholesky must be >= 1; got %d", block_size);
166     }
167 
168     // Blocked left-looking Cholesky factorization.
169     // Algorithm 1 from
170     // Haidar, Azzam, et al. "High-performance Cholesky factorization for
171     // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
172     XlaOp l = ZerosLike(a);
173     for (int64 i = 0; i < n; i += block_size) {
174       int64 k = std::min(block_size, n - i);
175       if (i > 0) {
176         // TODO(phawkins): consider implementing SYRK for the diagonal part of
177         // the panel.
178         // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
179         auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
180         auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
181         auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision);
182         auto before = SliceInMinorDims(a, {i, i}, {n, i + k});
183         a = UpdateSliceInMinorDims(a, before - delta, {i, i});
184       }
185 
186       // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
187       auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k});
188       auto factorized = CholeskyUnblocked(x, precision);
189       l = UpdateSliceInMinorDims(l, factorized, {i, i});
190 
191       if (i + k < n) {
192         // l[i+k:, i:i+k] =
193         //     trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
194         auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k});
195         auto update =
196             TriangularSolve(factorized, panel,
197                             /*left_side=*/false,
198                             /*lower=*/true,
199                             /*unit_diagonal=*/false,
200                             /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
201         l = UpdateSliceInMinorDims(l, update, {i + k, i});
202       }
203     }
204     return l;
205   });
206 }
207 
208 }  // namespace
209 
InstructionMatchesPattern(HloInstruction * instruction)210 bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
211   return instruction->opcode() == HloOpcode::kCholesky;
212 }
213 
ExpandInstruction(HloInstruction * instruction)214 StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
215     HloInstruction* instruction) {
216   const CholeskyOptions& options = instruction->cholesky_options();
217   const string name = absl::StrFormat(
218       "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(),
219       options.lower() ? "lower" : "upper");
220 
221   HloModule* module = instruction->parent()->parent();
222 
223   HloComputation*& computation =
224       computation_cache_.emplace(name, nullptr).first->second;
225   if (!computation) {
226     // Builds a new expansion.
227     //
228     // TODO(b/62327888): We do something unusual here: we build the computation
229     // using the XlaBuilder API, which is nominally an XLA client API. We do
230     // this because the external APIs for building complicated computations
231     // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
232     // out, XlaBuilder isn't really a client API—what it does is build a
233     // HloModuleProto protocol buffer, that we can then deserialize and clone
234     // into our HloModule. Ideally we would avoid the protocol buffer step;
235     // that is left as an exercise for future work.
236     XlaBuilder builder(name);
237     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
238     XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()),
239                             /*block_size=*/128,
240                             /*precision=*/PrecisionConfig::HIGHEST);
241     MaybeTransposeInMinorDims(l, !options.lower());
242 
243     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
244 
245     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
246                         xla_computation.GetProgramShape());
247     HloModuleConfig config(program_shape);
248     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
249                                              xla_computation.proto(), config));
250     HloCloneContext context(module);
251     computation =
252         module->DeepCloneComputation(new_module->entry_computation(), &context);
253   }
254 
255   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
256       instruction->shape(), instruction->operands(), computation));
257 }
258 
259 }  // namespace xla
260