1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Enable the use of M_* math constants.
17 // NOTE: this must be first in the file to ensure that if cmath is transitively
18 // included by any other header it has the define set on first processing.
19 // https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22 #include <numeric>
23 #include <vector>
24
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
30 #include "mlir-hlo/utils/broadcast_utils.h"
31 #include "mlir/Dialect/SCF/SCF.h"
32 #include "mlir/Dialect/Shape/IR/Shape.h"
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/Attributes.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/ImplicitLocOpBuilder.h"
38 #include "mlir/IR/MLIRContext.h"
39 #include "mlir/IR/OperationSupport.h"
40 #include "mlir/IR/PatternMatch.h"
41 #include "mlir/Transforms/DialectConversion.h"
42
43 namespace mlir {
44 namespace chlo {
45 namespace {
46
47 struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
48 using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertConstantLikeOp49 LogicalResult matchAndRewrite(
50 ConstantLikeOp op, ArrayRef<Value> operands,
51 ConversionPatternRewriter &rewriter) const override {
52 auto result_ty = op.getType().cast<ShapedType>();
53
54 // Unranked uses are not supported. Consider `transform-unranked-hlo`.
55 if (!result_ty.hasRank()) return failure();
56
57 // Lower to MHLO constant if statically shaped.
58 if (result_ty.hasStaticShape()) {
59 rewriter.replaceOpWithNewOp<mhlo::ConstOp>(
60 op, DenseElementsAttr::get(result_ty, op.value()));
61 return success();
62 }
63
64 // Lower to broadcasted constant.
65 ConstantLikeOp::Adaptor transformed(operands);
66 auto loc = op.getLoc();
67 Type extent_tensor_type = shape::getExtentTensorType(op.getContext());
68 Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value());
69 Value uncasted_shape = rewriter.create<shape::ShapeOfOp>(
70 loc, extent_tensor_type, transformed.operand());
71 Type shape_ty =
72 RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
73 Value shape =
74 rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
75 rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
76 op, result_ty, constant, shape, rewriter.getI64TensorAttr({}));
77 return success();
78 }
79 };
80
81 template <typename FTy>
MaterializePolynomialApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,const std::vector<FTy> & coefficients)82 Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
83 Location loc, Value x,
84 const std::vector<FTy> &coefficients) {
85 Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
86 for (FTy c : coefficients) {
87 poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
88 poly = rewriter.create<mhlo::AddOp>(
89 loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
90 }
91 return poly;
92 }
93
94 // Precondition is |x| >= 1. Use erf approximation, otherwise.
95 //
96 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
97 // argument and derive the final approximation for all |x| >= 1.
98 // This implementation is based on Cephes.
MaterializeErfcApproximationF64ForMagnituteGEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)99 Value MaterializeErfcApproximationF64ForMagnituteGEOne(
100 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
101 Value x = args.front();
102 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
103 "expect f64 element type");
104 const double kMaxlog = 7.09782712893383996843E2;
105 const std::vector<double> kErfcPCoefficients{
106 2.46196981473530512524E-10, 5.64189564831068821977E-1,
107 7.46321056442269912687E0, 4.86371970985681366614E1,
108 1.96520832956077098242E2, 5.26445194995477358631E2,
109 9.34528527171957607540E2, 1.02755188689515710272E3,
110 5.57535335369399327526E2};
111 const std::vector<double> kErfcQCoefficients{
112 1.00000000000000000000E0, 1.32281951154744992508E1,
113 8.67072140885989742329E1, 3.54937778887819891062E2,
114 9.75708501743205489753E2, 1.82390916687909736289E3,
115 2.24633760818710981792E3, 1.65666309194161350182E3,
116 5.57535340817727675546E2};
117 const std::vector<double> kErfcRCoefficients{
118 5.64189583547755073984E-1, 1.27536670759978104416E0,
119 5.01905042251180477414E0, 6.16021097993053585195E0,
120 7.40974269950448939160E0, 2.97886665372100240670E0};
121 const std::vector<double> kErfcSCoefficients{
122 1.00000000000000000000E0, 2.26052863220117276590E0,
123 9.39603524938001434673E0, 1.20489539808096656605E1,
124 1.70814450747565897222E1, 9.60896809063285878198E0,
125 3.36907645100081516050E0};
126
127 // Let z = -x^2.
128 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
129 Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
130
131 // Materialize polynomial approximation for x in [1, 8) as
132 // erfc(x) = exp(z) P(|x|) / Q(|x|).
133 Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
134 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
135 Value poly_p = MaterializePolynomialApproximation(rewriter, loc, abs_x,
136 kErfcPCoefficients);
137 Value exp_z_mul_poly_p = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_p);
138 Value poly_q = MaterializePolynomialApproximation(rewriter, loc, abs_x,
139 kErfcQCoefficients);
140 Value erfc_approx_1_8 =
141 rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_p, poly_q);
142
143 // Materialize polynomial approximation for x in >= 8 as
144 // erfc(x) exp(z) R(|x|) / S(|x|).
145 Value poly_r = MaterializePolynomialApproximation(rewriter, loc, abs_x,
146 kErfcRCoefficients);
147 Value exp_z_mul_poly_r = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_r);
148 Value poly_s = MaterializePolynomialApproximation(rewriter, loc, abs_x,
149 kErfcSCoefficients);
150 Value erfc_approx_8_inf =
151 rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_r, poly_s);
152
153 // Combine polynomial approximations for x >= 1.
154 const StringAttr kLT = rewriter.getStringAttr(
155 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
156 Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
157 Value abs_x_lt_8 = rewriter.create<mhlo::CompareOp>(loc, abs_x, eight, kLT);
158 Value erfc_approx = rewriter.create<mhlo::SelectOp>(
159 loc, abs_x_lt_8, erfc_approx_1_8, erfc_approx_8_inf);
160
161 // Clamp to prevent overflow and materialize approximation for large x as
162 // erfc(x) = 0.
163 Value z_lt_neg_maxlog = rewriter.create<mhlo::CompareOp>(
164 loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
165 Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
166 Value erfc_approx_clamped =
167 rewriter.create<mhlo::SelectOp>(loc, z_lt_neg_maxlog, zero, erfc_approx);
168
169 // Derive approximation for x <= -1 as
170 // erfc(x) = 2 - erfc(-x).
171 // Reuse previously materialized approximations all of which take |x| as their
172 // argument.
173 Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
174 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
175 Value two_sub_erfc_approx_clamped =
176 rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
177 return rewriter.create<mhlo::SelectOp>(
178 loc, x_lt_zero, two_sub_erfc_approx_clamped, erfc_approx_clamped);
179 }
180
181 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
182 // This implementation is based on Cephes.
MaterializeErfApproximationF64ForMagnituteLEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)183 Value MaterializeErfApproximationF64ForMagnituteLEOne(
184 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
185 Value x = args.front();
186 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
187 "expect f64 element type");
188 const std::vector<double> kErfTCoefficients{
189 9.60497373987051638749E0, 9.00260197203842689217E1,
190 2.23200534594684319226E3, 7.00332514112805075473E3,
191 5.55923013010394962768E4};
192 const std::vector<double> kErfUCoefficients{
193 1.00000000000000000000E0, 3.35617141647503099647E1,
194 5.21357949780152679795E2, 4.59432382970980127987E3,
195 2.26290000613890934246E4, 4.92673942608635921086E4};
196
197 // Materialize polynomial approximation for |x| <= 1 as
198 // erf(x) = x T(x^2) / U(x^2).
199 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
200 Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
201 kErfTCoefficients);
202 Value x_mul_poly_t = rewriter.create<mhlo::MulOp>(loc, x, poly_t);
203 Value poly_u = MaterializePolynomialApproximation(rewriter, loc, x_sq,
204 kErfUCoefficients);
205 return rewriter.create<mhlo::DivOp>(loc, x_mul_poly_t, poly_u);
206 }
207
208 // This implementation is based on Cephes.
MaterializeErfApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)209 Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
210 Location loc, ValueRange args) {
211 Value x = args.front();
212 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
213 "expect f64 element type");
214
215 // Rely on erf approximation for |x| < 1
216 // erf(x) = erf_approx(x)
217 Value erf_approx =
218 MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
219
220 // Rely on erfc approximation for |x| >= 1 and materialize erf as
221 // erf(x) = 1 - erfc_approx(x)
222 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
223 Value erfc_approx =
224 MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
225 Value erfc_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erfc_approx);
226
227 // Materialize approximation selection based on argument.
228 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
229 const StringAttr kLT = rewriter.getStringAttr(
230 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
231 Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
232 return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_approx,
233 erfc_based_approx);
234 }
235
MaterializeErfcApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)236 Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
237 Location loc, ValueRange args) {
238 Value x = args.front();
239 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
240 "expect f64 element type");
241
242 // Rely on erfc approximation for |x| >= 1
243 // erfc(x) = erfc_approx(x)
244 Value erfc_approx =
245 MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
246
247 // Rely on erf approximation for |x| < 1 and materialize erfc as
248 // erfc(x) = 1 - erf_approx(x)
249 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
250 Value erf_approx =
251 MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
252 Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
253
254 // Materialize approximation selection based on argument.
255 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
256 const StringAttr kLT = rewriter.getStringAttr(
257 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
258 Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
259 return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
260 erfc_approx);
261 }
262
263 // Precondition is |x| >= 1. Use erf approximation, otherwise.
264 //
265 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
266 // argument and derive the final approximation for all |x| >= 1.
267 // This implementation is based on Cephes.
MaterializeErfcApproximationF32ForMagnitudeGEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)268 Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
269 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
270 Value x = args.front();
271 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
272 "expect f32 element type");
273 const double kMaxlog = 88.72283905206835;
274 const std::vector<float> kErfcPCoefficients{
275 +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
276 -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
277 +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
278 };
279 const std::vector<float> kErfcRCoefficients{
280 -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
281 +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
282 -2.820767439740514E-1, +5.641895067754075E-1,
283 };
284
285 // Let z = -x^2.
286 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
287 Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
288
289 // Materialize polynomial approximation for x >= 1 as
290 // erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2)
291 // erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2
292 const StringAttr kLT = rewriter.getStringAttr(
293 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
294 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
295 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
296 Value reciprocal_x_sq = rewriter.create<mhlo::DivOp>(loc, one, x_sq);
297 Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
298 Value one_div_abs_x = rewriter.create<mhlo::DivOp>(loc, one, abs_x);
299 Value exp_z_mul_one_div_abs_x =
300 rewriter.create<mhlo::MulOp>(loc, exp_z, one_div_abs_x);
301 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
302 Value abs_x_lt_two = rewriter.create<mhlo::CompareOp>(loc, abs_x, two, kLT);
303 Value poly_p = MaterializePolynomialApproximation(
304 rewriter, loc, reciprocal_x_sq, kErfcPCoefficients);
305 Value poly_r = MaterializePolynomialApproximation(
306 rewriter, loc, reciprocal_x_sq, kErfcRCoefficients);
307 Value poly =
308 rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_two, poly_p, poly_r);
309 Value erfc_approx =
310 rewriter.create<mhlo::MulOp>(loc, exp_z_mul_one_div_abs_x, poly);
311
312 // Clamp to prevent overflow and materialize approximation for large x as
313 // erfc(x) = 0.
314 Value z_lt_neq_maxlog = rewriter.create<mhlo::CompareOp>(
315 loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
316 Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
317 Value erfc_approx_clamped =
318 rewriter.create<mhlo::SelectOp>(loc, z_lt_neq_maxlog, zero, erfc_approx);
319
320 // Derive approximation for x <= -1 as
321 // erfc(x) = 2 - erfc(-x).
322 // Reuse previously materialized approximations all of which take |x| as their
323 // argument.
324 Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
325 Value two_sub_erfc_approx =
326 rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
327 return rewriter.create<mhlo::SelectOp>(loc, x_lt_zero, two_sub_erfc_approx,
328 erfc_approx_clamped);
329 }
330
331 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
332 // This implementation is based on Cephes.
MaterializeErfApproximationF32ForMagnitudeLEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)333 Value MaterializeErfApproximationF32ForMagnitudeLEOne(
334 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
335 Value x = args.front();
336 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
337 "expect f32 element type");
338 const std::vector<float> kErfTCoefficients{
339 +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
340 -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
341 +1.128379165726710E+0,
342 };
343
344 // Materialize polynomial approximation for |x| <= 1 as
345 // erf(x) = x T(x^2).
346 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
347 Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
348 kErfTCoefficients);
349 return rewriter.create<mhlo::MulOp>(loc, x, poly_t);
350 }
351
352 // This is the same approximation as used in Eigen.
MaterializeErfApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)353 Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
354 Location loc, ValueRange args) {
355 Value x = args.front();
356 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
357 "expect f32 element type");
358 const std::vector<float> kAlpha{
359 -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
360 -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
361 -1.60960333262415e-02f,
362 };
363 const std::vector<float> kBeta{
364 -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
365 -7.37332916720468e-03f, -1.42647390514189e-02f,
366 };
367
368 // Clamp argument between -4 and 4.
369 Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
370 Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
371 x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
372 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
373
374 // Materialize polynomial approximation for x in [-4, 4] as
375 // erf(x) = x * Alpha(x^2) / Beta(x^2).
376 Value alpha_poly =
377 MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
378 Value beta_poly =
379 MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
380 Value x_mul_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
381 return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly);
382 }
383
MaterializeErfcApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)384 Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
385 Location loc, ValueRange args) {
386 Value x = args.front();
387 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
388 "expect f32 element type");
389
390 // Rely on erfc approximation for |x| >= 1
391 // erfc(x) = erfc_approx(x)
392 Value erfc_approx =
393 MaterializeErfcApproximationF32ForMagnitudeGEOne(rewriter, loc, x);
394
395 // Rely on erf approximation for |x| < 1 and materialize erfc as
396 // erfc(x) = 1 - erf_approx(x)
397 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
398 Value erf_approx =
399 MaterializeErfApproximationF32ForMagnitudeLEOne(rewriter, loc, x);
400 Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
401
402 // Materialize approximation selection based on argument.
403 const StringAttr kLT = rewriter.getStringAttr(
404 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
405 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
406 Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
407 return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
408 erfc_approx);
409 }
410
MaterializeWithUpcast(ConversionPatternRewriter & rewriter,Location loc,ValueRange args,FloatType min_precision_ty,Value callback (ConversionPatternRewriter &,Location,ValueRange))411 Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
412 ValueRange args, FloatType min_precision_ty,
413 Value callback(ConversionPatternRewriter &,
414 Location, ValueRange)) {
415 auto original_ty =
416 getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
417 bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
418
419 // Upcast arguments if necessary.
420 llvm::SmallVector<Value, 2> casted_args;
421 if (needs_upcast) {
422 for (Value a : args) {
423 casted_args.push_back(
424 rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
425 }
426 args = casted_args;
427 }
428
429 Value result = callback(rewriter, loc, args);
430
431 // Cast back if necessary.
432 if (needs_upcast) {
433 result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
434 }
435
436 return result;
437 }
438
439 struct ConvertErfOp : public OpConversionPattern<ErfOp> {
440 using OpConversionPattern<ErfOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertErfOp441 LogicalResult matchAndRewrite(
442 ErfOp op, ArrayRef<Value> operands,
443 ConversionPatternRewriter &rewriter) const override {
444 Location loc = op.getLoc();
445 ErfOp::Adaptor transformed(operands);
446 Value x = transformed.operand();
447 Type ty = x.getType().cast<ShapedType>().getElementType();
448
449 // For now, we support only f64, f32, and f16.
450 if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
451
452 if (ty.isF64()) {
453 rewriter.replaceOp(op, MaterializeErfApproximationF64(rewriter, loc, x));
454 return success();
455 }
456
457 rewriter.replaceOp(op, MaterializeWithUpcast(
458 rewriter, loc, operands, rewriter.getF32Type(),
459 &MaterializeErfApproximationF32));
460 return success();
461 }
462 };
463
464 struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
465 using OpConversionPattern<ErfcOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertErfcOp466 LogicalResult matchAndRewrite(
467 ErfcOp op, ArrayRef<Value> operands,
468 ConversionPatternRewriter &rewriter) const override {
469 Location loc = op.getLoc();
470 ErfcOp::Adaptor transformed(operands);
471 Value x = transformed.operand();
472 Type ty = x.getType().cast<ShapedType>().getElementType();
473
474 // For now, we support only f64, f32, and f16.
475 if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
476
477 if (ty.isF64()) {
478 rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
479 return success();
480 }
481
482 rewriter.replaceOp(op, MaterializeWithUpcast(
483 rewriter, loc, operands, rewriter.getF32Type(),
484 &MaterializeErfcApproximationF32));
485 return success();
486 }
487 };
488
489 // Coefficients for the Lanczos approximation of the gamma function. The
490 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
491 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
492 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
493 // [7, 9] seemed to be the least sensitive to the quality of the log function.
494 // In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
495 // for a particularly inaccurate log function.
496 constexpr double kLanczosGamma = 7; // aka g
497 constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
498 constexpr std::array<double, 8> kLanczosCoefficients = {
499 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
500 771.3234287776530788486528258894, -176.61502916214059906584551354,
501 12.507343278686904814458936853, -0.13857109526572011689554707,
502 9.984369578019570859563e-6, 1.50563273514931155834e-7};
503
504 // Compute the Lgamma function using Lanczos' approximation from "A Precision
505 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
506 // series B. Vol. 1:
507 // lgamma(z + 1) = (log(2) + log(pi)) / 2
508 // + (z + 1/2) * log(t(z))
509 // - t(z) + log(a(z))
510 // with t(z) = z + kLanczosGamma + 1/2
511 // a(z) = kBaseLanczosCoeff
512 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
MaterializeLgamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)513 Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
514 ValueRange args) {
515 // If the input is less than 0.5 use Euler's reflection formula.
516 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
517 // Let z be
518 // z = -x if x < 1/2
519 // z = x - 1 otheriwse
520 Value x = args.front();
521 const StringAttr kLT = rewriter.getStringAttr(
522 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
523 Value half = getConstantLike(rewriter, loc, 0.5, x);
524 Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
525 Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
526 Value one = getConstantLike(rewriter, loc, 1, x);
527 Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
528 Value z =
529 rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
530
531 // Materialize
532 // a(z) = kBaseLanczosCoeff
533 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
534 Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
535 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
536 Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
537 Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
538 Value quotient = rewriter.create<mhlo::DivOp>(
539 loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, one_based_index));
540 a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
541 }
542
543 // To improve accuracy on platforms with less-precise log implementations,
544 // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
545 // device.
546 // Materialize as
547 // log(t) = log(kLanczosGamma + 1/2 + z)
548 // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
549 Value lanczos_plus_half =
550 getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
551 Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
552 Value log_term =
553 getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
554 Value log1p_term = rewriter.create<mhlo::Log1pOp>(
555 loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
556 Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
557
558 // Note that t(z) may be large and we need to be careful not to overflow to
559 // infinity in the relevant term
560 // r = (z + 1/2) * log(t(z)) - t(z).
561 // Therefore, we compute this as
562 // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
563 Value t_div_log_t = rewriter.create<mhlo::DivOp>(loc, t, log_t);
564 Value sum = rewriter.create<mhlo::SubOp>(
565 loc, rewriter.create<mhlo::AddOp>(loc, z, half), t_div_log_t);
566 Value r = rewriter.create<mhlo::MulOp>(loc, sum, log_t);
567
568 // Compute the final result (modulo reflection) as
569 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
570 Value log_a = rewriter.create<mhlo::LogOp>(loc, a);
571 Value lgamma = rewriter.create<mhlo::AddOp>(
572 loc,
573 rewriter.create<mhlo::AddOp>(
574 loc,
575 getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
576 r),
577 log_a);
578
579 // Compute the reflected value for x < 0.5 as
580 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
581 //
582 // The abs is needed because lgamma is the log of the absolute value of the
583 // gamma function.
584 //
585 // We have to be careful when computing the final term above. gamma(x) goes
586 // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
587 // term. The slope is large, so precision is particularly important.
588 //
589 // Because abs(sin(pi * x)) has period of 1 we can equivalently use
590 // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
591 // more numerically accurate: It doesn't overflow to inf like pi * x would and
592 // if x is an integer it evaluates to exactly 0 which is important because we
593 // then take the log of this value, and log(0) is inf.
594 //
595 // We don't have a frac(x) primitive in HLO and computing it is tricky, but
596 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
597 // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
598 //
599 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
600 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
601 // [0, 1] is symmetric across the line Y=0.5.
602 //
603
604 // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
605 // pi * abs_frac for values of abs_frac close to 1.
606 Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
607 Value abs_frac = rewriter.create<mhlo::SubOp>(
608 loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
609 Value reduce_abs_frac =
610 rewriter.create<mhlo::CompareOp>(loc, half, abs_frac, kLT);
611 abs_frac = rewriter.create<mhlo::SelectOp>(
612 loc, reduce_abs_frac, rewriter.create<mhlo::SubOp>(loc, one, abs_frac),
613 abs_frac);
614
615 // Materialize reflection.
616 Value reflection_denom = rewriter.create<mhlo::LogOp>(
617 loc,
618 rewriter.create<mhlo::SinOp>(
619 loc, rewriter.create<mhlo::MulOp>(
620 loc, getConstantLike(rewriter, loc, M_PI, x), abs_frac)));
621 Value lgamma_reflection = rewriter.create<mhlo::SubOp>(
622 loc,
623 rewriter.create<mhlo::SubOp>(
624 loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
625 reflection_denom),
626 lgamma);
627
628 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
629 // then it "wins" and the result is +/-inf.
630 Value finite_reflection_denom =
631 rewriter.create<mhlo::IsFiniteOp>(loc, reflection_denom);
632 Value neg_reflection_denom =
633 rewriter.create<mhlo::NegOp>(loc, reflection_denom);
634 lgamma_reflection = rewriter.create<mhlo::SelectOp>(
635 loc, finite_reflection_denom, lgamma_reflection, neg_reflection_denom);
636
637 // Select whether or not to rely on the reflection.
638 lgamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect,
639 lgamma_reflection, lgamma);
640
641 // Materialize +/-inf behavior as
642 // lgamma(+/-inf) = +inf.
643 Value x_is_inf = rewriter.create<chlo::IsInfOp>(loc, x);
644 return rewriter.create<mhlo::SelectOp>(
645 loc, x_is_inf,
646 chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
647 lgamma);
648 }
649
650 // Compute the Digamma function using Lanczos' approximation from "A Precision
651 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
652 // series B. Vol. 1:
653 // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
654 // with t(z) = z + kLanczosGamma + 1/2
655 // a(z) = kBaseLanczosCoeff
656 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
657 // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
MaterializeDigamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)658 Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
659 ValueRange args) {
660 // If the input is less than 0.5 use Euler's reflection formula.
661 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
662 // Let z be
663 // z = -x if x < 1/2
664 // z = x - 1 otheriwse
665 Value x = args.front();
666 const StringAttr kLT = rewriter.getStringAttr(
667 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
668 Value half = getConstantLike(rewriter, loc, 0.5, x);
669 Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
670 Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
671 Value one = getConstantLike(rewriter, loc, 1, x);
672 Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
673 Value z =
674 rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
675
676 // Materialize
677 // a(z) = kBaseLanczosCoeff
678 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
679 // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
680 Value zero = getConstantLike(rewriter, loc, 0.0, x);
681 Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
682 Value a_prime = zero;
683 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
684 Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
685 Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
686 Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index);
687 a_prime = rewriter.create<mhlo::SubOp>(
688 loc, a_prime,
689 rewriter.create<mhlo::DivOp>(
690 loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term)));
691 a = rewriter.create<mhlo::AddOp>(
692 loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term));
693 }
694
695 // To improve accuracy on platforms with less-precise log implementations,
696 // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
697 // device.
698 // Materialize as
699 // log(t) = log(kLanczosGamma + 1/2 + z)
700 // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
701 Value lanczos_plus_half =
702 getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
703 Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
704 Value log_term =
705 getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
706 Value log1p_term = rewriter.create<mhlo::Log1pOp>(
707 loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
708 Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
709
710 // Materialize the final result (modulo reflection) as
711 // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
712 Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a);
713 Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>(
714 loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
715 Value digamma = rewriter.create<mhlo::SubOp>(
716 loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a),
717 lanczos_gamma_div_t);
718
719 // We need to be careful how we compute cot(pi * input) below: For
720 // near-integral arguments, pi * input can lose precision.
721 //
722 // Input is already known to be less than 0.5 (otherwise we don't have to
723 // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
724 // increase precision of pi * x and the resulting cotangent.
725 Value reduced_x = rewriter.create<mhlo::AddOp>(
726 loc, x,
727 rewriter.create<mhlo::AbsOp>(
728 loc, rewriter.create<mhlo::FloorOp>(
729 loc, rewriter.create<mhlo::AddOp>(
730 loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
731
732 // Materialize reflection for inputs less than 0.5 as
733 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
734 // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
735 Value pi = getConstantLike(rewriter, loc, M_PI, x);
736 Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x);
737 Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x);
738 Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x);
739 Value reflection = rewriter.create<mhlo::SubOp>(
740 loc, digamma,
741 rewriter.create<mhlo::DivOp>(
742 loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
743
744 // Select whether or not to rely on the reflection.
745 digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection,
746 digamma);
747
748 // Digamma has poles at negative integers and zero; return nan for those.
749 const StringAttr kLE = rewriter.getStringAttr(
750 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
751 Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE);
752 const StringAttr kEQ = rewriter.getStringAttr(
753 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
754 Value is_int = rewriter.create<mhlo::CompareOp>(
755 loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
756 Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int);
757 return rewriter.create<mhlo::SelectOp>(
758 loc, is_pole,
759 getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
760 x),
761 digamma);
762 }
763
MaterializeZeta(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)764 Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
765 ValueRange args) {
766 assert(args.size() == 2);
767 Value x = args[0];
768 Value q = args[1];
769 static const std::array<double, 12> kZetaCoeffs{
770 -7.1661652561756670113e18,
771 1.8152105401943546773e17,
772 -4.5979787224074726105e15,
773 1.1646782814350067249e14,
774 -2.950130727918164224e12,
775 7.47242496e10,
776 -1.8924375803183791606e9,
777 47900160.0,
778 -1209600.0,
779 30240.0,
780 -720.0,
781 12.0,
782 };
783
784 // For speed we'll always use 9 iterations for the initial series estimate,
785 // and a 12 term expansion for the Euler-Maclaurin formula.
786 Value a = q;
787 Value zero_like_a = chlo::getConstantLike(rewriter, loc, 0.0, a);
788 Value neg_power = zero_like_a;
789 Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
790 Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
791 Value one_like_a = chlo::getConstantLike(rewriter, loc, 1.0, a);
792 for (int i = 0; i < 9; ++i) {
793 a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a);
794 neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
795 initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
796 }
797 a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a);
798 neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
799 Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
800 Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
801 Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a);
802 Value neg_power_mul_a_div_x_minus_one =
803 rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one);
804 Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
805 neg_power_mul_a_div_x_minus_one);
806 Value a_inverse_square = rewriter.create<mhlo::DivOp>(
807 loc, one_like_a, rewriter.create<mhlo::MulOp>(loc, a, a));
808
809 Value horner_sum = zero_like_a;
810 Value factor = one_like_a;
811 // Use Horner's rule for this.
812 // Note this differs from Cephes which does a 'naive' polynomial evaluation.
813 // Using Horner's rule allows to avoid some NaN's and Infs from happening,
814 // resulting in more numerically stable code.
815 for (int i = 0; i < 11; ++i) {
816 Value factor_lhs = rewriter.create<mhlo::SubOp>(
817 loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
818 Value factor_rhs = rewriter.create<mhlo::SubOp>(
819 loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
820 factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs);
821 horner_sum = rewriter.create<mhlo::MulOp>(
822 loc, factor,
823 rewriter.create<mhlo::MulOp>(
824 loc, a_inverse_square,
825 rewriter.create<mhlo::AddOp>(
826 loc, horner_sum,
827 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
828 }
829 Value zero_point_five_like_neg_power =
830 chlo::getConstantLike(rewriter, loc, .5, neg_power);
831 Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a);
832 s = rewriter.create<mhlo::AddOp>(
833 loc, s,
834 rewriter.create<mhlo::MulOp>(
835 loc, neg_power,
836 rewriter.create<mhlo::AddOp>(
837 loc, zero_point_five_like_neg_power,
838 rewriter.create<mhlo::MulOp>(
839 loc, x_div_a,
840 rewriter.create<mhlo::AddOp>(
841 loc,
842 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
843 a),
844 horner_sum)))));
845 const double nan = std::numeric_limits<double>::quiet_NaN();
846 const double inf = std::numeric_limits<double>::infinity();
847 // Use the initial zeta sum without the correction term coming
848 // from Euler-Maclaurin if it is accurate enough.
849 const StringAttr kLT = rewriter.getStringAttr(
850 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
851 Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power);
852 Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum);
853 Value output = rewriter.create<mhlo::SelectOp>(
854 loc,
855 rewriter.create<mhlo::CompareOp>(
856 loc, abs_neg_power,
857 rewriter.create<mhlo::MulOp>(
858 loc, abs_initial_sum,
859 chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
860 kLT),
861 initial_sum, s);
862 // This is the harmonic series.
863 const StringAttr kEQ = rewriter.getStringAttr(
864 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
865 Value inf_like_x = chlo::getConstantLike(rewriter, loc, inf, x);
866 output = rewriter.create<mhlo::SelectOp>(
867 loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kEQ),
868 inf_like_x, output);
869 // Function is not defined for x < 1.
870 Value nan_like_x = chlo::getConstantLike(rewriter, loc, nan, x);
871 output = rewriter.create<mhlo::SelectOp>(
872 loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT),
873 nan_like_x, output);
874 // If q <= 0, then when q is an integer or x is not an integer, this is
875 // NaN.
876 const StringAttr kLE = rewriter.getStringAttr(
877 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
878 const StringAttr kNE = rewriter.getStringAttr(
879 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
880 Value zero_like_q = chlo::getConstantLike(rewriter, loc, 0.0, q);
881 Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero_like_q, kLE);
882 Value domain_error = rewriter.create<mhlo::AndOp>(
883 loc, q_le_zero,
884 rewriter.create<mhlo::CompareOp>(
885 loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE));
886 Value negative_integer_q = rewriter.create<mhlo::AndOp>(
887 loc, q_le_zero,
888 rewriter.create<mhlo::CompareOp>(
889 loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ));
890 output = rewriter.create<mhlo::SelectOp>(loc, negative_integer_q, inf_like_x,
891 output);
892 output =
893 rewriter.create<mhlo::SelectOp>(loc, domain_error, nan_like_x, output);
894 return output;
895 }
896
MaterializePolygamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)897 Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
898 ValueRange args) {
899 PolygammaOp::Adaptor transformed(args);
900 Value n = transformed.n();
901 Value x = transformed.x();
902
903 // Handle integer n > 0.
904 Value one = getConstantLike(rewriter, loc, 1.0, x);
905 Value two = getConstantLike(rewriter, loc, 2.0, x);
906 Value sign = rewriter.create<mhlo::SubOp>(
907 loc,
908 rewriter.create<mhlo::MulOp>(loc, two,
909 rewriter.create<mhlo::RemOp>(loc, n, two)),
910 one);
911 Value n_plus_one = rewriter.create<mhlo::AddOp>(loc, n, one);
912 Value exp_lgamma_np1 = rewriter.create<mhlo::ExpOp>(
913 loc, rewriter.create<chlo::LgammaOp>(loc, n_plus_one));
914 Value zeta = rewriter.create<chlo::ZetaOp>(loc, n_plus_one, x);
915 Value result = rewriter.create<mhlo::MulOp>(
916 loc, rewriter.create<mhlo::MulOp>(loc, sign, exp_lgamma_np1), zeta);
917
918 // Handle n = 0.
919 const StringAttr kEQ = rewriter.getStringAttr(
920 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
921 Value zero = getConstantLike(rewriter, loc, 0.0, x);
922 Value n_eq_zero = rewriter.create<mhlo::CompareOp>(loc, n, zero, kEQ);
923 result = rewriter.create<mhlo::SelectOp>(
924 loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
925
926 // Check that n is a natural number.
927 const StringAttr kNE = rewriter.getStringAttr(
928 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
929 Value non_int = rewriter.create<mhlo::CompareOp>(
930 loc, n, rewriter.create<mhlo::FloorOp>(loc, n), kNE);
931 const StringAttr kLT = rewriter.getStringAttr(
932 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
933 Value negative = rewriter.create<mhlo::CompareOp>(loc, n, zero, kLT);
934 Value non_natural = rewriter.create<mhlo::OrOp>(loc, non_int, negative);
935 return rewriter.create<mhlo::SelectOp>(
936 loc, non_natural,
937 getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
938 x),
939 result);
940 }
941
942 struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
943 using OpConversionPattern<LgammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertLgammaOp944 LogicalResult matchAndRewrite(
945 LgammaOp op, ArrayRef<Value> operands,
946 ConversionPatternRewriter &rewriter) const override {
947 FloatType min_precision_ty = rewriter.getF32Type();
948 rewriter.replaceOp(
949 op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
950 min_precision_ty, &MaterializeLgamma));
951 return success();
952 }
953 };
954
955 struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
956 using OpConversionPattern<DigammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertDigammaOp957 LogicalResult matchAndRewrite(
958 DigammaOp op, ArrayRef<Value> operands,
959 ConversionPatternRewriter &rewriter) const override {
960 FloatType min_precision_ty = rewriter.getF32Type();
961 rewriter.replaceOp(
962 op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
963 min_precision_ty, &MaterializeDigamma));
964 return success();
965 }
966 };
967
968 struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
969 using OpConversionPattern<PolygammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertPolygammaOp970 LogicalResult matchAndRewrite(
971 PolygammaOp op, ArrayRef<Value> operands,
972 ConversionPatternRewriter &rewriter) const override {
973 Location loc = op.getLoc();
974 FloatType min_precision_ty = rewriter.getF32Type();
975 rewriter.replaceOp(
976 op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
977 &MaterializePolygamma));
978 return success();
979 }
980 };
981
982 struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
983 using OpConversionPattern<ZetaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertZetaOp984 LogicalResult matchAndRewrite(
985 ZetaOp op, ArrayRef<Value> operands,
986 ConversionPatternRewriter &rewriter) const override {
987 Location loc = op.getLoc();
988 FloatType min_precision_ty = rewriter.getF32Type();
989 rewriter.replaceOp(
990 op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
991 &MaterializeZeta));
992 return success();
993 }
994 };
995
996 // Converts binary ops that statically are determined to not broadcast directly
997 // to the corresponding mhlo non-broadcasting op.
998 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
999 struct ConvertTrivialNonBroadcastBinaryOp
1000 : public OpConversionPattern<ChloOpTy> {
1001 using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertTrivialNonBroadcastBinaryOp1002 LogicalResult matchAndRewrite(
1003 ChloOpTy op, ArrayRef<Value> operands,
1004 ConversionPatternRewriter &rewriter) const override {
1005 // Only rewrite for statically determinable non-broadcasting cases.
1006 typename ChloOpTy::Adaptor transformed(operands);
1007 auto lhs_type =
1008 transformed.lhs().getType().template dyn_cast<RankedTensorType>();
1009 auto rhs_type =
1010 transformed.rhs().getType().template dyn_cast<RankedTensorType>();
1011 if (!lhs_type || !rhs_type) return failure();
1012
1013 // Requires rank broadcast.
1014 if (lhs_type.getRank() != rhs_type.getRank()) return failure();
1015 // Any dynamic dimension may require broadcasting and requires more
1016 // analysis.
1017 if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape())
1018 return failure();
1019
1020 for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) {
1021 auto lhs_extent = std::get<0>(extents);
1022 auto rhs_extent = std::get<1>(extents);
1023 if (lhs_extent != rhs_extent) {
1024 return failure();
1025 }
1026 }
1027
1028 rewriter.replaceOp(
1029 op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
1030 operands[1], rewriter)});
1031 return success();
1032 }
1033 };
1034
1035 // Converts a binary op with ranked broadcasting operands to explicitly
1036 // broadcast and invoke the corresponding mhlo non-broadcasting op.
1037 // Note that dynamic broadcasting supported by this pattern is only valid for
1038 // "numpy" broadcasting semantics as defined here:
1039 // https://docs.scipy.org/doc/numpy/reference/ufuncs.html
1040 // Specifically, this includes the following cases:
1041 // - Same rank broadcast (operands have the same static rank).
1042 // - Different-rank broadcast, either without a broadcast_dims attribte or
1043 // with the broadcast_dims attribute set to map to a prefix padding.
1044 // - Legal combinations of degenerate (1-dim) implicit broadcasting.
1045 // The restriction on broadcast_dims derives from the definition of the
1046 // `shape.broadcast` op, which only supports prefix-padding.
1047 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1048 struct ConvertRankedDynamicBroadcastBinaryOp
1049 : public OpConversionPattern<ChloOpTy> {
1050 using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertRankedDynamicBroadcastBinaryOp1051 LogicalResult matchAndRewrite(
1052 ChloOpTy op, ArrayRef<Value> operands,
1053 ConversionPatternRewriter &rewriter) const override {
1054 // Only support ranked operands.
1055 typename ChloOpTy::Adaptor transformed(operands);
1056 Value lhs = transformed.lhs();
1057 Value rhs = transformed.rhs();
1058 auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
1059 auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
1060 auto result_type =
1061 op.getResult().getType().template dyn_cast<RankedTensorType>();
1062 if (!lhs_type || !rhs_type || !result_type) return failure();
1063
1064 // Check for "numpy"-style rank broadcast.
1065 auto broadcast_dimensions = op.broadcast_dimensions();
1066 if (broadcast_dimensions &&
1067 !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
1068 // Note: It is unclear whether the general specification of explicit
1069 // broadcast_dimensions on binary ops is a feature we want to carry
1070 // forward. While it can technically be implemented for ranked-dynamic,
1071 // it is incompatible with unranked inputs. If this warning is emitted
1072 // in real programs, it is an indication that the feature should be
1073 // implemented versus just falling back on the more standard definition
1074 // of numpy-like prefix-padding.
1075 op.emitWarning() << "unsupported non prefix-padded dynamic rank "
1076 << "broadcast_dimensions = " << *broadcast_dimensions;
1077 return failure();
1078 }
1079
1080 // Compute result shape.
1081 auto loc = op.getLoc();
1082
1083 // Insert a constraint on the shapes being broadcastable and insert all
1084 // future code into an assuming block reliant on the constraint.
1085 Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
1086 Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
1087 auto broadcastable_cstr =
1088 rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
1089 auto assuming_op = rewriter.create<shape::AssumingOp>(
1090 loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
1091
1092 OpBuilder::InsertionGuard guard(rewriter);
1093 rewriter.createBlock(&assuming_op.doRegion());
1094
1095 int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
1096 Value result_extents =
1097 hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
1098 loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true);
1099
1100 // Note that we unconditionally emit DynamicBroadcastInDim ops and let
1101 // downstream canonicalizations fold them away if possible. This is
1102 // because, in the dynamic case, there are many corner cases regarding
1103 // when it is safe to omit, and some of them require analysis to prove
1104 // properly.
1105 auto lhs_broadcast_dimensions = llvm::to_vector<4>(
1106 llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
1107 Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1108 loc,
1109 RankedTensorType::get(result_type.getShape(),
1110 lhs_type.getElementType()),
1111 lhs, result_extents,
1112 rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
1113 auto rhs_broadcast_dimensions = llvm::to_vector<4>(
1114 llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
1115 Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1116 loc,
1117 RankedTensorType::get(result_type.getShape(),
1118 rhs_type.getElementType()),
1119 rhs, result_extents,
1120 rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
1121
1122 // And generate the final non-broadcasted binary op.
1123 Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
1124 broadcasted_rhs, rewriter);
1125 rewriter.create<shape::AssumingYieldOp>(loc, final_result);
1126 rewriter.replaceOp(op, {assuming_op.getResult(0)});
1127 return success();
1128 }
1129 };
1130
1131 #include "generated_chlo_legalize_to_hlo.inc"
1132 } // namespace
1133
PopulateChloBroadcastingPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1134 void PopulateChloBroadcastingPatterns(MLIRContext *context,
1135 OwningRewritePatternList *patterns) {
1136 // Instantiate conversion templates for conforming binary elementwise ops
1137 // that do not have different dtypes between operands and results and do
1138 // not have special attributes that need to be preserved.
1139 PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
1140 context, patterns, 10);
1141 PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
1142 context, patterns, 5);
1143 }
1144
PopulateLegalizeChloToHloPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1145 void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
1146 OwningRewritePatternList *patterns) {
1147 populateWithGenerated(context, *patterns);
1148 PopulateChloBroadcastingPatterns(context, patterns);
1149
1150 // Other patterns.
1151 // clang-format off
1152 patterns->insert<ConvertConstantLikeOp,
1153 ConvertDigammaOp,
1154 ConvertErfOp,
1155 ConvertErfcOp,
1156 ConvertLgammaOp,
1157 ConvertPolygammaOp,
1158 ConvertZetaOp>(context);
1159 // clang-format on
1160 }
1161
1162 } // namespace chlo
1163 } // namespace mlir
1164