1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Location.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/SymbolTable.h" // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
35 #include "mlir/IR/Types.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Support/LLVM.h" // from @llvm-project
39 #include "mlir/Support/LogicalResult.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor_shape.pb.h"
48 #include "tensorflow/core/framework/types.pb.h"
49 #include "tensorflow/core/platform/types.h"
50
51 namespace mlir {
52
53 namespace {
54
55 namespace cutil = TF::collection_ops_util;
56
57 // A pass that converts stack operations to tensor operations and read/assign
58 // ops on local variables. A later resource lifting pass can further remove the
59 // local variables.
60 //
61 // This pass requires that the full shape of the stack can be inferred: 1) the
62 // maximum size needs to be a constant and 2) a push op can be found with a
63 // known shape, and all push ops need to have the same shape.
64 //
65 // A stack creation op "tf.StackV2" will be turned in to two zero-initialized
66 // variables, for the buffer and current size. Each push will be turned into
67 // %old_val = "tf.ReadVariableOp"(%buffer)
68 // %old_size = "tf.ReadVariableOp"(%size)
69 // %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
70 // %new_val = "tf.XlaDynamicUpdateSlice"(%old_val, %push_val, %offsets)
71 // "tf.AssignVariableOp"(%buffer, %new_val)
72 // %new_size = "tf.AddV2"(%old_size, %const1)
73 // "tf.AssignVariableOp"(%size, %new_size)
74 //
75 // and each pop will be turned into
76 //
77 // %old_val = "tf.ReadVariableOp"(%buffer)
78 // %old_size = "tf.ReadVariableOp"(%size)
79 // %new_size = "tf.Sub"(%old_size, %const1)
80 // %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
81 // %slice = "tf.Slice"(%old_val, %offsets, %slice_size_const)
82 // %pop_result = "tf.Reshape"(%slice, %elem_size_const)
83 // "tf.AssignVariableOp"(%size, %new_size)
84 //
85 // The pass also works across control flow and functional calls.
86 struct StackOpsDecompositionPass
87 : public PassWrapper<StackOpsDecompositionPass, OperationPass<ModuleOp>> {
88 void runOnOperation() override;
89 };
90
91 // Returns the type of the local variable for the stack size.
GetSizeVarType(OpBuilder builder)92 Type GetSizeVarType(OpBuilder builder) {
93 auto size_type = cutil::GetSizeType(builder);
94 return RankedTensorType::get(
95 {}, TF::ResourceType::get(ArrayRef<TensorType>{size_type},
96 builder.getContext()));
97 }
98
99 // Returns the aliasing argument number of a fucntion return value if it simply
100 // forwards the argument. Otherwise, returns -1.
FindAliasedInput(FuncOp func,int64_t return_index)101 int64_t FindAliasedInput(FuncOp func, int64_t return_index) {
102 Value return_val = func.front().getTerminator()->getOperand(return_index);
103 auto maybe_arg = return_val.dyn_cast<BlockArgument>();
104 if (!maybe_arg) return -1;
105 return maybe_arg.getArgNumber();
106 }
107
108 // Changes the function signature that has stacks in the arguments. A stack
109 // argument will be turned into a variable type if arg_to_stack_type returns
110 // such a type, and a new argument will be added to the end of the argument
111 // list for the size variable.
112 //
113 // If stack_var_to_size_var is not nullptr, it will be used to store the
114 // mapping from the stack-variable argument to the size-variable argument.
115 //
116 // If handle_new_size_vars is provided, it will be invoked on the list of new
117 // size variables before finally changing the function type.
ModifyFunctionSignature(FuncOp func,llvm::SmallDenseMap<Value,Value> * stack_var_to_size_var,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_stack_type,llvm::function_ref<void (ArrayRef<BlockArgument>)> handle_new_size_vars=nullptr)118 void ModifyFunctionSignature(
119 FuncOp func, llvm::SmallDenseMap<Value, Value>* stack_var_to_size_var,
120 llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_stack_type,
121 llvm::function_ref<void(ArrayRef<BlockArgument>)> handle_new_size_vars =
122 nullptr) {
123 auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
124 auto size_var_type = GetSizeVarType(OpBuilder(func));
125 int64_t original_arg_count = new_input_types.size();
126 for (int64_t i = 0; i < original_arg_count; ++i) {
127 auto stack_type = arg_to_stack_type(i);
128 if (!stack_type.hasValue()) continue;
129 func.getArgument(i).setType(*stack_type);
130 new_input_types[i] = *stack_type;
131 auto size_arg = func.front().addArgument(size_var_type);
132 new_input_types.push_back(size_arg.getType());
133 if (stack_var_to_size_var) {
134 (*stack_var_to_size_var)[func.getArgument(i)] = size_arg;
135 }
136 }
137 if (handle_new_size_vars) {
138 handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
139 }
140 func.setType(
141 FunctionType::get(func.getContext(), new_input_types,
142 func.front().getTerminator()->getOperandTypes()));
143 }
144
145 // Contains cached information for decomposed callee functions for (stateful)
146 // partitioned call ops.
147 struct PartitionedCallStackOpsInfo {
148 bool signature_change;
149 FuncOp decomposed_callee;
150 llvm::SmallDenseMap<int64_t, int64_t> stack_var_arg_to_size_arg;
151 };
152
153 LogicalResult DecomposeStackOpsInternal(
154 Block*, ModuleOp, llvm::SmallDenseMap<Value, Value>*,
155 llvm::StringMap<PartitionedCallStackOpsInfo>*);
156
157 // Handles stack usage by a tf.While. It will convert the body and conditional
158 // function signatures, and performs stack ops decomposition on them.
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)159 LogicalResult HandleWhileOp(
160 TF::WhileOp while_op, ModuleOp module,
161 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
162 llvm::StringMap<PartitionedCallStackOpsInfo>*
163 decomposed_partitioned_call_callees) {
164 auto body = while_op.body_function();
165 llvm::SmallDenseMap<Value, Value> body_map;
166 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
167 auto it = data_var_to_size_var.find(while_op.getOperand(index));
168 if (it == data_var_to_size_var.end()) return llvm::None;
169 return it->getFirst().getType();
170 };
171 auto add_size_vars_to_return = [&](ArrayRef<BlockArgument> new_args) {
172 if (new_args.empty()) return;
173 auto body_ret = body.front().getTerminator();
174 auto new_body_returns = llvm::to_vector<8>(body_ret->getOperands());
175 for (auto arg : new_args) new_body_returns.push_back(arg);
176 OpBuilder(body_ret).create<ReturnOp>(body_ret->getLoc(), new_body_returns);
177 body_ret->erase();
178 };
179 // Handle body.
180 ModifyFunctionSignature(body, &body_map, find_arg_stack_type,
181 add_size_vars_to_return);
182 const bool signature_change = !body_map.empty();
183 if (failed(DecomposeStackOpsInternal(&body.front(), module, &body_map,
184 decomposed_partitioned_call_callees))) {
185 return failure();
186 }
187 // Cond should not change stacks in the arguments, so use an empty map.
188 auto cond = while_op.cond_function();
189 ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
190 llvm::SmallDenseMap<Value, Value> empty_map;
191 if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
192 decomposed_partitioned_call_callees))) {
193 return failure();
194 }
195 if (!signature_change) return success();
196 // Create the new while op.
197 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
198 OpBuilder builder(while_op);
199 assert(while_op.getNumOperands() == while_op.getNumResults());
200 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
201 auto it = data_var_to_size_var.find(while_op.getOperand(i));
202 if (it == data_var_to_size_var.end()) continue;
203 new_while_operands.push_back(it->getSecond());
204 }
205 auto new_while =
206 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
207 new_while_operands, while_op.getAttrs());
208 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
209 if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
210 .isa<TF::ResourceType>()) {
211 continue;
212 }
213 int64_t aliased_input = FindAliasedInput(body, i);
214 if (aliased_input == i) {
215 // Replace aliased stack output uses with input.
216 while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
217 }
218 }
219 while_op.replaceAllUsesWith(
220 new_while.getResults().take_front(while_op.getNumResults()));
221 while_op.erase();
222 return success();
223 }
224
225 // Handles stack usage by a tf.If. It will convert the branch function
226 // signatures, and performs stack ops decomposition on them.
HandleIfOp(TF::IfOp if_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)227 LogicalResult HandleIfOp(
228 TF::IfOp if_op, ModuleOp module,
229 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
230 llvm::StringMap<PartitionedCallStackOpsInfo>*
231 decomposed_partitioned_call_callees) {
232 auto then_func = if_op.then_function();
233 auto else_func = if_op.else_function();
234 llvm::SmallDenseMap<Value, Value> then_map;
235 llvm::SmallDenseMap<Value, Value> else_map;
236
237 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
238 auto it = data_var_to_size_var.find(if_op.getOperand(index + 1));
239 if (it == data_var_to_size_var.end()) return llvm::None;
240 return it->getFirst().getType();
241 };
242 ModifyFunctionSignature(then_func, &then_map, find_arg_stack_type);
243 ModifyFunctionSignature(else_func, &else_map, find_arg_stack_type);
244 const bool signature_change = !then_map.empty() || !else_map.empty();
245 if (failed(DecomposeStackOpsInternal(&then_func.front(), module, &then_map,
246 decomposed_partitioned_call_callees)) ||
247 failed(DecomposeStackOpsInternal(&else_func.front(), module, &else_map,
248 decomposed_partitioned_call_callees))) {
249 return failure();
250 }
251 if (!signature_change) return success();
252 auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
253 for (auto operand : if_op.getOperands()) {
254 auto it = data_var_to_size_var.find(operand);
255 if (it == data_var_to_size_var.end()) continue;
256 new_if_operands.push_back(it->getSecond());
257 }
258 auto new_if = OpBuilder(if_op).create<TF::IfOp>(
259 if_op.getLoc(), then_func.getType().getResults(), new_if_operands,
260 if_op.getAttrs());
261 for (auto result : if_op.getResults()) {
262 if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
263 continue;
264 }
265 int64_t then_aliased_input =
266 FindAliasedInput(then_func, result.getResultNumber());
267 int64_t else_aliased_input =
268 FindAliasedInput(else_func, result.getResultNumber());
269 if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) {
270 // Replace aliased stack output uses with input.
271 result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1));
272 }
273 }
274 if_op.replaceAllUsesWith(new_if);
275 if_op.erase();
276 return success();
277 }
278
279 // Handles stack usage by a tf.StatefulPartitionedCall or a tf.PartitionedCall.
280 // It will first check if the callee was previously handled, and try to reuse
281 // that result if so. Otherwise, it will clone and convert the callee function,
282 // and performs stack ops decomposition on it.
283 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)284 LogicalResult HandlePartitionedCallOp(
285 CallOp call, FuncOp callee, ModuleOp module,
286 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
287 llvm::StringMap<PartitionedCallStackOpsInfo>*
288 decomposed_partitioned_call_callees) {
289 auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
290 callee.getName(), PartitionedCallStackOpsInfo());
291 auto& info = emplace_res.first->second;
292 // Recreate the call op with info.
293 auto recreate_caller = [&] {
294 auto new_operands = llvm::to_vector<8>(call.getOperands());
295 for (int64_t i = 0; i < call.getNumOperands(); ++i) {
296 auto arg_it = info.stack_var_arg_to_size_arg.find(i);
297 if (arg_it == info.stack_var_arg_to_size_arg.end()) continue;
298 auto it = data_var_to_size_var.find(call.getOperand(i));
299 if (it == data_var_to_size_var.end()) {
300 call.emitOpError("unknown stack");
301 return failure();
302 }
303 assert(arg_it->second == new_operands.size());
304 new_operands.push_back(it->getSecond());
305 }
306 OpBuilder builder(call);
307 auto new_call = builder.create<CallOp>(
308 call.getLoc(), info.decomposed_callee.getType().getResults(),
309 new_operands, call.getAttrs());
310 new_call->setAttr(
311 "f", builder.getSymbolRefAttr(
312 const_cast<FuncOp&>(info.decomposed_callee).getName()));
313 for (int64_t i = 0; i < call.getNumResults(); ++i) {
314 auto result = call.getResult(i);
315 if (!getElementTypeOrSelf(result.getType())
316 .template isa<TF::ResourceType>()) {
317 continue;
318 }
319 int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i);
320 if (aliased_input >= 0) {
321 // Replace aliased stack output uses with input.
322 result.replaceAllUsesWith(call.getOperand(aliased_input));
323 }
324 }
325 call.replaceAllUsesWith(new_call);
326 call.erase();
327 return success();
328 };
329 if (!emplace_res.second) {
330 // This callee was handled before.
331 if (!info.signature_change) return success();
332 return recreate_caller();
333 }
334 llvm::SmallDenseMap<Value, Value> callee_map;
335 FuncOp lowered_callee = callee;
336 if (!callee.isPrivate()) {
337 // Clone non-private callee in case of signature change.
338 lowered_callee = callee.clone();
339 lowered_callee.setPrivate();
340 }
341 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
342 auto it = data_var_to_size_var.find(call.getOperand(index));
343 if (it == data_var_to_size_var.end()) return llvm::None;
344 return it->getFirst().getType();
345 };
346 ModifyFunctionSignature(lowered_callee, &callee_map, find_arg_stack_type);
347 info.signature_change = !callee_map.empty();
348 if (!info.signature_change) {
349 // Signature is not modified. We do not need the clone.
350 if (lowered_callee != callee) {
351 lowered_callee.erase();
352 }
353 } else {
354 info.decomposed_callee = lowered_callee;
355 for (auto& entry : callee_map) {
356 info.stack_var_arg_to_size_arg
357 [entry.getFirst().cast<BlockArgument>().getArgNumber()] =
358 entry.getSecond().cast<BlockArgument>().getArgNumber();
359 }
360 if (lowered_callee != callee) {
361 // Add the clone with a new name.
362 lowered_callee.setName(
363 llvm::formatv("{0}_stack_decomposed", callee.getName()).str());
364 SymbolTable(module).insert(lowered_callee);
365 callee = lowered_callee;
366 }
367 }
368 if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map,
369 decomposed_partitioned_call_callees))) {
370 return failure();
371 }
372 if (info.signature_change) return recreate_caller();
373 return success();
374 }
375
HandleStackV2Op(TF::StackV2Op stack,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)376 LogicalResult HandleStackV2Op(
377 TF::StackV2Op stack, ModuleOp module,
378 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
379 // Create a buffer variable and a size variable to replace the stack.
380 auto elem_type = cutil::GetElementTypeFromAccess(
381 stack.handle(), module, [](Operation* user) -> llvm::Optional<Type> {
382 auto push = llvm::dyn_cast<TF::StackPushV2Op>(user);
383 if (!push) return llvm::None;
384 return push.elem().getType();
385 });
386 if (!elem_type.hasValue()) {
387 return stack.emitOpError("cannot infer element shape of stack");
388 }
389 OpBuilder builder(stack);
390 Value buffer;
391 if (failed(cutil::CreateInitBufferValue(
392 elem_type->getShape(), stack.max_size(), stack,
393 elem_type->getElementType(), builder, &buffer))) {
394 return failure();
395 }
396 auto size_var_type = GetSizeVarType(builder);
397 auto var_type = RankedTensorType::get(
398 {}, TF::ResourceType::get(
399 ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
400 stack.getContext()));
401 auto local_var = builder.create<TF::MlirLocalVarOp>(
402 stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
403 auto local_size_var = builder.create<TF::MlirLocalVarOp>(
404 stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{});
405 // Zero-initialize the local vars.
406 cutil::WriteLocalVariable(local_size_var,
407 cutil::GetR1Const({0LL}, builder, stack.getLoc()),
408 builder, stack.getLoc());
409 cutil::WriteLocalVariable(local_var, buffer, builder, stack.getLoc());
410 stack.handle().replaceAllUsesWith(local_var);
411 (*data_var_to_size_var)[local_var] = local_size_var;
412 stack.erase();
413 return success();
414 }
415
HandleStackPushV2Op(TF::StackPushV2Op push,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)416 LogicalResult HandleStackPushV2Op(
417 TF::StackPushV2Op push,
418 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
419 auto it = data_var_to_size_var->find(push.handle());
420 if (it == data_var_to_size_var->end()) {
421 return push.emitOpError("unknown stack");
422 }
423 // Push output simply forward the input element.
424 push.replaceAllUsesWith(push.elem());
425 OpBuilder builder(push);
426 // Read the current buffer and size.
427 auto stack_val =
428 cutil::ReadLocalVariable(push.handle(), builder, push.getLoc());
429 auto index =
430 cutil::ReadLocalVariable(it->getSecond(), builder, push.getLoc());
431 stack_val =
432 cutil::SetElement(index, stack_val, push.elem(), builder, push.getLoc());
433 // Assign the new buffer and size.
434 cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
435 index = builder.create<TF::AddV2Op>(
436 push.getLoc(), ArrayRef<Type>{index.getType()},
437 ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())});
438 cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
439 push.erase();
440 return success();
441 }
442
HandleStackPopV2Op(TF::StackPopV2Op pop,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)443 LogicalResult HandleStackPopV2Op(
444 TF::StackPopV2Op pop,
445 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
446 auto it = data_var_to_size_var->find(pop.handle());
447 if (it == data_var_to_size_var->end()) {
448 return pop.emitOpError("unknown stack");
449 }
450 OpBuilder builder(pop);
451 // Read the current buffer and size.
452 auto stack_val =
453 cutil::ReadLocalVariable(pop.handle(), builder, pop.getLoc());
454 auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
455 auto new_size = builder.create<TF::SubOp>(
456 pop.getLoc(), ArrayRef<Type>{size.getType()},
457 ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())});
458 auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc());
459 pop.replaceAllUsesWith(pop_val);
460 // Update the size.
461 cutil::WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc());
462 pop.erase();
463 return success();
464 }
465
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)466 LogicalResult HandleRegionControlFlowOps(
467 Operation& op, ModuleOp module,
468 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
469 llvm::StringMap<PartitionedCallStackOpsInfo>*
470 decomposed_partitioned_call_callees) {
471 for (OpOperand& operand : op.getOpOperands()) {
472 if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
473 return op.emitOpError()
474 << "found unexpected type " << operand.get().getType()
475 << " of operand #" << operand.getOperandNumber()
476 << ", resource type operands are expected to have been "
477 "canonicalized away for region based control flow ops";
478 }
479 }
480 for (OpResult result : op.getResults()) {
481 if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
482 return op.emitOpError()
483 << "found unexpected type " << result.getType() << " of result #"
484 << result.getResultNumber()
485 << ", resource type results are expected to have been "
486 "canonicalized away for region based control flow ops";
487 }
488 }
489 for (Region& region : op.getRegions()) {
490 if (failed(DecomposeStackOpsInternal(®ion.front(), module,
491 data_var_to_size_var,
492 decomposed_partitioned_call_callees)))
493 return failure();
494 }
495 return success();
496 }
497
498 // Decomposes stack ops on a region and recursively decomposes called functions.
499 // data_var_to_size_var: a mapping from stacks' buffer local variables to size
500 // local variables.
501 // decomposed_partitioned_call_callees: cache for partitioned call ops' callee
502 // function handling.
DecomposeStackOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)503 LogicalResult DecomposeStackOpsInternal(
504 Block* block, ModuleOp module,
505 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
506 llvm::StringMap<PartitionedCallStackOpsInfo>*
507 decomposed_partitioned_call_callees) {
508 for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
509 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
510 // Removes identity nodes in the block. The device computation does not
511 // need such nodes to carry information.
512 op.replaceAllUsesWith(op.getOperands());
513 op.erase();
514 } else if (auto stack = llvm::dyn_cast<TF::StackV2Op>(&op)) {
515 if (failed(HandleStackV2Op(stack, module, data_var_to_size_var))) {
516 return failure();
517 }
518 } else if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(&op)) {
519 if (failed(HandleStackPushV2Op(push, data_var_to_size_var))) {
520 return failure();
521 }
522 } else if (auto pop = llvm::dyn_cast<TF::StackPopV2Op>(&op)) {
523 if (failed(HandleStackPopV2Op(pop, data_var_to_size_var))) {
524 return failure();
525 }
526 } else if (auto close = llvm::dyn_cast<TF::StackCloseV2Op>(&op)) {
527 data_var_to_size_var->erase(close.handle());
528 close.erase();
529 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
530 if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var,
531 decomposed_partitioned_call_callees))) {
532 return failure();
533 }
534 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
535 if (failed(HandleIfOp(if_op, module, *data_var_to_size_var,
536 decomposed_partitioned_call_callees))) {
537 return failure();
538 }
539 } else if (llvm::isa<TF::WhileRegionOp>(op) ||
540 llvm::isa<TF::IfRegionOp>(op) ||
541 llvm::isa<TF::CaseRegionOp>(op)) {
542 if (failed(
543 HandleRegionControlFlowOps(op, module, data_var_to_size_var,
544 decomposed_partitioned_call_callees)))
545 return failure();
546 } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
547 if (!pcall.func()) {
548 return pcall.emitOpError(
549 "stack decomposition does not support call with nested references");
550 }
551 if (failed(HandlePartitionedCallOp(
552 pcall, pcall.func(), module, *data_var_to_size_var,
553 decomposed_partitioned_call_callees))) {
554 return failure();
555 }
556 } else if (auto spcall =
557 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
558 if (failed(HandlePartitionedCallOp(
559 spcall, spcall.func(), module, *data_var_to_size_var,
560 decomposed_partitioned_call_callees))) {
561 return failure();
562 }
563 }
564 }
565 return success();
566 }
567
DecomposeStackOps(Block * block,ModuleOp module)568 LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
569 llvm::SmallDenseMap<Value, Value> data_var_to_size_var;
570 llvm::StringMap<PartitionedCallStackOpsInfo>
571 decomposed_partitioned_call_callees;
572 return DecomposeStackOpsInternal(block, module, &data_var_to_size_var,
573 &decomposed_partitioned_call_callees);
574 }
575
runOnOperation()576 void StackOpsDecompositionPass::runOnOperation() {
577 auto module = getOperation();
578 auto main = module.lookupSymbol<FuncOp>("main");
579 if (!main) return;
580 if (failed(DecomposeStackOps(&main.front(), module))) {
581 signalPassFailure();
582 }
583 }
584
585 static PassRegistration<StackOpsDecompositionPass> pass(
586 "tf-stack-ops-decomposition",
587 "Decompose stack operations into local variable operations. Needs static "
588 "shapes.");
589
590 } // namespace
591
592 namespace TF {
CreateStackOpsDecompositionPass()593 std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass() {
594 return std::make_unique<StackOpsDecompositionPass>();
595 }
596
597 } // namespace TF
598 } // namespace mlir
599