1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35
36 namespace xla {
37
38 // The Cholesky–Banachiewicz algorithm. See
39 // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
40 // for a description.
41 //
42 // def cholesky_unblocked(a):
43 // assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
44 // n = a.shape[-2]
45 // l = np.zeros_like(a)
46 // for j in xrange(n):
47 // mask = np.zeros_like(a)
48 // mask[i, k] == 1 when i >= k and k == j
49 // l_square = np.dot(l, l_t)
50 // temp = a - l_square
51 // l[..., j, j] = temp(j, j)
52 // l = temp / l[..., j, j) * mask + l
53 // return l
54 // Returns a (result, error) pair.
CholeskyUnblocked(XlaOp a,PrecisionConfig::Precision precision)55 StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
56 XlaOp a, PrecisionConfig::Precision precision) {
57 XlaBuilder* builder = a.builder();
58 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
59 const int ndims = a_shape.rank();
60 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
61 std::vector<int64> error_dims(a_shape.dimensions().begin(),
62 a_shape.dimensions().end());
63 error_dims.back() = error_dims.at(ndims - 2) = 1;
64
65 auto major_dims = AsInt64Slice(a_shape.dimensions())
66 .subspan(
67 /*pos=*/0,
68 /*len=*/ndims - 2);
69
70 auto matrix_dims = AsInt64Slice(a_shape.dimensions())
71 .subspan(
72 /*pos=*/0,
73 /*len=*/ndims);
74
75 XlaOp l = ZerosLike(a);
76
77 // Construct the for loop body to iterate over rows.
78 auto body_fn = [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
79 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
80 std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end());
81 std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end());
82 auto body_a = loop_vars[0];
83 auto body_l = loop_vars[1];
84 auto seen_error = loop_vars[2];
85 auto iota_row =
86 Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
87 auto iota_col =
88 Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
89
90 auto mask_pred = Ge(iota_col, iota_row);
91 mask_pred = And(mask_pred, Eq(iota_row, i));
92 auto mask_zeros =
93 Zeros(body_builder,
94 ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims));
95 // L * L.T, This matrix has of a lot of multiplying with zero
96 // (namely, L[:, j:] = 0) and redundant computation, but it is faster
97 // than slice.
98 auto l_square =
99 BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
100
101 // A - L*L.T
102 l_square = body_a - l_square;
103 auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
104 if (ShapeUtil::ElementIsComplex(a_shape)) {
105 auto sqrt = Sqrt(Real(l_ii));
106 l_ii = Complex(sqrt, ZerosLike(sqrt));
107 seen_error = Or(seen_error, IsNan(sqrt));
108 } else {
109 l_ii = Sqrt(l_ii);
110 seen_error = Or(seen_error, IsNan(l_ii));
111 }
112 // L = (A - L*L.T) / l_ii * mask + L
113 body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
114
115 return std::vector<XlaOp>{body_a, body_l, seen_error};
116 };
117
118 TF_ASSIGN_OR_RETURN(
119 auto cholesky_while,
120 ForEachIndex(
121 n, S32, body_fn,
122 {a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
123 "unblocked", builder));
124
125 return std::make_pair(cholesky_while[1], cholesky_while[2]);
126 }
127
BuildCholesky(XlaOp a,int64 block_size,PrecisionConfig::Precision precision)128 XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
129 PrecisionConfig::Precision precision) {
130 XlaBuilder* builder = a.builder();
131 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
133 const int ndims = a_shape.rank();
134 if (ndims < 2) {
135 return InvalidArgument(
136 "Argument to Cholesky must have rank >= 2; shape was %s",
137 a_shape.ToString());
138 }
139
140 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
141 if (n != ShapeUtil::GetDimension(a_shape, -2)) {
142 return InvalidArgument(
143 "Argument to Cholesky must be batched square matrices; got shape %s",
144 ShapeUtil::HumanString(a_shape));
145 }
146
147 if (block_size < 1) {
148 return InvalidArgument(
149 "block_size argument to Cholesky must be >= 1; got %d", block_size);
150 }
151
152 std::vector<int64> error_dims(a_shape.dimensions().begin(),
153 a_shape.dimensions().end());
154 error_dims.back() = error_dims.at(ndims - 2) = 1;
155 std::vector<int64> error_dim_indices(ndims);
156 absl::c_iota(error_dim_indices, 0);
157
158 // Blocked left-looking Cholesky factorization.
159 // Algorithm 1 from
160 // Haidar, Azzam, et al. "High-performance Cholesky factorization for
161 // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
162 XlaOp l = ZerosLike(a);
163 XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
164 for (int64 i = 0; i < n; i += block_size) {
165 int64 k = std::min(block_size, n - i);
166 auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
167 if (i > 0) {
168 // TODO(phawkins): consider implementing SYRK for the diagonal part of
169 // the panel.
170 // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
171 auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
172 auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
173 auto delta =
174 BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
175 panel = panel - delta;
176 }
177
178 // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
179 auto x = SliceInMinorDims(panel, {0, 0}, {k, k});
180 XlaOp factorized;
181 // TODO(b/167896062): A failure in one element of a batch shouldn't fail
182 // other elements.
183 XlaOp factorized_error;
184 if (k == 1) {
185 if (ShapeUtil::ElementIsComplex(a_shape)) {
186 auto sqrt = Sqrt(Real(x));
187 factorized = Complex(sqrt, ZerosLike(sqrt));
188 factorized_error = IsNan(sqrt);
189 } else {
190 factorized = Sqrt(x);
191 factorized_error = IsNan(factorized);
192 }
193 } else {
194 TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
195 std::tie(factorized, factorized_error) = tile_output;
196 }
197 seen_error = Or(seen_error, factorized_error);
198 l = UpdateSliceInMinorDims(l, factorized, {i, i});
199
200 if (i + k < n) {
201 // l[i+k:, i:i+k] =
202 // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
203 auto update = TriangularSolve(
204 factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}),
205 /*left_side=*/false,
206 /*lower=*/true,
207 /*unit_diagonal=*/false,
208 /*transpose_a=*/TriangularSolveOptions::ADJOINT);
209 l = UpdateSliceInMinorDims(l, update, {i + k, i});
210 }
211 }
212 return Select(
213 BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
214 FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
215 });
216 }
217
InstructionMatchesPattern(HloInstruction * instruction)218 bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
219 return instruction->opcode() == HloOpcode::kCholesky;
220 }
221
ExpandInstruction(HloInstruction * instruction)222 StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
223 HloInstruction* instruction) {
224 const CholeskyOptions& options = instruction->cholesky_options();
225 const string name = absl::StrFormat(
226 "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(),
227 options.lower() ? "lower" : "upper");
228
229 HloModule* module = instruction->parent()->parent();
230
231 HloComputation*& computation =
232 computation_cache_.emplace(name, nullptr).first->second;
233 if (!computation) {
234 // Builds a new expansion.
235 //
236 // TODO(b/62327888): We do something unusual here: we build the computation
237 // using the XlaBuilder API, which is nominally an XLA client API. We do
238 // this because the external APIs for building complicated computations
239 // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
240 // out, XlaBuilder isn't really a client API—what it does is build a
241 // HloModuleProto protocol buffer, that we can then deserialize and clone
242 // into our HloModule. Ideally we would avoid the protocol buffer step;
243 // that is left as an exercise for future work.
244 XlaBuilder builder(name);
245 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
246 XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()),
247 /*block_size=*/128,
248 /*precision=*/PrecisionConfig::HIGHEST);
249 MaybeTransposeInMinorDims(l, !options.lower());
250
251 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
252
253 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
254 xla_computation.GetProgramShape());
255 HloModuleConfig config(program_shape);
256 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
257 xla_computation.proto(), config));
258 HloCloneContext context(module);
259 computation =
260 module->DeepCloneComputation(new_module->entry_computation(), &context);
261 }
262
263 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
264 instruction->shape(), instruction->operands(), computation));
265 }
266
267 } // namespace xla
268