1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
17 
18 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
19 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
22 
23 namespace xla {
24 namespace llvm_ir {
25 
MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction * instr)26 bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr) {
27   // Today we can't emit a dynamic-update-slice if the DUS node is parallelized;
28   // the emitter will not emit correct code.  It's possible to change this, but
29   // then ParallelTaskAssigner would have to somehow know whether a node *will*
30   // be emitted as an in-place DUS, and it can't, because it doesn't have a
31   // buffer assignment when it runs.
32   if (!instr->outer_dimension_partitions().empty()) {
33     return false;
34   }
35 
36   // Until we know the final buffer assignment, any unfused dynamic-update-slice
37   // might be implementable as an in-place DUS.
38   if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) {
39     return true;
40   }
41 
42   // A fusion may be implementable as an in-place dynamic update slice if
43   //  - it's a loop fusion,
44   //  - dynamic-update-slice is the root of the fusion, and
45   //  - operand 0 of the dynamic-update-slice is a parameter to the fusion
46   //    (ignoring any get-tuple-element operations in the way).
47   if (instr->IsLoopFusion()) {
48     const HloInstruction* fused_root = instr->fused_expression_root();
49     return fused_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
50            fused_root->operand(0)->LatestNonGteAncestor()->opcode() ==
51                HloOpcode::kParameter;
52   }
53 
54   return false;
55 }
56 
CanUpdateDynamicSliceInPlace(HloInstruction * dynamic_update_slice,const BufferAssignment & assignment)57 bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
58                                   const BufferAssignment& assignment) {
59   CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
60   const HloInstruction* operand = dynamic_update_slice->operand(0);
61   return assignment.HasTopLevelAllocation(dynamic_update_slice) &&
62          assignment.HasTopLevelAllocation(operand) &&
63          assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
64 }
65 
CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,const BufferAssignment & assignment)66 bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
67                                            const BufferAssignment& assignment) {
68   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
69   if (!MayBeImplementedAsInPlaceDynamicUpdateSlice(fusion)) {
70     return false;
71   }
72 
73   // Walk DynamicUpdateSlice operand(0) to fused parameter and get its
74   // associated operand. See if it shares an allocation with this operand.
75   HloInstruction* fused_root = fusion->fused_expression_root();
76   HloInstruction* fusion_operand;
77   ShapeIndex index;
78   std::tie(fusion_operand, index) =
79       fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
80   // MayBeImplementedAsInPlaceDynamicUpdateSlice should have ensured that
81   // fusion_operand is a parameter.
82   CHECK_EQ(fusion_operand->opcode(), HloOpcode::kParameter);
83   auto* operand = fusion->operand(fusion_operand->parameter_number());
84   return assignment.HasAllocationAt(operand, index) &&
85          assignment.HasAllocationAt(fusion, {}) &&
86          assignment.SharesSliceAtIndex(fusion, {}, operand, index);
87 }
88 
89 // Shared implementation of EmitDynamicUpdateSliceInPlace and
90 // EmitFusedDynamicUpdateSliceInPlace.
91 //
92 // Emits a sequential loop if launch_dimensions is null.
93 using IndexGenerator = std::function<StatusOr<llvm::Value*>(int64)>;
94 
EmitDynamicUpdateSliceInPlaceImpl(const Shape & update_shape,const IndexGenerator & start_indices_generator,bool is_signed,ElementGenerator update_array_generator,const IrArray & output_array,const gpu::LaunchDimensions * launch_dimensions,absl::string_view name,llvm::IRBuilder<> * b)95 static Status EmitDynamicUpdateSliceInPlaceImpl(
96     const Shape& update_shape, const IndexGenerator& start_indices_generator,
97     bool is_signed, ElementGenerator update_array_generator,
98     const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
99     absl::string_view name, llvm::IRBuilder<>* b) {
100   const Shape& output_shape = output_array.GetShape();
101 
102   // Read start indices from start_indices_generator.
103   const int64 rank = output_shape.rank();
104   std::vector<llvm::Value*> start_multi_index(rank);
105   for (int64 i = 0; i < rank; ++i) {
106     TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i));
107     llvm::Value* output_dim_size = llvm::ConstantInt::get(
108         start_multi_index[i]->getType(), output_shape.dimensions(i));
109     llvm::Value* update_dim_size = llvm::ConstantInt::get(
110         start_multi_index[i]->getType(), update_shape.dimensions(i));
111 
112     // Clamp the start index so that the update region fits in the operand.
113     // start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
114     llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size);
115     llvm::Value* zero =
116         llvm::ConstantInt::get(start_multi_index[i]->getType(), 0);
117     start_multi_index[i] =
118         b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
119                                                 : llvm::ICmpInst::ICMP_UGE,
120                                       zero, start_multi_index[i]),
121                         zero, start_multi_index[i]);
122 
123     start_multi_index[i] =
124         b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
125                                                 : llvm::ICmpInst::ICMP_ULE,
126                                       max_bound, start_multi_index[i]),
127                         max_bound, start_multi_index[i]);
128   }
129 
130   auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
131     // Calculate output_index, where we'll write the value from update.  For
132     // each dimension,
133     //
134     //   output_index[dim] = start_index[dim] + update_index[dim]
135     //
136     std::vector<llvm::Value*> output_multi_index(rank);
137     for (int64 i = 0; i < rank; ++i) {
138       llvm::Value* start_index0 = b->CreateSExtOrBitCast(
139           start_multi_index[i], update_index[i]->getType());
140       output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]);
141     }
142 
143     // Do output[output_index] = update[update_index].
144     IrArray::Index output_index(output_multi_index, output_shape,
145                                 b->getInt64Ty());
146     TF_ASSIGN_OR_RETURN(llvm::Value * update_data,
147                         update_array_generator(update_index));
148     output_array.EmitWriteArrayElement(output_index, update_data, b);
149     return Status::OK();
150   };
151 
152   if (launch_dimensions != nullptr) {
153     return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape,
154                                     *launch_dimensions, b)
155         .EmitLoop(name);
156   }
157   return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
158 }
159 
EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,const IrArray & output_array,absl::string_view name,llvm::IRBuilder<> * b)160 Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
161                                      const IrArray& output_array,
162                                      absl::string_view name,
163                                      llvm::IRBuilder<>* b) {
164   VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
165 
166   // No need to use operand_arrays[0], the input array of the
167   // dynamic-update-slice, because we know it aliases the op's output.
168   IrArray update_array = operand_arrays[1];
169   IrArray start_indices_array = operand_arrays[2];
170   Shape output_shape = output_array.GetShape();
171   Shape update_shape = update_array.GetShape();
172 
173   IndexGenerator start_indices_generator = [&](int64 index) {
174     return operand_arrays[2 + index].EmitReadArrayElement(
175         IrArray::Index(b->getInt64Ty()), b);
176   };
177   ElementGenerator update_array_generator = [&](const IrArray::Index& index) {
178     return update_array.EmitReadArrayElement(index, b);
179   };
180 
181   bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape());
182   return EmitDynamicUpdateSliceInPlaceImpl(
183       update_shape, start_indices_generator, is_signed, update_array_generator,
184       output_array, /*launch_dimensions=*/nullptr, name, b);
185 }
186 
187 // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and
188 // EmitParallelFusedDynamicUpdateSliceInPlace.
189 //
190 // Emits a sequential loop if launch_dimensions is null.
EmitFusedDynamicUpdateSliceInPlaceImpl(const HloComputation * fusion,const IrArray & fusion_output_array,FusedIrEmitter * fused_emitter,const gpu::LaunchDimensions * launch_dimensions,llvm::IRBuilder<> * b)191 static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
192     const HloComputation* fusion, const IrArray& fusion_output_array,
193     FusedIrEmitter* fused_emitter,
194     const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
195   VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " << fusion->ToString();
196 
197   auto* dynamic_update_slice = fusion->root_instruction();
198 
199   const auto* update = dynamic_update_slice->operand(1);
200   const auto* start_indices = dynamic_update_slice->operand(2);
201   Shape update_shape = update->shape();
202 
203   // Our in-place dynamic-update-slice implementation emits a loop over
204   // update_shape.  To emit a cache-friendly loop, we need to know that shape's
205   // layout.
206   //
207   // update_shape is inside a fusion node -- it's never materialized in memory
208   // and thus doesn't have a layout.  In this case we use the layout of the
209   // fusion node for iteration, since that corresponds to the order in memory of
210   // the buffer we'll be writing to.
211   //
212   // (This isn't necessarily optimal; in some cases it might be faster to peek
213   // through the chain of ops that gives us the update operand and use the
214   // layout of its source buffer(s).  But this is no worse than we do with
215   // fusion elsewhere.)
216   TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
217       dynamic_update_slice->shape(), &update_shape));
218 
219   // Create element generators for update and start_indices.
220   TF_ASSIGN_OR_RETURN(ElementGenerator update_array_generator,
221                       fused_emitter->GetGenerator(update));
222 
223   IndexGenerator start_indices_generator =
224       [&](int64 index) -> StatusOr<llvm::Value*> {
225     TF_ASSIGN_OR_RETURN(
226         ElementGenerator element_generator,
227         fused_emitter->GetGenerator(dynamic_update_slice->operand(2 + index)));
228     return element_generator(IrArray::Index(b->getInt64Ty()));
229   };
230   bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
231   return EmitDynamicUpdateSliceInPlaceImpl(
232       update_shape, start_indices_generator, is_signed, update_array_generator,
233       fusion_output_array, launch_dimensions, IrName(dynamic_update_slice), b);
234 }
235 
EmitFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,const IrArray & fusion_output_array,FusedIrEmitter * fused_emitter,llvm::IRBuilder<> * b)236 Status EmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
237                                           const IrArray& fusion_output_array,
238                                           FusedIrEmitter* fused_emitter,
239                                           llvm::IRBuilder<>* b) {
240   return EmitFusedDynamicUpdateSliceInPlaceImpl(
241       fusion->called_computations()[0], fusion_output_array, fused_emitter,
242       /*launch_dimensions=*/nullptr, b);
243 }
244 
EmitParallelFusedDynamicUpdateSliceInPlace(const HloComputation * fusion,const IrArray & fusion_output_array,FusedIrEmitter * fused_emitter,const gpu::LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b)245 Status EmitParallelFusedDynamicUpdateSliceInPlace(
246     const HloComputation* fusion, const IrArray& fusion_output_array,
247     FusedIrEmitter* fused_emitter,
248     const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
249   return EmitFusedDynamicUpdateSliceInPlaceImpl(
250       fusion, fusion_output_array, fused_emitter, &launch_dimensions, b);
251 }
252 
253 }  // namespace llvm_ir
254 }  // namespace xla
255