1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/space_to_batch_converter.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <iterator>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <tuple>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/algorithm.h"
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/container/flat_hash_set.h"
32 #include "absl/memory/memory.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/literal_util.h"
36 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
37 #include "tensorflow/compiler/xla/service/hlo_computation.h"
38 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
40 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
41 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
42 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
43 #include "tensorflow/compiler/xla/service/shape_inference.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/core/bitmap.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/stream_executor/lib/statusor.h"
54
55 namespace xla {
56
57 namespace {
58
59 namespace m = match;
60
61 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
62 // operations with small batch counts into convolutions with larger batch
63 // counts by moving space to batch.
64 class ConvolutionVisitor {
65 public:
66 // Top-level function to begin space-to-batch conversion.
67 Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution);
68
69 // Struct containing details about a convolution.
70 struct ConvDetails {
71 int64 spatial_dimension_to_split, inherent_low_padding,
72 inherent_high_padding, stride, spatial_size, base_dilation_factor,
73 halo_size, high_padding_for_conv, low_padding_for_conv,
74 kernel_spatial_dim_size, input_dim_size;
75 };
76
77 // Return a struct containing various necessary information pieces for
78 // performing space-to-batch on a convolution.
79 ConvDetails GetConvolutionDetails(HloInstruction* convolution,
80 ConvolutionDimensionNumbers& dim_numbers);
81
82 // Function that determines if space-to-batch can be propagated into the
83 // consumer. Such propagation is only possible when all required operands are
84 // space-to-batch'ed.
85 bool CanPropagate(HloInstruction* consumer, HloInstruction* producer,
86 bool last_try = false);
87
88 // Returns true if the op has all its direct and indirect operands being
89 // created via broadcasts. Consumer uses op, and is space-to-batched.
90 // instructions_to_transform returns the reverse post order instruction graph.
91 bool IsBroadcastTree(HloInstruction* op, HloInstruction* consumer,
92 std::vector<HloInstruction*>& instructions_to_transform);
93
94 // Replicates the broadcast tree with space-to-batched instructions.
95 void RewriteBroadcastTree(
96 HloInstruction* producer,
97 std::vector<HloInstruction*>& instructions_to_transform);
98
99 // Propagate space-to-batch on a broadcast instruction.
100 void PropagateOnBroadcast(HloInstruction* consumer, HloInstruction* producer);
101
102 // This function checks if the HLO instrution supports propagation.
103 bool SupportedOpForPropagation(HloInstruction* consumer,
104 HloInstruction* producer);
105
106 // Method that checks validity of Broadcast propagation.
107 bool IsBroadcastPropagatable(HloInstruction* broadcast,
108 HloInstruction* old_other_op);
109
110 // Propagates space-to-batch on the op, and returns a bool that indicates if
111 // the users of the op need to be propagated through.
112 StatusOr<bool> Propagate(HloInstruction* consumer, HloInstruction* producer);
113
114 // Splits the given spatial dimension on the activations and returns the
115 // new instructions, and the dimension permutation of the new shape.
116 StatusOr<std::pair<HloInstruction*, std::vector<int64>>> SplitSpace(
117 HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
118 int64& spatial_dimension_to_split, int64& activations_batch_dim,
119 int64 high_padding, int64 low_padding, int64 spatial_split_size,
120 int64 num_splits, bool is_backprop = false, bool is_rhs = false);
121
122 // Helper function for the SplitSpace function above. Handles padding and
123 // reshaping to generate space-to-batched shape.
124 StatusOr<HloInstruction*> SplitSpaceHelper(
125 HloInstruction* activations, int64 spatial_dimension_to_split,
126 int64 activations_batch_dim, int64 high_padding, int64 low_padding,
127 int64 spatial_split_size, int64 num_splits);
128
129 // Perform space-to-batch propagation on constants.
130 StatusOr<HloInstruction*> PropagateOnConstant(HloInstruction* consumer,
131 HloInstruction* producer);
132
133 // Perform space-to-batch propagation on the convolution. Assumes the
134 // activations were already space-to-batched.
135 Status PropagateOnConv(HloInstruction* convolution);
136
137 // Perform space-to-batch propagation on the backprop filter convolution.
138 // Assumes the activations and kernel were already space-to-batched.
139 Status PropagateOnBackpropFilterConv(HloInstruction* convolution);
140
141 // Method that checks validity of space-to-batch on a given convolution.
142 bool IsConvSuitableForSpaceToBatch(HloInstruction* convolution);
143
144 // Once a convolution has been space-to-batch'ed, this function will
145 // transitively propagate the space-to-batch-ness on rest of the graph.
146 Status PropagateOnUsers(HloInstruction* old_conv);
147
148 // Generates masked output with valid data. This is useful when larger shapes
149 // are generated due to space-to-batch.
150 StatusOr<HloInstruction*> SelectValidPortion(
151 HloInstruction* new_instr, HloInstruction* old_instr,
152 HloInstruction* select_val, int64 new_batch_dim, int64 new_space_dim,
153 int64 old_batch_dim, int64 old_space_dim);
154
155 struct SpaceNextToBatchDetails {
156 HloInstruction* instr;
157 std::vector<int64> transpose_dims;
158 };
159
160 // Performs tranposition so that space dimension follows the batch dimension.
161 StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch(
162 HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
163 int64& spatial_dimension_to_split, int64& activations_batch_dim,
164 bool is_backprop = false, bool is_rhs = false);
165
166 // Increases the spatial dimension size in an already space-to-batched shape
167 // so that the new size is new_spatial_dim_size.
168 StatusOr<HloInstruction*> IncreaseSpatialSizeOnSpaceToBatchedShape(
169 HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
170 int64 spatial_dimension, int64 new_spatial_dim_size);
171
172 // Function that converts spaced-to-batch shape back to the original.
173 StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr);
174
175 // Duplicates elements at boundaries.
176 StatusOr<HloInstruction*> HaloDuplicateWithSlice(
177 HloInstruction* activations, int64 spatial_dimension_to_split,
178 int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
179 int64 high_padding, int64 halo_size, int64 original_split_dim_size,
180 HloInstruction* pad_val = nullptr);
181
182 // Runs the visitor on a computation.
183 StatusOr<bool> Run();
184
185 // Returns whether any convolution ops were rewritten.
changed() const186 const bool changed() const { return changed_; }
187
188 ~ConvolutionVisitor() = default;
189
190 explicit ConvolutionVisitor(int64 limit_on_batch_size,
191 HloComputation* computation);
192
get_chosen_spatial_dim(HloInstruction * convolution)193 int64 get_chosen_spatial_dim(HloInstruction* convolution) {
194 return convolution->convolution_dimension_numbers()
195 .input_spatial_dimensions_size() -
196 1;
197 }
198
DimLookUp(absl::Span<const int64> permute_dims,int64 id)199 int64 DimLookUp(absl::Span<const int64> permute_dims, int64 id) {
200 return permute_dims[id];
201 }
202
ReverseDimLookUp(absl::Span<const int64> permute_dims,int64 id)203 int64 ReverseDimLookUp(absl::Span<const int64> permute_dims, int64 id) {
204 return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
205 }
206
207 HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
208 HloInstruction* instr, int64 depth);
209
210 private:
211 // Current HloComputation instance the ConvolutionVisitor is traversing.
212 HloComputation* computation_;
213
214 absl::flat_hash_set<HloInstruction*> convs_to_visit_;
215 std::vector<HloInstruction*> conv_visitor_list_;
216 HloInstructionSet non_propagatable_instrs_;
217 // Map from a given spaced-to-batch instruction to its batched-to-space
218 // version.
219 absl::flat_hash_map<HloInstruction*, HloInstruction*> batch_to_space_map_;
220
221 // Map from old (non space-to-batch) instructions to space-to-batch'ed
222 // instructions.
223 absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_instrs_;
224
225 // Map from instruction to dimensions of the shape (first is batch, second is
226 // space). This is with respect to the old instruction.
227 absl::flat_hash_map<HloInstruction*, std::pair<int64, int64>>
228 instr_to_dim_map_;
229
230 // Map from space-to-batch'ed instruction to its permute dims.
231 absl::flat_hash_map<HloInstruction*, std::vector<int64>>
232 instr_to_dim_permute_map_;
233
234 // Map maintaining previously space-to-batched broadcasts.
235 absl::flat_hash_map<HloInstruction*, absl::flat_hash_set<HloInstruction*>>
236 broadcast_map_;
237
238 // Whether rewrite has occurred.
239 bool changed_ = false;
240
241 // Limit on batch size to apply this technique on.
242 int64 limit_on_batch_size_;
243
244 // We choose the new batch size to be kNumSplits times that of the old batch
245 // so that space-to-batch propagation through several convolutional layers is
246 // consistent.
247 static constexpr int64 kNumSplits = 8;
248
249 // Depth for searching reduce window
250 static constexpr int64 kReduceWindowSearchDepth = 10;
251 };
252
ConvolutionVisitor(int64 limit_on_batch_size,HloComputation * computation)253 ConvolutionVisitor::ConvolutionVisitor(int64 limit_on_batch_size,
254 HloComputation* computation) {
255 computation_ = computation;
256 limit_on_batch_size_ = limit_on_batch_size;
257 for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
258 if (inst->opcode() != HloOpcode::kConvolution) {
259 continue;
260 }
261
262 auto convolution = inst;
263 // Perform legality checks.
264 if (!IsConvSuitableForSpaceToBatch(convolution)) {
265 VLOG(1) << "Conv not suitable for space-to-batch "
266 << convolution->ToString();
267 continue;
268 }
269 VLOG(1) << "Conv added to space-to-batch worklist "
270 << convolution->ToString();
271 convs_to_visit_.insert(convolution);
272 conv_visitor_list_.push_back(convolution);
273 }
274 }
275
IsConvSuitableForSpaceToBatch(HloInstruction * convolution)276 bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
277 HloInstruction* convolution) {
278 ConvolutionDimensionNumbers dim_numbers =
279 convolution->convolution_dimension_numbers();
280
281 // If there are no spatial dims, we return.
282 if (dim_numbers.input_spatial_dimensions_size() < 1) {
283 return false;
284 }
285
286 // Batch in batch_group_count has different semantics (it isn't true batch).
287 // Consider supporting this case in future if needed.
288 if (convolution->batch_group_count() != 1) {
289 return false;
290 }
291
292 if (convolution->window()
293 .dimensions(get_chosen_spatial_dim(convolution))
294 .window_dilation() != 1) {
295 return false;
296 }
297
298 const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
299
300 const int64 low_pad = convolution->window()
301 .dimensions(get_chosen_spatial_dim(convolution))
302 .padding_low();
303
304 // TODO(b/168316428): Support base dilations more generically.
305 if (c.base_dilation_factor != 1) {
306 if (c.stride != 1) {
307 return false;
308 }
309 // For low pad of 0, only support a pointwise kernel.
310 if (low_pad == 0) {
311 if (c.kernel_spatial_dim_size != 1) {
312 return false;
313 }
314 } else if (c.kernel_spatial_dim_size != c.base_dilation_factor + 1 ||
315 low_pad != c.base_dilation_factor - 1) {
316 // Only support dilations such that base dilation factor and low pad are
317 // compatible with kernel_spatial_dim_size to be compatible with
318 // HaloDuplicateWithSlice.
319 return false;
320 }
321 }
322
323 int64 activations_batch_dim = dim_numbers.input_batch_dimension();
324
325 const int64 old_batch_size =
326 convolution->operand(0)->shape().dimensions(activations_batch_dim);
327
328 if (old_batch_size > limit_on_batch_size_) {
329 return false;
330 }
331
332 VLOG(1) << "spatial size " << c.spatial_size;
333
334 // If the ratio is not within the 2X range, we can't Halo Pad from the next
335 // split.
336 if (c.halo_size > CeilOfRatio(c.spatial_size, kNumSplits)) {
337 return false;
338 }
339 VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
340 return true;
341 }
342
HaloDuplicateWithSlice(HloInstruction * activations,int64 spatial_dimension_to_split,int64 activations_batch_dim,int64 old_batch_size,int64 low_padding,int64 high_padding,int64 halo_size,int64 original_split_dim_size,HloInstruction * pad_val)343 StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
344 HloInstruction* activations, int64 spatial_dimension_to_split,
345 int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
346 int64 high_padding, int64 halo_size, int64 original_split_dim_size,
347 HloInstruction* pad_val) {
348 const int64 original_batch_size =
349 activations->shape().dimensions(activations_batch_dim) / kNumSplits;
350
351 if (original_batch_size > 1) {
352 std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
353 activations->shape().dimensions().end());
354 new_dimensions[activations_batch_dim] = kNumSplits;
355 new_dimensions.insert(new_dimensions.begin() + activations_batch_dim,
356 original_batch_size);
357
358 // Reshape the output of the new conv into the old convolutions shape.
359 TF_ASSIGN_OR_RETURN(activations,
360 MakeReshapeHlo(new_dimensions, activations));
361
362 spatial_dimension_to_split++;
363 activations_batch_dim++;
364 }
365
366 const int64 rank = activations->shape().rank();
367 const int64 spatial_split_size =
368 activations->shape().dimensions(spatial_dimension_to_split);
369 const int64 batch_size =
370 activations->shape().dimensions(activations_batch_dim);
371
372 CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size);
373 VLOG(1) << "In HaloDuplicateWithSlice with activations "
374 << activations->ToString() << " batch_size " << batch_size
375 << " spatial_split_size " << spatial_split_size << " low_padding "
376 << low_padding << " halo size " << halo_size;
377
378 HloInstruction* first_slice = nullptr;
379
380 std::vector<int64> strides(rank, 1);
381 HloInstruction* padding =
382 pad_val == nullptr
383 ? computation_->AddInstruction(HloInstruction::CreateConstant(
384 LiteralUtil::Zero(activations->shape().element_type())))
385 : pad_val;
386
387 if (low_padding > 0) {
388 std::vector<int64> start_indices(rank, 0),
389 end_indices(activations->shape().dimensions().begin(),
390 activations->shape().dimensions().end());
391 start_indices[spatial_dimension_to_split] =
392 spatial_split_size - low_padding;
393 end_indices[activations_batch_dim] = batch_size - 1;
394 end_indices[spatial_dimension_to_split] = spatial_split_size;
395
396 TF_ASSIGN_OR_RETURN(first_slice, MakeSliceHlo(activations, start_indices,
397 end_indices, strides));
398 VLOG(1) << "first slice " << first_slice->ToString();
399 PaddingConfig padding_config =
400 MakeNoPaddingConfig(first_slice->shape().dimensions_size());
401 padding_config.mutable_dimensions(activations_batch_dim)
402 ->set_edge_padding_low(1);
403
404 TF_ASSIGN_OR_RETURN(first_slice,
405 MakePadHlo(first_slice, padding, padding_config));
406 }
407
408 HloInstruction* halo_region = nullptr;
409 if (halo_size - low_padding > 0) {
410 std::vector<int64> start_indices_halo(rank, 0),
411 end_indices_halo(activations->shape().dimensions().begin(),
412 activations->shape().dimensions().end());
413
414 start_indices_halo[activations_batch_dim] = 1;
415 end_indices_halo[spatial_dimension_to_split] = halo_size - low_padding;
416
417 TF_ASSIGN_OR_RETURN(halo_region,
418 MakeSliceHlo(activations, start_indices_halo,
419 end_indices_halo, strides));
420 VLOG(1) << "halo_region " << halo_region->ToString();
421 PaddingConfig padding_config_halo =
422 MakeNoPaddingConfig(halo_region->shape().dimensions_size());
423 padding_config_halo.mutable_dimensions(activations_batch_dim)
424 ->set_edge_padding_high(1);
425 TF_ASSIGN_OR_RETURN(halo_region,
426 MakePadHlo(halo_region, padding, padding_config_halo));
427 }
428
429 if (halo_size == 0 && low_padding != 0) {
430 std::vector<int64> start_indices_activations_cut(rank, 0),
431 end_indices_activations_cut(activations->shape().dimensions().begin(),
432 activations->shape().dimensions().end());
433 // When no halo is needed, we must slice out activations.
434 if (low_padding > 0) {
435 end_indices_activations_cut[spatial_dimension_to_split] =
436 spatial_split_size - low_padding;
437 } else {
438 start_indices_activations_cut[spatial_dimension_to_split] =
439 0 - low_padding;
440 end_indices_activations_cut[spatial_dimension_to_split] =
441 spatial_split_size;
442 }
443
444 TF_ASSIGN_OR_RETURN(activations,
445 MakeSliceHlo(activations, start_indices_activations_cut,
446 end_indices_activations_cut, strides));
447 }
448
449 if (first_slice != nullptr) {
450 TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({first_slice, activations},
451 spatial_dimension_to_split));
452 }
453
454 if (halo_region != nullptr) {
455 TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region},
456 spatial_dimension_to_split));
457 }
458
459 if (original_batch_size > 1) {
460 std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
461 activations->shape().dimensions().end());
462 new_dimensions[activations_batch_dim] = original_batch_size * kNumSplits;
463 new_dimensions.erase(new_dimensions.begin() + activations_batch_dim - 1);
464
465 // Reshape the output of the new conv into the old convolutions shape.
466 TF_ASSIGN_OR_RETURN(activations,
467 MakeReshapeHlo(new_dimensions, activations));
468
469 spatial_dimension_to_split++;
470 activations_batch_dim++;
471 }
472
473 VLOG(1) << "HaloDuplicated activations " << activations->ToString();
474 return activations;
475 }
476
477 StatusOr<ConvolutionVisitor::SpaceNextToBatchDetails>
BringSpaceNextToBatch(HloInstruction * activations,ConvolutionDimensionNumbers & dim_numbers,int64 & spatial_dimension_to_split,int64 & activations_batch_dim,bool is_backprop,bool is_rhs)478 ConvolutionVisitor::BringSpaceNextToBatch(
479 HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
480 int64& spatial_dimension_to_split, int64& activations_batch_dim,
481 bool is_backprop, bool is_rhs) {
482 std::vector<int64> transpose_dims(activations->shape().rank());
483 if (spatial_dimension_to_split == activations_batch_dim + 1) {
484 absl::c_iota(transpose_dims, 0);
485 } else {
486 ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
487 int64 pushed_counter = 0;
488 int64 new_batch_dim, new_spatial_dim;
489 int64 dim_counter = 0;
490 if (is_rhs) {
491 CHECK(is_backprop);
492 for (int i = 0; i < activations->shape().rank(); ++i) {
493 if (i == activations_batch_dim) {
494 continue;
495 }
496 if (i == spatial_dimension_to_split) {
497 transpose_dims[dim_counter++] = activations_batch_dim;
498 new_batch_dim = pushed_counter;
499 pushed_counter++;
500 new_spatial_dim = pushed_counter;
501 }
502
503 if (i == dim_numbers.kernel_output_feature_dimension()) {
504 new_dim_numbers.set_kernel_output_feature_dimension(pushed_counter);
505 } else {
506 auto it = absl::c_find(dim_numbers.kernel_spatial_dimensions(), i);
507 if (it != dim_numbers.kernel_spatial_dimensions().end()) {
508 int64 j = it - dim_numbers.kernel_spatial_dimensions().begin();
509 new_dim_numbers.set_kernel_spatial_dimensions(j, pushed_counter);
510 }
511 }
512 transpose_dims[dim_counter++] = i;
513 pushed_counter++;
514 }
515
516 activations_batch_dim = new_batch_dim;
517 spatial_dimension_to_split = new_spatial_dim;
518 TF_ASSIGN_OR_RETURN(activations,
519 MakeTransposeHlo(activations, transpose_dims));
520
521 new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim);
522
523 } else {
524 for (int i = 0; i < activations->shape().rank(); ++i) {
525 if (i == activations_batch_dim) {
526 continue;
527 }
528 if (i == spatial_dimension_to_split) {
529 transpose_dims[dim_counter++] = activations_batch_dim;
530 new_batch_dim = pushed_counter;
531 pushed_counter++;
532 new_spatial_dim = pushed_counter;
533 }
534
535 if (is_backprop && i == dim_numbers.input_batch_dimension()) {
536 new_dim_numbers.set_input_batch_dimension(pushed_counter);
537 } else if (i == dim_numbers.input_feature_dimension()) {
538 new_dim_numbers.set_input_feature_dimension(pushed_counter);
539 } else {
540 auto it = absl::c_find(dim_numbers.input_spatial_dimensions(), i);
541 if (it != dim_numbers.input_spatial_dimensions().end()) {
542 int64 j = it - dim_numbers.input_spatial_dimensions().begin();
543 new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
544 }
545 }
546 transpose_dims[dim_counter++] = i;
547 pushed_counter++;
548 }
549
550 activations_batch_dim = new_batch_dim;
551 spatial_dimension_to_split = new_spatial_dim;
552 TF_ASSIGN_OR_RETURN(activations,
553 MakeTransposeHlo(activations, transpose_dims));
554
555 if (is_backprop) {
556 new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
557 } else {
558 new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
559 }
560 }
561
562 dim_numbers = new_dim_numbers;
563 }
564
565 return SpaceNextToBatchDetails{activations, transpose_dims};
566 }
567
568 StatusOr<HloInstruction*>
IncreaseSpatialSizeOnSpaceToBatchedShape(HloInstruction * activations,int64 batch_dimension,int64 old_batch_size,int64 spatial_dimension,int64 new_spatial_dim_size)569 ConvolutionVisitor::IncreaseSpatialSizeOnSpaceToBatchedShape(
570 HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
571 int64 spatial_dimension, int64 new_spatial_dim_size) {
572 CHECK_EQ(batch_dimension + 1, spatial_dimension);
573 std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
574 activations->shape().dimensions().end());
575
576 const int64 new_batch_size = activations->shape().dimensions(batch_dimension);
577 int64 spatial_dim_size = activations->shape().dimensions(spatial_dimension);
578 const int64 reshaped_space_size =
579 spatial_dim_size * new_batch_size / old_batch_size;
580
581 VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
582 << new_batch_size << " old_batch_size " << old_batch_size;
583 new_dimensions[spatial_dimension] = reshaped_space_size;
584 new_dimensions[batch_dimension] = old_batch_size;
585
586 // Reshape the output of the new conv into the old convolutions shape.
587 TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
588 MakeReshapeHlo(new_dimensions, activations));
589
590 VLOG(3) << "First reshape done";
591 PaddingConfig padding_config =
592 MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
593 padding_config.mutable_dimensions(spatial_dimension)
594 ->set_edge_padding_high(new_spatial_dim_size * new_batch_size /
595 old_batch_size -
596 reshaped_space_size);
597 padding_config.mutable_dimensions(spatial_dimension)->set_edge_padding_low(0);
598 HloInstruction* padding =
599 computation_->AddInstruction(HloInstruction::CreateConstant(
600 LiteralUtil::Zero(reshaped_activations->shape().element_type())));
601
602 TF_ASSIGN_OR_RETURN(
603 reshaped_activations,
604 MakePadHlo(reshaped_activations, padding, padding_config));
605
606 std::vector<int64> reshape_back_dims(
607 reshaped_activations->shape().dimensions().begin(),
608 reshaped_activations->shape().dimensions().end());
609
610 reshape_back_dims[spatial_dimension] = new_spatial_dim_size;
611 reshape_back_dims[batch_dimension] = new_batch_size;
612
613 TF_ASSIGN_OR_RETURN(HloInstruction * activations_new,
614 MakeReshapeHlo(reshape_back_dims, reshaped_activations));
615
616 VLOG(3) << "Size increased activations " << activations_new->ToString();
617
618 return activations_new;
619 }
620
Run()621 StatusOr<bool> ConvolutionVisitor::Run() {
622 for (auto conv : conv_visitor_list_) {
623 if (convs_to_visit_.count(conv) > 0) {
624 TF_CHECK_OK(PerformSpaceToBatchOnConvolution(conv));
625 }
626 }
627 conv_visitor_list_.clear();
628 convs_to_visit_.clear();
629 // Iterate through all instructions that we could not propagate through, and
630 // turn their operands from batch-to-space as needed.
631 for (auto instr : non_propagatable_instrs_) {
632 if (instr->opcode() == HloOpcode::kConvolution) {
633 VLOG(1) << "Instr " << instr->ToString();
634 }
635 // Try to propagate on backprop filters
636 if (instr->opcode() == HloOpcode::kConvolution &&
637 !IsConvSuitableForSpaceToBatch(instr)) {
638 HloInstruction* producer = nullptr;
639 if (old_to_new_instrs_.contains(instr->mutable_operand(0))) {
640 producer = instr->mutable_operand(0);
641 } else if (old_to_new_instrs_.contains(instr->mutable_operand(1))) {
642 producer = instr->mutable_operand(1);
643 }
644 if (producer) {
645 if (CanPropagate(instr, producer, /*last_try=*/true)) {
646 bool needs_further_propagation;
647 TF_ASSIGN_OR_RETURN(needs_further_propagation,
648 Propagate(instr, producer));
649 TF_CHECK_OK(computation_->ReplaceInstruction(
650 instr, old_to_new_instrs_[instr]));
651 continue;
652 }
653 }
654 }
655 VLOG(1) << "Could not eventually propagate through " << instr->ToString();
656 absl::flat_hash_map<int64, HloInstruction*> operand_map;
657 for (int64 i = 0; i < instr->operand_count(); ++i) {
658 if (old_to_new_instrs_.count(instr->mutable_operand(i))) {
659 TF_ASSIGN_OR_RETURN(operand_map[i],
660 BatchToSpace(instr->mutable_operand(i)));
661 }
662 }
663 for (auto entry : operand_map) {
664 TF_CHECK_OK(instr->ReplaceOperandWith(entry.first, entry.second));
665 }
666 }
667 non_propagatable_instrs_.clear();
668 return changed_;
669 }
670
IsTrivialElementwise(HloInstruction * hlo)671 bool IsTrivialElementwise(HloInstruction* hlo) {
672 if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng ||
673 hlo->opcode() == HloOpcode::kCopy ||
674 hlo->opcode() == HloOpcode::kConstant ||
675 hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) {
676 return false;
677 }
678 return hlo->IsElementwise();
679 }
680
CanPropagate(HloInstruction * consumer,HloInstruction * producer,bool last_try)681 bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
682 HloInstruction* producer, bool last_try) {
683 if (IsTrivialElementwise(consumer)) {
684 VLOG(2) << "Doing propagation check on elementwise op: "
685 << consumer->ToString();
686
687 HloInstruction* pivot_operand = nullptr;
688 for (int64 i = 0; i < consumer->operand_count(); ++i) {
689 auto old_producer = consumer->mutable_operand(i);
690 std::vector<HloInstruction*> to_transform;
691 const bool broadcast_or_constant =
692 (old_producer->opcode() == HloOpcode::kConstant) ||
693 (old_producer->opcode() == HloOpcode::kBroadcast &&
694 IsBroadcastPropagatable(old_producer, producer)) ||
695 (consumer->IsElementwiseBinary() &&
696 old_producer->opcode() == HloOpcode::kBroadcast &&
697 IsBroadcastTree(old_producer, producer, to_transform));
698
699 if (!old_to_new_instrs_.contains(old_producer) &&
700 !broadcast_or_constant) {
701 VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString()
702 << " because operand " << old_producer->ToString()
703 << " isn't ready ";
704 return false;
705 } else {
706 if (broadcast_or_constant) {
707 VLOG(2) << "Skipping on " << old_producer->ToString();
708 continue;
709 }
710
711 CHECK(old_to_new_instrs_.contains(old_producer));
712
713 CHECK(instr_to_dim_map_.contains(old_producer));
714 if (pivot_operand == nullptr) {
715 pivot_operand = old_producer;
716 VLOG(2) << "Elementwise op: pivot " << old_producer->ToString();
717 } else {
718 if (instr_to_dim_map_[pivot_operand] !=
719 instr_to_dim_map_[old_producer]) {
720 VLOG(2) << "Elementwise op: checking for shape equivalence "
721 << consumer->ToString()
722 << " failed due to changed batch space ordering ";
723 return false;
724 }
725 auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
726 auto pivot_permute_dims = instr_to_dim_permute_map_[pivot_new_instr];
727 auto new_instr = old_to_new_instrs_[old_producer];
728 auto permute_dims = instr_to_dim_permute_map_[new_instr];
729 for (int j = 0; j < pivot_permute_dims.size(); ++j) {
730 // Ensure the dimension mapping is the same.
731 if (pivot_permute_dims[j] != permute_dims[j]) {
732 VLOG(2) << "Elementwise op: checking for shape equivalence "
733 << consumer->ToString()
734 << " failed due to permuted dimensions ";
735 return false;
736 }
737
738 // Make sure all other dimensions are of the same size.
739 if (pivot_new_instr->shape().dimensions(j) !=
740 new_instr->shape().dimensions(j)) {
741 if (!((consumer->IsElementwiseBinary() ||
742 consumer->opcode() == HloOpcode::kSelect) &&
743 j == instr_to_dim_map_[pivot_operand].second)) {
744 VLOG(2) << "Elementwise op: checking for shape equivalence "
745 << consumer->ToString()
746 << " failed due to changed shape sizes ";
747 return false;
748 }
749 }
750 }
751 }
752 }
753 }
754 }
755
756 if (consumer->opcode() == HloOpcode::kConvolution) {
757 VLOG(1) << "Checking if conv is supported for propagation "
758 << consumer->ToString();
759 if (IsConvSuitableForSpaceToBatch(consumer)) {
760 if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
761 return false;
762 }
763 auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
764 // Make sure that the space dimension is the same across the producer
765 // and consumer.
766 if (consumer->convolution_dimension_numbers().input_spatial_dimensions(
767 get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) {
768 return false;
769 }
770 // Make sure that the batch dimension is the same across the producer
771 // and consumer.
772 if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
773 dim_map_val_op_0.first) {
774 return false;
775 }
776
777 return true;
778 }
779
780 // Check for space-to-depth readiness here. Note this is not done in
781 // SupportedOpForPropagation because the readiness is dependent upon
782 // space-to-batchedness of the operands.
783
784 // We currently only support stride of 1.
785 if (consumer->window()
786 .dimensions(get_chosen_spatial_dim(consumer))
787 .stride() != 1) {
788 return false;
789 }
790
791 // Same reason why we give up on batch group counts applies to features in
792 // backprop.
793 if (consumer->feature_group_count() != 1) {
794 return false;
795 }
796
797 VLOG(2) << "Checking for backprop filter conv propagatability";
798 CHECK_EQ(consumer->operand_count(), 2);
799
800 auto activations = consumer->mutable_operand(0);
801 auto kernel = consumer->mutable_operand(1);
802
803 auto win_dims =
804 consumer->window().dimensions(get_chosen_spatial_dim(consumer));
805 const int64 rhs_dilation = win_dims.window_dilation();
806
807 // If the rhs_dilation is absent, we want both LHS and RHS to be space-to-
808 // batched for propagating on backprop convolutions.
809 if (!last_try || rhs_dilation == 1) {
810 if (!old_to_new_instrs_.contains(kernel) ||
811 !old_to_new_instrs_.contains(activations)) {
812 return false;
813 }
814 }
815
816 if (!old_to_new_instrs_.contains(kernel) &&
817 !old_to_new_instrs_.contains(activations)) {
818 return false;
819 }
820
821 if (!old_to_new_instrs_.contains(kernel)) {
822 const int64 rhs_batch =
823 kernel->shape().dimensions(consumer->convolution_dimension_numbers()
824 .kernel_input_feature_dimension());
825 auto dim_map_val_op_0 = instr_to_dim_map_[activations];
826 const int64 old_batch_dim = dim_map_val_op_0.first;
827 const int64 old_space_dim = dim_map_val_op_0.second;
828 auto first_operand = old_to_new_instrs_[activations];
829 auto permute_dims_first_operand =
830 instr_to_dim_permute_map_[first_operand];
831 const int64 new_batch_dim =
832 DimLookUp(permute_dims_first_operand, old_batch_dim);
833 const int64 new_space_dim =
834 DimLookUp(permute_dims_first_operand, old_space_dim);
835 const int64 lhs_batch = first_operand->shape().dimensions(new_batch_dim);
836
837 if (first_operand->shape().dimensions(new_space_dim) % rhs_dilation !=
838 0) {
839 return false;
840 }
841 // Because we want to convert activations into a space-to-batched version
842 // only for backprop filter convolutions, we want to make sure that the
843 // batch dimensions (feature dimensions, technically) are same sized.
844 // Since LHS is already space-to-batched, we need to account for it too.
845 if (rhs_batch * kNumSplits != lhs_batch) {
846 return false;
847 }
848
849 // If kernel have not been propagated through, we can do
850 // space-to-batch on them provided kernel has been propagated.
851 VLOG(2)
852 << "Backprop filter conv ready for propagation: activations ready, "
853 " kernel will be space-to-batched";
854 return true;
855 }
856
857 if (!old_to_new_instrs_.contains(activations)) {
858 const int64 lhs_batch = activations->shape().dimensions(
859 consumer->convolution_dimension_numbers().input_feature_dimension());
860 auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
861 const int64 old_batch_dim = dim_map_val_op_1.first;
862 auto second_operand = old_to_new_instrs_[kernel];
863 auto permute_dims_second_operand =
864 instr_to_dim_permute_map_[second_operand];
865 const int64 new_batch_dim =
866 DimLookUp(permute_dims_second_operand, old_batch_dim);
867 const int64 rhs_batch = second_operand->shape().dimensions(new_batch_dim);
868
869 // Because we want to convert activations into a space-to-batched version
870 // only for backprop filter convolutions, we want to make sure that the
871 // batch dimensions (feature dimensions, technically) are same sized.
872 // Since RHS is already space-to-batched, we need to account for it too.
873 if (rhs_batch != kNumSplits * lhs_batch) {
874 return false;
875 }
876
877 // If activations have not been propagated through, we can do
878 // space-to-batch on them provided kernel has been propagated.
879 VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
880 " activations will be space-to-batched";
881 return true;
882 }
883
884 auto first_operand = old_to_new_instrs_[activations];
885 auto dim_map_val_op_0 = instr_to_dim_map_[activations];
886 auto second_operand = old_to_new_instrs_[kernel];
887 auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
888
889 auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
890 auto permute_dims_second_operand =
891 instr_to_dim_permute_map_[second_operand];
892
893 const int64 new_batch_dim_operand_0 =
894 DimLookUp(permute_dims_first_operand, dim_map_val_op_0.first);
895 const int64 new_space_dim_operand_0 =
896 DimLookUp(permute_dims_first_operand, dim_map_val_op_0.second);
897
898 const int64 new_batch_dim_operand_1 =
899 DimLookUp(permute_dims_second_operand, dim_map_val_op_1.first);
900 const int64 new_space_dim_operand_1 =
901 DimLookUp(permute_dims_second_operand, dim_map_val_op_1.second);
902
903 if (first_operand->shape().dimensions(new_batch_dim_operand_0) !=
904 second_operand->shape().dimensions(new_batch_dim_operand_1)) {
905 VLOG(2) << "Backprop filter conv not ready for propagation because batch "
906 "dimensions don't line up";
907 return false;
908 }
909
910 if (first_operand->shape().dimensions(new_space_dim_operand_0) >
911 rhs_dilation *
912 second_operand->shape().dimensions(new_space_dim_operand_1)) {
913 VLOG(2) << "Backprop filter conv not ready for propagation because of "
914 "dilation factor mismatch";
915 return false;
916 }
917
918 VLOG(2) << "Backprop filter conv ready for propagation";
919
920 return true;
921 }
922
923 if (consumer->opcode() == HloOpcode::kReduceWindow ||
924 consumer->opcode() == HloOpcode::kReduce) {
925 for (int64 i = 0; i < consumer->operand_count(); ++i) {
926 auto old_producer = consumer->mutable_operand(i);
927 if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
928 return false;
929 }
930 }
931 }
932
933 if (consumer->opcode() == HloOpcode::kSelectAndScatter) {
934 // We currently only support adds in the scatter.
935 auto scatter_comp = consumer->scatter();
936 if (!Match(scatter_comp->root_instruction(),
937 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
938 return false;
939 }
940
941 for (int64 i = 0; i < consumer->operand_count(); ++i) {
942 auto old_producer = consumer->mutable_operand(i);
943 if (i < 2 && !old_to_new_instrs_.contains(old_producer)) {
944 return false;
945 }
946 }
947
948 auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
949 auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
950 auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
951
952 auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
953 auto permute_dims_second_operand =
954 instr_to_dim_permute_map_[second_operand];
955
956 // The permuting must match.
957 if (permute_dims_first_operand != permute_dims_second_operand) {
958 VLOG(2) << "Can't propagate through select and scatter due to "
959 "permutation mismatch";
960 return false;
961 }
962
963 const int64 old_batch_dim = dim_map_val_op_0.first;
964 const int64 old_space_dim = dim_map_val_op_0.second;
965
966 const int64 new_batch_dim =
967 DimLookUp(permute_dims_first_operand, old_batch_dim);
968 const int64 new_space_dim =
969 DimLookUp(permute_dims_first_operand, old_space_dim);
970
971 if (first_operand->shape().dimensions(new_batch_dim) !=
972 second_operand->shape().dimensions(new_batch_dim)) {
973 VLOG(2)
974 << "Can't propagate through select and scatter due to dim mismatch";
975 return false;
976 }
977
978 const int64 stride = consumer->window().dimensions(old_space_dim).stride();
979 const int64 pad_high =
980 consumer->window().dimensions(old_space_dim).padding_high();
981 const int64 pad_low =
982 consumer->window().dimensions(old_space_dim).padding_low();
983
984 if ((first_operand->shape().dimensions(new_space_dim) + pad_high +
985 pad_low) /
986 stride !=
987 second_operand->shape().dimensions(new_space_dim)) {
988 VLOG(2) << "Can't propagate through select and scatter due to stride "
989 "mismatch";
990 return false;
991 }
992 VLOG(1) << "Can propagate through select and scatter";
993 return true;
994 }
995 return true;
996 }
997
PropagateOnBroadcast(HloInstruction * consumer,HloInstruction * producer)998 void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer,
999 HloInstruction* producer) {
1000 auto new_producer = old_to_new_instrs_[producer];
1001 auto permute_dims = instr_to_dim_permute_map_[new_producer];
1002 auto dim_map_val = instr_to_dim_map_[producer];
1003
1004 const int64 old_batch_dim = dim_map_val.first;
1005 const int64 old_space_dim = dim_map_val.second;
1006
1007 auto orig_broadcast_dims = consumer->dimensions();
1008
1009 bool batch_is_broadcasted =
1010 absl::c_linear_search(orig_broadcast_dims, old_batch_dim);
1011 const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1012 const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
1013
1014 bool map_found = broadcast_map_.contains(consumer);
1015 if (map_found) {
1016 // Check if we previously had created the same broadcast.
1017 for (auto previous_broadcast : broadcast_map_[consumer]) {
1018 if (ShapeUtil::CompatibleIgnoringElementType(previous_broadcast->shape(),
1019 new_producer->shape())) {
1020 return;
1021 }
1022 }
1023 }
1024
1025 std::vector<int64> final_shape_dims(
1026 new_producer->shape().dimensions().begin(),
1027 new_producer->shape().dimensions().end());
1028 if (batch_is_broadcasted) {
1029 final_shape_dims[new_batch_dim] =
1030 producer->shape().dimensions(old_batch_dim);
1031 final_shape_dims[new_space_dim] *= kNumSplits;
1032 }
1033
1034 std::vector<int64> broadcast_dims;
1035 for (auto j : consumer->dimensions()) {
1036 broadcast_dims.push_back(DimLookUp(permute_dims, j));
1037 }
1038 auto new_broadcast = MakeBroadcastHlo(consumer->mutable_operand(0),
1039 broadcast_dims, final_shape_dims);
1040 VLOG(1) << "Created broadcast " << new_broadcast->ToString();
1041
1042 if (batch_is_broadcasted) {
1043 new_broadcast =
1044 MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast)
1045 .ValueOrDie();
1046 VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString();
1047 }
1048
1049 if (!map_found) {
1050 absl::flat_hash_set<HloInstruction*> set_of_broadcasts;
1051 broadcast_map_[consumer] = set_of_broadcasts;
1052 }
1053 broadcast_map_[consumer].insert(new_broadcast);
1054 }
1055
RewriteBroadcastTree(HloInstruction * producer,std::vector<HloInstruction * > & instructions_to_transform)1056 void ConvolutionVisitor::RewriteBroadcastTree(
1057 HloInstruction* producer,
1058 std::vector<HloInstruction*>& instructions_to_transform) {
1059 CHECK(old_to_new_instrs_.contains(producer));
1060 for (auto instr : instructions_to_transform) {
1061 if (instr->opcode() == HloOpcode::kBroadcast) {
1062 PropagateOnBroadcast(instr, producer);
1063 } else if (IsTrivialElementwise(instr)) {
1064 Propagate(instr, /*producer=*/instr->mutable_operand(0)).ValueOrDie();
1065 } else {
1066 LOG(FATAL) << "Unsupported opcode in RewriteBroadcastTree";
1067 }
1068 }
1069 }
1070
IsBroadcastTree(HloInstruction * op,HloInstruction * consumer,std::vector<HloInstruction * > & instructions_to_transform)1071 bool ConvolutionVisitor::IsBroadcastTree(
1072 HloInstruction* op, HloInstruction* consumer,
1073 std::vector<HloInstruction*>& instructions_to_transform) {
1074 if (op->opcode() == HloOpcode::kBroadcast) {
1075 // We want to ensure that the broadcast did not happen on the space and
1076 // batch dimensions.
1077 if (IsBroadcastPropagatable(op, consumer)) {
1078 instructions_to_transform.push_back(op);
1079 return true;
1080 } else {
1081 return false;
1082 }
1083 }
1084 if (Match(op, m::ConstantScalar())) {
1085 return true;
1086 }
1087 if (!IsTrivialElementwise(op)) {
1088 return false;
1089 }
1090 for (int64 i = 0; i < op->operand_count(); ++i) {
1091 if (!IsBroadcastTree(op->mutable_operand(i), consumer,
1092 instructions_to_transform)) {
1093 return false;
1094 }
1095 }
1096 instructions_to_transform.push_back(op);
1097 return true;
1098 }
1099
IsBroadcastPropagatable(HloInstruction * broadcast,HloInstruction * old_other_op)1100 bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast,
1101 HloInstruction* old_other_op) {
1102 CHECK_EQ(broadcast->opcode(), HloOpcode::kBroadcast);
1103 CHECK(instr_to_dim_map_.contains(old_other_op));
1104
1105 auto result = instr_to_dim_map_[old_other_op];
1106 const int64 space_dim = result.second;
1107 auto broadcast_dims = broadcast->dimensions();
1108 return !absl::c_linear_search(broadcast_dims, space_dim);
1109 }
1110
SupportedOpForPropagation(HloInstruction * consumer,HloInstruction * producer)1111 bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
1112 HloInstruction* producer) {
1113 if (IsTrivialElementwise(consumer)) {
1114 for (int64 i = 0; i < consumer->operand_count(); ++i) {
1115 if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1116 if (!IsBroadcastPropagatable(consumer->mutable_operand(i), producer)) {
1117 VLOG(2) << "Could not propagate through broadcast";
1118 return false;
1119 }
1120 }
1121 }
1122 return true;
1123 }
1124
1125 if (consumer->opcode() == HloOpcode::kConvolution) {
1126 return true;
1127 }
1128
1129 if (consumer->opcode() == HloOpcode::kReduce) {
1130 // Support only the trivial case where both batch and split spatial dim are
1131 // being reduced
1132
1133 auto reduce_dims = consumer->dimensions();
1134 auto result = instr_to_dim_map_[consumer->mutable_operand(0)];
1135 const int64 batch_dim = result.first;
1136 const int64 space_dim = result.second;
1137 VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim
1138 << " space_dim " << space_dim << " reduce "
1139 << consumer->ToString();
1140 return absl::c_linear_search(reduce_dims, batch_dim) &&
1141 absl::c_linear_search(reduce_dims, space_dim);
1142 }
1143
1144 if (consumer->opcode() == HloOpcode::kReduceWindow &&
1145 consumer->shape().IsTuple()) {
1146 // TODO (b/73062247) variadic reduce window is not yet supported.
1147 return false;
1148 }
1149 if (consumer->opcode() == HloOpcode::kReduceWindow ||
1150 consumer->opcode() == HloOpcode::kSelectAndScatter) {
1151 auto first_operand = consumer->mutable_operand(0);
1152 auto window = consumer->window();
1153 if (instr_to_dim_map_.count(first_operand) <= 0) {
1154 VLOG(1) << "Dim map not found on windowed operand. Window dim count "
1155 << window.dimensions().size();
1156 return false;
1157 }
1158 // Disallow windowing on on the batch dim
1159 auto result = instr_to_dim_map_[first_operand];
1160 const int64 old_batch_dim = result.first;
1161 const int64 old_space_dim = result.second;
1162 if (window.dimensions(old_batch_dim).size() != 1) {
1163 return false;
1164 }
1165
1166 // Only allow no-low-padding cases.
1167 if (window.dimensions(old_space_dim).padding_low() != 0) {
1168 return false;
1169 }
1170
1171 // Only allow small high pads.
1172 if (window.dimensions(old_space_dim).padding_high() >
1173 window.dimensions(old_space_dim).size()) {
1174 return false;
1175 }
1176
1177 // Operand 0 must have been propagated through
1178 if (old_to_new_instrs_.count(first_operand) <= 0) {
1179 return false;
1180 }
1181
1182 auto new_operand = old_to_new_instrs_[first_operand];
1183 auto permute_dims = instr_to_dim_permute_map_[new_operand];
1184 const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
1185
1186 // Make sure that the stride lines up.
1187 if (window.dimensions(old_space_dim).size() != 1) {
1188 if (new_operand->shape().dimensions(new_space_dim) %
1189 window.dimensions(old_space_dim).stride() !=
1190 0) {
1191 return false;
1192 }
1193 }
1194
1195 return true;
1196 }
1197
1198 return false;
1199 }
1200
Propagate(HloInstruction * consumer,HloInstruction * producer)1201 StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
1202 HloInstruction* producer) {
1203 auto computation = consumer->parent();
1204 if (IsTrivialElementwise(consumer)) {
1205 auto dim_map_val = instr_to_dim_map_[producer];
1206 auto new_consumer = computation->AddInstruction(consumer->Clone());
1207
1208 bool is_pivot_producer_modified = false;
1209 // For elementwise binary ops, both of whose operands have been space-to-
1210 // batched, if their new spatial sizes don't match, choose the bigger one
1211 // as the producer.
1212 if (consumer->IsElementwiseBinary() ||
1213 consumer->opcode() == HloOpcode::kSelect) {
1214 int64 pivot_operand_number = -1;
1215 HloInstruction* pivot_operand = nullptr;
1216 for (int i = 0; i < consumer->operand_count(); ++i) {
1217 if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1218 continue;
1219 }
1220 auto operand = consumer->mutable_operand(i);
1221 if (old_to_new_instrs_.contains(operand)) {
1222 if (pivot_operand_number == -1 ||
1223 old_to_new_instrs_[pivot_operand]->shape().dimensions() <
1224 old_to_new_instrs_[operand]->shape().dimensions()) {
1225 is_pivot_producer_modified = true;
1226 pivot_operand_number = i;
1227 pivot_operand = consumer->mutable_operand(pivot_operand_number);
1228 }
1229 }
1230 }
1231 if (pivot_operand_number != -1) {
1232 producer = pivot_operand;
1233 }
1234 }
1235
1236 for (int64 i = 0; i < consumer->operand_count(); ++i) {
1237 std::vector<HloInstruction*> instructions_to_transform;
1238
1239 if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1240 auto broadcast = consumer->mutable_operand(i);
1241 PropagateOnBroadcast(broadcast, producer);
1242 HloInstruction* new_broadcast = nullptr;
1243 auto new_producer = old_to_new_instrs_[producer];
1244 for (auto previous_broadcast : broadcast_map_[broadcast]) {
1245 if (ShapeUtil::CompatibleIgnoringElementType(
1246 previous_broadcast->shape(), new_producer->shape())) {
1247 new_broadcast = previous_broadcast;
1248 break;
1249 }
1250 }
1251 CHECK_NE(new_broadcast, nullptr);
1252 TF_CHECK_OK(
1253 new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
1254 } else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
1255 HloInstruction* operand_to_use = nullptr;
1256 auto result = instr_to_dim_map_[producer];
1257 const int64 old_batch_dim = result.first;
1258 const int64 old_space_dim = result.second;
1259 const int64 old_batch_size =
1260 producer->shape().dimensions(old_batch_dim);
1261 HloInstruction* new_instr =
1262 old_to_new_instrs_[consumer->mutable_operand(i)];
1263 HloInstruction* pivot_new_instr = old_to_new_instrs_[producer];
1264
1265 auto permute_dims = instr_to_dim_permute_map_[new_instr];
1266 const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
1267 const int64 space_dim = DimLookUp(permute_dims, old_space_dim);
1268 const int64 batch_size = new_instr->shape().dimensions(batch_dim);
1269
1270 if (new_instr->shape().dimensions(space_dim) !=
1271 pivot_new_instr->shape().dimensions(space_dim)) {
1272 // Because we do not propagate through transposes, the batch should
1273 // always be followed by the split space dimension.
1274 CHECK_EQ(batch_dim + 1, space_dim);
1275
1276 // Reshape to 1D, pad to the producer's size, reshape back to 2D.
1277 std::vector<int64> new_dimensions(
1278 new_instr->shape().dimensions().begin(),
1279 new_instr->shape().dimensions().end());
1280 new_dimensions[space_dim] *= (batch_size / old_batch_size);
1281 new_dimensions[batch_dim] = old_batch_size;
1282
1283 TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
1284 MakeReshapeHlo(new_dimensions, new_instr));
1285
1286 const int64 pivot_space_size =
1287 pivot_new_instr->shape().dimensions(space_dim) * batch_size /
1288 old_batch_size;
1289
1290 CHECK(pivot_space_size > new_dimensions[space_dim] ||
1291 !is_pivot_producer_modified);
1292
1293 PaddingConfig padding_config =
1294 MakeNoPaddingConfig(reshape->shape().dimensions_size());
1295 padding_config.mutable_dimensions(space_dim)->set_edge_padding_high(
1296 pivot_space_size - new_dimensions[space_dim]);
1297 padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0);
1298 HloInstruction* padding =
1299 computation_->AddInstruction(HloInstruction::CreateConstant(
1300 LiteralUtil::Zero(reshape->shape().element_type())));
1301
1302 TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand,
1303 MakePadHlo(reshape, padding, padding_config));
1304
1305 TF_ASSIGN_OR_RETURN(
1306 operand_to_use,
1307 MakeReshapeHlo(pivot_new_instr->shape().dimensions(),
1308 padded_operand));
1309
1310 } else {
1311 operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)];
1312 }
1313 TF_CHECK_OK(
1314 new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
1315 } else if (consumer->IsElementwiseBinary() &&
1316 consumer->mutable_operand(i)->opcode() ==
1317 HloOpcode::kBroadcast &&
1318 IsBroadcastTree(consumer->mutable_operand(i), producer,
1319 instructions_to_transform)) {
1320 RewriteBroadcastTree(producer, instructions_to_transform);
1321 TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape(
1322 i, old_to_new_instrs_[consumer->mutable_operand(i)]));
1323 } else if (consumer->operand(i)->opcode() == HloOpcode::kConstant) {
1324 TF_ASSIGN_OR_RETURN(
1325 auto new_constant,
1326 PropagateOnConstant(consumer->mutable_operand(i), producer));
1327 TF_CHECK_OK(
1328 new_consumer->ReplaceOperandWithDifferentShape(i, new_constant));
1329 }
1330 }
1331 auto old_type = new_consumer->mutable_shape()->element_type();
1332 *(new_consumer->mutable_shape()) = old_to_new_instrs_[producer]->shape();
1333
1334 // The element type needs to be retained.
1335 new_consumer->mutable_shape()->set_element_type(old_type);
1336
1337 old_to_new_instrs_[consumer] = new_consumer;
1338 instr_to_dim_map_[consumer] = dim_map_val;
1339 CHECK(instr_to_dim_permute_map_.contains(old_to_new_instrs_[producer]));
1340 instr_to_dim_permute_map_[new_consumer] = std::vector<int64>(
1341 instr_to_dim_permute_map_[old_to_new_instrs_[producer]]);
1342
1343 VLOG(2) << " new_consumer " << new_consumer->ToString()
1344 << " old_to_new_instrs_[producer] "
1345 << old_to_new_instrs_[producer]->ToString() << " permute dims "
1346 << instr_to_dim_permute_map_.count(new_consumer);
1347
1348 return true;
1349 }
1350
1351 if (consumer->opcode() == HloOpcode::kConvolution) {
1352 if (IsConvSuitableForSpaceToBatch(consumer)) {
1353 TF_CHECK_OK(PropagateOnConv(consumer));
1354 return true;
1355 } else {
1356 TF_CHECK_OK(PropagateOnBackpropFilterConv(consumer));
1357 return false;
1358 }
1359 }
1360
1361 if (consumer->opcode() == HloOpcode::kReduce) {
1362 auto new_consumer = computation->AddInstruction(consumer->Clone());
1363 auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1364
1365 auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
1366 const int64 old_batch_dim = dim_map_val.first;
1367 const int64 old_space_dim = dim_map_val.second;
1368 auto permute_dims = instr_to_dim_permute_map_[first_operand];
1369 const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1370 const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
1371
1372 TF_ASSIGN_OR_RETURN(
1373 first_operand,
1374 SelectValidPortion(first_operand, consumer->mutable_operand(0),
1375 consumer->mutable_operand(1), new_batch_dim,
1376 new_space_dim, old_batch_dim, old_space_dim));
1377
1378 std::vector<int64> changed_dims(new_consumer->dimensions().size());
1379 for (int64 i = 0; i < new_consumer->dimensions().size(); ++i) {
1380 changed_dims[i] = DimLookUp(permute_dims, new_consumer->dimensions(i));
1381 }
1382 *(new_consumer->mutable_dimensions()) = changed_dims;
1383 // Replace operand 0.
1384 TF_CHECK_OK(
1385 new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
1386 // We do not set instr_to_dim_permute_map_ here because no further
1387 // propagation is needed here.
1388 old_to_new_instrs_[consumer] = new_consumer;
1389 instr_to_dim_map_[consumer] = dim_map_val;
1390
1391 // Since the resultant ordering of dimension is the same as before, no
1392 // further propagation is needed.
1393 return false;
1394 }
1395
1396 if (consumer->opcode() == HloOpcode::kReduceWindow ||
1397 consumer->opcode() == HloOpcode::kSelectAndScatter) {
1398 bool is_select_and_scatter =
1399 consumer->opcode() == HloOpcode::kSelectAndScatter;
1400 auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1401
1402 auto init_val = is_select_and_scatter ? consumer->mutable_operand(2)
1403 : consumer->mutable_operand(1);
1404 auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
1405 const int64 old_batch_dim = dim_map_val.first;
1406 const int64 old_space_dim = dim_map_val.second;
1407 auto permute_dims = instr_to_dim_permute_map_[first_operand];
1408 const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1409 const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
1410
1411 TF_ASSIGN_OR_RETURN(
1412 first_operand,
1413 SelectValidPortion(first_operand, consumer->mutable_operand(0),
1414 init_val, new_batch_dim, new_space_dim,
1415 old_batch_dim, old_space_dim));
1416
1417 // Calculate the required halo size
1418 auto new_shape = first_operand->shape();
1419 auto old_shape = consumer->mutable_operand(0)->shape();
1420
1421 const int64 new_batch_size = new_shape.dimensions(new_batch_dim);
1422 const int64 new_space_size = new_shape.dimensions(new_space_dim);
1423 const int64 stride = consumer->window().dimensions(old_space_dim).stride();
1424 const int64 window_size =
1425 consumer->window().dimensions(old_space_dim).size();
1426 const int64 last_overlap_point = ((new_space_size - 1) / stride) * stride;
1427 VLOG(1) << "last_overlap_point " << last_overlap_point << " window_size "
1428 << window_size << " new_space_size " << new_space_size;
1429
1430 const int64 halo_size = last_overlap_point + window_size - new_space_size;
1431 if (halo_size > 0) {
1432 TF_ASSIGN_OR_RETURN(first_operand,
1433 HaloDuplicateWithSlice(first_operand, new_space_dim,
1434 new_batch_dim, new_batch_size,
1435 /*low_padding=*/0,
1436 /*high_padding=*/0, halo_size,
1437 new_space_size, init_val));
1438 }
1439
1440 Window new_win;
1441 for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) {
1442 auto dim = ReverseDimLookUp(permute_dims, i);
1443 new_win.add_dimensions();
1444 new_win.mutable_dimensions(i)->set_stride(
1445 consumer->window().dimensions(dim).stride());
1446 new_win.mutable_dimensions(i)->set_size(
1447 consumer->window().dimensions(dim).size());
1448 if (i == old_space_dim) {
1449 new_win.mutable_dimensions(i)->set_padding_high(0);
1450 new_win.mutable_dimensions(i)->set_padding_low(0);
1451 } else {
1452 new_win.mutable_dimensions(i)->set_padding_high(
1453 consumer->window().dimensions(dim).padding_high());
1454 new_win.mutable_dimensions(i)->set_padding_low(
1455 consumer->window().dimensions(dim).padding_low());
1456 }
1457 new_win.mutable_dimensions(i)->set_window_dilation(
1458 consumer->window().dimensions(dim).window_dilation());
1459 new_win.mutable_dimensions(i)->set_base_dilation(
1460 consumer->window().dimensions(dim).base_dilation());
1461 new_win.mutable_dimensions(i)->set_window_reversal(
1462 consumer->window().dimensions(dim).window_reversal());
1463 }
1464
1465 new_shape = first_operand->shape();
1466
1467 HloInstruction* new_consumer = nullptr;
1468 if (is_select_and_scatter) {
1469 auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
1470
1471 auto select_comp = consumer->select();
1472
1473 auto scatter_comp = consumer->scatter();
1474 TF_ASSIGN_OR_RETURN(
1475 auto new_select_and_scatter_shape,
1476 ShapeInference::InferSelectAndScatterShape(
1477 new_shape, select_comp->ComputeProgramShape(), new_win,
1478 second_operand->shape(), init_val->shape(),
1479 scatter_comp->ComputeProgramShape()));
1480 new_consumer =
1481 computation_->AddInstruction(HloInstruction::CreateSelectAndScatter(
1482 new_select_and_scatter_shape, first_operand, select_comp, new_win,
1483 second_operand, init_val, scatter_comp));
1484 // Replace operand 0.
1485 TF_CHECK_OK(
1486 new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
1487 // Replace operand 1.
1488 TF_CHECK_OK(
1489 new_consumer->ReplaceOperandWithDifferentShape(1, second_operand));
1490 VLOG(2) << "New select and scatter " << new_consumer->ToString();
1491
1492 // If the window size was larger than the stride, there could be overlaps.
1493 // Such cases require updates from both overlaps to be applied.
1494 if (halo_size > 0) {
1495 const int64 rank = new_consumer->shape().rank();
1496
1497 const int64 batch_size =
1498 new_consumer->shape().dimensions(new_batch_dim);
1499
1500 std::vector<int64> start_indices(rank, 0),
1501 end_indices(new_consumer->shape().dimensions().begin(),
1502 new_consumer->shape().dimensions().end()),
1503 strides(rank, 1);
1504 start_indices[new_space_dim] = new_space_size;
1505 end_indices[new_space_dim] = new_space_size + halo_size;
1506 end_indices[new_batch_dim] = batch_size - 1;
1507
1508 // This is the slice from halo padding.
1509 TF_ASSIGN_OR_RETURN(
1510 HloInstruction * bottom,
1511 MakeSliceHlo(new_consumer, start_indices, end_indices, strides));
1512
1513 std::vector<int64> start_indices_top(rank, 0),
1514 end_indices_top(new_consumer->shape().dimensions().begin(),
1515 new_consumer->shape().dimensions().end());
1516 end_indices_top[new_space_dim] = halo_size;
1517 // The first batch has correct data.
1518 start_indices_top[new_batch_dim] = 1;
1519
1520 // This is the original area from where halo pad was extracted.
1521 TF_ASSIGN_OR_RETURN(HloInstruction * top,
1522 MakeSliceHlo(new_consumer, start_indices_top,
1523 end_indices_top, strides));
1524
1525 HloInstruction* default_fill =
1526 MakeBroadcastHlo(init_val, {}, top->shape().dimensions());
1527
1528 // Compare to see if the bottom area was changed.
1529 TF_ASSIGN_OR_RETURN(
1530 HloInstruction * bottom_compare,
1531 MakeCompareHlo(ComparisonDirection::kNe, bottom, default_fill));
1532
1533 // Take out only the changed values.
1534 TF_ASSIGN_OR_RETURN(
1535 HloInstruction * bottom_taken,
1536 MakeSelectHlo(bottom_compare, bottom, default_fill));
1537
1538 // Compare to see if the top area was changed.
1539 TF_ASSIGN_OR_RETURN(
1540 HloInstruction * top_compare,
1541 MakeCompareHlo(ComparisonDirection::kNe, top, default_fill));
1542
1543 // Take out only the changed values.
1544 TF_ASSIGN_OR_RETURN(HloInstruction * top_taken,
1545 MakeSelectHlo(top_compare, top, bottom_taken));
1546
1547 // This makes checks if the area was updated by both overlaps.
1548 TF_ASSIGN_OR_RETURN(
1549 HloInstruction * both_compare,
1550 MakeBinaryHlo(HloOpcode::kAnd, top_compare, bottom_compare));
1551
1552 // If it was, add them up.
1553 TF_ASSIGN_OR_RETURN(HloInstruction * both_added,
1554 MakeBinaryHlo(HloOpcode::kAdd, top, bottom));
1555
1556 // Pad the final result to the original shape.
1557 TF_ASSIGN_OR_RETURN(HloInstruction * final_selection,
1558 MakeSelectHlo(both_compare, both_added, top_taken));
1559
1560 PaddingConfig padding_config =
1561 MakeNoPaddingConfig(final_selection->shape().dimensions_size());
1562 padding_config.mutable_dimensions(new_batch_dim)
1563 ->set_edge_padding_low(1);
1564 padding_config.mutable_dimensions(new_space_dim)
1565 ->set_edge_padding_high(new_space_size);
1566 HloInstruction* padding =
1567 computation_->AddInstruction(HloInstruction::CreateConstant(
1568 LiteralUtil::Zero(final_selection->shape().element_type())));
1569
1570 TF_ASSIGN_OR_RETURN(
1571 final_selection,
1572 MakePadHlo(final_selection, padding, padding_config));
1573
1574 tensorflow::core::Bitmap b(batch_size * (new_space_size + halo_size));
1575 for (int k = 0; k < batch_size * (new_space_size + halo_size); ++k) {
1576 const int64 space_index = k % (new_space_size + halo_size);
1577 const int64 batch_index = (k / (new_space_size + halo_size));
1578 if (batch_index < 1 || space_index >= halo_size) {
1579 b.set(k);
1580 } else {
1581 b.clear(k);
1582 }
1583 }
1584
1585 auto arg_literal = LiteralUtil::CreateR1(b);
1586 VLOG(4) << "Slice mask created: arg literal " << arg_literal.ToString();
1587 HloInstruction* slice_mask = computation_->AddInstruction(
1588 HloInstruction::CreateConstant(std::move(arg_literal)));
1589
1590 std::vector<int64> slice_mask_reshape_dims(2);
1591 slice_mask_reshape_dims[0] = batch_size;
1592 slice_mask_reshape_dims[1] = (new_space_size + halo_size);
1593
1594 TF_ASSIGN_OR_RETURN(
1595 HloInstruction * slice_mask_reshaped,
1596 MakeReshapeHlo(slice_mask_reshape_dims, slice_mask));
1597
1598 // Broadcast the mask in all dimensions.
1599 HloInstruction* shape_mask = MakeBroadcastHlo(
1600 slice_mask_reshaped, {new_batch_dim, new_space_dim},
1601 final_selection->shape().dimensions());
1602
1603 TF_ASSIGN_OR_RETURN(
1604 new_consumer,
1605 MakeSelectHlo(shape_mask, new_consumer, final_selection));
1606 }
1607
1608 auto previous_shape =
1609 old_to_new_instrs_[consumer->mutable_operand(0)]->shape();
1610 std::vector<int64> start_indices(previous_shape.rank(), 0),
1611 end_indices(previous_shape.dimensions().begin(),
1612 previous_shape.dimensions().end()),
1613 strides(previous_shape.rank(), 1);
1614
1615 TF_ASSIGN_OR_RETURN(
1616 new_consumer,
1617 MakeSliceHlo(new_consumer, start_indices, end_indices, strides));
1618
1619 } else {
1620 auto reduce_comp = consumer->to_apply();
1621 TF_ASSIGN_OR_RETURN(auto new_reduce_window_shape,
1622 ShapeInference::InferReduceWindowShape(
1623 new_shape, init_val->shape(), new_win));
1624 new_consumer =
1625 computation_->AddInstruction(HloInstruction::CreateReduceWindow(
1626 new_reduce_window_shape, first_operand, init_val, new_win,
1627 reduce_comp));
1628 // Replace operand 0.
1629 TF_CHECK_OK(
1630 new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
1631 VLOG(1) << "New reduce window " << new_consumer->ToString();
1632 }
1633
1634 old_to_new_instrs_[consumer] = new_consumer;
1635 instr_to_dim_map_[consumer] = dim_map_val;
1636
1637 instr_to_dim_permute_map_[new_consumer] = std::vector<int64>(
1638 instr_to_dim_permute_map_[old_to_new_instrs_[consumer->mutable_operand(
1639 0)]]);
1640
1641 return true;
1642 }
1643
1644 LOG(FATAL) << "Trying to propagate through an unsupported instruction "
1645 << consumer->ToString();
1646 return true;
1647 }
1648
SelectValidPortion(HloInstruction * new_instr,HloInstruction * old_instr,HloInstruction * select_val,int64 new_batch_dim,int64 new_space_dim,int64 old_batch_dim,int64 old_space_dim)1649 StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
1650 HloInstruction* new_instr, HloInstruction* old_instr,
1651 HloInstruction* select_val, int64 new_batch_dim, int64 new_space_dim,
1652 int64 old_batch_dim, int64 old_space_dim) {
1653 auto new_shape = new_instr->shape();
1654 auto old_shape = old_instr->shape();
1655 VLOG(1) << "In SelectValidPortion new_batch_dim " << new_batch_dim
1656 << " new_space_dim " << new_space_dim << " old_batch_dim "
1657 << old_batch_dim << " old_space_dim " << old_space_dim;
1658 const int64 new_batch_size = new_shape.dimensions(new_batch_dim);
1659 const int64 new_space_size = new_shape.dimensions(new_space_dim);
1660 const int64 old_batch_size = old_shape.dimensions(old_batch_dim);
1661 const int64 old_space_size = old_shape.dimensions(old_space_dim);
1662 CHECK_EQ(new_batch_size % old_batch_size, 0)
1663 << " New batch size " << new_batch_size << " old batch size "
1664 << old_batch_size;
1665 const int64 num_splits = new_batch_size / old_batch_size;
1666 // Build a constant PRED to decide which elements in the split dimension
1667 // are from halo.
1668 tensorflow::core::Bitmap b(new_batch_size * new_space_size);
1669 for (int k = 0; k < new_batch_size * new_space_size; ++k) {
1670 const int64 space_index = k % new_space_size;
1671 const int64 batch_index = (k / new_space_size) % num_splits;
1672 if (batch_index * new_space_size + space_index < old_space_size) {
1673 b.set(k);
1674 } else {
1675 b.clear(k);
1676 }
1677 }
1678
1679 auto arg_literal = LiteralUtil::CreateR1(b);
1680 VLOG(4) << "Slice mask created: arg literal " << arg_literal.ToString();
1681 HloInstruction* slice_mask = computation_->AddInstruction(
1682 HloInstruction::CreateConstant(std::move(arg_literal)));
1683
1684 std::vector<int64> slice_mask_reshape_dims(2);
1685 slice_mask_reshape_dims[0] = new_batch_size;
1686 slice_mask_reshape_dims[1] = new_space_size;
1687
1688 TF_ASSIGN_OR_RETURN(HloInstruction * slice_mask_reshaped,
1689 MakeReshapeHlo(slice_mask_reshape_dims, slice_mask));
1690
1691 // Broadcast the mask in all dimensions of the activations.
1692 HloInstruction* shape_mask =
1693 MakeBroadcastHlo(slice_mask_reshaped, {new_batch_dim, new_space_dim},
1694 new_instr->shape().dimensions());
1695
1696 VLOG(1) << "Shape mask made " << shape_mask->ToString();
1697
1698 HloInstruction* zeroes =
1699 MakeBroadcastHlo(select_val, {}, new_instr->shape().dimensions());
1700
1701 TF_ASSIGN_OR_RETURN(new_instr, MakeSelectHlo(shape_mask, new_instr, zeroes));
1702
1703 return new_instr;
1704 }
1705
BatchToSpace(HloInstruction * old_instr)1706 StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
1707 HloInstruction* old_instr) {
1708 if (batch_to_space_map_.count(old_instr)) {
1709 CHECK_NE(batch_to_space_map_[old_instr], nullptr);
1710 return batch_to_space_map_[old_instr];
1711 }
1712
1713 auto result = instr_to_dim_map_[old_instr];
1714 const int64 old_batch_dim = result.first;
1715 const int64 old_space_dim = result.second;
1716
1717 const int64 old_batch_size = old_instr->shape().dimensions(old_batch_dim);
1718 CHECK(old_to_new_instrs_.contains(old_instr));
1719 auto new_instr = old_to_new_instrs_[old_instr];
1720 VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
1721 << old_space_dim << " old_instr " << old_instr->ToString()
1722 << "\n new_instr " << new_instr->ToString() << " permute dims "
1723 << instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
1724 << old_batch_size;
1725 CHECK(instr_to_dim_permute_map_.contains(new_instr));
1726 auto permute_dims = instr_to_dim_permute_map_[new_instr];
1727 const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
1728 const int64 space_dim = DimLookUp(permute_dims, old_space_dim);
1729 const int64 batch_size = new_instr->shape().dimensions(batch_dim);
1730
1731 std::vector<int64> new_dimensions(new_instr->shape().dimensions().begin(),
1732 new_instr->shape().dimensions().end());
1733 new_dimensions[space_dim] *= (batch_size / old_batch_size);
1734 new_dimensions[batch_dim] = old_batch_size;
1735 // Reshape the output of the new conv into the old convolutions shape.
1736 TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
1737 MakeReshapeHlo(new_dimensions, new_instr));
1738
1739 const int64 rank = old_instr->shape().rank();
1740 std::vector<int64> start_indices(rank, 0),
1741 end_indices(new_dimensions.begin(), new_dimensions.end()),
1742 strides(rank, 1);
1743 end_indices[space_dim] = old_instr->shape().dimensions(old_space_dim);
1744
1745 // This slicing is getting rid of the padding we added to evenly divide space.
1746 TF_ASSIGN_OR_RETURN(
1747 HloInstruction * output_slice,
1748 MakeSliceHlo(reshape, start_indices, end_indices, strides));
1749 VLOG(1) << "Batch to space slice " << output_slice->ToString();
1750 std::vector<int64> transpose_dims(permute_dims);
1751 TF_ASSIGN_OR_RETURN(HloInstruction * output_transpose,
1752 MakeTransposeHlo(output_slice, transpose_dims));
1753
1754 old_instr->SetupDerivedInstruction(output_transpose);
1755
1756 batch_to_space_map_[old_instr] = output_transpose;
1757 return output_transpose;
1758 }
1759
PropagateOnUsers(HloInstruction * old_conv)1760 Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) {
1761 std::queue<std::pair<HloInstruction*, HloInstruction*>> propagation_worklist;
1762
1763 if (old_conv->user_count() == 0) {
1764 TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space,
1765 BatchToSpace(old_conv));
1766 VLOG(1) << "Replacing the root instruction to "
1767 << batch_to_space->ToString();
1768 TF_CHECK_OK(computation_->ReplaceInstruction(old_conv, batch_to_space));
1769 VLOG(1) << "Replacement successful";
1770 return Status::OK();
1771 }
1772
1773 int64 iteration_count = 0;
1774 propagation_worklist.push(
1775 std::make_pair(old_conv, old_conv->mutable_operand(0)));
1776
1777 while (!propagation_worklist.empty()) {
1778 auto top = propagation_worklist.front();
1779 auto node = top.first;
1780 auto parent = top.second;
1781 VLOG(1) << "Traversing for propagation operating on " << node->ToString();
1782 propagation_worklist.pop();
1783
1784 // Don't work on the same node again.
1785 if (old_to_new_instrs_.count(node) > 0 && iteration_count != 0) {
1786 continue;
1787 }
1788
1789 bool needs_further_propagation = true;
1790 if (iteration_count != 0) {
1791 // Do the space-to-batch propagation on this node.
1792 TF_ASSIGN_OR_RETURN(needs_further_propagation, Propagate(node, parent));
1793 }
1794 iteration_count++;
1795 // If this is the root, no room for further propagation.
1796 if (node->parent()->root_instruction() == node) {
1797 // The below case does not need going back to space.
1798 if (!needs_further_propagation) {
1799 VLOG(1) << "Replacing the root instruction to "
1800 << old_to_new_instrs_[node]->ToString();
1801 TF_CHECK_OK(
1802 computation_->ReplaceInstruction(node, old_to_new_instrs_[node]));
1803 continue;
1804 }
1805
1806 TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, BatchToSpace(node));
1807 VLOG(1) << "Replacing the root instruction to "
1808 << batch_to_space->ToString();
1809 TF_CHECK_OK(computation_->ReplaceInstruction(node, batch_to_space));
1810 } else {
1811 if (!needs_further_propagation) {
1812 TF_CHECK_OK(
1813 computation_->ReplaceInstruction(node, old_to_new_instrs_[node]));
1814 continue;
1815 }
1816 // Insert all users into the queue, as long as the ops are supported and
1817 // the op is ready for propagation. If the op is unsupported, do
1818 // batch-to-space. If not ready, mark as non-propagatable.
1819 for (auto user : node->users()) {
1820 if (!SupportedOpForPropagation(user, node)) {
1821 VLOG(1) << "Unsupported op found " << user->ToString();
1822 TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space,
1823 BatchToSpace(node));
1824 for (int64 i = 0; i < user->operand_count(); ++i) {
1825 if (user->operand(i) == node) {
1826 TF_CHECK_OK(user->ReplaceOperandWith(i, batch_to_space));
1827 }
1828 }
1829 continue;
1830 }
1831 // If the instruction is ready for propagation, add it to the queue.
1832 if (CanPropagate(user, node)) {
1833 non_propagatable_instrs_.erase(user);
1834 propagation_worklist.push(std::make_pair(user, node));
1835 } else {
1836 // Mark it as non-propagatable for now, for later revisiting.
1837 non_propagatable_instrs_.insert(user);
1838 }
1839 }
1840 }
1841 }
1842 return Status::OK();
1843 }
1844
PropagateOnConv(HloInstruction * convolution)1845 Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
1846 auto activations_old = convolution->mutable_operand(0);
1847
1848 CHECK(old_to_new_instrs_.contains(activations_old));
1849 auto activations_new = old_to_new_instrs_[activations_old];
1850 auto permute_dims = instr_to_dim_permute_map_[activations_new];
1851
1852 auto original_conv_dims = convolution->convolution_dimension_numbers();
1853
1854 const int64 old_space_dim = original_conv_dims.input_spatial_dimensions(
1855 get_chosen_spatial_dim(convolution));
1856 const int64 old_split_dim_size =
1857 convolution->mutable_operand(0)->shape().dimensions(old_space_dim);
1858
1859 auto permuted_conv_dims_numbers = original_conv_dims;
1860
1861 int64 activations_batch_dim =
1862 DimLookUp(permute_dims, original_conv_dims.input_batch_dimension());
1863 int64 activations_feature_dim =
1864 DimLookUp(permute_dims, original_conv_dims.input_feature_dimension());
1865 permuted_conv_dims_numbers.set_input_batch_dimension(activations_batch_dim);
1866 permuted_conv_dims_numbers.set_input_feature_dimension(
1867 activations_feature_dim);
1868
1869 for (int64 i = 0; i < original_conv_dims.input_spatial_dimensions_size();
1870 ++i) {
1871 permuted_conv_dims_numbers.set_input_spatial_dimensions(
1872 i, DimLookUp(permute_dims,
1873 original_conv_dims.input_spatial_dimensions(i)));
1874 }
1875
1876 const int64 old_batch_dim = original_conv_dims.input_batch_dimension();
1877 const int64 old_batch_size =
1878 activations_old->shape().dimensions(old_batch_dim);
1879
1880 ConvDetails c =
1881 GetConvolutionDetails(convolution, permuted_conv_dims_numbers);
1882
1883 VLOG(1) << "Propagating on conv activations_batch_dim "
1884 << activations_batch_dim << " spatial_dimension_to_split "
1885 << c.spatial_dimension_to_split << " old_batch_size "
1886 << old_batch_size;
1887
1888 TF_ASSIGN_OR_RETURN(auto retval,
1889 BringSpaceNextToBatch(
1890 activations_new, permuted_conv_dims_numbers,
1891 c.spatial_dimension_to_split, activations_batch_dim));
1892 activations_new = retval.instr;
1893 std::vector<int64> trans_dims = retval.transpose_dims;
1894 CHECK(!trans_dims.empty());
1895 auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
1896 LiteralUtil::Zero(activations_new->shape().element_type())));
1897
1898 TF_ASSIGN_OR_RETURN(
1899 activations_new,
1900 SelectValidPortion(activations_new, activations_old, select_val,
1901 activations_batch_dim, c.spatial_dimension_to_split,
1902 old_batch_dim, old_space_dim));
1903 // Create the new convolution dim numbers.
1904 auto new_dim_numbers = permuted_conv_dims_numbers;
1905
1906 VLOG(1) << "spatial size " << c.spatial_size;
1907
1908 const int64 num_splits = kNumSplits;
1909 const int64 output_offsets = convolution->shape().dimensions(
1910 permuted_conv_dims_numbers.output_spatial_dimensions(
1911 get_chosen_spatial_dim(convolution)));
1912 const int64 output_offsets_per_split =
1913 CeilOfRatio(output_offsets, num_splits);
1914
1915 int64 spatial_split_size =
1916 CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
1917
1918 // Keep increasing the split size so that overall size isn't smaller than the
1919 // original spatial dimension. Unlike for the first space-to-batch'ed
1920 // convolution, while propagating, we can use the last halo_size as available
1921 // spatial size.
1922 while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) {
1923 spatial_split_size += c.stride;
1924 }
1925
1926 int64 slice_size = spatial_split_size + c.halo_size;
1927
1928 VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
1929 << slice_size;
1930
1931 const int64 new_space_size =
1932 activations_new->shape().dimensions(c.spatial_dimension_to_split);
1933 // In the below case, we cannot use the activations directly for Halo
1934 // Duplication. We must reshape them.
1935 if (spatial_split_size > new_space_size) {
1936 TF_ASSIGN_OR_RETURN(
1937 activations_new,
1938 IncreaseSpatialSizeOnSpaceToBatchedShape(
1939 activations_new, activations_batch_dim, old_batch_size,
1940 c.spatial_dimension_to_split, spatial_split_size));
1941
1942 TF_ASSIGN_OR_RETURN(
1943 activations_new,
1944 HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
1945 activations_batch_dim, old_batch_size,
1946 /*low_padding=*/c.base_dilation_factor != 1 &&
1947 c.inherent_low_padding != 0
1948 ? c.base_dilation_factor - 1
1949 : c.inherent_low_padding,
1950 c.inherent_high_padding,
1951 slice_size - spatial_split_size,
1952 old_split_dim_size));
1953 } else {
1954 // If the ideal spatial_split_size was smaller than the incoming spatial
1955 // dimension size, we don't need reshaping. Instead, we determine the
1956 // additional space available, and adjust the required slice size (and
1957 // thereby the halo size).
1958 if (spatial_split_size < new_space_size) {
1959 VLOG(3) << "Decreasing the spatial size while propagating";
1960 const int64 additional_space_present = spatial_split_size % c.stride;
1961 spatial_split_size = new_space_size;
1962 slice_size =
1963 spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride -
1964 additional_space_present,
1965 static_cast<int64>(0));
1966 }
1967
1968 TF_ASSIGN_OR_RETURN(
1969 activations_new,
1970 HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
1971 activations_batch_dim, old_batch_size,
1972 /*low_padding=*/c.base_dilation_factor != 1 &&
1973 c.inherent_low_padding != 0
1974 ? c.base_dilation_factor - 1
1975 : c.inherent_low_padding,
1976 c.inherent_high_padding,
1977 slice_size - spatial_split_size,
1978 old_split_dim_size));
1979 }
1980
1981 // We will generate output such that batch is followed by the split spatial
1982 // dimension.
1983 const int64 rank = (convolution->shape().rank());
1984 std::vector<int64> transpose_dims(rank);
1985 int dim_count = 0;
1986 std::map<int64, int64> dim_map;
1987
1988 for (int j = 0;
1989 j < permuted_conv_dims_numbers.output_spatial_dimensions_size(); ++j) {
1990 if (j == get_chosen_spatial_dim(convolution)) {
1991 dim_map[permuted_conv_dims_numbers.output_batch_dimension()] = dim_count;
1992 new_dim_numbers.set_output_batch_dimension(dim_count++);
1993 }
1994 dim_map[permuted_conv_dims_numbers.output_spatial_dimensions(j)] =
1995 dim_count;
1996 new_dim_numbers.set_output_spatial_dimensions(j, dim_count);
1997 dim_count++;
1998 }
1999
2000 dim_map[permuted_conv_dims_numbers.output_feature_dimension()] = dim_count;
2001 new_dim_numbers.set_output_feature_dimension(dim_count);
2002
2003 int p = 0;
2004 for (const auto& entry : dim_map) {
2005 transpose_dims[p] = entry.second;
2006 p++;
2007 }
2008
2009 auto new_window = convolution->window();
2010 new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2011 ->set_padding_high(c.high_padding_for_conv);
2012 new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2013 ->set_padding_low(c.low_padding_for_conv);
2014 TF_ASSIGN_OR_RETURN(
2015 HloInstruction * new_conv,
2016 MakeConvolveHlo(
2017 activations_new, /*rhs=*/convolution->mutable_operand(1),
2018 convolution->feature_group_count(), convolution->batch_group_count(),
2019 new_window, new_dim_numbers, convolution->precision_config(),
2020 /*preferred_element_type=*/convolution->shape().element_type()));
2021 convolution->SetupDerivedInstruction(new_conv);
2022
2023 old_to_new_instrs_[convolution] = new_conv;
2024 VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
2025
2026 instr_to_dim_map_[convolution] =
2027 std::make_pair(original_conv_dims.output_batch_dimension(),
2028 original_conv_dims.output_spatial_dimensions(
2029 get_chosen_spatial_dim(convolution)));
2030
2031 instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims);
2032
2033 convs_to_visit_.erase(convolution);
2034 return Status::OK();
2035 }
2036
SplitSpaceHelper(HloInstruction * activations,int64 spatial_dimension_to_split,int64 activations_batch_dim,int64 high_padding,int64 low_padding,int64 spatial_split_size,int64 num_splits)2037 StatusOr<HloInstruction*> ConvolutionVisitor::SplitSpaceHelper(
2038 HloInstruction* activations, int64 spatial_dimension_to_split,
2039 int64 activations_batch_dim, int64 high_padding, int64 low_padding,
2040 int64 spatial_split_size, int64 num_splits) {
2041 const int64 old_batch_size =
2042 activations->shape().dimensions(activations_batch_dim);
2043
2044 // Because we are splitting the spatial dimension, if convolution needed
2045 // padding in the spatial dimension, we materialize it.
2046 if (high_padding || low_padding) {
2047 PaddingConfig padding_config =
2048 MakeNoPaddingConfig(activations->shape().dimensions_size());
2049 padding_config.mutable_dimensions(spatial_dimension_to_split)
2050 ->set_edge_padding_high(high_padding);
2051 padding_config.mutable_dimensions(spatial_dimension_to_split)
2052 ->set_edge_padding_low(low_padding);
2053 HloInstruction* padding =
2054 computation_->AddInstruction(HloInstruction::CreateConstant(
2055 LiteralUtil::Zero(activations->shape().element_type())));
2056 TF_ASSIGN_OR_RETURN(activations,
2057 MakePadHlo(activations, padding, padding_config));
2058 }
2059 VLOG(1) << "Initial padded activations shape "
2060 << activations->shape().ToString() << " old_batch_size "
2061 << old_batch_size << " activations_batch_dim "
2062 << activations_batch_dim;
2063
2064 // Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
2065 // and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
2066 // in the spatial dimension, we generate a gather. E.g. if halo size was 2,
2067 // we'd create a shape of [24] using the gather, and reshape it into [6, 4]
2068 // (4 being the batch).
2069
2070 // The benefit of the above mentioned scheme is that it allows for batch
2071 // growth. Here are some examples of the size increases it causes for a 3x3
2072 // kernel.
2073 // with batch=1, [1,16] -> [4,4] -> [4,6] -> [1,24] growth of 8.
2074 // with batch=2, [2,16] -> [8,4] -> [8,6] -> [1,48] growth of 16.
2075 // with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24.
2076
2077 std::vector<int64> reshape_dimensions(
2078 activations->shape().dimensions().begin(),
2079 activations->shape().dimensions().end());
2080
2081 reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
2082 reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
2083
2084 TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
2085 MakeReshapeHlo(reshape_dimensions, activations));
2086
2087 return batch_increased_reshape;
2088 }
2089
2090 StatusOr<std::pair<HloInstruction*, std::vector<int64>>>
SplitSpace(HloInstruction * activations,ConvolutionDimensionNumbers & dim_numbers,int64 & spatial_dimension_to_split,int64 & activations_batch_dim,int64 high_padding,int64 low_padding,int64 spatial_split_size,int64 num_splits,bool is_backprop,bool is_rhs)2091 ConvolutionVisitor::SplitSpace(HloInstruction* activations,
2092 ConvolutionDimensionNumbers& dim_numbers,
2093 int64& spatial_dimension_to_split,
2094 int64& activations_batch_dim, int64 high_padding,
2095 int64 low_padding, int64 spatial_split_size,
2096 int64 num_splits, bool is_backprop,
2097 bool is_rhs) {
2098 TF_ASSIGN_OR_RETURN(auto retval,
2099 BringSpaceNextToBatch(
2100 activations, dim_numbers, spatial_dimension_to_split,
2101 activations_batch_dim, is_backprop, is_rhs));
2102
2103 activations = retval.instr;
2104 std::vector<int64> transpose_dims = retval.transpose_dims;
2105 TF_ASSIGN_OR_RETURN(
2106 auto new_activations,
2107 SplitSpaceHelper(activations, spatial_dimension_to_split,
2108 activations_batch_dim, high_padding, low_padding,
2109 spatial_split_size, num_splits));
2110 return std::make_pair(new_activations, transpose_dims);
2111 }
2112
PropagateOnConstant(HloInstruction * consumer,HloInstruction * producer)2113 StatusOr<HloInstruction*> ConvolutionVisitor::PropagateOnConstant(
2114 HloInstruction* consumer, HloInstruction* producer) {
2115 CHECK(old_to_new_instrs_.contains(producer));
2116 HloInstruction* new_producer = old_to_new_instrs_[producer];
2117 auto prod_transpose_dims = instr_to_dim_permute_map_[new_producer];
2118 std::vector<int64> reversed_transpose_dims(prod_transpose_dims.size());
2119 for (int64 i = 0; i < prod_transpose_dims.size(); ++i) {
2120 reversed_transpose_dims[i] = ReverseDimLookUp(prod_transpose_dims, i);
2121 }
2122 // Bring space next to batch.
2123 TF_ASSIGN_OR_RETURN(consumer,
2124 MakeTransposeHlo(consumer, reversed_transpose_dims));
2125
2126 auto dim_map = instr_to_dim_map_[producer];
2127 const int64 old_batch_dim = dim_map.first;
2128 const int64 old_space_dim = dim_map.second;
2129 const int64 new_batch_dim = DimLookUp(prod_transpose_dims, old_batch_dim);
2130 const int64 new_space_dim = DimLookUp(prod_transpose_dims, old_space_dim);
2131
2132 const int64 old_batch_size = producer->shape().dimensions(old_batch_dim);
2133 const int64 new_batch_size = new_producer->shape().dimensions(new_batch_dim);
2134 const int64 high_padding =
2135 (new_batch_size * new_producer->shape().dimensions(new_space_dim) -
2136 old_batch_size * producer->shape().dimensions(old_space_dim)) /
2137 old_batch_size;
2138
2139 auto new_consumer = SplitSpaceHelper(
2140 consumer, new_space_dim, new_batch_dim, high_padding, /*low_padding=*/0,
2141 new_producer->shape().dimensions(new_space_dim), kNumSplits);
2142
2143 return new_consumer;
2144 }
2145
PropagateOnBackpropFilterConv(HloInstruction * convolution)2146 Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
2147 HloInstruction* convolution) {
2148 auto activations_old = convolution->mutable_operand(0);
2149
2150 const int64 rhs_dilation =
2151 convolution->window()
2152 .dimensions(get_chosen_spatial_dim(convolution))
2153 .window_dilation();
2154
2155 auto original_conv_dims = convolution->convolution_dimension_numbers();
2156 int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions(
2157 get_chosen_spatial_dim(convolution));
2158 auto kernel_old = convolution->mutable_operand(1);
2159 const int64 old_kernel_split_dim_size =
2160 kernel_old->shape().dimensions(kernel_space_dim);
2161
2162 int64 old_space_dim = original_conv_dims.input_spatial_dimensions(
2163 get_chosen_spatial_dim(convolution));
2164 int64 old_split_dim_size = activations_old->shape().dimensions(old_space_dim);
2165
2166 int64 old_batch_dim = original_conv_dims.input_feature_dimension();
2167 int64 kernel_old_batch_dim =
2168 original_conv_dims.kernel_input_feature_dimension();
2169 const int64 old_batch_size =
2170 activations_old->shape().dimensions(old_batch_dim);
2171
2172 CHECK(old_to_new_instrs_.contains(kernel_old) ||
2173 old_to_new_instrs_.contains(activations_old));
2174
2175 HloInstruction* activations_new = nullptr;
2176 HloInstruction* kernel_new = nullptr;
2177 bool activations_locally_space_to_batched = false;
2178 bool kernel_locally_space_to_batched = false;
2179 std::vector<int64> permute_dims_kernel, permute_dims;
2180 // If activations were no space-to-batched, we space-to-batch them below.
2181 if (!old_to_new_instrs_.contains(activations_old)) {
2182 kernel_new = old_to_new_instrs_[kernel_old];
2183 permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
2184
2185 VLOG(1) << "Space-to-batching activations to enable space-to-depth";
2186 const int64 prev_feature_dim = original_conv_dims.input_feature_dimension();
2187 const int64 prev_batch_dim = original_conv_dims.input_batch_dimension();
2188 instr_to_dim_map_[activations_old] =
2189 std::make_pair(prev_feature_dim, prev_batch_dim);
2190
2191 const int64 new_kernel_space_dim =
2192 DimLookUp(permute_dims_kernel, kernel_space_dim);
2193
2194 const int64 new_kernel_split_dim_size =
2195 kernel_new->shape().dimensions(new_kernel_space_dim);
2196 const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size;
2197 const int64 pad_size =
2198 needed_spatial_size * kNumSplits - old_split_dim_size;
2199 ConvolutionDimensionNumbers tmp_dim_numbers;
2200 tmp_dim_numbers = original_conv_dims;
2201 TF_ASSIGN_OR_RETURN(
2202 auto retval,
2203 SplitSpace(activations_old, tmp_dim_numbers, old_space_dim,
2204 old_batch_dim,
2205 /*high_padding=*/pad_size, /*low_padding=*/0,
2206 needed_spatial_size, kNumSplits, /*is_backprop=*/true));
2207
2208 old_to_new_instrs_[activations_old] = retval.first;
2209
2210 std::vector<int64> reversed_transpose_dims(retval.second.size());
2211 for (int64 i = 0; i < retval.second.size(); ++i) {
2212 reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
2213 }
2214 instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
2215
2216 VLOG(3) << "New Activations " << retval.first->ToString();
2217
2218 activations_locally_space_to_batched = true;
2219 } else if (!old_to_new_instrs_.contains(kernel_old)) {
2220 activations_new = old_to_new_instrs_[activations_old];
2221 permute_dims = instr_to_dim_permute_map_[activations_new];
2222
2223 VLOG(1) << "Space-to-batching kernel to enable space-to-depth";
2224 const int64 prev_feature_dim =
2225 original_conv_dims.kernel_input_feature_dimension();
2226 const int64 prev_output_feature_dim =
2227 original_conv_dims.kernel_output_feature_dimension();
2228 // TODO(b/168316428): The instr_to_dim_map_ is set incorrectly here, but it
2229 // doesn't matter since it is never used. Investigate further to see if just
2230 // not setting it works.
2231 instr_to_dim_map_[kernel_old] =
2232 std::make_pair(prev_feature_dim, prev_output_feature_dim);
2233
2234 const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
2235 const int64 new_split_dim_size =
2236 activations_new->shape().dimensions(new_space_dim);
2237 const int64 needed_spatial_size =
2238 CeilOfRatio(new_split_dim_size, rhs_dilation);
2239 int64 old_kernel_split_dim_size =
2240 kernel_old->shape().dimensions(kernel_space_dim);
2241 const int64 pad_size =
2242 needed_spatial_size * kNumSplits - old_kernel_split_dim_size;
2243
2244 ConvolutionDimensionNumbers tmp_dim_numbers;
2245 tmp_dim_numbers = original_conv_dims;
2246 TF_ASSIGN_OR_RETURN(
2247 auto retval, SplitSpace(kernel_old, tmp_dim_numbers, kernel_space_dim,
2248 kernel_old_batch_dim,
2249 /*high_padding=*/pad_size, /*low_padding=*/0,
2250 needed_spatial_size, kNumSplits,
2251 /*is_backprop=*/true, /*is_rhs=*/true));
2252
2253 old_to_new_instrs_[kernel_old] = retval.first;
2254
2255 std::vector<int64> reversed_transpose_dims(retval.second.size());
2256 for (int64 i = 0; i < retval.second.size(); ++i) {
2257 reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
2258 }
2259 instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
2260
2261 VLOG(3) << "New kernel " << retval.first->ToString();
2262
2263 kernel_locally_space_to_batched = true;
2264 }
2265
2266 CHECK(old_to_new_instrs_.contains(activations_old));
2267 CHECK(old_to_new_instrs_.contains(kernel_old));
2268 activations_new = old_to_new_instrs_[activations_old];
2269 kernel_new = old_to_new_instrs_[kernel_old];
2270 const int64 new_spatial_dimension =
2271 activations_new->shape().dimensions_size();
2272
2273 permute_dims = instr_to_dim_permute_map_[activations_new];
2274 permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
2275
2276 auto permuted_conv_dims_numbers = original_conv_dims;
2277
2278 // Note the inversion here : batch and feature are inverted in backprop
2279 // filters.
2280 int64 activations_batch_dim =
2281 DimLookUp(permute_dims, original_conv_dims.input_feature_dimension());
2282 int64 activations_feature_dim =
2283 DimLookUp(permute_dims, original_conv_dims.input_batch_dimension());
2284
2285 const int64 previous_spatial_dim_count =
2286 original_conv_dims.input_spatial_dimensions_size();
2287 for (int64 i = 0; i < previous_spatial_dim_count; ++i) {
2288 permuted_conv_dims_numbers.set_input_spatial_dimensions(
2289 i, DimLookUp(permute_dims,
2290 original_conv_dims.input_spatial_dimensions(i)));
2291 permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
2292 i, DimLookUp(permute_dims_kernel,
2293 original_conv_dims.kernel_spatial_dimensions(i)));
2294 }
2295
2296 permuted_conv_dims_numbers.add_input_spatial_dimensions(
2297 new_spatial_dimension);
2298 permuted_conv_dims_numbers.add_kernel_spatial_dimensions(
2299 new_spatial_dimension);
2300 permuted_conv_dims_numbers.add_output_spatial_dimensions(
2301 new_spatial_dimension);
2302
2303 // For the output, make the last dimension size 1.
2304 const int64 previous_chosen_spatial_dim_in_output =
2305 permuted_conv_dims_numbers.output_spatial_dimensions(
2306 get_chosen_spatial_dim(convolution));
2307 permuted_conv_dims_numbers.set_output_spatial_dimensions(
2308 get_chosen_spatial_dim(convolution), new_spatial_dimension);
2309 permuted_conv_dims_numbers.set_output_spatial_dimensions(
2310 previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
2311
2312 const int64 kernel_input_feature_dim = DimLookUp(
2313 permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
2314
2315 const int64 kernel_output_feature_dim =
2316 DimLookUp(permute_dims_kernel,
2317 original_conv_dims.kernel_output_feature_dimension());
2318
2319 permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
2320 kernel_input_feature_dim);
2321 permuted_conv_dims_numbers.set_kernel_output_feature_dimension(
2322 kernel_output_feature_dim);
2323
2324 int64 spatial_dimension_to_split =
2325 permuted_conv_dims_numbers.input_spatial_dimensions(
2326 get_chosen_spatial_dim(convolution));
2327
2328 const int64 kernel_spatial_dimension_to_split =
2329 permuted_conv_dims_numbers.kernel_spatial_dimensions(
2330 get_chosen_spatial_dim(convolution));
2331
2332 int64 new_split_dim_size =
2333 activations_new->shape().dimensions(spatial_dimension_to_split);
2334
2335 const int64 kernel_new_split_dim_size =
2336 kernel_new->shape().dimensions(kernel_spatial_dimension_to_split);
2337
2338 permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim);
2339 permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim);
2340
2341 VLOG(1) << "Propagating on conv activations_batch_dim "
2342 << activations_batch_dim << " spatial_dimension_to_split "
2343 << spatial_dimension_to_split << " old_batch_size " << old_batch_size
2344 << " new_split_dim_size " << new_split_dim_size;
2345
2346 TF_ASSIGN_OR_RETURN(
2347 auto retval,
2348 BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
2349 spatial_dimension_to_split, activations_batch_dim,
2350 /*is_backprop=*/true));
2351
2352 std::vector<int64> transpose_dims = retval.transpose_dims;
2353 CHECK(!transpose_dims.empty());
2354 activations_new = retval.instr;
2355
2356 VLOG(1) << "Activations_new post BringSpaceNextToBatch "
2357 << activations_new->ToString();
2358 VLOG(1) << "activations_batch_dim " << activations_batch_dim
2359 << " activations_feature_dim " << activations_feature_dim;
2360 const int64 expected_split_dim_size =
2361 rhs_dilation * kernel_new_split_dim_size;
2362 if (new_split_dim_size != expected_split_dim_size) {
2363 CHECK_LT(new_split_dim_size, expected_split_dim_size);
2364 new_split_dim_size = expected_split_dim_size;
2365 TF_ASSIGN_OR_RETURN(
2366 activations_new,
2367 IncreaseSpatialSizeOnSpaceToBatchedShape(
2368 activations_new, activations_batch_dim, old_batch_size,
2369 spatial_dimension_to_split, new_split_dim_size));
2370 }
2371
2372 auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
2373 LiteralUtil::Zero(activations_new->shape().element_type())));
2374
2375 if (!activations_locally_space_to_batched) {
2376 // Select activations correctly by masking additional space.
2377 TF_ASSIGN_OR_RETURN(
2378 activations_new,
2379 SelectValidPortion(activations_new, activations_old, select_val,
2380 activations_batch_dim, spatial_dimension_to_split,
2381 old_batch_dim, old_space_dim));
2382 }
2383 if (!kernel_locally_space_to_batched) {
2384 VLOG(3) << "Selecting the valid kernel area";
2385 // Select kernel correctly by masking additional space.
2386 TF_ASSIGN_OR_RETURN(
2387 kernel_new,
2388 SelectValidPortion(kernel_new, kernel_old, select_val,
2389 /*new_batch_dim=*/kernel_input_feature_dim,
2390 kernel_spatial_dimension_to_split,
2391 /*old_batch_dim=*/
2392 original_conv_dims.kernel_input_feature_dimension(),
2393 kernel_space_dim));
2394 }
2395
2396 // Create the new convolution dim numbers.
2397 auto new_dim_numbers = permuted_conv_dims_numbers;
2398
2399 VLOG(2) << "New dim numbers " << new_dim_numbers.DebugString();
2400
2401 const int64 inherent_low_padding =
2402 convolution->window()
2403 .dimensions(get_chosen_spatial_dim(convolution))
2404 .padding_low();
2405
2406 const int64 inherent_high_padding =
2407 convolution->window()
2408 .dimensions(get_chosen_spatial_dim(convolution))
2409 .padding_high();
2410
2411 std::vector<HloInstruction*> activations_chunks;
2412
2413 // Insert slices for low padding.
2414 for (int64 i = 0; i < inherent_low_padding; ++i) {
2415 HloInstruction* activations_to_use = nullptr;
2416 if (i == 0) {
2417 activations_to_use = activations_new;
2418 } else {
2419 activations_to_use = activations_chunks.back();
2420 }
2421 TF_ASSIGN_OR_RETURN(
2422 HloInstruction * activations_slice,
2423 HaloDuplicateWithSlice(activations_to_use, spatial_dimension_to_split,
2424 activations_batch_dim, old_batch_size,
2425 /*low_padding=*/1,
2426 /*high_padding=*/0,
2427 /*halo_size=*/0, old_split_dim_size));
2428 activations_chunks.push_back(activations_slice);
2429 }
2430 // Reverse the low padding slices because we created them in the opposite
2431 // order above.
2432 absl::c_reverse(activations_chunks);
2433
2434 const int64 expanded_kernel =
2435 old_kernel_split_dim_size * rhs_dilation - (rhs_dilation - 1);
2436 const int64 overlap_count =
2437 old_split_dim_size - expanded_kernel + 1 +
2438 (inherent_low_padding < 0 ? inherent_low_padding : 0) +
2439 (inherent_high_padding < 0 ? inherent_high_padding : 0);
2440 VLOG(1) << "overlap_count " << overlap_count << " inherent_low_padding "
2441 << inherent_low_padding << " inherent_high_padding "
2442 << inherent_high_padding;
2443
2444 // Insert original activations.
2445 for (int64 i = 0; i < overlap_count; ++i) {
2446 HloInstruction* activations_to_use = nullptr;
2447 HloInstruction* activations_slice = nullptr;
2448 if (i == 0) {
2449 activations_to_use = activations_new;
2450 if (inherent_low_padding < 0) {
2451 TF_ASSIGN_OR_RETURN(activations_slice,
2452 HaloDuplicateWithSlice(
2453 activations_to_use, spatial_dimension_to_split,
2454 activations_batch_dim, old_batch_size,
2455 /*low_padding=*/inherent_low_padding,
2456 /*high_padding=*/0,
2457 /*halo_size=*/0, old_split_dim_size));
2458 } else {
2459 activations_slice = activations_to_use;
2460 }
2461 } else {
2462 activations_to_use = activations_chunks.back();
2463
2464 TF_ASSIGN_OR_RETURN(
2465 activations_slice,
2466 HaloDuplicateWithSlice(activations_to_use, spatial_dimension_to_split,
2467 activations_batch_dim, old_batch_size,
2468 /*low_padding=*/-1,
2469 /*high_padding=*/0,
2470 /*halo_size=*/0, old_split_dim_size));
2471 }
2472
2473 activations_chunks.push_back(activations_slice);
2474 }
2475
2476 // Insert slices for high padding.
2477 for (int64 i = 0; i < inherent_high_padding; ++i) {
2478 HloInstruction* activations_to_use = nullptr;
2479 activations_to_use = activations_chunks.back();
2480
2481 TF_ASSIGN_OR_RETURN(
2482 HloInstruction * activations_slice,
2483 HaloDuplicateWithSlice(activations_to_use, spatial_dimension_to_split,
2484 activations_batch_dim, old_batch_size,
2485 /*low_padding=*/-1, /*high_padding=*/0,
2486 /*halo_size=*/0, old_split_dim_size));
2487 activations_chunks.push_back(activations_slice);
2488 }
2489
2490 for (int64 i = 0; i < activations_chunks.size(); ++i) {
2491 std::vector<int64> input_sizes(
2492 activations_chunks[i]->shape().dimensions().begin(),
2493 activations_chunks[i]->shape().dimensions().end());
2494 // Insert 1-sized dimension at the end
2495 input_sizes.push_back(1);
2496 TF_ASSIGN_OR_RETURN(activations_chunks[i],
2497 MakeReshapeHlo(input_sizes, activations_chunks[i]));
2498 }
2499
2500 TF_ASSIGN_OR_RETURN(
2501 activations_new,
2502 MakeConcatHlo(absl::MakeSpan(activations_chunks), new_spatial_dimension));
2503
2504 // Reshape the kernel with additional spatial dim.
2505 std::vector<int64> kernel_sizes(kernel_new->shape().dimensions().begin(),
2506 kernel_new->shape().dimensions().end());
2507 // Insert 1-sized dimension at the end
2508 kernel_sizes.push_back(1);
2509 TF_ASSIGN_OR_RETURN(kernel_new, MakeReshapeHlo(kernel_sizes, kernel_new));
2510
2511 auto new_window = convolution->window();
2512 new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2513 ->set_padding_high(-(rhs_dilation - 1));
2514 new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2515 ->set_padding_low(0);
2516 new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2517 ->set_size(CeilOfRatio(new_split_dim_size, rhs_dilation));
2518
2519 // Set the window for the additional spatial dim. This is a vanilla window.
2520 auto window_dim = new_window.add_dimensions();
2521 window_dim->set_base_dilation(1);
2522 window_dim->set_size(1);
2523 window_dim->set_stride(1);
2524 window_dim->set_padding_low(0);
2525 window_dim->set_padding_high(0);
2526 window_dim->set_window_reversal(false);
2527 window_dim->set_window_dilation(1);
2528
2529 TF_ASSIGN_OR_RETURN(
2530 HloInstruction * new_conv,
2531 MakeConvolveHlo(
2532 activations_new, kernel_new, convolution->feature_group_count(),
2533 convolution->batch_group_count(), new_window, new_dim_numbers,
2534 convolution->precision_config(),
2535 /*preferred_element_type=*/convolution->shape().element_type()));
2536 convolution->SetupDerivedInstruction(new_conv);
2537
2538 VLOG(2) << "New backprop filter convolution " << new_conv->ToString();
2539
2540 std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
2541 new_conv->shape().dimensions().end());
2542
2543 output_sizes.erase(output_sizes.begin() +
2544 new_dim_numbers.output_spatial_dimensions(
2545 get_chosen_spatial_dim(convolution)));
2546
2547 TF_ASSIGN_OR_RETURN(new_conv, MakeReshapeHlo(output_sizes, new_conv));
2548
2549 old_to_new_instrs_[convolution] = new_conv;
2550 VLOG(1) << "Space-to-featured convolution " << new_conv->ToString();
2551
2552 instr_to_dim_map_[convolution] =
2553 std::make_pair(original_conv_dims.output_batch_dimension(),
2554 original_conv_dims.output_spatial_dimensions(
2555 get_chosen_spatial_dim(convolution)));
2556
2557 std::vector<int64> trans_dims(convolution->shape().dimensions_size());
2558 absl::c_iota(trans_dims, 0);
2559 instr_to_dim_permute_map_[new_conv] = trans_dims;
2560
2561 return Status::OK();
2562 }
2563
2564 HloInstruction*
DoesConvolutionFeedReduceWindowOrSelectAndScatter(HloInstruction * instr,int64 depth=kReduceWindowSearchDepth)2565 ConvolutionVisitor::DoesConvolutionFeedReduceWindowOrSelectAndScatter(
2566 HloInstruction* instr, int64 depth = kReduceWindowSearchDepth) {
2567 if (depth == 0) {
2568 return nullptr;
2569 }
2570
2571 for (auto user : instr->users()) {
2572 if (user->opcode() == HloOpcode::kReduceWindow ||
2573 user->opcode() == HloOpcode::kSelectAndScatter) {
2574 return user;
2575 }
2576 // Stop the search if these ops are encountered.
2577 if (user->opcode() == HloOpcode::kConvolution ||
2578 user->opcode() == HloOpcode::kPad ||
2579 user->opcode() == HloOpcode::kTranspose) {
2580 continue;
2581 }
2582 auto ret =
2583 DoesConvolutionFeedReduceWindowOrSelectAndScatter(user, depth - 1);
2584 if (ret != nullptr) {
2585 return ret;
2586 }
2587 }
2588 return nullptr;
2589 }
2590
GetConvolutionDetails(HloInstruction * convolution,ConvolutionDimensionNumbers & dim_numbers)2591 ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
2592 HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
2593 auto activations = convolution->mutable_operand(0);
2594
2595 auto kernel = convolution->mutable_operand(1);
2596 const auto& kernel_shape = kernel->shape();
2597 const int64 kernel_spatial_dim_size =
2598 kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
2599 get_chosen_spatial_dim(convolution)));
2600
2601 const int64 spatial_dimension_to_split =
2602 dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution));
2603
2604 const int64 input_dim_size =
2605 activations->shape().dimensions(spatial_dimension_to_split);
2606
2607 const int64 inherent_low_padding =
2608 convolution->window()
2609 .dimensions(get_chosen_spatial_dim(convolution))
2610 .padding_low();
2611 const int64 inherent_high_padding =
2612 convolution->window()
2613 .dimensions(get_chosen_spatial_dim(convolution))
2614 .padding_high();
2615
2616 const int64 stride = convolution->window()
2617 .dimensions(get_chosen_spatial_dim(convolution))
2618 .stride();
2619
2620 const int64 base_dilation_factor =
2621 convolution->window()
2622 .dimensions(get_chosen_spatial_dim(convolution))
2623 .base_dilation();
2624
2625 const int64 spatial_size =
2626 input_dim_size + (base_dilation_factor > 1 ? 0 : inherent_low_padding) +
2627 inherent_high_padding;
2628
2629 const int64 halo_size =
2630 std::max(kernel_spatial_dim_size - stride - (base_dilation_factor - 1),
2631 static_cast<int64>(0));
2632 const int64 high_padding_for_conv = base_dilation_factor == 1 ? 0
2633 : inherent_low_padding == 0
2634 ? base_dilation_factor - 1
2635 : 0;
2636 const int64 low_padding_for_conv =
2637 base_dilation_factor == 1 ? 0 : inherent_low_padding;
2638
2639 return ConvDetails{spatial_dimension_to_split,
2640 inherent_low_padding,
2641 inherent_high_padding,
2642 stride,
2643 spatial_size,
2644 base_dilation_factor,
2645 halo_size,
2646 high_padding_for_conv,
2647 low_padding_for_conv,
2648 kernel_spatial_dim_size,
2649 input_dim_size};
2650 }
2651
PerformSpaceToBatchOnConvolution(HloInstruction * convolution)2652 Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
2653 HloInstruction* convolution) {
2654 VLOG(1) << "Handling conv " << convolution->ToString();
2655
2656 changed_ = false;
2657
2658 ConvolutionDimensionNumbers dim_numbers =
2659 convolution->convolution_dimension_numbers();
2660
2661 ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
2662
2663 int64 activations_batch_dim = dim_numbers.input_batch_dimension();
2664
2665 const int64 old_batch_size =
2666 convolution->operand(0)->shape().dimensions(activations_batch_dim);
2667
2668 auto activations = convolution->mutable_operand(0);
2669
2670 VLOG(1) << "spatial size " << c.spatial_size;
2671
2672 auto original_conv = convolution;
2673
2674 const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
2675 get_chosen_spatial_dim(convolution));
2676 const int64 output_offsets =
2677 convolution->shape().dimensions(output_spatial_dim);
2678 const int64 output_offsets_per_split =
2679 CeilOfRatio(output_offsets, kNumSplits);
2680
2681 int64 spatial_split_size =
2682 CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
2683 // Keep increasing the split size so that overall size isn't smaller than the
2684 // original spatial dimension.
2685 while (spatial_split_size * kNumSplits - c.spatial_size < 0) {
2686 spatial_split_size += c.stride;
2687 }
2688
2689 auto reduce_window_or_select_and_scatter =
2690 DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
2691
2692 if (reduce_window_or_select_and_scatter != nullptr) {
2693 VLOG(2)
2694 << "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
2695 // Take into account the stride of the reduce window while choosing the
2696 // spatial_split_size. This will guarantee propagation through reduce
2697 // windows.
2698 const int64 win_stride = reduce_window_or_select_and_scatter->window()
2699 .dimensions(output_spatial_dim)
2700 .stride();
2701 while ((spatial_split_size / c.stride) % win_stride != 0) {
2702 spatial_split_size += c.stride;
2703 }
2704 }
2705
2706 const int64 slice_size = spatial_split_size + c.halo_size;
2707
2708 // Pad spatial dim.
2709 const int64 pad_size = spatial_split_size * kNumSplits - c.spatial_size;
2710
2711 VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
2712 << c.stride << " slice_size " << slice_size;
2713 VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
2714 << " num_splits " << kNumSplits << " kernel_spatial_dim_size "
2715 << c.kernel_spatial_dim_size;
2716 int64 spatial_dimension_to_split = c.spatial_dimension_to_split;
2717 TF_ASSIGN_OR_RETURN(
2718 auto retval,
2719 SplitSpace(activations, dim_numbers, spatial_dimension_to_split,
2720 activations_batch_dim,
2721 /*high_padding=*/c.inherent_high_padding + pad_size,
2722 /*low_padding=*/c.base_dilation_factor == 1
2723 ? c.inherent_low_padding
2724 : 0,
2725 spatial_split_size, kNumSplits));
2726 HloInstruction* batch_increased_reshape = retval.first;
2727 convolution->SetupDerivedInstruction(batch_increased_reshape);
2728
2729 VLOG(1) << "First reshape done " << batch_increased_reshape->ToString();
2730
2731 TF_ASSIGN_OR_RETURN(
2732 activations, HaloDuplicateWithSlice(batch_increased_reshape,
2733 spatial_dimension_to_split,
2734 activations_batch_dim, old_batch_size,
2735 /*low_padding=*/0, /*high_padding=*/0,
2736 c.halo_size, c.input_dim_size));
2737
2738 VLOG(1) << "Batch merge done " << activations->ToString();
2739
2740 // Now, we rewrite the convolution with a larger batch.
2741
2742 // Create the new convolution dim numbers.
2743 auto new_dim_numbers = dim_numbers;
2744
2745 // We will generate output such that batch is followed by the split spatial
2746 // dimension.
2747 const int64 rank = convolution->shape().rank();
2748 std::vector<int64> transpose_dims(rank);
2749 int dim_count = 0;
2750 std::map<int64, int64> dim_map;
2751
2752 for (int j = 0; j < dim_numbers.output_spatial_dimensions_size(); ++j) {
2753 if (j == get_chosen_spatial_dim(convolution)) {
2754 dim_map[dim_numbers.output_batch_dimension()] = dim_count;
2755 new_dim_numbers.set_output_batch_dimension(dim_count++);
2756 }
2757 dim_map[dim_numbers.output_spatial_dimensions(j)] = dim_count;
2758 new_dim_numbers.set_output_spatial_dimensions(j, dim_count);
2759 dim_count++;
2760 }
2761
2762 dim_map[dim_numbers.output_feature_dimension()] = dim_count;
2763 new_dim_numbers.set_output_feature_dimension(dim_count);
2764
2765 int p = 0;
2766 for (const auto& entry : dim_map) {
2767 transpose_dims[p] = entry.second;
2768 p++;
2769 }
2770 VLOG(1) << "New dim numbers " << new_dim_numbers.DebugString()
2771 << " batch dim " << new_dim_numbers.input_batch_dimension();
2772 auto new_window = convolution->window();
2773 new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2774 ->set_padding_high(c.high_padding_for_conv);
2775 new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2776 ->set_padding_low(c.low_padding_for_conv);
2777 TF_ASSIGN_OR_RETURN(
2778 HloInstruction * new_conv,
2779 MakeConvolveHlo(
2780 activations, /*rhs=*/convolution->mutable_operand(1),
2781 convolution->feature_group_count(), convolution->batch_group_count(),
2782 new_window, new_dim_numbers, convolution->precision_config(),
2783 /*preferred_element_type=*/convolution->shape().element_type()));
2784 convolution->SetupDerivedInstruction(new_conv);
2785
2786 // If the activations were to be batch-to-spaced again, simply use the
2787 // original value.
2788 batch_to_space_map_[convolution->mutable_operand(0)] =
2789 convolution->mutable_operand(0);
2790
2791 VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
2792
2793 const int64 output_split_spatial_dim =
2794 new_dim_numbers.output_spatial_dimensions(
2795 get_chosen_spatial_dim(convolution));
2796 const int64 output_batch_dim = new_dim_numbers.output_batch_dimension();
2797 VLOG(1) << "output_batch_dim " << output_batch_dim
2798 << " output_split_spatial_dim " << output_split_spatial_dim;
2799
2800 auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
2801 LiteralUtil::Zero(new_conv->shape().element_type())));
2802
2803 TF_ASSIGN_OR_RETURN(
2804 new_conv, SelectValidPortion(new_conv, original_conv, select_val,
2805 output_batch_dim, output_split_spatial_dim,
2806 dim_numbers.output_batch_dimension(),
2807 dim_numbers.output_spatial_dimensions(
2808 get_chosen_spatial_dim(original_conv))));
2809 old_to_new_instrs_[original_conv] = new_conv;
2810
2811 instr_to_dim_map_[original_conv] =
2812 std::make_pair(dim_numbers.output_batch_dimension(),
2813 dim_numbers.output_spatial_dimensions(
2814 get_chosen_spatial_dim(original_conv)));
2815
2816 instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims);
2817 if (non_propagatable_instrs_.count(convolution) > 0) {
2818 non_propagatable_instrs_.erase(convolution);
2819 }
2820 TF_CHECK_OK(PropagateOnUsers(original_conv));
2821
2822 changed_ = true;
2823
2824 return Status::OK();
2825 }
2826
2827 } // namespace
2828
Run(HloModule * module)2829 StatusOr<bool> ConvolutionSpaceToBatchConverter::Run(HloModule* module) {
2830 XLA_VLOG_LINES(2, "ConvolutionSpaceToBatchConverter::Run(), before:\n" +
2831 module->ToString());
2832 bool changed = false;
2833
2834 for (auto* comp : module->MakeNonfusionComputations()) {
2835 ConvolutionVisitor visitor(limit_on_batch_size_, comp);
2836 if (visitor.Run().ValueOrDie()) {
2837 changed = true;
2838 }
2839 VLOG(1) << "Done operating on computation";
2840 }
2841 XLA_VLOG_LINES(2, "ConvolutionSpaceToBatchConverter::Run(), after:\n" +
2842 module->ToString());
2843 return changed;
2844 }
2845
2846 } // namespace xla
2847