1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/gather_expander.h"
21 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/while_util.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 
TransposeIndexVectorDimToLast(HloInstruction * start_indices,int64 index_vector_dim)29 static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
30     HloInstruction* start_indices, int64 index_vector_dim) {
31   const Shape& start_indices_shape = start_indices->shape();
32 
33   if (start_indices_shape.dimensions_size() == index_vector_dim) {
34     return start_indices;
35   }
36 
37   if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) {
38     return start_indices;
39   }
40 
41   std::vector<int64> permutation;
42   permutation.reserve(start_indices_shape.dimensions_size());
43   for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
44     if (i != index_vector_dim) {
45       permutation.push_back(i);
46     }
47   }
48   permutation.push_back(index_vector_dim);
49   return MakeTransposeHlo(start_indices, permutation);
50 }
51 
52 // Canonicalizes the start_indices tensors so that we only have deal with some
53 // specific cases in the while loop that does the heavy lifting.
54 //
55 // See the "High Level Algorithm" section for a broader picture.
CanonicalizeGatherIndices(HloInstruction * start_indices,int64 index_vector_dim)56 static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
57     HloInstruction* start_indices, int64 index_vector_dim) {
58   // Transpose the non-index-vector dimensions to the front.
59   TF_ASSIGN_OR_RETURN(
60       HloInstruction * transposed_start_indices,
61       TransposeIndexVectorDimToLast(start_indices, index_vector_dim));
62   bool indices_are_scalar =
63       index_vector_dim == start_indices->shape().dimensions_size();
64 
65   // The number of dimensions in start_indices that are index dimensions.
66   const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1;
67 
68   // If there is only one index (i.e. start_indices has rank 1 and this gather
69   // is really just a dynamic slice) add a leading degenerate dimension for
70   // uniformity.  Otherwise create a "collapsed" leading dimension that subsumes
71   // all of the non-index-vector dimensions.
72   const Shape& shape = transposed_start_indices->shape();
73   if (shape.dimensions_size() == index_dims_in_start_indices) {
74     return PrependDegenerateDims(transposed_start_indices, 1);
75   } else {
76     // Collapse all but the dimensions (0 or 1) in start_indices containing the
77     // index vectors.
78     return CollapseFirstNDims(
79         transposed_start_indices,
80         shape.dimensions_size() - index_dims_in_start_indices);
81   }
82 }
83 
84 // Expands out or contracts away the gather dimensions in the accumulator
85 // produced by the while loop.
AdjustBatchDimsInAccumulator(const Shape & start_indices_shape,HloInstruction * accumulator,int64 index_vector_dim)86 static StatusOr<HloInstruction*> AdjustBatchDimsInAccumulator(
87     const Shape& start_indices_shape, HloInstruction* accumulator,
88     int64 index_vector_dim) {
89   std::vector<int64> batch_dim_bounds;
90   batch_dim_bounds.reserve(start_indices_shape.dimensions_size());
91   for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
92     if (i != index_vector_dim) {
93       batch_dim_bounds.push_back(start_indices_shape.dimensions(i));
94     }
95   }
96 
97   if (batch_dim_bounds.empty()) {
98     // If batch_dim_bounds is empty we must be lowering a (effectively)
99     // dynamic-slice.  In that case, there is a leading degenerate gather
100     // dimension that we added to make this special case play well with the
101     // general while loop which we need to remove now.
102     return ElideDegenerateDims(accumulator, {0});
103   }
104 
105   return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds);
106 }
107 
108 // Expand an index vector from the start_indices tensor into a vector that can
109 // be used to dynamic-slice out of the gather operand.
ExpandIndexVectorIntoOperandSpace(HloInstruction * index_vector,const GatherDimensionNumbers & dim_numbers,int64 operand_rank)110 static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
111     HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers,
112     int64 operand_rank) {
113   HloComputation* computation = index_vector->parent();
114   const Shape& index_shape = index_vector->shape();
115 
116   if (operand_rank == 0) {
117     // This is Gather from a scalar. So, the index vector in operand space must
118     // be a zero-sized vector.
119     return computation->AddInstruction(HloInstruction::CreateConstant(
120         LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
121   }
122 
123   HloInstruction* zero =
124       computation->AddInstruction(HloInstruction::CreateConstant(
125           LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
126 
127   // We extract out individual components from the smaller index and concatenate
128   // them (interspersing zeros as needed) into the larger index.
129   std::vector<HloInstruction*> expanded_index_components;
130 
131   for (int i = 0; i < operand_rank; i++) {
132     int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i);
133     if (index_vector_dim_index != dim_numbers.start_index_map_size()) {
134       TF_ASSIGN_OR_RETURN(
135           HloInstruction * component_to_concat,
136           MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
137                        /*limit_indices=*/{index_vector_dim_index + 1},
138                        /*strides=*/{1}));
139       expanded_index_components.push_back(component_to_concat);
140     } else {
141       expanded_index_components.push_back(zero);
142     }
143   }
144 
145   return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
146 }
147 
148 // This generates the body of the while that implements the main data movement
149 // behavior of gather using dynamic-slice and dynamic-update-slice.
GatherLoopBody(const HloInstruction & gather,HloInstruction * induction_var,const std::vector<HloInstruction * > & incoming_loop_state)150 static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
151     const HloInstruction& gather, HloInstruction* induction_var,
152     const std::vector<HloInstruction*>& incoming_loop_state) {
153   const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers();
154   CHECK_EQ(incoming_loop_state.size(), 3);
155   HloInstruction* const operand = incoming_loop_state[0];
156   HloInstruction* const start_indices = incoming_loop_state[1];
157   HloInstruction* const output_accumulator = incoming_loop_state[2];
158 
159   bool has_scalar_indices = start_indices->shape().dimensions_size() == 1;
160   CHECK_EQ(has_scalar_indices,
161            dim_numbers.index_vector_dim() ==
162                gather.operand(1)->shape().dimensions_size());
163 
164   HloInstruction* induction_var_as_vector =
165       MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
166                        /*result_shape_bounds=*/{1});
167 
168   HloInstruction* index_vector;
169 
170   if (has_scalar_indices) {
171     // In this case start_indices has rank 1 and induction_var_as_vector (of
172     // shape {1}) is an index into this rank 1 tensor.
173     TF_ASSIGN_OR_RETURN(
174         index_vector,
175         MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1}));
176   } else {
177     // In this case start_indices has rank 2 and induction_var_as_vector (of
178     // shape {1}) is an index into just the first dimension of this rank 2
179     // tensor.
180     TF_ASSIGN_OR_RETURN(
181         HloInstruction * index_into_start_indices,
182         PadVectorWithZeros(induction_var_as_vector,
183                            /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
184 
185     int64 index_vector_size = start_indices->shape().dimensions(1);
186     TF_ASSIGN_OR_RETURN(
187         HloInstruction * index_vector_2d,
188         MakeDynamicSliceHlo(start_indices, index_into_start_indices,
189                             {1, index_vector_size}));
190 
191     TF_ASSIGN_OR_RETURN(index_vector,
192                         ElideDegenerateDims(index_vector_2d, {0}));
193   }
194 
195   TF_ASSIGN_OR_RETURN(
196       HloInstruction * gathered_slice_start,
197       ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
198                                         operand->shape().dimensions_size()));
199 
200   TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
201                       MakeDynamicSliceHlo(operand, gathered_slice_start,
202                                           gather.gather_slice_sizes()));
203 
204   TF_ASSIGN_OR_RETURN(
205       HloInstruction* const gathered_slice_with_dims_collapsed,
206       ElideDegenerateDims(gathered_slice,
207                           AsInt64Slice(dim_numbers.collapsed_slice_dims())));
208 
209   TF_ASSIGN_OR_RETURN(
210       HloInstruction* const gathered_slice_for_update,
211       PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1));
212 
213   TF_ASSIGN_OR_RETURN(
214       HloInstruction* const index_vector_into_accumulator,
215       PadVectorWithZeros(
216           induction_var_as_vector, /*zeros_to_prepend=*/0,
217           /*zeros_to_append=*/
218           gathered_slice_with_dims_collapsed->shape().dimensions_size()));
219 
220   TF_ASSIGN_OR_RETURN(
221       HloInstruction* const updated_accumulator,
222       MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
223                                 index_vector_into_accumulator));
224 
225   // New loop state -- only the accumulator has changed.  The
226   // WhileUtil::MakeCountedLoop functions takes care of the induction variable
227   // and the while loop exit condition.
228   return StatusOr<std::vector<HloInstruction*>>{
229       {operand, start_indices, updated_accumulator}};
230 }
231 
CreateGatherLoopAccumulatorInitValue(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> slice_sizes,int64 gather_loop_trip_count,const GatherDimensionNumbers & dim_numbers)232 static HloInstruction* CreateGatherLoopAccumulatorInitValue(
233     HloComputation* computation, PrimitiveType element_type,
234     absl::Span<const int64> slice_sizes, int64 gather_loop_trip_count,
235     const GatherDimensionNumbers& dim_numbers) {
236   std::vector<int64> accumulator_state_shape_dims;
237   accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
238   accumulator_state_shape_dims.push_back(gather_loop_trip_count);
239   for (int64 i = 0; i < slice_sizes.size(); i++) {
240     if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
241       accumulator_state_shape_dims.push_back(slice_sizes[i]);
242     }
243   }
244   return BroadcastZeros(computation, element_type,
245                         accumulator_state_shape_dims);
246 }
247 
248 // `accumulator` is almost the tensor the gather operation would have produced,
249 // except that it has the dimensions in the wrong order -- the batch dimensions
250 // are the major dimensions and the offset dimensions are the minor dimensions.
251 // Fix this up with a transpose.
PermuteBatchAndOffsetDims(HloInstruction * accumulator,absl::Span<const int64> offset_dims,int64 output_rank)252 static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
253     HloInstruction* accumulator, absl::Span<const int64> offset_dims,
254     int64 output_rank) {
255   std::vector<int64> permutation;
256   permutation.reserve(output_rank);
257 
258   int64 batch_idx_counter = 0;
259   int64 offset_idx_counter = output_rank - offset_dims.size();
260   for (int64 i = 0; i < output_rank; i++) {
261     bool is_offset_dim = absl::c_binary_search(offset_dims, i);
262     if (is_offset_dim) {
263       permutation.push_back(offset_idx_counter++);
264     } else {
265       permutation.push_back(batch_idx_counter++);
266     }
267   }
268 
269   return MakeTransposeHlo(accumulator, permutation);
270 }
271 
272 // High Level Algorithm
273 //
274 // We follow the following steps in sequence:
275 //
276 //  1. We canonicalize the start_indices tensor such that it has rank
277 //     2 (i.e. is a matrix) where each row is an index vector into the
278 //     operand.
279 //  2. We iterate over the set of indices in the canonicalized
280 //     start_indices tensor using a while loop, accumulating slices
281 //     of the operand tensor into an accumulator using
282 //     DynamicUpdateSlice.
283 //  3. The accumulator result from the while loop from (2) is then
284 //     reshaped to split out all the individual gather dimensions and
285 //     then transposed to give the final result.
286 //
287 // As an example, if we started with the following operation:
288 //
289 //   HloModule TensorFlowGatherMultipleBatchDims
290 //
291 //   ENTRY main {
292 //     operand = s32[3,3] parameter(0)
293 //     indices = s32[2,2] parameter(1)
294 //     ROOT gather = s32[2,3,2] gather(operand, indices),
295 //         offset_dims={1},
296 //         collapsed_slice_dims={1},
297 //         start_index_map={1},
298 //         index_vector_dim=2,
299 //         slice_sizes={3, 1}
300 //   }
301 //
302 // We'd first reshape indices to s32[4,1], where each row is an index
303 // into operand.  We'd then run a loop to slice out 4 tensors of shape
304 // [3,1] out of operand into an accumulator of shape [4,3,1].  We then
305 // reshape this result to [2,2,3] and finally transpose it to [2,3,2].
306 
ExpandInstruction(HloInstruction * gather_instr)307 StatusOr<HloInstruction*> GatherExpander::ExpandInstruction(
308     HloInstruction* gather_instr) {
309   CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape()));
310 
311   HloComputation* computation = gather_instr->parent();
312   HloInstruction* operand = gather_instr->mutable_operand(0);
313   HloInstruction* start_indices = gather_instr->mutable_operand(1);
314   const Shape& start_indices_shape = start_indices->shape();
315   const Shape& output_shape = gather_instr->shape();
316   int64 output_rank = output_shape.dimensions_size();
317 
318   const GatherDimensionNumbers& dim_numbers =
319       gather_instr->gather_dimension_numbers();
320 
321   int64 gather_loop_trip_count = 1;
322   for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
323     if (i != dim_numbers.index_vector_dim()) {
324       gather_loop_trip_count *= start_indices_shape.dimensions(i);
325     }
326   }
327 
328   if (!IsInt32(gather_loop_trip_count)) {
329     return Unimplemented(
330         "Gather operations with more than 2147483647 gather indices are not "
331         "supported. This error occurred for %s.",
332         gather_instr->ToString());
333   }
334 
335   TF_ASSIGN_OR_RETURN(
336       HloInstruction * canonical_start_indices,
337       CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim()));
338 
339   CHECK_EQ(gather_loop_trip_count,
340            canonical_start_indices->shape().dimensions(0));
341 
342   HloInstruction* accumulator_init = CreateGatherLoopAccumulatorInitValue(
343       computation, output_shape.element_type(),
344       gather_instr->gather_slice_sizes(), gather_loop_trip_count,
345       gather_instr->gather_dimension_numbers());
346 
347   StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error =
348       WhileUtil::MakeCountedLoop(
349           computation, gather_loop_trip_count,
350           {operand, canonical_start_indices, accumulator_init},
351           [&](HloInstruction* indvar,
352               const std::vector<HloInstruction*>& loop_state) {
353             return GatherLoopBody(*gather_instr, indvar, loop_state);
354           },
355           gather_instr->metadata());
356 
357   TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> gather_loop_result,
358                       gather_loop_result_or_error);
359 
360   HloInstruction* accumulator_result = gather_loop_result.back();
361 
362   TF_ASSIGN_OR_RETURN(
363       HloInstruction* const accumulator_with_batch_dims_decanonicalized,
364       AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result,
365                                    dim_numbers.index_vector_dim()));
366 
367   return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized,
368                                    AsInt64Slice(dim_numbers.offset_dims()),
369                                    output_rank);
370 }
371 
InstructionMatchesPattern(HloInstruction * inst)372 bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) {
373   return inst->opcode() == HloOpcode::kGather &&
374          // Avoid expanding gather ops that produce zero sized tensors,
375          // instead punt these to ZeroSizedHloElimination.
376          !ShapeUtil::IsZeroElementArray(inst->shape());
377 }
378 
379 }  // namespace xla
380