1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass lifts resource variable operations outside of device computation.
17
18 #include <cstddef>
19 #include <cstdint>
20
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Block.h" // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
34 #include "mlir/IR/Builders.h" // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
37 #include "mlir/IR/Diagnostics.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/Region.h" // from @llvm-project
40 #include "mlir/IR/SymbolTable.h" // from @llvm-project
41 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
42 #include "mlir/IR/Types.h" // from @llvm-project
43 #include "mlir/IR/Value.h" // from @llvm-project
44 #include "mlir/IR/Verifier.h" // from @llvm-project
45 #include "mlir/IR/Visitors.h" // from @llvm-project
46 #include "mlir/Pass/Pass.h" // from @llvm-project
47 #include "mlir/Support/LLVM.h" // from @llvm-project
48 #include "mlir/Support/LogicalResult.h" // from @llvm-project
49 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
54 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h"
56 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
57 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
58 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
59 #include "tensorflow/core/framework/tensor_shape.pb.h"
60
61 namespace mlir {
62
63 namespace {
64
65 // This pass lifts resource variable operations outside of device computation.
66 // This is useful because a lot of accelerator devices can not interact with
67 // resource variables directly..
68 //
69 // Here is a simple example in TensorFlow where a device doubles the value of a
70 // TensorFlow resource variable and returns new value:
71 //
72 // %resource_handle = "tf.VarHandleOp"()
73 // %1 = "tf_device.cluster"() ( {
74 // %init_value = "tf.ReadVariableOp"(%resource_handle)
75 // "tf.AssignAddVariableOp"(%resource_handle, %init_value)
76 // %new_value = "tf.ReadVariableOp"(%resource_handle)
77 // tf_device.return %new_value
78 // })
79 //
80 // After this pass, the computation would become:
81 //
82 // %resource_handle = "tf.VarHandleOp"()
83 // %init_value = "tf.ReadVariableOp"(%resource_handle)
84 // %1:2 = "tf_device.cluster"() ( {
85 // %new_value = "tf.AddV2"(%init_value, %init_value)
86 // tf_device.return %new_value, %new_value
87 // })
88 // "tf.AssignVariableOp"(%resource_handle, %1#1)
89 //
90 // You can see that there are a few main changes applied:
91 // 1) All the resource variable reads and writes are now outside of
92 // tf_device.cluster op.
93 // 2) Instead of taking resource handles as input, this device computation now
94 // takes snapshotted values of that device.
95 // 3) Some resource load operations are eliminated with store-load forwarding.
96 // 4) Updated values to resource are appended to `tf_device.return` and used by
97 // external resource store operations so that resources are still updated
98 // after the computation.
99 //
100 // If the cluster body contains functional control flow, the pass first lifts
101 // the loads/stores in the body/cond/branch functions to the cluster body, then
102 // performs the above lifting. E.g.,
103 //
104 // func @cluster_with_loop() -> () {
105 // %0 = "tf.VarHandleOp"() ...
106 // "tf_device.cluster"() ( {
107 // %1 = "tf.While"(%0) {body = @while_body, cond = @while_cond}
108 // tf_device.return
109 // })
110 // return
111 // }
112 // func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
113 // %constant = "tf.Const"() ...
114 // "tf.AssignVariableOp"(%arg0, %constant)
115 // return %arg0
116 // }
117 // func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
118 // %read = "tf.ReadVariableOp"(%arg0)
119 // return %read
120 // }
121 //
122 // will be transformed to:
123 //
124 // func @cluster_with_loop() {
125 // %0 = "tf.VarHandleOp"() ...
126 // %1 = "tf.ReadVariableOp"(%0)
127 // %2 = "tf_device.cluster"() ( {
128 // %3 = "tf.While"(%1) {body = @while_body, cond = @while_cond}
129 // tf_device.return %3 : tensor<f32>
130 // }) : () -> tensor<f32>
131 // "tf.AssignVariableOp"(%0, %2)
132 // return
133 // }
134 // func @while_body(%arg0: tensor<f32>) {
135 // %0 = "tf.Const"() ...
136 // return %0 : tensor<f32>
137 // }
138 // func @while_cond(%arg0: tensor<f32>) {
139 // return %arg0
140 // }
141 //
142 struct ResourceOpLiftingPass
143 : public PassWrapper<ResourceOpLiftingPass, OperationPass<ModuleOp>> {
144 void runOnOperation() override;
145 };
146
IsResource(Value value)147 bool IsResource(Value value) {
148 return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
149 }
150
151 // Get the type of the data contained in a resource. Returns null if there is
152 // no single type in the resource.
GetResourceSubtype(Value value)153 Type GetResourceSubtype(Value value) {
154 auto resource_type =
155 getElementTypeOrSelf(value.getType()).dyn_cast<TF::ResourceType>();
156 auto subtypes = resource_type.getSubtypes();
157 if (subtypes.size() == 1) return subtypes[0];
158 return nullptr;
159 }
160
161 // Replaces all `tf.VarIsInitializedOp` in a block with a constant true.
162 // TODO(b/171039585): Replace this with proper analysis of
163 // `tf.VarIsInitializedOp` in regards to resource writes and control flow.
SetAllVarIsInitializedToTrue(Block * block)164 void SetAllVarIsInitializedToTrue(Block* block) {
165 auto builder = OpBuilder::atBlockBegin(block);
166 TF::ConstOp const_true = nullptr;
167 for (auto op :
168 llvm::make_early_inc_range(block->getOps<TF::VarIsInitializedOp>())) {
169 builder.setInsertionPoint(op);
170 if (!const_true)
171 const_true = builder.create<TF::ConstOp>(
172 op.getLoc(),
173 DenseIntElementsAttr::get(
174 RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true));
175
176 op.is_initialized().replaceAllUsesWith(const_true);
177 op.erase();
178 }
179 }
180
181 // Performs store-load forwarding. This effectively removes
182 // 1) Any resource loads after a store to that same resource is done
183 // 2) Any resource stores except the last one.
184 // TODO(ycao): Store-load forwarding implemented here is only correct when
185 // computation is purely sequential (no concurrency). Need to support concurrent
186 // computation as well.
ForwardStoreToLoad(Block * block)187 void ForwardStoreToLoad(Block* block) {
188 // resource_handle_to_last_store_op keeps track of the most recent (last)
189 // store to each resource. Non-existent entry indicates that a resource has
190 // not been stored to yet.
191 llvm::SmallDenseMap<Value, TF::AssignVariableOp>
192 resource_handle_to_last_store_op;
193
194 // Only iterate through ops directly in the block as we can't handle ops
195 // nested deeper in regions.
196 for (Operation& op : llvm::make_early_inc_range(*block)) {
197 if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) {
198 Value resource = read_variable_op.resource();
199 auto last_store = resource_handle_to_last_store_op[resource];
200 if (!last_store) continue;
201
202 // Use stored value in last_store to replace all uses of current resource
203 // load's result, then erase this resource load. Add an intermediate
204 // CastOp if the shape of types doesn't exactly match.
205 Type read_type = read_variable_op.value().getType();
206 if (read_type != last_store.value().getType()) {
207 OpBuilder builder(last_store);
208 builder.setInsertionPointAfter(last_store);
209 auto cast = builder.create<TF::CastOp>(
210 last_store.getLoc(), read_type, last_store.value(),
211 /*Truncate=*/builder.getBoolAttr(false));
212 read_variable_op.value().replaceAllUsesWith(cast);
213 } else {
214 read_variable_op.value().replaceAllUsesWith(last_store.value());
215 }
216
217 read_variable_op.erase();
218 continue;
219 }
220
221 if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) {
222 Value resource = assign_variable_op.resource();
223 auto last_store = resource_handle_to_last_store_op[resource];
224 // Previous store ops to same resource can be erased.
225 if (last_store) last_store.erase();
226
227 resource_handle_to_last_store_op[resource] = assign_variable_op;
228 }
229 }
230 }
231
232 //===----------------------------------------------------------------------===//
233 // RegionResourceHoister
234 //===----------------------------------------------------------------------===//
235
236 // Helper class to hoist resource ops out of regions attached to an op.
237 class RegionResourceHoister {
238 public:
RegionResourceHoister(Operation * op)239 explicit RegionResourceHoister(Operation* op) : op_(op) {}
240
241 // Analyzes attached regions to record resources read and written.
242 LogicalResult Analyze();
243
244 // Returns all resources accessed by the regions attached the op.
GetResources()245 auto& GetResources() { return resources_; }
246
247 // Returns if the given value is a resource that needs lifting.
Contains(Value resource) const248 bool Contains(Value resource) const {
249 return resources_.find(resource) != resources_.end();
250 }
251
252 // Drops the given resource from lifting.
DropResource(Value resource)253 void DropResource(Value resource) {
254 resources_.erase(resource);
255 written_resources_.remove(resource);
256 }
257
258 // Replaces all resource loads in all regions attached to the op.
ReplaceResourceLoads(bool read_only)259 void ReplaceResourceLoads(bool read_only) {
260 llvm::for_each(op_->getRegions(), [&](Region& region) {
261 ReplaceResourceLoads(region, read_only);
262 });
263 }
264
265 static LogicalResult ReplaceOpWithNewOp(Operation* op);
266
267 private:
268 // Returns if any resources need lifting.
NeedsLifting() const269 bool NeedsLifting() const { return !resources_.empty(); }
270
271 // Returns the number of results generated by the lifted op.
GetLiftedNumResults() const272 int GetLiftedNumResults() const { return num_new_results_; }
273
274 // Generates hoisted reads for resources that need them before the op.
275 void GenerateHoistedReads();
276
277 // Replaces all resource loads in the given region with hoisted loads. If
278 // `read_only` is true, limit this replacement to read only resources.
279 void ReplaceResourceLoads(Region& region, bool read_only);
280
281 // Appends final values writte to resources to the region returns for the
282 // given set of regions.
283 void AppendResourceStoreValueToReturn(RegionRange regions);
284
285 // Performs the final replacement of the op.
286 void ReplaceOpWithNewOp();
287
288 // Returns is this resource was written to in any of the regions.
IsWritten(Value resource) const289 bool IsWritten(Value resource) const {
290 return written_resources_.contains(resource);
291 }
292
293 static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op);
294 static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op);
295
296 Operation* op_;
297
298 // Per resource information about accesses to that resource.
299 struct ResourceInfo {
300 // Is this resource read in any of the regions?
301 bool is_read;
302 // Is this resource written in any of the regions?
303 bool is_written;
304 // Is this resource written in all of the regions?
305 bool is_written_all;
306 // The hoisted read used to replace region reads.
307 Value hoisted_read;
308 // the type of the data held by the resource.
309 Type data_type;
310 // For written resources, the result # of the lifted op which will hold the
311 // value of the resource. This result will be used to generates writes to
312 // the resource after the lifted op.
313 int result_index;
314 // Attributes on the read operation.
315 DictionaryAttr read_attrs;
316 // Attributes on the write operation.
317 DictionaryAttr write_attrs;
318
ResourceInfomlir::__anonf2669e680111::RegionResourceHoister::ResourceInfo319 ResourceInfo()
320 : is_read(false),
321 is_written(false),
322 is_written_all(false),
323 hoisted_read(nullptr),
324 data_type(nullptr),
325 result_index(-1) {}
326
IsResultIndexAssignedmlir::__anonf2669e680111::RegionResourceHoister::ResourceInfo327 bool IsResultIndexAssigned() { return result_index != -1; }
328
329 // Refine the resource type using the given type `type`.
RefineTypemlir::__anonf2669e680111::RegionResourceHoister::ResourceInfo330 void RefineType(Type type) {
331 if (!data_type) {
332 data_type = type;
333 } else {
334 data_type = TF::GetCastCompatibleType(data_type, type,
335 /*may_ignore_ref_type_a=*/false);
336 assert(data_type != nullptr && "Resource used with incompatible types");
337 }
338 }
339 };
340 llvm::MapVector<Value, ResourceInfo> resources_;
341 llvm::SetVector<Value> written_resources_;
342 // number of new results after lifting.
343 int num_new_results_;
344 };
345
346 // Analyzes resources that are read or written within attached regions.
Analyze()347 LogicalResult RegionResourceHoister::Analyze() {
348 // Hoisting of child regions might have created opportunity for store-load
349 // forwarding.
350 for (Region& region : op_->getRegions()) {
351 ForwardStoreToLoad(®ion.front());
352 }
353
354 llvm::SetVector<Value> all_resources;
355 bool is_func = false;
356 // For functions, the resources to analyze are the function arguments.
357 // Otherwise, its the region captures.
358 if (FuncOp func = dyn_cast<FuncOp>(op_)) {
359 is_func = true;
360 Region& body = func.getBody();
361 for (BlockArgument arg : body.getArguments()) {
362 if (IsResource(arg)) all_resources.insert(arg);
363 }
364 } else {
365 getUsedValuesDefinedAbove(op_->getRegions(), all_resources);
366 all_resources.remove_if([](Value value) { return !IsResource(value); });
367 }
368
369 num_new_results_ = op_->getNumResults();
370
371 for (auto resource : all_resources) {
372 ResourceInfo info;
373 info.data_type = GetResourceSubtype(resource);
374 llvm::BitVector written_regions(op_->getNumRegions());
375 bool unsupported_use = false;
376 for (OpOperand& use : resource.getUses()) {
377 Operation* user = use.getOwner();
378 // If the user is not in one of the regions, we are not interested in it.
379 // Since all the sub-regions within this region (i.e., regions attached to
380 // op's in this region) have themselves gone through lifting, all resource
381 // users are expected to be operations in this region and not embedded
382 // within other sub-regions attached to op's in this region. So the check
383 // for whether a user is in one of the regions attached to this op is
384 // straightforward.
385 if (user->getParentRegion()->getParentOp() != op_) continue;
386
387 // For functions, if the resource is used as a return operand, use that
388 // as its result index.
389 if (is_func && isa<ReturnOp>(user)) {
390 assert(!info.IsResultIndexAssigned() &&
391 "Expect resource argument to returned no more than once");
392 info.result_index = use.getOperandNumber();
393 continue;
394 }
395
396 auto read = dyn_cast<TF::ReadVariableOp>(user);
397 auto write = dyn_cast<TF::AssignVariableOp>(user);
398 if (!read && !write) {
399 unsupported_use = true;
400 break;
401 }
402
403 if (read && !info.is_read) {
404 info.is_read = true;
405 info.RefineType(read.value().getType());
406 info.read_attrs = user->getAttrDictionary();
407 }
408
409 if (write) {
410 info.is_written = true;
411 info.RefineType(write.value().getType());
412 info.write_attrs = user->getAttrDictionary();
413 written_regions.set(user->getParentRegion()->getRegionNumber());
414 }
415 }
416
417 // If the resource is used in an op that we do not understand, skip
418 // lifting for that resource.
419 if (unsupported_use) continue;
420
421 info.is_written_all = written_regions.count() == op_->getNumRegions();
422
423 // If the resource is written in some but not all regions, we would need
424 // a read for the value before these regions. Note that this is applicable
425 // only to multi-region ops:
426 // If/Case: If not all regions write to the resource, post hoisting the read
427 // value need to be routed through all paths that don't write.
428 // While: since while condition cannot write, any resource written in the
429 // while body will need to be read as well in case the while body is never
430 // executed.
431 // Both cases are handled by the condition below.
432 if (info.is_written && !info.is_written_all) info.is_read = true;
433
434 // Allocate a result index for written resources that don't have one.
435 if (info.is_written) {
436 written_resources_.insert(resource);
437 if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++;
438 }
439
440 resources_.insert({resource, info});
441 }
442 return success();
443 }
444
445 // Generates hoisted reads for all resources that need them just before the op.
GenerateHoistedReads()446 void RegionResourceHoister::GenerateHoistedReads() {
447 OpBuilder builder(op_);
448 DictionaryAttr empty_attrs = builder.getDictionaryAttr({});
449 for (auto& resource_it : GetResources()) {
450 Value resource = resource_it.first;
451 auto& info = resource_it.second;
452
453 if (info.is_read) {
454 Operation* read = builder.create<TF::ReadVariableOp>(
455 op_->getLoc(), info.data_type, resource);
456 read->setAttrs(info.read_attrs ? info.read_attrs : empty_attrs);
457 info.hoisted_read = read->getResult(0);
458 }
459 }
460 }
461
462 // Replaces all resource reads with the hoisted read.
ReplaceResourceLoads(Region & region,bool read_only)463 void RegionResourceHoister::ReplaceResourceLoads(Region& region,
464 bool read_only) {
465 assert(llvm::hasSingleElement(region) && "Expected single block region");
466 // Only iterate through ops directly in the body as we can't handle
467 // ops nested deeper in regions.
468 auto all_reads = region.front().getOps<TF::ReadVariableOp>();
469 for (auto read_op : llvm::make_early_inc_range(all_reads)) {
470 Value resource = read_op.resource();
471 if (!Contains(resource)) continue;
472
473 ResourceInfo& info = resources_[resource];
474 // If replacing loads for read only resources, skip if the resource
475 // was written to.
476 if (read_only && info.is_written) continue;
477
478 read_op.replaceAllUsesWith(info.hoisted_read);
479 read_op.erase();
480 }
481 }
482
483 // For written resources, add its value at the end of each region to that
484 // regions return value. For a region, its value at the end may be a value
485 // written to that resource in that region, or its hoisted read value if the
486 // resource is not written in that region. The return value can be vended out
487 // either as an existing return value, or a newly allocated return value.
AppendResourceStoreValueToReturn(RegionRange regions)488 void RegionResourceHoister::AppendResourceStoreValueToReturn(
489 RegionRange regions) {
490 for (Region* region : regions) {
491 assert(llvm::hasSingleElement(*region) && "Expected single block region");
492 Block& front = region->front();
493 auto old_return = front.getTerminator();
494 assert(old_return->getNumOperands() == op_->getNumResults());
495 auto new_return_operands = llvm::to_vector<4>(old_return->getOperands());
496 new_return_operands.resize(num_new_results_);
497
498 // initialize return values for written resources to be the hoisted reads.
499 for (Value resource : written_resources_) {
500 const ResourceInfo& info = resources_[resource];
501 new_return_operands[info.result_index] = info.hoisted_read;
502 }
503
504 // Only iterate through ops directly in the body as op's embedded in child
505 // regions should have been lifted out.
506 auto assign_ops = front.getOps<TF::AssignVariableOp>();
507 for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
508 Value resource = assign_variable_op.resource();
509 if (!IsWritten(resource)) continue;
510
511 // TODO(ycao): Prevent same value from being returned multiple times.
512 // TODO(ycao): Do not return resource store value if it is defined outside
513 // of cluster. Both of these can be post-resource-op-lifting cleanup
514 // passes.
515 int result_index = resources_[resource].result_index;
516 new_return_operands[result_index] = assign_variable_op.value();
517 assign_variable_op.erase();
518 }
519 old_return->setOperands(new_return_operands);
520 }
521 }
522
523 // Replace the old op with a new op (with potentially additional results), and
524 // add stores to written resources after the new op.
ReplaceOpWithNewOp()525 void RegionResourceHoister::ReplaceOpWithNewOp() {
526 auto new_result_types = llvm::to_vector<4>(op_->getResultTypes());
527 int result_region = isa<TF::WhileRegionOp>(op_) ? 1 : 0;
528 Operation* terminator = op_->getRegion(result_region).front().getTerminator();
529 auto extra_result_types =
530 terminator->getOperands().drop_front(op_->getNumResults()).getTypes();
531 new_result_types.insert(new_result_types.end(), extra_result_types.begin(),
532 extra_result_types.end());
533 OpBuilder builder(op_);
534 // Clone this old operation but with new result types.
535 Operation* new_op = Operation::create(
536 op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(),
537 op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions());
538 builder.insert(new_op);
539
540 // Move regions to the new op.
541 for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) {
542 Region& old_region = std::get<0>(it);
543 Region& new_region = std::get<1>(it);
544 new_region.takeBody(old_region);
545 }
546
547 // Insert stores to all written resources.
548 for (Value resource : written_resources_) {
549 ResourceInfo& info = resources_[resource];
550 Value value_to_write = new_op->getResult(info.result_index);
551 Operation* write = builder.create<TF::AssignVariableOp>(
552 op_->getLoc(), resource, value_to_write);
553 write->setAttrs(info.write_attrs);
554 }
555
556 // As a part of lifting, we either reuse an existing slot for resource type
557 // results or add a new slot. Resource type results should not have any uses
558 // to begin with. So we can safely replace each old op result with the
559 // corresponding new op result.
560 int old_num_results = op_->getNumResults();
561 op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results));
562 op_->erase();
563 op_ = nullptr;
564 }
565
566 // Lift resource load and stores out of regions attached to `op`, where op is
567 // an If/case/cluster op.
HoistResourcesOutOfIfCaseCluster(Operation * op)568 LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster(
569 Operation* op) {
570 RegionResourceHoister hoister(op);
571 if (failed(hoister.Analyze())) return failure();
572
573 // If there are no resource region captures, then nothing to do.
574 if (!hoister.NeedsLifting()) return success();
575
576 // Start the transformation. For each region, replace the resource read with
577 // the value read before the op.
578 hoister.GenerateHoistedReads();
579 hoister.ReplaceResourceLoads(/*read_only=*/false);
580 hoister.AppendResourceStoreValueToReturn(op->getRegions());
581 hoister.ReplaceOpWithNewOp();
582 return success();
583 }
584
585 // Lift resource loads and stores out of WhileRegion
HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op)586 LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion(
587 TF::WhileRegionOp op) {
588 // For WhileRegion, post canonicalization all resource used within the
589 // body and condition regions are replaced with captured values, so we do not
590 // need to take into account the body and condition region arguments.
591 RegionResourceHoister hoister(op);
592
593 if (failed(hoister.Analyze())) return failure();
594
595 // If there are no resource region captures, then nothing to do.
596 if (!hoister.NeedsLifting()) return success();
597
598 // The resources captured for While loop fall into two categories:
599 // (a) read-only. These reads can be replaced by a hoisted read created
600 // before the WhileOp (similar to if and case).
601 // (b) written: since the value is written in the loop (which can only in
602 // loop body, all these will become loop variables. Since all resource
603 // variables are removed from the loop variabled during
604 // canonicalizationW, we need to create new operand/result slots. The
605 // input operands for these slots are the read values
606 // prior to the op, and all references to these are replaced by the
607 // corresponding slot argument. We need to generate writes following
608 // the while for these resources.
609 //
610 // Note that for WhileRegion ops, if a resource is written, it will be written
611 // only in the body and not the condition, so the hoister analysis will infer
612 // it as needing a read as well.
613
614 // Generate hoisted reads before the while.
615 hoister.GenerateHoistedReads();
616
617 // Replace just the read-only resources with the hoisted reads.
618 hoister.ReplaceResourceLoads(/*read_only=*/true);
619
620 // For written resources, add additional operands to the while op.
621 int num_old_results = op.getNumResults();
622 int num_new_results = hoister.GetLiftedNumResults();
623 int num_extra_results = num_new_results - num_old_results;
624
625 SmallVector<Type, 4> new_result_types;
626 SmallVector<Value, 4> new_while_operands;
627 new_result_types.resize(num_extra_results);
628 new_while_operands.resize(num_extra_results);
629
630 for (auto& it : hoister.GetResources()) {
631 if (!it.second.is_written) continue;
632 int index = it.second.result_index - num_old_results;
633 new_result_types[index] = it.second.data_type;
634 new_while_operands[index] = it.second.hoisted_read;
635 }
636 op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands);
637
638 // Patch the cond and body regions to have additional arguments, and replace
639 // the remaining resource reads (which will be resource reads for written
640 // resources) with these arguments.
641 for (Region* region : op.getRegions()) {
642 region->addArguments(new_result_types);
643 // Point hoisted read for written resources to the region's arguments.
644 for (auto& it : hoister.GetResources()) {
645 if (!it.second.is_written) continue;
646 it.second.hoisted_read = region->getArgument(it.second.result_index);
647 }
648 hoister.ReplaceResourceLoads(*region, /*read_only=*/false);
649 }
650
651 // Add additional return values to body return. These correspond to values
652 // written to resources in the body region.
653 hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front());
654
655 // Finally, create a new while with additional return values.
656 hoister.ReplaceOpWithNewOp();
657 return success();
658 }
659
660 // Lift resources out of the regions attached to `op`
ReplaceOpWithNewOp(Operation * op)661 LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) {
662 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op))
663 return HoistResourcesOutOfWhileRegion(while_op);
664 return HoistResourcesOutOfIfCaseCluster(op);
665 }
666
667 // Holds information about a function's use of a resource argument.
668 struct ResourceArgUseInfo {
669 // Data type of the data contained in the resource.
670 Type data_type;
671 // Is the resource argument used in an assign op?
672 bool updated;
673 // Is the resource argument used in a read or assign op?
674 bool used;
675 };
676
677 // Finds the ResourceArgUseInfo for each resource argument. Forwarding to the
678 // output (i.e., the argument is an operand of the return op) is not considered
679 // as a use. This doesn't support nesting of ops, so before calling this, nested
680 // ops/functions need to be already resource-lifted.
FindResourceArgUseInfo(FuncOp func_op,llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> * result)681 LogicalResult FindResourceArgUseInfo(
682 FuncOp func_op, llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>* result) {
683 auto return_op = func_op.front().getTerminator();
684 for (auto arg : TF::filter_resources(func_op.getArguments())) {
685 ResourceArgUseInfo info;
686 info.used = false;
687 info.updated = false;
688 bool read_or_assigned = false;
689 bool used_in_unsupported_op = false;
690 for (auto user : arg.getUsers()) {
691 if (user == return_op) continue;
692 info.used = true;
693 if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
694 read_or_assigned = true;
695 info.data_type = read.getType();
696 continue;
697 }
698
699 if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
700 read_or_assigned = true;
701 info.updated = true;
702 info.data_type = assign.value().getType();
703 continue;
704 }
705
706 used_in_unsupported_op = true;
707 break;
708 }
709
710 // If the arg is used in an unsupported op, skip lifting it.
711 if (used_in_unsupported_op) continue;
712 (*result)[arg.getArgNumber()] = info;
713 }
714 return success();
715 }
716
717 // Merges two sets of resource arg use infos. An argument is considered used in
718 // the merged result as long as either set marks it as used. This is used to
719 // merge results from functions that have aliasing inputs, e.g., a while loop's
720 // body and condition. The sets of keys of the two maps must be the same.
MergeArgResourceUseInfo(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos0,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos1)721 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
722 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
723 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
724 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result;
725 for (const auto& entry : infos0) {
726 auto info1_it = infos1.find(entry.getFirst());
727 // If the entry is missing in any input, we should not touch this entry.
728 if (info1_it == infos1.end()) continue;
729 auto& info = result[entry.getFirst()];
730 info = entry.getSecond();
731 if (info.updated) continue;
732 if (info1_it->getSecond().used) {
733 info.used = true;
734 info.updated = info1_it->getSecond().updated;
735 info.data_type = info1_it->getSecond().data_type;
736 }
737 }
738 return result;
739 }
740
741 // Removes the unused resource arguments, and the return values that forward the
742 // removed arguments. If old_to_new_arg_indices is provided, it will store the
743 // new argument index that corresponds to each original index (-1 means it is
744 // removed). If remaining_resource_data_types is provided, it will store the
745 // data types of the remaining resource arguments, where the indices are after
746 // removing unused ones.
RemoveUnusedResourceArgumentsAndForwardedRetvals(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos,FuncOp func_op,llvm::SmallVector<int64_t,4> * old_to_new_arg_indices=nullptr,llvm::SmallDenseMap<int64_t,Type> * remaining_resource_data_types=nullptr)747 void RemoveUnusedResourceArgumentsAndForwardedRetvals(
748 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos,
749 FuncOp func_op,
750 llvm::SmallVector<int64_t, 4>* old_to_new_arg_indices = nullptr,
751 llvm::SmallDenseMap<int64_t, Type>* remaining_resource_data_types =
752 nullptr) {
753 // Remove return values forwarded from unused arguments.
754 auto return_op = func_op.front().getTerminator();
755 auto old_return_vals = llvm::to_vector<8>(return_op->getOperands());
756 int64_t skipped_retvals = 0;
757 for (auto entry : llvm::enumerate(old_return_vals)) {
758 auto return_val = entry.value();
759 if (auto arg = return_val.dyn_cast<BlockArgument>()) {
760 auto it = infos.find(arg.getArgNumber());
761 if (it != infos.end() && !it->getSecond().used) {
762 return_op->eraseOperand(entry.index() - skipped_retvals++);
763 }
764 }
765 }
766 llvm::SmallVector<unsigned int, 4> indices_to_erase;
767 llvm::SmallVector<Type, 4> new_types;
768 int64_t skipped_args = 0;
769 for (auto arg : func_op.getArguments()) {
770 auto it = infos.find(arg.getArgNumber());
771 if (it != infos.end() && !it->getSecond().used) {
772 indices_to_erase.push_back(arg.getArgNumber());
773 skipped_args++;
774 if (old_to_new_arg_indices != nullptr) {
775 old_to_new_arg_indices->push_back(-1);
776 }
777 } else {
778 new_types.push_back(arg.getType());
779 if (old_to_new_arg_indices != nullptr) {
780 old_to_new_arg_indices->push_back(arg.getArgNumber() - skipped_args);
781 }
782 if (it != infos.end() && remaining_resource_data_types != nullptr) {
783 (*remaining_resource_data_types)[arg.getArgNumber() - skipped_args] =
784 it->second.data_type;
785 }
786 }
787 }
788 func_op.eraseArguments(indices_to_erase);
789 func_op.setType(
790 FunctionType::get(func_op.getContext(), new_types,
791 llvm::to_vector<4>(return_op->getOperandTypes())));
792 }
793
794 // Lifts reads/writes of resource arguments from func_op and changes its
795 // signature. resource_data_types is the (index, data type) pair for each
796 // resource argument. handle_updated_arg_value is a caller-provided function
797 // that handles the updated value for an resource argument.
LiftArgRetResourcesForFunction(FuncOp func_op,const llvm::SmallDenseMap<int64_t,Type> & resource_data_types,llvm::function_ref<void (int64_t,Value)> handle_updated_arg_value)798 LogicalResult LiftArgRetResourcesForFunction(
799 FuncOp func_op,
800 const llvm::SmallDenseMap<int64_t, Type>& resource_data_types,
801 llvm::function_ref<void(int64_t, Value)> handle_updated_arg_value) {
802 RegionResourceHoister hoister(func_op);
803 if (failed(hoister.Analyze())) return failure();
804
805 // Each of these resources could be read or written in the function. If its
806 // read, we need to replace the resource arg with a value arg to get the
807 // read value. If its written, we need to replace the write with an additional
808 // value to be written.
809
810 // Now create read values that will be used to replace each resource that
811 // is read in the function body. These read values are just the same argument
812 // with type replaced.
813 llvm::SmallVector<Value, 4> skipped_args;
814 for (auto& it : hoister.GetResources()) {
815 BlockArgument arg = it.first.dyn_cast<BlockArgument>();
816 assert(arg && "Expect resources for FuncOp to be its arguments");
817 auto type_iter = resource_data_types.find(arg.getArgNumber());
818 if (type_iter == resource_data_types.end()) {
819 // Skip lifting the resource if it's not present in the data type map.
820 // This indicates that the resource is not to be lifted because it is used
821 // in an unsupported op in some other function.
822 skipped_args.push_back(arg);
823 } else {
824 arg.setType(type_iter->second);
825 it.second.hoisted_read = arg;
826 }
827 }
828
829 // Drop all the args that have to be skipped.
830 for (Value arg : skipped_args) hoister.DropResource(arg);
831
832 hoister.ReplaceResourceLoads(/*read_only=*/false);
833
834 // For writes, invoke the callback and then erase the write.
835 auto assign_ops = func_op.front().getOps<TF::AssignVariableOp>();
836 for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
837 Value resource = assign_variable_op.resource();
838 if (!hoister.Contains(resource)) continue;
839
840 auto arg = resource.dyn_cast<BlockArgument>();
841 handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value());
842 assign_variable_op.erase();
843 }
844
845 func_op.setType(FunctionType::get(
846 func_op.getContext(), func_op.front().getArgumentTypes(),
847 func_op.front().getTerminator()->getOperandTypes()));
848
849 return success();
850 }
851
852 // Returns a vector filtered from range where the unused elements (specified by
853 // resource_arg_uses) are removed.
854 template <typename T, typename Range>
FilterRange(Range range,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & resource_arg_uses)855 llvm::SmallVector<T, 4> FilterRange(
856 Range range,
857 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& resource_arg_uses) {
858 llvm::SmallVector<T, 4> filtered;
859 for (auto entry : llvm::enumerate(range)) {
860 auto it = resource_arg_uses.find(entry.index());
861 if (it == resource_arg_uses.end() || it->getSecond().used)
862 filtered.push_back(entry.value());
863 }
864 return filtered;
865 }
866
867 // Changes the types of the control flow op (e.g., while, if) and adds loads and
868 // stores around it. arg_data_type_and_updated_output_index maps an operand (to
869 // be changed) index to its data type and the updated value index in the output
870 // (-1 means not updated.)
AddLoadsStoresOutsideControlFlowOp(Operation * caller,const llvm::SmallDenseMap<int64_t,std::pair<Type,int64_t>> & arg_data_type_and_updated_output_index)871 void AddLoadsStoresOutsideControlFlowOp(
872 Operation* caller,
873 const llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>&
874 arg_data_type_and_updated_output_index) {
875 OpBuilder builder(caller);
876 auto new_operands = llvm::to_vector<8>(caller->getOperands());
877 llvm::SmallVector<int64_t, 8> changed_indices;
878 // Find the operands to change, and create the loads.
879 for (auto& entry : arg_data_type_and_updated_output_index) {
880 int64_t index = entry.getFirst();
881 Type new_type = entry.getSecond().first;
882 int64_t updated_index = entry.getSecond().second;
883 auto operand = caller->getOperand(index);
884 builder.setInsertionPoint(caller);
885 new_operands[index] = builder.create<TF::ReadVariableOp>(
886 caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand});
887 caller->setOperand(index, new_operands[index]);
888 if (updated_index < 0) continue;
889 builder.setInsertionPointAfter(caller);
890 builder.create<TF::AssignVariableOp>(
891 caller->getLoc(), ArrayRef<Type>{},
892 ArrayRef<Value>{operand, caller->getResult(updated_index)});
893 }
894 }
895
896 // Lifts loads/stores from while loop's body and cond functions.
HandleWhileLoop(TF::WhileOp while_op,FuncOp body,FuncOp cond)897 LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
898 auto return_op = body.front().getTerminator();
899 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> body_use_info;
900 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> cond_use_info;
901 if (failed(FindResourceArgUseInfo(body, &body_use_info)) ||
902 failed(FindResourceArgUseInfo(cond, &cond_use_info))) {
903 return failure();
904 }
905 // A resource is considered used as long as it is used in either body or cond.
906 auto resource_arg_uses =
907 MergeArgResourceUseInfo(body_use_info, cond_use_info);
908 if (resource_arg_uses.empty()) return success();
909
910 // Remove unused resources in functions.
911 llvm::SmallVector<int64_t, 4> old_to_new_indices;
912 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
913 RemoveUnusedResourceArgumentsAndForwardedRetvals(
914 resource_arg_uses, body, &old_to_new_indices,
915 &remaining_resource_data_types);
916 RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond);
917 (void)LiftArgRetResourcesForFunction(
918 body, remaining_resource_data_types,
919 [&](int64_t index, Value value) { return_op->setOperand(index, value); });
920 (void)LiftArgRetResourcesForFunction(cond, remaining_resource_data_types,
921 [&](int64_t index, Value value) {
922 // We already checked that cond should
923 // not have variable writes.
924 assert(false && "Should not happen");
925 });
926 // Recreate the while op.
927 OpBuilder builder(while_op);
928 // Now use the filtered original operands, which will be replaced by
929 // AddLoadsStoresOutsideControlFlowOp().
930 auto new_while = builder.create<TF::WhileOp>(
931 while_op.getLoc(), body.getType().getResults(),
932 FilterRange<Value, OperandRange>(while_op.getOperands(),
933 resource_arg_uses),
934 while_op.getAttrs());
935 // Prepare for AddLoadsStoresOutsideControlFlowOp().
936 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
937 arg_data_type_and_updated_output_index;
938 for (const auto& entry : remaining_resource_data_types) {
939 int64_t update_index = return_op->getOperand(entry.getFirst()) ==
940 body.getArgument(entry.getFirst())
941 ? -1
942 : entry.getFirst();
943 arg_data_type_and_updated_output_index[entry.getFirst()] = {
944 entry.getSecond(), update_index};
945 }
946 AddLoadsStoresOutsideControlFlowOp(new_while,
947 arg_data_type_and_updated_output_index);
948 // Replace uses.
949 for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) {
950 if (old_to_new_indices[i] >= 0) {
951 while_op.getResult(i).replaceAllUsesWith(
952 new_while.getResult(old_to_new_indices[i]));
953 }
954 }
955 while_op.erase();
956 return success();
957 }
958
959 // Lifts loads/stores from an IfOp or CaseOp's branches.
960 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<FuncOp> branches)961 LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
962 // For canonicalized If/Case, there should not be any resource outputs
963 int64_t non_resource_results = op.getNumResults();
964
965 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses;
966 if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses)))
967 return failure();
968
969 for (auto func : branches.drop_front()) {
970 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info;
971 if (failed(FindResourceArgUseInfo(func, &branch_use_info)))
972 return failure();
973 // A resource is considered used as long as it is used in either branch.
974 resource_arg_uses =
975 MergeArgResourceUseInfo(resource_arg_uses, branch_use_info);
976 }
977
978 if (resource_arg_uses.empty()) return success();
979 // Remove unused resources in functions.
980 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
981 RemoveUnusedResourceArgumentsAndForwardedRetvals(
982 resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr,
983 &remaining_resource_data_types);
984 for (auto func : branches.drop_front())
985 RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func);
986
987 // Forward resource inputs updated in any branch to the outputs of both
988 // branches. First prepare the mapping from arg to new update output.
989 llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output;
990 {
991 int64_t removed_args = 0;
992 for (const auto& entry : resource_arg_uses) {
993 if (!entry.getSecond().used) {
994 removed_args++;
995 continue;
996 }
997 if (!entry.getSecond().updated) continue;
998 int64_t new_output_index =
999 non_resource_results + resource_arg_to_new_output.size();
1000 resource_arg_to_new_output[entry.getFirst() - removed_args] =
1001 new_output_index;
1002 }
1003 }
1004
1005 // Append resource updates to the return ops: now they are just forwarded
1006 // input resources, but will be replaced by the data value in
1007 // LiftArgRetResourcesForFunction().
1008 for (auto branch : branches) {
1009 auto new_retvals =
1010 llvm::to_vector<4>(branch.front().getTerminator()->getOperands());
1011 new_retvals.resize(new_retvals.size() + resource_arg_to_new_output.size());
1012 for (const auto& entry : resource_arg_to_new_output) {
1013 int64_t resource_arg_index = entry.getFirst();
1014 int64_t output_index = entry.getSecond();
1015 new_retvals[output_index] = branch.getArgument(resource_arg_index);
1016 }
1017 auto old_return = branch.front().getTerminator();
1018 OpBuilder builder(old_return);
1019 auto new_return =
1020 builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
1021 old_return->erase();
1022 (void)LiftArgRetResourcesForFunction(
1023 branch, remaining_resource_data_types, [&](int64_t index, Value value) {
1024 new_return.setOperand(resource_arg_to_new_output[index], value);
1025 });
1026 }
1027
1028 // Recreate the op without resource operands.
1029 OpBuilder builder(op);
1030 // Now use the filtered original operands, which will be replaced by
1031 // AddLoadsStoresOutsideControlFlowOp().
1032 auto new_operands =
1033 FilterRange<Value, OperandRange>(op.input(), resource_arg_uses);
1034 new_operands.insert(new_operands.begin(), op.getOperand(0));
1035 FuncOp first_func = branches.front();
1036 auto new_op =
1037 builder.create<CaseOrIfOp>(op.getLoc(), first_func.getType().getResults(),
1038 new_operands, op.getAttrs());
1039 // Prepare for AddLoadsStoresOutsideControlFlowOp()
1040 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
1041 arg_data_type_and_updated_output_index;
1042 for (const auto& entry : remaining_resource_data_types) {
1043 auto new_output_it = resource_arg_to_new_output.find(entry.getFirst());
1044 int64_t update_index = new_output_it == resource_arg_to_new_output.end()
1045 ? -1
1046 : new_output_it->getSecond();
1047 arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
1048 entry.getSecond(), update_index};
1049 }
1050 AddLoadsStoresOutsideControlFlowOp(new_op,
1051 arg_data_type_and_updated_output_index);
1052 // Replace uses.
1053 op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
1054 op.erase();
1055 return success();
1056 }
1057
1058 // A resource-lifted function for (potentially multiple) PartitionedCallOps and
1059 // information about the lifting changes.
1060 struct PartitionedCallLiftingInfo {
1061 // Function with resources lifted. Can be nullptr if nothing needs to change.
1062 FuncOp lifted_callee;
1063 // Mapping from old resource outputs to their aliasing output inputs.
1064 llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
1065 // Mapping from old to new output indices in case any output is removed.
1066 llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
1067 // ResourceArgUseInfo for each old resource argument.
1068 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
1069 // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
1070 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
1071 arg_data_type_and_updated_output_index;
1072 };
1073
1074 // Lifts loads/stores from a PartitionedCallOp's callee function. If anything
1075 // needs to be changed, the original function will be preserved, and the lifting
1076 // happens on a clone, which will be stored in `result`.
HandlePartitionedCallOpCallee(FuncOp callee,PartitionedCallLiftingInfo * result)1077 LogicalResult HandlePartitionedCallOpCallee(
1078 FuncOp callee, PartitionedCallLiftingInfo* result) {
1079 // Sanity check: return of resources should be aliases of inputs. Such outputs
1080 // will be removed later.
1081 int64_t non_resource_results = 0;
1082 for (auto entry :
1083 llvm::enumerate(callee.front().getTerminator()->getOperands())) {
1084 auto retval = entry.value();
1085 if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
1086 result->old_to_new_output_indices.push_back(non_resource_results++);
1087 continue;
1088 }
1089 auto aliasing_arg = retval.dyn_cast<BlockArgument>();
1090 if (!aliasing_arg) {
1091 return callee.emitOpError("unsupported function call: ")
1092 << "resource return value does not alias an input.";
1093 }
1094 result->old_outputs_aliasing_old_inputs[entry.index()] =
1095 aliasing_arg.getArgNumber();
1096 result->old_to_new_output_indices.push_back(-1);
1097 }
1098
1099 if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
1100 return failure();
1101 }
1102 if (result->use_info.empty()) {
1103 result->lifted_callee = nullptr;
1104 return success();
1105 }
1106
1107 // Clone the callee before making changes.
1108 SmallString<64> name_base = callee.getName();
1109 auto module = callee->getParentOfType<ModuleOp>();
1110 name_base += "_resource_lifted";
1111 auto name = name_base;
1112 callee = callee.clone();
1113 callee.setPrivate();
1114 callee.setName(name);
1115 SymbolTable(module).insert(callee);
1116 result->lifted_callee = callee;
1117
1118 // Remove unused resources in functions.
1119 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
1120 RemoveUnusedResourceArgumentsAndForwardedRetvals(
1121 result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
1122 &remaining_resource_data_types);
1123 for (const auto& entry : remaining_resource_data_types) {
1124 result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
1125 entry.getSecond(), -1};
1126 }
1127 llvm::SmallVector<int64_t, 4> retval_indices_to_preserve;
1128 for (auto& val : callee.front().getTerminator()->getOpOperands()) {
1129 // Store indices of results that are not resources.
1130 if (!getElementTypeOrSelf(val.get().getType()).isa<TF::ResourceType>())
1131 retval_indices_to_preserve.push_back(val.getOperandNumber());
1132 }
1133 int64_t num_retvals = retval_indices_to_preserve.size();
1134 llvm::SmallVector<Value, 4> new_retvals;
1135 // Lift resources.
1136 (void)LiftArgRetResourcesForFunction(
1137 callee, remaining_resource_data_types, [&](int64_t index, Value value) {
1138 result->arg_data_type_and_updated_output_index[index].second =
1139 num_retvals++;
1140 new_retvals.push_back(value);
1141 });
1142
1143 auto old_return = callee.front().getTerminator();
1144 llvm::SmallVector<Value, 4> old_and_new_retvals;
1145 old_and_new_retvals.reserve(retval_indices_to_preserve.size() +
1146 new_retvals.size());
1147 for (int64_t retval_index : retval_indices_to_preserve)
1148 old_and_new_retvals.push_back(old_return->getOperand(retval_index));
1149
1150 old_and_new_retvals.append(new_retvals.begin(), new_retvals.end());
1151 // Replace old return with the new ones with update values.
1152 OpBuilder builder(old_return);
1153 auto new_return =
1154 builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
1155 old_return->erase();
1156 callee.setType(
1157 FunctionType::get(callee.getContext(), callee.getType().getInputs(),
1158 llvm::to_vector<4>(new_return.getOperandTypes())));
1159 return success();
1160 }
1161
1162 // Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
1163 // resource-lifted new callee function in lifting_info.
1164 template <typename CallOpType>
UpdatePartitionedCallOpWithNewCallee(CallOpType call_op,PartitionedCallLiftingInfo & lifting_info)1165 void UpdatePartitionedCallOpWithNewCallee(
1166 CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) {
1167 if (!lifting_info.lifted_callee) return;
1168 // Replace output resource uses with the aliasing input, so that we can remove
1169 // this output.
1170 for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
1171 call_op.getResult(entry.getFirst())
1172 .replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
1173 }
1174 // Recreate the call op.
1175 OpBuilder builder(call_op);
1176 // Now use the filtered original operands, which will be replaced by
1177 // AddLoadsStoresOutsideControlFlowOp().
1178 auto new_operands =
1179 FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
1180 auto new_call = builder.create<CallOpType>(
1181 call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
1182 new_operands, call_op.getAttrs());
1183 new_call->setAttr(
1184 "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
1185 AddLoadsStoresOutsideControlFlowOp(
1186 new_call, lifting_info.arg_data_type_and_updated_output_index);
1187 // Replace uses.
1188 for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size();
1189 i < end; ++i) {
1190 if (lifting_info.old_to_new_output_indices[i] >= 0) {
1191 call_op.getResult(i).replaceAllUsesWith(
1192 new_call.getResult(lifting_info.old_to_new_output_indices[i]));
1193 }
1194 }
1195 call_op.erase();
1196 }
1197
1198 LogicalResult HoistForControlFlow(
1199 Block*, ModuleOp, bool,
1200 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*);
1201
1202 // A templated routine for handling both PartitionedCallOp and
1203 // StatefulPartitionedCallOp. If the callee is already lifted, it just updates
1204 // the caller op itself; otherwise, it first recursively handles nested control
1205 // flow, then performs lifting on the callee.
1206 template <typename CallOpType>
HandlePartitionedCallOp(CallOpType call_op,FuncOp callee,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_callees)1207 LogicalResult HandlePartitionedCallOp(
1208 CallOpType call_op, FuncOp callee, ModuleOp module, bool vars_initialized,
1209 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1210 lifted_callees) {
1211 auto emplace_res = lifted_callees->try_emplace(callee.getName(),
1212 PartitionedCallLiftingInfo());
1213 if (emplace_res.second) {
1214 // Unseen callee. Perform resource lifting on it.
1215 if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized,
1216 lifted_callees)))
1217 return failure();
1218
1219 if (failed(HandlePartitionedCallOpCallee(
1220 callee, &emplace_res.first->getSecond()))) {
1221 return failure();
1222 }
1223 }
1224 UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
1225 return success();
1226 }
1227
1228 // Hoists resource loads/stores from control flow ops in `block` outside the
1229 // body/cond/branch/callee functions.
HoistForControlFlow(Block * block,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_partitioned_call_callees)1230 LogicalResult HoistForControlFlow(
1231 Block* block, ModuleOp module, bool vars_initialized,
1232 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1233 lifted_partitioned_call_callees) {
1234 if (vars_initialized) SetAllVarIsInitializedToTrue(block);
1235
1236 for (Operation& op : llvm::make_early_inc_range(*block)) {
1237 if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
1238 auto body = while_op.body_function();
1239 auto cond = while_op.cond_function();
1240 // Recursively handle the nested control flow.
1241 (void)HoistForControlFlow(&body.front(), module, vars_initialized,
1242 lifted_partitioned_call_callees);
1243 (void)HoistForControlFlow(&cond.front(), module, vars_initialized,
1244 lifted_partitioned_call_callees);
1245 if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
1246 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
1247 auto then_branch = if_op.then_function();
1248 auto else_branch = if_op.else_function();
1249 // Recursively handle the nested control flow.
1250 (void)HoistForControlFlow(&then_branch.front(), module, vars_initialized,
1251 lifted_partitioned_call_callees);
1252 (void)HoistForControlFlow(&else_branch.front(), module, vars_initialized,
1253 lifted_partitioned_call_callees);
1254 if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch})))
1255 return failure();
1256 } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
1257 SmallVector<FuncOp, 4> branch_functions;
1258 case_op.get_branch_functions(branch_functions);
1259 for (FuncOp func : branch_functions) {
1260 // Recursively handle the nested control flow.
1261 (void)HoistForControlFlow(&func.front(), module, vars_initialized,
1262 lifted_partitioned_call_callees);
1263 }
1264 if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
1265 } else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
1266 auto callee = call_op.func();
1267 if (!callee) {
1268 return call_op.emitOpError(
1269 "resource lifting does not support call with nested references.");
1270 }
1271 if (failed(HandlePartitionedCallOp(call_op, callee, module,
1272 vars_initialized,
1273 lifted_partitioned_call_callees))) {
1274 // Nested control flow handling is done in HandlePartitionedCallOp().
1275 return failure();
1276 }
1277 } else if (auto call_op =
1278 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
1279 if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module,
1280 vars_initialized,
1281 lifted_partitioned_call_callees))) {
1282 return failure();
1283 }
1284 } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, TF::WhileRegionOp>(op)) {
1285 for (Region& region : op.getRegions())
1286 (void)HoistForControlFlow(®ion.front(), module, vars_initialized,
1287 lifted_partitioned_call_callees);
1288 LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op);
1289 if (failed(result)) return failure();
1290 }
1291 }
1292
1293 // After we have hoisted operations in the block, we may have added new read
1294 // and writes of resources to this block. Clean them up by doing store-load
1295 // forwarding.
1296 ForwardStoreToLoad(block);
1297 return success();
1298 }
1299
1300 // Lifts resource operation from tf_device.cluster ops nested in `op` outside.
1301 // Returns failure if there are remaining resource-type values that can not be
1302 // lifted.
runOnOperation()1303 void ResourceOpLiftingPass::runOnOperation() {
1304 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1305 lifted_partitioned_call_callees;
1306 ModuleOp module = getOperation();
1307
1308 if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module)))
1309 return signalPassFailure();
1310
1311 auto walk_result = module.walk([&](FuncOp func_op) {
1312 return func_op.walk([&](tf_device::ClusterOp cluster) {
1313 LogicalResult result = HoistForControlFlow(
1314 &cluster.GetBody(), module, /*vars_initialized=*/true,
1315 &lifted_partitioned_call_callees);
1316 if (failed(result)) return WalkResult::interrupt();
1317 result = RegionResourceHoister::ReplaceOpWithNewOp(cluster);
1318 if (failed(result)) return WalkResult::interrupt();
1319 return WalkResult::advance();
1320 });
1321 });
1322
1323 if (walk_result.wasInterrupted()) return signalPassFailure();
1324 }
1325
1326 struct ResourceOpLiftingForMainFunctionPass
1327 : public PassWrapper<ResourceOpLiftingForMainFunctionPass,
1328 OperationPass<ModuleOp>> {
1329 void runOnOperation() override;
1330 };
1331
runOnOperation()1332 void ResourceOpLiftingForMainFunctionPass::runOnOperation() {
1333 ModuleOp module = getOperation();
1334 FuncOp main_func = module.lookupSymbol<FuncOp>("main");
1335 if (!main_func) {
1336 return;
1337 }
1338
1339 if (failed(TF::ResourceLiftingForFunctionalControlFlow(main_func))) {
1340 return signalPassFailure();
1341 }
1342 }
1343
1344 static PassRegistration<ResourceOpLiftingForMainFunctionPass>
1345 lift_main_func_pass(
1346 "tf-resource-op-lifting-for-main-function",
1347 "Lifting resource operations out of control flow statements for the "
1348 "main function");
1349
1350 static PassRegistration<ResourceOpLiftingPass> pass(
1351 "tf-resource-op-lifting",
1352 "Lifting resource operations out of device computation");
1353
1354 } // namespace
1355
1356 namespace TFDevice {
CreateResourceOpLiftingPass()1357 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() {
1358 return std::make_unique<ResourceOpLiftingPass>();
1359 }
1360 } // namespace TFDevice
1361
1362 namespace TF {
ResourceLiftingForFunctionalControlFlow(FuncOp function)1363 LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
1364 // This routine should only be called when control flow operations are still
1365 // represented with TF IfOp and WhileOp operations. In this case, there should
1366 // be only one basic blocks in the MLIR representation.
1367 if (!llvm::hasSingleElement(function)) {
1368 return function.emitError()
1369 << "expect the function to have 1 block while it has "
1370 << function.getBlocks().size();
1371 }
1372
1373 if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function)))
1374 return failure();
1375
1376 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1377 lifted_partitioned_call_callees;
1378 if (failed(HoistForControlFlow(
1379 &function.front(), cast<ModuleOp>(function->getParentOp()),
1380 /*vars_initialized=*/false, &lifted_partitioned_call_callees)))
1381 return failure();
1382
1383 // Clean up and canonicalize to remove dead local variables as some local
1384 // variables might be dead after hoisting resource loads/stores from control
1385 // flow ops.
1386 return TF::CleanupAndCanonicalizeForResourceOpLifting(function);
1387 }
1388 } // namespace TF
1389
1390 } // namespace mlir
1391