1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/scatter_expander.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/while_util.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 
27 namespace xla {
28 
29 // Transposes the given scatter_indices such that the index_vector_dim becomes
30 // the most-minor dimension.
TransposeIndexVectorDimToLast(HloInstruction * scatter_indices,int64 index_vector_dim)31 static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
32     HloInstruction* scatter_indices, int64 index_vector_dim) {
33   const Shape& scatter_indices_shape = scatter_indices->shape();
34 
35   if (scatter_indices_shape.dimensions_size() == index_vector_dim) {
36     return scatter_indices;
37   }
38 
39   if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) {
40     return scatter_indices;
41   }
42 
43   std::vector<int64> permutation;
44   permutation.reserve(scatter_indices_shape.dimensions_size());
45   for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
46     if (i != index_vector_dim) {
47       permutation.push_back(i);
48     }
49   }
50   permutation.push_back(index_vector_dim);
51   return MakeTransposeHlo(scatter_indices, permutation);
52 }
53 
54 // Canonicalizes the scatter_indices tensor in order to keep them uniform while
55 // performing the scatter operation.
CanonicalizeScatterIndices(HloInstruction * scatter_indices,int64 index_vector_dim)56 static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
57     HloInstruction* scatter_indices, int64 index_vector_dim) {
58   // Transpose the non-index-vector dimensions to the front.
59   TF_ASSIGN_OR_RETURN(
60       HloInstruction * transposed_scatter_indices,
61       TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim));
62   if (scatter_indices->shape().rank() == index_vector_dim + 1 &&
63       scatter_indices->shape().dimensions(index_vector_dim) == 1) {
64     auto new_shape =
65         ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape());
66     TF_ASSIGN_OR_RETURN(scatter_indices,
67                         MakeReshapeHlo(new_shape, scatter_indices));
68   }
69   bool indices_are_scalar =
70       index_vector_dim == scatter_indices->shape().dimensions_size();
71 
72   // The number of dimensions in scatter_indices that are index dimensions.
73   const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1;
74 
75   // If there is only one index (i.e. scatter_indices has rank 1 and this
76   // scatter is really just a dynamic update slice) add a leading degenerate
77   // dimension for uniformity.  Otherwise create a "collapsed" leading dimension
78   // that subsumes all of the non-index-vector dimensions.
79   const Shape& shape = transposed_scatter_indices->shape();
80   if (shape.dimensions_size() == index_dims_in_scatter_indices) {
81     return PrependDegenerateDims(transposed_scatter_indices, 1);
82   } else {
83     // Collapse all but the dimensions (0 or 1) in scatter_indices containing
84     // the index vectors.
85     return CollapseFirstNDims(
86         transposed_scatter_indices,
87         shape.dimensions_size() - index_dims_in_scatter_indices);
88   }
89 }
90 
91 // Permutes the `updates` tensor such that all the scatter dims appear in the
92 // major dimensions and all the window dimensions appear in the minor
93 // dimensions.
PermuteScatterAndWindowDims(HloInstruction * updates,absl::Span<const int64> update_window_dims)94 static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
95     HloInstruction* updates, absl::Span<const int64> update_window_dims) {
96   std::vector<int64> permutation;
97   const int64 updates_rank = updates->shape().rank();
98   permutation.reserve(updates_rank);
99 
100   for (int64 i = 0; i < updates_rank; ++i) {
101     bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i);
102     if (is_scatter_dim) {
103       permutation.push_back(i);
104     }
105   }
106   for (auto window_dim : update_window_dims) {
107     permutation.push_back(window_dim);
108   }
109 
110   return MakeTransposeHlo(updates, permutation);
111 }
112 
113 // Expands or contracts the scatter indices in the updates tensor.
AdjustScatterDims(const Shape & scatter_indices_shape,HloInstruction * updates,int64 index_vector_dim)114 static StatusOr<HloInstruction*> AdjustScatterDims(
115     const Shape& scatter_indices_shape, HloInstruction* updates,
116     int64 index_vector_dim) {
117   int64 num_scatter_dims = scatter_indices_shape.dimensions_size();
118   if (index_vector_dim < scatter_indices_shape.dimensions_size()) {
119     --num_scatter_dims;
120   }
121   if (num_scatter_dims == 0) {
122     // If there are no scatter dims, this must be a dynamic-update-slice kind of
123     // scatter. In this case, we prepend a degenerate dimension to work
124     // uniformly in the while loop.
125     return PrependDegenerateDims(updates, 1);
126   }
127   return CollapseFirstNDims(updates, num_scatter_dims);
128 }
129 
130 // Expands an index vector from the scatter_indices tensor into a vector that
131 // can be used to dynamic-update-slice to perform the scatter update.
ExpandIndexVectorIntoOperandSpace(HloInstruction * index_vector,const ScatterDimensionNumbers & dim_numbers,int64 operand_rank)132 static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
133     HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
134     int64 operand_rank) {
135   HloComputation* computation = index_vector->parent();
136   const Shape& index_shape = index_vector->shape();
137 
138   // Scatter of a scalar. Return a zero-sized vector of indices.
139   if (operand_rank == 0) {
140     return computation->AddInstruction(HloInstruction::CreateConstant(
141         LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
142   }
143 
144   HloInstruction* zero =
145       computation->AddInstruction(HloInstruction::CreateConstant(
146           LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
147 
148   // We extract out individual components from the smaller index and concatenate
149   // them (interspersing zeros as needed) into the larger index.
150   std::vector<HloInstruction*> expanded_index_components;
151 
152   for (int i = 0; i < operand_rank; i++) {
153     int64 index_vector_dim_index =
154         FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
155     if (index_vector_dim_index !=
156         dim_numbers.scatter_dims_to_operand_dims_size()) {
157       TF_ASSIGN_OR_RETURN(
158           HloInstruction * component_to_concat,
159           MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
160                        /*limit_indices=*/{index_vector_dim_index + 1},
161                        /*strides=*/{1}));
162       expanded_index_components.push_back(component_to_concat);
163     } else {
164       expanded_index_components.push_back(zero);
165     }
166   }
167 
168   return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
169 }
170 
CheckIndexValidity(HloComputation * computation,HloInstruction * index,absl::Span<const int64> operand_dims,absl::Span<const int64> window_sizes,HloModule * module)171 static StatusOr<HloInstruction*> CheckIndexValidity(
172     HloComputation* computation, HloInstruction* index,
173     absl::Span<const int64> operand_dims, absl::Span<const int64> window_sizes,
174     HloModule* module) {
175   DCHECK_NE(nullptr, module);
176   DCHECK_EQ(operand_dims.size(), window_sizes.size());
177 
178   // Valid range for the index: [0, operand_dims - window_sizes]
179 
180   // Check if the index has any negative values.
181   HloInstruction* zero_index =
182       BroadcastZeros(computation, index->shape().element_type(),
183                      AsInt64Slice(index->shape().dimensions()));
184   TF_ASSIGN_OR_RETURN(
185       HloInstruction * negative_index_check,
186       MakeCompareHlo(ComparisonDirection::kLe, zero_index, index));
187 
188   // Check if the index is OOB w.r.t. the operand dimensions and window sizes.
189   std::vector<int64> max_valid_index(operand_dims.size());
190   for (int i = 0; i < operand_dims.size(); ++i) {
191     max_valid_index[i] = operand_dims[i] - window_sizes[i];
192   }
193   TF_ASSIGN_OR_RETURN(
194       HloInstruction * max_valid_index_constant,
195       MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
196                                max_valid_index));
197   TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check,
198                       MakeCompareHlo(ComparisonDirection::kGe,
199                                      max_valid_index_constant, index));
200 
201   // Combine the results of the two checks above.
202   TF_ASSIGN_OR_RETURN(
203       HloInstruction * valid_index,
204       MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check));
205 
206   // Reduce the index validity check vector into a scalar predicate.
207   auto reduction_init = computation->AddInstruction(
208       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
209   TF_ASSIGN_OR_RETURN(
210       HloInstruction * valid_index_reduced,
211       MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module));
212 
213   // Return a broadcasted value of the scalar predicate to the same size as the
214   // window.
215   return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes);
216 }
217 
218 // Body of the while loop that performs the scatter operation using other HLOs.
ScatterLoopBody(HloInstruction * scatter,HloInstruction * induction_var,const std::vector<HloInstruction * > & loop_state)219 static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
220     HloInstruction* scatter, HloInstruction* induction_var,
221     const std::vector<HloInstruction*>& loop_state) {
222   const ScatterDimensionNumbers& dim_numbers =
223       scatter->scatter_dimension_numbers();
224   CHECK_EQ(loop_state.size(), 3);
225   HloInstruction* operand = loop_state[0];
226   HloInstruction* scatter_indices = loop_state[1];
227   HloInstruction* updates = loop_state[2];
228 
229   bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1;
230 
231   // Build a vector form of the induction variable of the while loop.
232   HloInstruction* induction_var_as_vector =
233       MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
234                        /*result_shape_bounds=*/{1});
235 
236   // Pick the index to scatter from scatter_indices based on the induction_var
237   // and transform that to an index into the `operand` space.
238   HloInstruction* index_vector;
239   if (has_scalar_indices) {
240     TF_ASSIGN_OR_RETURN(
241         index_vector,
242         MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1}));
243   } else {
244     TF_ASSIGN_OR_RETURN(
245         HloInstruction * index_into_scatter_indices,
246         PadVectorWithZeros(induction_var_as_vector,
247                            /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
248     int index_vector_size = scatter_indices->shape().dimensions(1);
249     TF_ASSIGN_OR_RETURN(
250         HloInstruction * index_vector_2d,
251         MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices,
252                             {1, index_vector_size}));
253     TF_ASSIGN_OR_RETURN(index_vector,
254                         ElideDegenerateDims(index_vector_2d, {0}));
255   }
256   TF_ASSIGN_OR_RETURN(
257       HloInstruction * scatter_slice_start,
258       ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
259                                         operand->shape().dimensions_size()));
260 
261   // Extract the slice to be used to update from `updates` tensor for the
262   // induction_var corresponding to this iteration of the while loop.
263   TF_ASSIGN_OR_RETURN(
264       HloInstruction * index_into_updates,
265       PadVectorWithZeros(
266           induction_var_as_vector, /*zeros_to_prepend=*/0,
267           /*zeros_to_append=*/updates->shape().dimensions_size() - 1));
268   std::vector<int64> update_slice_bounds(updates->shape().dimensions().begin(),
269                                          updates->shape().dimensions().end());
270   update_slice_bounds[0] = 1;
271   TF_ASSIGN_OR_RETURN(
272       HloInstruction * update_slice,
273       MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds));
274   TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter,
275                       ElideDegenerateDims(update_slice, {0}));
276   TF_ASSIGN_OR_RETURN(
277       HloInstruction * update_slice_with_dims_inserted,
278       InsertDegenerateDims(update_slice_for_scatter,
279                            AsInt64Slice(dim_numbers.inserted_window_dims())));
280 
281   // Note that the following transformation assumes that both DynamicSlice and
282   // DynamicUpdateSlice follow the same semantics for OOB indices. For example,
283   // if there are negative indices and DynamicSlice uses "clamping" semantics,
284   // then the extracted data will be "shifted". Since DynamicUpdateSlice also
285   // follows the same "clamping" semantics, writing the update will also be
286   // "shifted" by exactly the same amount. So, this transformation is correct as
287   // long as the semantics of handling OOB indices remain the same in
288   // DynamicSlice and DynamicUpdateSlice.
289 
290   // Extract the slice to update from `operand` tensor.
291   const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
292   TF_ASSIGN_OR_RETURN(
293       HloInstruction * operand_slice_to_update,
294       MakeDynamicSliceHlo(operand, scatter_slice_start,
295                           AsInt64Slice(update_slice_shape.dimensions())));
296 
297   // Compute the new value for the slice to be updated in `operand` tensor by
298   // combining the existing value and the update value using the update
299   // computation.
300   TF_ASSIGN_OR_RETURN(
301       HloInstruction * updated_operand_slice,
302       MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted},
303                  scatter->to_apply()));
304 
305   TF_ASSIGN_OR_RETURN(
306       HloInstruction * is_index_valid,
307       CheckIndexValidity(
308           operand->parent(), scatter_slice_start,
309           AsInt64Slice(operand->shape().dimensions()),
310           AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()),
311           scatter->GetModule()));
312 
313   // Select the updated operand only if the index is valid. If not, select the
314   // original value.
315   TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply,
316                       MakeSelectHlo(is_index_valid, updated_operand_slice,
317                                     operand_slice_to_update));
318 
319   // Write the updated value of the slice into `operand` tensor.
320   TF_ASSIGN_OR_RETURN(
321       HloInstruction * updated_operand,
322       MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start));
323 
324   return StatusOr<std::vector<HloInstruction*>>{
325       {updated_operand, scatter_indices, updates}};
326 }
327 
328 // High Level Algorithm.
329 //
330 // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
331 //    each row is an index into the operand.
332 // 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1`
333 //    and the scatter dim is the most-major dimension.
334 // 3. Iterate over the set of indices in the canonicalized scatter_indices
335 //    tensor using a while loop, updating the operand for each such index. Each
336 //    iteration of this while loop performs the following:
337 //      a. Pick the index from scatter_indices for this iteration.
338 //      b. Transfrom this index into an index into the operand space.
339 //      c. Extract the slice to be used to update from the updates tensor.
340 //      d. Extract the slice to update from the operand tensor.
341 //      e. Compute the new value for the slice to update by combining the slices
342 //         from c. and d. using the update_computation of scatter.
343 //      f. Write the updated value of the slice into the operand tensor.
344 
ExpandScatter(HloInstruction * scatter)345 StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
346     HloInstruction* scatter) {
347   HloInstruction* operand = scatter->mutable_operand(0);
348   HloInstruction* scatter_indices = scatter->mutable_operand(1);
349   HloInstruction* updates = scatter->mutable_operand(2);
350   const ScatterDimensionNumbers& dim_numbers =
351       scatter->scatter_dimension_numbers();
352 
353   // If the updates tensor is empty, there is no need to update the operand. We
354   // can return the operand as is.
355   if (ShapeUtil::IsZeroElementArray(updates->shape())) {
356     return operand;
357   }
358 
359   // Compute the trip count for the while loop to be used for scatter. This
360   // should be the number of indices we should scatter into the operand.
361   const Shape& scatter_indices_shape = scatter_indices->shape();
362   int64 scatter_loop_trip_count = 1;
363   for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
364     if (i != dim_numbers.index_vector_dim()) {
365       scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
366     }
367   }
368   if (!IsInt32(scatter_loop_trip_count)) {
369     return Unimplemented(
370         "Scatter operations with more than 2147483647 scatter indices are not "
371         "supported. This error occurred for %s.",
372         scatter->ToString());
373   }
374 
375   // Canonicalize the scatter_indices, after which the size of its most-major
376   // dimension must be same as the while loop trip count.
377   TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices,
378                       CanonicalizeScatterIndices(
379                           scatter_indices, dim_numbers.index_vector_dim()));
380   CHECK_EQ(scatter_loop_trip_count,
381            canonical_scatter_indices->shape().dimensions(0));
382 
383   // Canonicalize the updates, after which the size of its most-major dimension
384   // must be same as the while loop trip count.
385   TF_ASSIGN_OR_RETURN(
386       HloInstruction * canonical_updates,
387       PermuteScatterAndWindowDims(
388           updates, AsInt64Slice(dim_numbers.update_window_dims())));
389   TF_ASSIGN_OR_RETURN(
390       HloInstruction * adjusted_canonical_updates,
391       AdjustScatterDims(scatter_indices->shape(), canonical_updates,
392                         dim_numbers.index_vector_dim()));
393   CHECK_EQ(scatter_loop_trip_count,
394            adjusted_canonical_updates->shape().dimensions(0));
395 
396   // The while loop that implements the scatter operation.
397   StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status =
398       WhileUtil::MakeCountedLoop(
399           scatter->parent(), scatter_loop_trip_count,
400           {operand, canonical_scatter_indices, adjusted_canonical_updates},
401           [&](HloInstruction* induction_var,
402               const std::vector<HloInstruction*>& loop_state) {
403             return ScatterLoopBody(scatter, induction_var, loop_state);
404           },
405           scatter->metadata());
406   TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result,
407                       scatter_loop_result_status);
408   return scatter_loop_result.front();
409 }
410 
Run(HloModule * module)411 StatusOr<bool> ScatterExpander::Run(HloModule* module) {
412   std::vector<HloInstruction*> scatter_instrs;
413   for (HloComputation* computation : module->MakeNonfusionComputations()) {
414     for (HloInstruction* instr : computation->instructions()) {
415       if (instr->opcode() == HloOpcode::kScatter) {
416         scatter_instrs.push_back(instr);
417       }
418     }
419   }
420 
421   for (auto instr : scatter_instrs) {
422     TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr));
423     TF_RETURN_IF_ERROR(
424         instr->parent()->ReplaceInstruction(instr, expanded_root));
425   }
426 
427   return !scatter_instrs.empty();
428 }
429 
430 }  // namespace xla
431