1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass transforms region bases control flow operations in
17 // the TensorFlow dialect to their functional counterparts, i.e.,
18 // tf.IfRegion ->  tf.If and tf.WhileRegion -> tf.While
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
29 #include "mlir/IR/Value.h"  // from @llvm-project
30 #include "mlir/IR/Verifier.h"  // from @llvm-project
31 #include "mlir/IR/Visitors.h"  // from @llvm-project
32 #include "mlir/Pass/Pass.h"  // from @llvm-project
33 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
41 
42 #define DEBUG_TYPE "tf-region-cf-to-functional"
43 
44 namespace mlir {
45 namespace TF {
46 
47 namespace {
48 
49 constexpr char kElseFuncNameAttr[] = "_else_func_name";
50 constexpr char kThenFuncNameAttr[] = "_then_func_name";
51 
52 struct RegionControlFlowToFunctional
53     : public TF::RegionControlFlowToFunctionalPassBase<
54           RegionControlFlowToFunctional> {
55   void runOnOperation() override;
56 
57  private:
58   LogicalResult ConvertIfOp(IfRegionOp if_region);
59   LogicalResult ConvertWhileOp(WhileRegionOp while_region);
60 
61   // Get unique name by using the loc to name mapping.
62   std::string GetName(Operation* op, StringRef suffix);
63 
64   tensorflow::OpOrArgLocNameMapper mapper;
65   llvm::SmallVector<FuncOp, 4> worklist;
66 };
67 
GetName(Operation * op,StringRef suffix)68 std::string RegionControlFlowToFunctional::GetName(Operation* op,
69                                                    StringRef suffix) {
70   return (mapper.GetUniqueName(op) + suffix).str();
71 }
72 
73 // Returns all the external values referenced from the given regions. If the
74 // external value is a constant, sink it into the region instead (and do not
75 // add it to the returned vector).
CollectExternValues(Region & first,Region & second)76 llvm::SmallVector<Value, 4> CollectExternValues(Region& first, Region& second) {
77   llvm::SetVector<Value> extern_values;
78 
79   for (Region* region : {&first, &second}) {
80     llvm::SetVector<Value> region_extern_values;
81     getUsedValuesDefinedAbove(*region, region_extern_values);
82 
83     // Sink down constants into the functions.
84     for (auto extern_value : region_extern_values) {
85       if (!matchPattern(extern_value, m_Constant())) {
86         extern_values.insert(extern_value);
87         continue;
88       }
89       // Add constant at start of region.
90       auto const_builder = OpBuilder::atBlockBegin(&region->front());
91       auto const_value = const_builder.clone(*extern_value.getDefiningOp());
92       replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
93                                  *region);
94     }
95   }
96 
97   return llvm::to_vector<4>(extern_values);
98 }
99 
100 // Extracts the contents of a region with a single block into a new function.
101 // `extern_values` is the set of external values that the region refers to.
102 //
103 // Inputs to the terminator of the region are converted to return values of
104 // the function. If `extern_values_passthrough` is true, all the extern values
105 // are also added as return values from the function
ExtractSingleBlockRegion(Region & region,StringRef name,llvm::SmallVectorImpl<Value> & extern_values,llvm::SmallVectorImpl<FuncOp> & worklist,bool extern_values_passthrough)106 void ExtractSingleBlockRegion(Region& region, StringRef name,
107                               llvm::SmallVectorImpl<Value>& extern_values,
108                               llvm::SmallVectorImpl<FuncOp>& worklist,
109                               bool extern_values_passthrough) {
110   ModuleOp module = region.getParentOfType<ModuleOp>();
111   auto builder = OpBuilder::atBlockBegin(module.getBody());
112   auto loc = region.getParentOp()->getLoc();
113   Block& entry = region.front();
114   int num_region_arguments = entry.getNumArguments();
115   Operation* terminator = entry.getTerminator();
116 
117   // Build the function type. Region arguments and extern values together
118   // become the function arguments, with region arguments going first.
119   auto input_types = llvm::to_vector<4>(entry.getArgumentTypes());
120   for (auto input : extern_values) input_types.push_back(input.getType());
121 
122   // Terminator operands and pass through extern values (if enabled) together
123   // become the function return values.
124   auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
125   if (extern_values_passthrough)
126     for (auto input : extern_values) return_types.push_back(input.getType());
127 
128   auto type = FunctionType::get(region.getContext(), input_types, return_types);
129 
130   // Create new function and extract region body into the function.
131   auto outlined_func = builder.create<FuncOp>(loc, name, type);
132   Region& func_region = outlined_func.getBody();
133   func_region.takeBody(region);
134   Block& first_block = func_region.front();
135 
136   // Replace all external uses with function arguments.
137   for (auto it : llvm::enumerate(extern_values)) {
138     Value arg = first_block.addArgument(it.value().getType());
139     replaceAllUsesInRegionWith(it.value(), arg, func_region);
140   }
141 
142   // Function return values are all the terminator operands + pass through
143   // extern values (if enabled).
144   auto return_values = llvm::to_vector<4>(terminator->getOperands());
145   if (extern_values_passthrough)
146     return_values.insert(return_values.end(),
147                          first_block.args_begin() + num_region_arguments,
148                          first_block.args_end());
149 
150   // Replace the existing terminator with a return.
151   terminator = first_block.getTerminator();
152   builder.setInsertionPoint(terminator);
153   builder.create<ReturnOp>(terminator->getLoc(), return_values);
154   terminator->erase();
155 
156   outlined_func.setPrivate();
157 
158   // Add the outlined function to the worklist in case its body has
159   // IfRegion or WhileRegion ops that need to converted.
160   worklist.push_back(outlined_func);
161 }
162 
163 // Returns call for region with single call whose result feeds into the
164 // terminator of the region. if `allow_to_bool` is true, also allows a single
165 // ToBoolOp between the region yield and the call. Returns none if the region
166 // does not conform to this pattern.
IsSingleCallRegion(Region & region,bool allow_to_bool=false)167 llvm::Optional<CallOp> IsSingleCallRegion(Region& region,
168                                           bool allow_to_bool = false) {
169   if (!llvm::hasSingleElement(region)) return llvm::None;
170 
171   Block& block = region.front();
172   auto it = block.rbegin();
173   YieldOp yield = dyn_cast<YieldOp>(*it++);
174 
175   if (it == block.rend()) return llvm::None;
176 
177   // Operation which is expected to consume all the call results.
178   Operation* call_consumer = yield;
179 
180   // Allow a single ToBoolOp between the call and the yield (valid only
181   // when the yield has a single operand)
182   if (allow_to_bool && yield.getNumOperands() == 1 && isa<ToBoolOp>(*it)) {
183     if (it->getResult(0) != yield.getOperand(0)) return llvm::None;
184     call_consumer = cast<ToBoolOp>(*it);
185     it++;
186   }
187 
188   // Check if there is a Call before the Yield.
189   CallOp call = dyn_cast<CallOp>(*it++);
190   if (!call) return llvm::None;
191 
192   // All call results should feed into expected consumer
193   // All results of the call should feed into the yield.
194   if (call.getNumResults() != call_consumer->getNumOperands())
195     return llvm::None;
196 
197   for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands()))
198     if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
199 
200   // There can only be non-truncating cast op's prior to the call.
201   for (; it != block.rend(); ++it) {
202     CastOp cast = dyn_cast<CastOp>(*it);
203     if (!cast || cast.Truncate()) return llvm::None;
204   }
205 
206   return call;
207 }
208 
209 using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
210 
211 // Returns whether the arguments of the given 2 calls are match (after looking
212 // through cast ops). `matcher` is the predicate used to check if two arguments
213 // match.
MatchCallArgs(CallOp first,CallOp second,ArgMatcherFn matcher)214 bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
215   if (first.getNumOperands() != second.getNumOperands()) return false;
216 
217   Region& first_region = *first->getParentRegion();
218   Region& second_region = *second->getParentRegion();
219 
220   for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
221     // Get the defining Op, skipping over casts.
222     auto get_defining_op = [](Value value) {
223       while (auto cast_op =
224                  llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
225         // Consider cast compatibility in case
226         //    %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
227         // is skipped.
228         if (cast_op.SrcT() != cast_op.DstT()) {
229           break;
230         }
231         value = cast_op.getOperand();
232       }
233       return value;
234     };
235     Value first_arg = get_defining_op(std::get<0>(it));
236     Value second_arg = get_defining_op(std::get<1>(it));
237 
238     if (!matcher(first_arg, first_region, second_arg, second_region))
239       return false;
240   }
241   return true;
242 }
243 
244 // Summary information for trivially transforming region based op's to
245 // functional ops. A trivial transformation can be done when the regions are
246 // just calls to functions, in which case no outlining is needed.
247 struct TrivialTransformInfo {
248   // Can the op be transformed trivially?
249   bool can_transform = false;
250 
251   // List of callee names (one for each region).
252   llvm::SmallVector<StringRef, 2> callee_names;
253 
254   // Analyzes the given calls (from regions attached to the same parent op) to
255   // check if the parent op be transformed to functional form trivially (i.e.,
256   // reusing existing functions and without outlining). This is possible when
257   // all the regions are single call regions (checked using matchers outside
258   // this class) and the all the calls match using the given argument matcher.
259   //
260   // If such a trivial transformation is possible, stash the relevant
261   // information needed for the transformation, else indicate that a trivial
262   // transformation is not possible by setting `can_transform` to false.
TrivialTransformInfomlir::TF::__anon2f708ce50111::TrivialTransformInfo263   TrivialTransformInfo(llvm::Optional<CallOp> first_call,
264                        llvm::Optional<CallOp> second_call,
265                        ArgMatcherFn arg_matcher) {
266     if (!first_call || !second_call) return;
267 
268     if (!MatchCallArgs(first_call.getValue(), second_call.getValue(),
269                        arg_matcher))
270       return;
271 
272     can_transform = true;
273     callee_names = {first_call.getValue().getCallee(),
274                     second_call.getValue().getCallee()};
275   }
276 };
277 
278 // Transform IfRegionOp to IfOp.
ConvertIfOp(IfRegionOp if_region)279 LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
280   llvm::SmallVector<Value, 4> extern_values;
281 
282   // For IfOp, arguments of calls in the then and else regions match if they
283   // are the same value.
284   auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) {
285     if (first != second) return false;
286 
287     // collect the call arguments post lookup through cast Op's
288     extern_values.push_back(first);
289     return true;
290   };
291 
292   const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()),
293                                  IsSingleCallRegion(if_region.else_branch()),
294                                  if_arg_matcher);
295 
296   std::string then_name, else_name;
297 
298   if (tti.can_transform) {
299     // We can transform to functional form trivially without outlining.
300     then_name = tti.callee_names[0].str();
301     else_name = tti.callee_names[1].str();
302   } else {
303     // Collect external values that are used within the else and then bodies.
304     extern_values =
305         CollectExternValues(if_region.then_branch(), if_region.else_branch());
306 
307     // These external values need to be added as inputs to the generated If. The
308     // order is determined by the order of these values the `extern_vales`.
309 
310     // Create 2 new functions with the input signature matching this order,
311     // and outline the `then` and `else` regions by moving the bodies of these
312     // regions into these functions. Replace tf.yield with a regular return.
313     if (if_region->hasAttrOfType<StringAttr>(kThenFuncNameAttr) &&
314         !if_region._then_func_nameAttr().getValue().empty()) {
315       then_name =
316           mapper.GetUniqueName(if_region._then_func_nameAttr().getValue())
317               .str();
318     } else {
319       then_name = GetName(if_region, "_then");
320     }
321     ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values,
322                              worklist, /*extern_values_passthrough=*/false);
323 
324     if (if_region->hasAttrOfType<StringAttr>(kElseFuncNameAttr) &&
325         !if_region._else_func_nameAttr().getValue().empty()) {
326       else_name =
327           mapper.GetUniqueName(if_region._else_func_nameAttr().getValue())
328               .str();
329     } else {
330       else_name = GetName(if_region, "_else");
331     }
332     ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values,
333                              worklist, /*extern_values_passthrough=*/false);
334   }
335 
336   // Look through ToBool operations for the condition.
337   Value cond = if_region.cond();
338   auto to_bool = dyn_cast_or_null<ToBoolOp>(cond.getDefiningOp());
339   if (to_bool) cond = to_bool.getOperand();
340 
341   // Once we have the `then` and `else` functions ready (either outlined or
342   // existing ones), replace the region based op with a functional control flow
343   // op.
344   OpBuilder builder(if_region);
345   auto if_op = builder.create<IfOp>(
346       if_region.getLoc(), if_region.getResultTypes(), cond, extern_values,
347       then_name, else_name, if_region.is_stateless());
348   CopyDeviceAndUnderscoredAttributes(if_region, if_op);
349   if_region.replaceAllUsesWith(if_op.getResults());
350   if_region.erase();
351 
352   if (to_bool && to_bool.use_empty()) to_bool.erase();
353   return success();
354 }
355 
356 // Transform WhileRegion to WhileOp.
ConvertWhileOp(WhileRegionOp while_region)357 LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
358     WhileRegionOp while_region) {
359   // For While, the arguments of the calls in the body and cond regions match
360   // if they are region arguments with the same region argument numbers. If the
361   // 2 calls have the same value (an extern value) used as an argument, we
362   // cannot do a trivial transformation because post transform, we will need to
363   // pass this extern value as an argument to the function, so we cannot use the
364   // existing function as is.
365   auto while_arg_matcher = [](Value first, Region& first_region, Value second,
366                               Region& second_region) {
367     if (!first.isa<BlockArgument>() || !second.isa<BlockArgument>())
368       return false;
369     BlockArgument first_block_arg = first.cast<BlockArgument>();
370     BlockArgument second_block_arg = second.cast<BlockArgument>();
371 
372     // 2 block arguments will match if they are the same argument number, and
373     // are block arguments of the corresponding containing regions.
374     return first_block_arg.getArgNumber() == second_block_arg.getArgNumber() &&
375            first_block_arg.getParentBlock() == &first_region.front() &&
376            second_block_arg.getParentBlock() == &second_region.front();
377   };
378 
379   const TrivialTransformInfo tti(
380       IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true),
381       IsSingleCallRegion(while_region.body()), while_arg_matcher);
382 
383   // All existing inputs to while region are inputs to the functional while.
384   auto new_inputs = llvm::to_vector<4>(while_region.getOperands());
385 
386   // All existing results will also be generated by the functional while.
387   auto new_result_types = llvm::to_vector<4>(while_region.getResultTypes());
388 
389   std::string cond_name, body_name;
390   if (tti.can_transform) {
391     // We can transform to functional form trivially without outlining.
392     cond_name = tti.callee_names[0].str();
393     body_name = tti.callee_names[1].str();
394   } else {
395     // The WhileRegion regions can refer to either arguments of the region, or
396     // external values implicitly captured by the region. When converting to
397     // functional form, all such external values need to become function
398     // arguments of the outlined functions, and become pass through values in
399     // the outlined body function. So when outlining the while body, in addition
400     // to the region arguments, all these external references need to be added
401     // as function arguments.
402     llvm::SmallVector<Value, 4> extern_values =
403         CollectExternValues(while_region.cond(), while_region.body());
404 
405     // Outline the `cond` and `body` regions by moving the bodies of these
406     // regions into new functions. Replace tf.yield with a regular return.
407     cond_name = GetName(while_region, "_cond");
408     ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values,
409                              worklist, /*extern_values_passthrough=*/false);
410 
411     body_name = GetName(while_region, "_body");
412     ExtractSingleBlockRegion(while_region.body(), body_name, extern_values,
413                              worklist, /*extern_values_passthrough=*/true);
414 
415     // All extern values become additional inputs and additional output types
416     // for the functional while.
417     new_inputs.append(extern_values.begin(), extern_values.end());
418     for (auto ext : extern_values) new_result_types.push_back(ext.getType());
419   }
420 
421   // Once we have the `cond` and `body` functions ready (either outlined or
422   // existing ones), replace the region based op with a functional op.
423   OpBuilder builder(while_region);
424   auto while_op = builder.create<WhileOp>(
425       while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
426       while_region.parallel_iterations(), while_region.is_stateless(),
427       while_region.shape_invariant());
428   CopyDeviceAndUnderscoredAttributes(while_region, while_op);
429 
430   // Redirect old results to new results.
431   for (auto it : llvm::zip(
432            while_region.getResults(),
433            while_op.getResults().take_front(while_region.getNumResults())))
434     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
435 
436   while_region.erase();
437   return success();
438 }
439 
runOnOperation()440 void RegionControlFlowToFunctional::runOnOperation() {
441   ModuleOp module = getOperation();
442 
443   // Seed worklist with all functions in the module.
444   worklist = llvm::to_vector<4>(module.getOps<FuncOp>());
445   while (!worklist.empty()) {
446     FuncOp function = worklist.pop_back_val();
447 
448     auto result = function.walk([&](Operation* op) {
449       if (auto if_region = llvm::dyn_cast<IfRegionOp>(op)) {
450         if (failed(ConvertIfOp(if_region))) {
451           op->emitOpError() << "failed to convert to functional form";
452           return WalkResult::interrupt();
453         }
454       } else if (auto while_region = llvm::dyn_cast<WhileRegionOp>(op)) {
455         if (failed(ConvertWhileOp(while_region))) {
456           op->emitOpError() << "failed to convert to functional form";
457           return WalkResult::interrupt();
458         }
459       }
460       return WalkResult::advance();
461     });
462 
463     if (result.wasInterrupted()) return signalPassFailure();
464   }
465 }
466 
467 }  // namespace
468 
469 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFRegionControlFlowToFunctional()470 CreateTFRegionControlFlowToFunctional() {
471   return std::make_unique<RegionControlFlowToFunctional>();
472 }
473 
474 }  // namespace TF
475 }  // namespace mlir
476