1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass transforms region bases control flow operations in
17 // the TensorFlow dialect to their functional counterparts, i.e.,
18 // tf.IfRegion -> tf.If and tf.WhileRegion -> tf.While
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/Builders.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
29 #include "mlir/IR/Value.h" // from @llvm-project
30 #include "mlir/IR/Verifier.h" // from @llvm-project
31 #include "mlir/IR/Visitors.h" // from @llvm-project
32 #include "mlir/Pass/Pass.h" // from @llvm-project
33 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
41
42 #define DEBUG_TYPE "tf-region-cf-to-functional"
43
44 namespace mlir {
45 namespace TF {
46
47 namespace {
48
49 constexpr char kElseFuncNameAttr[] = "_else_func_name";
50 constexpr char kThenFuncNameAttr[] = "_then_func_name";
51
52 struct RegionControlFlowToFunctional
53 : public TF::RegionControlFlowToFunctionalPassBase<
54 RegionControlFlowToFunctional> {
55 void runOnOperation() override;
56
57 private:
58 LogicalResult ConvertIfOp(IfRegionOp if_region);
59 LogicalResult ConvertWhileOp(WhileRegionOp while_region);
60
61 // Get unique name by using the loc to name mapping.
62 std::string GetName(Operation* op, StringRef suffix);
63
64 tensorflow::OpOrArgLocNameMapper mapper;
65 llvm::SmallVector<FuncOp, 4> worklist;
66 };
67
GetName(Operation * op,StringRef suffix)68 std::string RegionControlFlowToFunctional::GetName(Operation* op,
69 StringRef suffix) {
70 return (mapper.GetUniqueName(op) + suffix).str();
71 }
72
73 // Returns all the external values referenced from the given regions. If the
74 // external value is a constant, sink it into the region instead (and do not
75 // add it to the returned vector).
CollectExternValues(Region & first,Region & second)76 llvm::SmallVector<Value, 4> CollectExternValues(Region& first, Region& second) {
77 llvm::SetVector<Value> extern_values;
78
79 for (Region* region : {&first, &second}) {
80 llvm::SetVector<Value> region_extern_values;
81 getUsedValuesDefinedAbove(*region, region_extern_values);
82
83 // Sink down constants into the functions.
84 for (auto extern_value : region_extern_values) {
85 if (!matchPattern(extern_value, m_Constant())) {
86 extern_values.insert(extern_value);
87 continue;
88 }
89 // Add constant at start of region.
90 auto const_builder = OpBuilder::atBlockBegin(®ion->front());
91 auto const_value = const_builder.clone(*extern_value.getDefiningOp());
92 replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
93 *region);
94 }
95 }
96
97 return llvm::to_vector<4>(extern_values);
98 }
99
100 // Extracts the contents of a region with a single block into a new function.
101 // `extern_values` is the set of external values that the region refers to.
102 //
103 // Inputs to the terminator of the region are converted to return values of
104 // the function. If `extern_values_passthrough` is true, all the extern values
105 // are also added as return values from the function
ExtractSingleBlockRegion(Region & region,StringRef name,llvm::SmallVectorImpl<Value> & extern_values,llvm::SmallVectorImpl<FuncOp> & worklist,bool extern_values_passthrough)106 void ExtractSingleBlockRegion(Region& region, StringRef name,
107 llvm::SmallVectorImpl<Value>& extern_values,
108 llvm::SmallVectorImpl<FuncOp>& worklist,
109 bool extern_values_passthrough) {
110 ModuleOp module = region.getParentOfType<ModuleOp>();
111 auto builder = OpBuilder::atBlockBegin(module.getBody());
112 auto loc = region.getParentOp()->getLoc();
113 Block& entry = region.front();
114 int num_region_arguments = entry.getNumArguments();
115 Operation* terminator = entry.getTerminator();
116
117 // Build the function type. Region arguments and extern values together
118 // become the function arguments, with region arguments going first.
119 auto input_types = llvm::to_vector<4>(entry.getArgumentTypes());
120 for (auto input : extern_values) input_types.push_back(input.getType());
121
122 // Terminator operands and pass through extern values (if enabled) together
123 // become the function return values.
124 auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
125 if (extern_values_passthrough)
126 for (auto input : extern_values) return_types.push_back(input.getType());
127
128 auto type = FunctionType::get(region.getContext(), input_types, return_types);
129
130 // Create new function and extract region body into the function.
131 auto outlined_func = builder.create<FuncOp>(loc, name, type);
132 Region& func_region = outlined_func.getBody();
133 func_region.takeBody(region);
134 Block& first_block = func_region.front();
135
136 // Replace all external uses with function arguments.
137 for (auto it : llvm::enumerate(extern_values)) {
138 Value arg = first_block.addArgument(it.value().getType());
139 replaceAllUsesInRegionWith(it.value(), arg, func_region);
140 }
141
142 // Function return values are all the terminator operands + pass through
143 // extern values (if enabled).
144 auto return_values = llvm::to_vector<4>(terminator->getOperands());
145 if (extern_values_passthrough)
146 return_values.insert(return_values.end(),
147 first_block.args_begin() + num_region_arguments,
148 first_block.args_end());
149
150 // Replace the existing terminator with a return.
151 terminator = first_block.getTerminator();
152 builder.setInsertionPoint(terminator);
153 builder.create<ReturnOp>(terminator->getLoc(), return_values);
154 terminator->erase();
155
156 outlined_func.setPrivate();
157
158 // Add the outlined function to the worklist in case its body has
159 // IfRegion or WhileRegion ops that need to converted.
160 worklist.push_back(outlined_func);
161 }
162
163 // Returns call for region with single call whose result feeds into the
164 // terminator of the region. if `allow_to_bool` is true, also allows a single
165 // ToBoolOp between the region yield and the call. Returns none if the region
166 // does not conform to this pattern.
IsSingleCallRegion(Region & region,bool allow_to_bool=false)167 llvm::Optional<CallOp> IsSingleCallRegion(Region& region,
168 bool allow_to_bool = false) {
169 if (!llvm::hasSingleElement(region)) return llvm::None;
170
171 Block& block = region.front();
172 auto it = block.rbegin();
173 YieldOp yield = dyn_cast<YieldOp>(*it++);
174
175 if (it == block.rend()) return llvm::None;
176
177 // Operation which is expected to consume all the call results.
178 Operation* call_consumer = yield;
179
180 // Allow a single ToBoolOp between the call and the yield (valid only
181 // when the yield has a single operand)
182 if (allow_to_bool && yield.getNumOperands() == 1 && isa<ToBoolOp>(*it)) {
183 if (it->getResult(0) != yield.getOperand(0)) return llvm::None;
184 call_consumer = cast<ToBoolOp>(*it);
185 it++;
186 }
187
188 // Check if there is a Call before the Yield.
189 CallOp call = dyn_cast<CallOp>(*it++);
190 if (!call) return llvm::None;
191
192 // All call results should feed into expected consumer
193 // All results of the call should feed into the yield.
194 if (call.getNumResults() != call_consumer->getNumOperands())
195 return llvm::None;
196
197 for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands()))
198 if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
199
200 // There can only be non-truncating cast op's prior to the call.
201 for (; it != block.rend(); ++it) {
202 CastOp cast = dyn_cast<CastOp>(*it);
203 if (!cast || cast.Truncate()) return llvm::None;
204 }
205
206 return call;
207 }
208
209 using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
210
211 // Returns whether the arguments of the given 2 calls are match (after looking
212 // through cast ops). `matcher` is the predicate used to check if two arguments
213 // match.
MatchCallArgs(CallOp first,CallOp second,ArgMatcherFn matcher)214 bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
215 if (first.getNumOperands() != second.getNumOperands()) return false;
216
217 Region& first_region = *first->getParentRegion();
218 Region& second_region = *second->getParentRegion();
219
220 for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
221 // Get the defining Op, skipping over casts.
222 auto get_defining_op = [](Value value) {
223 while (auto cast_op =
224 llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
225 // Consider cast compatibility in case
226 // %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
227 // is skipped.
228 if (cast_op.SrcT() != cast_op.DstT()) {
229 break;
230 }
231 value = cast_op.getOperand();
232 }
233 return value;
234 };
235 Value first_arg = get_defining_op(std::get<0>(it));
236 Value second_arg = get_defining_op(std::get<1>(it));
237
238 if (!matcher(first_arg, first_region, second_arg, second_region))
239 return false;
240 }
241 return true;
242 }
243
244 // Summary information for trivially transforming region based op's to
245 // functional ops. A trivial transformation can be done when the regions are
246 // just calls to functions, in which case no outlining is needed.
247 struct TrivialTransformInfo {
248 // Can the op be transformed trivially?
249 bool can_transform = false;
250
251 // List of callee names (one for each region).
252 llvm::SmallVector<StringRef, 2> callee_names;
253
254 // Analyzes the given calls (from regions attached to the same parent op) to
255 // check if the parent op be transformed to functional form trivially (i.e.,
256 // reusing existing functions and without outlining). This is possible when
257 // all the regions are single call regions (checked using matchers outside
258 // this class) and the all the calls match using the given argument matcher.
259 //
260 // If such a trivial transformation is possible, stash the relevant
261 // information needed for the transformation, else indicate that a trivial
262 // transformation is not possible by setting `can_transform` to false.
TrivialTransformInfomlir::TF::__anon2f708ce50111::TrivialTransformInfo263 TrivialTransformInfo(llvm::Optional<CallOp> first_call,
264 llvm::Optional<CallOp> second_call,
265 ArgMatcherFn arg_matcher) {
266 if (!first_call || !second_call) return;
267
268 if (!MatchCallArgs(first_call.getValue(), second_call.getValue(),
269 arg_matcher))
270 return;
271
272 can_transform = true;
273 callee_names = {first_call.getValue().getCallee(),
274 second_call.getValue().getCallee()};
275 }
276 };
277
278 // Transform IfRegionOp to IfOp.
ConvertIfOp(IfRegionOp if_region)279 LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
280 llvm::SmallVector<Value, 4> extern_values;
281
282 // For IfOp, arguments of calls in the then and else regions match if they
283 // are the same value.
284 auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) {
285 if (first != second) return false;
286
287 // collect the call arguments post lookup through cast Op's
288 extern_values.push_back(first);
289 return true;
290 };
291
292 const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()),
293 IsSingleCallRegion(if_region.else_branch()),
294 if_arg_matcher);
295
296 std::string then_name, else_name;
297
298 if (tti.can_transform) {
299 // We can transform to functional form trivially without outlining.
300 then_name = tti.callee_names[0].str();
301 else_name = tti.callee_names[1].str();
302 } else {
303 // Collect external values that are used within the else and then bodies.
304 extern_values =
305 CollectExternValues(if_region.then_branch(), if_region.else_branch());
306
307 // These external values need to be added as inputs to the generated If. The
308 // order is determined by the order of these values the `extern_vales`.
309
310 // Create 2 new functions with the input signature matching this order,
311 // and outline the `then` and `else` regions by moving the bodies of these
312 // regions into these functions. Replace tf.yield with a regular return.
313 if (if_region->hasAttrOfType<StringAttr>(kThenFuncNameAttr) &&
314 !if_region._then_func_nameAttr().getValue().empty()) {
315 then_name =
316 mapper.GetUniqueName(if_region._then_func_nameAttr().getValue())
317 .str();
318 } else {
319 then_name = GetName(if_region, "_then");
320 }
321 ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values,
322 worklist, /*extern_values_passthrough=*/false);
323
324 if (if_region->hasAttrOfType<StringAttr>(kElseFuncNameAttr) &&
325 !if_region._else_func_nameAttr().getValue().empty()) {
326 else_name =
327 mapper.GetUniqueName(if_region._else_func_nameAttr().getValue())
328 .str();
329 } else {
330 else_name = GetName(if_region, "_else");
331 }
332 ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values,
333 worklist, /*extern_values_passthrough=*/false);
334 }
335
336 // Look through ToBool operations for the condition.
337 Value cond = if_region.cond();
338 auto to_bool = dyn_cast_or_null<ToBoolOp>(cond.getDefiningOp());
339 if (to_bool) cond = to_bool.getOperand();
340
341 // Once we have the `then` and `else` functions ready (either outlined or
342 // existing ones), replace the region based op with a functional control flow
343 // op.
344 OpBuilder builder(if_region);
345 auto if_op = builder.create<IfOp>(
346 if_region.getLoc(), if_region.getResultTypes(), cond, extern_values,
347 then_name, else_name, if_region.is_stateless());
348 CopyDeviceAndUnderscoredAttributes(if_region, if_op);
349 if_region.replaceAllUsesWith(if_op.getResults());
350 if_region.erase();
351
352 if (to_bool && to_bool.use_empty()) to_bool.erase();
353 return success();
354 }
355
356 // Transform WhileRegion to WhileOp.
ConvertWhileOp(WhileRegionOp while_region)357 LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
358 WhileRegionOp while_region) {
359 // For While, the arguments of the calls in the body and cond regions match
360 // if they are region arguments with the same region argument numbers. If the
361 // 2 calls have the same value (an extern value) used as an argument, we
362 // cannot do a trivial transformation because post transform, we will need to
363 // pass this extern value as an argument to the function, so we cannot use the
364 // existing function as is.
365 auto while_arg_matcher = [](Value first, Region& first_region, Value second,
366 Region& second_region) {
367 if (!first.isa<BlockArgument>() || !second.isa<BlockArgument>())
368 return false;
369 BlockArgument first_block_arg = first.cast<BlockArgument>();
370 BlockArgument second_block_arg = second.cast<BlockArgument>();
371
372 // 2 block arguments will match if they are the same argument number, and
373 // are block arguments of the corresponding containing regions.
374 return first_block_arg.getArgNumber() == second_block_arg.getArgNumber() &&
375 first_block_arg.getParentBlock() == &first_region.front() &&
376 second_block_arg.getParentBlock() == &second_region.front();
377 };
378
379 const TrivialTransformInfo tti(
380 IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true),
381 IsSingleCallRegion(while_region.body()), while_arg_matcher);
382
383 // All existing inputs to while region are inputs to the functional while.
384 auto new_inputs = llvm::to_vector<4>(while_region.getOperands());
385
386 // All existing results will also be generated by the functional while.
387 auto new_result_types = llvm::to_vector<4>(while_region.getResultTypes());
388
389 std::string cond_name, body_name;
390 if (tti.can_transform) {
391 // We can transform to functional form trivially without outlining.
392 cond_name = tti.callee_names[0].str();
393 body_name = tti.callee_names[1].str();
394 } else {
395 // The WhileRegion regions can refer to either arguments of the region, or
396 // external values implicitly captured by the region. When converting to
397 // functional form, all such external values need to become function
398 // arguments of the outlined functions, and become pass through values in
399 // the outlined body function. So when outlining the while body, in addition
400 // to the region arguments, all these external references need to be added
401 // as function arguments.
402 llvm::SmallVector<Value, 4> extern_values =
403 CollectExternValues(while_region.cond(), while_region.body());
404
405 // Outline the `cond` and `body` regions by moving the bodies of these
406 // regions into new functions. Replace tf.yield with a regular return.
407 cond_name = GetName(while_region, "_cond");
408 ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values,
409 worklist, /*extern_values_passthrough=*/false);
410
411 body_name = GetName(while_region, "_body");
412 ExtractSingleBlockRegion(while_region.body(), body_name, extern_values,
413 worklist, /*extern_values_passthrough=*/true);
414
415 // All extern values become additional inputs and additional output types
416 // for the functional while.
417 new_inputs.append(extern_values.begin(), extern_values.end());
418 for (auto ext : extern_values) new_result_types.push_back(ext.getType());
419 }
420
421 // Once we have the `cond` and `body` functions ready (either outlined or
422 // existing ones), replace the region based op with a functional op.
423 OpBuilder builder(while_region);
424 auto while_op = builder.create<WhileOp>(
425 while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
426 while_region.parallel_iterations(), while_region.is_stateless(),
427 while_region.shape_invariant());
428 CopyDeviceAndUnderscoredAttributes(while_region, while_op);
429
430 // Redirect old results to new results.
431 for (auto it : llvm::zip(
432 while_region.getResults(),
433 while_op.getResults().take_front(while_region.getNumResults())))
434 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
435
436 while_region.erase();
437 return success();
438 }
439
runOnOperation()440 void RegionControlFlowToFunctional::runOnOperation() {
441 ModuleOp module = getOperation();
442
443 // Seed worklist with all functions in the module.
444 worklist = llvm::to_vector<4>(module.getOps<FuncOp>());
445 while (!worklist.empty()) {
446 FuncOp function = worklist.pop_back_val();
447
448 auto result = function.walk([&](Operation* op) {
449 if (auto if_region = llvm::dyn_cast<IfRegionOp>(op)) {
450 if (failed(ConvertIfOp(if_region))) {
451 op->emitOpError() << "failed to convert to functional form";
452 return WalkResult::interrupt();
453 }
454 } else if (auto while_region = llvm::dyn_cast<WhileRegionOp>(op)) {
455 if (failed(ConvertWhileOp(while_region))) {
456 op->emitOpError() << "failed to convert to functional form";
457 return WalkResult::interrupt();
458 }
459 }
460 return WalkResult::advance();
461 });
462
463 if (result.wasInterrupted()) return signalPassFailure();
464 }
465 }
466
467 } // namespace
468
469 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFRegionControlFlowToFunctional()470 CreateTFRegionControlFlowToFunctional() {
471 return std::make_unique<RegionControlFlowToFunctional>();
472 }
473
474 } // namespace TF
475 } // namespace mlir
476