1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17 
18 #include "llvm/ADT/StringSwitch.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 namespace mlir {
28 namespace {
29 #include "generated_legalize_to_standard.inc"
30 }  // end anonymous namespace
31 namespace mhlo {
32 namespace {
33 
34 class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
35  public:
36   using OpRewritePattern::OpRewritePattern;
37 
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const38   LogicalResult matchAndRewrite(mhlo::CompareOp op,
39                                 PatternRewriter &rewriter) const override {
40     auto lhs = op.lhs();
41     auto rhs = op.rhs();
42     auto lhs_type = lhs.getType().cast<TensorType>();
43     auto rhs_type = rhs.getType().cast<TensorType>();
44 
45     // Broadcasting not supported by this rewrite.
46     if (lhs_type.getShape() != rhs_type.getShape()) return failure();
47 
48     if (!lhs_type.getElementType().isSignlessInteger() ||
49         !rhs_type.getElementType().isSignlessInteger())
50       return failure();
51 
52     auto comparison_direction = op.comparison_direction();
53     auto compare_predicate =
54         llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
55             .Case("EQ", CmpIPredicate::eq)
56             .Case("NE", CmpIPredicate::ne)
57             .Case("LT", CmpIPredicate::slt)
58             .Case("LE", CmpIPredicate::sle)
59             .Case("GT", CmpIPredicate::sgt)
60             .Case("GE", CmpIPredicate::sge)
61             .Default(llvm::None);
62 
63     if (!compare_predicate.hasValue()) return failure();
64 
65     rewriter.replaceOpWithNewOp<CmpIOp>(op, compare_predicate.getValue(), lhs,
66                                         rhs);
67     return success();
68   }
69 };
70 
71 class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
72  public:
73   using OpRewritePattern::OpRewritePattern;
74 
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const75   LogicalResult matchAndRewrite(mhlo::CompareOp op,
76                                 PatternRewriter &rewriter) const override {
77     auto lhs = op.lhs();
78     auto rhs = op.rhs();
79     auto lhs_type = lhs.getType().cast<TensorType>();
80     auto rhs_type = rhs.getType().cast<TensorType>();
81 
82     // Broadcasting not supported by this rewrite.
83     if (lhs_type.getShape() != rhs_type.getShape()) return failure();
84 
85     if (!lhs_type.getElementType().isa<FloatType>() ||
86         !rhs_type.getElementType().isa<FloatType>())
87       return failure();
88 
89     auto comparison_direction = op.comparison_direction();
90     auto compare_predicate =
91         llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
92             .Case("EQ", CmpFPredicate::OEQ)
93             .Case("NE", CmpFPredicate::UNE)
94             .Case("LT", CmpFPredicate::OLT)
95             .Case("LE", CmpFPredicate::OLE)
96             .Case("GT", CmpFPredicate::OGT)
97             .Case("GE", CmpFPredicate::OGE)
98             .Default(llvm::None);
99 
100     if (!compare_predicate.hasValue()) return failure();
101 
102     rewriter.replaceOpWithNewOp<CmpFOp>(op, compare_predicate.getValue(), lhs,
103                                         rhs);
104     return success();
105   }
106 };
107 
108 // Replace IotaOp with an integer constant. A ConvertOp is added to
109 // convert the integer constant to iota result type. For complex types, the real
110 // part is replaced with the generated constant and the imaginary part is
111 // replaced with zero tensor.
112 class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
113  public:
114   using OpRewritePattern::OpRewritePattern;
115 
matchAndRewrite(mhlo::IotaOp op,PatternRewriter & rewriter) const116   LogicalResult matchAndRewrite(mhlo::IotaOp op,
117                                 PatternRewriter &rewriter) const override {
118     auto output_type = op.getType().cast<ShapedType>();
119     auto output_size = output_type.getNumElements();
120     auto dimension = op.iota_dimension();
121     auto max_dim_size = output_type.getDimSize(dimension);
122 
123     auto element_type = output_type.getElementType();
124     int bitwidth;
125 
126     auto complex_ty = element_type.dyn_cast<ComplexType>();
127     Type int_or_float_ty = element_type;
128     if (complex_ty) int_or_float_ty = complex_ty.getElementType();
129 
130     bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
131     llvm::SmallVector<APInt, 10> values;
132     values.reserve(output_size);
133 
134     int64_t increase_stride = output_size;
135     for (int i = 0; i <= dimension; i++) {
136       increase_stride /= output_type.getDimSize(i);
137     }
138 
139     int64_t current_value = 0;
140     for (int i = 0; i < output_size; i++) {
141       int64_t value = (current_value / increase_stride) % max_dim_size;
142       values.push_back(APInt(bitwidth, value));
143       ++current_value;
144     }
145 
146     auto int_shape_type = RankedTensorType::get(
147         output_type.getShape(),
148         IntegerType::get(rewriter.getContext(), bitwidth));
149     auto loc = op.getLoc();
150     auto integer_const = rewriter.create<mlir::ConstantOp>(
151         loc, DenseIntElementsAttr::get(int_shape_type, values));
152 
153     auto int_or_float_shape_ty =
154         RankedTensorType::get(output_type.getShape(), int_or_float_ty);
155 
156     auto iota_const =
157         rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
158 
159     // For int/float types we are done, replace op and return.
160     if (!complex_ty) {
161       rewriter.replaceOp(op, iota_const.getResult());
162       return success();
163     }
164 
165     // For complex types, generate a constant tensor of zeroes for the imaginary
166     // part and use iota_const for real part.
167     auto zeroes = rewriter.create<mlir::ConstantOp>(
168         loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
169     auto imag_zeroes =
170         rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
171     rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
172     return success();
173   }
174 };
175 
176 }  // end anonymous namespace
177 
178 namespace {
179 struct LegalizeToStandardPass
180     : public PassWrapper<LegalizeToStandardPass, FunctionPass> {
getDependentDialectsmlir::mhlo::__anona68d694f0311::LegalizeToStandardPass181   void getDependentDialects(DialectRegistry &registry) const override {
182     registry.insert<StandardOpsDialect>();
183   }
184 
185   /// Perform the lowering to Standard dialect.
186   void runOnFunction() override;
187 };
188 }  // end anonymous namespace
189 
createLegalizeToStdPass()190 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
191   return std::make_unique<LegalizeToStandardPass>();
192 }
193 
PopulateMhloToStdPatterns(OwningRewritePatternList * patterns,mlir::MLIRContext * ctx)194 void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
195                                mlir::MLIRContext *ctx) {
196   mlir::populateWithGenerated(ctx, *patterns);
197   patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
198 }
199 
200 /// Perform the lowering to standard dialect.
runOnFunction()201 void LegalizeToStandardPass::runOnFunction() {
202   OwningRewritePatternList patterns;
203   mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
204   (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
205 }
206 
207 }  // end namespace mhlo
208 }  // end namespace mlir
209