1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/memory/memory.h"
17 #include "absl/strings/str_split.h"
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include "llvm/Support/Regex.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
27 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/IR/AffineExpr.h" // from @llvm-project
30 #include "mlir/IR/AffineMap.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Location.h" // from @llvm-project
34 #include "mlir/IR/PatternMatch.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Support/LLVM.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h"
38 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
40
41 // NOLINTNEXTLINE
42 static llvm::cl::opt<std::string> quantize_stats(
43 "quant-test-stats", llvm::cl::value_desc("string"),
44 llvm::cl::desc("serialized quant info string. Only used in tests"),
45 llvm::cl::init(""));
46
47 //===----------------------------------------------------------------------===//
48 // The Pass to import quantization stats to the ops in a function. This requires
49 // a custom method to retrieve the unique name of the operation.
50
51 namespace mlir {
52 namespace quant {
53
54 using QuantParamsEntry = QuantizationInfo::QuantParams;
55
56 namespace {
57 class ImportQuantStatsPass
58 : public PassWrapper<ImportQuantStatsPass, FunctionPass> {
59 public:
ImportQuantStatsPass(OperationToName op_to_name)60 explicit ImportQuantStatsPass(OperationToName op_to_name)
61 : op_to_name_(op_to_name) {}
62
63 void runOnFunction() override;
64
getDependentDialects(DialectRegistry & registry) const65 void getDependentDialects(DialectRegistry ®istry) const override {
66 registry.insert<quant::QuantizationDialect>();
67 }
68
69 // Parses the serialized quant stats protobuf and initialize the internal
70 // data structure. This method must be called after the pass is created.
71 bool ParseQuantStats(const std::string &stats_str);
72
73 private:
74 void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
75 const QuantParamsEntry &info);
76
77 void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
78 ElementsAttr axis_stats, IntegerAttr axis);
79
80 // If the index is out of range, this method returns false. Otherwise it
81 // returns true if the value is a float tensor.
IsQuantizableResult(Operation * op,int index)82 bool IsQuantizableResult(Operation *op, int index) {
83 if (index < 0 || index >= static_cast<int>(op->getNumResults()))
84 return false;
85 Value res = op->getResult(index);
86 return res.getType().isa<ShapedType>() &&
87 res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
88 }
89
90 // A method to retrieve the name for the given op.
91 OperationToName op_to_name_;
92
93 // We split the normal names and regex names, since the former can use hash
94 // map to lookup and the latter needs to iterate all the regex to find the
95 // match.
96 // The `int` in the following two containers are to specify the result index
97 // of the given op. -1 indicates all the floating-point results.
98 llvm::StringMap<std::pair<int, const QuantParamsEntry>> name_to_info_;
99 llvm::StringMap<std::pair<int, const QuantParamsEntry>> regex_to_info_;
100 };
101 } // namespace
102
ParseQuantStats(const std::string & stats_str)103 bool ImportQuantStatsPass::ParseQuantStats(const std::string &stats_str) {
104 QuantizationInfo quant_stats;
105 if (!tensorflow::LoadProtoFromBuffer(stats_str, &quant_stats).ok()) {
106 return true;
107 }
108
109 for (const auto &entry : quant_stats.entries()) {
110 if (!entry.name().empty()) {
111 std::vector<std::string> name_and_port =
112 absl::StrSplit(entry.name(), ':');
113 int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
114 name_to_info_.insert({name_and_port[0], {port, entry}});
115 } else if (!entry.name_regex().empty()) {
116 std::vector<std::string> name_and_port =
117 absl::StrSplit(entry.name_regex(), ':');
118 int port = name_and_port.size() == 2 ? std::stoi(name_and_port[1]) : -1;
119 regex_to_info_.insert({name_and_port[0], {port, entry}});
120 }
121 }
122 return false;
123 }
124
InsertStatsOpAtResult(OpBuilder b,Value res,ElementsAttr layer_stats,ElementsAttr axis_stats,IntegerAttr axis)125 void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
126 ElementsAttr layer_stats,
127 ElementsAttr axis_stats,
128 IntegerAttr axis) {
129 auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
130 layer_stats, axis_stats, axis);
131 res.replaceAllUsesWith(stats_op);
132 stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
133 }
134
ImportAsStatsOps(OpBuilder b,Operation * op,int index,const QuantParamsEntry & info)135 void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op,
136 int index,
137 const QuantParamsEntry &info) {
138 if (info.params_size() == 0) return;
139
140 SmallVector<APFloat, 4> min_maxs;
141 min_maxs.reserve(info.params_size() * 2);
142 for (const auto ¶m : info.params()) {
143 llvm::APFloat min(param.min_max().min());
144 llvm::APFloat max(param.min_max().max());
145 min_maxs.push_back(min);
146 min_maxs.push_back(max);
147 }
148 // The layer stats contain only the first min/max pairs.
149 ElementsAttr layer_stats = DenseFPElementsAttr::get(
150 RankedTensorType::get({2}, b.getF32Type()), {min_maxs[0], min_maxs[1]});
151 ElementsAttr axis_stats;
152 IntegerAttr axis;
153
154 if (info.params_size() > 1) {
155 SmallVector<int64_t, 4> axis_stats_shape{info.params_size(), 2};
156 axis_stats = DenseFPElementsAttr::get(
157 RankedTensorType::get(axis_stats_shape, b.getF32Type()), min_maxs);
158 axis = b.getI64IntegerAttr(info.meta().quantize_axis());
159 }
160
161 b.setInsertionPointAfter(op);
162 if (IsQuantizableResult(op, index)) {
163 InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats,
164 axis);
165 } else {
166 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
167 if (IsQuantizableResult(op, i)) {
168 InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats,
169 axis);
170 }
171 }
172 }
173 }
174
runOnFunction()175 void ImportQuantStatsPass::runOnFunction() {
176 FuncOp func = getFunction();
177 OpBuilder builder(func);
178
179 func.walk([&](Operation *op) {
180 if (op->hasTrait<OpTrait::IsTerminator>()) return;
181 auto op_name = op_to_name_(op);
182
183 // Check the named info collection first.
184 auto it = name_to_info_.find(op_name);
185 if (it != name_to_info_.end()) {
186 ImportAsStatsOps(builder, op, it->second.first, it->second.second);
187 return;
188 }
189
190 // Iterate all the regex names and matches the first one.
191 for (auto ®ex : regex_to_info_) {
192 if (llvm::Regex(regex.first()).match(op_name)) {
193 ImportAsStatsOps(builder, op, regex.second.first, regex.second.second);
194 break;
195 }
196 }
197 });
198 }
199
200 // Creates an instance of the default quant parameters pass.
CreateImportQuantStatsPass(OperationToName op_to_name,const std::string & stats_str)201 std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
202 OperationToName op_to_name, const std::string &stats_str) {
203 auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
204 if (pass->ParseQuantStats(stats_str)) return nullptr;
205 return pass;
206 }
207
208 // Creates an instance pass to import quantization stats to the operations in
209 // the function. A custom method to get the name from the op is used because
210 // different dialect ops might have different ways to assign the name.
211 std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string & stats_str)212 CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
213 auto get_name_func = [](Operation *op) {
214 Location loc = op->getLoc();
215 if (auto name = loc.dyn_cast<NameLoc>()) {
216 return name.getName().strref();
217 } else if (auto fused_name = loc.dyn_cast<FusedLoc>()) {
218 for (auto sub_loc : fused_name.getLocations()) {
219 if (auto named_sub_loc = sub_loc.dyn_cast<NameLoc>()) {
220 return named_sub_loc.getName().strref();
221 }
222 }
223 }
224 return llvm::StringRef("");
225 };
226
227 return CreateImportQuantStatsPass(get_name_func, stats_str);
228 }
229
230 // Registers this pass with default values, only for test
231 static PassRegistration<ImportQuantStatsPass> pass(
__anond0edd7300402null232 "quant-import-stats", "Import quantization stats to the model", [] {
233 return CreateImportQuantStatsPassForTFControlDialect(quantize_stats);
234 });
235
236 } // namespace quant
237 } // namespace mlir
238