1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass applies some clean up steps after quantization.
17 
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
20 #include "mlir/Pass/Pass.h"  // from @llvm-project
21 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
24 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
25 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
26 
27 //===----------------------------------------------------------------------===//
28 // The post-quantize Pass.
29 //
30 namespace mlir {
31 namespace TFL {
32 namespace {
33 
34 // Applies all the clean up steps after quantization.
35 class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
36  public:
37   // Constructor used by the PassRegistration. This will remove the adaptor ops.
PostQuantizePass()38   explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
39 
40   // Constructor used by manually creating the pass.
PostQuantizePass(bool emit_quant_adaptor_ops)41   explicit PostQuantizePass(bool emit_quant_adaptor_ops)
42       : emit_quant_adaptor_ops_(emit_quant_adaptor_ops) {}
43 
44   void runOnFunction() override;
45 
46  private:
47   // Set this flag to true if the inputs and outputs are in floating point. The
48   // quant adaptor ops convert them to fixed point values (i.e. quantize) before
49   // feeding them to the model and convert them back to floating point
50   // (i.e. dequantize) as the output.
51   bool emit_quant_adaptor_ops_;
52 };
53 
RemoveQuantizationAdaptorOps(FuncOp func)54 void RemoveQuantizationAdaptorOps(FuncOp func) {
55   mlir::OpBuilder builder(func.getBody());
56   auto& bb = func.front();
57 
58   int num_args = bb.getNumArguments();
59   llvm::SmallVector<Type, 4> input_types;
60   input_types.reserve(num_args);
61   // Edit the block arguments and create the new input ops in place to replace
62   // the old input ops and quantize ops.
63   for (int i = 0; i != num_args; ++i) {
64     // Previous loop iteration may invalidate the insertion point so we have to
65     // reset insertion point each iteration.
66     builder.setInsertionPointToStart(&bb);
67 
68     // In each iteration, a new argument is appended to the end of the list
69     // and the current argument is erased, so here we always process the first
70     // argument in the list.
71     auto arg = bb.getArgument(0);
72 
73     auto remove_quantize_op = [&](QuantizeOp quantize_op) {
74       auto quantize_output = quantize_op.output();
75       auto quantize_type = quantize_output.getType();
76       input_types.push_back(quantize_type);
77       auto new_arg = bb.addArgument(quantize_type);
78       quantize_output.replaceAllUsesWith(new_arg);
79       quantize_op.erase();
80       arg.dropAllUses();
81       bb.eraseArgument(0);
82     };
83 
84     // This is looking for a pattern: arg -> tfl.quantize
85     if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
86       auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
87       remove_quantize_op(quantize_op);
88       continue;
89     }
90 
91     // Make a copy of current argument and append it to the end of the list if
92     // the pattern isn't found.
93     Type arg_type = arg.getType();
94     input_types.push_back(arg_type);
95     auto new_arg = bb.addArgument(arg_type);
96     arg.replaceAllUsesWith(new_arg);
97     arg.dropAllUses();
98     bb.eraseArgument(0);
99   }
100 
101   // Edit the return ops and remove the dequantize ops in place.
102   auto* terminator = bb.getTerminator();
103   int num_return_operands = terminator->getNumOperands();
104   llvm::SmallVector<Type, 4> output_types;
105   output_types.reserve(num_return_operands);
106   for (int i = 0; i != num_return_operands; ++i) {
107     auto returned_value = terminator->getOperand(i);
108     Operation* returned_op = returned_value.getDefiningOp();
109     if (returned_op && returned_op->hasOneUse() &&
110         llvm::isa<DequantizeOp>(returned_op)) {
111       auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
112       Value dequantized_result = dequantize_op.input();
113       output_types.push_back(dequantized_result.getType());
114       terminator->setOperand(i, dequantized_result);
115       returned_op->erase();
116     } else {
117       output_types.push_back(returned_value.getType());
118     }
119   }
120   auto new_func_type = builder.getFunctionType(input_types, output_types);
121   func.setType(new_func_type);
122 }
123 
124 // Remove the back-to-back quantize and dequantize ops with volatile attribute.
125 struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
RemoveVolatileOpsmlir::TFL::__anon5a1d2ac70111::RemoveVolatileOps126   explicit RemoveVolatileOps(MLIRContext* context)
127       : OpRewritePattern<DequantizeOp>(context, 1) {}
128 
matchAndRewritemlir::TFL::__anon5a1d2ac70111::RemoveVolatileOps129   LogicalResult matchAndRewrite(DequantizeOp op,
130                                 PatternRewriter& rewriter) const override {
131     auto input_op = op.input().getDefiningOp();
132     if (auto q = llvm::dyn_cast_or_null<QuantizeOp>(input_op)) {
133       if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure();
134 
135       // Don't remove leading and tailing QDQ for PQT workflow, so the io
136       // modifying lib can work correctly.
137       if (!q.input().getDefiningOp()) return failure();
138       if (op->hasOneUse() &&
139           op->user_begin()->hasTrait<OpTrait::IsTerminator>())
140         return failure();
141 
142       op.replaceAllUsesWith(q.input());
143       return success();
144     }
145     return failure();
146   }
147 };
148 
149 // Removes operations with side effect (i.e. LSTM, SVDF) that have dangling
150 // output.
151 template <typename OpTy>
152 struct PruneUnusedOpsWithSideEffect : public OpRewritePattern<OpTy> {
153  public:
PruneUnusedOpsWithSideEffectmlir::TFL::__anon5a1d2ac70111::PruneUnusedOpsWithSideEffect154   explicit PruneUnusedOpsWithSideEffect(MLIRContext* context)
155       : OpRewritePattern<OpTy>(context) {}
156 
matchAndRewritemlir::TFL::__anon5a1d2ac70111::PruneUnusedOpsWithSideEffect157   LogicalResult matchAndRewrite(OpTy op,
158                                 PatternRewriter& rewriter) const override {
159     if (op.getOperation()->template hasTrait<OpTrait::IsTerminator>()) {
160       return failure();
161     }
162     for (auto result : op.getOperation()->getOpResults()) {
163       if (!result.use_empty()) {
164         return failure();
165       }
166     }
167     rewriter.eraseOp(op);
168     return success();
169   }
170 };
171 
172 #include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
173 
runOnFunction()174 void PostQuantizePass::runOnFunction() {
175   OwningRewritePatternList patterns;
176   auto func = getFunction();
177   auto* ctx = func.getContext();
178   TFL::populateWithGenerated(ctx, patterns);
179   patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
180   patterns.insert<PruneUnusedOpsWithSideEffect<TFL::LSTMOp>>(ctx);
181   patterns
182       .insert<PruneUnusedOpsWithSideEffect<TFL::UnidirectionalSequenceLSTMOp>>(
183           ctx);
184   patterns.insert<PruneUnusedOpsWithSideEffect<TFL::SVDFOp>>(ctx);
185   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
186 
187   if (!emit_quant_adaptor_ops_) {
188     RemoveQuantizationAdaptorOps(getFunction());
189   }
190 
191   OwningRewritePatternList phase_2_patterns;
192   TFL::populateWithGenerated(ctx, phase_2_patterns);
193   phase_2_patterns
194       .insert<quant::FoldTrivalRequantizeOp<QuantizeOp>, RemoveVolatileOps>(
195           ctx);
196   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
197 }
198 
199 }  // namespace
200 
201 // Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
CreatePostQuantizePass(bool emit_quant_adaptor_ops)202 std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
203     bool emit_quant_adaptor_ops) {
204   return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
205 }
206 
207 static PassRegistration<PostQuantizePass> pass(
208     "tfl-post-quantize", "Apply post quantization clean up after quantization");
209 
210 }  // namespace TFL
211 }  // namespace mlir
212