1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements the lowering for trigonometric standard ops to
17 // approximations.
18
19 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
21 #include "mlir/Dialect/Math/IR/Math.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27 namespace mlir {
28 namespace mhlo {
29 namespace {
30
31 template <typename OpTy>
32 class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
33 public:
ApproximateOnExtendedF32Lowering(MLIRContext * ctx)34 explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx)
35 : OpRewritePattern<OpTy>(ctx, /*benefit=*/100) {}
36
37 virtual Value emitApproximation(ValueRange, Location,
38 PatternRewriter &) const = 0;
39
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const40 LogicalResult matchAndRewrite(OpTy op,
41 PatternRewriter &rewriter) const override {
42 Location loc = op.getLoc();
43 auto raw_args = op.getOperation()->getOperands();
44
45 // Supports only f16 and f32 for now.
46 if (!op.getType().isF16() && !op.getType().isF32()) return failure();
47
48 // Extend operands to f32 if needed and possible.
49 SmallVector<Value, 2> f32_args;
50 f32_args.reserve(raw_args.size());
51 for (Value arg : raw_args) {
52 // Similar to XLA, do not rewrite f64 as precision might matter.
53 Type arg_ty = arg.getType();
54 if (arg_ty.isF64()) return failure();
55
56 if (arg_ty.isF16())
57 arg = rewriter.create<FPExtOp>(loc, arg, rewriter.getF32Type());
58
59 // If we still do not have f32, fail.
60 if (!arg.getType().isF32()) return failure();
61
62 f32_args.push_back(arg);
63 }
64
65 Value result = emitApproximation(f32_args, loc, rewriter);
66 assert(result.getType().isF32() && "Expect f32 intermediate result.");
67
68 // Truncate back if needed.
69 if (op.getType().isF16())
70 result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
71
72 rewriter.replaceOp(op, {result});
73 return success();
74 }
75 };
76
77 class ApproximateTanhLowering
78 : public ApproximateOnExtendedF32Lowering<math::TanhOp> {
79 public:
ApproximateTanhLowering(MLIRContext * ctx)80 explicit ApproximateTanhLowering(MLIRContext *ctx)
81 : ApproximateOnExtendedF32Lowering<math::TanhOp>(ctx) {}
82
83 // Emits the fast tanh approximation that is also used by XLA.
emitApproximation(ValueRange args,Location loc,PatternRewriter & rewriter) const84 Value emitApproximation(ValueRange args, Location loc,
85 PatternRewriter &rewriter) const override {
86 // For small values of x, we can approximate tanh(x) = x. For extremely
87 // small values of x (|x| < 1e-37), the other approximation would evaluate
88 // tanh(x) = 0.
89 Value input = args.front();
90 assert(input.getType().isF32());
91 constexpr float kCanUseApprox = 0.0004;
92 Value abs_value = rewriter.create<AbsFOp>(loc, input);
93 Value can_use_approx = rewriter.create<ConstantOp>(
94 loc, rewriter.getF32FloatAttr(kCanUseApprox));
95 Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
96 abs_value, can_use_approx);
97 // Clamp the input to [-c, c].
98 Value max_clamp = rewriter.create<ConstantOp>(
99 loc, rewriter.getF32FloatAttr(7.90531110763549805f));
100 Value smaller_than_max =
101 rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
102 Value clamped_half =
103 rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
104 Value min_clamp = rewriter.create<ConstantOp>(
105 loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
106 Value larger_than_min = rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE,
107 clamped_half, min_clamp);
108 Value input_clamped = rewriter.create<SelectOp>(loc, larger_than_min,
109 clamped_half, min_clamp);
110
111 static constexpr std::array<float, 7> numerator_coeffs{
112 -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
113 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
114 4.89352455891786e-03f};
115
116 static constexpr std::array<float, 4> denominator_coeffs{
117 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
118 4.89352518554385e-03f};
119
120 Value input_squared =
121 rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
122 Value numerator = rewriter.create<ConstantOp>(
123 loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
124 for (int i = 1; i < numerator_coeffs.size(); i++) {
125 numerator = rewriter.create<AddFOp>(
126 loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
127 rewriter.create<ConstantOp>(
128 loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
129 }
130
131 numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
132
133 Value denominator = rewriter.create<ConstantOp>(
134 loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
135 for (int i = 1; i < denominator_coeffs.size(); i++) {
136 denominator = rewriter.create<AddFOp>(
137 loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
138 rewriter.create<ConstantOp>(
139 loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
140 }
141
142 Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
143
144 return rewriter.create<SelectOp>(loc, return_input, input, approx);
145 }
146 };
147
148 struct LegalizeTrigonometricToApproximationPass
149 : public PassWrapper<LegalizeTrigonometricToApproximationPass,
150 FunctionPass> {
151 /// Perform the lowering of standard dialect operations to approximations.
runOnFunctionmlir::mhlo::__anonc050ebae0111::LegalizeTrigonometricToApproximationPass152 void runOnFunction() override {
153 OwningRewritePatternList patterns;
154 PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns);
155 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
156 }
157 };
158
159 } // anonymous namespace
160
161 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeTrigonometricToApproximationPass()162 createLegalizeTrigonometricToApproximationPass() {
163 return std::make_unique<LegalizeTrigonometricToApproximationPass>();
164 }
165
PopulateTrigonometricToApproximationPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)166 void PopulateTrigonometricToApproximationPatterns(
167 mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
168 // clang-format off
169 patterns->insert<ApproximateTanhLowering>(context);
170 // clang-format on
171 }
172
173 } // namespace mhlo
174 } // namespace mlir
175