1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17
18 #include "llvm/ADT/StringSwitch.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27 namespace mlir {
28 namespace {
29 #include "generated_legalize_to_standard.inc"
30 } // end anonymous namespace
31 namespace mhlo {
32 namespace {
33
34 class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
35 public:
36 using OpRewritePattern::OpRewritePattern;
37
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const38 LogicalResult matchAndRewrite(mhlo::CompareOp op,
39 PatternRewriter &rewriter) const override {
40 auto lhs = op.lhs();
41 auto rhs = op.rhs();
42 auto lhs_type = lhs.getType().cast<TensorType>();
43 auto rhs_type = rhs.getType().cast<TensorType>();
44
45 // Broadcasting not supported by this rewrite.
46 if (lhs_type.getShape() != rhs_type.getShape()) return failure();
47
48 if (!lhs_type.getElementType().isSignlessInteger() ||
49 !rhs_type.getElementType().isSignlessInteger())
50 return failure();
51
52 auto comparison_direction = op.comparison_direction();
53 auto compare_predicate =
54 llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
55 .Case("EQ", CmpIPredicate::eq)
56 .Case("NE", CmpIPredicate::ne)
57 .Case("LT", CmpIPredicate::slt)
58 .Case("LE", CmpIPredicate::sle)
59 .Case("GT", CmpIPredicate::sgt)
60 .Case("GE", CmpIPredicate::sge)
61 .Default(llvm::None);
62
63 if (!compare_predicate.hasValue()) return failure();
64
65 rewriter.replaceOpWithNewOp<CmpIOp>(op, compare_predicate.getValue(), lhs,
66 rhs);
67 return success();
68 }
69 };
70
71 class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
72 public:
73 using OpRewritePattern::OpRewritePattern;
74
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const75 LogicalResult matchAndRewrite(mhlo::CompareOp op,
76 PatternRewriter &rewriter) const override {
77 auto lhs = op.lhs();
78 auto rhs = op.rhs();
79 auto lhs_type = lhs.getType().cast<TensorType>();
80 auto rhs_type = rhs.getType().cast<TensorType>();
81
82 // Broadcasting not supported by this rewrite.
83 if (lhs_type.getShape() != rhs_type.getShape()) return failure();
84
85 if (!lhs_type.getElementType().isa<FloatType>() ||
86 !rhs_type.getElementType().isa<FloatType>())
87 return failure();
88
89 auto comparison_direction = op.comparison_direction();
90 auto compare_predicate =
91 llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
92 .Case("EQ", CmpFPredicate::OEQ)
93 .Case("NE", CmpFPredicate::UNE)
94 .Case("LT", CmpFPredicate::OLT)
95 .Case("LE", CmpFPredicate::OLE)
96 .Case("GT", CmpFPredicate::OGT)
97 .Case("GE", CmpFPredicate::OGE)
98 .Default(llvm::None);
99
100 if (!compare_predicate.hasValue()) return failure();
101
102 rewriter.replaceOpWithNewOp<CmpFOp>(op, compare_predicate.getValue(), lhs,
103 rhs);
104 return success();
105 }
106 };
107
108 // Replace IotaOp with an integer constant. A ConvertOp is added to
109 // convert the integer constant to iota result type. For complex types, the real
110 // part is replaced with the generated constant and the imaginary part is
111 // replaced with zero tensor.
112 class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
113 public:
114 using OpRewritePattern::OpRewritePattern;
115
matchAndRewrite(mhlo::IotaOp op,PatternRewriter & rewriter) const116 LogicalResult matchAndRewrite(mhlo::IotaOp op,
117 PatternRewriter &rewriter) const override {
118 auto output_type = op.getType().cast<ShapedType>();
119 auto output_size = output_type.getNumElements();
120 auto dimension = op.iota_dimension();
121 auto max_dim_size = output_type.getDimSize(dimension);
122
123 auto element_type = output_type.getElementType();
124 int bitwidth;
125
126 auto complex_ty = element_type.dyn_cast<ComplexType>();
127 Type int_or_float_ty = element_type;
128 if (complex_ty) int_or_float_ty = complex_ty.getElementType();
129
130 bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
131 llvm::SmallVector<APInt, 10> values;
132 values.reserve(output_size);
133
134 int64_t increase_stride = output_size;
135 for (int i = 0; i <= dimension; i++) {
136 increase_stride /= output_type.getDimSize(i);
137 }
138
139 int64_t current_value = 0;
140 for (int i = 0; i < output_size; i++) {
141 int64_t value = (current_value / increase_stride) % max_dim_size;
142 values.push_back(APInt(bitwidth, value));
143 ++current_value;
144 }
145
146 auto int_shape_type = RankedTensorType::get(
147 output_type.getShape(),
148 IntegerType::get(rewriter.getContext(), bitwidth));
149 auto loc = op.getLoc();
150 auto integer_const = rewriter.create<mlir::ConstantOp>(
151 loc, DenseIntElementsAttr::get(int_shape_type, values));
152
153 auto int_or_float_shape_ty =
154 RankedTensorType::get(output_type.getShape(), int_or_float_ty);
155
156 auto iota_const =
157 rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
158
159 // For int/float types we are done, replace op and return.
160 if (!complex_ty) {
161 rewriter.replaceOp(op, iota_const.getResult());
162 return success();
163 }
164
165 // For complex types, generate a constant tensor of zeroes for the imaginary
166 // part and use iota_const for real part.
167 auto zeroes = rewriter.create<mlir::ConstantOp>(
168 loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
169 auto imag_zeroes =
170 rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
171 rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
172 return success();
173 }
174 };
175
176 } // end anonymous namespace
177
178 namespace {
179 struct LegalizeToStandardPass
180 : public PassWrapper<LegalizeToStandardPass, FunctionPass> {
getDependentDialectsmlir::mhlo::__anona68d694f0311::LegalizeToStandardPass181 void getDependentDialects(DialectRegistry ®istry) const override {
182 registry.insert<StandardOpsDialect>();
183 }
184
185 /// Perform the lowering to Standard dialect.
186 void runOnFunction() override;
187 };
188 } // end anonymous namespace
189
createLegalizeToStdPass()190 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
191 return std::make_unique<LegalizeToStandardPass>();
192 }
193
PopulateMhloToStdPatterns(OwningRewritePatternList * patterns,mlir::MLIRContext * ctx)194 void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
195 mlir::MLIRContext *ctx) {
196 mlir::populateWithGenerated(ctx, *patterns);
197 patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
198 }
199
200 /// Perform the lowering to standard dialect.
runOnFunction()201 void LegalizeToStandardPass::runOnFunction() {
202 OwningRewritePatternList patterns;
203 mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
204 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
205 }
206
207 } // end namespace mhlo
208 } // end namespace mlir
209