1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h"
17
18 #include "absl/algorithm/container.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
21 #include "mlir/Transforms/LoopUtils.h" // from @llvm-project
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace xla {
25 namespace experimental {
26
27 using mlir::OpBuilder;
28
GetBoundAffineMapFrom(mlir::Operation * op)29 BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
30 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
31 return {load.getAffineMap(),
32 std::vector<mlir::Value>(load.getMapOperands().begin(),
33 load.getMapOperands().end())};
34 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
35 return {store.getAffineMap(),
36 std::vector<mlir::Value>(store.getMapOperands().begin(),
37 store.getMapOperands().end())};
38 } else {
39 CHECK(false);
40 }
41 }
42
CloneWithNewAffineMap(mlir::Operation * op,BoundAffineMap new_affine,OpBuilder builder)43 mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
44 BoundAffineMap new_affine,
45 OpBuilder builder) {
46 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
47 return builder.create<mlir::AffineLoadOp>(
48 builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map,
49 new_affine.operands);
50 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
51 return builder.create<mlir::AffineStoreOp>(
52 builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(),
53 new_affine.affine_map, new_affine.operands);
54 } else {
55 CHECK(false);
56 }
57 }
58
IsSimpleLoop(mlir::AffineForOp loop)59 bool IsSimpleLoop(mlir::AffineForOp loop) {
60 return loop.getLowerBoundMap().isSingleConstant() &&
61 loop.getLowerBoundMap().getSingleConstantResult() == 0 &&
62 loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 &&
63 std::next(loop.region().begin()) == loop.region().end();
64 }
65
CreateNestedSimpleLoops(absl::Span<const int64_t> upper_bounds,OpBuilder builder)66 std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
67 absl::Span<const int64_t> upper_bounds, OpBuilder builder) {
68 std::vector<mlir::AffineForOp> loops;
69 loops.reserve(upper_bounds.size());
70 for (int64_t dim : upper_bounds) {
71 auto loop =
72 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
73 loops.push_back(loop);
74 builder = OpBuilder::atBlockTerminator(loop.getBody());
75 }
76 return loops;
77 }
78
SetBoundForSimpleLoop(mlir::AffineForOp loop,mlir::AffineExpr new_bound,OpBuilder builder)79 void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
80 OpBuilder builder) {
81 CHECK(IsSimpleLoop(loop));
82
83 loop.setUpperBoundMap(mlir::AffineMap::get(
84 loop.getUpperBoundMap().getNumDims(),
85 loop.getUpperBoundMap().getNumSymbols(), {new_bound}));
86 }
87
TileLoop(mlir::AffineForOp loop,int64_t size,mlir::AffineForOp target)88 mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
89 mlir::AffineForOp target) {
90 CHECK(IsSimpleLoop(loop));
91 CHECK(IsSimpleLoop(target));
92 {
93 llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
94 getPerfectlyNestedLoops(all_loops, loop);
95 CHECK(absl::c_linear_search(all_loops, target));
96 }
97
98 auto builder = OpBuilder::atBlockTerminator(target.getBody());
99
100 auto inner_loop =
101 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, size);
102 {
103 auto& inner_operations = inner_loop.getBody()->getOperations();
104 auto& target_operations = target.getBody()->getOperations();
105
106 inner_operations.splice(inner_operations.begin(), target_operations,
107 target_operations.begin(),
108 std::prev(target_operations.end(), 2));
109
110 mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0);
111 CHECK_EQ(0, length.cast<mlir::AffineConstantExpr>().getValue() % size);
112 SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder);
113 }
114
115 for (auto& use :
116 llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
117 mlir::Operation* owner = use.getOwner();
118 BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
119 unsigned new_dim = affine_map.operands.size();
120 affine_map.operands.push_back(inner_loop.getInductionVar());
121 std::vector<mlir::AffineExpr> replacements;
122 for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) {
123 if (affine_map.operands[i] == loop.getInductionVar()) {
124 replacements.push_back(builder.getAffineDimExpr(i) * size +
125 builder.getAffineDimExpr(new_dim));
126 } else {
127 replacements.push_back(builder.getAffineDimExpr(i));
128 }
129 }
130 affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols(
131 replacements, {}, affine_map.operands.size(), 0);
132 auto new_op = CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner));
133 owner->replaceAllUsesWith(new_op);
134 owner->erase();
135 }
136 return inner_loop;
137 }
138
SinkPerfectlyNestedLoops(llvm::MutableArrayRef<mlir::AffineForOp> loops,int rotate_amount)139 void SinkPerfectlyNestedLoops(llvm::MutableArrayRef<mlir::AffineForOp> loops,
140 int rotate_amount) {
141 CHECK_GE(rotate_amount, 0);
142 std::vector<unsigned> permutation(loops.size());
143 std::iota(permutation.begin(), permutation.end(), unsigned(0));
144 std::rotate(permutation.begin(),
145 permutation.begin() + loops.size() - rotate_amount,
146 permutation.end());
147 mlir::permuteLoops(loops, permutation);
148 }
149
150 } // namespace experimental
151 } // namespace xla
152