1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/gather_expander.h"
21 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/while_util.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28
TransposeIndexVectorDimToLast(HloInstruction * start_indices,int64 index_vector_dim)29 static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
30 HloInstruction* start_indices, int64 index_vector_dim) {
31 const Shape& start_indices_shape = start_indices->shape();
32
33 if (start_indices_shape.dimensions_size() == index_vector_dim) {
34 return start_indices;
35 }
36
37 if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) {
38 return start_indices;
39 }
40
41 std::vector<int64> permutation;
42 permutation.reserve(start_indices_shape.dimensions_size());
43 for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
44 if (i != index_vector_dim) {
45 permutation.push_back(i);
46 }
47 }
48 permutation.push_back(index_vector_dim);
49 return MakeTransposeHlo(start_indices, permutation);
50 }
51
52 // Canonicalizes the start_indices tensors so that we only have deal with some
53 // specific cases in the while loop that does the heavy lifting.
54 //
55 // See the "High Level Algorithm" section for a broader picture.
CanonicalizeGatherIndices(HloInstruction * start_indices,int64 index_vector_dim)56 static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
57 HloInstruction* start_indices, int64 index_vector_dim) {
58 // Transpose the non-index-vector dimensions to the front.
59 TF_ASSIGN_OR_RETURN(
60 HloInstruction * transposed_start_indices,
61 TransposeIndexVectorDimToLast(start_indices, index_vector_dim));
62 bool indices_are_scalar =
63 index_vector_dim == start_indices->shape().dimensions_size();
64
65 // The number of dimensions in start_indices that are index dimensions.
66 const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1;
67
68 // If there is only one index (i.e. start_indices has rank 1 and this gather
69 // is really just a dynamic slice) add a leading degenerate dimension for
70 // uniformity. Otherwise create a "collapsed" leading dimension that subsumes
71 // all of the non-index-vector dimensions.
72 const Shape& shape = transposed_start_indices->shape();
73 if (shape.dimensions_size() == index_dims_in_start_indices) {
74 return PrependDegenerateDims(transposed_start_indices, 1);
75 } else {
76 // Collapse all but the dimensions (0 or 1) in start_indices containing the
77 // index vectors.
78 return CollapseFirstNDims(
79 transposed_start_indices,
80 shape.dimensions_size() - index_dims_in_start_indices);
81 }
82 }
83
84 // Expands out or contracts away the gather dimensions in the accumulator
85 // produced by the while loop.
AdjustBatchDimsInAccumulator(const Shape & start_indices_shape,HloInstruction * accumulator,int64 index_vector_dim)86 static StatusOr<HloInstruction*> AdjustBatchDimsInAccumulator(
87 const Shape& start_indices_shape, HloInstruction* accumulator,
88 int64 index_vector_dim) {
89 std::vector<int64> batch_dim_bounds;
90 batch_dim_bounds.reserve(start_indices_shape.dimensions_size());
91 for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
92 if (i != index_vector_dim) {
93 batch_dim_bounds.push_back(start_indices_shape.dimensions(i));
94 }
95 }
96
97 if (batch_dim_bounds.empty()) {
98 // If batch_dim_bounds is empty we must be lowering a (effectively)
99 // dynamic-slice. In that case, there is a leading degenerate gather
100 // dimension that we added to make this special case play well with the
101 // general while loop which we need to remove now.
102 return ElideDegenerateDims(accumulator, {0});
103 }
104
105 return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds);
106 }
107
108 // Expand an index vector from the start_indices tensor into a vector that can
109 // be used to dynamic-slice out of the gather operand.
ExpandIndexVectorIntoOperandSpace(HloInstruction * index_vector,const GatherDimensionNumbers & dim_numbers,int64 operand_rank)110 static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
111 HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers,
112 int64 operand_rank) {
113 HloComputation* computation = index_vector->parent();
114 const Shape& index_shape = index_vector->shape();
115
116 if (operand_rank == 0) {
117 // This is Gather from a scalar. So, the index vector in operand space must
118 // be a zero-sized vector.
119 return computation->AddInstruction(HloInstruction::CreateConstant(
120 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
121 }
122
123 HloInstruction* zero =
124 computation->AddInstruction(HloInstruction::CreateConstant(
125 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
126
127 // We extract out individual components from the smaller index and concatenate
128 // them (interspersing zeros as needed) into the larger index.
129 std::vector<HloInstruction*> expanded_index_components;
130
131 for (int i = 0; i < operand_rank; i++) {
132 int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i);
133 if (index_vector_dim_index != dim_numbers.start_index_map_size()) {
134 TF_ASSIGN_OR_RETURN(
135 HloInstruction * component_to_concat,
136 MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
137 /*limit_indices=*/{index_vector_dim_index + 1},
138 /*strides=*/{1}));
139 expanded_index_components.push_back(component_to_concat);
140 } else {
141 expanded_index_components.push_back(zero);
142 }
143 }
144
145 return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
146 }
147
148 // This generates the body of the while that implements the main data movement
149 // behavior of gather using dynamic-slice and dynamic-update-slice.
GatherLoopBody(const HloInstruction & gather,HloInstruction * induction_var,const std::vector<HloInstruction * > & incoming_loop_state)150 static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
151 const HloInstruction& gather, HloInstruction* induction_var,
152 const std::vector<HloInstruction*>& incoming_loop_state) {
153 const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers();
154 CHECK_EQ(incoming_loop_state.size(), 3);
155 HloInstruction* const operand = incoming_loop_state[0];
156 HloInstruction* const start_indices = incoming_loop_state[1];
157 HloInstruction* const output_accumulator = incoming_loop_state[2];
158
159 bool has_scalar_indices = start_indices->shape().dimensions_size() == 1;
160 CHECK_EQ(has_scalar_indices,
161 dim_numbers.index_vector_dim() ==
162 gather.operand(1)->shape().dimensions_size());
163
164 HloInstruction* induction_var_as_vector =
165 MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
166 /*result_shape_bounds=*/{1});
167
168 HloInstruction* index_vector;
169
170 if (has_scalar_indices) {
171 // In this case start_indices has rank 1 and induction_var_as_vector (of
172 // shape {1}) is an index into this rank 1 tensor.
173 TF_ASSIGN_OR_RETURN(
174 index_vector,
175 MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1}));
176 } else {
177 // In this case start_indices has rank 2 and induction_var_as_vector (of
178 // shape {1}) is an index into just the first dimension of this rank 2
179 // tensor.
180 TF_ASSIGN_OR_RETURN(
181 HloInstruction * index_into_start_indices,
182 PadVectorWithZeros(induction_var_as_vector,
183 /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
184
185 int64 index_vector_size = start_indices->shape().dimensions(1);
186 TF_ASSIGN_OR_RETURN(
187 HloInstruction * index_vector_2d,
188 MakeDynamicSliceHlo(start_indices, index_into_start_indices,
189 {1, index_vector_size}));
190
191 TF_ASSIGN_OR_RETURN(index_vector,
192 ElideDegenerateDims(index_vector_2d, {0}));
193 }
194
195 TF_ASSIGN_OR_RETURN(
196 HloInstruction * gathered_slice_start,
197 ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
198 operand->shape().dimensions_size()));
199
200 TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
201 MakeDynamicSliceHlo(operand, gathered_slice_start,
202 gather.gather_slice_sizes()));
203
204 TF_ASSIGN_OR_RETURN(
205 HloInstruction* const gathered_slice_with_dims_collapsed,
206 ElideDegenerateDims(gathered_slice,
207 AsInt64Slice(dim_numbers.collapsed_slice_dims())));
208
209 TF_ASSIGN_OR_RETURN(
210 HloInstruction* const gathered_slice_for_update,
211 PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1));
212
213 TF_ASSIGN_OR_RETURN(
214 HloInstruction* const index_vector_into_accumulator,
215 PadVectorWithZeros(
216 induction_var_as_vector, /*zeros_to_prepend=*/0,
217 /*zeros_to_append=*/
218 gathered_slice_with_dims_collapsed->shape().dimensions_size()));
219
220 TF_ASSIGN_OR_RETURN(
221 HloInstruction* const updated_accumulator,
222 MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
223 index_vector_into_accumulator));
224
225 // New loop state -- only the accumulator has changed. The
226 // WhileUtil::MakeCountedLoop functions takes care of the induction variable
227 // and the while loop exit condition.
228 return StatusOr<std::vector<HloInstruction*>>{
229 {operand, start_indices, updated_accumulator}};
230 }
231
CreateGatherLoopAccumulatorInitValue(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> slice_sizes,int64 gather_loop_trip_count,const GatherDimensionNumbers & dim_numbers)232 static HloInstruction* CreateGatherLoopAccumulatorInitValue(
233 HloComputation* computation, PrimitiveType element_type,
234 absl::Span<const int64> slice_sizes, int64 gather_loop_trip_count,
235 const GatherDimensionNumbers& dim_numbers) {
236 std::vector<int64> accumulator_state_shape_dims;
237 accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
238 accumulator_state_shape_dims.push_back(gather_loop_trip_count);
239 for (int64 i = 0; i < slice_sizes.size(); i++) {
240 if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
241 accumulator_state_shape_dims.push_back(slice_sizes[i]);
242 }
243 }
244 return BroadcastZeros(computation, element_type,
245 accumulator_state_shape_dims);
246 }
247
248 // `accumulator` is almost the tensor the gather operation would have produced,
249 // except that it has the dimensions in the wrong order -- the batch dimensions
250 // are the major dimensions and the offset dimensions are the minor dimensions.
251 // Fix this up with a transpose.
PermuteBatchAndOffsetDims(HloInstruction * accumulator,absl::Span<const int64> offset_dims,int64 output_rank)252 static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
253 HloInstruction* accumulator, absl::Span<const int64> offset_dims,
254 int64 output_rank) {
255 std::vector<int64> permutation;
256 permutation.reserve(output_rank);
257
258 int64 batch_idx_counter = 0;
259 int64 offset_idx_counter = output_rank - offset_dims.size();
260 for (int64 i = 0; i < output_rank; i++) {
261 bool is_offset_dim = absl::c_binary_search(offset_dims, i);
262 if (is_offset_dim) {
263 permutation.push_back(offset_idx_counter++);
264 } else {
265 permutation.push_back(batch_idx_counter++);
266 }
267 }
268
269 return MakeTransposeHlo(accumulator, permutation);
270 }
271
272 // High Level Algorithm
273 //
274 // We follow the following steps in sequence:
275 //
276 // 1. We canonicalize the start_indices tensor such that it has rank
277 // 2 (i.e. is a matrix) where each row is an index vector into the
278 // operand.
279 // 2. We iterate over the set of indices in the canonicalized
280 // start_indices tensor using a while loop, accumulating slices
281 // of the operand tensor into an accumulator using
282 // DynamicUpdateSlice.
283 // 3. The accumulator result from the while loop from (2) is then
284 // reshaped to split out all the individual gather dimensions and
285 // then transposed to give the final result.
286 //
287 // As an example, if we started with the following operation:
288 //
289 // HloModule TensorFlowGatherMultipleBatchDims
290 //
291 // ENTRY main {
292 // operand = s32[3,3] parameter(0)
293 // indices = s32[2,2] parameter(1)
294 // ROOT gather = s32[2,3,2] gather(operand, indices),
295 // offset_dims={1},
296 // collapsed_slice_dims={1},
297 // start_index_map={1},
298 // index_vector_dim=2,
299 // slice_sizes={3, 1}
300 // }
301 //
302 // We'd first reshape indices to s32[4,1], where each row is an index
303 // into operand. We'd then run a loop to slice out 4 tensors of shape
304 // [3,1] out of operand into an accumulator of shape [4,3,1]. We then
305 // reshape this result to [2,2,3] and finally transpose it to [2,3,2].
306
ExpandInstruction(HloInstruction * gather_instr)307 StatusOr<HloInstruction*> GatherExpander::ExpandInstruction(
308 HloInstruction* gather_instr) {
309 CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape()));
310
311 HloComputation* computation = gather_instr->parent();
312 HloInstruction* operand = gather_instr->mutable_operand(0);
313 HloInstruction* start_indices = gather_instr->mutable_operand(1);
314 const Shape& start_indices_shape = start_indices->shape();
315 const Shape& output_shape = gather_instr->shape();
316 int64 output_rank = output_shape.dimensions_size();
317
318 const GatherDimensionNumbers& dim_numbers =
319 gather_instr->gather_dimension_numbers();
320
321 int64 gather_loop_trip_count = 1;
322 for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
323 if (i != dim_numbers.index_vector_dim()) {
324 gather_loop_trip_count *= start_indices_shape.dimensions(i);
325 }
326 }
327
328 if (!IsInt32(gather_loop_trip_count)) {
329 return Unimplemented(
330 "Gather operations with more than 2147483647 gather indices are not "
331 "supported. This error occurred for %s.",
332 gather_instr->ToString());
333 }
334
335 TF_ASSIGN_OR_RETURN(
336 HloInstruction * canonical_start_indices,
337 CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim()));
338
339 CHECK_EQ(gather_loop_trip_count,
340 canonical_start_indices->shape().dimensions(0));
341
342 HloInstruction* accumulator_init = CreateGatherLoopAccumulatorInitValue(
343 computation, output_shape.element_type(),
344 gather_instr->gather_slice_sizes(), gather_loop_trip_count,
345 gather_instr->gather_dimension_numbers());
346
347 StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error =
348 WhileUtil::MakeCountedLoop(
349 computation, gather_loop_trip_count,
350 {operand, canonical_start_indices, accumulator_init},
351 [&](HloInstruction* indvar,
352 const std::vector<HloInstruction*>& loop_state) {
353 return GatherLoopBody(*gather_instr, indvar, loop_state);
354 },
355 gather_instr->metadata());
356
357 TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> gather_loop_result,
358 gather_loop_result_or_error);
359
360 HloInstruction* accumulator_result = gather_loop_result.back();
361
362 TF_ASSIGN_OR_RETURN(
363 HloInstruction* const accumulator_with_batch_dims_decanonicalized,
364 AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result,
365 dim_numbers.index_vector_dim()));
366
367 return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized,
368 AsInt64Slice(dim_numbers.offset_dims()),
369 output_rank);
370 }
371
InstructionMatchesPattern(HloInstruction * inst)372 bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) {
373 return inst->opcode() == HloOpcode::kGather &&
374 // Avoid expanding gather ops that produce zero sized tensors,
375 // instead punt these to ZeroSizedHloElimination.
376 !ShapeUtil::IsZeroElementArray(inst->shape());
377 }
378
379 } // namespace xla
380