1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
17
18 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
19 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
22
23 namespace xla {
24 namespace llvm_ir {
25
MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction * instr)26 bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr) {
27 // Today we can't emit a dynamic-update-slice if the DUS node is parallelized;
28 // the emitter will not emit correct code. It's possible to change this, but
29 // then ParallelTaskAssigner would have to somehow know whether a node *will*
30 // be emitted as an in-place DUS, and it can't, because it doesn't have a
31 // buffer assignment when it runs.
32 if (!instr->outer_dimension_partitions().empty()) {
33 return false;
34 }
35
36 // Until we know the final buffer assignment, any unfused dynamic-update-slice
37 // might be implementable as an in-place DUS.
38 if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) {
39 return true;
40 }
41
42 // A fusion may be implementable as an in-place dynamic update slice if
43 // - it's a loop fusion,
44 // - dynamic-update-slice is the root of the fusion, and
45 // - operand 0 of the dynamic-update-slice is a parameter to the fusion
46 // (ignoring any get-tuple-element operations in the way).
47 if (instr->IsLoopFusion()) {
48 const HloInstruction* fused_root = instr->fused_expression_root();
49 return fused_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
50 fused_root->operand(0)->LatestNonGteAncestor()->opcode() ==
51 HloOpcode::kParameter;
52 }
53
54 return false;
55 }
56
CanUpdateDynamicSliceInPlace(HloInstruction * dynamic_update_slice,const BufferAssignment & assignment)57 bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
58 const BufferAssignment& assignment) {
59 CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
60 const HloInstruction* operand = dynamic_update_slice->operand(0);
61 return assignment.HasTopLevelAllocation(dynamic_update_slice) &&
62 assignment.HasTopLevelAllocation(operand) &&
63 assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
64 }
65
CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,const BufferAssignment & assignment)66 bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
67 const BufferAssignment& assignment) {
68 CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
69 if (!MayBeImplementedAsInPlaceDynamicUpdateSlice(fusion)) {
70 return false;
71 }
72
73 // Walk DynamicUpdateSlice operand(0) to fused parameter and get its
74 // associated operand. See if it shares an allocation with this operand.
75 HloInstruction* fused_root = fusion->fused_expression_root();
76 HloInstruction* fusion_operand;
77 ShapeIndex index;
78 std::tie(fusion_operand, index) =
79 fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
80 // MayBeImplementedAsInPlaceDynamicUpdateSlice should have ensured that
81 // fusion_operand is a parameter.
82 CHECK_EQ(fusion_operand->opcode(), HloOpcode::kParameter);
83 auto* operand = fusion->operand(fusion_operand->parameter_number());
84 return assignment.HasAllocationAt(operand, index) &&
85 assignment.HasAllocationAt(fusion, {}) &&
86 assignment.SharesSliceAtIndex(fusion, {}, operand, index);
87 }
88
89 // Shared implementation of EmitDynamicUpdateSliceInPlace and
90 // EmitFusedDynamicUpdateSliceInPlace.
91 //
92 // Emits a sequential loop if launch_dimensions is null.
93 using IndexGenerator = std::function<StatusOr<llvm::Value*>(int64)>;
94
EmitDynamicUpdateSliceInPlaceImpl(const Shape & update_shape,const IndexGenerator & start_indices_generator,bool is_signed,ElementGenerator update_array_generator,const IrArray & output_array,const gpu::LaunchDimensions * launch_dimensions,absl::string_view name,llvm::IRBuilder<> * b)95 static Status EmitDynamicUpdateSliceInPlaceImpl(
96 const Shape& update_shape, const IndexGenerator& start_indices_generator,
97 bool is_signed, ElementGenerator update_array_generator,
98 const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
99 absl::string_view name, llvm::IRBuilder<>* b) {
100 const Shape& output_shape = output_array.GetShape();
101
102 // Read start indices from start_indices_generator.
103 const int64 rank = output_shape.rank();
104 std::vector<llvm::Value*> start_multi_index(rank);
105 for (int64 i = 0; i < rank; ++i) {
106 TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i));
107 llvm::Value* output_dim_size = llvm::ConstantInt::get(
108 start_multi_index[i]->getType(), output_shape.dimensions(i));
109 llvm::Value* update_dim_size = llvm::ConstantInt::get(
110 start_multi_index[i]->getType(), update_shape.dimensions(i));
111
112 // Clamp the start index so that the update region fits in the operand.
113 // start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
114 llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size);
115 llvm::Value* zero =
116 llvm::ConstantInt::get(start_multi_index[i]->getType(), 0);
117 start_multi_index[i] =
118 b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
119 : llvm::ICmpInst::ICMP_UGE,
120 zero, start_multi_index[i]),
121 zero, start_multi_index[i]);
122
123 start_multi_index[i] =
124 b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
125 : llvm::ICmpInst::ICMP_ULE,
126 max_bound, start_multi_index[i]),
127 max_bound, start_multi_index[i]);
128 }
129
130 auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
131 // Calculate output_index, where we'll write the value from update. For
132 // each dimension,
133 //
134 // output_index[dim] = start_index[dim] + update_index[dim]
135 //
136 std::vector<llvm::Value*> output_multi_index(rank);
137 for (int64 i = 0; i < rank; ++i) {
138 llvm::Value* start_index0 = b->CreateSExtOrBitCast(
139 start_multi_index[i], update_index[i]->getType());
140 output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]);
141 }
142
143 // Do output[output_index] = update[update_index].
144 IrArray::Index output_index(output_multi_index, output_shape,
145 b->getInt64Ty());
146 TF_ASSIGN_OR_RETURN(llvm::Value * update_data,
147 update_array_generator(update_index));
148 output_array.EmitWriteArrayElement(output_index, update_data, b);
149 return Status::OK();
150 };
151
152 if (launch_dimensions != nullptr) {
153 return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape,
154 *launch_dimensions, b)
155 .EmitLoop(name);
156 }
157 return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
158 }
159
EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,const IrArray & output_array,absl::string_view name,llvm::IRBuilder<> * b)160 Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
161 const IrArray& output_array,
162 absl::string_view name,
163 llvm::IRBuilder<>* b) {
164 VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
165
166 // No need to use operand_arrays[0], the input array of the
167 // dynamic-update-slice, because we know it aliases the op's output.
168 IrArray update_array = operand_arrays[1];
169 IrArray start_indices_array = operand_arrays[2];
170 Shape output_shape = output_array.GetShape();
171 Shape update_shape = update_array.GetShape();
172
173 IndexGenerator start_indices_generator = [&](int64 index) {
174 return operand_arrays[2 + index].EmitReadArrayElement(
175 IrArray::Index(b->getInt64Ty()), b);
176 };
177 ElementGenerator update_array_generator = [&](const IrArray::Index& index) {
178 return update_array.EmitReadArrayElement(index, b);
179 };
180
181 bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape());
182 return EmitDynamicUpdateSliceInPlaceImpl(
183 update_shape, start_indices_generator, is_signed, update_array_generator,
184 output_array, /*launch_dimensions=*/nullptr, name, b);
185 }
186
187 // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and
188 // EmitParallelFusedDynamicUpdateSliceInPlace.
189 //
190 // Emits a sequential loop if launch_dimensions is null.
EmitFusedDynamicUpdateSliceInPlaceImpl(const HloComputation * fusion,const IrArray & fusion_output_array,FusedIrEmitter * fused_emitter,const gpu::LaunchDimensions * launch_dimensions,llvm::IRBuilder<> * b)191 static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
192 const HloComputation* fusion, const IrArray& fusion_output_array,
193 FusedIrEmitter* fused_emitter,
194 const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
195 VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " << fusion->ToString();
196
197 auto* dynamic_update_slice = fusion->root_instruction();
198
199 const auto* update = dynamic_update_slice->operand(1);
200 const auto* start_indices = dynamic_update_slice->operand(2);
201 Shape update_shape = update->shape();
202
203 // Our in-place dynamic-update-slice implementation emits a loop over
204 // update_shape. To emit a cache-friendly loop, we need to know that shape's
205 // layout.
206 //
207 // update_shape is inside a fusion node -- it's never materialized in memory
208 // and thus doesn't have a layout. In this case we use the layout of the
209 // fusion node for iteration, since that corresponds to the order in memory of
210 // the buffer we'll be writing to.
211 //
212 // (This isn't necessarily optimal; in some cases it might be faster to peek
213 // through the chain of ops that gives us the update operand and use the
214 // layout of its source buffer(s). But this is no worse than we do with
215 // fusion elsewhere.)
216 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
217 dynamic_update_slice->shape(), &update_shape));
218
219 // Create element generators for update and start_indices.
220 TF_ASSIGN_OR_RETURN(ElementGenerator update_array_generator,
221 fused_emitter->GetGenerator(update));
222
223 IndexGenerator start_indices_generator =
224 [&](int64 index) -> StatusOr<llvm::Value*> {
225 TF_ASSIGN_OR_RETURN(
226 ElementGenerator element_generator,
227 fused_emitter->GetGenerator(dynamic_update_slice->operand(2 + index)));
228 return element_generator(IrArray::Index(b->getInt64Ty()));
229 };
230 bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
231 return EmitDynamicUpdateSliceInPlaceImpl(
232 update_shape, start_indices_generator, is_signed, update_array_generator,
233 fusion_output_array, launch_dimensions, IrName(dynamic_update_slice), b);
234 }
235
EmitFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,const IrArray & fusion_output_array,FusedIrEmitter * fused_emitter,llvm::IRBuilder<> * b)236 Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
237 const IrArray& fusion_output_array,
238 FusedIrEmitter* fused_emitter,
239 llvm::IRBuilder<>* b) {
240 return EmitFusedDynamicUpdateSliceInPlaceImpl(
241 fusion->called_computations()[0], fusion_output_array, fused_emitter,
242 /*launch_dimensions=*/nullptr, b);
243 }
244
EmitParallelFusedDynamicUpdateSliceInPlace(const HloComputation * fusion,const IrArray & fusion_output_array,FusedIrEmitter * fused_emitter,const gpu::LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b)245 Status EmitParallelFusedDynamicUpdateSliceInPlace(
246 const HloComputation* fusion, const IrArray& fusion_output_array,
247 FusedIrEmitter* fused_emitter,
248 const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
249 return EmitFusedDynamicUpdateSliceInPlaceImpl(
250 fusion, fusion_output_array, fused_emitter, &launch_dimensions, b);
251 }
252
253 } // namespace llvm_ir
254 } // namespace xla
255