1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Enable the use of M_* math constants.
17 // NOTE: this must be first in the file to ensure that if cmath is transitively
18 // included by any other header it has the define set on first processing.
19 // https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22 #include <numeric>
23 #include <vector>
24 
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
30 #include "mlir-hlo/utils/broadcast_utils.h"
31 #include "mlir/Dialect/SCF/SCF.h"
32 #include "mlir/Dialect/Shape/IR/Shape.h"
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"
35 #include "mlir/IR/Attributes.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/ImplicitLocOpBuilder.h"
38 #include "mlir/IR/MLIRContext.h"
39 #include "mlir/IR/OperationSupport.h"
40 #include "mlir/IR/PatternMatch.h"
41 #include "mlir/Transforms/DialectConversion.h"
42 
43 namespace mlir {
44 namespace chlo {
45 namespace {
46 
47 struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
48   using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertConstantLikeOp49   LogicalResult matchAndRewrite(
50       ConstantLikeOp op, ArrayRef<Value> operands,
51       ConversionPatternRewriter &rewriter) const override {
52     auto result_ty = op.getType().cast<ShapedType>();
53 
54     // Unranked uses are not supported.  Consider `transform-unranked-hlo`.
55     if (!result_ty.hasRank()) return failure();
56 
57     // Lower to MHLO constant if statically shaped.
58     if (result_ty.hasStaticShape()) {
59       rewriter.replaceOpWithNewOp<mhlo::ConstOp>(
60           op, DenseElementsAttr::get(result_ty, op.value()));
61       return success();
62     }
63 
64     // Lower to broadcasted constant.
65     ConstantLikeOp::Adaptor transformed(operands);
66     auto loc = op.getLoc();
67     Type extent_tensor_type = shape::getExtentTensorType(op.getContext());
68     Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value());
69     Value uncasted_shape = rewriter.create<shape::ShapeOfOp>(
70         loc, extent_tensor_type, transformed.operand());
71     Type shape_ty =
72         RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
73     Value shape =
74         rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
75     rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
76         op, result_ty, constant, shape, rewriter.getI64TensorAttr({}));
77     return success();
78   }
79 };
80 
81 template <typename FTy>
MaterializePolynomialApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,const std::vector<FTy> & coefficients)82 Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
83                                          Location loc, Value x,
84                                          const std::vector<FTy> &coefficients) {
85   Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
86   for (FTy c : coefficients) {
87     poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
88     poly = rewriter.create<mhlo::AddOp>(
89         loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
90   }
91   return poly;
92 }
93 
94 // Precondition is |x| >= 1. Use erf approximation, otherwise.
95 //
96 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
97 // argument and derive the final approximation for all |x| >= 1.
98 // This implementation is based on Cephes.
MaterializeErfcApproximationF64ForMagnituteGEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)99 Value MaterializeErfcApproximationF64ForMagnituteGEOne(
100     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
101   Value x = args.front();
102   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
103          "expect f64 element type");
104   const double kMaxlog = 7.09782712893383996843E2;
105   const std::vector<double> kErfcPCoefficients{
106       2.46196981473530512524E-10, 5.64189564831068821977E-1,
107       7.46321056442269912687E0,   4.86371970985681366614E1,
108       1.96520832956077098242E2,   5.26445194995477358631E2,
109       9.34528527171957607540E2,   1.02755188689515710272E3,
110       5.57535335369399327526E2};
111   const std::vector<double> kErfcQCoefficients{
112       1.00000000000000000000E0, 1.32281951154744992508E1,
113       8.67072140885989742329E1, 3.54937778887819891062E2,
114       9.75708501743205489753E2, 1.82390916687909736289E3,
115       2.24633760818710981792E3, 1.65666309194161350182E3,
116       5.57535340817727675546E2};
117   const std::vector<double> kErfcRCoefficients{
118       5.64189583547755073984E-1, 1.27536670759978104416E0,
119       5.01905042251180477414E0,  6.16021097993053585195E0,
120       7.40974269950448939160E0,  2.97886665372100240670E0};
121   const std::vector<double> kErfcSCoefficients{
122       1.00000000000000000000E0, 2.26052863220117276590E0,
123       9.39603524938001434673E0, 1.20489539808096656605E1,
124       1.70814450747565897222E1, 9.60896809063285878198E0,
125       3.36907645100081516050E0};
126 
127   // Let z = -x^2.
128   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
129   Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
130 
131   // Materialize polynomial approximation for x in [1, 8) as
132   //   erfc(x) = exp(z) P(|x|) / Q(|x|).
133   Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
134   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
135   Value poly_p = MaterializePolynomialApproximation(rewriter, loc, abs_x,
136                                                     kErfcPCoefficients);
137   Value exp_z_mul_poly_p = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_p);
138   Value poly_q = MaterializePolynomialApproximation(rewriter, loc, abs_x,
139                                                     kErfcQCoefficients);
140   Value erfc_approx_1_8 =
141       rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_p, poly_q);
142 
143   // Materialize polynomial approximation for x in >= 8 as
144   //   erfc(x) exp(z) R(|x|) / S(|x|).
145   Value poly_r = MaterializePolynomialApproximation(rewriter, loc, abs_x,
146                                                     kErfcRCoefficients);
147   Value exp_z_mul_poly_r = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_r);
148   Value poly_s = MaterializePolynomialApproximation(rewriter, loc, abs_x,
149                                                     kErfcSCoefficients);
150   Value erfc_approx_8_inf =
151       rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_r, poly_s);
152 
153   // Combine polynomial approximations for x >= 1.
154   const StringAttr kLT = rewriter.getStringAttr(
155       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
156   Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
157   Value abs_x_lt_8 = rewriter.create<mhlo::CompareOp>(loc, abs_x, eight, kLT);
158   Value erfc_approx = rewriter.create<mhlo::SelectOp>(
159       loc, abs_x_lt_8, erfc_approx_1_8, erfc_approx_8_inf);
160 
161   // Clamp to prevent overflow and materialize approximation for large x as
162   //   erfc(x) = 0.
163   Value z_lt_neg_maxlog = rewriter.create<mhlo::CompareOp>(
164       loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
165   Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
166   Value erfc_approx_clamped =
167       rewriter.create<mhlo::SelectOp>(loc, z_lt_neg_maxlog, zero, erfc_approx);
168 
169   // Derive approximation for x <= -1 as
170   //   erfc(x) = 2 - erfc(-x).
171   // Reuse previously materialized approximations all of which take |x| as their
172   // argument.
173   Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
174   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
175   Value two_sub_erfc_approx_clamped =
176       rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
177   return rewriter.create<mhlo::SelectOp>(
178       loc, x_lt_zero, two_sub_erfc_approx_clamped, erfc_approx_clamped);
179 }
180 
181 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
182 // This implementation is based on Cephes.
MaterializeErfApproximationF64ForMagnituteLEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)183 Value MaterializeErfApproximationF64ForMagnituteLEOne(
184     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
185   Value x = args.front();
186   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
187          "expect f64 element type");
188   const std::vector<double> kErfTCoefficients{
189       9.60497373987051638749E0, 9.00260197203842689217E1,
190       2.23200534594684319226E3, 7.00332514112805075473E3,
191       5.55923013010394962768E4};
192   const std::vector<double> kErfUCoefficients{
193       1.00000000000000000000E0, 3.35617141647503099647E1,
194       5.21357949780152679795E2, 4.59432382970980127987E3,
195       2.26290000613890934246E4, 4.92673942608635921086E4};
196 
197   // Materialize polynomial approximation for |x| <= 1 as
198   //   erf(x) = x T(x^2) / U(x^2).
199   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
200   Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
201                                                     kErfTCoefficients);
202   Value x_mul_poly_t = rewriter.create<mhlo::MulOp>(loc, x, poly_t);
203   Value poly_u = MaterializePolynomialApproximation(rewriter, loc, x_sq,
204                                                     kErfUCoefficients);
205   return rewriter.create<mhlo::DivOp>(loc, x_mul_poly_t, poly_u);
206 }
207 
208 // This implementation is based on Cephes.
MaterializeErfApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)209 Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
210                                      Location loc, ValueRange args) {
211   Value x = args.front();
212   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
213          "expect f64 element type");
214 
215   // Rely on erf approximation for |x| < 1
216   //   erf(x) = erf_approx(x)
217   Value erf_approx =
218       MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
219 
220   // Rely on erfc approximation for |x| >= 1 and materialize erf as
221   //   erf(x) = 1 - erfc_approx(x)
222   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
223   Value erfc_approx =
224       MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
225   Value erfc_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erfc_approx);
226 
227   // Materialize approximation selection based on argument.
228   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
229   const StringAttr kLT = rewriter.getStringAttr(
230       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
231   Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
232   return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_approx,
233                                          erfc_based_approx);
234 }
235 
MaterializeErfcApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)236 Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
237                                       Location loc, ValueRange args) {
238   Value x = args.front();
239   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
240          "expect f64 element type");
241 
242   // Rely on erfc approximation for |x| >= 1
243   //   erfc(x) = erfc_approx(x)
244   Value erfc_approx =
245       MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
246 
247   // Rely on erf approximation for |x| < 1 and materialize erfc as
248   //   erfc(x) = 1 - erf_approx(x)
249   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
250   Value erf_approx =
251       MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
252   Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
253 
254   // Materialize approximation selection based on argument.
255   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
256   const StringAttr kLT = rewriter.getStringAttr(
257       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
258   Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
259   return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
260                                          erfc_approx);
261 }
262 
263 // Precondition is |x| >= 1. Use erf approximation, otherwise.
264 //
265 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
266 // argument and derive the final approximation for all |x| >= 1.
267 // This implementation is based on Cephes.
MaterializeErfcApproximationF32ForMagnitudeGEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)268 Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
269     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
270   Value x = args.front();
271   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
272          "expect f32 element type");
273   const double kMaxlog = 88.72283905206835;
274   const std::vector<float> kErfcPCoefficients{
275       +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
276       -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
277       +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
278   };
279   const std::vector<float> kErfcRCoefficients{
280       -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
281       +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
282       -2.820767439740514E-1, +5.641895067754075E-1,
283   };
284 
285   // Let z = -x^2.
286   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
287   Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
288 
289   // Materialize polynomial approximation for x >= 1 as
290   //   erfc(x) = exp(z) 1/x P(1/x^2)   if x in [1, 2)
291   //   erfc(x) = exp(z) 1/x R(1/x^2)   if x >= 2
292   const StringAttr kLT = rewriter.getStringAttr(
293       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
294   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
295   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
296   Value reciprocal_x_sq = rewriter.create<mhlo::DivOp>(loc, one, x_sq);
297   Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
298   Value one_div_abs_x = rewriter.create<mhlo::DivOp>(loc, one, abs_x);
299   Value exp_z_mul_one_div_abs_x =
300       rewriter.create<mhlo::MulOp>(loc, exp_z, one_div_abs_x);
301   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
302   Value abs_x_lt_two = rewriter.create<mhlo::CompareOp>(loc, abs_x, two, kLT);
303   Value poly_p = MaterializePolynomialApproximation(
304       rewriter, loc, reciprocal_x_sq, kErfcPCoefficients);
305   Value poly_r = MaterializePolynomialApproximation(
306       rewriter, loc, reciprocal_x_sq, kErfcRCoefficients);
307   Value poly =
308       rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_two, poly_p, poly_r);
309   Value erfc_approx =
310       rewriter.create<mhlo::MulOp>(loc, exp_z_mul_one_div_abs_x, poly);
311 
312   // Clamp to prevent overflow and materialize approximation for large x as
313   //   erfc(x) = 0.
314   Value z_lt_neq_maxlog = rewriter.create<mhlo::CompareOp>(
315       loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
316   Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
317   Value erfc_approx_clamped =
318       rewriter.create<mhlo::SelectOp>(loc, z_lt_neq_maxlog, zero, erfc_approx);
319 
320   // Derive approximation for x <= -1 as
321   //   erfc(x) = 2 - erfc(-x).
322   // Reuse previously materialized approximations all of which take |x| as their
323   // argument.
324   Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
325   Value two_sub_erfc_approx =
326       rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
327   return rewriter.create<mhlo::SelectOp>(loc, x_lt_zero, two_sub_erfc_approx,
328                                          erfc_approx_clamped);
329 }
330 
331 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
332 // This implementation is based on Cephes.
MaterializeErfApproximationF32ForMagnitudeLEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)333 Value MaterializeErfApproximationF32ForMagnitudeLEOne(
334     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
335   Value x = args.front();
336   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
337          "expect f32 element type");
338   const std::vector<float> kErfTCoefficients{
339       +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
340       -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
341       +1.128379165726710E+0,
342   };
343 
344   // Materialize polynomial approximation for |x| <= 1 as
345   //   erf(x) = x T(x^2).
346   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
347   Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
348                                                     kErfTCoefficients);
349   return rewriter.create<mhlo::MulOp>(loc, x, poly_t);
350 }
351 
352 // This is the same approximation as used in Eigen.
MaterializeErfApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)353 Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
354                                      Location loc, ValueRange args) {
355   Value x = args.front();
356   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
357          "expect f32 element type");
358   const std::vector<float> kAlpha{
359       -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f,
360       -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
361       -1.60960333262415e-02f,
362   };
363   const std::vector<float> kBeta{
364       -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
365       -7.37332916720468e-03f, -1.42647390514189e-02f,
366   };
367 
368   // Clamp argument between -4 and 4.
369   Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
370   Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
371   x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
372   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
373 
374   // Materialize polynomial approximation for x in [-4, 4] as
375   //   erf(x) = x * Alpha(x^2) / Beta(x^2).
376   Value alpha_poly =
377       MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
378   Value beta_poly =
379       MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
380   Value x_mul_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
381   return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly);
382 }
383 
MaterializeErfcApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)384 Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
385                                       Location loc, ValueRange args) {
386   Value x = args.front();
387   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
388          "expect f32 element type");
389 
390   // Rely on erfc approximation for |x| >= 1
391   //   erfc(x) = erfc_approx(x)
392   Value erfc_approx =
393       MaterializeErfcApproximationF32ForMagnitudeGEOne(rewriter, loc, x);
394 
395   // Rely on erf approximation for |x| < 1 and materialize erfc as
396   //   erfc(x) = 1 - erf_approx(x)
397   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
398   Value erf_approx =
399       MaterializeErfApproximationF32ForMagnitudeLEOne(rewriter, loc, x);
400   Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
401 
402   // Materialize approximation selection based on argument.
403   const StringAttr kLT = rewriter.getStringAttr(
404       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
405   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
406   Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
407   return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
408                                          erfc_approx);
409 }
410 
MaterializeWithUpcast(ConversionPatternRewriter & rewriter,Location loc,ValueRange args,FloatType min_precision_ty,Value callback (ConversionPatternRewriter &,Location,ValueRange))411 Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
412                             ValueRange args, FloatType min_precision_ty,
413                             Value callback(ConversionPatternRewriter &,
414                                            Location, ValueRange)) {
415   auto original_ty =
416       getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
417   bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
418 
419   // Upcast arguments if necessary.
420   llvm::SmallVector<Value, 2> casted_args;
421   if (needs_upcast) {
422     for (Value a : args) {
423       casted_args.push_back(
424           rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
425     }
426     args = casted_args;
427   }
428 
429   Value result = callback(rewriter, loc, args);
430 
431   // Cast back if necessary.
432   if (needs_upcast) {
433     result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
434   }
435 
436   return result;
437 }
438 
439 struct ConvertErfOp : public OpConversionPattern<ErfOp> {
440   using OpConversionPattern<ErfOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertErfOp441   LogicalResult matchAndRewrite(
442       ErfOp op, ArrayRef<Value> operands,
443       ConversionPatternRewriter &rewriter) const override {
444     Location loc = op.getLoc();
445     ErfOp::Adaptor transformed(operands);
446     Value x = transformed.operand();
447     Type ty = x.getType().cast<ShapedType>().getElementType();
448 
449     // For now, we support only f64, f32, and f16.
450     if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
451 
452     if (ty.isF64()) {
453       rewriter.replaceOp(op, MaterializeErfApproximationF64(rewriter, loc, x));
454       return success();
455     }
456 
457     rewriter.replaceOp(op, MaterializeWithUpcast(
458                                rewriter, loc, operands, rewriter.getF32Type(),
459                                &MaterializeErfApproximationF32));
460     return success();
461   }
462 };
463 
464 struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
465   using OpConversionPattern<ErfcOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertErfcOp466   LogicalResult matchAndRewrite(
467       ErfcOp op, ArrayRef<Value> operands,
468       ConversionPatternRewriter &rewriter) const override {
469     Location loc = op.getLoc();
470     ErfcOp::Adaptor transformed(operands);
471     Value x = transformed.operand();
472     Type ty = x.getType().cast<ShapedType>().getElementType();
473 
474     // For now, we support only f64, f32, and f16.
475     if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
476 
477     if (ty.isF64()) {
478       rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
479       return success();
480     }
481 
482     rewriter.replaceOp(op, MaterializeWithUpcast(
483                                rewriter, loc, operands, rewriter.getF32Type(),
484                                &MaterializeErfcApproximationF32));
485     return success();
486   }
487 };
488 
489 // Coefficients for the Lanczos approximation of the gamma function. The
490 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
491 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
492 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
493 // [7, 9] seemed to be the least sensitive to the quality of the log function.
494 // In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
495 // for a particularly inaccurate log function.
496 constexpr double kLanczosGamma = 7;  // aka g
497 constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
498 constexpr std::array<double, 8> kLanczosCoefficients = {
499     676.520368121885098567009190444019, -1259.13921672240287047156078755283,
500     771.3234287776530788486528258894,   -176.61502916214059906584551354,
501     12.507343278686904814458936853,     -0.13857109526572011689554707,
502     9.984369578019570859563e-6,         1.50563273514931155834e-7};
503 
504 // Compute the Lgamma function using Lanczos' approximation from "A Precision
505 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
506 // series B. Vol. 1:
507 //   lgamma(z + 1) = (log(2) + log(pi)) / 2
508 //                     + (z + 1/2) * log(t(z))
509 //                     - t(z) + log(a(z))
510 //   with   t(z) = z + kLanczosGamma + 1/2
511 //          a(z) = kBaseLanczosCoeff
512 //                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
MaterializeLgamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)513 Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
514                         ValueRange args) {
515   // If the input is less than 0.5 use Euler's reflection formula.
516   //   gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
517   // Let z be
518   //   z = -x      if x < 1/2
519   //   z = x - 1   otheriwse
520   Value x = args.front();
521   const StringAttr kLT = rewriter.getStringAttr(
522       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
523   Value half = getConstantLike(rewriter, loc, 0.5, x);
524   Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
525   Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
526   Value one = getConstantLike(rewriter, loc, 1, x);
527   Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
528   Value z =
529       rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
530 
531   // Materialize
532   //   a(z) = kBaseLanczosCoeff
533   //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
534   Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
535   for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
536     Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
537     Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
538     Value quotient = rewriter.create<mhlo::DivOp>(
539         loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, one_based_index));
540     a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
541   }
542 
543   // To improve accuracy on platforms with less-precise log implementations,
544   // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
545   // device.
546   // Materialize as
547   //   log(t) = log(kLanczosGamma + 1/2 + z)
548   //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
549   Value lanczos_plus_half =
550       getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
551   Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
552   Value log_term =
553       getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
554   Value log1p_term = rewriter.create<mhlo::Log1pOp>(
555       loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
556   Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
557 
558   // Note that t(z) may be large and we need to be careful not to overflow to
559   // infinity in the relevant term
560   //   r = (z + 1/2) * log(t(z)) - t(z).
561   // Therefore, we compute this as
562   //   r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
563   Value t_div_log_t = rewriter.create<mhlo::DivOp>(loc, t, log_t);
564   Value sum = rewriter.create<mhlo::SubOp>(
565       loc, rewriter.create<mhlo::AddOp>(loc, z, half), t_div_log_t);
566   Value r = rewriter.create<mhlo::MulOp>(loc, sum, log_t);
567 
568   // Compute the final result (modulo reflection) as
569   //   lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
570   Value log_a = rewriter.create<mhlo::LogOp>(loc, a);
571   Value lgamma = rewriter.create<mhlo::AddOp>(
572       loc,
573       rewriter.create<mhlo::AddOp>(
574           loc,
575           getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
576           r),
577       log_a);
578 
579   // Compute the reflected value for x < 0.5 as
580   //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
581   //
582   // The abs is needed because lgamma is the log of the absolute value of the
583   // gamma function.
584   //
585   // We have to be careful when computing the final term above. gamma(x) goes
586   // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
587   // term. The slope is large, so precision is particularly important.
588   //
589   // Because abs(sin(pi * x)) has period of 1 we can equivalently use
590   // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
591   // more numerically accurate: It doesn't overflow to inf like pi * x would and
592   // if x is an integer it evaluates to exactly 0 which is important because we
593   // then take the log of this value, and log(0) is inf.
594   //
595   // We don't have a frac(x) primitive in HLO and computing it is tricky, but
596   // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
597   // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
598   //
599   // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
600   // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
601   // [0, 1] is symmetric across the line Y=0.5.
602   //
603 
604   // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
605   // pi * abs_frac for values of abs_frac close to 1.
606   Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
607   Value abs_frac = rewriter.create<mhlo::SubOp>(
608       loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
609   Value reduce_abs_frac =
610       rewriter.create<mhlo::CompareOp>(loc, half, abs_frac, kLT);
611   abs_frac = rewriter.create<mhlo::SelectOp>(
612       loc, reduce_abs_frac, rewriter.create<mhlo::SubOp>(loc, one, abs_frac),
613       abs_frac);
614 
615   // Materialize reflection.
616   Value reflection_denom = rewriter.create<mhlo::LogOp>(
617       loc,
618       rewriter.create<mhlo::SinOp>(
619           loc, rewriter.create<mhlo::MulOp>(
620                    loc, getConstantLike(rewriter, loc, M_PI, x), abs_frac)));
621   Value lgamma_reflection = rewriter.create<mhlo::SubOp>(
622       loc,
623       rewriter.create<mhlo::SubOp>(
624           loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
625           reflection_denom),
626       lgamma);
627 
628   // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
629   // then it "wins" and the result is +/-inf.
630   Value finite_reflection_denom =
631       rewriter.create<mhlo::IsFiniteOp>(loc, reflection_denom);
632   Value neg_reflection_denom =
633       rewriter.create<mhlo::NegOp>(loc, reflection_denom);
634   lgamma_reflection = rewriter.create<mhlo::SelectOp>(
635       loc, finite_reflection_denom, lgamma_reflection, neg_reflection_denom);
636 
637   // Select whether or not to rely on the reflection.
638   lgamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect,
639                                            lgamma_reflection, lgamma);
640 
641   // Materialize +/-inf behavior as
642   //   lgamma(+/-inf) = +inf.
643   Value x_is_inf = rewriter.create<chlo::IsInfOp>(loc, x);
644   return rewriter.create<mhlo::SelectOp>(
645       loc, x_is_inf,
646       chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
647       lgamma);
648 }
649 
650 // Compute the Digamma function using Lanczos' approximation from "A Precision
651 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
652 // series B. Vol. 1:
653 //   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
654 //   with   t(z) = z + kLanczosGamma + 1/2
655 //          a(z) = kBaseLanczosCoeff
656 //                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
657 //          a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
MaterializeDigamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)658 Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
659                          ValueRange args) {
660   // If the input is less than 0.5 use Euler's reflection formula.
661   //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
662   // Let z be
663   //   z = -x      if x < 1/2
664   //   z = x - 1   otheriwse
665   Value x = args.front();
666   const StringAttr kLT = rewriter.getStringAttr(
667       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
668   Value half = getConstantLike(rewriter, loc, 0.5, x);
669   Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
670   Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
671   Value one = getConstantLike(rewriter, loc, 1, x);
672   Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
673   Value z =
674       rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
675 
676   // Materialize
677   //   a(z) = kBaseLanczosCoeff
678   //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
679   //   a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
680   Value zero = getConstantLike(rewriter, loc, 0.0, x);
681   Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
682   Value a_prime = zero;
683   for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
684     Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
685     Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
686     Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index);
687     a_prime = rewriter.create<mhlo::SubOp>(
688         loc, a_prime,
689         rewriter.create<mhlo::DivOp>(
690             loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term)));
691     a = rewriter.create<mhlo::AddOp>(
692         loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term));
693   }
694 
695   // To improve accuracy on platforms with less-precise log implementations,
696   // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
697   // device.
698   // Materialize as
699   //   log(t) = log(kLanczosGamma + 1/2 + z)
700   //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
701   Value lanczos_plus_half =
702       getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
703   Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
704   Value log_term =
705       getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
706   Value log1p_term = rewriter.create<mhlo::Log1pOp>(
707       loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
708   Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
709 
710   // Materialize the final result (modulo reflection) as
711   //   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
712   Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a);
713   Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>(
714       loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
715   Value digamma = rewriter.create<mhlo::SubOp>(
716       loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a),
717       lanczos_gamma_div_t);
718 
719   // We need to be careful how we compute cot(pi * input) below: For
720   // near-integral arguments, pi * input can lose precision.
721   //
722   // Input is already known to be less than 0.5 (otherwise we don't have to
723   // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
724   // increase precision of pi * x and the resulting cotangent.
725   Value reduced_x = rewriter.create<mhlo::AddOp>(
726       loc, x,
727       rewriter.create<mhlo::AbsOp>(
728           loc, rewriter.create<mhlo::FloorOp>(
729                    loc, rewriter.create<mhlo::AddOp>(
730                             loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
731 
732   // Materialize reflection for inputs less than 0.5 as
733   //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
734   //              = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
735   Value pi = getConstantLike(rewriter, loc, M_PI, x);
736   Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x);
737   Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x);
738   Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x);
739   Value reflection = rewriter.create<mhlo::SubOp>(
740       loc, digamma,
741       rewriter.create<mhlo::DivOp>(
742           loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
743 
744   // Select whether or not to rely on the reflection.
745   digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection,
746                                             digamma);
747 
748   // Digamma has poles at negative integers and zero; return nan for those.
749   const StringAttr kLE = rewriter.getStringAttr(
750       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
751   Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE);
752   const StringAttr kEQ = rewriter.getStringAttr(
753       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
754   Value is_int = rewriter.create<mhlo::CompareOp>(
755       loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
756   Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int);
757   return rewriter.create<mhlo::SelectOp>(
758       loc, is_pole,
759       getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
760                       x),
761       digamma);
762 }
763 
MaterializeZeta(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)764 Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
765                       ValueRange args) {
766   assert(args.size() == 2);
767   Value x = args[0];
768   Value q = args[1];
769   static const std::array<double, 12> kZetaCoeffs{
770       -7.1661652561756670113e18,
771       1.8152105401943546773e17,
772       -4.5979787224074726105e15,
773       1.1646782814350067249e14,
774       -2.950130727918164224e12,
775       7.47242496e10,
776       -1.8924375803183791606e9,
777       47900160.0,
778       -1209600.0,
779       30240.0,
780       -720.0,
781       12.0,
782   };
783 
784   // For speed we'll always use 9 iterations for the initial series estimate,
785   // and a 12 term expansion for the Euler-Maclaurin formula.
786   Value a = q;
787   Value zero_like_a = chlo::getConstantLike(rewriter, loc, 0.0, a);
788   Value neg_power = zero_like_a;
789   Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
790   Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
791   Value one_like_a = chlo::getConstantLike(rewriter, loc, 1.0, a);
792   for (int i = 0; i < 9; ++i) {
793     a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a);
794     neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
795     initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
796   }
797   a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a);
798   neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
799   Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
800   Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
801   Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a);
802   Value neg_power_mul_a_div_x_minus_one =
803       rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one);
804   Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
805                                          neg_power_mul_a_div_x_minus_one);
806   Value a_inverse_square = rewriter.create<mhlo::DivOp>(
807       loc, one_like_a, rewriter.create<mhlo::MulOp>(loc, a, a));
808 
809   Value horner_sum = zero_like_a;
810   Value factor = one_like_a;
811   // Use Horner's rule for this.
812   // Note this differs from Cephes which does a 'naive' polynomial evaluation.
813   // Using Horner's rule allows to avoid some NaN's and Infs from happening,
814   // resulting in more numerically stable code.
815   for (int i = 0; i < 11; ++i) {
816     Value factor_lhs = rewriter.create<mhlo::SubOp>(
817         loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
818     Value factor_rhs = rewriter.create<mhlo::SubOp>(
819         loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
820     factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs);
821     horner_sum = rewriter.create<mhlo::MulOp>(
822         loc, factor,
823         rewriter.create<mhlo::MulOp>(
824             loc, a_inverse_square,
825             rewriter.create<mhlo::AddOp>(
826                 loc, horner_sum,
827                 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
828   }
829   Value zero_point_five_like_neg_power =
830       chlo::getConstantLike(rewriter, loc, .5, neg_power);
831   Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a);
832   s = rewriter.create<mhlo::AddOp>(
833       loc, s,
834       rewriter.create<mhlo::MulOp>(
835           loc, neg_power,
836           rewriter.create<mhlo::AddOp>(
837               loc, zero_point_five_like_neg_power,
838               rewriter.create<mhlo::MulOp>(
839                   loc, x_div_a,
840                   rewriter.create<mhlo::AddOp>(
841                       loc,
842                       chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
843                                             a),
844                       horner_sum)))));
845   const double nan = std::numeric_limits<double>::quiet_NaN();
846   const double inf = std::numeric_limits<double>::infinity();
847   // Use the initial zeta sum without the correction term coming
848   // from Euler-Maclaurin if it is accurate enough.
849   const StringAttr kLT = rewriter.getStringAttr(
850       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
851   Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power);
852   Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum);
853   Value output = rewriter.create<mhlo::SelectOp>(
854       loc,
855       rewriter.create<mhlo::CompareOp>(
856           loc, abs_neg_power,
857           rewriter.create<mhlo::MulOp>(
858               loc, abs_initial_sum,
859               chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
860           kLT),
861       initial_sum, s);
862   // This is the harmonic series.
863   const StringAttr kEQ = rewriter.getStringAttr(
864       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
865   Value inf_like_x = chlo::getConstantLike(rewriter, loc, inf, x);
866   output = rewriter.create<mhlo::SelectOp>(
867       loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kEQ),
868       inf_like_x, output);
869   // Function is not defined for x < 1.
870   Value nan_like_x = chlo::getConstantLike(rewriter, loc, nan, x);
871   output = rewriter.create<mhlo::SelectOp>(
872       loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT),
873       nan_like_x, output);
874   // If q <= 0, then when q is an integer or x is not an integer, this is
875   // NaN.
876   const StringAttr kLE = rewriter.getStringAttr(
877       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
878   const StringAttr kNE = rewriter.getStringAttr(
879       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
880   Value zero_like_q = chlo::getConstantLike(rewriter, loc, 0.0, q);
881   Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero_like_q, kLE);
882   Value domain_error = rewriter.create<mhlo::AndOp>(
883       loc, q_le_zero,
884       rewriter.create<mhlo::CompareOp>(
885           loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE));
886   Value negative_integer_q = rewriter.create<mhlo::AndOp>(
887       loc, q_le_zero,
888       rewriter.create<mhlo::CompareOp>(
889           loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ));
890   output = rewriter.create<mhlo::SelectOp>(loc, negative_integer_q, inf_like_x,
891                                            output);
892   output =
893       rewriter.create<mhlo::SelectOp>(loc, domain_error, nan_like_x, output);
894   return output;
895 }
896 
MaterializePolygamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)897 Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
898                            ValueRange args) {
899   PolygammaOp::Adaptor transformed(args);
900   Value n = transformed.n();
901   Value x = transformed.x();
902 
903   // Handle integer n > 0.
904   Value one = getConstantLike(rewriter, loc, 1.0, x);
905   Value two = getConstantLike(rewriter, loc, 2.0, x);
906   Value sign = rewriter.create<mhlo::SubOp>(
907       loc,
908       rewriter.create<mhlo::MulOp>(loc, two,
909                                    rewriter.create<mhlo::RemOp>(loc, n, two)),
910       one);
911   Value n_plus_one = rewriter.create<mhlo::AddOp>(loc, n, one);
912   Value exp_lgamma_np1 = rewriter.create<mhlo::ExpOp>(
913       loc, rewriter.create<chlo::LgammaOp>(loc, n_plus_one));
914   Value zeta = rewriter.create<chlo::ZetaOp>(loc, n_plus_one, x);
915   Value result = rewriter.create<mhlo::MulOp>(
916       loc, rewriter.create<mhlo::MulOp>(loc, sign, exp_lgamma_np1), zeta);
917 
918   // Handle n = 0.
919   const StringAttr kEQ = rewriter.getStringAttr(
920       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
921   Value zero = getConstantLike(rewriter, loc, 0.0, x);
922   Value n_eq_zero = rewriter.create<mhlo::CompareOp>(loc, n, zero, kEQ);
923   result = rewriter.create<mhlo::SelectOp>(
924       loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
925 
926   // Check that n is a natural number.
927   const StringAttr kNE = rewriter.getStringAttr(
928       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
929   Value non_int = rewriter.create<mhlo::CompareOp>(
930       loc, n, rewriter.create<mhlo::FloorOp>(loc, n), kNE);
931   const StringAttr kLT = rewriter.getStringAttr(
932       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
933   Value negative = rewriter.create<mhlo::CompareOp>(loc, n, zero, kLT);
934   Value non_natural = rewriter.create<mhlo::OrOp>(loc, non_int, negative);
935   return rewriter.create<mhlo::SelectOp>(
936       loc, non_natural,
937       getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
938                       x),
939       result);
940 }
941 
942 struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
943   using OpConversionPattern<LgammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertLgammaOp944   LogicalResult matchAndRewrite(
945       LgammaOp op, ArrayRef<Value> operands,
946       ConversionPatternRewriter &rewriter) const override {
947     FloatType min_precision_ty = rewriter.getF32Type();
948     rewriter.replaceOp(
949         op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
950                                   min_precision_ty, &MaterializeLgamma));
951     return success();
952   }
953 };
954 
955 struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
956   using OpConversionPattern<DigammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertDigammaOp957   LogicalResult matchAndRewrite(
958       DigammaOp op, ArrayRef<Value> operands,
959       ConversionPatternRewriter &rewriter) const override {
960     FloatType min_precision_ty = rewriter.getF32Type();
961     rewriter.replaceOp(
962         op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
963                                   min_precision_ty, &MaterializeDigamma));
964     return success();
965   }
966 };
967 
968 struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
969   using OpConversionPattern<PolygammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertPolygammaOp970   LogicalResult matchAndRewrite(
971       PolygammaOp op, ArrayRef<Value> operands,
972       ConversionPatternRewriter &rewriter) const override {
973     Location loc = op.getLoc();
974     FloatType min_precision_ty = rewriter.getF32Type();
975     rewriter.replaceOp(
976         op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
977                                   &MaterializePolygamma));
978     return success();
979   }
980 };
981 
982 struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
983   using OpConversionPattern<ZetaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertZetaOp984   LogicalResult matchAndRewrite(
985       ZetaOp op, ArrayRef<Value> operands,
986       ConversionPatternRewriter &rewriter) const override {
987     Location loc = op.getLoc();
988     FloatType min_precision_ty = rewriter.getF32Type();
989     rewriter.replaceOp(
990         op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
991                                   &MaterializeZeta));
992     return success();
993   }
994 };
995 
996 // Converts binary ops that statically are determined to not broadcast directly
997 // to the corresponding mhlo non-broadcasting op.
998 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
999 struct ConvertTrivialNonBroadcastBinaryOp
1000     : public OpConversionPattern<ChloOpTy> {
1001   using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertTrivialNonBroadcastBinaryOp1002   LogicalResult matchAndRewrite(
1003       ChloOpTy op, ArrayRef<Value> operands,
1004       ConversionPatternRewriter &rewriter) const override {
1005     // Only rewrite for statically determinable non-broadcasting cases.
1006     typename ChloOpTy::Adaptor transformed(operands);
1007     auto lhs_type =
1008         transformed.lhs().getType().template dyn_cast<RankedTensorType>();
1009     auto rhs_type =
1010         transformed.rhs().getType().template dyn_cast<RankedTensorType>();
1011     if (!lhs_type || !rhs_type) return failure();
1012 
1013     // Requires rank broadcast.
1014     if (lhs_type.getRank() != rhs_type.getRank()) return failure();
1015     // Any dynamic dimension may require broadcasting and requires more
1016     // analysis.
1017     if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape())
1018       return failure();
1019 
1020     for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) {
1021       auto lhs_extent = std::get<0>(extents);
1022       auto rhs_extent = std::get<1>(extents);
1023       if (lhs_extent != rhs_extent) {
1024         return failure();
1025       }
1026     }
1027 
1028     rewriter.replaceOp(
1029         op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0],
1030                                operands[1], rewriter)});
1031     return success();
1032   }
1033 };
1034 
1035 // Converts a binary op with ranked broadcasting operands to explicitly
1036 // broadcast and invoke the corresponding mhlo non-broadcasting op.
1037 // Note that dynamic broadcasting supported by this pattern is only valid for
1038 // "numpy" broadcasting semantics as defined here:
1039 //   https://docs.scipy.org/doc/numpy/reference/ufuncs.html
1040 // Specifically, this includes the following cases:
1041 //   - Same rank broadcast (operands have the same static rank).
1042 //   - Different-rank broadcast, either without a broadcast_dims attribte or
1043 //     with the broadcast_dims attribute set to map to a prefix padding.
1044 //   - Legal combinations of degenerate (1-dim) implicit broadcasting.
1045 // The restriction on broadcast_dims derives from the definition of the
1046 // `shape.broadcast` op, which only supports prefix-padding.
1047 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1048 struct ConvertRankedDynamicBroadcastBinaryOp
1049     : public OpConversionPattern<ChloOpTy> {
1050   using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anonca6190260111::ConvertRankedDynamicBroadcastBinaryOp1051   LogicalResult matchAndRewrite(
1052       ChloOpTy op, ArrayRef<Value> operands,
1053       ConversionPatternRewriter &rewriter) const override {
1054     // Only support ranked operands.
1055     typename ChloOpTy::Adaptor transformed(operands);
1056     Value lhs = transformed.lhs();
1057     Value rhs = transformed.rhs();
1058     auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
1059     auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
1060     auto result_type =
1061         op.getResult().getType().template dyn_cast<RankedTensorType>();
1062     if (!lhs_type || !rhs_type || !result_type) return failure();
1063 
1064     // Check for "numpy"-style rank broadcast.
1065     auto broadcast_dimensions = op.broadcast_dimensions();
1066     if (broadcast_dimensions &&
1067         !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
1068       // Note: It is unclear whether the general specification of explicit
1069       // broadcast_dimensions on binary ops is a feature we want to carry
1070       // forward. While it can technically be implemented for ranked-dynamic,
1071       // it is incompatible with unranked inputs. If this warning is emitted
1072       // in real programs, it is an indication that the feature should be
1073       // implemented versus just falling back on the more standard definition
1074       // of numpy-like prefix-padding.
1075       op.emitWarning() << "unsupported non prefix-padded dynamic rank "
1076                        << "broadcast_dimensions = " << *broadcast_dimensions;
1077       return failure();
1078     }
1079 
1080     // Compute result shape.
1081     auto loc = op.getLoc();
1082 
1083     // Insert a constraint on the shapes being broadcastable and insert all
1084     // future code into an assuming block reliant on the constraint.
1085     Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
1086     Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
1087     auto broadcastable_cstr =
1088         rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
1089     auto assuming_op = rewriter.create<shape::AssumingOp>(
1090         loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
1091 
1092     OpBuilder::InsertionGuard guard(rewriter);
1093     rewriter.createBlock(&assuming_op.doRegion());
1094 
1095     int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
1096     Value result_extents =
1097         hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
1098             loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true);
1099 
1100     // Note that we unconditionally emit DynamicBroadcastInDim ops and let
1101     // downstream canonicalizations fold them away if possible. This is
1102     // because, in the dynamic case, there are many corner cases regarding
1103     // when it is safe to omit, and some of them require analysis to prove
1104     // properly.
1105     auto lhs_broadcast_dimensions = llvm::to_vector<4>(
1106         llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
1107     Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1108         loc,
1109         RankedTensorType::get(result_type.getShape(),
1110                               lhs_type.getElementType()),
1111         lhs, result_extents,
1112         rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
1113     auto rhs_broadcast_dimensions = llvm::to_vector<4>(
1114         llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
1115     Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1116         loc,
1117         RankedTensorType::get(result_type.getShape(),
1118                               rhs_type.getElementType()),
1119         rhs, result_extents,
1120         rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
1121 
1122     // And generate the final non-broadcasted binary op.
1123     Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
1124                                            broadcasted_rhs, rewriter);
1125     rewriter.create<shape::AssumingYieldOp>(loc, final_result);
1126     rewriter.replaceOp(op, {assuming_op.getResult(0)});
1127     return success();
1128   }
1129 };
1130 
1131 #include "generated_chlo_legalize_to_hlo.inc"
1132 }  // namespace
1133 
PopulateChloBroadcastingPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1134 void PopulateChloBroadcastingPatterns(MLIRContext *context,
1135                                       OwningRewritePatternList *patterns) {
1136   // Instantiate conversion templates for conforming binary elementwise ops
1137   // that do not have different dtypes between operands and results and do
1138   // not have special attributes that need to be preserved.
1139   PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
1140       context, patterns, 10);
1141   PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
1142       context, patterns, 5);
1143 }
1144 
PopulateLegalizeChloToHloPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1145 void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
1146                                        OwningRewritePatternList *patterns) {
1147   populateWithGenerated(context, *patterns);
1148   PopulateChloBroadcastingPatterns(context, patterns);
1149 
1150   // Other patterns.
1151   // clang-format off
1152   patterns->insert<ConvertConstantLikeOp,
1153                    ConvertDigammaOp,
1154                    ConvertErfOp,
1155                    ConvertErfcOp,
1156                    ConvertLgammaOp,
1157                    ConvertPolygammaOp,
1158                    ConvertZetaOp>(context);
1159   // clang-format on
1160 }
1161 
1162 }  // namespace chlo
1163 }  // namespace mlir
1164