1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/loops.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/primitive_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34
35 namespace xla {
36
37 namespace {
38
39 // The Cholesky–Banachiewicz algorithm. See
40 // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
41 // for a description.
42 //
43 // def cholesky_unblocked(a):
44 // assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
45 // n = a.shape[-2]
46 // l = np.zeros_like(a)
47 // for j in xrange(n):
48 // row = l[..., j, :j]
49 // row_t = np.swapaxes(row, -1, -2)
50 // l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(row, row_t))
51 // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
52 // l[..., j, j]
53 // return l
CholeskyUnblocked(XlaOp a,PrecisionConfig::Precision precision)54 XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
55 XlaBuilder* builder = a.builder();
56 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
57 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
58 const int n_dims = a_shape.rank();
59 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
60 auto major_dims = AsInt64Slice(a_shape.dimensions())
61 .subspan(
62 /*pos=*/0,
63 /*len=*/n_dims - 2);
64
65 XlaOp l = ZerosLike(a);
66
67 // Construct the for loop body to iterate over rows.
68 auto body_fn =
69 [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
70 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
71 std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end());
72 std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end());
73 row_shape_dims.push_back(1);
74 row_shape_dims.push_back(n);
75 auto mask_zeros_row =
76 Zeros(body_builder,
77 ShapeUtil::MakeShape(a_shape.element_type(), row_shape_dims));
78
79 col_shape_dims.push_back(n);
80 col_shape_dims.push_back(1);
81 auto mask_zeros_col =
82 Zeros(body_builder,
83 ShapeUtil::MakeShape(a_shape.element_type(), col_shape_dims));
84
85 auto mask_range_row =
86 Iota(body_builder, ShapeUtil::MakeShape(S32, row_shape_dims),
87 /*iota_dimension=*/n_dims - 1);
88 auto mask_range_col =
89 Iota(body_builder, ShapeUtil::MakeShape(S32, col_shape_dims),
90 /*iota_dimension=*/n_dims - 2);
91 auto body_a = loop_vars[0];
92 auto body_l = loop_vars[1];
93
94 // row = l[..., i, :i]
95 // select the whole i-th row, then mask out all columns past i-1
96 auto zero = ConstantR0<int32>(body_builder, 0);
97 auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
98 auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i);
99 // a[..., i, i]
100 auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
101 // np.dot(row, np.swapaxes(row, -1, -2))
102 auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision);
103 // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
104 // np.swapaxes(row, -1, -2)))
105 auto l_ii = Sqrt(a_ii - diag_dot);
106
107 // a[..., i+1:, i]
108 // select the whole i-th column, then mask out all rows above i+1
109 auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
110 auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i);
111
112 // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
113 // l[..., i, i]
114 // The columns in [i, n] are zeroed out in `row`, so we just have to
115 // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i],
116 // r.T)
117 auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision);
118 // np.dot(l[..., i+1:, :i], r.T)
119 auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot);
120
121 body_l =
122 DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
123 // Assign the diagonal after the rest of the column because otherwise the
124 // column assign will wrap around and overwrite the diagonal assign.
125 body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
126
127 return std::vector<XlaOp>{body_a, body_l};
128 };
129
130 TF_ASSIGN_OR_RETURN(
131 auto cholesky_while,
132 ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder));
133
134 return cholesky_while[1];
135 });
136 }
137
BuildCholesky(XlaOp a,int64 block_size,PrecisionConfig::Precision precision)138 XlaOp BuildCholesky(XlaOp a, int64 block_size,
139 PrecisionConfig::Precision precision) {
140 XlaBuilder* builder = a.builder();
141 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
142 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
143 const int ndims = a_shape.rank();
144 if (ndims < 2) {
145 return InvalidArgument(
146 "Argument to Cholesky must have rank >= 2; shape was %s",
147 a_shape.ToString());
148 }
149
150 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
151 if (n != ShapeUtil::GetDimension(a_shape, -2)) {
152 return InvalidArgument(
153 "Argument to Cholesky must be batched square matrices; got shape %s",
154 ShapeUtil::HumanString(a_shape));
155 }
156
157 if (primitive_util::IsComplexType(a_shape.element_type())) {
158 return Unimplemented(
159 "Complex types are not implemented in Cholesky; got shape %s",
160 ShapeUtil::HumanString(a_shape));
161 }
162
163 if (block_size < 1) {
164 return InvalidArgument(
165 "block_size argument to Cholesky must be >= 1; got %d", block_size);
166 }
167
168 // Blocked left-looking Cholesky factorization.
169 // Algorithm 1 from
170 // Haidar, Azzam, et al. "High-performance Cholesky factorization for
171 // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
172 XlaOp l = ZerosLike(a);
173 for (int64 i = 0; i < n; i += block_size) {
174 int64 k = std::min(block_size, n - i);
175 if (i > 0) {
176 // TODO(phawkins): consider implementing SYRK for the diagonal part of
177 // the panel.
178 // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
179 auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
180 auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
181 auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision);
182 auto before = SliceInMinorDims(a, {i, i}, {n, i + k});
183 a = UpdateSliceInMinorDims(a, before - delta, {i, i});
184 }
185
186 // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
187 auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k});
188 auto factorized = CholeskyUnblocked(x, precision);
189 l = UpdateSliceInMinorDims(l, factorized, {i, i});
190
191 if (i + k < n) {
192 // l[i+k:, i:i+k] =
193 // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
194 auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k});
195 auto update =
196 TriangularSolve(factorized, panel,
197 /*left_side=*/false,
198 /*lower=*/true,
199 /*unit_diagonal=*/false,
200 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
201 l = UpdateSliceInMinorDims(l, update, {i + k, i});
202 }
203 }
204 return l;
205 });
206 }
207
208 } // namespace
209
InstructionMatchesPattern(HloInstruction * instruction)210 bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
211 return instruction->opcode() == HloOpcode::kCholesky;
212 }
213
ExpandInstruction(HloInstruction * instruction)214 StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
215 HloInstruction* instruction) {
216 const CholeskyOptions& options = instruction->cholesky_options();
217 const string name = absl::StrFormat(
218 "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(),
219 options.lower() ? "lower" : "upper");
220
221 HloModule* module = instruction->parent()->parent();
222
223 HloComputation*& computation =
224 computation_cache_.emplace(name, nullptr).first->second;
225 if (!computation) {
226 // Builds a new expansion.
227 //
228 // TODO(b/62327888): We do something unusual here: we build the computation
229 // using the XlaBuilder API, which is nominally an XLA client API. We do
230 // this because the external APIs for building complicated computations
231 // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
232 // out, XlaBuilder isn't really a client API—what it does is build a
233 // HloModuleProto protocol buffer, that we can then deserialize and clone
234 // into our HloModule. Ideally we would avoid the protocol buffer step;
235 // that is left as an exercise for future work.
236 XlaBuilder builder(name);
237 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
238 XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()),
239 /*block_size=*/128,
240 /*precision=*/PrecisionConfig::HIGHEST);
241 MaybeTransposeInMinorDims(l, !options.lower());
242
243 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
244
245 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
246 xla_computation.GetProgramShape());
247 HloModuleConfig config(program_shape);
248 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
249 xla_computation.proto(), config));
250 HloCloneContext context(module);
251 computation =
252 module->DeepCloneComputation(new_module->entry_computation(), &context);
253 }
254
255 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
256 instruction->shape(), instruction->operands(), computation));
257 }
258
259 } // namespace xla
260