1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/memory/memory.h"
17 #include "absl/strings/str_split.h"
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include "llvm/Support/Regex.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
27 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
30 #include "mlir/IR/AffineMap.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Location.h"  // from @llvm-project
34 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Support/LLVM.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
38 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
40 
41 // NOLINTNEXTLINE
42 static llvm::cl::opt<std::string> quantize_stats(
43     "quant-test-stats", llvm::cl::value_desc("string"),
44     llvm::cl::desc("serialized quant info string. Only used in tests"),
45     llvm::cl::init(""));
46 
47 //===----------------------------------------------------------------------===//
48 // The Pass to import quantization stats to the ops in a function. This requires
49 // a custom method to retrieve the unique name of the operation.
50 
51 namespace mlir {
52 namespace quant {
53 
54 using QuantParamsEntry = QuantizationInfo::QuantParams;
55 
56 namespace {
57 class ImportQuantStatsPass
58     : public PassWrapper<ImportQuantStatsPass, FunctionPass> {
59  public:
ImportQuantStatsPass(OperationToName op_to_name)60   explicit ImportQuantStatsPass(OperationToName op_to_name)
61       : op_to_name_(op_to_name) {}
62 
63   void runOnFunction() override;
64 
getDependentDialects(DialectRegistry & registry) const65   void getDependentDialects(DialectRegistry &registry) const override {
66     registry.insert<quant::QuantizationDialect>();
67   }
68 
69   // Parses the serialized quant stats protobuf and initialize the internal
70   // data structure. This method must be called after the pass is created.
71   bool ParseQuantStats(const std::string &stats_str);
72 
73  private:
74   void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
75                         const QuantParamsEntry &info);
76 
77   void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
78                              ElementsAttr axis_stats, IntegerAttr axis);
79 
80   // If the index is out of range, this method returns false. Otherwise it
81   // returns true if the value is a float tensor.
IsQuantizableResult(Operation * op,int index)82   bool IsQuantizableResult(Operation *op, int index) {
83     if (index < 0 || index >= static_cast<int>(op->getNumResults()))
84       return false;
85     Value res = op->getResult(index);
86     return res.getType().isa<ShapedType>() &&
87            res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
88   }
89 
90   // A method to retrieve the name for the given op.
91   OperationToName op_to_name_;
92 
93   // We split the normal names and regex names, since the former can use hash
94   // map to lookup and the latter needs to iterate all the regex to find the
95   // match.
96   // The `int` in the following two containers are to specify the result index
97   // of the given op. -1 indicates all the floating-point results.
98   llvm::StringMap<std::pair<int, const QuantParamsEntry>> name_to_info_;
99   llvm::StringMap<std::pair<int, const QuantParamsEntry>> regex_to_info_;
100 };
101 }  // namespace
102 
ParseQuantStats(const std::string & stats_str)103 bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
104   QuantizationInfo quant_stats;
105   if (!tensorflow::LoadProtoFromBuffer(stats_str, &quant_stats).ok()) {
106     return true;
107   }
108 
109   for (const auto &entry : quant_stats.entries()) {
110     if (!entry.name().empty()) {
111       std::vector<std::string> name_and_port =
112           absl::StrSplit(entry.name(), ':');
113       int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
114       name_to_info_.insert({name_and_port[0], {port, entry}});
115     } else if (!entry.name_regex().empty()) {
116       std::vector<std::string> name_and_port =
117           absl::StrSplit(entry.name_regex(), ':');
118       int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
119       regex_to_info_.insert({name_and_port[0], {port, entry}});
120     }
121   }
122   return false;
123 }
124 
InsertStatsOpAtResult(OpBuilder b,Value res,ElementsAttr layer_stats,ElementsAttr axis_stats,IntegerAttr axis)125 void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
126                                                  ElementsAttr layer_stats,
127                                                  ElementsAttr axis_stats,
128                                                  IntegerAttr axis) {
129   auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
130                                                 layer_stats, axis_stats, axis);
131   res.replaceAllUsesWith(stats_op);
132   stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
133 }
134 
ImportAsStatsOps(OpBuilder b,Operation * op,int index,const QuantParamsEntry & info)135 void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
136                                             int index,
137                                             const QuantParamsEntry &info) {
138   if (info.params_size() == 0) return;
139 
140   SmallVector<APFloat, 4> min_maxs;
141   min_maxs.reserve(info.params_size() * 2);
142   for (const auto &param : info.params()) {
143     llvm::APFloat min(param.min_max().min());
144     llvm::APFloat max(param.min_max().max());
145     min_maxs.push_back(min);
146     min_maxs.push_back(max);
147   }
148   // The layer stats contain only the first min/max pairs.
149   ElementsAttr layer_stats = DenseFPElementsAttr::get(
150       RankedTensorType::get({2}, b.getF32Type()), {min_maxs[0], min_maxs[1]});
151   ElementsAttr axis_stats;
152   IntegerAttr axis;
153 
154   if (info.params_size() > 1) {
155     SmallVector<int64_t, 4> axis_stats_shape{info.params_size(), 2};
156     axis_stats = DenseFPElementsAttr::get(
157         RankedTensorType::get(axis_stats_shape, b.getF32Type()), min_maxs);
158     axis = b.getI64IntegerAttr(info.meta().quantize_axis());
159   }
160 
161   b.setInsertionPointAfter(op);
162   if (IsQuantizableResult(op, index)) {
163     InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
164                           axis);
165   } else {
166     for (int i = 0, e = op->getNumResults(); i < e; ++i) {
167       if (IsQuantizableResult(op, i)) {
168         InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
169                               axis);
170       }
171     }
172   }
173 }
174 
runOnFunction()175 void ImportQuantStatsPass::runOnFunction() {
176   FuncOp func = getFunction();
177   OpBuilder builder(func);
178 
179   func.walk([&](Operation *op) {
180     if (op->hasTrait<OpTrait::IsTerminator>()) return;
181     auto op_name = op_to_name_(op);
182 
183     // Check the named info collection first.
184     auto it = name_to_info_.find(op_name);
185     if (it != name_to_info_.end()) {
186       ImportAsStatsOps(builder, op, it->second.first, it->second.second);
187       return;
188     }
189 
190     // Iterate all the regex names and matches the first one.
191     for (auto &regex : regex_to_info_) {
192       if (llvm::Regex(regex.first()).match(op_name)) {
193         ImportAsStatsOps(builder, op, regex.second.first, regex.second.second);
194         break;
195       }
196     }
197   });
198 }
199 
200 // Creates an instance of the default quant parameters pass.
CreateImportQuantStatsPass(OperationToName op_to_name,const std::string & stats_str)201 std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
202     OperationToName op_to_name, const std::string &stats_str) {
203   auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
204   if (pass->ParseQuantStats(stats_str)) return nullptr;
205   return pass;
206 }
207 
208 // Creates an instance pass to import quantization stats to the operations in
209 // the function. A custom method to get the name from the op is used because
210 // different dialect ops might have different ways to assign the name.
211 std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string & stats_str)212 CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
213   auto get_name_func = [](Operation *op) {
214     Location loc = op->getLoc();
215     if (auto name = loc.dyn_cast<NameLoc>()) {
216       return name.getName().strref();
217     } else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
218       for (auto sub_loc : fused_name.getLocations()) {
219         if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
220           return named_sub_loc.getName().strref();
221         }
222       }
223     }
224     return llvm::StringRef("");
225   };
226 
227   return CreateImportQuantStatsPass(get_name_func, stats_str);
228 }
229 
230 // Registers this pass with default values, only for test
231 static PassRegistration<ImportQuantStatsPass> pass(
__anond0edd7300402null232     "quant-import-stats", "Import quantization stats to the model", [] {
233       return CreateImportQuantStatsPassForTFControlDialect(quantize_stats);
234     });
235 
236 }  // namespace quant
237 }  // namespace mlir
238