1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass lifts resource variable operations outside of device computation.
17 
18 #include <cstddef>
19 #include <cstdint>
20 
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Block.h"  // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
34 #include "mlir/IR/Builders.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
37 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/Region.h"  // from @llvm-project
40 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
41 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
42 #include "mlir/IR/Types.h"  // from @llvm-project
43 #include "mlir/IR/Value.h"  // from @llvm-project
44 #include "mlir/IR/Verifier.h"  // from @llvm-project
45 #include "mlir/IR/Visitors.h"  // from @llvm-project
46 #include "mlir/Pass/Pass.h"  // from @llvm-project
47 #include "mlir/Support/LLVM.h"  // from @llvm-project
48 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
49 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
54 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h"
56 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
57 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
58 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
59 #include "tensorflow/core/framework/tensor_shape.pb.h"
60 
61 namespace mlir {
62 
63 namespace {
64 
65 // This pass lifts resource variable operations outside of device computation.
66 // This is useful because a lot of accelerator devices can not interact with
67 // resource variables directly..
68 //
69 // Here is a simple example in TensorFlow where a device doubles the value of a
70 // TensorFlow resource variable and returns new value:
71 //
72 // %resource_handle = "tf.VarHandleOp"()
73 // %1 = "tf_device.cluster"() ( {
74 //   %init_value = "tf.ReadVariableOp"(%resource_handle)
75 //   "tf.AssignAddVariableOp"(%resource_handle, %init_value)
76 //   %new_value = "tf.ReadVariableOp"(%resource_handle)
77 //   tf_device.return %new_value
78 // })
79 //
80 // After this pass, the computation would become:
81 //
82 // %resource_handle = "tf.VarHandleOp"()
83 // %init_value = "tf.ReadVariableOp"(%resource_handle)
84 // %1:2 = "tf_device.cluster"() ( {
85 //   %new_value = "tf.AddV2"(%init_value, %init_value)
86 //   tf_device.return %new_value, %new_value
87 // })
88 // "tf.AssignVariableOp"(%resource_handle, %1#1)
89 //
90 // You can see that there are a few main changes applied:
91 // 1) All the resource variable reads and writes are now outside of
92 //    tf_device.cluster op.
93 // 2) Instead of taking resource handles as input, this device computation now
94 //    takes snapshotted values of that device.
95 // 3) Some resource load operations are eliminated with store-load forwarding.
96 // 4) Updated values to resource are appended to `tf_device.return` and used by
97 //    external resource store operations so that resources are still updated
98 //    after the computation.
99 //
100 // If the cluster body contains functional control flow, the pass first lifts
101 // the loads/stores in the body/cond/branch functions to the cluster body, then
102 // performs the above lifting. E.g.,
103 //
104 // func @cluster_with_loop() -> () {
105 //   %0 = "tf.VarHandleOp"() ...
106 //   "tf_device.cluster"() ( {
107 //      %1 = "tf.While"(%0) {body = @while_body, cond = @while_cond}
108 //      tf_device.return
109 //   })
110 //   return
111 // }
112 // func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
113 //  %constant = "tf.Const"() ...
114 //  "tf.AssignVariableOp"(%arg0, %constant)
115 //  return %arg0
116 // }
117 // func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
118 //   %read = "tf.ReadVariableOp"(%arg0)
119 //   return %read
120 // }
121 //
122 // will be transformed to:
123 //
124 // func @cluster_with_loop() {
125 //   %0 = "tf.VarHandleOp"() ...
126 //   %1 = "tf.ReadVariableOp"(%0)
127 //   %2 = "tf_device.cluster"() ( {
128 //     %3 = "tf.While"(%1) {body = @while_body, cond = @while_cond}
129 //     tf_device.return %3 : tensor<f32>
130 //   }) : () -> tensor<f32>
131 //   "tf.AssignVariableOp"(%0, %2)
132 //   return
133 // }
134 // func @while_body(%arg0: tensor<f32>) {
135 //   %0 = "tf.Const"() ...
136 //   return %0 : tensor<f32>
137 // }
138 // func @while_cond(%arg0: tensor<f32>) {
139 //   return %arg0
140 // }
141 //
142 struct ResourceOpLiftingPass
143     : public PassWrapper<ResourceOpLiftingPass, OperationPass<ModuleOp>> {
144   void runOnOperation() override;
145 };
146 
IsResource(Value value)147 bool IsResource(Value value) {
148   return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
149 }
150 
151 // Get the type of the data contained in a resource. Returns null if there is
152 // no single type in the resource.
GetResourceSubtype(Value value)153 Type GetResourceSubtype(Value value) {
154   auto resource_type =
155       getElementTypeOrSelf(value.getType()).dyn_cast<TF::ResourceType>();
156   auto subtypes = resource_type.getSubtypes();
157   if (subtypes.size() == 1) return subtypes[0];
158   return nullptr;
159 }
160 
161 // Replaces all `tf.VarIsInitializedOp` in a block with a constant true.
162 // TODO(b/171039585): Replace this with proper analysis of
163 // `tf.VarIsInitializedOp` in regards to resource writes and control flow.
SetAllVarIsInitializedToTrue(Block * block)164 void SetAllVarIsInitializedToTrue(Block* block) {
165   auto builder = OpBuilder::atBlockBegin(block);
166   TF::ConstOp const_true = nullptr;
167   for (auto op :
168        llvm::make_early_inc_range(block->getOps<TF::VarIsInitializedOp>())) {
169     builder.setInsertionPoint(op);
170     if (!const_true)
171       const_true = builder.create<TF::ConstOp>(
172           op.getLoc(),
173           DenseIntElementsAttr::get(
174               RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true));
175 
176     op.is_initialized().replaceAllUsesWith(const_true);
177     op.erase();
178   }
179 }
180 
181 // Performs store-load forwarding. This effectively removes
182 // 1) Any resource loads after a store to that same resource is done
183 // 2) Any resource stores except the last one.
184 // TODO(ycao): Store-load forwarding implemented here is only correct when
185 // computation is purely sequential (no concurrency). Need to support concurrent
186 // computation as well.
ForwardStoreToLoad(Block * block)187 void ForwardStoreToLoad(Block* block) {
188   // resource_handle_to_last_store_op keeps track of the most recent (last)
189   // store to each resource. Non-existent entry indicates that a resource has
190   // not been stored to yet.
191   llvm::SmallDenseMap<Value, TF::AssignVariableOp>
192       resource_handle_to_last_store_op;
193 
194   // Only iterate through ops directly in the block as we can't handle ops
195   // nested deeper in regions.
196   for (Operation& op : llvm::make_early_inc_range(*block)) {
197     if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) {
198       Value resource = read_variable_op.resource();
199       auto last_store = resource_handle_to_last_store_op[resource];
200       if (!last_store) continue;
201 
202       // Use stored value in last_store to replace all uses of current resource
203       // load's result, then erase this resource load. Add an intermediate
204       // CastOp if the shape of types doesn't exactly match.
205       Type read_type = read_variable_op.value().getType();
206       if (read_type != last_store.value().getType()) {
207         OpBuilder builder(last_store);
208         builder.setInsertionPointAfter(last_store);
209         auto cast = builder.create<TF::CastOp>(
210             last_store.getLoc(), read_type, last_store.value(),
211             /*Truncate=*/builder.getBoolAttr(false));
212         read_variable_op.value().replaceAllUsesWith(cast);
213       } else {
214         read_variable_op.value().replaceAllUsesWith(last_store.value());
215       }
216 
217       read_variable_op.erase();
218       continue;
219     }
220 
221     if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) {
222       Value resource = assign_variable_op.resource();
223       auto last_store = resource_handle_to_last_store_op[resource];
224       // Previous store ops to same resource can be erased.
225       if (last_store) last_store.erase();
226 
227       resource_handle_to_last_store_op[resource] = assign_variable_op;
228     }
229   }
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // RegionResourceHoister
234 //===----------------------------------------------------------------------===//
235 
236 // Helper class to hoist resource ops out of regions attached to an op.
237 class RegionResourceHoister {
238  public:
RegionResourceHoister(Operation * op)239   explicit RegionResourceHoister(Operation* op) : op_(op) {}
240 
241   // Analyzes attached regions to record resources read and written.
242   LogicalResult Analyze();
243 
244   // Returns all resources accessed by the regions attached the op.
GetResources()245   auto& GetResources() { return resources_; }
246 
247   // Returns if the given value is a resource that needs lifting.
Contains(Value resource) const248   bool Contains(Value resource) const {
249     return resources_.find(resource) != resources_.end();
250   }
251 
252   // Drops the given resource from lifting.
DropResource(Value resource)253   void DropResource(Value resource) {
254     resources_.erase(resource);
255     written_resources_.remove(resource);
256   }
257 
258   // Replaces all resource loads in all regions attached to the op.
ReplaceResourceLoads(bool read_only)259   void ReplaceResourceLoads(bool read_only) {
260     llvm::for_each(op_->getRegions(), [&](Region& region) {
261       ReplaceResourceLoads(region, read_only);
262     });
263   }
264 
265   static LogicalResult ReplaceOpWithNewOp(Operation* op);
266 
267  private:
268   // Returns if any resources need lifting.
NeedsLifting() const269   bool NeedsLifting() const { return !resources_.empty(); }
270 
271   // Returns the number of results generated by the lifted op.
GetLiftedNumResults() const272   int GetLiftedNumResults() const { return num_new_results_; }
273 
274   // Generates hoisted reads for resources that need them before the op.
275   void GenerateHoistedReads();
276 
277   // Replaces all resource loads in the given region with hoisted loads. If
278   // `read_only` is true, limit this replacement to read only resources.
279   void ReplaceResourceLoads(Region& region, bool read_only);
280 
281   // Appends final values writte to resources to the region returns for the
282   // given set of regions.
283   void AppendResourceStoreValueToReturn(RegionRange regions);
284 
285   // Performs the final replacement of the op.
286   void ReplaceOpWithNewOp();
287 
288   // Returns is this resource was written to in any of the regions.
IsWritten(Value resource) const289   bool IsWritten(Value resource) const {
290     return written_resources_.contains(resource);
291   }
292 
293   static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op);
294   static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op);
295 
296   Operation* op_;
297 
298   // Per resource information about accesses to that resource.
299   struct ResourceInfo {
300     // Is this resource read in any of the regions?
301     bool is_read;
302     // Is this resource written in any of the regions?
303     bool is_written;
304     // Is this resource written in all of the regions?
305     bool is_written_all;
306     // The hoisted read used to replace region reads.
307     Value hoisted_read;
308     // the type of the data held by the resource.
309     Type data_type;
310     // For written resources, the result # of the lifted op which will hold the
311     // value of the resource. This result will be used to generates writes to
312     // the resource after the lifted op.
313     int result_index;
314     // Attributes on the read operation.
315     DictionaryAttr read_attrs;
316     // Attributes on the write operation.
317     DictionaryAttr write_attrs;
318 
ResourceInfomlir::__anonf2669e680111::RegionResourceHoister::ResourceInfo319     ResourceInfo()
320         : is_read(false),
321           is_written(false),
322           is_written_all(false),
323           hoisted_read(nullptr),
324           data_type(nullptr),
325           result_index(-1) {}
326 
IsResultIndexAssignedmlir::__anonf2669e680111::RegionResourceHoister::ResourceInfo327     bool IsResultIndexAssigned() { return result_index != -1; }
328 
329     // Refine the resource type using the given type `type`.
RefineTypemlir::__anonf2669e680111::RegionResourceHoister::ResourceInfo330     void RefineType(Type type) {
331       if (!data_type) {
332         data_type = type;
333       } else {
334         data_type = TF::GetCastCompatibleType(data_type, type,
335                                               /*may_ignore_ref_type_a=*/false);
336         assert(data_type != nullptr && "Resource used with incompatible types");
337       }
338     }
339   };
340   llvm::MapVector<Value, ResourceInfo> resources_;
341   llvm::SetVector<Value> written_resources_;
342   // number of new results after lifting.
343   int num_new_results_;
344 };
345 
346 // Analyzes resources that are read or written within attached regions.
Analyze()347 LogicalResult RegionResourceHoister::Analyze() {
348   // Hoisting of child regions might have created opportunity for store-load
349   // forwarding.
350   for (Region& region : op_->getRegions()) {
351     ForwardStoreToLoad(&region.front());
352   }
353 
354   llvm::SetVector<Value> all_resources;
355   bool is_func = false;
356   // For functions, the resources to analyze are the function arguments.
357   // Otherwise, its the region captures.
358   if (FuncOp func = dyn_cast<FuncOp>(op_)) {
359     is_func = true;
360     Region& body = func.getBody();
361     for (BlockArgument arg : body.getArguments()) {
362       if (IsResource(arg)) all_resources.insert(arg);
363     }
364   } else {
365     getUsedValuesDefinedAbove(op_->getRegions(), all_resources);
366     all_resources.remove_if([](Value value) { return !IsResource(value); });
367   }
368 
369   num_new_results_ = op_->getNumResults();
370 
371   for (auto resource : all_resources) {
372     ResourceInfo info;
373     info.data_type = GetResourceSubtype(resource);
374     llvm::BitVector written_regions(op_->getNumRegions());
375     bool unsupported_use = false;
376     for (OpOperand& use : resource.getUses()) {
377       Operation* user = use.getOwner();
378       // If the user is not in one of the regions, we are not interested in it.
379       // Since all the sub-regions within this region (i.e., regions attached to
380       // op's in this region) have themselves gone through lifting, all resource
381       // users are expected to be operations in this region and not embedded
382       // within other sub-regions attached to op's in this region. So the check
383       // for whether a user is in one of the regions attached to this op is
384       // straightforward.
385       if (user->getParentRegion()->getParentOp() != op_) continue;
386 
387       // For functions, if the resource is used as a return operand, use that
388       // as its result index.
389       if (is_func && isa<ReturnOp>(user)) {
390         assert(!info.IsResultIndexAssigned() &&
391                "Expect resource argument to returned no more than once");
392         info.result_index = use.getOperandNumber();
393         continue;
394       }
395 
396       auto read = dyn_cast<TF::ReadVariableOp>(user);
397       auto write = dyn_cast<TF::AssignVariableOp>(user);
398       if (!read && !write) {
399         unsupported_use = true;
400         break;
401       }
402 
403       if (read && !info.is_read) {
404         info.is_read = true;
405         info.RefineType(read.value().getType());
406         info.read_attrs = user->getAttrDictionary();
407       }
408 
409       if (write) {
410         info.is_written = true;
411         info.RefineType(write.value().getType());
412         info.write_attrs = user->getAttrDictionary();
413         written_regions.set(user->getParentRegion()->getRegionNumber());
414       }
415     }
416 
417     // If the resource is used in an op that we do not understand, skip
418     // lifting for that resource.
419     if (unsupported_use) continue;
420 
421     info.is_written_all = written_regions.count() == op_->getNumRegions();
422 
423     // If the resource is written in some but not all regions, we would need
424     // a read for the value before these regions. Note that this is applicable
425     // only to multi-region ops:
426     // If/Case: If not all regions write to the resource, post hoisting the read
427     //   value need to be routed through all paths that don't write.
428     // While: since while condition cannot write, any resource written in the
429     //   while body will need to be read as well in case the while body is never
430     //   executed.
431     // Both cases are handled by the condition below.
432     if (info.is_written && !info.is_written_all) info.is_read = true;
433 
434     // Allocate a result index for written resources that don't have one.
435     if (info.is_written) {
436       written_resources_.insert(resource);
437       if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++;
438     }
439 
440     resources_.insert({resource, info});
441   }
442   return success();
443 }
444 
445 // Generates hoisted reads for all resources that need them just before the op.
GenerateHoistedReads()446 void RegionResourceHoister::GenerateHoistedReads() {
447   OpBuilder builder(op_);
448   DictionaryAttr empty_attrs = builder.getDictionaryAttr({});
449   for (auto& resource_it : GetResources()) {
450     Value resource = resource_it.first;
451     auto& info = resource_it.second;
452 
453     if (info.is_read) {
454       Operation* read = builder.create<TF::ReadVariableOp>(
455           op_->getLoc(), info.data_type, resource);
456       read->setAttrs(info.read_attrs ? info.read_attrs : empty_attrs);
457       info.hoisted_read = read->getResult(0);
458     }
459   }
460 }
461 
462 // Replaces all resource reads with the hoisted read.
ReplaceResourceLoads(Region & region,bool read_only)463 void RegionResourceHoister::ReplaceResourceLoads(Region& region,
464                                                  bool read_only) {
465   assert(llvm::hasSingleElement(region) && "Expected single block region");
466   // Only iterate through ops directly in the body as we can't handle
467   // ops nested deeper in regions.
468   auto all_reads = region.front().getOps<TF::ReadVariableOp>();
469   for (auto read_op : llvm::make_early_inc_range(all_reads)) {
470     Value resource = read_op.resource();
471     if (!Contains(resource)) continue;
472 
473     ResourceInfo& info = resources_[resource];
474     // If replacing loads for read only resources, skip if the resource
475     // was written to.
476     if (read_only && info.is_written) continue;
477 
478     read_op.replaceAllUsesWith(info.hoisted_read);
479     read_op.erase();
480   }
481 }
482 
483 // For written resources, add its value at the end of each region to that
484 // regions return value. For a region, its value at the end may be a value
485 // written to that resource in that region, or its hoisted read value if the
486 // resource is not written in that region. The return value can be vended out
487 // either as an existing return value, or a newly allocated return value.
AppendResourceStoreValueToReturn(RegionRange regions)488 void RegionResourceHoister::AppendResourceStoreValueToReturn(
489     RegionRange regions) {
490   for (Region* region : regions) {
491     assert(llvm::hasSingleElement(*region) && "Expected single block region");
492     Block& front = region->front();
493     auto old_return = front.getTerminator();
494     assert(old_return->getNumOperands() == op_->getNumResults());
495     auto new_return_operands = llvm::to_vector<4>(old_return->getOperands());
496     new_return_operands.resize(num_new_results_);
497 
498     // initialize return values for written resources to be the hoisted reads.
499     for (Value resource : written_resources_) {
500       const ResourceInfo& info = resources_[resource];
501       new_return_operands[info.result_index] = info.hoisted_read;
502     }
503 
504     // Only iterate through ops directly in the body as op's embedded in child
505     // regions should have been lifted out.
506     auto assign_ops = front.getOps<TF::AssignVariableOp>();
507     for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
508       Value resource = assign_variable_op.resource();
509       if (!IsWritten(resource)) continue;
510 
511       // TODO(ycao): Prevent same value from being returned multiple times.
512       // TODO(ycao): Do not return resource store value if it is defined outside
513       // of cluster. Both of these can be post-resource-op-lifting cleanup
514       // passes.
515       int result_index = resources_[resource].result_index;
516       new_return_operands[result_index] = assign_variable_op.value();
517       assign_variable_op.erase();
518     }
519     old_return->setOperands(new_return_operands);
520   }
521 }
522 
523 // Replace the old op with a new op (with potentially additional results), and
524 // add stores to written resources after the new op.
ReplaceOpWithNewOp()525 void RegionResourceHoister::ReplaceOpWithNewOp() {
526   auto new_result_types = llvm::to_vector<4>(op_->getResultTypes());
527   int result_region = isa<TF::WhileRegionOp>(op_) ? 1 : 0;
528   Operation* terminator = op_->getRegion(result_region).front().getTerminator();
529   auto extra_result_types =
530       terminator->getOperands().drop_front(op_->getNumResults()).getTypes();
531   new_result_types.insert(new_result_types.end(), extra_result_types.begin(),
532                           extra_result_types.end());
533   OpBuilder builder(op_);
534   // Clone this old operation but with new result types.
535   Operation* new_op = Operation::create(
536       op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(),
537       op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions());
538   builder.insert(new_op);
539 
540   // Move regions to the new op.
541   for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) {
542     Region& old_region = std::get<0>(it);
543     Region& new_region = std::get<1>(it);
544     new_region.takeBody(old_region);
545   }
546 
547   // Insert stores to all written resources.
548   for (Value resource : written_resources_) {
549     ResourceInfo& info = resources_[resource];
550     Value value_to_write = new_op->getResult(info.result_index);
551     Operation* write = builder.create<TF::AssignVariableOp>(
552         op_->getLoc(), resource, value_to_write);
553     write->setAttrs(info.write_attrs);
554   }
555 
556   // As a part of lifting, we either reuse an existing slot for resource type
557   // results or add a new slot. Resource type results should not have any uses
558   // to begin with. So we can safely replace each old op result with the
559   // corresponding new op result.
560   int old_num_results = op_->getNumResults();
561   op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results));
562   op_->erase();
563   op_ = nullptr;
564 }
565 
566 // Lift resource load and stores out of regions attached to `op`, where op is
567 // an If/case/cluster op.
HoistResourcesOutOfIfCaseCluster(Operation * op)568 LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster(
569     Operation* op) {
570   RegionResourceHoister hoister(op);
571   if (failed(hoister.Analyze())) return failure();
572 
573   // If there are no resource region captures, then nothing to do.
574   if (!hoister.NeedsLifting()) return success();
575 
576   // Start the transformation. For each region, replace the resource read with
577   // the value read before the op.
578   hoister.GenerateHoistedReads();
579   hoister.ReplaceResourceLoads(/*read_only=*/false);
580   hoister.AppendResourceStoreValueToReturn(op->getRegions());
581   hoister.ReplaceOpWithNewOp();
582   return success();
583 }
584 
585 // Lift resource loads and stores out of WhileRegion
HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op)586 LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion(
587     TF::WhileRegionOp op) {
588   // For WhileRegion, post canonicalization all resource used within the
589   // body and condition regions are replaced with captured values, so we do not
590   // need to take into account the body and condition region arguments.
591   RegionResourceHoister hoister(op);
592 
593   if (failed(hoister.Analyze())) return failure();
594 
595   // If there are no resource region captures, then nothing to do.
596   if (!hoister.NeedsLifting()) return success();
597 
598   // The resources captured for While loop fall into two categories:
599   // (a) read-only. These reads can be replaced by a hoisted read created
600   //        before the WhileOp (similar to if and case).
601   // (b) written: since the value is written in the loop (which can only in
602   //        loop body, all these will become loop variables. Since all resource
603   //        variables are removed from the loop variabled during
604   //        canonicalizationW, we need to create new operand/result slots. The
605   //        input operands for these slots are the read values
606   //        prior to the op, and all references to these are replaced by the
607   //        corresponding slot argument. We need to generate writes following
608   //        the while for these resources.
609   //
610   // Note that for WhileRegion ops, if a resource is written, it will be written
611   // only in the body and not the condition, so the hoister analysis will infer
612   // it as needing a read as well.
613 
614   // Generate hoisted reads before the while.
615   hoister.GenerateHoistedReads();
616 
617   // Replace just the read-only resources with the hoisted reads.
618   hoister.ReplaceResourceLoads(/*read_only=*/true);
619 
620   // For written resources, add additional operands to the while op.
621   int num_old_results = op.getNumResults();
622   int num_new_results = hoister.GetLiftedNumResults();
623   int num_extra_results = num_new_results - num_old_results;
624 
625   SmallVector<Type, 4> new_result_types;
626   SmallVector<Value, 4> new_while_operands;
627   new_result_types.resize(num_extra_results);
628   new_while_operands.resize(num_extra_results);
629 
630   for (auto& it : hoister.GetResources()) {
631     if (!it.second.is_written) continue;
632     int index = it.second.result_index - num_old_results;
633     new_result_types[index] = it.second.data_type;
634     new_while_operands[index] = it.second.hoisted_read;
635   }
636   op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands);
637 
638   // Patch the cond and body regions to have additional arguments, and replace
639   // the remaining resource reads (which will be resource reads for written
640   // resources) with these arguments.
641   for (Region* region : op.getRegions()) {
642     region->addArguments(new_result_types);
643     // Point hoisted read for written resources to the region's arguments.
644     for (auto& it : hoister.GetResources()) {
645       if (!it.second.is_written) continue;
646       it.second.hoisted_read = region->getArgument(it.second.result_index);
647     }
648     hoister.ReplaceResourceLoads(*region, /*read_only=*/false);
649   }
650 
651   // Add additional return values to body return. These correspond to values
652   // written to resources in the body region.
653   hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front());
654 
655   // Finally, create a new while with additional return values.
656   hoister.ReplaceOpWithNewOp();
657   return success();
658 }
659 
660 // Lift resources out of the regions attached to `op`
ReplaceOpWithNewOp(Operation * op)661 LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) {
662   if (auto while_op = dyn_cast<TF::WhileRegionOp>(op))
663     return HoistResourcesOutOfWhileRegion(while_op);
664   return HoistResourcesOutOfIfCaseCluster(op);
665 }
666 
667 // Holds information about a function's use of a resource argument.
668 struct ResourceArgUseInfo {
669   // Data type of the data contained in the resource.
670   Type data_type;
671   // Is the resource argument used in an assign op?
672   bool updated;
673   // Is the resource argument used in a read or assign op?
674   bool used;
675 };
676 
677 // Finds the ResourceArgUseInfo for each resource argument. Forwarding to the
678 // output (i.e., the argument is an operand of the return op) is not considered
679 // as a use. This doesn't support nesting of ops, so before calling this, nested
680 // ops/functions need to be already resource-lifted.
FindResourceArgUseInfo(FuncOp func_op,llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> * result)681 LogicalResult FindResourceArgUseInfo(
682     FuncOp func_op, llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>* result) {
683   auto return_op = func_op.front().getTerminator();
684   for (auto arg : TF::filter_resources(func_op.getArguments())) {
685     ResourceArgUseInfo info;
686     info.used = false;
687     info.updated = false;
688     bool read_or_assigned = false;
689     bool used_in_unsupported_op = false;
690     for (auto user : arg.getUsers()) {
691       if (user == return_op) continue;
692       info.used = true;
693       if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
694         read_or_assigned = true;
695         info.data_type = read.getType();
696         continue;
697       }
698 
699       if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
700         read_or_assigned = true;
701         info.updated = true;
702         info.data_type = assign.value().getType();
703         continue;
704       }
705 
706       used_in_unsupported_op = true;
707       break;
708     }
709 
710     // If the arg is used in an unsupported op, skip lifting it.
711     if (used_in_unsupported_op) continue;
712     (*result)[arg.getArgNumber()] = info;
713   }
714   return success();
715 }
716 
717 // Merges two sets of resource arg use infos. An argument is considered used in
718 // the merged result as long as either set marks it as used. This is used to
719 // merge results from functions that have aliasing inputs, e.g., a while loop's
720 // body and condition. The sets of keys of the two maps must be the same.
MergeArgResourceUseInfo(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos0,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos1)721 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
722     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
723     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
724   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result;
725   for (const auto& entry : infos0) {
726     auto info1_it = infos1.find(entry.getFirst());
727     // If the entry is missing in any input, we should not touch this entry.
728     if (info1_it == infos1.end()) continue;
729     auto& info = result[entry.getFirst()];
730     info = entry.getSecond();
731     if (info.updated) continue;
732     if (info1_it->getSecond().used) {
733       info.used = true;
734       info.updated = info1_it->getSecond().updated;
735       info.data_type = info1_it->getSecond().data_type;
736     }
737   }
738   return result;
739 }
740 
741 // Removes the unused resource arguments, and the return values that forward the
742 // removed arguments. If old_to_new_arg_indices is provided, it will store the
743 // new argument index that corresponds to each original index (-1 means it is
744 // removed). If remaining_resource_data_types is provided, it will store the
745 // data types of the remaining resource arguments, where the indices are after
746 // removing unused ones.
RemoveUnusedResourceArgumentsAndForwardedRetvals(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos,FuncOp func_op,llvm::SmallVector<int64_t,4> * old_to_new_arg_indices=nullptr,llvm::SmallDenseMap<int64_t,Type> * remaining_resource_data_types=nullptr)747 void RemoveUnusedResourceArgumentsAndForwardedRetvals(
748     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos,
749     FuncOp func_op,
750     llvm::SmallVector<int64_t, 4>* old_to_new_arg_indices = nullptr,
751     llvm::SmallDenseMap<int64_t, Type>* remaining_resource_data_types =
752         nullptr) {
753   // Remove return values forwarded from unused arguments.
754   auto return_op = func_op.front().getTerminator();
755   auto old_return_vals = llvm::to_vector<8>(return_op->getOperands());
756   int64_t skipped_retvals = 0;
757   for (auto entry : llvm::enumerate(old_return_vals)) {
758     auto return_val = entry.value();
759     if (auto arg = return_val.dyn_cast<BlockArgument>()) {
760       auto it = infos.find(arg.getArgNumber());
761       if (it != infos.end() && !it->getSecond().used) {
762         return_op->eraseOperand(entry.index() - skipped_retvals++);
763       }
764     }
765   }
766   llvm::SmallVector<unsigned int, 4> indices_to_erase;
767   llvm::SmallVector<Type, 4> new_types;
768   int64_t skipped_args = 0;
769   for (auto arg : func_op.getArguments()) {
770     auto it = infos.find(arg.getArgNumber());
771     if (it != infos.end() && !it->getSecond().used) {
772       indices_to_erase.push_back(arg.getArgNumber());
773       skipped_args++;
774       if (old_to_new_arg_indices != nullptr) {
775         old_to_new_arg_indices->push_back(-1);
776       }
777     } else {
778       new_types.push_back(arg.getType());
779       if (old_to_new_arg_indices != nullptr) {
780         old_to_new_arg_indices->push_back(arg.getArgNumber() - skipped_args);
781       }
782       if (it != infos.end() && remaining_resource_data_types != nullptr) {
783         (*remaining_resource_data_types)[arg.getArgNumber() - skipped_args] =
784             it->second.data_type;
785       }
786     }
787   }
788   func_op.eraseArguments(indices_to_erase);
789   func_op.setType(
790       FunctionType::get(func_op.getContext(), new_types,
791                         llvm::to_vector<4>(return_op->getOperandTypes())));
792 }
793 
794 // Lifts reads/writes of resource arguments from func_op and changes its
795 // signature. resource_data_types is the (index, data type) pair for each
796 // resource argument. handle_updated_arg_value is a caller-provided function
797 // that handles the updated value for an resource argument.
LiftArgRetResourcesForFunction(FuncOp func_op,const llvm::SmallDenseMap<int64_t,Type> & resource_data_types,llvm::function_ref<void (int64_t,Value)> handle_updated_arg_value)798 LogicalResult LiftArgRetResourcesForFunction(
799     FuncOp func_op,
800     const llvm::SmallDenseMap<int64_t, Type>& resource_data_types,
801     llvm::function_ref<void(int64_t, Value)> handle_updated_arg_value) {
802   RegionResourceHoister hoister(func_op);
803   if (failed(hoister.Analyze())) return failure();
804 
805   // Each of these resources could be read or written in the function. If its
806   // read, we need to replace the resource arg with a value arg to get the
807   // read value. If its written, we need to replace the write with an additional
808   // value to be written.
809 
810   // Now create read values that will be used to replace each resource that
811   // is read in the function body. These read values are just the same argument
812   // with type replaced.
813   llvm::SmallVector<Value, 4> skipped_args;
814   for (auto& it : hoister.GetResources()) {
815     BlockArgument arg = it.first.dyn_cast<BlockArgument>();
816     assert(arg && "Expect resources for FuncOp to be its arguments");
817     auto type_iter = resource_data_types.find(arg.getArgNumber());
818     if (type_iter == resource_data_types.end()) {
819       // Skip lifting the resource if it's not present in the data type map.
820       // This indicates that the resource is not to be lifted because it is used
821       // in an unsupported op in some other function.
822       skipped_args.push_back(arg);
823     } else {
824       arg.setType(type_iter->second);
825       it.second.hoisted_read = arg;
826     }
827   }
828 
829   // Drop all the args that have to be skipped.
830   for (Value arg : skipped_args) hoister.DropResource(arg);
831 
832   hoister.ReplaceResourceLoads(/*read_only=*/false);
833 
834   // For writes, invoke the callback and then erase the write.
835   auto assign_ops = func_op.front().getOps<TF::AssignVariableOp>();
836   for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
837     Value resource = assign_variable_op.resource();
838     if (!hoister.Contains(resource)) continue;
839 
840     auto arg = resource.dyn_cast<BlockArgument>();
841     handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value());
842     assign_variable_op.erase();
843   }
844 
845   func_op.setType(FunctionType::get(
846       func_op.getContext(), func_op.front().getArgumentTypes(),
847       func_op.front().getTerminator()->getOperandTypes()));
848 
849   return success();
850 }
851 
852 // Returns a vector filtered from range where the unused elements (specified by
853 // resource_arg_uses) are removed.
854 template <typename T, typename Range>
FilterRange(Range range,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & resource_arg_uses)855 llvm::SmallVector<T, 4> FilterRange(
856     Range range,
857     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& resource_arg_uses) {
858   llvm::SmallVector<T, 4> filtered;
859   for (auto entry : llvm::enumerate(range)) {
860     auto it = resource_arg_uses.find(entry.index());
861     if (it == resource_arg_uses.end() || it->getSecond().used)
862       filtered.push_back(entry.value());
863   }
864   return filtered;
865 }
866 
867 // Changes the types of the control flow op (e.g., while, if) and adds loads and
868 // stores around it. arg_data_type_and_updated_output_index maps an operand (to
869 // be changed) index to its data type and the updated value index in the output
870 // (-1 means not updated.)
AddLoadsStoresOutsideControlFlowOp(Operation * caller,const llvm::SmallDenseMap<int64_t,std::pair<Type,int64_t>> & arg_data_type_and_updated_output_index)871 void AddLoadsStoresOutsideControlFlowOp(
872     Operation* caller,
873     const llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>&
874         arg_data_type_and_updated_output_index) {
875   OpBuilder builder(caller);
876   auto new_operands = llvm::to_vector<8>(caller->getOperands());
877   llvm::SmallVector<int64_t, 8> changed_indices;
878   // Find the operands to change, and create the loads.
879   for (auto& entry : arg_data_type_and_updated_output_index) {
880     int64_t index = entry.getFirst();
881     Type new_type = entry.getSecond().first;
882     int64_t updated_index = entry.getSecond().second;
883     auto operand = caller->getOperand(index);
884     builder.setInsertionPoint(caller);
885     new_operands[index] = builder.create<TF::ReadVariableOp>(
886         caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand});
887     caller->setOperand(index, new_operands[index]);
888     if (updated_index < 0) continue;
889     builder.setInsertionPointAfter(caller);
890     builder.create<TF::AssignVariableOp>(
891         caller->getLoc(), ArrayRef<Type>{},
892         ArrayRef<Value>{operand, caller->getResult(updated_index)});
893   }
894 }
895 
896 // Lifts loads/stores from while loop's body and cond functions.
HandleWhileLoop(TF::WhileOp while_op,FuncOp body,FuncOp cond)897 LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
898   auto return_op = body.front().getTerminator();
899   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> body_use_info;
900   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> cond_use_info;
901   if (failed(FindResourceArgUseInfo(body, &body_use_info)) ||
902       failed(FindResourceArgUseInfo(cond, &cond_use_info))) {
903     return failure();
904   }
905   // A resource is considered used as long as it is used in either body or cond.
906   auto resource_arg_uses =
907       MergeArgResourceUseInfo(body_use_info, cond_use_info);
908   if (resource_arg_uses.empty()) return success();
909 
910   // Remove unused resources in functions.
911   llvm::SmallVector<int64_t, 4> old_to_new_indices;
912   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
913   RemoveUnusedResourceArgumentsAndForwardedRetvals(
914       resource_arg_uses, body, &old_to_new_indices,
915       &remaining_resource_data_types);
916   RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond);
917   (void)LiftArgRetResourcesForFunction(
918       body, remaining_resource_data_types,
919       [&](int64_t index, Value value) { return_op->setOperand(index, value); });
920   (void)LiftArgRetResourcesForFunction(cond, remaining_resource_data_types,
921                                        [&](int64_t index, Value value) {
922                                          // We already checked that cond should
923                                          // not have variable writes.
924                                          assert(false && "Should not happen");
925                                        });
926   // Recreate the while op.
927   OpBuilder builder(while_op);
928   // Now use the filtered original operands, which will be replaced by
929   // AddLoadsStoresOutsideControlFlowOp().
930   auto new_while = builder.create<TF::WhileOp>(
931       while_op.getLoc(), body.getType().getResults(),
932       FilterRange<Value, OperandRange>(while_op.getOperands(),
933                                        resource_arg_uses),
934       while_op.getAttrs());
935   // Prepare for AddLoadsStoresOutsideControlFlowOp().
936   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
937       arg_data_type_and_updated_output_index;
938   for (const auto& entry : remaining_resource_data_types) {
939     int64_t update_index = return_op->getOperand(entry.getFirst()) ==
940                                    body.getArgument(entry.getFirst())
941                                ? -1
942                                : entry.getFirst();
943     arg_data_type_and_updated_output_index[entry.getFirst()] = {
944         entry.getSecond(), update_index};
945   }
946   AddLoadsStoresOutsideControlFlowOp(new_while,
947                                      arg_data_type_and_updated_output_index);
948   // Replace uses.
949   for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) {
950     if (old_to_new_indices[i] >= 0) {
951       while_op.getResult(i).replaceAllUsesWith(
952           new_while.getResult(old_to_new_indices[i]));
953     }
954   }
955   while_op.erase();
956   return success();
957 }
958 
959 // Lifts loads/stores from an IfOp or CaseOp's branches.
960 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<FuncOp> branches)961 LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
962   // For canonicalized If/Case, there should not be any resource outputs
963   int64_t non_resource_results = op.getNumResults();
964 
965   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses;
966   if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses)))
967     return failure();
968 
969   for (auto func : branches.drop_front()) {
970     llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info;
971     if (failed(FindResourceArgUseInfo(func, &branch_use_info)))
972       return failure();
973     // A resource is considered used as long as it is used in either branch.
974     resource_arg_uses =
975         MergeArgResourceUseInfo(resource_arg_uses, branch_use_info);
976   }
977 
978   if (resource_arg_uses.empty()) return success();
979   // Remove unused resources in functions.
980   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
981   RemoveUnusedResourceArgumentsAndForwardedRetvals(
982       resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr,
983       &remaining_resource_data_types);
984   for (auto func : branches.drop_front())
985     RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func);
986 
987   // Forward resource inputs updated in any branch to the outputs of both
988   // branches. First prepare the mapping from arg to new update output.
989   llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output;
990   {
991     int64_t removed_args = 0;
992     for (const auto& entry : resource_arg_uses) {
993       if (!entry.getSecond().used) {
994         removed_args++;
995         continue;
996       }
997       if (!entry.getSecond().updated) continue;
998       int64_t new_output_index =
999           non_resource_results + resource_arg_to_new_output.size();
1000       resource_arg_to_new_output[entry.getFirst() - removed_args] =
1001           new_output_index;
1002     }
1003   }
1004 
1005   // Append resource updates to the return ops: now they are just forwarded
1006   // input resources, but will be replaced by the data value in
1007   // LiftArgRetResourcesForFunction().
1008   for (auto branch : branches) {
1009     auto new_retvals =
1010         llvm::to_vector<4>(branch.front().getTerminator()->getOperands());
1011     new_retvals.resize(new_retvals.size() + resource_arg_to_new_output.size());
1012     for (const auto& entry : resource_arg_to_new_output) {
1013       int64_t resource_arg_index = entry.getFirst();
1014       int64_t output_index = entry.getSecond();
1015       new_retvals[output_index] = branch.getArgument(resource_arg_index);
1016     }
1017     auto old_return = branch.front().getTerminator();
1018     OpBuilder builder(old_return);
1019     auto new_return =
1020         builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
1021     old_return->erase();
1022     (void)LiftArgRetResourcesForFunction(
1023         branch, remaining_resource_data_types, [&](int64_t index, Value value) {
1024           new_return.setOperand(resource_arg_to_new_output[index], value);
1025         });
1026   }
1027 
1028   // Recreate the op without resource operands.
1029   OpBuilder builder(op);
1030   // Now use the filtered original operands, which will be replaced by
1031   // AddLoadsStoresOutsideControlFlowOp().
1032   auto new_operands =
1033       FilterRange<Value, OperandRange>(op.input(), resource_arg_uses);
1034   new_operands.insert(new_operands.begin(), op.getOperand(0));
1035   FuncOp first_func = branches.front();
1036   auto new_op =
1037       builder.create<CaseOrIfOp>(op.getLoc(), first_func.getType().getResults(),
1038                                  new_operands, op.getAttrs());
1039   // Prepare for AddLoadsStoresOutsideControlFlowOp()
1040   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
1041       arg_data_type_and_updated_output_index;
1042   for (const auto& entry : remaining_resource_data_types) {
1043     auto new_output_it = resource_arg_to_new_output.find(entry.getFirst());
1044     int64_t update_index = new_output_it == resource_arg_to_new_output.end()
1045                                ? -1
1046                                : new_output_it->getSecond();
1047     arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
1048         entry.getSecond(), update_index};
1049   }
1050   AddLoadsStoresOutsideControlFlowOp(new_op,
1051                                      arg_data_type_and_updated_output_index);
1052   // Replace uses.
1053   op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
1054   op.erase();
1055   return success();
1056 }
1057 
1058 // A resource-lifted function for (potentially multiple) PartitionedCallOps and
1059 // information about the lifting changes.
1060 struct PartitionedCallLiftingInfo {
1061   // Function with resources lifted. Can be nullptr if nothing needs to change.
1062   FuncOp lifted_callee;
1063   // Mapping from old resource outputs to their aliasing output inputs.
1064   llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
1065   // Mapping from old to new output indices in case any output is removed.
1066   llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
1067   // ResourceArgUseInfo for each old resource argument.
1068   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
1069   // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
1070   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
1071       arg_data_type_and_updated_output_index;
1072 };
1073 
1074 // Lifts loads/stores from a PartitionedCallOp's callee function. If anything
1075 // needs to be changed, the original function will be preserved, and the lifting
1076 // happens on a clone, which will be stored in `result`.
HandlePartitionedCallOpCallee(FuncOp callee,PartitionedCallLiftingInfo * result)1077 LogicalResult HandlePartitionedCallOpCallee(
1078     FuncOp callee, PartitionedCallLiftingInfo* result) {
1079   // Sanity check: return of resources should be aliases of inputs. Such outputs
1080   // will be removed later.
1081   int64_t non_resource_results = 0;
1082   for (auto entry :
1083        llvm::enumerate(callee.front().getTerminator()->getOperands())) {
1084     auto retval = entry.value();
1085     if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
1086       result->old_to_new_output_indices.push_back(non_resource_results++);
1087       continue;
1088     }
1089     auto aliasing_arg = retval.dyn_cast<BlockArgument>();
1090     if (!aliasing_arg) {
1091       return callee.emitOpError("unsupported function call: ")
1092              << "resource return value does not alias an input.";
1093     }
1094     result->old_outputs_aliasing_old_inputs[entry.index()] =
1095         aliasing_arg.getArgNumber();
1096     result->old_to_new_output_indices.push_back(-1);
1097   }
1098 
1099   if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
1100     return failure();
1101   }
1102   if (result->use_info.empty()) {
1103     result->lifted_callee = nullptr;
1104     return success();
1105   }
1106 
1107   // Clone the callee before making changes.
1108   SmallString<64> name_base = callee.getName();
1109   auto module = callee->getParentOfType<ModuleOp>();
1110   name_base += "_resource_lifted";
1111   auto name = name_base;
1112   callee = callee.clone();
1113   callee.setPrivate();
1114   callee.setName(name);
1115   SymbolTable(module).insert(callee);
1116   result->lifted_callee = callee;
1117 
1118   // Remove unused resources in functions.
1119   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
1120   RemoveUnusedResourceArgumentsAndForwardedRetvals(
1121       result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
1122       &remaining_resource_data_types);
1123   for (const auto& entry : remaining_resource_data_types) {
1124     result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
1125         entry.getSecond(), -1};
1126   }
1127   llvm::SmallVector<int64_t, 4> retval_indices_to_preserve;
1128   for (auto& val : callee.front().getTerminator()->getOpOperands()) {
1129     // Store indices of results that are not resources.
1130     if (!getElementTypeOrSelf(val.get().getType()).isa<TF::ResourceType>())
1131       retval_indices_to_preserve.push_back(val.getOperandNumber());
1132   }
1133   int64_t num_retvals = retval_indices_to_preserve.size();
1134   llvm::SmallVector<Value, 4> new_retvals;
1135   // Lift resources.
1136   (void)LiftArgRetResourcesForFunction(
1137       callee, remaining_resource_data_types, [&](int64_t index, Value value) {
1138         result->arg_data_type_and_updated_output_index[index].second =
1139             num_retvals++;
1140         new_retvals.push_back(value);
1141       });
1142 
1143   auto old_return = callee.front().getTerminator();
1144   llvm::SmallVector<Value, 4> old_and_new_retvals;
1145   old_and_new_retvals.reserve(retval_indices_to_preserve.size() +
1146                               new_retvals.size());
1147   for (int64_t retval_index : retval_indices_to_preserve)
1148     old_and_new_retvals.push_back(old_return->getOperand(retval_index));
1149 
1150   old_and_new_retvals.append(new_retvals.begin(), new_retvals.end());
1151   // Replace old return with the new ones with update values.
1152   OpBuilder builder(old_return);
1153   auto new_return =
1154       builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
1155   old_return->erase();
1156   callee.setType(
1157       FunctionType::get(callee.getContext(), callee.getType().getInputs(),
1158                         llvm::to_vector<4>(new_return.getOperandTypes())));
1159   return success();
1160 }
1161 
1162 // Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
1163 // resource-lifted new callee function in lifting_info.
1164 template <typename CallOpType>
UpdatePartitionedCallOpWithNewCallee(CallOpType call_op,PartitionedCallLiftingInfo & lifting_info)1165 void UpdatePartitionedCallOpWithNewCallee(
1166     CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) {
1167   if (!lifting_info.lifted_callee) return;
1168   // Replace output resource uses with the aliasing input, so that we can remove
1169   // this output.
1170   for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
1171     call_op.getResult(entry.getFirst())
1172         .replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
1173   }
1174   // Recreate the call op.
1175   OpBuilder builder(call_op);
1176   // Now use the filtered original operands, which will be replaced by
1177   // AddLoadsStoresOutsideControlFlowOp().
1178   auto new_operands =
1179       FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
1180   auto new_call = builder.create<CallOpType>(
1181       call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
1182       new_operands, call_op.getAttrs());
1183   new_call->setAttr(
1184       "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
1185   AddLoadsStoresOutsideControlFlowOp(
1186       new_call, lifting_info.arg_data_type_and_updated_output_index);
1187   // Replace uses.
1188   for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size();
1189        i < end; ++i) {
1190     if (lifting_info.old_to_new_output_indices[i] >= 0) {
1191       call_op.getResult(i).replaceAllUsesWith(
1192           new_call.getResult(lifting_info.old_to_new_output_indices[i]));
1193     }
1194   }
1195   call_op.erase();
1196 }
1197 
1198 LogicalResult HoistForControlFlow(
1199     Block*, ModuleOp, bool,
1200     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*);
1201 
1202 // A templated routine for handling both PartitionedCallOp and
1203 // StatefulPartitionedCallOp. If the callee is already lifted, it just updates
1204 // the caller op itself; otherwise, it first recursively handles nested control
1205 // flow, then performs lifting on the callee.
1206 template <typename CallOpType>
HandlePartitionedCallOp(CallOpType call_op,FuncOp callee,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_callees)1207 LogicalResult HandlePartitionedCallOp(
1208     CallOpType call_op, FuncOp callee, ModuleOp module, bool vars_initialized,
1209     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1210         lifted_callees) {
1211   auto emplace_res = lifted_callees->try_emplace(callee.getName(),
1212                                                  PartitionedCallLiftingInfo());
1213   if (emplace_res.second) {
1214     // Unseen callee. Perform resource lifting on it.
1215     if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized,
1216                                    lifted_callees)))
1217       return failure();
1218 
1219     if (failed(HandlePartitionedCallOpCallee(
1220             callee, &emplace_res.first->getSecond()))) {
1221       return failure();
1222     }
1223   }
1224   UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
1225   return success();
1226 }
1227 
1228 // Hoists resource loads/stores from control flow ops in `block` outside the
1229 // body/cond/branch/callee functions.
HoistForControlFlow(Block * block,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_partitioned_call_callees)1230 LogicalResult HoistForControlFlow(
1231     Block* block, ModuleOp module, bool vars_initialized,
1232     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1233         lifted_partitioned_call_callees) {
1234   if (vars_initialized) SetAllVarIsInitializedToTrue(block);
1235 
1236   for (Operation& op : llvm::make_early_inc_range(*block)) {
1237     if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
1238       auto body = while_op.body_function();
1239       auto cond = while_op.cond_function();
1240       // Recursively handle the nested control flow.
1241       (void)HoistForControlFlow(&body.front(), module, vars_initialized,
1242                                 lifted_partitioned_call_callees);
1243       (void)HoistForControlFlow(&cond.front(), module, vars_initialized,
1244                                 lifted_partitioned_call_callees);
1245       if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
1246     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
1247       auto then_branch = if_op.then_function();
1248       auto else_branch = if_op.else_function();
1249       // Recursively handle the nested control flow.
1250       (void)HoistForControlFlow(&then_branch.front(), module, vars_initialized,
1251                                 lifted_partitioned_call_callees);
1252       (void)HoistForControlFlow(&else_branch.front(), module, vars_initialized,
1253                                 lifted_partitioned_call_callees);
1254       if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch})))
1255         return failure();
1256     } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
1257       SmallVector<FuncOp, 4> branch_functions;
1258       case_op.get_branch_functions(branch_functions);
1259       for (FuncOp func : branch_functions) {
1260         // Recursively handle the nested control flow.
1261         (void)HoistForControlFlow(&func.front(), module, vars_initialized,
1262                                   lifted_partitioned_call_callees);
1263       }
1264       if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
1265     } else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
1266       auto callee = call_op.func();
1267       if (!callee) {
1268         return call_op.emitOpError(
1269             "resource lifting does not support call with nested references.");
1270       }
1271       if (failed(HandlePartitionedCallOp(call_op, callee, module,
1272                                          vars_initialized,
1273                                          lifted_partitioned_call_callees))) {
1274         // Nested control flow handling is done in HandlePartitionedCallOp().
1275         return failure();
1276       }
1277     } else if (auto call_op =
1278                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
1279       if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module,
1280                                          vars_initialized,
1281                                          lifted_partitioned_call_callees))) {
1282         return failure();
1283       }
1284     } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, TF::WhileRegionOp>(op)) {
1285       for (Region& region : op.getRegions())
1286         (void)HoistForControlFlow(&region.front(), module, vars_initialized,
1287                                   lifted_partitioned_call_callees);
1288       LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op);
1289       if (failed(result)) return failure();
1290     }
1291   }
1292 
1293   // After we have hoisted operations in the block, we may have added new read
1294   // and writes of resources to this block. Clean them up by doing store-load
1295   // forwarding.
1296   ForwardStoreToLoad(block);
1297   return success();
1298 }
1299 
1300 // Lifts resource operation from tf_device.cluster ops nested in `op` outside.
1301 // Returns failure if there are remaining resource-type values that can not be
1302 // lifted.
runOnOperation()1303 void ResourceOpLiftingPass::runOnOperation() {
1304   llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1305       lifted_partitioned_call_callees;
1306   ModuleOp module = getOperation();
1307 
1308   if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module)))
1309     return signalPassFailure();
1310 
1311   auto walk_result = module.walk([&](FuncOp func_op) {
1312     return func_op.walk([&](tf_device::ClusterOp cluster) {
1313       LogicalResult result = HoistForControlFlow(
1314           &cluster.GetBody(), module, /*vars_initialized=*/true,
1315           &lifted_partitioned_call_callees);
1316       if (failed(result)) return WalkResult::interrupt();
1317       result = RegionResourceHoister::ReplaceOpWithNewOp(cluster);
1318       if (failed(result)) return WalkResult::interrupt();
1319       return WalkResult::advance();
1320     });
1321   });
1322 
1323   if (walk_result.wasInterrupted()) return signalPassFailure();
1324 }
1325 
1326 struct ResourceOpLiftingForMainFunctionPass
1327     : public PassWrapper<ResourceOpLiftingForMainFunctionPass,
1328                          OperationPass<ModuleOp>> {
1329   void runOnOperation() override;
1330 };
1331 
runOnOperation()1332 void ResourceOpLiftingForMainFunctionPass::runOnOperation() {
1333   ModuleOp module = getOperation();
1334   FuncOp main_func = module.lookupSymbol<FuncOp>("main");
1335   if (!main_func) {
1336     return;
1337   }
1338 
1339   if (failed(TF::ResourceLiftingForFunctionalControlFlow(main_func))) {
1340     return signalPassFailure();
1341   }
1342 }
1343 
1344 static PassRegistration<ResourceOpLiftingForMainFunctionPass>
1345     lift_main_func_pass(
1346         "tf-resource-op-lifting-for-main-function",
1347         "Lifting resource operations out of control flow statements for the "
1348         "main function");
1349 
1350 static PassRegistration<ResourceOpLiftingPass> pass(
1351     "tf-resource-op-lifting",
1352     "Lifting resource operations out of device computation");
1353 
1354 }  // namespace
1355 
1356 namespace TFDevice {
CreateResourceOpLiftingPass()1357 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() {
1358   return std::make_unique<ResourceOpLiftingPass>();
1359 }
1360 }  // namespace TFDevice
1361 
1362 namespace TF {
ResourceLiftingForFunctionalControlFlow(FuncOp function)1363 LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
1364   // This routine should only be called when control flow operations are still
1365   // represented with TF IfOp and WhileOp operations. In this case, there should
1366   // be only one basic blocks in the MLIR representation.
1367   if (!llvm::hasSingleElement(function)) {
1368     return function.emitError()
1369            << "expect the function to have 1 block while it has "
1370            << function.getBlocks().size();
1371   }
1372 
1373   if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function)))
1374     return failure();
1375 
1376   llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1377       lifted_partitioned_call_callees;
1378   if (failed(HoistForControlFlow(
1379           &function.front(), cast<ModuleOp>(function->getParentOp()),
1380           /*vars_initialized=*/false, &lifted_partitioned_call_callees)))
1381     return failure();
1382 
1383   // Clean up and canonicalize to remove dead local variables as some local
1384   // variables might be dead after hoisting resource loads/stores from control
1385   // flow ops.
1386   return TF::CleanupAndCanonicalizeForResourceOpLifting(function);
1387 }
1388 }  // namespace TF
1389 
1390 }  // namespace mlir
1391