1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements the lowering for trigonometric standard ops to
17 // approximations.
18 
19 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
21 #include "mlir/Dialect/Math/IR/Math.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 namespace mlir {
28 namespace mhlo {
29 namespace {
30 
31 template <typename OpTy>
32 class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
33  public:
ApproximateOnExtendedF32Lowering(MLIRContext * ctx)34   explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx)
35       : OpRewritePattern<OpTy>(ctx, /*benefit=*/100) {}
36 
37   virtual Value emitApproximation(ValueRange, Location,
38                                   PatternRewriter &) const = 0;
39 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const40   LogicalResult matchAndRewrite(OpTy op,
41                                 PatternRewriter &rewriter) const override {
42     Location loc = op.getLoc();
43     auto raw_args = op.getOperation()->getOperands();
44 
45     // Supports only f16 and f32 for now.
46     if (!op.getType().isF16() && !op.getType().isF32()) return failure();
47 
48     // Extend operands to f32 if needed and possible.
49     SmallVector<Value, 2> f32_args;
50     f32_args.reserve(raw_args.size());
51     for (Value arg : raw_args) {
52       // Similar to XLA, do not rewrite f64 as precision might matter.
53       Type arg_ty = arg.getType();
54       if (arg_ty.isF64()) return failure();
55 
56       if (arg_ty.isF16())
57         arg = rewriter.create<FPExtOp>(loc, arg, rewriter.getF32Type());
58 
59       // If we still do not have f32, fail.
60       if (!arg.getType().isF32()) return failure();
61 
62       f32_args.push_back(arg);
63     }
64 
65     Value result = emitApproximation(f32_args, loc, rewriter);
66     assert(result.getType().isF32() && "Expect f32 intermediate result.");
67 
68     // Truncate back if needed.
69     if (op.getType().isF16())
70       result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
71 
72     rewriter.replaceOp(op, {result});
73     return success();
74   }
75 };
76 
77 class ApproximateTanhLowering
78     : public ApproximateOnExtendedF32Lowering<math::TanhOp> {
79  public:
ApproximateTanhLowering(MLIRContext * ctx)80   explicit ApproximateTanhLowering(MLIRContext *ctx)
81       : ApproximateOnExtendedF32Lowering<math::TanhOp>(ctx) {}
82 
83   // Emits the fast tanh approximation that is also used by XLA.
emitApproximation(ValueRange args,Location loc,PatternRewriter & rewriter) const84   Value emitApproximation(ValueRange args, Location loc,
85                           PatternRewriter &rewriter) const override {
86     // For small values of x, we can approximate tanh(x) = x.  For extremely
87     // small values of x (|x| < 1e-37), the other approximation would evaluate
88     // tanh(x) = 0.
89     Value input = args.front();
90     assert(input.getType().isF32());
91     constexpr float kCanUseApprox = 0.0004;
92     Value abs_value = rewriter.create<AbsFOp>(loc, input);
93     Value can_use_approx = rewriter.create<ConstantOp>(
94         loc, rewriter.getF32FloatAttr(kCanUseApprox));
95     Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
96                                                  abs_value, can_use_approx);
97     // Clamp the input to [-c, c].
98     Value max_clamp = rewriter.create<ConstantOp>(
99         loc, rewriter.getF32FloatAttr(7.90531110763549805f));
100     Value smaller_than_max =
101         rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
102     Value clamped_half =
103         rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
104     Value min_clamp = rewriter.create<ConstantOp>(
105         loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
106     Value larger_than_min = rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE,
107                                                     clamped_half, min_clamp);
108     Value input_clamped = rewriter.create<SelectOp>(loc, larger_than_min,
109                                                     clamped_half, min_clamp);
110 
111     static constexpr std::array<float, 7> numerator_coeffs{
112         -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
113         5.12229709037114e-08f,  1.48572235717979e-05f, 6.37261928875436e-04f,
114         4.89352455891786e-03f};
115 
116     static constexpr std::array<float, 4> denominator_coeffs{
117         1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
118         4.89352518554385e-03f};
119 
120     Value input_squared =
121         rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
122     Value numerator = rewriter.create<ConstantOp>(
123         loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
124     for (int i = 1; i < numerator_coeffs.size(); i++) {
125       numerator = rewriter.create<AddFOp>(
126           loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
127           rewriter.create<ConstantOp>(
128               loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
129     }
130 
131     numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
132 
133     Value denominator = rewriter.create<ConstantOp>(
134         loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
135     for (int i = 1; i < denominator_coeffs.size(); i++) {
136       denominator = rewriter.create<AddFOp>(
137           loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
138           rewriter.create<ConstantOp>(
139               loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
140     }
141 
142     Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
143 
144     return rewriter.create<SelectOp>(loc, return_input, input, approx);
145   }
146 };
147 
148 struct LegalizeTrigonometricToApproximationPass
149     : public PassWrapper<LegalizeTrigonometricToApproximationPass,
150                          FunctionPass> {
151   /// Perform the lowering of standard dialect operations to approximations.
runOnFunctionmlir::mhlo::__anonc050ebae0111::LegalizeTrigonometricToApproximationPass152   void runOnFunction() override {
153     OwningRewritePatternList patterns;
154     PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns);
155     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
156   }
157 };
158 
159 }  // anonymous namespace
160 
161 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeTrigonometricToApproximationPass()162 createLegalizeTrigonometricToApproximationPass() {
163   return std::make_unique<LegalizeTrigonometricToApproximationPass>();
164 }
165 
PopulateTrigonometricToApproximationPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)166 void PopulateTrigonometricToApproximationPatterns(
167     mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
168   // clang-format off
169   patterns->insert<ApproximateTanhLowering>(context);
170   // clang-format on
171 }
172 
173 }  // namespace mhlo
174 }  // namespace mlir
175