1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Identifier.h"  // from @llvm-project
24 #include "mlir/IR/Location.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/IR/Matchers.h"  // from @llvm-project
27 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 
35 namespace mlir {
36 namespace TFL {
37 namespace {
38 
39 // This pass outlines the cond/body region of the TFL WhileOp into functions and
40 // replaces the regions with calls to these outlined functions.
41 class WhileOutlinePass
42     : public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> {
43  public:
WhileOutlinePass()44   explicit WhileOutlinePass() {}
45 
46  private:
47   void runOnOperation() override;
48 
49   // Outlines the regions of the WhileOp's cond and body and insert function
50   // calls instead,
51   void OutlineWhile(WhileOp while_op);
52 
53   // Get unique name by using the loc to name mapping.
54   std::string GetName(Operation* op, StringRef suffix);
55 
56   tensorflow::OpOrArgLocNameMapper mapper_;
57 };
58 
GetName(Operation * op,StringRef suffix)59 std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
60   return (mapper_.GetUniqueName(op) + suffix).str();
61 }
62 
63 // Returns whether the WhileOp is already outlined (e.g., only consists of calls
64 // to functions).
IsAlreadyOutlined(WhileOp while_op)65 bool IsAlreadyOutlined(WhileOp while_op) {
66   auto just_call = [](Region& region) {
67     auto it = region.front().begin();
68     if (!isa<CallOp>(*it)) return false;
69     ++it;
70     if (!isa<YieldOp>(*it)) return false;
71     return true;
72   };
73   return just_call(while_op.body()) && just_call(while_op.cond());
74 }
75 
IsCompatibleTypeWithTFLCastOp(Type type)76 bool IsCompatibleTypeWithTFLCastOp(Type type) {
77   auto elemType = getElementTypeOrSelf(type);
78   // F32 and BF16 types are allowed.
79   if (elemType.isBF16() || elemType.isF32()) return true;
80 
81   // I1, I16, I32, I64 types are allowed.
82   if (elemType.isInteger(1) || elemType.isInteger(16) ||
83       elemType.isInteger(32) || elemType.isInteger(64))
84     return true;
85 
86   // Complex<F<32>> is allowed.
87   if (elemType.isa<ComplexType>() &&
88       elemType.cast<ComplexType>().getElementType().isF32())
89     return true;
90 
91   // QUINT8 and UI8 are allowed.
92   if (elemType.isa<TF::Quint8Type>() ||
93       (elemType.isInteger(8) && elemType.cast<IntegerType>().isUnsigned()))
94     return true;
95 
96   return false;
97 }
98 
OutlineWhile(WhileOp while_op)99 void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
100   OpBuilder builder(&getContext());
101   // Collect external values used.
102   llvm::SetVector<Value> extern_values;
103 
104   // The basic block arguments correspond to values that are loop carried, while
105   // all those post are loop independent. Initialize extern_values with while_op
106   // not loop carried operands.
107   auto num_loop_carried = while_op.cond().getNumArguments();
108   auto not_carried_operands =
109       while_op.getOperands().drop_front(num_loop_carried);
110   extern_values.insert(not_carried_operands.begin(),
111                        not_carried_operands.end());
112   auto old_extern_values_size = extern_values.size();
113 
114   llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
115   for (auto it : llvm::enumerate(regions)) {
116     llvm::SetVector<Value> region_extern_values;
117     getUsedValuesDefinedAbove(*it.value(), region_extern_values);
118 
119     // Sink down constants into the functions.
120     for (auto extern_value : region_extern_values) {
121       if (!matchPattern(extern_value, m_Constant())) {
122         extern_values.insert(extern_value);
123         continue;
124       }
125       // Add constant at start of region.
126       auto const_builder =
127           OpBuilder(&it.value()->front(), it.value()->front().begin());
128       auto const_value = const_builder.clone(*extern_value.getDefiningOp());
129       replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
130                                  *it.value());
131     }
132   }
133 
134   bool has_extra_extern_values = old_extern_values_size != extern_values.size();
135   // If an extern value is already an operand post the loop carried operands,
136   // then it need not be passed in again.
137   // Compute all the extra operands that have to be added to the while.
138   llvm::SetVector<Value> extra_operands;
139   if (has_extra_extern_values) {
140     auto new_extern =
141         extern_values.getArrayRef().drop_front(old_extern_values_size);
142     extra_operands.insert(new_extern.begin(), new_extern.end());
143   }
144 
145   // Skip if already just calls.
146   if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
147 
148   // Collect new types.
149   SmallVector<Type, 4> types;
150   types.reserve(extra_operands.size() + while_op.getNumOperands());
151   for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type);
152   for (Value operand : extern_values) types.push_back(operand.getType());
153 
154   // Create outline function from region. Optional pass extra arguments through
155   // to yield.
156   SymbolTable symbol_table(getOperation());
157   auto create_outline_func = [&](StringRef name, Region& region,
158                                  bool passthru_extra_args) {
159     FunctionType type;
160     if (passthru_extra_args) {
161       type = FunctionType::get(&getContext(), types, types);
162     } else {
163       SmallVector<Type, 4> result_types;
164       auto operands = region.front().getTerminator()->getOperandTypes();
165       result_types.append(operands.begin(), operands.end());
166       type = FunctionType::get(&getContext(), types, result_types);
167     }
168 
169     auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type);
170     outlined_func.getBody().takeBody(region);
171     Region& func_region = outlined_func.getBody();
172 
173     // Replace all external uses with block args and update uses.
174     llvm::SmallVector<Value, 4> new_args;
175     new_args.reserve(extern_values.size());
176     Block& block = func_region.front();
177     for (Value value : extern_values) {
178       auto arg = block.addArgument(value.getType());
179       replaceAllUsesInRegionWith(value, arg, func_region);
180       new_args.push_back(arg);
181     }
182 
183     // Replace yield op with return.
184     Operation* yield_op = outlined_func.getBody().front().getTerminator();
185     OpBuilder b(yield_op);
186     llvm::SmallVector<Value, 4> args;
187     auto loop_carried_yield_operands =
188         yield_op->getOperands().take_front(num_loop_carried);
189     args.reserve(loop_carried_yield_operands.size() + new_args.size());
190     if (passthru_extra_args) {
191       // Add operands of yield to the return, inserting casts if needed.
192       for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) {
193         auto value = std::get<0>(it);
194         auto type = std::get<1>(it);
195         if (value.getType() == type) {
196           args.push_back(value);
197         } else {
198           if (IsCompatibleTypeWithTFLCastOp(value.getType()) &&
199               IsCompatibleTypeWithTFLCastOp(type)) {
200             auto cast = b.create<CastOp>(yield_op->getLoc(), type, value);
201             args.push_back(cast);
202           } else {
203             auto cast = b.create<TF::CastOp>(yield_op->getLoc(), type, value);
204             args.push_back(cast);
205           }
206         }
207       }
208       args.append(new_args.begin(), new_args.end());
209     } else {
210       args.append(yield_op->operand_begin(), yield_op->operand_end());
211     }
212     b.create<ReturnOp>(yield_op->getLoc(), args);
213     yield_op->erase();
214     symbol_table.insert(outlined_func);
215     outlined_func.setPrivate();
216     return outlined_func;
217   };
218 
219   // Replace region with call to outline function.
220   auto replace_with_call = [&](StringRef name, Region& region,
221                                bool passthru_extra_args) {
222     auto func = create_outline_func(name, region, passthru_extra_args);
223     OpBuilder b(region);
224     // The body of the region is empty/has been outlined into the function.
225     auto block = b.createBlock(&region);
226     SmallVector<Value, 4> new_operands;
227     new_operands.reserve(types.size());
228     for (Type t : llvm::makeArrayRef(types).drop_back(extern_values.size()))
229       new_operands.push_back(block->addArgument(t));
230     for (Value v : extern_values) new_operands.push_back(v);
231     auto call = b.create<CallOp>(while_op.getLoc(), func, new_operands);
232     b.create<YieldOp>(while_op.getLoc(), call.getResults());
233   };
234 
235   replace_with_call(GetName(while_op.getOperation(), "_cond"), while_op.cond(),
236                     false);
237   replace_with_call(GetName(while_op.getOperation(), "_body"), while_op.body(),
238                     true);
239 
240   // If there are extern values used then the result type of the while has to
241   // change, so replace with new while op.
242   if (extra_operands.empty()) return;
243 
244   const int operands_size = while_op.getNumOperands() + extra_operands.size();
245   SmallVector<Value, 4> operands;
246   operands.reserve(operands_size);
247   operands.append(while_op.getOperands().begin(), while_op.getOperands().end());
248   operands.append(extra_operands.begin(), extra_operands.end());
249   SmallVector<Type, 4> new_types;
250   new_types.reserve(operands_size);
251   new_types.append(while_op.getResultTypes().begin(),
252                    while_op.getResultTypes().end());
253   for (auto extra_operand : extra_operands)
254     new_types.push_back(extra_operand.getType());
255 
256   auto new_while_op = OpBuilder(while_op).create<WhileOp>(
257       while_op.getLoc(), new_types, operands, while_op.getAttrs());
258   new_while_op.cond().takeBody(while_op.cond());
259   new_while_op.body().takeBody(while_op.body());
260   while_op.replaceAllUsesWith(
261       new_while_op.getResults().take_front(while_op.getNumResults()));
262   while_op.erase();
263 }
264 
runOnOperation()265 void WhileOutlinePass::runOnOperation() {
266   getOperation().walk(
267       [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
268 }
269 }  // namespace
270 
271 // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
CreateWhileOutlinePass()272 std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
273   return std::make_unique<WhileOutlinePass>();
274 }
275 
276 static PassRegistration<WhileOutlinePass> pass(
277     "tfl-while-loop-outline", "Hoist while op regions into functions");
278 
279 }  // namespace TFL
280 }  // namespace mlir
281