1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/ADT/iterator_range.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/Identifier.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
48 #define DEBUG_TYPE "tf-tpu-merge-variables-with-execute"
49
50 namespace mlir {
51 namespace TFTPU {
52
53 namespace {
54 constexpr char kAliasingAttr[] = "tf.aliasing_output";
55 constexpr char kDeviceAttr[] = "device";
56 constexpr char kFuncDeviceAttr[] = "tf.device";
57
58 // A pass that finds on-device resource variable reads/assigns surrounding a
59 // tf.TPUExecute op, and merges them into a tf.TPUExecuteAndUpdateVariables.
60 // This allows the TPU execution to perform in-place variable updates.
61 //
62 // For example,
63 //
64 // %0 = "tf.ReadVariableOp"(%arg0)
65 // %1 = "tf.ReadVariableOp"(%arg1)
66 // %2 = "tf.TPUExecute"(%0, %1, %compile)
67 // %3 = "tf.AssignVariableOp"(%arg0, %2)
68 //
69 // will be transformed into
70 //
71 // %2 = "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %compile)
72 // { device_var_reads_indices = [0, 1],
73 // device_var_updates_indices = [0, -1] }
74 //
75 // The transformation happens only for on-device variables. The above
76 // transformation requires %arg0, %arg1 to have the same device assignment as
77 // the TPUExecute op.
78
79 struct TPUMergeVariablesWithExecutePass
80 : public PassWrapper<TPUMergeVariablesWithExecutePass, FunctionPass> {
81 void runOnFunction() override;
82 };
83
84 // Information for a pair of input/output of the TPUExecute op and the
85 // surrounding read/assign ops.
86 struct VariableAccessInfo {
87 int execute_input_index = -1;
88 int execute_output_index = -1;
89 Operation* read = nullptr;
90 Operation* assign = nullptr;
91 };
92
93 // Information about all resource accesses to be fused into a TPUExecute op.
94 struct VariableAccessesForTPUExecute {
95 // Maps each resource detected to VariableAccessInfo.
96 llvm::SmallDenseMap<Value, VariableAccessInfo, 8> per_resource_info;
97 // The corresponding new output index in TPUExecuteAndUpdateVariables for
98 // each old output index in TPUExecute.
99 llvm::SmallVector<int, 8> old_to_new_output_mapping;
100 // The resources read by ReadVariableOps that are inputs to TPUExecute.
101 // Ordered by the input indices to TPUExecute
102 llvm::SmallVector<Value, 8> resources_read;
103 // Operands for the new TPUExecuteAndUpdateVariables.
104 llvm::SmallVector<Value, 8> new_operand_values;
105 };
106
107 // Returns if an op accesses a resource.
108 //
109 // TODO(yuanzx): Decide how to make this fine-grained. Right now we do not know
110 // if the resources alias.
OpAccessesResource(Operation * op)111 bool OpAccessesResource(Operation* op) {
112 return llvm::any_of(op->getOperandTypes(), [](const Type& type) {
113 return type.isa<TF::ResourceType>() ||
114 (type.isa<TensorType>() &&
115 type.cast<TensorType>().getElementType().isa<TF::ResourceType>());
116 });
117 }
118
119 // Finds the variable access info for a TPUExecute op.
120 // - `check_device` specifies whether it checks the device assignment of the
121 // variables to match the TPUExecute op. This is optional in some context,
122 // e.g., guaranteed by replication.
123 // - `check_same_region` specifies whether the reads/assigns need to be in the
124 // same region as `execute`. This is needed if `execute` is inside ReplicateOp.
BuildVariableAccessInfo(tf_device::LaunchOp execute_launch,bool check_device,bool check_same_region)125 VariableAccessesForTPUExecute BuildVariableAccessInfo(
126 tf_device::LaunchOp execute_launch, bool check_device,
127 bool check_same_region) {
128 VariableAccessesForTPUExecute infos;
129 Attribute device_attr = execute_launch.deviceAttr();
130 if (check_device && !device_attr) return infos;
131 auto func = execute_launch->getParentOfType<mlir::FuncOp>();
132
133 // Track the first read op found, which is used later to check if there are
134 // assign ops between it and the TPUExecute op. We will exclude reads before
135 // interferencing accesses in a conservative way (see below). We do not
136 // consider resource accesses in other islands since they ordering is enforced
137 // by inter-island dependencies.
138 Operation* first_read = nullptr;
139 auto execute = cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
140 auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
141 execute_launch->getParentOp());
142 Operation* execute_parent =
143 parallel_execute ? parallel_execute.getOperation() : execute_launch;
144 // Find inputs that are variable reads.
145 for (auto operand : llvm::enumerate(execute->getOpOperands())) {
146 infos.new_operand_values.push_back(operand.value().get());
147 auto read_op = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
148 operand.value().get().getDefiningOp());
149 if (!read_op) continue;
150 if (check_same_region &&
151 read_op->getParentRegion() != execute_parent->getParentRegion())
152 continue;
153
154 auto resource = read_op.resource();
155
156 if (check_device) {
157 // TODO(lyandy): Wrap resource ops in tf_device.launch.
158 if (auto* resource_op = resource.getDefiningOp()) {
159 auto resource_attr = resource_op->getAttr(kDeviceAttr);
160 // Check device matching for the node defining the resource.
161 if (!resource_attr || resource_attr != device_attr) continue;
162 } else {
163 auto resource_arg = resource.dyn_cast<BlockArgument>();
164 assert(resource_arg);
165 if (resource_arg.getOwner() != &func.front()) continue;
166 // Check device matching for the argument defining the resource.
167 auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
168 resource_arg.getArgNumber(), kFuncDeviceAttr);
169 if (!resource_attr || resource_attr != device_attr) continue;
170 }
171 }
172
173 auto emplace_res =
174 infos.per_resource_info.try_emplace(resource, VariableAccessInfo());
175 if (!emplace_res.second) {
176 LLVM_DEBUG(llvm::dbgs()
177 << "Skipping execute that has multiple reads of a variable: "
178 << execute << "\n");
179 infos.per_resource_info.shrink_and_clear();
180 return infos;
181 }
182
183 auto& info = emplace_res.first->getSecond();
184 info.execute_input_index = operand.index();
185 info.read = read_op;
186 infos.new_operand_values[operand.index()] = resource;
187 infos.resources_read.push_back(resource);
188 if (!first_read || info.read->isBeforeInBlock(first_read)) {
189 first_read = info.read;
190 }
191 }
192
193 if (!first_read) return infos;
194
195 // Conservatively find the last resource-accessing op between first_read and
196 // execute, excluding ReadVariableOps since they are read-only. This should
197 // work fine for the reads/assigns created by resource lifting, since they are
198 // placed close to the TPUExecute.
199 Operation* last_may_modify_resource_access_before_execute = nullptr;
200 for (Operation& op :
201 llvm::reverse(llvm::make_range(std::next(first_read->getIterator()),
202 execute_parent->getIterator()))) {
203 if (llvm::dyn_cast<TF::ReadVariableOp>(&op)) continue;
204 if (!OpAccessesResource(&op)) continue;
205 last_may_modify_resource_access_before_execute = &op;
206 break;
207 }
208
209 if (last_may_modify_resource_access_before_execute) {
210 // Remove the reads before last_unknown_resource_access_before_execute.
211 for (auto& op : llvm::make_range(
212 first_read->getIterator(),
213 last_may_modify_resource_access_before_execute->getIterator())) {
214 auto read = llvm::dyn_cast<TF::ReadVariableOp>(&op);
215 if (!read) continue;
216 auto info_it = infos.per_resource_info.find(read.resource());
217 if (info_it == infos.per_resource_info.end()) continue;
218 int input_index = info_it->getSecond().execute_input_index;
219 infos.new_operand_values[input_index] = execute.getOperand(input_index);
220 infos.per_resource_info.erase(info_it);
221 }
222 infos.resources_read.erase(
223 llvm::remove_if(infos.resources_read,
224 [&](const Value resource) {
225 return infos.per_resource_info.count(resource) == 0;
226 }),
227 infos.resources_read.end());
228 }
229
230 if (infos.per_resource_info.empty()) {
231 return infos;
232 }
233
234 // Find outputs that are variable assigns.
235 Operation* last_assign = nullptr;
236 llvm::SmallPtrSet<Operation*, 8> all_assigns;
237 llvm::SmallVector<bool, 8> output_fused(execute_launch.getNumResults(),
238 false);
239
240 auto execute_outputs =
241 parallel_execute
242 ? parallel_execute.GetRegionOutputs(
243 execute_launch->getParentRegion()->getRegionNumber())
244 : execute_launch.getResults();
245 for (auto execute_output : llvm::enumerate(execute_outputs)) {
246 // TODO(lyandy): Handle updates to resource writes by remapping to parent
247 // launch result and checking if launch result is an AssignVariableOp.
248 auto result = execute_output.value();
249 if (!result.hasOneUse()) continue;
250 auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
251 if (!assign_op) continue;
252 auto resource = assign_op.resource();
253 auto it = infos.per_resource_info.find(resource);
254 if (it == infos.per_resource_info.end()) continue;
255 auto& info = it->getSecond();
256 if (info.assign) {
257 LLVM_DEBUG(llvm::dbgs()
258 << "Skipping execute that has multiple assigns of a variable: "
259 << execute << "\n");
260 infos.per_resource_info.shrink_and_clear();
261 return infos;
262 }
263 info.execute_output_index = execute_output.index();
264 info.assign = assign_op;
265 if (!last_assign || last_assign->isBeforeInBlock(assign_op)) {
266 last_assign = assign_op;
267 }
268 all_assigns.insert(assign_op);
269 output_fused[execute_output.index()] = true;
270 }
271
272 // Check if there are other resource accesses after execute.
273 Operation* first_unknown_resource_access_after_execute = nullptr;
274 if (last_assign) {
275 for (auto& op : llvm::make_range(std::next(execute_parent->getIterator()),
276 last_assign->getIterator())) {
277 if (all_assigns.count(&op) > 0) continue;
278 if (!OpAccessesResource(&op)) continue;
279 first_unknown_resource_access_after_execute = &op;
280 break;
281 }
282 }
283 if (first_unknown_resource_access_after_execute) {
284 // Remove the assigns after first_unknown_resource_access_after_execute.
285 for (auto& op : llvm::make_range(
286 first_unknown_resource_access_after_execute->getIterator(),
287 std::next(last_assign->getIterator()))) {
288 if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
289 if (all_assigns.count(assign) == 0) continue;
290 auto info_it = infos.per_resource_info.find(assign.resource());
291 if (info_it == infos.per_resource_info.end()) continue;
292 output_fused[info_it->second.execute_output_index] = false;
293 info_it->second.execute_output_index = -1;
294 info_it->second.assign = nullptr;
295 }
296 }
297 }
298
299 // Populate infos.old_to_new_output_mapping.
300 int new_output_index = 0;
301 infos.old_to_new_output_mapping.resize(execute_launch.getNumResults());
302 for (int i = 0, end = execute_launch.getNumResults(); i < end; ++i) {
303 if (output_fused[i]) {
304 infos.old_to_new_output_mapping[i] = -1;
305 } else {
306 infos.old_to_new_output_mapping[i] = new_output_index;
307 ++new_output_index;
308 }
309 }
310 return infos;
311 }
312
313 // Appends result types of tf_device.parallel_execute from `start` index region
314 // (inclusive) to `end` index region (exclusive) to `output_types` and returns
315 // the number of types added.
AppendTypes(llvm::SmallVectorImpl<Type> * output_types,tf_device::ParallelExecuteOp parallel_execute,int start,int end)316 int AppendTypes(llvm::SmallVectorImpl<Type>* output_types,
317 tf_device::ParallelExecuteOp parallel_execute, int start,
318 int end) {
319 const int size_before = output_types->size();
320 for (int index = start; index < end; ++index) {
321 Block& block = parallel_execute.GetRegionBlockWithIndex(index);
322 auto terminator_operand_types = block.getTerminator()->getOperandTypes();
323 output_types->append(terminator_operand_types.begin(),
324 terminator_operand_types.end());
325 }
326 return output_types->size() - size_before;
327 }
328
329 // Replaces TPUExecute with TPUExecuteAndUpdateVariables in a
330 // tf_device.parallel_execute op.
ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp merged_execute_launch,const VariableAccessesForTPUExecute & infos,OpBuilder * builder)331 void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,
332 tf_device::LaunchOp execute_launch,
333 tf_device::LaunchOp merged_execute_launch,
334 const VariableAccessesForTPUExecute& infos,
335 OpBuilder* builder) {
336 Operation* parallel_execute_op = parallel_execute.getOperation();
337
338 // Collect result types of tf_device.parallel_execute and update region
339 // result types with the new merged execute result types.
340 llvm::SmallVector<Type, 8> output_types;
341 const int parallel_execute_num_results = parallel_execute_op->getNumResults();
342 output_types.reserve(parallel_execute_num_results);
343 Region* execute_region = merged_execute_launch->getParentRegion();
344 const int region_index = execute_region->getRegionNumber();
345 const int num_results_before_region =
346 AppendTypes(&output_types, parallel_execute, 0, region_index);
347 // Append updated results from merged execute.
348 output_types.append(merged_execute_launch.getResultTypes().begin(),
349 merged_execute_launch.getResultTypes().end());
350 const int num_regions = parallel_execute_op->getNumRegions();
351 const int num_results_after_region = AppendTypes(
352 &output_types, parallel_execute, region_index + 1, num_regions);
353
354 builder->setInsertionPoint(parallel_execute);
355 auto new_parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
356 parallel_execute.getLoc(), num_regions, output_types);
357
358 // Replace the uses of the original parallel_execute before region containing
359 // merged execute.
360 Operation* new_parallel_execute_op = new_parallel_execute.getOperation();
361 for (int i = 0; i < num_results_before_region; ++i)
362 parallel_execute_op->getResult(i).replaceAllUsesWith(
363 new_parallel_execute_op->getResult(i));
364
365 // Replace the uses of the original parallel_execute after region containing
366 // merged execute. The number of results changed in the region containing the
367 // merged execute, but they should match, so results are replaced starting
368 // from the ends of both parallel_execute.
369 const int new_parallel_execute_num_results =
370 new_parallel_execute_op->getNumResults();
371 for (int i = 0; i < num_results_after_region; ++i)
372 parallel_execute_op->getResult(parallel_execute_num_results - i - 1)
373 .replaceAllUsesWith(new_parallel_execute_op->getResult(
374 new_parallel_execute_num_results - i - 1));
375
376 // Replace the uses of the original parallel_execute for the region containing
377 // the merged execute.
378 auto old_region_results = parallel_execute.GetRegionOutputs(region_index);
379 for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) {
380 if (infos.old_to_new_output_mapping[i] < 0) continue;
381 old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult(
382 infos.old_to_new_output_mapping[i] + num_results_before_region));
383 }
384
385 // Replace original terminator with new terminator for returning merged
386 // execute results.
387 Operation* old_terminator = execute_region->front().getTerminator();
388 builder->setInsertionPointToEnd(&execute_region->front());
389 builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
390 merged_execute_launch.getResults());
391 old_terminator->erase();
392
393 // Remove the original TPUExecute op.
394 execute_launch.erase();
395
396 // Move all regions from old parallel_execute to new parallel_execute.
397 for (auto region : llvm::zip(new_parallel_execute_op->getRegions(),
398 parallel_execute_op->getRegions()))
399 std::get<0>(region).takeBody(std::get<1>(region));
400
401 // Remove the original parallel_execute.
402 parallel_execute_op->dropAllUses();
403 parallel_execute.erase();
404 }
405
406 // Replaces TPUExecute with TPUExecuteAndUpdateVariables.
ReplaceExecute(tf_device::LaunchOp execute_launch,tf_device::LaunchOp merged_execute_launch,const VariableAccessesForTPUExecute & infos)407 void ReplaceExecute(tf_device::LaunchOp execute_launch,
408 tf_device::LaunchOp merged_execute_launch,
409 const VariableAccessesForTPUExecute& infos) {
410 // Replace the uses.
411 for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) {
412 if (infos.old_to_new_output_mapping[i] < 0) continue;
413 execute_launch.getResult(i).replaceAllUsesWith(
414 merged_execute_launch.getResult(infos.old_to_new_output_mapping[i]));
415 }
416
417 // Remove the original TPUExecute op.
418 execute_launch.getOperation()->dropAllUses();
419 execute_launch.erase();
420 }
421
422 // Returns TPUCompileMlir op that generates the program executed by the
423 // TPUExecute op.
GetTPUCompileOp(tf_device::LaunchOp execute_launch)424 TF::_TPUCompileMlirOp GetTPUCompileOp(tf_device::LaunchOp execute_launch) {
425 auto execute =
426 llvm::dyn_cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
427 if (!execute) return {};
428 auto compile_launch = llvm::dyn_cast_or_null<tf_device::LaunchOp>(
429 execute.getOperand(execute.getNumOperands() - 1).getDefiningOp());
430 if (!compile_launch) return {};
431 return llvm::dyn_cast<TF::_TPUCompileMlirOp>(
432 compile_launch.GetBody().front());
433 }
434
435 // Updates the serialized module associated with the TPUExecute op to reflect
436 // the aliasing information for better management of device memory.
UpdateSerializedModule(tf_device::LaunchOp execute_launch,VariableAccessesForTPUExecute & infos)437 LogicalResult UpdateSerializedModule(tf_device::LaunchOp execute_launch,
438 VariableAccessesForTPUExecute& infos) {
439 TF::_TPUCompileMlirOp compile = GetTPUCompileOp(execute_launch);
440
441 // Skip adding alias information in case of model parallelism i.e.,
442 // TPUCompileMlir op generates multiple programs.
443 if (!compile || compile.program().size() > 1) return failure();
444
445 // Parse the serialized module
446 mlir::OwningModuleRef module_ref;
447 tensorflow::Status status = tensorflow::DeserializeMlirModule(
448 compile.mlir_module().str(), compile.getContext(), &module_ref);
449 if (!status.ok()) {
450 LLVM_DEBUG(llvm::dbgs() << "Error in parsing serialized module: "
451 << status.error_message() << "\n");
452
453 return failure();
454 }
455
456 // Add aliasing information to main function arguments.
457 FuncOp main_func = module_ref->lookupSymbol<FuncOp>("main");
458 if (!main_func) return failure();
459
460 OpBuilder builder(main_func.getContext());
461 for (auto resource : infos.resources_read) {
462 auto& info = infos.per_resource_info[resource];
463 if (info.execute_input_index < 0 || info.execute_output_index < 0) continue;
464 auto aliasing_attr = main_func.getArgAttrOfType<mlir::IntegerAttr>(
465 info.execute_input_index, kAliasingAttr);
466
467 // Set only if aliasing attribute does not exist.
468 if (!aliasing_attr) {
469 main_func.setArgAttr(
470 info.execute_input_index, kAliasingAttr,
471 builder.getI64IntegerAttr(info.execute_output_index));
472 continue;
473 }
474 // If aliasing attribute already exists, it must match the new value.
475 assert(aliasing_attr.getInt() == info.execute_output_index);
476 }
477
478 // Serialize the updated module back into the TPUCompileMlir op.
479 auto module_string = tensorflow::SerializeMlirModule(module_ref.get());
480 compile.mlir_moduleAttr(
481 mlir::StringAttr::get(module_ref->getContext(), module_string));
482 return success();
483 }
484
485 // Merges the variable accesses into one TPUExecute op.
MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,bool check_device,bool check_same_region,OpBuilder * builder)486 LogicalResult MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
487 bool check_device, bool check_same_region,
488 OpBuilder* builder) {
489 auto infos =
490 BuildVariableAccessInfo(execute_launch, check_device, check_same_region);
491 if (infos.per_resource_info.empty()) return success();
492
493 // Update the serialized module with aliasing information for better memory
494 // management on device.
495 // TODO(b/172608422): Benchmark the cost of deserialization/serialization of
496 // the attached module. We can avoid it by serializing it at the end of the
497 // bridge pipeline.
498 if (failed(UpdateSerializedModule(execute_launch, infos))) {
499 LLVM_DEBUG(
500 llvm::dbgs()
501 << "Unable to update the serialized module with aliasing information "
502 "which can lead to poor memory management on device.\n");
503 }
504
505 // Start creating the new TPUExecuteAndUpdateVariables op.
506 builder->setInsertionPoint(execute_launch);
507 // Output types. Skip the original outputs for fused assigns.
508 llvm::SmallVector<Type, 8> new_output_types;
509 int old_output_index = 0;
510 for (const auto& type : execute_launch.getResultTypes()) {
511 if (infos.old_to_new_output_mapping[old_output_index] >= 0) {
512 new_output_types.push_back(type);
513 }
514 ++old_output_index;
515 }
516 // The attributes for fused variable reads and updates.
517 llvm::SmallVector<int64_t, 8> device_var_reads_indices;
518 llvm::SmallVector<int64_t, 8> device_var_updates_indices;
519 for (auto resource : infos.resources_read) {
520 const auto& info = infos.per_resource_info[resource];
521 device_var_reads_indices.push_back(info.execute_input_index);
522 device_var_updates_indices.push_back(info.execute_output_index);
523 }
524
525 // Check that all resources op are either read or written to.
526 for (auto it : llvm::enumerate(infos.new_operand_values)) {
527 Type type = it.value().getType();
528 if (type.isa<TensorType>() &&
529 type.cast<TensorType>().getElementType().isa<TF::ResourceType>()) {
530 if (!llvm::is_contained(device_var_reads_indices, it.index()) &&
531 !llvm::is_contained(device_var_updates_indices, it.index())) {
532 return execute_launch.GetBody().front().emitError("operand #")
533 << it.index()
534 << " is a resource that was neither read nor written to; this "
535 "resource potentially failed to be hoisted";
536 }
537 }
538 }
539
540 // Create the merged execute and update variables op.
541 auto merged_execute = builder->create<TF::TPUExecuteAndUpdateVariablesOp>(
542 execute_launch.getLoc(), new_output_types, infos.new_operand_values,
543 llvm::ArrayRef<NamedAttribute>{
544 builder->getNamedAttr(
545 "device_var_reads_indices",
546 builder->getI64ArrayAttr(device_var_reads_indices)),
547 builder->getNamedAttr(
548 "device_var_updates_indices",
549 builder->getI64ArrayAttr(device_var_updates_indices))});
550
551 // Wrap in launch for device assignment.
552 auto merged_execute_launch = builder->create<tf_device::LaunchOp>(
553 merged_execute.getLoc(), execute_launch.deviceAttr(),
554 merged_execute.getResultTypes());
555 merged_execute_launch.body().push_back(new Block);
556
557 builder->setInsertionPointToEnd(&merged_execute_launch.GetBody());
558 builder->create<tf_device::ReturnOp>(merged_execute.getLoc(),
559 merged_execute.getResults());
560
561 merged_execute.getOperation()->moveBefore(
562 merged_execute_launch.GetBody().getTerminator());
563
564 if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
565 execute_launch->getParentOp()))
566 ReplaceParallelExecute(parallel_execute, execute_launch,
567 merged_execute_launch, infos, builder);
568 else
569 ReplaceExecute(execute_launch, merged_execute_launch, infos);
570
571 // Remove the assign ops.
572 for (const auto& entry : infos.per_resource_info) {
573 const auto& info = entry.getSecond();
574 if (info.assign) info.assign->erase();
575 }
576
577 // Remove the read ops if they have no more uses.
578 for (const auto& entry : infos.per_resource_info) {
579 const auto& info = entry.getSecond();
580 if (info.read->use_empty()) info.read->erase();
581 }
582 return success();
583 }
584
585 // Checks if an ops parent is a tf_device.parallel_execute and the region the
586 // op is in is perfectly wrapped.
ParentParallelExecuteWrapsSingleOp(Operation * op)587 bool ParentParallelExecuteWrapsSingleOp(Operation* op) {
588 auto parallel_execute =
589 llvm::dyn_cast<tf_device::ParallelExecuteOp>(op->getParentOp());
590 if (!parallel_execute) return true;
591
592 return parallel_execute.RegionWrapsSingleOp(
593 op->getParentRegion()->getRegionNumber());
594 }
595
runOnFunction()596 void TPUMergeVariablesWithExecutePass::runOnFunction() {
597 // Find all the executes first, since we will mutate the nodes around each
598 // execute.
599 llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
600 getFunction().walk([&](tf_device::LaunchOp op) {
601 if (op.WrapsSingleOp() &&
602 llvm::isa<TF::TPUExecuteOp>(op.GetBody().front()) &&
603 ParentParallelExecuteWrapsSingleOp(op))
604 execute_launches.push_back(op);
605 });
606
607 for (auto execute_launch : execute_launches) {
608 OpBuilder builder(&getContext());
609 const bool parent_is_replicate =
610 llvm::isa<tf_device::ReplicateOp>(execute_launch->getParentOp()) ||
611 (llvm::isa<tf_device::ParallelExecuteOp>(
612 execute_launch->getParentOp()) &&
613 llvm::isa<tf_device::ReplicateOp>(
614 execute_launch->getParentOp()->getParentOp()));
615
616 // If this is inside a tf_device::ReplicateOp, the variables are guaranteed
617 // to be on the same device as the TPUExecute op. Skip device checking in
618 // that case, but we need to check that we are only merging reads/assigns
619 // that are also in this replicated region.
620 if (failed(MergeForOneTPUExecute(
621 execute_launch, /*check_device=*/!parent_is_replicate,
622 /*check_same_region=*/parent_is_replicate, &builder))) {
623 signalPassFailure();
624 return;
625 }
626 }
627 }
628
629 } // namespace
630
631 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUMergeVariablesWithExecutePass()632 CreateTPUMergeVariablesWithExecutePass() {
633 return std::make_unique<TPUMergeVariablesWithExecutePass>();
634 }
635
636 static PassRegistration<TPUMergeVariablesWithExecutePass> pass(
637 "tf-tpu-merge-variables-with-execute",
638 "Merges device variable reads/updates into tpu execute nodes");
639
640 } // namespace TFTPU
641 } // namespace mlir
642