1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/scatter_expander.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/while_util.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26
27 namespace xla {
28
29 // Transposes the given scatter_indices such that the index_vector_dim becomes
30 // the most-minor dimension.
TransposeIndexVectorDimToLast(HloInstruction * scatter_indices,int64 index_vector_dim)31 static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
32 HloInstruction* scatter_indices, int64 index_vector_dim) {
33 const Shape& scatter_indices_shape = scatter_indices->shape();
34
35 if (scatter_indices_shape.dimensions_size() == index_vector_dim) {
36 return scatter_indices;
37 }
38
39 if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) {
40 return scatter_indices;
41 }
42
43 std::vector<int64> permutation;
44 permutation.reserve(scatter_indices_shape.dimensions_size());
45 for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
46 if (i != index_vector_dim) {
47 permutation.push_back(i);
48 }
49 }
50 permutation.push_back(index_vector_dim);
51 return MakeTransposeHlo(scatter_indices, permutation);
52 }
53
54 // Canonicalizes the scatter_indices tensor in order to keep them uniform while
55 // performing the scatter operation.
CanonicalizeScatterIndices(HloInstruction * scatter_indices,int64 index_vector_dim)56 static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
57 HloInstruction* scatter_indices, int64 index_vector_dim) {
58 // Transpose the non-index-vector dimensions to the front.
59 TF_ASSIGN_OR_RETURN(
60 HloInstruction * transposed_scatter_indices,
61 TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim));
62 if (scatter_indices->shape().rank() == index_vector_dim + 1 &&
63 scatter_indices->shape().dimensions(index_vector_dim) == 1) {
64 auto new_shape =
65 ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape());
66 TF_ASSIGN_OR_RETURN(scatter_indices,
67 MakeReshapeHlo(new_shape, scatter_indices));
68 }
69 bool indices_are_scalar =
70 index_vector_dim == scatter_indices->shape().dimensions_size();
71
72 // The number of dimensions in scatter_indices that are index dimensions.
73 const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1;
74
75 // If there is only one index (i.e. scatter_indices has rank 1 and this
76 // scatter is really just a dynamic update slice) add a leading degenerate
77 // dimension for uniformity. Otherwise create a "collapsed" leading dimension
78 // that subsumes all of the non-index-vector dimensions.
79 const Shape& shape = transposed_scatter_indices->shape();
80 if (shape.dimensions_size() == index_dims_in_scatter_indices) {
81 return PrependDegenerateDims(transposed_scatter_indices, 1);
82 } else {
83 // Collapse all but the dimensions (0 or 1) in scatter_indices containing
84 // the index vectors.
85 return CollapseFirstNDims(
86 transposed_scatter_indices,
87 shape.dimensions_size() - index_dims_in_scatter_indices);
88 }
89 }
90
91 // Permutes the `updates` tensor such that all the scatter dims appear in the
92 // major dimensions and all the window dimensions appear in the minor
93 // dimensions.
PermuteScatterAndWindowDims(HloInstruction * updates,absl::Span<const int64> update_window_dims)94 static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
95 HloInstruction* updates, absl::Span<const int64> update_window_dims) {
96 std::vector<int64> permutation;
97 const int64 updates_rank = updates->shape().rank();
98 permutation.reserve(updates_rank);
99
100 for (int64 i = 0; i < updates_rank; ++i) {
101 bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i);
102 if (is_scatter_dim) {
103 permutation.push_back(i);
104 }
105 }
106 for (auto window_dim : update_window_dims) {
107 permutation.push_back(window_dim);
108 }
109
110 return MakeTransposeHlo(updates, permutation);
111 }
112
113 // Expands or contracts the scatter indices in the updates tensor.
AdjustScatterDims(const Shape & scatter_indices_shape,HloInstruction * updates,int64 index_vector_dim)114 static StatusOr<HloInstruction*> AdjustScatterDims(
115 const Shape& scatter_indices_shape, HloInstruction* updates,
116 int64 index_vector_dim) {
117 int64 num_scatter_dims = scatter_indices_shape.dimensions_size();
118 if (index_vector_dim < scatter_indices_shape.dimensions_size()) {
119 --num_scatter_dims;
120 }
121 if (num_scatter_dims == 0) {
122 // If there are no scatter dims, this must be a dynamic-update-slice kind of
123 // scatter. In this case, we prepend a degenerate dimension to work
124 // uniformly in the while loop.
125 return PrependDegenerateDims(updates, 1);
126 }
127 return CollapseFirstNDims(updates, num_scatter_dims);
128 }
129
130 // Expands an index vector from the scatter_indices tensor into a vector that
131 // can be used to dynamic-update-slice to perform the scatter update.
ExpandIndexVectorIntoOperandSpace(HloInstruction * index_vector,const ScatterDimensionNumbers & dim_numbers,int64 operand_rank)132 static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
133 HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
134 int64 operand_rank) {
135 HloComputation* computation = index_vector->parent();
136 const Shape& index_shape = index_vector->shape();
137
138 // Scatter of a scalar. Return a zero-sized vector of indices.
139 if (operand_rank == 0) {
140 return computation->AddInstruction(HloInstruction::CreateConstant(
141 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
142 }
143
144 HloInstruction* zero =
145 computation->AddInstruction(HloInstruction::CreateConstant(
146 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
147
148 // We extract out individual components from the smaller index and concatenate
149 // them (interspersing zeros as needed) into the larger index.
150 std::vector<HloInstruction*> expanded_index_components;
151
152 for (int i = 0; i < operand_rank; i++) {
153 int64 index_vector_dim_index =
154 FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
155 if (index_vector_dim_index !=
156 dim_numbers.scatter_dims_to_operand_dims_size()) {
157 TF_ASSIGN_OR_RETURN(
158 HloInstruction * component_to_concat,
159 MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
160 /*limit_indices=*/{index_vector_dim_index + 1},
161 /*strides=*/{1}));
162 expanded_index_components.push_back(component_to_concat);
163 } else {
164 expanded_index_components.push_back(zero);
165 }
166 }
167
168 return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
169 }
170
CheckIndexValidity(HloComputation * computation,HloInstruction * index,absl::Span<const int64> operand_dims,absl::Span<const int64> window_sizes,HloModule * module)171 static StatusOr<HloInstruction*> CheckIndexValidity(
172 HloComputation* computation, HloInstruction* index,
173 absl::Span<const int64> operand_dims, absl::Span<const int64> window_sizes,
174 HloModule* module) {
175 DCHECK_NE(nullptr, module);
176 DCHECK_EQ(operand_dims.size(), window_sizes.size());
177
178 // Valid range for the index: [0, operand_dims - window_sizes]
179
180 // Check if the index has any negative values.
181 HloInstruction* zero_index =
182 BroadcastZeros(computation, index->shape().element_type(),
183 AsInt64Slice(index->shape().dimensions()));
184 TF_ASSIGN_OR_RETURN(
185 HloInstruction * negative_index_check,
186 MakeCompareHlo(ComparisonDirection::kLe, zero_index, index));
187
188 // Check if the index is OOB w.r.t. the operand dimensions and window sizes.
189 std::vector<int64> max_valid_index(operand_dims.size());
190 for (int i = 0; i < operand_dims.size(); ++i) {
191 max_valid_index[i] = operand_dims[i] - window_sizes[i];
192 }
193 TF_ASSIGN_OR_RETURN(
194 HloInstruction * max_valid_index_constant,
195 MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
196 max_valid_index));
197 TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check,
198 MakeCompareHlo(ComparisonDirection::kGe,
199 max_valid_index_constant, index));
200
201 // Combine the results of the two checks above.
202 TF_ASSIGN_OR_RETURN(
203 HloInstruction * valid_index,
204 MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check));
205
206 // Reduce the index validity check vector into a scalar predicate.
207 auto reduction_init = computation->AddInstruction(
208 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
209 TF_ASSIGN_OR_RETURN(
210 HloInstruction * valid_index_reduced,
211 MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module));
212
213 // Return a broadcasted value of the scalar predicate to the same size as the
214 // window.
215 return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes);
216 }
217
218 // Body of the while loop that performs the scatter operation using other HLOs.
ScatterLoopBody(HloInstruction * scatter,HloInstruction * induction_var,const std::vector<HloInstruction * > & loop_state)219 static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
220 HloInstruction* scatter, HloInstruction* induction_var,
221 const std::vector<HloInstruction*>& loop_state) {
222 const ScatterDimensionNumbers& dim_numbers =
223 scatter->scatter_dimension_numbers();
224 CHECK_EQ(loop_state.size(), 3);
225 HloInstruction* operand = loop_state[0];
226 HloInstruction* scatter_indices = loop_state[1];
227 HloInstruction* updates = loop_state[2];
228
229 bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1;
230
231 // Build a vector form of the induction variable of the while loop.
232 HloInstruction* induction_var_as_vector =
233 MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
234 /*result_shape_bounds=*/{1});
235
236 // Pick the index to scatter from scatter_indices based on the induction_var
237 // and transform that to an index into the `operand` space.
238 HloInstruction* index_vector;
239 if (has_scalar_indices) {
240 TF_ASSIGN_OR_RETURN(
241 index_vector,
242 MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1}));
243 } else {
244 TF_ASSIGN_OR_RETURN(
245 HloInstruction * index_into_scatter_indices,
246 PadVectorWithZeros(induction_var_as_vector,
247 /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
248 int index_vector_size = scatter_indices->shape().dimensions(1);
249 TF_ASSIGN_OR_RETURN(
250 HloInstruction * index_vector_2d,
251 MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices,
252 {1, index_vector_size}));
253 TF_ASSIGN_OR_RETURN(index_vector,
254 ElideDegenerateDims(index_vector_2d, {0}));
255 }
256 TF_ASSIGN_OR_RETURN(
257 HloInstruction * scatter_slice_start,
258 ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
259 operand->shape().dimensions_size()));
260
261 // Extract the slice to be used to update from `updates` tensor for the
262 // induction_var corresponding to this iteration of the while loop.
263 TF_ASSIGN_OR_RETURN(
264 HloInstruction * index_into_updates,
265 PadVectorWithZeros(
266 induction_var_as_vector, /*zeros_to_prepend=*/0,
267 /*zeros_to_append=*/updates->shape().dimensions_size() - 1));
268 std::vector<int64> update_slice_bounds(updates->shape().dimensions().begin(),
269 updates->shape().dimensions().end());
270 update_slice_bounds[0] = 1;
271 TF_ASSIGN_OR_RETURN(
272 HloInstruction * update_slice,
273 MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds));
274 TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter,
275 ElideDegenerateDims(update_slice, {0}));
276 TF_ASSIGN_OR_RETURN(
277 HloInstruction * update_slice_with_dims_inserted,
278 InsertDegenerateDims(update_slice_for_scatter,
279 AsInt64Slice(dim_numbers.inserted_window_dims())));
280
281 // Note that the following transformation assumes that both DynamicSlice and
282 // DynamicUpdateSlice follow the same semantics for OOB indices. For example,
283 // if there are negative indices and DynamicSlice uses "clamping" semantics,
284 // then the extracted data will be "shifted". Since DynamicUpdateSlice also
285 // follows the same "clamping" semantics, writing the update will also be
286 // "shifted" by exactly the same amount. So, this transformation is correct as
287 // long as the semantics of handling OOB indices remain the same in
288 // DynamicSlice and DynamicUpdateSlice.
289
290 // Extract the slice to update from `operand` tensor.
291 const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
292 TF_ASSIGN_OR_RETURN(
293 HloInstruction * operand_slice_to_update,
294 MakeDynamicSliceHlo(operand, scatter_slice_start,
295 AsInt64Slice(update_slice_shape.dimensions())));
296
297 // Compute the new value for the slice to be updated in `operand` tensor by
298 // combining the existing value and the update value using the update
299 // computation.
300 TF_ASSIGN_OR_RETURN(
301 HloInstruction * updated_operand_slice,
302 MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted},
303 scatter->to_apply()));
304
305 TF_ASSIGN_OR_RETURN(
306 HloInstruction * is_index_valid,
307 CheckIndexValidity(
308 operand->parent(), scatter_slice_start,
309 AsInt64Slice(operand->shape().dimensions()),
310 AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()),
311 scatter->GetModule()));
312
313 // Select the updated operand only if the index is valid. If not, select the
314 // original value.
315 TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply,
316 MakeSelectHlo(is_index_valid, updated_operand_slice,
317 operand_slice_to_update));
318
319 // Write the updated value of the slice into `operand` tensor.
320 TF_ASSIGN_OR_RETURN(
321 HloInstruction * updated_operand,
322 MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start));
323
324 return StatusOr<std::vector<HloInstruction*>>{
325 {updated_operand, scatter_indices, updates}};
326 }
327
328 // High Level Algorithm.
329 //
330 // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
331 // each row is an index into the operand.
332 // 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1`
333 // and the scatter dim is the most-major dimension.
334 // 3. Iterate over the set of indices in the canonicalized scatter_indices
335 // tensor using a while loop, updating the operand for each such index. Each
336 // iteration of this while loop performs the following:
337 // a. Pick the index from scatter_indices for this iteration.
338 // b. Transfrom this index into an index into the operand space.
339 // c. Extract the slice to be used to update from the updates tensor.
340 // d. Extract the slice to update from the operand tensor.
341 // e. Compute the new value for the slice to update by combining the slices
342 // from c. and d. using the update_computation of scatter.
343 // f. Write the updated value of the slice into the operand tensor.
344
ExpandScatter(HloInstruction * scatter)345 StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
346 HloInstruction* scatter) {
347 HloInstruction* operand = scatter->mutable_operand(0);
348 HloInstruction* scatter_indices = scatter->mutable_operand(1);
349 HloInstruction* updates = scatter->mutable_operand(2);
350 const ScatterDimensionNumbers& dim_numbers =
351 scatter->scatter_dimension_numbers();
352
353 // If the updates tensor is empty, there is no need to update the operand. We
354 // can return the operand as is.
355 if (ShapeUtil::IsZeroElementArray(updates->shape())) {
356 return operand;
357 }
358
359 // Compute the trip count for the while loop to be used for scatter. This
360 // should be the number of indices we should scatter into the operand.
361 const Shape& scatter_indices_shape = scatter_indices->shape();
362 int64 scatter_loop_trip_count = 1;
363 for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
364 if (i != dim_numbers.index_vector_dim()) {
365 scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
366 }
367 }
368 if (!IsInt32(scatter_loop_trip_count)) {
369 return Unimplemented(
370 "Scatter operations with more than 2147483647 scatter indices are not "
371 "supported. This error occurred for %s.",
372 scatter->ToString());
373 }
374
375 // Canonicalize the scatter_indices, after which the size of its most-major
376 // dimension must be same as the while loop trip count.
377 TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices,
378 CanonicalizeScatterIndices(
379 scatter_indices, dim_numbers.index_vector_dim()));
380 CHECK_EQ(scatter_loop_trip_count,
381 canonical_scatter_indices->shape().dimensions(0));
382
383 // Canonicalize the updates, after which the size of its most-major dimension
384 // must be same as the while loop trip count.
385 TF_ASSIGN_OR_RETURN(
386 HloInstruction * canonical_updates,
387 PermuteScatterAndWindowDims(
388 updates, AsInt64Slice(dim_numbers.update_window_dims())));
389 TF_ASSIGN_OR_RETURN(
390 HloInstruction * adjusted_canonical_updates,
391 AdjustScatterDims(scatter_indices->shape(), canonical_updates,
392 dim_numbers.index_vector_dim()));
393 CHECK_EQ(scatter_loop_trip_count,
394 adjusted_canonical_updates->shape().dimensions(0));
395
396 // The while loop that implements the scatter operation.
397 StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status =
398 WhileUtil::MakeCountedLoop(
399 scatter->parent(), scatter_loop_trip_count,
400 {operand, canonical_scatter_indices, adjusted_canonical_updates},
401 [&](HloInstruction* induction_var,
402 const std::vector<HloInstruction*>& loop_state) {
403 return ScatterLoopBody(scatter, induction_var, loop_state);
404 },
405 scatter->metadata());
406 TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result,
407 scatter_loop_result_status);
408 return scatter_loop_result.front();
409 }
410
Run(HloModule * module)411 StatusOr<bool> ScatterExpander::Run(HloModule* module) {
412 std::vector<HloInstruction*> scatter_instrs;
413 for (HloComputation* computation : module->MakeNonfusionComputations()) {
414 for (HloInstruction* instr : computation->instructions()) {
415 if (instr->opcode() == HloOpcode::kScatter) {
416 scatter_instrs.push_back(instr);
417 }
418 }
419 }
420
421 for (auto instr : scatter_instrs) {
422 TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr));
423 TF_RETURN_IF_ERROR(
424 instr->parent()->ReplaceInstruction(instr, expanded_root));
425 }
426
427 return !scatter_instrs.empty();
428 }
429
430 } // namespace xla
431