1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Location.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
35 #include "mlir/IR/Types.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Support/LLVM.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor_shape.pb.h"
48 #include "tensorflow/core/framework/types.pb.h"
49 #include "tensorflow/core/platform/types.h"
50 
51 namespace mlir {
52 
53 namespace {
54 
55 namespace cutil = TF::collection_ops_util;
56 
57 // A pass that converts stack operations to tensor operations and read/assign
58 // ops on local variables. A later resource lifting pass can further remove the
59 // local variables.
60 //
61 // This pass requires that the full shape of the stack can be inferred: 1) the
62 // maximum size needs to be a constant and 2) a push op can be found with a
63 // known shape, and all push ops need to have the same shape.
64 //
65 // A stack creation op "tf.StackV2" will be turned in to two zero-initialized
66 // variables, for the buffer and current size. Each push will be turned into
67 //   %old_val = "tf.ReadVariableOp"(%buffer)
68 //   %old_size = "tf.ReadVariableOp"(%size)
69 //   %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
70 //   %new_val = "tf.XlaDynamicUpdateSlice"(%old_val, %push_val, %offsets)
71 //   "tf.AssignVariableOp"(%buffer, %new_val)
72 //   %new_size = "tf.AddV2"(%old_size, %const1)
73 //   "tf.AssignVariableOp"(%size, %new_size)
74 //
75 // and each pop will be turned into
76 //
77 //   %old_val = "tf.ReadVariableOp"(%buffer)
78 //   %old_size = "tf.ReadVariableOp"(%size)
79 //   %new_size = "tf.Sub"(%old_size, %const1)
80 //   %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
81 //   %slice = "tf.Slice"(%old_val, %offsets, %slice_size_const)
82 //   %pop_result = "tf.Reshape"(%slice, %elem_size_const)
83 //   "tf.AssignVariableOp"(%size, %new_size)
84 //
85 // The pass also works across control flow and functional calls.
86 struct StackOpsDecompositionPass
87     : public PassWrapper<StackOpsDecompositionPass, OperationPass<ModuleOp>> {
88   void runOnOperation() override;
89 };
90 
91 // Returns the type of the local variable for the stack size.
GetSizeVarType(OpBuilder builder)92 Type GetSizeVarType(OpBuilder builder) {
93   auto size_type = cutil::GetSizeType(builder);
94   return RankedTensorType::get(
95       {}, TF::ResourceType::get(ArrayRef<TensorType>{size_type},
96                                 builder.getContext()));
97 }
98 
99 // Returns the aliasing argument number of a fucntion return value if it simply
100 // forwards the argument. Otherwise, returns -1.
FindAliasedInput(FuncOp func,int64_t return_index)101 int64_t FindAliasedInput(FuncOp func, int64_t return_index) {
102   Value return_val = func.front().getTerminator()->getOperand(return_index);
103   auto maybe_arg = return_val.dyn_cast<BlockArgument>();
104   if (!maybe_arg) return -1;
105   return maybe_arg.getArgNumber();
106 }
107 
108 // Changes the function signature that has stacks in the arguments. A stack
109 // argument will be turned into a variable type if arg_to_stack_type returns
110 // such a type, and a new argument will be added to the end of the argument
111 // list for the size variable.
112 //
113 // If stack_var_to_size_var is not nullptr, it will  be used to store the
114 // mapping from the stack-variable argument to the size-variable argument.
115 //
116 // If handle_new_size_vars is provided, it will be invoked on the list of new
117 // size variables before finally changing the function type.
ModifyFunctionSignature(FuncOp func,llvm::SmallDenseMap<Value,Value> * stack_var_to_size_var,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_stack_type,llvm::function_ref<void (ArrayRef<BlockArgument>)> handle_new_size_vars=nullptr)118 void ModifyFunctionSignature(
119     FuncOp func, llvm::SmallDenseMap<Value, Value>* stack_var_to_size_var,
120     llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_stack_type,
121     llvm::function_ref<void(ArrayRef<BlockArgument>)> handle_new_size_vars =
122         nullptr) {
123   auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
124   auto size_var_type = GetSizeVarType(OpBuilder(func));
125   int64_t original_arg_count = new_input_types.size();
126   for (int64_t i = 0; i < original_arg_count; ++i) {
127     auto stack_type = arg_to_stack_type(i);
128     if (!stack_type.hasValue()) continue;
129     func.getArgument(i).setType(*stack_type);
130     new_input_types[i] = *stack_type;
131     auto size_arg = func.front().addArgument(size_var_type);
132     new_input_types.push_back(size_arg.getType());
133     if (stack_var_to_size_var) {
134       (*stack_var_to_size_var)[func.getArgument(i)] = size_arg;
135     }
136   }
137   if (handle_new_size_vars) {
138     handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
139   }
140   func.setType(
141       FunctionType::get(func.getContext(), new_input_types,
142                         func.front().getTerminator()->getOperandTypes()));
143 }
144 
145 // Contains cached information for decomposed callee functions for (stateful)
146 // partitioned call ops.
147 struct PartitionedCallStackOpsInfo {
148   bool signature_change;
149   FuncOp decomposed_callee;
150   llvm::SmallDenseMap<int64_t, int64_t> stack_var_arg_to_size_arg;
151 };
152 
153 LogicalResult DecomposeStackOpsInternal(
154     Block*, ModuleOp, llvm::SmallDenseMap<Value, Value>*,
155     llvm::StringMap<PartitionedCallStackOpsInfo>*);
156 
157 // Handles stack usage by a tf.While. It will convert the body and conditional
158 // function signatures, and performs stack ops decomposition on them.
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)159 LogicalResult HandleWhileOp(
160     TF::WhileOp while_op, ModuleOp module,
161     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
162     llvm::StringMap<PartitionedCallStackOpsInfo>*
163         decomposed_partitioned_call_callees) {
164   auto body = while_op.body_function();
165   llvm::SmallDenseMap<Value, Value> body_map;
166   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
167     auto it = data_var_to_size_var.find(while_op.getOperand(index));
168     if (it == data_var_to_size_var.end()) return llvm::None;
169     return it->getFirst().getType();
170   };
171   auto add_size_vars_to_return = [&](ArrayRef<BlockArgument> new_args) {
172     if (new_args.empty()) return;
173     auto body_ret = body.front().getTerminator();
174     auto new_body_returns = llvm::to_vector<8>(body_ret->getOperands());
175     for (auto arg : new_args) new_body_returns.push_back(arg);
176     OpBuilder(body_ret).create<ReturnOp>(body_ret->getLoc(), new_body_returns);
177     body_ret->erase();
178   };
179   // Handle body.
180   ModifyFunctionSignature(body, &body_map, find_arg_stack_type,
181                           add_size_vars_to_return);
182   const bool signature_change = !body_map.empty();
183   if (failed(DecomposeStackOpsInternal(&body.front(), module, &body_map,
184                                        decomposed_partitioned_call_callees))) {
185     return failure();
186   }
187   // Cond should not change stacks in the arguments, so use an empty map.
188   auto cond = while_op.cond_function();
189   ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
190   llvm::SmallDenseMap<Value, Value> empty_map;
191   if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
192                                        decomposed_partitioned_call_callees))) {
193     return failure();
194   }
195   if (!signature_change) return success();
196   // Create the new while op.
197   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
198   OpBuilder builder(while_op);
199   assert(while_op.getNumOperands() == while_op.getNumResults());
200   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
201     auto it = data_var_to_size_var.find(while_op.getOperand(i));
202     if (it == data_var_to_size_var.end()) continue;
203     new_while_operands.push_back(it->getSecond());
204   }
205   auto new_while =
206       builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
207                                   new_while_operands, while_op.getAttrs());
208   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
209     if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
210              .isa<TF::ResourceType>()) {
211       continue;
212     }
213     int64_t aliased_input = FindAliasedInput(body, i);
214     if (aliased_input == i) {
215       // Replace aliased stack output uses with input.
216       while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
217     }
218   }
219   while_op.replaceAllUsesWith(
220       new_while.getResults().take_front(while_op.getNumResults()));
221   while_op.erase();
222   return success();
223 }
224 
225 // Handles stack usage by a tf.If. It will convert the branch function
226 // signatures, and performs stack ops decomposition on them.
HandleIfOp(TF::IfOp if_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)227 LogicalResult HandleIfOp(
228     TF::IfOp if_op, ModuleOp module,
229     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
230     llvm::StringMap<PartitionedCallStackOpsInfo>*
231         decomposed_partitioned_call_callees) {
232   auto then_func = if_op.then_function();
233   auto else_func = if_op.else_function();
234   llvm::SmallDenseMap<Value, Value> then_map;
235   llvm::SmallDenseMap<Value, Value> else_map;
236 
237   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
238     auto it = data_var_to_size_var.find(if_op.getOperand(index + 1));
239     if (it == data_var_to_size_var.end()) return llvm::None;
240     return it->getFirst().getType();
241   };
242   ModifyFunctionSignature(then_func, &then_map, find_arg_stack_type);
243   ModifyFunctionSignature(else_func, &else_map, find_arg_stack_type);
244   const bool signature_change = !then_map.empty() || !else_map.empty();
245   if (failed(DecomposeStackOpsInternal(&then_func.front(), module, &then_map,
246                                        decomposed_partitioned_call_callees)) ||
247       failed(DecomposeStackOpsInternal(&else_func.front(), module, &else_map,
248                                        decomposed_partitioned_call_callees))) {
249     return failure();
250   }
251   if (!signature_change) return success();
252   auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
253   for (auto operand : if_op.getOperands()) {
254     auto it = data_var_to_size_var.find(operand);
255     if (it == data_var_to_size_var.end()) continue;
256     new_if_operands.push_back(it->getSecond());
257   }
258   auto new_if = OpBuilder(if_op).create<TF::IfOp>(
259       if_op.getLoc(), then_func.getType().getResults(), new_if_operands,
260       if_op.getAttrs());
261   for (auto result : if_op.getResults()) {
262     if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
263       continue;
264     }
265     int64_t then_aliased_input =
266         FindAliasedInput(then_func, result.getResultNumber());
267     int64_t else_aliased_input =
268         FindAliasedInput(else_func, result.getResultNumber());
269     if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) {
270       // Replace aliased stack output uses with input.
271       result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1));
272     }
273   }
274   if_op.replaceAllUsesWith(new_if);
275   if_op.erase();
276   return success();
277 }
278 
279 // Handles stack usage by a tf.StatefulPartitionedCall or a tf.PartitionedCall.
280 // It will first check if the callee was previously handled, and try to reuse
281 // that result if so. Otherwise, it will clone and convert the callee function,
282 // and performs stack ops decomposition on it.
283 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)284 LogicalResult HandlePartitionedCallOp(
285     CallOp call, FuncOp callee, ModuleOp module,
286     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
287     llvm::StringMap<PartitionedCallStackOpsInfo>*
288         decomposed_partitioned_call_callees) {
289   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
290       callee.getName(), PartitionedCallStackOpsInfo());
291   auto& info = emplace_res.first->second;
292   // Recreate the call op with info.
293   auto recreate_caller = [&] {
294     auto new_operands = llvm::to_vector<8>(call.getOperands());
295     for (int64_t i = 0; i < call.getNumOperands(); ++i) {
296       auto arg_it = info.stack_var_arg_to_size_arg.find(i);
297       if (arg_it == info.stack_var_arg_to_size_arg.end()) continue;
298       auto it = data_var_to_size_var.find(call.getOperand(i));
299       if (it == data_var_to_size_var.end()) {
300         call.emitOpError("unknown stack");
301         return failure();
302       }
303       assert(arg_it->second == new_operands.size());
304       new_operands.push_back(it->getSecond());
305     }
306     OpBuilder builder(call);
307     auto new_call = builder.create<CallOp>(
308         call.getLoc(), info.decomposed_callee.getType().getResults(),
309         new_operands, call.getAttrs());
310     new_call->setAttr(
311         "f", builder.getSymbolRefAttr(
312                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
313     for (int64_t i = 0; i < call.getNumResults(); ++i) {
314       auto result = call.getResult(i);
315       if (!getElementTypeOrSelf(result.getType())
316                .template isa<TF::ResourceType>()) {
317         continue;
318       }
319       int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i);
320       if (aliased_input >= 0) {
321         // Replace aliased stack output uses with input.
322         result.replaceAllUsesWith(call.getOperand(aliased_input));
323       }
324     }
325     call.replaceAllUsesWith(new_call);
326     call.erase();
327     return success();
328   };
329   if (!emplace_res.second) {
330     // This callee was handled before.
331     if (!info.signature_change) return success();
332     return recreate_caller();
333   }
334   llvm::SmallDenseMap<Value, Value> callee_map;
335   FuncOp lowered_callee = callee;
336   if (!callee.isPrivate()) {
337     // Clone non-private callee in case of signature change.
338     lowered_callee = callee.clone();
339     lowered_callee.setPrivate();
340   }
341   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
342     auto it = data_var_to_size_var.find(call.getOperand(index));
343     if (it == data_var_to_size_var.end()) return llvm::None;
344     return it->getFirst().getType();
345   };
346   ModifyFunctionSignature(lowered_callee, &callee_map, find_arg_stack_type);
347   info.signature_change = !callee_map.empty();
348   if (!info.signature_change) {
349     // Signature is not modified. We do not need the clone.
350     if (lowered_callee != callee) {
351       lowered_callee.erase();
352     }
353   } else {
354     info.decomposed_callee = lowered_callee;
355     for (auto& entry : callee_map) {
356       info.stack_var_arg_to_size_arg
357           [entry.getFirst().cast<BlockArgument>().getArgNumber()] =
358           entry.getSecond().cast<BlockArgument>().getArgNumber();
359     }
360     if (lowered_callee != callee) {
361       // Add the clone with a new name.
362       lowered_callee.setName(
363           llvm::formatv("{0}_stack_decomposed", callee.getName()).str());
364       SymbolTable(module).insert(lowered_callee);
365       callee = lowered_callee;
366     }
367   }
368   if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map,
369                                        decomposed_partitioned_call_callees))) {
370     return failure();
371   }
372   if (info.signature_change) return recreate_caller();
373   return success();
374 }
375 
HandleStackV2Op(TF::StackV2Op stack,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)376 LogicalResult HandleStackV2Op(
377     TF::StackV2Op stack, ModuleOp module,
378     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
379   // Create a buffer variable and a size variable to replace the stack.
380   auto elem_type = cutil::GetElementTypeFromAccess(
381       stack.handle(), module, [](Operation* user) -> llvm::Optional<Type> {
382         auto push = llvm::dyn_cast<TF::StackPushV2Op>(user);
383         if (!push) return llvm::None;
384         return push.elem().getType();
385       });
386   if (!elem_type.hasValue()) {
387     return stack.emitOpError("cannot infer element shape of stack");
388   }
389   OpBuilder builder(stack);
390   Value buffer;
391   if (failed(cutil::CreateInitBufferValue(
392           elem_type->getShape(), stack.max_size(), stack,
393           elem_type->getElementType(), builder, &buffer))) {
394     return failure();
395   }
396   auto size_var_type = GetSizeVarType(builder);
397   auto var_type = RankedTensorType::get(
398       {}, TF::ResourceType::get(
399               ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
400               stack.getContext()));
401   auto local_var = builder.create<TF::MlirLocalVarOp>(
402       stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
403   auto local_size_var = builder.create<TF::MlirLocalVarOp>(
404       stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{});
405   // Zero-initialize the local vars.
406   cutil::WriteLocalVariable(local_size_var,
407                             cutil::GetR1Const({0LL}, builder, stack.getLoc()),
408                             builder, stack.getLoc());
409   cutil::WriteLocalVariable(local_var, buffer, builder, stack.getLoc());
410   stack.handle().replaceAllUsesWith(local_var);
411   (*data_var_to_size_var)[local_var] = local_size_var;
412   stack.erase();
413   return success();
414 }
415 
HandleStackPushV2Op(TF::StackPushV2Op push,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)416 LogicalResult HandleStackPushV2Op(
417     TF::StackPushV2Op push,
418     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
419   auto it = data_var_to_size_var->find(push.handle());
420   if (it == data_var_to_size_var->end()) {
421     return push.emitOpError("unknown stack");
422   }
423   // Push output simply forward the input element.
424   push.replaceAllUsesWith(push.elem());
425   OpBuilder builder(push);
426   // Read the current buffer and size.
427   auto stack_val =
428       cutil::ReadLocalVariable(push.handle(), builder, push.getLoc());
429   auto index =
430       cutil::ReadLocalVariable(it->getSecond(), builder, push.getLoc());
431   stack_val =
432       cutil::SetElement(index, stack_val, push.elem(), builder, push.getLoc());
433   // Assign the new buffer and size.
434   cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
435   index = builder.create<TF::AddV2Op>(
436       push.getLoc(), ArrayRef<Type>{index.getType()},
437       ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())});
438   cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
439   push.erase();
440   return success();
441 }
442 
HandleStackPopV2Op(TF::StackPopV2Op pop,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)443 LogicalResult HandleStackPopV2Op(
444     TF::StackPopV2Op pop,
445     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
446   auto it = data_var_to_size_var->find(pop.handle());
447   if (it == data_var_to_size_var->end()) {
448     return pop.emitOpError("unknown stack");
449   }
450   OpBuilder builder(pop);
451   // Read the current buffer and size.
452   auto stack_val =
453       cutil::ReadLocalVariable(pop.handle(), builder, pop.getLoc());
454   auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
455   auto new_size = builder.create<TF::SubOp>(
456       pop.getLoc(), ArrayRef<Type>{size.getType()},
457       ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())});
458   auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc());
459   pop.replaceAllUsesWith(pop_val);
460   // Update the size.
461   cutil::WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc());
462   pop.erase();
463   return success();
464 }
465 
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)466 LogicalResult HandleRegionControlFlowOps(
467     Operation& op, ModuleOp module,
468     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
469     llvm::StringMap<PartitionedCallStackOpsInfo>*
470         decomposed_partitioned_call_callees) {
471   for (OpOperand& operand : op.getOpOperands()) {
472     if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
473       return op.emitOpError()
474              << "found unexpected type " << operand.get().getType()
475              << " of operand #" << operand.getOperandNumber()
476              << ", resource type operands are expected to have been "
477                 "canonicalized away for region based control flow ops";
478     }
479   }
480   for (OpResult result : op.getResults()) {
481     if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
482       return op.emitOpError()
483              << "found unexpected type " << result.getType() << " of result #"
484              << result.getResultNumber()
485              << ", resource type results are expected to have been "
486                 "canonicalized away for region based control flow ops";
487     }
488   }
489   for (Region& region : op.getRegions()) {
490     if (failed(DecomposeStackOpsInternal(&region.front(), module,
491                                          data_var_to_size_var,
492                                          decomposed_partitioned_call_callees)))
493       return failure();
494   }
495   return success();
496 }
497 
498 // Decomposes stack ops on a region and recursively decomposes called functions.
499 // data_var_to_size_var: a mapping from stacks' buffer local variables to size
500 // local variables.
501 // decomposed_partitioned_call_callees: cache for partitioned call ops' callee
502 // function handling.
DecomposeStackOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)503 LogicalResult DecomposeStackOpsInternal(
504     Block* block, ModuleOp module,
505     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
506     llvm::StringMap<PartitionedCallStackOpsInfo>*
507         decomposed_partitioned_call_callees) {
508   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
509     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
510       // Removes identity nodes in the block. The device computation does not
511       // need such nodes to carry information.
512       op.replaceAllUsesWith(op.getOperands());
513       op.erase();
514     } else if (auto stack = llvm::dyn_cast<TF::StackV2Op>(&op)) {
515       if (failed(HandleStackV2Op(stack, module, data_var_to_size_var))) {
516         return failure();
517       }
518     } else if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(&op)) {
519       if (failed(HandleStackPushV2Op(push, data_var_to_size_var))) {
520         return failure();
521       }
522     } else if (auto pop = llvm::dyn_cast<TF::StackPopV2Op>(&op)) {
523       if (failed(HandleStackPopV2Op(pop, data_var_to_size_var))) {
524         return failure();
525       }
526     } else if (auto close = llvm::dyn_cast<TF::StackCloseV2Op>(&op)) {
527       data_var_to_size_var->erase(close.handle());
528       close.erase();
529     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
530       if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var,
531                                decomposed_partitioned_call_callees))) {
532         return failure();
533       }
534     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
535       if (failed(HandleIfOp(if_op, module, *data_var_to_size_var,
536                             decomposed_partitioned_call_callees))) {
537         return failure();
538       }
539     } else if (llvm::isa<TF::WhileRegionOp>(op) ||
540                llvm::isa<TF::IfRegionOp>(op) ||
541                llvm::isa<TF::CaseRegionOp>(op)) {
542       if (failed(
543               HandleRegionControlFlowOps(op, module, data_var_to_size_var,
544                                          decomposed_partitioned_call_callees)))
545         return failure();
546     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
547       if (!pcall.func()) {
548         return pcall.emitOpError(
549             "stack decomposition does not support call with nested references");
550       }
551       if (failed(HandlePartitionedCallOp(
552               pcall, pcall.func(), module, *data_var_to_size_var,
553               decomposed_partitioned_call_callees))) {
554         return failure();
555       }
556     } else if (auto spcall =
557                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
558       if (failed(HandlePartitionedCallOp(
559               spcall, spcall.func(), module, *data_var_to_size_var,
560               decomposed_partitioned_call_callees))) {
561         return failure();
562       }
563     }
564   }
565   return success();
566 }
567 
DecomposeStackOps(Block * block,ModuleOp module)568 LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
569   llvm::SmallDenseMap<Value, Value> data_var_to_size_var;
570   llvm::StringMap<PartitionedCallStackOpsInfo>
571       decomposed_partitioned_call_callees;
572   return DecomposeStackOpsInternal(block, module, &data_var_to_size_var,
573                                    &decomposed_partitioned_call_callees);
574 }
575 
runOnOperation()576 void StackOpsDecompositionPass::runOnOperation() {
577   auto module = getOperation();
578   auto main = module.lookupSymbol<FuncOp>("main");
579   if (!main) return;
580   if (failed(DecomposeStackOps(&main.front(), module))) {
581     signalPassFailure();
582   }
583 }
584 
585 static PassRegistration<StackOpsDecompositionPass> pass(
586     "tf-stack-ops-decomposition",
587     "Decompose stack operations into local variable operations. Needs static "
588     "shapes.");
589 
590 }  // namespace
591 
592 namespace TF {
CreateStackOpsDecompositionPass()593 std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass() {
594   return std::make_unique<StackOpsDecompositionPass>();
595 }
596 
597 }  // namespace TF
598 }  // namespace mlir
599