1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/ADT/iterator_range.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/Identifier.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
48 #define DEBUG_TYPE "tf-tpu-merge-variables-with-execute"
49 
50 namespace mlir {
51 namespace TFTPU {
52 
53 namespace {
54 constexpr char kAliasingAttr[] = "tf.aliasing_output";
55 constexpr char kDeviceAttr[] = "device";
56 constexpr char kFuncDeviceAttr[] = "tf.device";
57 
58 // A pass that finds on-device resource variable reads/assigns surrounding a
59 // tf.TPUExecute op, and merges them into a tf.TPUExecuteAndUpdateVariables.
60 // This allows the TPU execution to perform in-place variable updates.
61 //
62 // For example,
63 //
64 //   %0 = "tf.ReadVariableOp"(%arg0)
65 //   %1 = "tf.ReadVariableOp"(%arg1)
66 //   %2 = "tf.TPUExecute"(%0, %1, %compile)
67 //   %3 = "tf.AssignVariableOp"(%arg0, %2)
68 //
69 // will be transformed into
70 //
71 //   %2 = "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %compile)
72 //     { device_var_reads_indices = [0, 1],
73 //       device_var_updates_indices = [0, -1] }
74 //
75 // The transformation happens only for on-device variables. The above
76 // transformation requires %arg0, %arg1 to have the same device assignment as
77 // the TPUExecute op.
78 
79 struct TPUMergeVariablesWithExecutePass
80     : public PassWrapper<TPUMergeVariablesWithExecutePass, FunctionPass> {
81   void runOnFunction() override;
82 };
83 
84 // Information for a pair of input/output of the TPUExecute op and the
85 // surrounding read/assign ops.
86 struct VariableAccessInfo {
87   int execute_input_index = -1;
88   int execute_output_index = -1;
89   Operation* read = nullptr;
90   Operation* assign = nullptr;
91 };
92 
93 // Information about all resource accesses to be fused into a TPUExecute op.
94 struct VariableAccessesForTPUExecute {
95   // Maps each resource detected to VariableAccessInfo.
96   llvm::SmallDenseMap<Value, VariableAccessInfo, 8> per_resource_info;
97   // The corresponding new output index in TPUExecuteAndUpdateVariables for
98   // each old output index in TPUExecute.
99   llvm::SmallVector<int, 8> old_to_new_output_mapping;
100   // The resources read by ReadVariableOps that are inputs to TPUExecute.
101   // Ordered by the input indices to TPUExecute
102   llvm::SmallVector<Value, 8> resources_read;
103   // Operands for the new TPUExecuteAndUpdateVariables.
104   llvm::SmallVector<Value, 8> new_operand_values;
105 };
106 
107 // Returns if an op accesses a resource.
108 //
109 // TODO(yuanzx): Decide how to make this fine-grained. Right now we do not know
110 // if the resources alias.
OpAccessesResource(Operation * op)111 bool OpAccessesResource(Operation* op) {
112   return llvm::any_of(op->getOperandTypes(), [](const Type& type) {
113     return type.isa<TF::ResourceType>() ||
114            (type.isa<TensorType>() &&
115             type.cast<TensorType>().getElementType().isa<TF::ResourceType>());
116   });
117 }
118 
119 // Finds the variable access info for a TPUExecute op.
120 //  - `check_device` specifies  whether it checks the device assignment of the
121 //  variables to match the TPUExecute op. This is optional in some context,
122 //  e.g., guaranteed by replication.
123 //  - `check_same_region` specifies whether the reads/assigns need to be in the
124 //  same region as `execute`. This is needed if `execute` is inside ReplicateOp.
BuildVariableAccessInfo(tf_device::LaunchOp execute_launch,bool check_device,bool check_same_region)125 VariableAccessesForTPUExecute BuildVariableAccessInfo(
126     tf_device::LaunchOp execute_launch, bool check_device,
127     bool check_same_region) {
128   VariableAccessesForTPUExecute infos;
129   Attribute device_attr = execute_launch.deviceAttr();
130   if (check_device && !device_attr) return infos;
131   auto func = execute_launch->getParentOfType<mlir::FuncOp>();
132 
133   // Track the first read op found, which is used later to check if there are
134   // assign ops between it and the TPUExecute op. We will exclude reads before
135   // interferencing accesses in a conservative way (see below). We do not
136   // consider resource accesses in other islands since they ordering is enforced
137   // by inter-island dependencies.
138   Operation* first_read = nullptr;
139   auto execute = cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
140   auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
141       execute_launch->getParentOp());
142   Operation* execute_parent =
143       parallel_execute ? parallel_execute.getOperation() : execute_launch;
144   // Find inputs that are variable reads.
145   for (auto operand : llvm::enumerate(execute->getOpOperands())) {
146     infos.new_operand_values.push_back(operand.value().get());
147     auto read_op = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
148         operand.value().get().getDefiningOp());
149     if (!read_op) continue;
150     if (check_same_region &&
151         read_op->getParentRegion() != execute_parent->getParentRegion())
152       continue;
153 
154     auto resource = read_op.resource();
155 
156     if (check_device) {
157       // TODO(lyandy): Wrap resource ops in tf_device.launch.
158       if (auto* resource_op = resource.getDefiningOp()) {
159         auto resource_attr = resource_op->getAttr(kDeviceAttr);
160         // Check device matching for the node defining the resource.
161         if (!resource_attr || resource_attr != device_attr) continue;
162       } else {
163         auto resource_arg = resource.dyn_cast<BlockArgument>();
164         assert(resource_arg);
165         if (resource_arg.getOwner() != &func.front()) continue;
166         // Check device matching for the argument defining the resource.
167         auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
168             resource_arg.getArgNumber(), kFuncDeviceAttr);
169         if (!resource_attr || resource_attr != device_attr) continue;
170       }
171     }
172 
173     auto emplace_res =
174         infos.per_resource_info.try_emplace(resource, VariableAccessInfo());
175     if (!emplace_res.second) {
176       LLVM_DEBUG(llvm::dbgs()
177                  << "Skipping execute that has multiple reads of a variable: "
178                  << execute << "\n");
179       infos.per_resource_info.shrink_and_clear();
180       return infos;
181     }
182 
183     auto& info = emplace_res.first->getSecond();
184     info.execute_input_index = operand.index();
185     info.read = read_op;
186     infos.new_operand_values[operand.index()] = resource;
187     infos.resources_read.push_back(resource);
188     if (!first_read || info.read->isBeforeInBlock(first_read)) {
189       first_read = info.read;
190     }
191   }
192 
193   if (!first_read) return infos;
194 
195   // Conservatively find the last resource-accessing op between first_read and
196   // execute, excluding ReadVariableOps since they are read-only. This should
197   // work fine for the reads/assigns created by resource lifting, since they are
198   // placed close to the TPUExecute.
199   Operation* last_may_modify_resource_access_before_execute = nullptr;
200   for (Operation& op :
201        llvm::reverse(llvm::make_range(std::next(first_read->getIterator()),
202                                       execute_parent->getIterator()))) {
203     if (llvm::dyn_cast<TF::ReadVariableOp>(&op)) continue;
204     if (!OpAccessesResource(&op)) continue;
205     last_may_modify_resource_access_before_execute = &op;
206     break;
207   }
208 
209   if (last_may_modify_resource_access_before_execute) {
210     // Remove the reads before last_unknown_resource_access_before_execute.
211     for (auto& op : llvm::make_range(
212              first_read->getIterator(),
213              last_may_modify_resource_access_before_execute->getIterator())) {
214       auto read = llvm::dyn_cast<TF::ReadVariableOp>(&op);
215       if (!read) continue;
216       auto info_it = infos.per_resource_info.find(read.resource());
217       if (info_it == infos.per_resource_info.end()) continue;
218       int input_index = info_it->getSecond().execute_input_index;
219       infos.new_operand_values[input_index] = execute.getOperand(input_index);
220       infos.per_resource_info.erase(info_it);
221     }
222     infos.resources_read.erase(
223         llvm::remove_if(infos.resources_read,
224                         [&](const Value resource) {
225                           return infos.per_resource_info.count(resource) == 0;
226                         }),
227         infos.resources_read.end());
228   }
229 
230   if (infos.per_resource_info.empty()) {
231     return infos;
232   }
233 
234   // Find outputs that are variable assigns.
235   Operation* last_assign = nullptr;
236   llvm::SmallPtrSet<Operation*, 8> all_assigns;
237   llvm::SmallVector<bool, 8> output_fused(execute_launch.getNumResults(),
238                                           false);
239 
240   auto execute_outputs =
241       parallel_execute
242           ? parallel_execute.GetRegionOutputs(
243                 execute_launch->getParentRegion()->getRegionNumber())
244           : execute_launch.getResults();
245   for (auto execute_output : llvm::enumerate(execute_outputs)) {
246     // TODO(lyandy): Handle updates to resource writes by remapping to parent
247     // launch result and checking if launch result is an AssignVariableOp.
248     auto result = execute_output.value();
249     if (!result.hasOneUse()) continue;
250     auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
251     if (!assign_op) continue;
252     auto resource = assign_op.resource();
253     auto it = infos.per_resource_info.find(resource);
254     if (it == infos.per_resource_info.end()) continue;
255     auto& info = it->getSecond();
256     if (info.assign) {
257       LLVM_DEBUG(llvm::dbgs()
258                  << "Skipping execute that has multiple assigns of a variable: "
259                  << execute << "\n");
260       infos.per_resource_info.shrink_and_clear();
261       return infos;
262     }
263     info.execute_output_index = execute_output.index();
264     info.assign = assign_op;
265     if (!last_assign || last_assign->isBeforeInBlock(assign_op)) {
266       last_assign = assign_op;
267     }
268     all_assigns.insert(assign_op);
269     output_fused[execute_output.index()] = true;
270   }
271 
272   // Check if there are other resource accesses after execute.
273   Operation* first_unknown_resource_access_after_execute = nullptr;
274   if (last_assign) {
275     for (auto& op : llvm::make_range(std::next(execute_parent->getIterator()),
276                                      last_assign->getIterator())) {
277       if (all_assigns.count(&op) > 0) continue;
278       if (!OpAccessesResource(&op)) continue;
279       first_unknown_resource_access_after_execute = &op;
280       break;
281     }
282   }
283   if (first_unknown_resource_access_after_execute) {
284     // Remove the assigns after first_unknown_resource_access_after_execute.
285     for (auto& op : llvm::make_range(
286              first_unknown_resource_access_after_execute->getIterator(),
287              std::next(last_assign->getIterator()))) {
288       if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
289         if (all_assigns.count(assign) == 0) continue;
290         auto info_it = infos.per_resource_info.find(assign.resource());
291         if (info_it == infos.per_resource_info.end()) continue;
292         output_fused[info_it->second.execute_output_index] = false;
293         info_it->second.execute_output_index = -1;
294         info_it->second.assign = nullptr;
295       }
296     }
297   }
298 
299   // Populate infos.old_to_new_output_mapping.
300   int new_output_index = 0;
301   infos.old_to_new_output_mapping.resize(execute_launch.getNumResults());
302   for (int i = 0, end = execute_launch.getNumResults(); i < end; ++i) {
303     if (output_fused[i]) {
304       infos.old_to_new_output_mapping[i] = -1;
305     } else {
306       infos.old_to_new_output_mapping[i] = new_output_index;
307       ++new_output_index;
308     }
309   }
310   return infos;
311 }
312 
313 // Appends result types of tf_device.parallel_execute from `start` index region
314 // (inclusive) to `end` index region (exclusive) to `output_types` and returns
315 // the number of types added.
AppendTypes(llvm::SmallVectorImpl<Type> * output_types,tf_device::ParallelExecuteOp parallel_execute,int start,int end)316 int AppendTypes(llvm::SmallVectorImpl<Type>* output_types,
317                 tf_device::ParallelExecuteOp parallel_execute, int start,
318                 int end) {
319   const int size_before = output_types->size();
320   for (int index = start; index < end; ++index) {
321     Block& block = parallel_execute.GetRegionBlockWithIndex(index);
322     auto terminator_operand_types = block.getTerminator()->getOperandTypes();
323     output_types->append(terminator_operand_types.begin(),
324                          terminator_operand_types.end());
325   }
326   return output_types->size() - size_before;
327 }
328 
329 // Replaces TPUExecute with TPUExecuteAndUpdateVariables in a
330 // tf_device.parallel_execute op.
ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp merged_execute_launch,const VariableAccessesForTPUExecute & infos,OpBuilder * builder)331 void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,
332                             tf_device::LaunchOp execute_launch,
333                             tf_device::LaunchOp merged_execute_launch,
334                             const VariableAccessesForTPUExecute& infos,
335                             OpBuilder* builder) {
336   Operation* parallel_execute_op = parallel_execute.getOperation();
337 
338   // Collect result types of tf_device.parallel_execute and update region
339   // result types with the new merged execute result types.
340   llvm::SmallVector<Type, 8> output_types;
341   const int parallel_execute_num_results = parallel_execute_op->getNumResults();
342   output_types.reserve(parallel_execute_num_results);
343   Region* execute_region = merged_execute_launch->getParentRegion();
344   const int region_index = execute_region->getRegionNumber();
345   const int num_results_before_region =
346       AppendTypes(&output_types, parallel_execute, 0, region_index);
347   // Append updated results from merged execute.
348   output_types.append(merged_execute_launch.getResultTypes().begin(),
349                       merged_execute_launch.getResultTypes().end());
350   const int num_regions = parallel_execute_op->getNumRegions();
351   const int num_results_after_region = AppendTypes(
352       &output_types, parallel_execute, region_index + 1, num_regions);
353 
354   builder->setInsertionPoint(parallel_execute);
355   auto new_parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
356       parallel_execute.getLoc(), num_regions, output_types);
357 
358   // Replace the uses of the original parallel_execute before region containing
359   // merged execute.
360   Operation* new_parallel_execute_op = new_parallel_execute.getOperation();
361   for (int i = 0; i < num_results_before_region; ++i)
362     parallel_execute_op->getResult(i).replaceAllUsesWith(
363         new_parallel_execute_op->getResult(i));
364 
365   // Replace the uses of the original parallel_execute after region containing
366   // merged execute. The number of results changed in the region containing the
367   // merged execute, but they should match, so results are replaced starting
368   // from the ends of both parallel_execute.
369   const int new_parallel_execute_num_results =
370       new_parallel_execute_op->getNumResults();
371   for (int i = 0; i < num_results_after_region; ++i)
372     parallel_execute_op->getResult(parallel_execute_num_results - i - 1)
373         .replaceAllUsesWith(new_parallel_execute_op->getResult(
374             new_parallel_execute_num_results - i - 1));
375 
376   // Replace the uses of the original parallel_execute for the region containing
377   // the merged execute.
378   auto old_region_results = parallel_execute.GetRegionOutputs(region_index);
379   for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) {
380     if (infos.old_to_new_output_mapping[i] < 0) continue;
381     old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult(
382         infos.old_to_new_output_mapping[i] + num_results_before_region));
383   }
384 
385   // Replace original terminator with new terminator for returning merged
386   // execute results.
387   Operation* old_terminator = execute_region->front().getTerminator();
388   builder->setInsertionPointToEnd(&execute_region->front());
389   builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
390                                        merged_execute_launch.getResults());
391   old_terminator->erase();
392 
393   // Remove the original TPUExecute op.
394   execute_launch.erase();
395 
396   // Move all regions from old parallel_execute to new parallel_execute.
397   for (auto region : llvm::zip(new_parallel_execute_op->getRegions(),
398                                parallel_execute_op->getRegions()))
399     std::get<0>(region).takeBody(std::get<1>(region));
400 
401   // Remove the original parallel_execute.
402   parallel_execute_op->dropAllUses();
403   parallel_execute.erase();
404 }
405 
406 // Replaces TPUExecute with TPUExecuteAndUpdateVariables.
ReplaceExecute(tf_device::LaunchOp execute_launch,tf_device::LaunchOp merged_execute_launch,const VariableAccessesForTPUExecute & infos)407 void ReplaceExecute(tf_device::LaunchOp execute_launch,
408                     tf_device::LaunchOp merged_execute_launch,
409                     const VariableAccessesForTPUExecute& infos) {
410   // Replace the uses.
411   for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) {
412     if (infos.old_to_new_output_mapping[i] < 0) continue;
413     execute_launch.getResult(i).replaceAllUsesWith(
414         merged_execute_launch.getResult(infos.old_to_new_output_mapping[i]));
415   }
416 
417   // Remove the original TPUExecute op.
418   execute_launch.getOperation()->dropAllUses();
419   execute_launch.erase();
420 }
421 
422 // Returns TPUCompileMlir op that generates the program executed by the
423 // TPUExecute op.
GetTPUCompileOp(tf_device::LaunchOp execute_launch)424 TF::_TPUCompileMlirOp GetTPUCompileOp(tf_device::LaunchOp execute_launch) {
425   auto execute =
426       llvm::dyn_cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
427   if (!execute) return {};
428   auto compile_launch = llvm::dyn_cast_or_null<tf_device::LaunchOp>(
429       execute.getOperand(execute.getNumOperands() - 1).getDefiningOp());
430   if (!compile_launch) return {};
431   return llvm::dyn_cast<TF::_TPUCompileMlirOp>(
432       compile_launch.GetBody().front());
433 }
434 
435 // Updates the serialized module associated with the TPUExecute op to reflect
436 // the aliasing information for better management of device memory.
UpdateSerializedModule(tf_device::LaunchOp execute_launch,VariableAccessesForTPUExecute & infos)437 LogicalResult UpdateSerializedModule(tf_device::LaunchOp execute_launch,
438                                      VariableAccessesForTPUExecute& infos) {
439   TF::_TPUCompileMlirOp compile = GetTPUCompileOp(execute_launch);
440 
441   // Skip adding alias information in case of model parallelism i.e.,
442   // TPUCompileMlir op generates multiple programs.
443   if (!compile || compile.program().size() > 1) return failure();
444 
445   // Parse the serialized module
446   mlir::OwningModuleRef module_ref;
447   tensorflow::Status status = tensorflow::DeserializeMlirModule(
448       compile.mlir_module().str(), compile.getContext(), &module_ref);
449   if (!status.ok()) {
450     LLVM_DEBUG(llvm::dbgs() << "Error in parsing serialized module: "
451                             << status.error_message() << "\n");
452 
453     return failure();
454   }
455 
456   // Add aliasing information to main function arguments.
457   FuncOp main_func = module_ref->lookupSymbol<FuncOp>("main");
458   if (!main_func) return failure();
459 
460   OpBuilder builder(main_func.getContext());
461   for (auto resource : infos.resources_read) {
462     auto& info = infos.per_resource_info[resource];
463     if (info.execute_input_index < 0 || info.execute_output_index < 0) continue;
464     auto aliasing_attr = main_func.getArgAttrOfType<mlir::IntegerAttr>(
465         info.execute_input_index, kAliasingAttr);
466 
467     // Set only if aliasing attribute does not exist.
468     if (!aliasing_attr) {
469       main_func.setArgAttr(
470           info.execute_input_index, kAliasingAttr,
471           builder.getI64IntegerAttr(info.execute_output_index));
472       continue;
473     }
474     // If aliasing attribute already exists, it must match the new value.
475     assert(aliasing_attr.getInt() == info.execute_output_index);
476   }
477 
478   // Serialize the updated module back into the TPUCompileMlir op.
479   auto module_string = tensorflow::SerializeMlirModule(module_ref.get());
480   compile.mlir_moduleAttr(
481       mlir::StringAttr::get(module_ref->getContext(), module_string));
482   return success();
483 }
484 
485 // Merges the variable accesses into one TPUExecute op.
MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,bool check_device,bool check_same_region,OpBuilder * builder)486 LogicalResult MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
487                                     bool check_device, bool check_same_region,
488                                     OpBuilder* builder) {
489   auto infos =
490       BuildVariableAccessInfo(execute_launch, check_device, check_same_region);
491   if (infos.per_resource_info.empty()) return success();
492 
493   // Update the serialized module with aliasing information for better memory
494   // management on device.
495   // TODO(b/172608422): Benchmark the cost of deserialization/serialization of
496   // the attached module. We can avoid it by serializing it at the end of the
497   // bridge pipeline.
498   if (failed(UpdateSerializedModule(execute_launch, infos))) {
499     LLVM_DEBUG(
500         llvm::dbgs()
501         << "Unable to update the serialized module with aliasing information "
502            "which can lead to poor memory management on device.\n");
503   }
504 
505   // Start creating the new TPUExecuteAndUpdateVariables op.
506   builder->setInsertionPoint(execute_launch);
507   // Output types. Skip the original outputs for fused assigns.
508   llvm::SmallVector<Type, 8> new_output_types;
509   int old_output_index = 0;
510   for (const auto& type : execute_launch.getResultTypes()) {
511     if (infos.old_to_new_output_mapping[old_output_index] >= 0) {
512       new_output_types.push_back(type);
513     }
514     ++old_output_index;
515   }
516   // The attributes for fused variable reads and updates.
517   llvm::SmallVector<int64_t, 8> device_var_reads_indices;
518   llvm::SmallVector<int64_t, 8> device_var_updates_indices;
519   for (auto resource : infos.resources_read) {
520     const auto& info = infos.per_resource_info[resource];
521     device_var_reads_indices.push_back(info.execute_input_index);
522     device_var_updates_indices.push_back(info.execute_output_index);
523   }
524 
525   // Check that all resources op are either read or written to.
526   for (auto it : llvm::enumerate(infos.new_operand_values)) {
527     Type type = it.value().getType();
528     if (type.isa<TensorType>() &&
529         type.cast<TensorType>().getElementType().isa<TF::ResourceType>()) {
530       if (!llvm::is_contained(device_var_reads_indices, it.index()) &&
531           !llvm::is_contained(device_var_updates_indices, it.index())) {
532         return execute_launch.GetBody().front().emitError("operand #")
533                << it.index()
534                << " is a resource that was neither read nor written to; this "
535                   "resource potentially failed to be hoisted";
536       }
537     }
538   }
539 
540   // Create the merged execute and update variables op.
541   auto merged_execute = builder->create<TF::TPUExecuteAndUpdateVariablesOp>(
542       execute_launch.getLoc(), new_output_types, infos.new_operand_values,
543       llvm::ArrayRef<NamedAttribute>{
544           builder->getNamedAttr(
545               "device_var_reads_indices",
546               builder->getI64ArrayAttr(device_var_reads_indices)),
547           builder->getNamedAttr(
548               "device_var_updates_indices",
549               builder->getI64ArrayAttr(device_var_updates_indices))});
550 
551   // Wrap in launch for device assignment.
552   auto merged_execute_launch = builder->create<tf_device::LaunchOp>(
553       merged_execute.getLoc(), execute_launch.deviceAttr(),
554       merged_execute.getResultTypes());
555   merged_execute_launch.body().push_back(new Block);
556 
557   builder->setInsertionPointToEnd(&merged_execute_launch.GetBody());
558   builder->create<tf_device::ReturnOp>(merged_execute.getLoc(),
559                                        merged_execute.getResults());
560 
561   merged_execute.getOperation()->moveBefore(
562       merged_execute_launch.GetBody().getTerminator());
563 
564   if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
565           execute_launch->getParentOp()))
566     ReplaceParallelExecute(parallel_execute, execute_launch,
567                            merged_execute_launch, infos, builder);
568   else
569     ReplaceExecute(execute_launch, merged_execute_launch, infos);
570 
571   // Remove the assign ops.
572   for (const auto& entry : infos.per_resource_info) {
573     const auto& info = entry.getSecond();
574     if (info.assign) info.assign->erase();
575   }
576 
577   // Remove the read ops if they have no more uses.
578   for (const auto& entry : infos.per_resource_info) {
579     const auto& info = entry.getSecond();
580     if (info.read->use_empty()) info.read->erase();
581   }
582   return success();
583 }
584 
585 // Checks if an ops parent is a tf_device.parallel_execute and the region the
586 // op is in is perfectly wrapped.
ParentParallelExecuteWrapsSingleOp(Operation * op)587 bool ParentParallelExecuteWrapsSingleOp(Operation* op) {
588   auto parallel_execute =
589       llvm::dyn_cast<tf_device::ParallelExecuteOp>(op->getParentOp());
590   if (!parallel_execute) return true;
591 
592   return parallel_execute.RegionWrapsSingleOp(
593       op->getParentRegion()->getRegionNumber());
594 }
595 
runOnFunction()596 void TPUMergeVariablesWithExecutePass::runOnFunction() {
597   // Find all the executes first, since we will mutate the nodes around each
598   // execute.
599   llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
600   getFunction().walk([&](tf_device::LaunchOp op) {
601     if (op.WrapsSingleOp() &&
602         llvm::isa<TF::TPUExecuteOp>(op.GetBody().front()) &&
603         ParentParallelExecuteWrapsSingleOp(op))
604       execute_launches.push_back(op);
605   });
606 
607   for (auto execute_launch : execute_launches) {
608     OpBuilder builder(&getContext());
609     const bool parent_is_replicate =
610         llvm::isa<tf_device::ReplicateOp>(execute_launch->getParentOp()) ||
611         (llvm::isa<tf_device::ParallelExecuteOp>(
612              execute_launch->getParentOp()) &&
613          llvm::isa<tf_device::ReplicateOp>(
614              execute_launch->getParentOp()->getParentOp()));
615 
616     // If this is inside a tf_device::ReplicateOp, the variables are guaranteed
617     // to be on the same device as the TPUExecute op. Skip device checking in
618     // that case, but we need to check that we are only merging reads/assigns
619     // that are also in this replicated region.
620     if (failed(MergeForOneTPUExecute(
621             execute_launch, /*check_device=*/!parent_is_replicate,
622             /*check_same_region=*/parent_is_replicate, &builder))) {
623       signalPassFailure();
624       return;
625     }
626   }
627 }
628 
629 }  // namespace
630 
631 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUMergeVariablesWithExecutePass()632 CreateTPUMergeVariablesWithExecutePass() {
633   return std::make_unique<TPUMergeVariablesWithExecutePass>();
634 }
635 
636 static PassRegistration<TPUMergeVariablesWithExecutePass> pass(
637     "tf-tpu-merge-variables-with-execute",
638     "Merges device variable reads/updates into tpu execute nodes");
639 
640 }  // namespace TFTPU
641 }  // namespace mlir
642