1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/space_to_batch_converter.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <iterator>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <tuple>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/algorithm.h"
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/container/flat_hash_set.h"
32 #include "absl/memory/memory.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/literal_util.h"
36 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
37 #include "tensorflow/compiler/xla/service/hlo_computation.h"
38 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
40 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
41 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
42 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
43 #include "tensorflow/compiler/xla/service/shape_inference.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/core/bitmap.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/stream_executor/lib/statusor.h"
54 
55 namespace xla {
56 
57 namespace {
58 
59 namespace m = match;
60 
61 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
62 // operations with small batch counts into convolutions with larger batch
63 // counts by moving space to batch.
64 class ConvolutionVisitor {
65  public:
66   // Top-level function to begin space-to-batch conversion.
67   Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution);
68 
69   // Struct containing details about a convolution.
70   struct ConvDetails {
71     int64 spatial_dimension_to_split, inherent_low_padding,
72         inherent_high_padding, stride, spatial_size, base_dilation_factor,
73         halo_size, high_padding_for_conv, low_padding_for_conv,
74         kernel_spatial_dim_size, input_dim_size;
75   };
76 
77   // Return a struct containing various necessary information pieces for
78   // performing space-to-batch on a convolution.
79   ConvDetails GetConvolutionDetails(HloInstruction* convolution,
80                                     ConvolutionDimensionNumbers& dim_numbers);
81 
82   // Function that determines if space-to-batch can be propagated into the
83   // consumer. Such propagation is only possible when all required operands are
84   // space-to-batch'ed.
85   bool CanPropagate(HloInstruction* consumer, HloInstruction* producer,
86                     bool last_try = false);
87 
88   // Returns true if the op has all its direct and indirect operands being
89   // created via broadcasts. Consumer uses op, and is space-to-batched.
90   // instructions_to_transform returns the reverse post order instruction graph.
91   bool IsBroadcastTree(HloInstruction* op, HloInstruction* consumer,
92                        std::vector<HloInstruction*>& instructions_to_transform);
93 
94   // Replicates the broadcast tree with space-to-batched instructions.
95   void RewriteBroadcastTree(
96       HloInstruction* producer,
97       std::vector<HloInstruction*>& instructions_to_transform);
98 
99   // Propagate space-to-batch on a broadcast instruction.
100   void PropagateOnBroadcast(HloInstruction* consumer, HloInstruction* producer);
101 
102   // This function checks if the HLO instrution supports propagation.
103   bool SupportedOpForPropagation(HloInstruction* consumer,
104                                  HloInstruction* producer);
105 
106   // Method that checks validity of Broadcast propagation.
107   bool IsBroadcastPropagatable(HloInstruction* broadcast,
108                                HloInstruction* old_other_op);
109 
110   // Propagates space-to-batch on the op, and returns a bool that indicates if
111   // the users of the op need to be propagated through.
112   StatusOr<bool> Propagate(HloInstruction* consumer, HloInstruction* producer);
113 
114   // Splits the given spatial dimension on the activations and returns the
115   // new instructions, and the dimension permutation of the new shape.
116   StatusOr<std::pair<HloInstruction*, std::vector<int64>>> SplitSpace(
117       HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
118       int64& spatial_dimension_to_split, int64& activations_batch_dim,
119       int64 high_padding, int64 low_padding, int64 spatial_split_size,
120       int64 num_splits, bool is_backprop = false, bool is_rhs = false);
121 
122   // Helper function for the SplitSpace function above. Handles padding and
123   // reshaping to generate space-to-batched shape.
124   StatusOr<HloInstruction*> SplitSpaceHelper(
125       HloInstruction* activations, int64 spatial_dimension_to_split,
126       int64 activations_batch_dim, int64 high_padding, int64 low_padding,
127       int64 spatial_split_size, int64 num_splits);
128 
129   // Perform space-to-batch propagation on constants.
130   StatusOr<HloInstruction*> PropagateOnConstant(HloInstruction* consumer,
131                                                 HloInstruction* producer);
132 
133   // Perform space-to-batch propagation on the convolution. Assumes the
134   // activations were already space-to-batched.
135   Status PropagateOnConv(HloInstruction* convolution);
136 
137   // Perform space-to-batch propagation on the backprop filter convolution.
138   // Assumes the activations and kernel were already space-to-batched.
139   Status PropagateOnBackpropFilterConv(HloInstruction* convolution);
140 
141   // Method that checks validity of space-to-batch on a given convolution.
142   bool IsConvSuitableForSpaceToBatch(HloInstruction* convolution);
143 
144   // Once a convolution has been space-to-batch'ed, this function will
145   // transitively propagate the space-to-batch-ness on rest of the graph.
146   Status PropagateOnUsers(HloInstruction* old_conv);
147 
148   // Generates masked output with valid data. This is useful when larger shapes
149   // are generated due to space-to-batch.
150   StatusOr<HloInstruction*> SelectValidPortion(
151       HloInstruction* new_instr, HloInstruction* old_instr,
152       HloInstruction* select_val, int64 new_batch_dim, int64 new_space_dim,
153       int64 old_batch_dim, int64 old_space_dim);
154 
155   struct SpaceNextToBatchDetails {
156     HloInstruction* instr;
157     std::vector<int64> transpose_dims;
158   };
159 
160   // Performs tranposition so that space dimension follows the batch dimension.
161   StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch(
162       HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
163       int64& spatial_dimension_to_split, int64& activations_batch_dim,
164       bool is_backprop = false, bool is_rhs = false);
165 
166   // Increases the spatial dimension size in an already space-to-batched shape
167   // so that the new size is new_spatial_dim_size.
168   StatusOr<HloInstruction*> IncreaseSpatialSizeOnSpaceToBatchedShape(
169       HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
170       int64 spatial_dimension, int64 new_spatial_dim_size);
171 
172   // Function that converts spaced-to-batch shape back to the original.
173   StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr);
174 
175   // Duplicates elements at boundaries.
176   StatusOr<HloInstruction*> HaloDuplicateWithSlice(
177       HloInstruction* activations, int64 spatial_dimension_to_split,
178       int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
179       int64 high_padding, int64 halo_size, int64 original_split_dim_size,
180       HloInstruction* pad_val = nullptr);
181 
182   // Runs the visitor on a computation.
183   StatusOr<bool> Run();
184 
185   // Returns whether any convolution ops were rewritten.
changed() const186   const bool changed() const { return changed_; }
187 
188   ~ConvolutionVisitor() = default;
189 
190   explicit ConvolutionVisitor(int64 limit_on_batch_size,
191                               HloComputation* computation);
192 
get_chosen_spatial_dim(HloInstruction * convolution)193   int64 get_chosen_spatial_dim(HloInstruction* convolution) {
194     return convolution->convolution_dimension_numbers()
195                .input_spatial_dimensions_size() -
196            1;
197   }
198 
DimLookUp(absl::Span<const int64> permute_dims,int64 id)199   int64 DimLookUp(absl::Span<const int64> permute_dims, int64 id) {
200     return permute_dims[id];
201   }
202 
ReverseDimLookUp(absl::Span<const int64> permute_dims,int64 id)203   int64 ReverseDimLookUp(absl::Span<const int64> permute_dims, int64 id) {
204     return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
205   }
206 
207   HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
208       HloInstruction* instr, int64 depth);
209 
210  private:
211   // Current HloComputation instance the ConvolutionVisitor is traversing.
212   HloComputation* computation_;
213 
214   absl::flat_hash_set<HloInstruction*> convs_to_visit_;
215   std::vector<HloInstruction*> conv_visitor_list_;
216   HloInstructionSet non_propagatable_instrs_;
217   // Map from a given spaced-to-batch instruction to its batched-to-space
218   // version.
219   absl::flat_hash_map<HloInstruction*, HloInstruction*> batch_to_space_map_;
220 
221   // Map from old (non space-to-batch) instructions to space-to-batch'ed
222   // instructions.
223   absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_instrs_;
224 
225   // Map from instruction to dimensions of the shape (first is batch, second is
226   // space). This is with respect to the old instruction.
227   absl::flat_hash_map<HloInstruction*, std::pair<int64, int64>>
228       instr_to_dim_map_;
229 
230   // Map from space-to-batch'ed instruction to its permute dims.
231   absl::flat_hash_map<HloInstruction*, std::vector<int64>>
232       instr_to_dim_permute_map_;
233 
234   // Map maintaining previously space-to-batched broadcasts.
235   absl::flat_hash_map<HloInstruction*, absl::flat_hash_set<HloInstruction*>>
236       broadcast_map_;
237 
238   // Whether rewrite has occurred.
239   bool changed_ = false;
240 
241   // Limit on batch size to apply this technique on.
242   int64 limit_on_batch_size_;
243 
244   // We choose the new batch size to be kNumSplits times that of the old batch
245   // so that space-to-batch propagation through several convolutional layers is
246   // consistent.
247   static constexpr int64 kNumSplits = 8;
248 
249   // Depth for searching reduce window
250   static constexpr int64 kReduceWindowSearchDepth = 10;
251 };
252 
ConvolutionVisitor(int64 limit_on_batch_size,HloComputation * computation)253 ConvolutionVisitor::ConvolutionVisitor(int64 limit_on_batch_size,
254                                        HloComputation* computation) {
255   computation_ = computation;
256   limit_on_batch_size_ = limit_on_batch_size;
257   for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
258     if (inst->opcode() != HloOpcode::kConvolution) {
259       continue;
260     }
261 
262     auto convolution = inst;
263     // Perform legality checks.
264     if (!IsConvSuitableForSpaceToBatch(convolution)) {
265       VLOG(1) << "Conv not suitable for space-to-batch "
266               << convolution->ToString();
267       continue;
268     }
269     VLOG(1) << "Conv added to space-to-batch worklist "
270             << convolution->ToString();
271     convs_to_visit_.insert(convolution);
272     conv_visitor_list_.push_back(convolution);
273   }
274 }
275 
IsConvSuitableForSpaceToBatch(HloInstruction * convolution)276 bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
277     HloInstruction* convolution) {
278   ConvolutionDimensionNumbers dim_numbers =
279       convolution->convolution_dimension_numbers();
280 
281   // If there are no spatial dims, we return.
282   if (dim_numbers.input_spatial_dimensions_size() < 1) {
283     return false;
284   }
285 
286   // Batch in batch_group_count has different semantics (it isn't true batch).
287   // Consider supporting this case in future if needed.
288   if (convolution->batch_group_count() != 1) {
289     return false;
290   }
291 
292   if (convolution->window()
293           .dimensions(get_chosen_spatial_dim(convolution))
294           .window_dilation() != 1) {
295     return false;
296   }
297 
298   const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
299 
300   const int64 low_pad = convolution->window()
301                             .dimensions(get_chosen_spatial_dim(convolution))
302                             .padding_low();
303 
304   // TODO(b/168316428): Support base dilations more generically.
305   if (c.base_dilation_factor != 1) {
306     if (c.stride != 1) {
307       return false;
308     }
309     // For low pad of 0, only support a pointwise kernel.
310     if (low_pad == 0) {
311       if (c.kernel_spatial_dim_size != 1) {
312         return false;
313       }
314     } else if (c.kernel_spatial_dim_size != c.base_dilation_factor + 1 ||
315                low_pad != c.base_dilation_factor - 1) {
316       // Only support dilations such that base dilation factor and low pad are
317       // compatible with kernel_spatial_dim_size to be compatible with
318       // HaloDuplicateWithSlice.
319       return false;
320     }
321   }
322 
323   int64 activations_batch_dim = dim_numbers.input_batch_dimension();
324 
325   const int64 old_batch_size =
326       convolution->operand(0)->shape().dimensions(activations_batch_dim);
327 
328   if (old_batch_size > limit_on_batch_size_) {
329     return false;
330   }
331 
332   VLOG(1) << "spatial size " << c.spatial_size;
333 
334   // If the ratio is not within the 2X range, we can't Halo Pad from the next
335   // split.
336   if (c.halo_size > CeilOfRatio(c.spatial_size, kNumSplits)) {
337     return false;
338   }
339   VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
340   return true;
341 }
342 
HaloDuplicateWithSlice(HloInstruction * activations,int64 spatial_dimension_to_split,int64 activations_batch_dim,int64 old_batch_size,int64 low_padding,int64 high_padding,int64 halo_size,int64 original_split_dim_size,HloInstruction * pad_val)343 StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
344     HloInstruction* activations, int64 spatial_dimension_to_split,
345     int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
346     int64 high_padding, int64 halo_size, int64 original_split_dim_size,
347     HloInstruction* pad_val) {
348   const int64 original_batch_size =
349       activations->shape().dimensions(activations_batch_dim) / kNumSplits;
350 
351   if (original_batch_size > 1) {
352     std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
353                                       activations->shape().dimensions().end());
354     new_dimensions[activations_batch_dim] = kNumSplits;
355     new_dimensions.insert(new_dimensions.begin() + activations_batch_dim,
356                           original_batch_size);
357 
358     // Reshape the output of the new conv into the old convolutions shape.
359     TF_ASSIGN_OR_RETURN(activations,
360                         MakeReshapeHlo(new_dimensions, activations));
361 
362     spatial_dimension_to_split++;
363     activations_batch_dim++;
364   }
365 
366   const int64 rank = activations->shape().rank();
367   const int64 spatial_split_size =
368       activations->shape().dimensions(spatial_dimension_to_split);
369   const int64 batch_size =
370       activations->shape().dimensions(activations_batch_dim);
371 
372   CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size);
373   VLOG(1) << "In HaloDuplicateWithSlice with activations "
374           << activations->ToString() << " batch_size " << batch_size
375           << " spatial_split_size " << spatial_split_size << " low_padding "
376           << low_padding << " halo size " << halo_size;
377 
378   HloInstruction* first_slice = nullptr;
379 
380   std::vector<int64> strides(rank, 1);
381   HloInstruction* padding =
382       pad_val == nullptr
383           ? computation_->AddInstruction(HloInstruction::CreateConstant(
384                 LiteralUtil::Zero(activations->shape().element_type())))
385           : pad_val;
386 
387   if (low_padding > 0) {
388     std::vector<int64> start_indices(rank, 0),
389         end_indices(activations->shape().dimensions().begin(),
390                     activations->shape().dimensions().end());
391     start_indices[spatial_dimension_to_split] =
392         spatial_split_size - low_padding;
393     end_indices[activations_batch_dim] = batch_size - 1;
394     end_indices[spatial_dimension_to_split] = spatial_split_size;
395 
396     TF_ASSIGN_OR_RETURN(first_slice, MakeSliceHlo(activations, start_indices,
397                                                   end_indices, strides));
398     VLOG(1) << "first slice " << first_slice->ToString();
399     PaddingConfig padding_config =
400         MakeNoPaddingConfig(first_slice->shape().dimensions_size());
401     padding_config.mutable_dimensions(activations_batch_dim)
402         ->set_edge_padding_low(1);
403 
404     TF_ASSIGN_OR_RETURN(first_slice,
405                         MakePadHlo(first_slice, padding, padding_config));
406   }
407 
408   HloInstruction* halo_region = nullptr;
409   if (halo_size - low_padding > 0) {
410     std::vector<int64> start_indices_halo(rank, 0),
411         end_indices_halo(activations->shape().dimensions().begin(),
412                          activations->shape().dimensions().end());
413 
414     start_indices_halo[activations_batch_dim] = 1;
415     end_indices_halo[spatial_dimension_to_split] = halo_size - low_padding;
416 
417     TF_ASSIGN_OR_RETURN(halo_region,
418                         MakeSliceHlo(activations, start_indices_halo,
419                                      end_indices_halo, strides));
420     VLOG(1) << "halo_region " << halo_region->ToString();
421     PaddingConfig padding_config_halo =
422         MakeNoPaddingConfig(halo_region->shape().dimensions_size());
423     padding_config_halo.mutable_dimensions(activations_batch_dim)
424         ->set_edge_padding_high(1);
425     TF_ASSIGN_OR_RETURN(halo_region,
426                         MakePadHlo(halo_region, padding, padding_config_halo));
427   }
428 
429   if (halo_size == 0 && low_padding != 0) {
430     std::vector<int64> start_indices_activations_cut(rank, 0),
431         end_indices_activations_cut(activations->shape().dimensions().begin(),
432                                     activations->shape().dimensions().end());
433     // When no halo is needed, we must slice out activations.
434     if (low_padding > 0) {
435       end_indices_activations_cut[spatial_dimension_to_split] =
436           spatial_split_size - low_padding;
437     } else {
438       start_indices_activations_cut[spatial_dimension_to_split] =
439           0 - low_padding;
440       end_indices_activations_cut[spatial_dimension_to_split] =
441           spatial_split_size;
442     }
443 
444     TF_ASSIGN_OR_RETURN(activations,
445                         MakeSliceHlo(activations, start_indices_activations_cut,
446                                      end_indices_activations_cut, strides));
447   }
448 
449   if (first_slice != nullptr) {
450     TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({first_slice, activations},
451                                                    spatial_dimension_to_split));
452   }
453 
454   if (halo_region != nullptr) {
455     TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region},
456                                                    spatial_dimension_to_split));
457   }
458 
459   if (original_batch_size > 1) {
460     std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
461                                       activations->shape().dimensions().end());
462     new_dimensions[activations_batch_dim] = original_batch_size * kNumSplits;
463     new_dimensions.erase(new_dimensions.begin() + activations_batch_dim - 1);
464 
465     // Reshape the output of the new conv into the old convolutions shape.
466     TF_ASSIGN_OR_RETURN(activations,
467                         MakeReshapeHlo(new_dimensions, activations));
468 
469     spatial_dimension_to_split++;
470     activations_batch_dim++;
471   }
472 
473   VLOG(1) << "HaloDuplicated activations " << activations->ToString();
474   return activations;
475 }
476 
477 StatusOr<ConvolutionVisitor::SpaceNextToBatchDetails>
BringSpaceNextToBatch(HloInstruction * activations,ConvolutionDimensionNumbers & dim_numbers,int64 & spatial_dimension_to_split,int64 & activations_batch_dim,bool is_backprop,bool is_rhs)478 ConvolutionVisitor::BringSpaceNextToBatch(
479     HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
480     int64& spatial_dimension_to_split, int64& activations_batch_dim,
481     bool is_backprop, bool is_rhs) {
482   std::vector<int64> transpose_dims(activations->shape().rank());
483   if (spatial_dimension_to_split == activations_batch_dim + 1) {
484     absl::c_iota(transpose_dims, 0);
485   } else {
486     ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
487     int64 pushed_counter = 0;
488     int64 new_batch_dim, new_spatial_dim;
489     int64 dim_counter = 0;
490     if (is_rhs) {
491       CHECK(is_backprop);
492       for (int i = 0; i < activations->shape().rank(); ++i) {
493         if (i == activations_batch_dim) {
494           continue;
495         }
496         if (i == spatial_dimension_to_split) {
497           transpose_dims[dim_counter++] = activations_batch_dim;
498           new_batch_dim = pushed_counter;
499           pushed_counter++;
500           new_spatial_dim = pushed_counter;
501         }
502 
503         if (i == dim_numbers.kernel_output_feature_dimension()) {
504           new_dim_numbers.set_kernel_output_feature_dimension(pushed_counter);
505         } else {
506           auto it = absl::c_find(dim_numbers.kernel_spatial_dimensions(), i);
507           if (it != dim_numbers.kernel_spatial_dimensions().end()) {
508             int64 j = it - dim_numbers.kernel_spatial_dimensions().begin();
509             new_dim_numbers.set_kernel_spatial_dimensions(j, pushed_counter);
510           }
511         }
512         transpose_dims[dim_counter++] = i;
513         pushed_counter++;
514       }
515 
516       activations_batch_dim = new_batch_dim;
517       spatial_dimension_to_split = new_spatial_dim;
518       TF_ASSIGN_OR_RETURN(activations,
519                           MakeTransposeHlo(activations, transpose_dims));
520 
521       new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim);
522 
523     } else {
524       for (int i = 0; i < activations->shape().rank(); ++i) {
525         if (i == activations_batch_dim) {
526           continue;
527         }
528         if (i == spatial_dimension_to_split) {
529           transpose_dims[dim_counter++] = activations_batch_dim;
530           new_batch_dim = pushed_counter;
531           pushed_counter++;
532           new_spatial_dim = pushed_counter;
533         }
534 
535         if (is_backprop && i == dim_numbers.input_batch_dimension()) {
536           new_dim_numbers.set_input_batch_dimension(pushed_counter);
537         } else if (i == dim_numbers.input_feature_dimension()) {
538           new_dim_numbers.set_input_feature_dimension(pushed_counter);
539         } else {
540           auto it = absl::c_find(dim_numbers.input_spatial_dimensions(), i);
541           if (it != dim_numbers.input_spatial_dimensions().end()) {
542             int64 j = it - dim_numbers.input_spatial_dimensions().begin();
543             new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
544           }
545         }
546         transpose_dims[dim_counter++] = i;
547         pushed_counter++;
548       }
549 
550       activations_batch_dim = new_batch_dim;
551       spatial_dimension_to_split = new_spatial_dim;
552       TF_ASSIGN_OR_RETURN(activations,
553                           MakeTransposeHlo(activations, transpose_dims));
554 
555       if (is_backprop) {
556         new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
557       } else {
558         new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
559       }
560     }
561 
562     dim_numbers = new_dim_numbers;
563   }
564 
565   return SpaceNextToBatchDetails{activations, transpose_dims};
566 }
567 
568 StatusOr<HloInstruction*>
IncreaseSpatialSizeOnSpaceToBatchedShape(HloInstruction * activations,int64 batch_dimension,int64 old_batch_size,int64 spatial_dimension,int64 new_spatial_dim_size)569 ConvolutionVisitor::IncreaseSpatialSizeOnSpaceToBatchedShape(
570     HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
571     int64 spatial_dimension, int64 new_spatial_dim_size) {
572   CHECK_EQ(batch_dimension + 1, spatial_dimension);
573   std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
574                                     activations->shape().dimensions().end());
575 
576   const int64 new_batch_size = activations->shape().dimensions(batch_dimension);
577   int64 spatial_dim_size = activations->shape().dimensions(spatial_dimension);
578   const int64 reshaped_space_size =
579       spatial_dim_size * new_batch_size / old_batch_size;
580 
581   VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
582           << new_batch_size << " old_batch_size " << old_batch_size;
583   new_dimensions[spatial_dimension] = reshaped_space_size;
584   new_dimensions[batch_dimension] = old_batch_size;
585 
586   // Reshape the output of the new conv into the old convolutions shape.
587   TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
588                       MakeReshapeHlo(new_dimensions, activations));
589 
590   VLOG(3) << "First reshape done";
591   PaddingConfig padding_config =
592       MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
593   padding_config.mutable_dimensions(spatial_dimension)
594       ->set_edge_padding_high(new_spatial_dim_size * new_batch_size /
595                                   old_batch_size -
596                               reshaped_space_size);
597   padding_config.mutable_dimensions(spatial_dimension)->set_edge_padding_low(0);
598   HloInstruction* padding =
599       computation_->AddInstruction(HloInstruction::CreateConstant(
600           LiteralUtil::Zero(reshaped_activations->shape().element_type())));
601 
602   TF_ASSIGN_OR_RETURN(
603       reshaped_activations,
604       MakePadHlo(reshaped_activations, padding, padding_config));
605 
606   std::vector<int64> reshape_back_dims(
607       reshaped_activations->shape().dimensions().begin(),
608       reshaped_activations->shape().dimensions().end());
609 
610   reshape_back_dims[spatial_dimension] = new_spatial_dim_size;
611   reshape_back_dims[batch_dimension] = new_batch_size;
612 
613   TF_ASSIGN_OR_RETURN(HloInstruction * activations_new,
614                       MakeReshapeHlo(reshape_back_dims, reshaped_activations));
615 
616   VLOG(3) << "Size increased activations " << activations_new->ToString();
617 
618   return activations_new;
619 }
620 
Run()621 StatusOr<bool> ConvolutionVisitor::Run() {
622   for (auto conv : conv_visitor_list_) {
623     if (convs_to_visit_.count(conv) > 0) {
624       TF_CHECK_OK(PerformSpaceToBatchOnConvolution(conv));
625     }
626   }
627   conv_visitor_list_.clear();
628   convs_to_visit_.clear();
629   // Iterate through all instructions that we could not propagate through, and
630   // turn their operands from batch-to-space as needed.
631   for (auto instr : non_propagatable_instrs_) {
632     if (instr->opcode() == HloOpcode::kConvolution) {
633       VLOG(1) << "Instr " << instr->ToString();
634     }
635     // Try to propagate on backprop filters
636     if (instr->opcode() == HloOpcode::kConvolution &&
637         !IsConvSuitableForSpaceToBatch(instr)) {
638       HloInstruction* producer = nullptr;
639       if (old_to_new_instrs_.contains(instr->mutable_operand(0))) {
640         producer = instr->mutable_operand(0);
641       } else if (old_to_new_instrs_.contains(instr->mutable_operand(1))) {
642         producer = instr->mutable_operand(1);
643       }
644       if (producer) {
645         if (CanPropagate(instr, producer, /*last_try=*/true)) {
646           bool needs_further_propagation;
647           TF_ASSIGN_OR_RETURN(needs_further_propagation,
648                               Propagate(instr, producer));
649           TF_CHECK_OK(computation_->ReplaceInstruction(
650               instr, old_to_new_instrs_[instr]));
651           continue;
652         }
653       }
654     }
655     VLOG(1) << "Could not eventually propagate through " << instr->ToString();
656     absl::flat_hash_map<int64, HloInstruction*> operand_map;
657     for (int64 i = 0; i < instr->operand_count(); ++i) {
658       if (old_to_new_instrs_.count(instr->mutable_operand(i))) {
659         TF_ASSIGN_OR_RETURN(operand_map[i],
660                             BatchToSpace(instr->mutable_operand(i)));
661       }
662     }
663     for (auto entry : operand_map) {
664       TF_CHECK_OK(instr->ReplaceOperandWith(entry.first, entry.second));
665     }
666   }
667   non_propagatable_instrs_.clear();
668   return changed_;
669 }
670 
IsTrivialElementwise(HloInstruction * hlo)671 bool IsTrivialElementwise(HloInstruction* hlo) {
672   if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng ||
673       hlo->opcode() == HloOpcode::kCopy ||
674       hlo->opcode() == HloOpcode::kConstant ||
675       hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) {
676     return false;
677   }
678   return hlo->IsElementwise();
679 }
680 
CanPropagate(HloInstruction * consumer,HloInstruction * producer,bool last_try)681 bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
682                                       HloInstruction* producer, bool last_try) {
683   if (IsTrivialElementwise(consumer)) {
684     VLOG(2) << "Doing propagation check on elementwise op: "
685             << consumer->ToString();
686 
687     HloInstruction* pivot_operand = nullptr;
688     for (int64 i = 0; i < consumer->operand_count(); ++i) {
689       auto old_producer = consumer->mutable_operand(i);
690       std::vector<HloInstruction*> to_transform;
691       const bool broadcast_or_constant =
692           (old_producer->opcode() == HloOpcode::kConstant) ||
693           (old_producer->opcode() == HloOpcode::kBroadcast &&
694            IsBroadcastPropagatable(old_producer, producer)) ||
695           (consumer->IsElementwiseBinary() &&
696            old_producer->opcode() == HloOpcode::kBroadcast &&
697            IsBroadcastTree(old_producer, producer, to_transform));
698 
699       if (!old_to_new_instrs_.contains(old_producer) &&
700           !broadcast_or_constant) {
701         VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString()
702                 << " because operand " << old_producer->ToString()
703                 << " isn't ready ";
704         return false;
705       } else {
706         if (broadcast_or_constant) {
707           VLOG(2) << "Skipping on " << old_producer->ToString();
708           continue;
709         }
710 
711         CHECK(old_to_new_instrs_.contains(old_producer));
712 
713         CHECK(instr_to_dim_map_.contains(old_producer));
714         if (pivot_operand == nullptr) {
715           pivot_operand = old_producer;
716           VLOG(2) << "Elementwise op: pivot " << old_producer->ToString();
717         } else {
718           if (instr_to_dim_map_[pivot_operand] !=
719               instr_to_dim_map_[old_producer]) {
720             VLOG(2) << "Elementwise op: checking for shape equivalence "
721                     << consumer->ToString()
722                     << " failed due to changed batch space ordering ";
723             return false;
724           }
725           auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
726           auto pivot_permute_dims = instr_to_dim_permute_map_[pivot_new_instr];
727           auto new_instr = old_to_new_instrs_[old_producer];
728           auto permute_dims = instr_to_dim_permute_map_[new_instr];
729           for (int j = 0; j < pivot_permute_dims.size(); ++j) {
730             // Ensure the dimension mapping is the same.
731             if (pivot_permute_dims[j] != permute_dims[j]) {
732               VLOG(2) << "Elementwise op: checking for shape equivalence "
733                       << consumer->ToString()
734                       << " failed due to permuted dimensions ";
735               return false;
736             }
737 
738             // Make sure all other dimensions are of the same size.
739             if (pivot_new_instr->shape().dimensions(j) !=
740                 new_instr->shape().dimensions(j)) {
741               if (!((consumer->IsElementwiseBinary() ||
742                      consumer->opcode() == HloOpcode::kSelect) &&
743                     j == instr_to_dim_map_[pivot_operand].second)) {
744                 VLOG(2) << "Elementwise op: checking for shape equivalence "
745                         << consumer->ToString()
746                         << " failed due to changed shape sizes ";
747                 return false;
748               }
749             }
750           }
751         }
752       }
753     }
754   }
755 
756   if (consumer->opcode() == HloOpcode::kConvolution) {
757     VLOG(1) << "Checking if conv is supported for propagation "
758             << consumer->ToString();
759     if (IsConvSuitableForSpaceToBatch(consumer)) {
760       if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
761         return false;
762       }
763       auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
764       // Make sure that the space dimension is the same across the producer
765       // and consumer.
766       if (consumer->convolution_dimension_numbers().input_spatial_dimensions(
767               get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) {
768         return false;
769       }
770       // Make sure that the batch dimension is the same across the producer
771       // and consumer.
772       if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
773           dim_map_val_op_0.first) {
774         return false;
775       }
776 
777       return true;
778     }
779 
780     // Check for space-to-depth readiness here. Note this is not done in
781     // SupportedOpForPropagation because the readiness is dependent upon
782     // space-to-batchedness of the operands.
783 
784     // We currently only support stride of 1.
785     if (consumer->window()
786             .dimensions(get_chosen_spatial_dim(consumer))
787             .stride() != 1) {
788       return false;
789     }
790 
791     // Same reason why we give up on batch group counts applies to features in
792     // backprop.
793     if (consumer->feature_group_count() != 1) {
794       return false;
795     }
796 
797     VLOG(2) << "Checking for backprop filter conv propagatability";
798     CHECK_EQ(consumer->operand_count(), 2);
799 
800     auto activations = consumer->mutable_operand(0);
801     auto kernel = consumer->mutable_operand(1);
802 
803     auto win_dims =
804         consumer->window().dimensions(get_chosen_spatial_dim(consumer));
805     const int64 rhs_dilation = win_dims.window_dilation();
806 
807     // If the rhs_dilation is absent, we want both LHS and RHS to be space-to-
808     // batched for propagating on backprop convolutions.
809     if (!last_try || rhs_dilation == 1) {
810       if (!old_to_new_instrs_.contains(kernel) ||
811           !old_to_new_instrs_.contains(activations)) {
812         return false;
813       }
814     }
815 
816     if (!old_to_new_instrs_.contains(kernel) &&
817         !old_to_new_instrs_.contains(activations)) {
818       return false;
819     }
820 
821     if (!old_to_new_instrs_.contains(kernel)) {
822       const int64 rhs_batch =
823           kernel->shape().dimensions(consumer->convolution_dimension_numbers()
824                                          .kernel_input_feature_dimension());
825       auto dim_map_val_op_0 = instr_to_dim_map_[activations];
826       const int64 old_batch_dim = dim_map_val_op_0.first;
827       const int64 old_space_dim = dim_map_val_op_0.second;
828       auto first_operand = old_to_new_instrs_[activations];
829       auto permute_dims_first_operand =
830           instr_to_dim_permute_map_[first_operand];
831       const int64 new_batch_dim =
832           DimLookUp(permute_dims_first_operand, old_batch_dim);
833       const int64 new_space_dim =
834           DimLookUp(permute_dims_first_operand, old_space_dim);
835       const int64 lhs_batch = first_operand->shape().dimensions(new_batch_dim);
836 
837       if (first_operand->shape().dimensions(new_space_dim) % rhs_dilation !=
838           0) {
839         return false;
840       }
841       // Because we want to convert activations into a space-to-batched version
842       // only for backprop filter convolutions, we want to make sure that the
843       // batch dimensions (feature dimensions, technically) are same sized.
844       // Since LHS is already space-to-batched, we need to account for it too.
845       if (rhs_batch * kNumSplits != lhs_batch) {
846         return false;
847       }
848 
849       // If kernel have not been propagated through, we can do
850       // space-to-batch on them provided kernel has been propagated.
851       VLOG(2)
852           << "Backprop filter conv ready for propagation: activations ready, "
853              " kernel will be space-to-batched";
854       return true;
855     }
856 
857     if (!old_to_new_instrs_.contains(activations)) {
858       const int64 lhs_batch = activations->shape().dimensions(
859           consumer->convolution_dimension_numbers().input_feature_dimension());
860       auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
861       const int64 old_batch_dim = dim_map_val_op_1.first;
862       auto second_operand = old_to_new_instrs_[kernel];
863       auto permute_dims_second_operand =
864           instr_to_dim_permute_map_[second_operand];
865       const int64 new_batch_dim =
866           DimLookUp(permute_dims_second_operand, old_batch_dim);
867       const int64 rhs_batch = second_operand->shape().dimensions(new_batch_dim);
868 
869       // Because we want to convert activations into a space-to-batched version
870       // only for backprop filter convolutions, we want to make sure that the
871       // batch dimensions (feature dimensions, technically) are same sized.
872       // Since RHS is already space-to-batched, we need to account for it too.
873       if (rhs_batch != kNumSplits * lhs_batch) {
874         return false;
875       }
876 
877       // If activations have not been propagated through, we can do
878       // space-to-batch on them provided kernel has been propagated.
879       VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
880                  " activations will be space-to-batched";
881       return true;
882     }
883 
884     auto first_operand = old_to_new_instrs_[activations];
885     auto dim_map_val_op_0 = instr_to_dim_map_[activations];
886     auto second_operand = old_to_new_instrs_[kernel];
887     auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
888 
889     auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
890     auto permute_dims_second_operand =
891         instr_to_dim_permute_map_[second_operand];
892 
893     const int64 new_batch_dim_operand_0 =
894         DimLookUp(permute_dims_first_operand, dim_map_val_op_0.first);
895     const int64 new_space_dim_operand_0 =
896         DimLookUp(permute_dims_first_operand, dim_map_val_op_0.second);
897 
898     const int64 new_batch_dim_operand_1 =
899         DimLookUp(permute_dims_second_operand, dim_map_val_op_1.first);
900     const int64 new_space_dim_operand_1 =
901         DimLookUp(permute_dims_second_operand, dim_map_val_op_1.second);
902 
903     if (first_operand->shape().dimensions(new_batch_dim_operand_0) !=
904         second_operand->shape().dimensions(new_batch_dim_operand_1)) {
905       VLOG(2) << "Backprop filter conv not ready for propagation because batch "
906                  "dimensions don't line up";
907       return false;
908     }
909 
910     if (first_operand->shape().dimensions(new_space_dim_operand_0) >
911         rhs_dilation *
912             second_operand->shape().dimensions(new_space_dim_operand_1)) {
913       VLOG(2) << "Backprop filter conv not ready for propagation because of "
914                  "dilation factor mismatch";
915       return false;
916     }
917 
918     VLOG(2) << "Backprop filter conv ready for propagation";
919 
920     return true;
921   }
922 
923   if (consumer->opcode() == HloOpcode::kReduceWindow ||
924       consumer->opcode() == HloOpcode::kReduce) {
925     for (int64 i = 0; i < consumer->operand_count(); ++i) {
926       auto old_producer = consumer->mutable_operand(i);
927       if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
928         return false;
929       }
930     }
931   }
932 
933   if (consumer->opcode() == HloOpcode::kSelectAndScatter) {
934     // We currently only support adds in the scatter.
935     auto scatter_comp = consumer->scatter();
936     if (!Match(scatter_comp->root_instruction(),
937                m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
938       return false;
939     }
940 
941     for (int64 i = 0; i < consumer->operand_count(); ++i) {
942       auto old_producer = consumer->mutable_operand(i);
943       if (i < 2 && !old_to_new_instrs_.contains(old_producer)) {
944         return false;
945       }
946     }
947 
948     auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
949     auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
950     auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
951 
952     auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
953     auto permute_dims_second_operand =
954         instr_to_dim_permute_map_[second_operand];
955 
956     // The permuting must match.
957     if (permute_dims_first_operand != permute_dims_second_operand) {
958       VLOG(2) << "Can't propagate through select and scatter due to "
959                  "permutation mismatch";
960       return false;
961     }
962 
963     const int64 old_batch_dim = dim_map_val_op_0.first;
964     const int64 old_space_dim = dim_map_val_op_0.second;
965 
966     const int64 new_batch_dim =
967         DimLookUp(permute_dims_first_operand, old_batch_dim);
968     const int64 new_space_dim =
969         DimLookUp(permute_dims_first_operand, old_space_dim);
970 
971     if (first_operand->shape().dimensions(new_batch_dim) !=
972         second_operand->shape().dimensions(new_batch_dim)) {
973       VLOG(2)
974           << "Can't propagate through select and scatter due to dim mismatch";
975       return false;
976     }
977 
978     const int64 stride = consumer->window().dimensions(old_space_dim).stride();
979     const int64 pad_high =
980         consumer->window().dimensions(old_space_dim).padding_high();
981     const int64 pad_low =
982         consumer->window().dimensions(old_space_dim).padding_low();
983 
984     if ((first_operand->shape().dimensions(new_space_dim) + pad_high +
985          pad_low) /
986             stride !=
987         second_operand->shape().dimensions(new_space_dim)) {
988       VLOG(2) << "Can't propagate through select and scatter due to stride "
989                  "mismatch";
990       return false;
991     }
992     VLOG(1) << "Can propagate through select and scatter";
993     return true;
994   }
995   return true;
996 }
997 
PropagateOnBroadcast(HloInstruction * consumer,HloInstruction * producer)998 void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer,
999                                               HloInstruction* producer) {
1000   auto new_producer = old_to_new_instrs_[producer];
1001   auto permute_dims = instr_to_dim_permute_map_[new_producer];
1002   auto dim_map_val = instr_to_dim_map_[producer];
1003 
1004   const int64 old_batch_dim = dim_map_val.first;
1005   const int64 old_space_dim = dim_map_val.second;
1006 
1007   auto orig_broadcast_dims = consumer->dimensions();
1008 
1009   bool batch_is_broadcasted =
1010       absl::c_linear_search(orig_broadcast_dims, old_batch_dim);
1011   const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1012   const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
1013 
1014   bool map_found = broadcast_map_.contains(consumer);
1015   if (map_found) {
1016     // Check if we previously had created the same broadcast.
1017     for (auto previous_broadcast : broadcast_map_[consumer]) {
1018       if (ShapeUtil::CompatibleIgnoringElementType(previous_broadcast->shape(),
1019                                                    new_producer->shape())) {
1020         return;
1021       }
1022     }
1023   }
1024 
1025   std::vector<int64> final_shape_dims(
1026       new_producer->shape().dimensions().begin(),
1027       new_producer->shape().dimensions().end());
1028   if (batch_is_broadcasted) {
1029     final_shape_dims[new_batch_dim] =
1030         producer->shape().dimensions(old_batch_dim);
1031     final_shape_dims[new_space_dim] *= kNumSplits;
1032   }
1033 
1034   std::vector<int64> broadcast_dims;
1035   for (auto j : consumer->dimensions()) {
1036     broadcast_dims.push_back(DimLookUp(permute_dims, j));
1037   }
1038   auto new_broadcast = MakeBroadcastHlo(consumer->mutable_operand(0),
1039                                         broadcast_dims, final_shape_dims);
1040   VLOG(1) << "Created broadcast " << new_broadcast->ToString();
1041 
1042   if (batch_is_broadcasted) {
1043     new_broadcast =
1044         MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast)
1045             .ValueOrDie();
1046     VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString();
1047   }
1048 
1049   if (!map_found) {
1050     absl::flat_hash_set<HloInstruction*> set_of_broadcasts;
1051     broadcast_map_[consumer] = set_of_broadcasts;
1052   }
1053   broadcast_map_[consumer].insert(new_broadcast);
1054 }
1055 
RewriteBroadcastTree(HloInstruction * producer,std::vector<HloInstruction * > & instructions_to_transform)1056 void ConvolutionVisitor::RewriteBroadcastTree(
1057     HloInstruction* producer,
1058     std::vector<HloInstruction*>& instructions_to_transform) {
1059   CHECK(old_to_new_instrs_.contains(producer));
1060   for (auto instr : instructions_to_transform) {
1061     if (instr->opcode() == HloOpcode::kBroadcast) {
1062       PropagateOnBroadcast(instr, producer);
1063     } else if (IsTrivialElementwise(instr)) {
1064       Propagate(instr, /*producer=*/instr->mutable_operand(0)).ValueOrDie();
1065     } else {
1066       LOG(FATAL) << "Unsupported opcode in RewriteBroadcastTree";
1067     }
1068   }
1069 }
1070 
IsBroadcastTree(HloInstruction * op,HloInstruction * consumer,std::vector<HloInstruction * > & instructions_to_transform)1071 bool ConvolutionVisitor::IsBroadcastTree(
1072     HloInstruction* op, HloInstruction* consumer,
1073     std::vector<HloInstruction*>& instructions_to_transform) {
1074   if (op->opcode() == HloOpcode::kBroadcast) {
1075     // We want to ensure that the broadcast did not happen on the space and
1076     // batch dimensions.
1077     if (IsBroadcastPropagatable(op, consumer)) {
1078       instructions_to_transform.push_back(op);
1079       return true;
1080     } else {
1081       return false;
1082     }
1083   }
1084   if (Match(op, m::ConstantScalar())) {
1085     return true;
1086   }
1087   if (!IsTrivialElementwise(op)) {
1088     return false;
1089   }
1090   for (int64 i = 0; i < op->operand_count(); ++i) {
1091     if (!IsBroadcastTree(op->mutable_operand(i), consumer,
1092                          instructions_to_transform)) {
1093       return false;
1094     }
1095   }
1096   instructions_to_transform.push_back(op);
1097   return true;
1098 }
1099 
IsBroadcastPropagatable(HloInstruction * broadcast,HloInstruction * old_other_op)1100 bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast,
1101                                                  HloInstruction* old_other_op) {
1102   CHECK_EQ(broadcast->opcode(), HloOpcode::kBroadcast);
1103   CHECK(instr_to_dim_map_.contains(old_other_op));
1104 
1105   auto result = instr_to_dim_map_[old_other_op];
1106   const int64 space_dim = result.second;
1107   auto broadcast_dims = broadcast->dimensions();
1108   return !absl::c_linear_search(broadcast_dims, space_dim);
1109 }
1110 
SupportedOpForPropagation(HloInstruction * consumer,HloInstruction * producer)1111 bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
1112                                                    HloInstruction* producer) {
1113   if (IsTrivialElementwise(consumer)) {
1114     for (int64 i = 0; i < consumer->operand_count(); ++i) {
1115       if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1116         if (!IsBroadcastPropagatable(consumer->mutable_operand(i), producer)) {
1117           VLOG(2) << "Could not propagate through broadcast";
1118           return false;
1119         }
1120       }
1121     }
1122     return true;
1123   }
1124 
1125   if (consumer->opcode() == HloOpcode::kConvolution) {
1126     return true;
1127   }
1128 
1129   if (consumer->opcode() == HloOpcode::kReduce) {
1130     // Support only the trivial case where both batch and split spatial dim are
1131     // being reduced
1132 
1133     auto reduce_dims = consumer->dimensions();
1134     auto result = instr_to_dim_map_[consumer->mutable_operand(0)];
1135     const int64 batch_dim = result.first;
1136     const int64 space_dim = result.second;
1137     VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim
1138             << "  space_dim " << space_dim << " reduce "
1139             << consumer->ToString();
1140     return absl::c_linear_search(reduce_dims, batch_dim) &&
1141            absl::c_linear_search(reduce_dims, space_dim);
1142   }
1143 
1144   if (consumer->opcode() == HloOpcode::kReduceWindow &&
1145       consumer->shape().IsTuple()) {
1146     // TODO (b/73062247) variadic reduce window is not yet supported.
1147     return false;
1148   }
1149   if (consumer->opcode() == HloOpcode::kReduceWindow ||
1150       consumer->opcode() == HloOpcode::kSelectAndScatter) {
1151     auto first_operand = consumer->mutable_operand(0);
1152     auto window = consumer->window();
1153     if (instr_to_dim_map_.count(first_operand) <= 0) {
1154       VLOG(1) << "Dim map not found on windowed operand. Window dim count "
1155               << window.dimensions().size();
1156       return false;
1157     }
1158     // Disallow windowing on on the batch dim
1159     auto result = instr_to_dim_map_[first_operand];
1160     const int64 old_batch_dim = result.first;
1161     const int64 old_space_dim = result.second;
1162     if (window.dimensions(old_batch_dim).size() != 1) {
1163       return false;
1164     }
1165 
1166     // Only allow no-low-padding cases.
1167     if (window.dimensions(old_space_dim).padding_low() != 0) {
1168       return false;
1169     }
1170 
1171     // Only allow small high pads.
1172     if (window.dimensions(old_space_dim).padding_high() >
1173         window.dimensions(old_space_dim).size()) {
1174       return false;
1175     }
1176 
1177     // Operand 0 must have been propagated through
1178     if (old_to_new_instrs_.count(first_operand) <= 0) {
1179       return false;
1180     }
1181 
1182     auto new_operand = old_to_new_instrs_[first_operand];
1183     auto permute_dims = instr_to_dim_permute_map_[new_operand];
1184     const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
1185 
1186     // Make sure that the stride lines up.
1187     if (window.dimensions(old_space_dim).size() != 1) {
1188       if (new_operand->shape().dimensions(new_space_dim) %
1189               window.dimensions(old_space_dim).stride() !=
1190           0) {
1191         return false;
1192       }
1193     }
1194 
1195     return true;
1196   }
1197 
1198   return false;
1199 }
1200 
Propagate(HloInstruction * consumer,HloInstruction * producer)1201 StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
1202                                              HloInstruction* producer) {
1203   auto computation = consumer->parent();
1204   if (IsTrivialElementwise(consumer)) {
1205     auto dim_map_val = instr_to_dim_map_[producer];
1206     auto new_consumer = computation->AddInstruction(consumer->Clone());
1207 
1208     bool is_pivot_producer_modified = false;
1209     // For elementwise binary ops, both of whose operands have been space-to-
1210     // batched, if their new spatial sizes don't match, choose the bigger one
1211     // as the producer.
1212     if (consumer->IsElementwiseBinary() ||
1213         consumer->opcode() == HloOpcode::kSelect) {
1214       int64 pivot_operand_number = -1;
1215       HloInstruction* pivot_operand = nullptr;
1216       for (int i = 0; i < consumer->operand_count(); ++i) {
1217         if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1218           continue;
1219         }
1220         auto operand = consumer->mutable_operand(i);
1221         if (old_to_new_instrs_.contains(operand)) {
1222           if (pivot_operand_number == -1 ||
1223               old_to_new_instrs_[pivot_operand]->shape().dimensions() <
1224                   old_to_new_instrs_[operand]->shape().dimensions()) {
1225             is_pivot_producer_modified = true;
1226             pivot_operand_number = i;
1227             pivot_operand = consumer->mutable_operand(pivot_operand_number);
1228           }
1229         }
1230       }
1231       if (pivot_operand_number != -1) {
1232         producer = pivot_operand;
1233       }
1234     }
1235 
1236     for (int64 i = 0; i < consumer->operand_count(); ++i) {
1237       std::vector<HloInstruction*> instructions_to_transform;
1238 
1239       if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1240         auto broadcast = consumer->mutable_operand(i);
1241         PropagateOnBroadcast(broadcast, producer);
1242         HloInstruction* new_broadcast = nullptr;
1243         auto new_producer = old_to_new_instrs_[producer];
1244         for (auto previous_broadcast : broadcast_map_[broadcast]) {
1245           if (ShapeUtil::CompatibleIgnoringElementType(
1246                   previous_broadcast->shape(), new_producer->shape())) {
1247             new_broadcast = previous_broadcast;
1248             break;
1249           }
1250         }
1251         CHECK_NE(new_broadcast, nullptr);
1252         TF_CHECK_OK(
1253             new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
1254       } else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
1255         HloInstruction* operand_to_use = nullptr;
1256         auto result = instr_to_dim_map_[producer];
1257         const int64 old_batch_dim = result.first;
1258         const int64 old_space_dim = result.second;
1259         const int64 old_batch_size =
1260             producer->shape().dimensions(old_batch_dim);
1261         HloInstruction* new_instr =
1262             old_to_new_instrs_[consumer->mutable_operand(i)];
1263         HloInstruction* pivot_new_instr = old_to_new_instrs_[producer];
1264 
1265         auto permute_dims = instr_to_dim_permute_map_[new_instr];
1266         const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
1267         const int64 space_dim = DimLookUp(permute_dims, old_space_dim);
1268         const int64 batch_size = new_instr->shape().dimensions(batch_dim);
1269 
1270         if (new_instr->shape().dimensions(space_dim) !=
1271             pivot_new_instr->shape().dimensions(space_dim)) {
1272           // Because we do not propagate through transposes, the batch should
1273           // always be followed by the split space dimension.
1274           CHECK_EQ(batch_dim + 1, space_dim);
1275 
1276           // Reshape to 1D, pad to the producer's size, reshape back to 2D.
1277           std::vector<int64> new_dimensions(
1278               new_instr->shape().dimensions().begin(),
1279               new_instr->shape().dimensions().end());
1280           new_dimensions[space_dim] *= (batch_size / old_batch_size);
1281           new_dimensions[batch_dim] = old_batch_size;
1282 
1283           TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
1284                               MakeReshapeHlo(new_dimensions, new_instr));
1285 
1286           const int64 pivot_space_size =
1287               pivot_new_instr->shape().dimensions(space_dim) * batch_size /
1288               old_batch_size;
1289 
1290           CHECK(pivot_space_size > new_dimensions[space_dim] ||
1291                 !is_pivot_producer_modified);
1292 
1293           PaddingConfig padding_config =
1294               MakeNoPaddingConfig(reshape->shape().dimensions_size());
1295           padding_config.mutable_dimensions(space_dim)->set_edge_padding_high(
1296               pivot_space_size - new_dimensions[space_dim]);
1297           padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0);
1298           HloInstruction* padding =
1299               computation_->AddInstruction(HloInstruction::CreateConstant(
1300                   LiteralUtil::Zero(reshape->shape().element_type())));
1301 
1302           TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand,
1303                               MakePadHlo(reshape, padding, padding_config));
1304 
1305           TF_ASSIGN_OR_RETURN(
1306               operand_to_use,
1307               MakeReshapeHlo(pivot_new_instr->shape().dimensions(),
1308                              padded_operand));
1309 
1310         } else {
1311           operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)];
1312         }
1313         TF_CHECK_OK(
1314             new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
1315       } else if (consumer->IsElementwiseBinary() &&
1316                  consumer->mutable_operand(i)->opcode() ==
1317                      HloOpcode::kBroadcast &&
1318                  IsBroadcastTree(consumer->mutable_operand(i), producer,
1319                                  instructions_to_transform)) {
1320         RewriteBroadcastTree(producer, instructions_to_transform);
1321         TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape(
1322             i, old_to_new_instrs_[consumer->mutable_operand(i)]));
1323       } else if (consumer->operand(i)->opcode() == HloOpcode::kConstant) {
1324         TF_ASSIGN_OR_RETURN(
1325             auto new_constant,
1326             PropagateOnConstant(consumer->mutable_operand(i), producer));
1327         TF_CHECK_OK(
1328             new_consumer->ReplaceOperandWithDifferentShape(i, new_constant));
1329       }
1330     }
1331     auto old_type = new_consumer->mutable_shape()->element_type();
1332     *(new_consumer->mutable_shape()) = old_to_new_instrs_[producer]->shape();
1333 
1334     // The element type needs to be retained.
1335     new_consumer->mutable_shape()->set_element_type(old_type);
1336 
1337     old_to_new_instrs_[consumer] = new_consumer;
1338     instr_to_dim_map_[consumer] = dim_map_val;
1339     CHECK(instr_to_dim_permute_map_.contains(old_to_new_instrs_[producer]));
1340     instr_to_dim_permute_map_[new_consumer] = std::vector<int64>(
1341         instr_to_dim_permute_map_[old_to_new_instrs_[producer]]);
1342 
1343     VLOG(2) << " new_consumer " << new_consumer->ToString()
1344             << " old_to_new_instrs_[producer] "
1345             << old_to_new_instrs_[producer]->ToString() << " permute dims "
1346             << instr_to_dim_permute_map_.count(new_consumer);
1347 
1348     return true;
1349   }
1350 
1351   if (consumer->opcode() == HloOpcode::kConvolution) {
1352     if (IsConvSuitableForSpaceToBatch(consumer)) {
1353       TF_CHECK_OK(PropagateOnConv(consumer));
1354       return true;
1355     } else {
1356       TF_CHECK_OK(PropagateOnBackpropFilterConv(consumer));
1357       return false;
1358     }
1359   }
1360 
1361   if (consumer->opcode() == HloOpcode::kReduce) {
1362     auto new_consumer = computation->AddInstruction(consumer->Clone());
1363     auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1364 
1365     auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
1366     const int64 old_batch_dim = dim_map_val.first;
1367     const int64 old_space_dim = dim_map_val.second;
1368     auto permute_dims = instr_to_dim_permute_map_[first_operand];
1369     const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1370     const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
1371 
1372     TF_ASSIGN_OR_RETURN(
1373         first_operand,
1374         SelectValidPortion(first_operand, consumer->mutable_operand(0),
1375                            consumer->mutable_operand(1), new_batch_dim,
1376                            new_space_dim, old_batch_dim, old_space_dim));
1377 
1378     std::vector<int64> changed_dims(new_consumer->dimensions().size());
1379     for (int64 i = 0; i < new_consumer->dimensions().size(); ++i) {
1380       changed_dims[i] = DimLookUp(permute_dims, new_consumer->dimensions(i));
1381     }
1382     *(new_consumer->mutable_dimensions()) = changed_dims;
1383     // Replace operand 0.
1384     TF_CHECK_OK(
1385         new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
1386     // We do not set instr_to_dim_permute_map_ here because no further
1387     // propagation is needed here.
1388     old_to_new_instrs_[consumer] = new_consumer;
1389     instr_to_dim_map_[consumer] = dim_map_val;
1390 
1391     // Since the resultant ordering of dimension is the same as before, no
1392     // further propagation is needed.
1393     return false;
1394   }
1395 
1396   if (consumer->opcode() == HloOpcode::kReduceWindow ||
1397       consumer->opcode() == HloOpcode::kSelectAndScatter) {
1398     bool is_select_and_scatter =
1399         consumer->opcode() == HloOpcode::kSelectAndScatter;
1400     auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1401 
1402     auto init_val = is_select_and_scatter ? consumer->mutable_operand(2)
1403                                           : consumer->mutable_operand(1);
1404     auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
1405     const int64 old_batch_dim = dim_map_val.first;
1406     const int64 old_space_dim = dim_map_val.second;
1407     auto permute_dims = instr_to_dim_permute_map_[first_operand];
1408     const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1409     const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
1410 
1411     TF_ASSIGN_OR_RETURN(
1412         first_operand,
1413         SelectValidPortion(first_operand, consumer->mutable_operand(0),
1414                            init_val, new_batch_dim, new_space_dim,
1415                            old_batch_dim, old_space_dim));
1416 
1417     // Calculate the required halo size
1418     auto new_shape = first_operand->shape();
1419     auto old_shape = consumer->mutable_operand(0)->shape();
1420 
1421     const int64 new_batch_size = new_shape.dimensions(new_batch_dim);
1422     const int64 new_space_size = new_shape.dimensions(new_space_dim);
1423     const int64 stride = consumer->window().dimensions(old_space_dim).stride();
1424     const int64 window_size =
1425         consumer->window().dimensions(old_space_dim).size();
1426     const int64 last_overlap_point = ((new_space_size - 1) / stride) * stride;
1427     VLOG(1) << "last_overlap_point " << last_overlap_point << " window_size "
1428             << window_size << " new_space_size " << new_space_size;
1429 
1430     const int64 halo_size = last_overlap_point + window_size - new_space_size;
1431     if (halo_size > 0) {
1432       TF_ASSIGN_OR_RETURN(first_operand,
1433                           HaloDuplicateWithSlice(first_operand, new_space_dim,
1434                                                  new_batch_dim, new_batch_size,
1435                                                  /*low_padding=*/0,
1436                                                  /*high_padding=*/0, halo_size,
1437                                                  new_space_size, init_val));
1438     }
1439 
1440     Window new_win;
1441     for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) {
1442       auto dim = ReverseDimLookUp(permute_dims, i);
1443       new_win.add_dimensions();
1444       new_win.mutable_dimensions(i)->set_stride(
1445           consumer->window().dimensions(dim).stride());
1446       new_win.mutable_dimensions(i)->set_size(
1447           consumer->window().dimensions(dim).size());
1448       if (i == old_space_dim) {
1449         new_win.mutable_dimensions(i)->set_padding_high(0);
1450         new_win.mutable_dimensions(i)->set_padding_low(0);
1451       } else {
1452         new_win.mutable_dimensions(i)->set_padding_high(
1453             consumer->window().dimensions(dim).padding_high());
1454         new_win.mutable_dimensions(i)->set_padding_low(
1455             consumer->window().dimensions(dim).padding_low());
1456       }
1457       new_win.mutable_dimensions(i)->set_window_dilation(
1458           consumer->window().dimensions(dim).window_dilation());
1459       new_win.mutable_dimensions(i)->set_base_dilation(
1460           consumer->window().dimensions(dim).base_dilation());
1461       new_win.mutable_dimensions(i)->set_window_reversal(
1462           consumer->window().dimensions(dim).window_reversal());
1463     }
1464 
1465     new_shape = first_operand->shape();
1466 
1467     HloInstruction* new_consumer = nullptr;
1468     if (is_select_and_scatter) {
1469       auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
1470 
1471       auto select_comp = consumer->select();
1472 
1473       auto scatter_comp = consumer->scatter();
1474       TF_ASSIGN_OR_RETURN(
1475           auto new_select_and_scatter_shape,
1476           ShapeInference::InferSelectAndScatterShape(
1477               new_shape, select_comp->ComputeProgramShape(), new_win,
1478               second_operand->shape(), init_val->shape(),
1479               scatter_comp->ComputeProgramShape()));
1480       new_consumer =
1481           computation_->AddInstruction(HloInstruction::CreateSelectAndScatter(
1482               new_select_and_scatter_shape, first_operand, select_comp, new_win,
1483               second_operand, init_val, scatter_comp));
1484       // Replace operand 0.
1485       TF_CHECK_OK(
1486           new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
1487       // Replace operand 1.
1488       TF_CHECK_OK(
1489           new_consumer->ReplaceOperandWithDifferentShape(1, second_operand));
1490       VLOG(2) << "New select and scatter " << new_consumer->ToString();
1491 
1492       // If the window size was larger than the stride, there could be overlaps.
1493       // Such cases require updates from both overlaps to be applied.
1494       if (halo_size > 0) {
1495         const int64 rank = new_consumer->shape().rank();
1496 
1497         const int64 batch_size =
1498             new_consumer->shape().dimensions(new_batch_dim);
1499 
1500         std::vector<int64> start_indices(rank, 0),
1501             end_indices(new_consumer->shape().dimensions().begin(),
1502                         new_consumer->shape().dimensions().end()),
1503             strides(rank, 1);
1504         start_indices[new_space_dim] = new_space_size;
1505         end_indices[new_space_dim] = new_space_size + halo_size;
1506         end_indices[new_batch_dim] = batch_size - 1;
1507 
1508         // This is the slice from halo padding.
1509         TF_ASSIGN_OR_RETURN(
1510             HloInstruction * bottom,
1511             MakeSliceHlo(new_consumer, start_indices, end_indices, strides));
1512 
1513         std::vector<int64> start_indices_top(rank, 0),
1514             end_indices_top(new_consumer->shape().dimensions().begin(),
1515                             new_consumer->shape().dimensions().end());
1516         end_indices_top[new_space_dim] = halo_size;
1517         // The first batch has correct data.
1518         start_indices_top[new_batch_dim] = 1;
1519 
1520         // This is the original area from where halo pad was extracted.
1521         TF_ASSIGN_OR_RETURN(HloInstruction * top,
1522                             MakeSliceHlo(new_consumer, start_indices_top,
1523                                          end_indices_top, strides));
1524 
1525         HloInstruction* default_fill =
1526             MakeBroadcastHlo(init_val, {}, top->shape().dimensions());
1527 
1528         // Compare to see if the bottom area was changed.
1529         TF_ASSIGN_OR_RETURN(
1530             HloInstruction * bottom_compare,
1531             MakeCompareHlo(ComparisonDirection::kNe, bottom, default_fill));
1532 
1533         // Take out only the changed values.
1534         TF_ASSIGN_OR_RETURN(
1535             HloInstruction * bottom_taken,
1536             MakeSelectHlo(bottom_compare, bottom, default_fill));
1537 
1538         // Compare to see if the top area was changed.
1539         TF_ASSIGN_OR_RETURN(
1540             HloInstruction * top_compare,
1541             MakeCompareHlo(ComparisonDirection::kNe, top, default_fill));
1542 
1543         // Take out only the changed values.
1544         TF_ASSIGN_OR_RETURN(HloInstruction * top_taken,
1545                             MakeSelectHlo(top_compare, top, bottom_taken));
1546 
1547         // This makes checks if the area was updated by both overlaps.
1548         TF_ASSIGN_OR_RETURN(
1549             HloInstruction * both_compare,
1550             MakeBinaryHlo(HloOpcode::kAnd, top_compare, bottom_compare));
1551 
1552         // If it was, add them up.
1553         TF_ASSIGN_OR_RETURN(HloInstruction * both_added,
1554                             MakeBinaryHlo(HloOpcode::kAdd, top, bottom));
1555 
1556         // Pad the final result to the original shape.
1557         TF_ASSIGN_OR_RETURN(HloInstruction * final_selection,
1558                             MakeSelectHlo(both_compare, both_added, top_taken));
1559 
1560         PaddingConfig padding_config =
1561             MakeNoPaddingConfig(final_selection->shape().dimensions_size());
1562         padding_config.mutable_dimensions(new_batch_dim)
1563             ->set_edge_padding_low(1);
1564         padding_config.mutable_dimensions(new_space_dim)
1565             ->set_edge_padding_high(new_space_size);
1566         HloInstruction* padding =
1567             computation_->AddInstruction(HloInstruction::CreateConstant(
1568                 LiteralUtil::Zero(final_selection->shape().element_type())));
1569 
1570         TF_ASSIGN_OR_RETURN(
1571             final_selection,
1572             MakePadHlo(final_selection, padding, padding_config));
1573 
1574         tensorflow::core::Bitmap b(batch_size * (new_space_size + halo_size));
1575         for (int k = 0; k < batch_size * (new_space_size + halo_size); ++k) {
1576           const int64 space_index = k % (new_space_size + halo_size);
1577           const int64 batch_index = (k / (new_space_size + halo_size));
1578           if (batch_index < 1 || space_index >= halo_size) {
1579             b.set(k);
1580           } else {
1581             b.clear(k);
1582           }
1583         }
1584 
1585         auto arg_literal = LiteralUtil::CreateR1(b);
1586         VLOG(4) << "Slice mask created: arg literal " << arg_literal.ToString();
1587         HloInstruction* slice_mask = computation_->AddInstruction(
1588             HloInstruction::CreateConstant(std::move(arg_literal)));
1589 
1590         std::vector<int64> slice_mask_reshape_dims(2);
1591         slice_mask_reshape_dims[0] = batch_size;
1592         slice_mask_reshape_dims[1] = (new_space_size + halo_size);
1593 
1594         TF_ASSIGN_OR_RETURN(
1595             HloInstruction * slice_mask_reshaped,
1596             MakeReshapeHlo(slice_mask_reshape_dims, slice_mask));
1597 
1598         // Broadcast the mask in all dimensions.
1599         HloInstruction* shape_mask = MakeBroadcastHlo(
1600             slice_mask_reshaped, {new_batch_dim, new_space_dim},
1601             final_selection->shape().dimensions());
1602 
1603         TF_ASSIGN_OR_RETURN(
1604             new_consumer,
1605             MakeSelectHlo(shape_mask, new_consumer, final_selection));
1606       }
1607 
1608       auto previous_shape =
1609           old_to_new_instrs_[consumer->mutable_operand(0)]->shape();
1610       std::vector<int64> start_indices(previous_shape.rank(), 0),
1611           end_indices(previous_shape.dimensions().begin(),
1612                       previous_shape.dimensions().end()),
1613           strides(previous_shape.rank(), 1);
1614 
1615       TF_ASSIGN_OR_RETURN(
1616           new_consumer,
1617           MakeSliceHlo(new_consumer, start_indices, end_indices, strides));
1618 
1619     } else {
1620       auto reduce_comp = consumer->to_apply();
1621       TF_ASSIGN_OR_RETURN(auto new_reduce_window_shape,
1622                           ShapeInference::InferReduceWindowShape(
1623                               new_shape, init_val->shape(), new_win));
1624       new_consumer =
1625           computation_->AddInstruction(HloInstruction::CreateReduceWindow(
1626               new_reduce_window_shape, first_operand, init_val, new_win,
1627               reduce_comp));
1628       // Replace operand 0.
1629       TF_CHECK_OK(
1630           new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
1631       VLOG(1) << "New reduce window " << new_consumer->ToString();
1632     }
1633 
1634     old_to_new_instrs_[consumer] = new_consumer;
1635     instr_to_dim_map_[consumer] = dim_map_val;
1636 
1637     instr_to_dim_permute_map_[new_consumer] = std::vector<int64>(
1638         instr_to_dim_permute_map_[old_to_new_instrs_[consumer->mutable_operand(
1639             0)]]);
1640 
1641     return true;
1642   }
1643 
1644   LOG(FATAL) << "Trying to propagate through an unsupported instruction "
1645              << consumer->ToString();
1646   return true;
1647 }
1648 
SelectValidPortion(HloInstruction * new_instr,HloInstruction * old_instr,HloInstruction * select_val,int64 new_batch_dim,int64 new_space_dim,int64 old_batch_dim,int64 old_space_dim)1649 StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
1650     HloInstruction* new_instr, HloInstruction* old_instr,
1651     HloInstruction* select_val, int64 new_batch_dim, int64 new_space_dim,
1652     int64 old_batch_dim, int64 old_space_dim) {
1653   auto new_shape = new_instr->shape();
1654   auto old_shape = old_instr->shape();
1655   VLOG(1) << "In SelectValidPortion new_batch_dim " << new_batch_dim
1656           << " new_space_dim " << new_space_dim << " old_batch_dim "
1657           << old_batch_dim << " old_space_dim " << old_space_dim;
1658   const int64 new_batch_size = new_shape.dimensions(new_batch_dim);
1659   const int64 new_space_size = new_shape.dimensions(new_space_dim);
1660   const int64 old_batch_size = old_shape.dimensions(old_batch_dim);
1661   const int64 old_space_size = old_shape.dimensions(old_space_dim);
1662   CHECK_EQ(new_batch_size % old_batch_size, 0)
1663       << " New batch size " << new_batch_size << " old batch size "
1664       << old_batch_size;
1665   const int64 num_splits = new_batch_size / old_batch_size;
1666   // Build a constant PRED to decide which elements in the split dimension
1667   // are from halo.
1668   tensorflow::core::Bitmap b(new_batch_size * new_space_size);
1669   for (int k = 0; k < new_batch_size * new_space_size; ++k) {
1670     const int64 space_index = k % new_space_size;
1671     const int64 batch_index = (k / new_space_size) % num_splits;
1672     if (batch_index * new_space_size + space_index < old_space_size) {
1673       b.set(k);
1674     } else {
1675       b.clear(k);
1676     }
1677   }
1678 
1679   auto arg_literal = LiteralUtil::CreateR1(b);
1680   VLOG(4) << "Slice mask created: arg literal " << arg_literal.ToString();
1681   HloInstruction* slice_mask = computation_->AddInstruction(
1682       HloInstruction::CreateConstant(std::move(arg_literal)));
1683 
1684   std::vector<int64> slice_mask_reshape_dims(2);
1685   slice_mask_reshape_dims[0] = new_batch_size;
1686   slice_mask_reshape_dims[1] = new_space_size;
1687 
1688   TF_ASSIGN_OR_RETURN(HloInstruction * slice_mask_reshaped,
1689                       MakeReshapeHlo(slice_mask_reshape_dims, slice_mask));
1690 
1691   // Broadcast the mask in all dimensions of the activations.
1692   HloInstruction* shape_mask =
1693       MakeBroadcastHlo(slice_mask_reshaped, {new_batch_dim, new_space_dim},
1694                        new_instr->shape().dimensions());
1695 
1696   VLOG(1) << "Shape mask made " << shape_mask->ToString();
1697 
1698   HloInstruction* zeroes =
1699       MakeBroadcastHlo(select_val, {}, new_instr->shape().dimensions());
1700 
1701   TF_ASSIGN_OR_RETURN(new_instr, MakeSelectHlo(shape_mask, new_instr, zeroes));
1702 
1703   return new_instr;
1704 }
1705 
BatchToSpace(HloInstruction * old_instr)1706 StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
1707     HloInstruction* old_instr) {
1708   if (batch_to_space_map_.count(old_instr)) {
1709     CHECK_NE(batch_to_space_map_[old_instr], nullptr);
1710     return batch_to_space_map_[old_instr];
1711   }
1712 
1713   auto result = instr_to_dim_map_[old_instr];
1714   const int64 old_batch_dim = result.first;
1715   const int64 old_space_dim = result.second;
1716 
1717   const int64 old_batch_size = old_instr->shape().dimensions(old_batch_dim);
1718   CHECK(old_to_new_instrs_.contains(old_instr));
1719   auto new_instr = old_to_new_instrs_[old_instr];
1720   VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
1721           << old_space_dim << " old_instr " << old_instr->ToString()
1722           << "\n new_instr " << new_instr->ToString() << " permute dims "
1723           << instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
1724           << old_batch_size;
1725   CHECK(instr_to_dim_permute_map_.contains(new_instr));
1726   auto permute_dims = instr_to_dim_permute_map_[new_instr];
1727   const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
1728   const int64 space_dim = DimLookUp(permute_dims, old_space_dim);
1729   const int64 batch_size = new_instr->shape().dimensions(batch_dim);
1730 
1731   std::vector<int64> new_dimensions(new_instr->shape().dimensions().begin(),
1732                                     new_instr->shape().dimensions().end());
1733   new_dimensions[space_dim] *= (batch_size / old_batch_size);
1734   new_dimensions[batch_dim] = old_batch_size;
1735   // Reshape the output of the new conv into the old convolutions shape.
1736   TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
1737                       MakeReshapeHlo(new_dimensions, new_instr));
1738 
1739   const int64 rank = old_instr->shape().rank();
1740   std::vector<int64> start_indices(rank, 0),
1741       end_indices(new_dimensions.begin(), new_dimensions.end()),
1742       strides(rank, 1);
1743   end_indices[space_dim] = old_instr->shape().dimensions(old_space_dim);
1744 
1745   // This slicing is getting rid of the padding we added to evenly divide space.
1746   TF_ASSIGN_OR_RETURN(
1747       HloInstruction * output_slice,
1748       MakeSliceHlo(reshape, start_indices, end_indices, strides));
1749   VLOG(1) << "Batch to space slice " << output_slice->ToString();
1750   std::vector<int64> transpose_dims(permute_dims);
1751   TF_ASSIGN_OR_RETURN(HloInstruction * output_transpose,
1752                       MakeTransposeHlo(output_slice, transpose_dims));
1753 
1754   old_instr->SetupDerivedInstruction(output_transpose);
1755 
1756   batch_to_space_map_[old_instr] = output_transpose;
1757   return output_transpose;
1758 }
1759 
PropagateOnUsers(HloInstruction * old_conv)1760 Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) {
1761   std::queue<std::pair<HloInstruction*, HloInstruction*>> propagation_worklist;
1762 
1763   if (old_conv->user_count() == 0) {
1764     TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space,
1765                         BatchToSpace(old_conv));
1766     VLOG(1) << "Replacing the root instruction to "
1767             << batch_to_space->ToString();
1768     TF_CHECK_OK(computation_->ReplaceInstruction(old_conv, batch_to_space));
1769     VLOG(1) << "Replacement successful";
1770     return Status::OK();
1771   }
1772 
1773   int64 iteration_count = 0;
1774   propagation_worklist.push(
1775       std::make_pair(old_conv, old_conv->mutable_operand(0)));
1776 
1777   while (!propagation_worklist.empty()) {
1778     auto top = propagation_worklist.front();
1779     auto node = top.first;
1780     auto parent = top.second;
1781     VLOG(1) << "Traversing for propagation operating on " << node->ToString();
1782     propagation_worklist.pop();
1783 
1784     // Don't work on the same node again.
1785     if (old_to_new_instrs_.count(node) > 0 && iteration_count != 0) {
1786       continue;
1787     }
1788 
1789     bool needs_further_propagation = true;
1790     if (iteration_count != 0) {
1791       // Do the space-to-batch propagation on this node.
1792       TF_ASSIGN_OR_RETURN(needs_further_propagation, Propagate(node, parent));
1793     }
1794     iteration_count++;
1795     // If this is the root, no room for further propagation.
1796     if (node->parent()->root_instruction() == node) {
1797       // The below case does not need going back to space.
1798       if (!needs_further_propagation) {
1799         VLOG(1) << "Replacing the root instruction to "
1800                 << old_to_new_instrs_[node]->ToString();
1801         TF_CHECK_OK(
1802             computation_->ReplaceInstruction(node, old_to_new_instrs_[node]));
1803         continue;
1804       }
1805 
1806       TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, BatchToSpace(node));
1807       VLOG(1) << "Replacing the root instruction to "
1808               << batch_to_space->ToString();
1809       TF_CHECK_OK(computation_->ReplaceInstruction(node, batch_to_space));
1810     } else {
1811       if (!needs_further_propagation) {
1812         TF_CHECK_OK(
1813             computation_->ReplaceInstruction(node, old_to_new_instrs_[node]));
1814         continue;
1815       }
1816       // Insert all users into the queue, as long as the ops are supported and
1817       // the op is ready for propagation. If the op is unsupported, do
1818       // batch-to-space. If not ready, mark as non-propagatable.
1819       for (auto user : node->users()) {
1820         if (!SupportedOpForPropagation(user, node)) {
1821           VLOG(1) << "Unsupported op found " << user->ToString();
1822           TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space,
1823                               BatchToSpace(node));
1824           for (int64 i = 0; i < user->operand_count(); ++i) {
1825             if (user->operand(i) == node) {
1826               TF_CHECK_OK(user->ReplaceOperandWith(i, batch_to_space));
1827             }
1828           }
1829           continue;
1830         }
1831         // If the instruction is ready for propagation, add it to the queue.
1832         if (CanPropagate(user, node)) {
1833           non_propagatable_instrs_.erase(user);
1834           propagation_worklist.push(std::make_pair(user, node));
1835         } else {
1836           // Mark it as non-propagatable for now, for later revisiting.
1837           non_propagatable_instrs_.insert(user);
1838         }
1839       }
1840     }
1841   }
1842   return Status::OK();
1843 }
1844 
PropagateOnConv(HloInstruction * convolution)1845 Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
1846   auto activations_old = convolution->mutable_operand(0);
1847 
1848   CHECK(old_to_new_instrs_.contains(activations_old));
1849   auto activations_new = old_to_new_instrs_[activations_old];
1850   auto permute_dims = instr_to_dim_permute_map_[activations_new];
1851 
1852   auto original_conv_dims = convolution->convolution_dimension_numbers();
1853 
1854   const int64 old_space_dim = original_conv_dims.input_spatial_dimensions(
1855       get_chosen_spatial_dim(convolution));
1856   const int64 old_split_dim_size =
1857       convolution->mutable_operand(0)->shape().dimensions(old_space_dim);
1858 
1859   auto permuted_conv_dims_numbers = original_conv_dims;
1860 
1861   int64 activations_batch_dim =
1862       DimLookUp(permute_dims, original_conv_dims.input_batch_dimension());
1863   int64 activations_feature_dim =
1864       DimLookUp(permute_dims, original_conv_dims.input_feature_dimension());
1865   permuted_conv_dims_numbers.set_input_batch_dimension(activations_batch_dim);
1866   permuted_conv_dims_numbers.set_input_feature_dimension(
1867       activations_feature_dim);
1868 
1869   for (int64 i = 0; i < original_conv_dims.input_spatial_dimensions_size();
1870        ++i) {
1871     permuted_conv_dims_numbers.set_input_spatial_dimensions(
1872         i, DimLookUp(permute_dims,
1873                      original_conv_dims.input_spatial_dimensions(i)));
1874   }
1875 
1876   const int64 old_batch_dim = original_conv_dims.input_batch_dimension();
1877   const int64 old_batch_size =
1878       activations_old->shape().dimensions(old_batch_dim);
1879 
1880   ConvDetails c =
1881       GetConvolutionDetails(convolution, permuted_conv_dims_numbers);
1882 
1883   VLOG(1) << "Propagating on conv activations_batch_dim "
1884           << activations_batch_dim << " spatial_dimension_to_split "
1885           << c.spatial_dimension_to_split << " old_batch_size "
1886           << old_batch_size;
1887 
1888   TF_ASSIGN_OR_RETURN(auto retval,
1889                       BringSpaceNextToBatch(
1890                           activations_new, permuted_conv_dims_numbers,
1891                           c.spatial_dimension_to_split, activations_batch_dim));
1892   activations_new = retval.instr;
1893   std::vector<int64> trans_dims = retval.transpose_dims;
1894   CHECK(!trans_dims.empty());
1895   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
1896       LiteralUtil::Zero(activations_new->shape().element_type())));
1897 
1898   TF_ASSIGN_OR_RETURN(
1899       activations_new,
1900       SelectValidPortion(activations_new, activations_old, select_val,
1901                          activations_batch_dim, c.spatial_dimension_to_split,
1902                          old_batch_dim, old_space_dim));
1903   // Create the new convolution dim numbers.
1904   auto new_dim_numbers = permuted_conv_dims_numbers;
1905 
1906   VLOG(1) << "spatial size " << c.spatial_size;
1907 
1908   const int64 num_splits = kNumSplits;
1909   const int64 output_offsets = convolution->shape().dimensions(
1910       permuted_conv_dims_numbers.output_spatial_dimensions(
1911           get_chosen_spatial_dim(convolution)));
1912   const int64 output_offsets_per_split =
1913       CeilOfRatio(output_offsets, num_splits);
1914 
1915   int64 spatial_split_size =
1916       CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
1917 
1918   // Keep increasing the split size so that overall size isn't smaller than the
1919   // original spatial dimension. Unlike for the first space-to-batch'ed
1920   // convolution, while propagating, we can use the last halo_size as available
1921   // spatial size.
1922   while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) {
1923     spatial_split_size += c.stride;
1924   }
1925 
1926   int64 slice_size = spatial_split_size + c.halo_size;
1927 
1928   VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
1929           << slice_size;
1930 
1931   const int64 new_space_size =
1932       activations_new->shape().dimensions(c.spatial_dimension_to_split);
1933   // In the below case, we cannot use the activations directly for Halo
1934   // Duplication. We must reshape them.
1935   if (spatial_split_size > new_space_size) {
1936     TF_ASSIGN_OR_RETURN(
1937         activations_new,
1938         IncreaseSpatialSizeOnSpaceToBatchedShape(
1939             activations_new, activations_batch_dim, old_batch_size,
1940             c.spatial_dimension_to_split, spatial_split_size));
1941 
1942     TF_ASSIGN_OR_RETURN(
1943         activations_new,
1944         HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
1945                                activations_batch_dim, old_batch_size,
1946                                /*low_padding=*/c.base_dilation_factor != 1 &&
1947                                        c.inherent_low_padding != 0
1948                                    ? c.base_dilation_factor - 1
1949                                    : c.inherent_low_padding,
1950                                c.inherent_high_padding,
1951                                slice_size - spatial_split_size,
1952                                old_split_dim_size));
1953   } else {
1954     // If the ideal spatial_split_size was smaller than the incoming spatial
1955     // dimension size, we don't need reshaping. Instead, we determine the
1956     // additional space available, and adjust the required slice size (and
1957     // thereby the halo size).
1958     if (spatial_split_size < new_space_size) {
1959       VLOG(3) << "Decreasing the spatial size while propagating";
1960       const int64 additional_space_present = spatial_split_size % c.stride;
1961       spatial_split_size = new_space_size;
1962       slice_size =
1963           spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride -
1964                                             additional_space_present,
1965                                         static_cast<int64>(0));
1966     }
1967 
1968     TF_ASSIGN_OR_RETURN(
1969         activations_new,
1970         HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
1971                                activations_batch_dim, old_batch_size,
1972                                /*low_padding=*/c.base_dilation_factor != 1 &&
1973                                        c.inherent_low_padding != 0
1974                                    ? c.base_dilation_factor - 1
1975                                    : c.inherent_low_padding,
1976                                c.inherent_high_padding,
1977                                slice_size - spatial_split_size,
1978                                old_split_dim_size));
1979   }
1980 
1981   // We will generate output such that batch is followed by the split spatial
1982   // dimension.
1983   const int64 rank = (convolution->shape().rank());
1984   std::vector<int64> transpose_dims(rank);
1985   int dim_count = 0;
1986   std::map<int64, int64> dim_map;
1987 
1988   for (int j = 0;
1989        j < permuted_conv_dims_numbers.output_spatial_dimensions_size(); ++j) {
1990     if (j == get_chosen_spatial_dim(convolution)) {
1991       dim_map[permuted_conv_dims_numbers.output_batch_dimension()] = dim_count;
1992       new_dim_numbers.set_output_batch_dimension(dim_count++);
1993     }
1994     dim_map[permuted_conv_dims_numbers.output_spatial_dimensions(j)] =
1995         dim_count;
1996     new_dim_numbers.set_output_spatial_dimensions(j, dim_count);
1997     dim_count++;
1998   }
1999 
2000   dim_map[permuted_conv_dims_numbers.output_feature_dimension()] = dim_count;
2001   new_dim_numbers.set_output_feature_dimension(dim_count);
2002 
2003   int p = 0;
2004   for (const auto& entry : dim_map) {
2005     transpose_dims[p] = entry.second;
2006     p++;
2007   }
2008 
2009   auto new_window = convolution->window();
2010   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2011       ->set_padding_high(c.high_padding_for_conv);
2012   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2013       ->set_padding_low(c.low_padding_for_conv);
2014   TF_ASSIGN_OR_RETURN(
2015       HloInstruction * new_conv,
2016       MakeConvolveHlo(
2017           activations_new, /*rhs=*/convolution->mutable_operand(1),
2018           convolution->feature_group_count(), convolution->batch_group_count(),
2019           new_window, new_dim_numbers, convolution->precision_config(),
2020           /*preferred_element_type=*/convolution->shape().element_type()));
2021   convolution->SetupDerivedInstruction(new_conv);
2022 
2023   old_to_new_instrs_[convolution] = new_conv;
2024   VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
2025 
2026   instr_to_dim_map_[convolution] =
2027       std::make_pair(original_conv_dims.output_batch_dimension(),
2028                      original_conv_dims.output_spatial_dimensions(
2029                          get_chosen_spatial_dim(convolution)));
2030 
2031   instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims);
2032 
2033   convs_to_visit_.erase(convolution);
2034   return Status::OK();
2035 }
2036 
SplitSpaceHelper(HloInstruction * activations,int64 spatial_dimension_to_split,int64 activations_batch_dim,int64 high_padding,int64 low_padding,int64 spatial_split_size,int64 num_splits)2037 StatusOr<HloInstruction*> ConvolutionVisitor::SplitSpaceHelper(
2038     HloInstruction* activations, int64 spatial_dimension_to_split,
2039     int64 activations_batch_dim, int64 high_padding, int64 low_padding,
2040     int64 spatial_split_size, int64 num_splits) {
2041   const int64 old_batch_size =
2042       activations->shape().dimensions(activations_batch_dim);
2043 
2044   // Because we are splitting the spatial dimension, if convolution needed
2045   // padding in the spatial dimension, we materialize it.
2046   if (high_padding || low_padding) {
2047     PaddingConfig padding_config =
2048         MakeNoPaddingConfig(activations->shape().dimensions_size());
2049     padding_config.mutable_dimensions(spatial_dimension_to_split)
2050         ->set_edge_padding_high(high_padding);
2051     padding_config.mutable_dimensions(spatial_dimension_to_split)
2052         ->set_edge_padding_low(low_padding);
2053     HloInstruction* padding =
2054         computation_->AddInstruction(HloInstruction::CreateConstant(
2055             LiteralUtil::Zero(activations->shape().element_type())));
2056     TF_ASSIGN_OR_RETURN(activations,
2057                         MakePadHlo(activations, padding, padding_config));
2058   }
2059   VLOG(1) << "Initial padded activations shape "
2060           << activations->shape().ToString() << " old_batch_size "
2061           << old_batch_size << " activations_batch_dim "
2062           << activations_batch_dim;
2063 
2064   // Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
2065   // and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
2066   // in the spatial dimension, we generate a gather. E.g. if halo size was 2,
2067   // we'd create a shape of [24] using the gather, and reshape it into [6, 4]
2068   // (4 being the batch).
2069 
2070   // The benefit of the above mentioned scheme is that it allows for batch
2071   // growth. Here are some examples of the size increases it causes for a 3x3
2072   // kernel.
2073   // with batch=1, [1,16] -> [4,4] ->   [4,6] ->   [1,24] growth of 8.
2074   // with batch=2, [2,16] -> [8,4] ->   [8,6] ->   [1,48] growth of 16.
2075   // with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24.
2076 
2077   std::vector<int64> reshape_dimensions(
2078       activations->shape().dimensions().begin(),
2079       activations->shape().dimensions().end());
2080 
2081   reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
2082   reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
2083 
2084   TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
2085                       MakeReshapeHlo(reshape_dimensions, activations));
2086 
2087   return batch_increased_reshape;
2088 }
2089 
2090 StatusOr<std::pair<HloInstruction*, std::vector<int64>>>
SplitSpace(HloInstruction * activations,ConvolutionDimensionNumbers & dim_numbers,int64 & spatial_dimension_to_split,int64 & activations_batch_dim,int64 high_padding,int64 low_padding,int64 spatial_split_size,int64 num_splits,bool is_backprop,bool is_rhs)2091 ConvolutionVisitor::SplitSpace(HloInstruction* activations,
2092                                ConvolutionDimensionNumbers& dim_numbers,
2093                                int64& spatial_dimension_to_split,
2094                                int64& activations_batch_dim, int64 high_padding,
2095                                int64 low_padding, int64 spatial_split_size,
2096                                int64 num_splits, bool is_backprop,
2097                                bool is_rhs) {
2098   TF_ASSIGN_OR_RETURN(auto retval,
2099                       BringSpaceNextToBatch(
2100                           activations, dim_numbers, spatial_dimension_to_split,
2101                           activations_batch_dim, is_backprop, is_rhs));
2102 
2103   activations = retval.instr;
2104   std::vector<int64> transpose_dims = retval.transpose_dims;
2105   TF_ASSIGN_OR_RETURN(
2106       auto new_activations,
2107       SplitSpaceHelper(activations, spatial_dimension_to_split,
2108                        activations_batch_dim, high_padding, low_padding,
2109                        spatial_split_size, num_splits));
2110   return std::make_pair(new_activations, transpose_dims);
2111 }
2112 
PropagateOnConstant(HloInstruction * consumer,HloInstruction * producer)2113 StatusOr<HloInstruction*> ConvolutionVisitor::PropagateOnConstant(
2114     HloInstruction* consumer, HloInstruction* producer) {
2115   CHECK(old_to_new_instrs_.contains(producer));
2116   HloInstruction* new_producer = old_to_new_instrs_[producer];
2117   auto prod_transpose_dims = instr_to_dim_permute_map_[new_producer];
2118   std::vector<int64> reversed_transpose_dims(prod_transpose_dims.size());
2119   for (int64 i = 0; i < prod_transpose_dims.size(); ++i) {
2120     reversed_transpose_dims[i] = ReverseDimLookUp(prod_transpose_dims, i);
2121   }
2122   // Bring space next to batch.
2123   TF_ASSIGN_OR_RETURN(consumer,
2124                       MakeTransposeHlo(consumer, reversed_transpose_dims));
2125 
2126   auto dim_map = instr_to_dim_map_[producer];
2127   const int64 old_batch_dim = dim_map.first;
2128   const int64 old_space_dim = dim_map.second;
2129   const int64 new_batch_dim = DimLookUp(prod_transpose_dims, old_batch_dim);
2130   const int64 new_space_dim = DimLookUp(prod_transpose_dims, old_space_dim);
2131 
2132   const int64 old_batch_size = producer->shape().dimensions(old_batch_dim);
2133   const int64 new_batch_size = new_producer->shape().dimensions(new_batch_dim);
2134   const int64 high_padding =
2135       (new_batch_size * new_producer->shape().dimensions(new_space_dim) -
2136        old_batch_size * producer->shape().dimensions(old_space_dim)) /
2137       old_batch_size;
2138 
2139   auto new_consumer = SplitSpaceHelper(
2140       consumer, new_space_dim, new_batch_dim, high_padding, /*low_padding=*/0,
2141       new_producer->shape().dimensions(new_space_dim), kNumSplits);
2142 
2143   return new_consumer;
2144 }
2145 
PropagateOnBackpropFilterConv(HloInstruction * convolution)2146 Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
2147     HloInstruction* convolution) {
2148   auto activations_old = convolution->mutable_operand(0);
2149 
2150   const int64 rhs_dilation =
2151       convolution->window()
2152           .dimensions(get_chosen_spatial_dim(convolution))
2153           .window_dilation();
2154 
2155   auto original_conv_dims = convolution->convolution_dimension_numbers();
2156   int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions(
2157       get_chosen_spatial_dim(convolution));
2158   auto kernel_old = convolution->mutable_operand(1);
2159   const int64 old_kernel_split_dim_size =
2160       kernel_old->shape().dimensions(kernel_space_dim);
2161 
2162   int64 old_space_dim = original_conv_dims.input_spatial_dimensions(
2163       get_chosen_spatial_dim(convolution));
2164   int64 old_split_dim_size = activations_old->shape().dimensions(old_space_dim);
2165 
2166   int64 old_batch_dim = original_conv_dims.input_feature_dimension();
2167   int64 kernel_old_batch_dim =
2168       original_conv_dims.kernel_input_feature_dimension();
2169   const int64 old_batch_size =
2170       activations_old->shape().dimensions(old_batch_dim);
2171 
2172   CHECK(old_to_new_instrs_.contains(kernel_old) ||
2173         old_to_new_instrs_.contains(activations_old));
2174 
2175   HloInstruction* activations_new = nullptr;
2176   HloInstruction* kernel_new = nullptr;
2177   bool activations_locally_space_to_batched = false;
2178   bool kernel_locally_space_to_batched = false;
2179   std::vector<int64> permute_dims_kernel, permute_dims;
2180   // If activations were no space-to-batched, we space-to-batch them below.
2181   if (!old_to_new_instrs_.contains(activations_old)) {
2182     kernel_new = old_to_new_instrs_[kernel_old];
2183     permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
2184 
2185     VLOG(1) << "Space-to-batching activations to enable space-to-depth";
2186     const int64 prev_feature_dim = original_conv_dims.input_feature_dimension();
2187     const int64 prev_batch_dim = original_conv_dims.input_batch_dimension();
2188     instr_to_dim_map_[activations_old] =
2189         std::make_pair(prev_feature_dim, prev_batch_dim);
2190 
2191     const int64 new_kernel_space_dim =
2192         DimLookUp(permute_dims_kernel, kernel_space_dim);
2193 
2194     const int64 new_kernel_split_dim_size =
2195         kernel_new->shape().dimensions(new_kernel_space_dim);
2196     const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size;
2197     const int64 pad_size =
2198         needed_spatial_size * kNumSplits - old_split_dim_size;
2199     ConvolutionDimensionNumbers tmp_dim_numbers;
2200     tmp_dim_numbers = original_conv_dims;
2201     TF_ASSIGN_OR_RETURN(
2202         auto retval,
2203         SplitSpace(activations_old, tmp_dim_numbers, old_space_dim,
2204                    old_batch_dim,
2205                    /*high_padding=*/pad_size, /*low_padding=*/0,
2206                    needed_spatial_size, kNumSplits, /*is_backprop=*/true));
2207 
2208     old_to_new_instrs_[activations_old] = retval.first;
2209 
2210     std::vector<int64> reversed_transpose_dims(retval.second.size());
2211     for (int64 i = 0; i < retval.second.size(); ++i) {
2212       reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
2213     }
2214     instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
2215 
2216     VLOG(3) << "New Activations " << retval.first->ToString();
2217 
2218     activations_locally_space_to_batched = true;
2219   } else if (!old_to_new_instrs_.contains(kernel_old)) {
2220     activations_new = old_to_new_instrs_[activations_old];
2221     permute_dims = instr_to_dim_permute_map_[activations_new];
2222 
2223     VLOG(1) << "Space-to-batching kernel to enable space-to-depth";
2224     const int64 prev_feature_dim =
2225         original_conv_dims.kernel_input_feature_dimension();
2226     const int64 prev_output_feature_dim =
2227         original_conv_dims.kernel_output_feature_dimension();
2228     // TODO(b/168316428): The instr_to_dim_map_ is set incorrectly here, but it
2229     // doesn't matter since it is never used. Investigate further to see if just
2230     // not setting it works.
2231     instr_to_dim_map_[kernel_old] =
2232         std::make_pair(prev_feature_dim, prev_output_feature_dim);
2233 
2234     const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
2235     const int64 new_split_dim_size =
2236         activations_new->shape().dimensions(new_space_dim);
2237     const int64 needed_spatial_size =
2238         CeilOfRatio(new_split_dim_size, rhs_dilation);
2239     int64 old_kernel_split_dim_size =
2240         kernel_old->shape().dimensions(kernel_space_dim);
2241     const int64 pad_size =
2242         needed_spatial_size * kNumSplits - old_kernel_split_dim_size;
2243 
2244     ConvolutionDimensionNumbers tmp_dim_numbers;
2245     tmp_dim_numbers = original_conv_dims;
2246     TF_ASSIGN_OR_RETURN(
2247         auto retval, SplitSpace(kernel_old, tmp_dim_numbers, kernel_space_dim,
2248                                 kernel_old_batch_dim,
2249                                 /*high_padding=*/pad_size, /*low_padding=*/0,
2250                                 needed_spatial_size, kNumSplits,
2251                                 /*is_backprop=*/true, /*is_rhs=*/true));
2252 
2253     old_to_new_instrs_[kernel_old] = retval.first;
2254 
2255     std::vector<int64> reversed_transpose_dims(retval.second.size());
2256     for (int64 i = 0; i < retval.second.size(); ++i) {
2257       reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
2258     }
2259     instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
2260 
2261     VLOG(3) << "New kernel " << retval.first->ToString();
2262 
2263     kernel_locally_space_to_batched = true;
2264   }
2265 
2266   CHECK(old_to_new_instrs_.contains(activations_old));
2267   CHECK(old_to_new_instrs_.contains(kernel_old));
2268   activations_new = old_to_new_instrs_[activations_old];
2269   kernel_new = old_to_new_instrs_[kernel_old];
2270   const int64 new_spatial_dimension =
2271       activations_new->shape().dimensions_size();
2272 
2273   permute_dims = instr_to_dim_permute_map_[activations_new];
2274   permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
2275 
2276   auto permuted_conv_dims_numbers = original_conv_dims;
2277 
2278   // Note the inversion here : batch and feature are inverted in backprop
2279   // filters.
2280   int64 activations_batch_dim =
2281       DimLookUp(permute_dims, original_conv_dims.input_feature_dimension());
2282   int64 activations_feature_dim =
2283       DimLookUp(permute_dims, original_conv_dims.input_batch_dimension());
2284 
2285   const int64 previous_spatial_dim_count =
2286       original_conv_dims.input_spatial_dimensions_size();
2287   for (int64 i = 0; i < previous_spatial_dim_count; ++i) {
2288     permuted_conv_dims_numbers.set_input_spatial_dimensions(
2289         i, DimLookUp(permute_dims,
2290                      original_conv_dims.input_spatial_dimensions(i)));
2291     permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
2292         i, DimLookUp(permute_dims_kernel,
2293                      original_conv_dims.kernel_spatial_dimensions(i)));
2294   }
2295 
2296   permuted_conv_dims_numbers.add_input_spatial_dimensions(
2297       new_spatial_dimension);
2298   permuted_conv_dims_numbers.add_kernel_spatial_dimensions(
2299       new_spatial_dimension);
2300   permuted_conv_dims_numbers.add_output_spatial_dimensions(
2301       new_spatial_dimension);
2302 
2303   // For the output, make the last dimension size 1.
2304   const int64 previous_chosen_spatial_dim_in_output =
2305       permuted_conv_dims_numbers.output_spatial_dimensions(
2306           get_chosen_spatial_dim(convolution));
2307   permuted_conv_dims_numbers.set_output_spatial_dimensions(
2308       get_chosen_spatial_dim(convolution), new_spatial_dimension);
2309   permuted_conv_dims_numbers.set_output_spatial_dimensions(
2310       previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
2311 
2312   const int64 kernel_input_feature_dim = DimLookUp(
2313       permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
2314 
2315   const int64 kernel_output_feature_dim =
2316       DimLookUp(permute_dims_kernel,
2317                 original_conv_dims.kernel_output_feature_dimension());
2318 
2319   permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
2320       kernel_input_feature_dim);
2321   permuted_conv_dims_numbers.set_kernel_output_feature_dimension(
2322       kernel_output_feature_dim);
2323 
2324   int64 spatial_dimension_to_split =
2325       permuted_conv_dims_numbers.input_spatial_dimensions(
2326           get_chosen_spatial_dim(convolution));
2327 
2328   const int64 kernel_spatial_dimension_to_split =
2329       permuted_conv_dims_numbers.kernel_spatial_dimensions(
2330           get_chosen_spatial_dim(convolution));
2331 
2332   int64 new_split_dim_size =
2333       activations_new->shape().dimensions(spatial_dimension_to_split);
2334 
2335   const int64 kernel_new_split_dim_size =
2336       kernel_new->shape().dimensions(kernel_spatial_dimension_to_split);
2337 
2338   permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim);
2339   permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim);
2340 
2341   VLOG(1) << "Propagating on conv activations_batch_dim "
2342           << activations_batch_dim << " spatial_dimension_to_split "
2343           << spatial_dimension_to_split << " old_batch_size " << old_batch_size
2344           << " new_split_dim_size " << new_split_dim_size;
2345 
2346   TF_ASSIGN_OR_RETURN(
2347       auto retval,
2348       BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
2349                             spatial_dimension_to_split, activations_batch_dim,
2350                             /*is_backprop=*/true));
2351 
2352   std::vector<int64> transpose_dims = retval.transpose_dims;
2353   CHECK(!transpose_dims.empty());
2354   activations_new = retval.instr;
2355 
2356   VLOG(1) << "Activations_new post BringSpaceNextToBatch "
2357           << activations_new->ToString();
2358   VLOG(1) << "activations_batch_dim " << activations_batch_dim
2359           << " activations_feature_dim " << activations_feature_dim;
2360   const int64 expected_split_dim_size =
2361       rhs_dilation * kernel_new_split_dim_size;
2362   if (new_split_dim_size != expected_split_dim_size) {
2363     CHECK_LT(new_split_dim_size, expected_split_dim_size);
2364     new_split_dim_size = expected_split_dim_size;
2365     TF_ASSIGN_OR_RETURN(
2366         activations_new,
2367         IncreaseSpatialSizeOnSpaceToBatchedShape(
2368             activations_new, activations_batch_dim, old_batch_size,
2369             spatial_dimension_to_split, new_split_dim_size));
2370   }
2371 
2372   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
2373       LiteralUtil::Zero(activations_new->shape().element_type())));
2374 
2375   if (!activations_locally_space_to_batched) {
2376     // Select activations correctly by masking additional space.
2377     TF_ASSIGN_OR_RETURN(
2378         activations_new,
2379         SelectValidPortion(activations_new, activations_old, select_val,
2380                            activations_batch_dim, spatial_dimension_to_split,
2381                            old_batch_dim, old_space_dim));
2382   }
2383   if (!kernel_locally_space_to_batched) {
2384     VLOG(3) << "Selecting the valid kernel area";
2385     // Select kernel correctly by masking additional space.
2386     TF_ASSIGN_OR_RETURN(
2387         kernel_new,
2388         SelectValidPortion(kernel_new, kernel_old, select_val,
2389                            /*new_batch_dim=*/kernel_input_feature_dim,
2390                            kernel_spatial_dimension_to_split,
2391                            /*old_batch_dim=*/
2392                            original_conv_dims.kernel_input_feature_dimension(),
2393                            kernel_space_dim));
2394   }
2395 
2396   // Create the new convolution dim numbers.
2397   auto new_dim_numbers = permuted_conv_dims_numbers;
2398 
2399   VLOG(2) << "New dim numbers " << new_dim_numbers.DebugString();
2400 
2401   const int64 inherent_low_padding =
2402       convolution->window()
2403           .dimensions(get_chosen_spatial_dim(convolution))
2404           .padding_low();
2405 
2406   const int64 inherent_high_padding =
2407       convolution->window()
2408           .dimensions(get_chosen_spatial_dim(convolution))
2409           .padding_high();
2410 
2411   std::vector<HloInstruction*> activations_chunks;
2412 
2413   // Insert slices for low padding.
2414   for (int64 i = 0; i < inherent_low_padding; ++i) {
2415     HloInstruction* activations_to_use = nullptr;
2416     if (i == 0) {
2417       activations_to_use = activations_new;
2418     } else {
2419       activations_to_use = activations_chunks.back();
2420     }
2421     TF_ASSIGN_OR_RETURN(
2422         HloInstruction * activations_slice,
2423         HaloDuplicateWithSlice(activations_to_use, spatial_dimension_to_split,
2424                                activations_batch_dim, old_batch_size,
2425                                /*low_padding=*/1,
2426                                /*high_padding=*/0,
2427                                /*halo_size=*/0, old_split_dim_size));
2428     activations_chunks.push_back(activations_slice);
2429   }
2430   // Reverse the low padding slices because we created them in the opposite
2431   // order above.
2432   absl::c_reverse(activations_chunks);
2433 
2434   const int64 expanded_kernel =
2435       old_kernel_split_dim_size * rhs_dilation - (rhs_dilation - 1);
2436   const int64 overlap_count =
2437       old_split_dim_size - expanded_kernel + 1 +
2438       (inherent_low_padding < 0 ? inherent_low_padding : 0) +
2439       (inherent_high_padding < 0 ? inherent_high_padding : 0);
2440   VLOG(1) << "overlap_count " << overlap_count << " inherent_low_padding "
2441           << inherent_low_padding << " inherent_high_padding "
2442           << inherent_high_padding;
2443 
2444   // Insert original activations.
2445   for (int64 i = 0; i < overlap_count; ++i) {
2446     HloInstruction* activations_to_use = nullptr;
2447     HloInstruction* activations_slice = nullptr;
2448     if (i == 0) {
2449       activations_to_use = activations_new;
2450       if (inherent_low_padding < 0) {
2451         TF_ASSIGN_OR_RETURN(activations_slice,
2452                             HaloDuplicateWithSlice(
2453                                 activations_to_use, spatial_dimension_to_split,
2454                                 activations_batch_dim, old_batch_size,
2455                                 /*low_padding=*/inherent_low_padding,
2456                                 /*high_padding=*/0,
2457                                 /*halo_size=*/0, old_split_dim_size));
2458       } else {
2459         activations_slice = activations_to_use;
2460       }
2461     } else {
2462       activations_to_use = activations_chunks.back();
2463 
2464       TF_ASSIGN_OR_RETURN(
2465           activations_slice,
2466           HaloDuplicateWithSlice(activations_to_use, spatial_dimension_to_split,
2467                                  activations_batch_dim, old_batch_size,
2468                                  /*low_padding=*/-1,
2469                                  /*high_padding=*/0,
2470                                  /*halo_size=*/0, old_split_dim_size));
2471     }
2472 
2473     activations_chunks.push_back(activations_slice);
2474   }
2475 
2476   // Insert slices for high padding.
2477   for (int64 i = 0; i < inherent_high_padding; ++i) {
2478     HloInstruction* activations_to_use = nullptr;
2479     activations_to_use = activations_chunks.back();
2480 
2481     TF_ASSIGN_OR_RETURN(
2482         HloInstruction * activations_slice,
2483         HaloDuplicateWithSlice(activations_to_use, spatial_dimension_to_split,
2484                                activations_batch_dim, old_batch_size,
2485                                /*low_padding=*/-1, /*high_padding=*/0,
2486                                /*halo_size=*/0, old_split_dim_size));
2487     activations_chunks.push_back(activations_slice);
2488   }
2489 
2490   for (int64 i = 0; i < activations_chunks.size(); ++i) {
2491     std::vector<int64> input_sizes(
2492         activations_chunks[i]->shape().dimensions().begin(),
2493         activations_chunks[i]->shape().dimensions().end());
2494     // Insert 1-sized dimension at the end
2495     input_sizes.push_back(1);
2496     TF_ASSIGN_OR_RETURN(activations_chunks[i],
2497                         MakeReshapeHlo(input_sizes, activations_chunks[i]));
2498   }
2499 
2500   TF_ASSIGN_OR_RETURN(
2501       activations_new,
2502       MakeConcatHlo(absl::MakeSpan(activations_chunks), new_spatial_dimension));
2503 
2504   // Reshape the kernel with additional spatial dim.
2505   std::vector<int64> kernel_sizes(kernel_new->shape().dimensions().begin(),
2506                                   kernel_new->shape().dimensions().end());
2507   // Insert 1-sized dimension at the end
2508   kernel_sizes.push_back(1);
2509   TF_ASSIGN_OR_RETURN(kernel_new, MakeReshapeHlo(kernel_sizes, kernel_new));
2510 
2511   auto new_window = convolution->window();
2512   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2513       ->set_padding_high(-(rhs_dilation - 1));
2514   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2515       ->set_padding_low(0);
2516   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2517       ->set_size(CeilOfRatio(new_split_dim_size, rhs_dilation));
2518 
2519   // Set the window for the additional spatial dim. This is a vanilla window.
2520   auto window_dim = new_window.add_dimensions();
2521   window_dim->set_base_dilation(1);
2522   window_dim->set_size(1);
2523   window_dim->set_stride(1);
2524   window_dim->set_padding_low(0);
2525   window_dim->set_padding_high(0);
2526   window_dim->set_window_reversal(false);
2527   window_dim->set_window_dilation(1);
2528 
2529   TF_ASSIGN_OR_RETURN(
2530       HloInstruction * new_conv,
2531       MakeConvolveHlo(
2532           activations_new, kernel_new, convolution->feature_group_count(),
2533           convolution->batch_group_count(), new_window, new_dim_numbers,
2534           convolution->precision_config(),
2535           /*preferred_element_type=*/convolution->shape().element_type()));
2536   convolution->SetupDerivedInstruction(new_conv);
2537 
2538   VLOG(2) << "New backprop filter convolution " << new_conv->ToString();
2539 
2540   std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
2541                                   new_conv->shape().dimensions().end());
2542 
2543   output_sizes.erase(output_sizes.begin() +
2544                      new_dim_numbers.output_spatial_dimensions(
2545                          get_chosen_spatial_dim(convolution)));
2546 
2547   TF_ASSIGN_OR_RETURN(new_conv, MakeReshapeHlo(output_sizes, new_conv));
2548 
2549   old_to_new_instrs_[convolution] = new_conv;
2550   VLOG(1) << "Space-to-featured convolution " << new_conv->ToString();
2551 
2552   instr_to_dim_map_[convolution] =
2553       std::make_pair(original_conv_dims.output_batch_dimension(),
2554                      original_conv_dims.output_spatial_dimensions(
2555                          get_chosen_spatial_dim(convolution)));
2556 
2557   std::vector<int64> trans_dims(convolution->shape().dimensions_size());
2558   absl::c_iota(trans_dims, 0);
2559   instr_to_dim_permute_map_[new_conv] = trans_dims;
2560 
2561   return Status::OK();
2562 }
2563 
2564 HloInstruction*
DoesConvolutionFeedReduceWindowOrSelectAndScatter(HloInstruction * instr,int64 depth=kReduceWindowSearchDepth)2565 ConvolutionVisitor::DoesConvolutionFeedReduceWindowOrSelectAndScatter(
2566     HloInstruction* instr, int64 depth = kReduceWindowSearchDepth) {
2567   if (depth == 0) {
2568     return nullptr;
2569   }
2570 
2571   for (auto user : instr->users()) {
2572     if (user->opcode() == HloOpcode::kReduceWindow ||
2573         user->opcode() == HloOpcode::kSelectAndScatter) {
2574       return user;
2575     }
2576     // Stop the search if these ops are encountered.
2577     if (user->opcode() == HloOpcode::kConvolution ||
2578         user->opcode() == HloOpcode::kPad ||
2579         user->opcode() == HloOpcode::kTranspose) {
2580       continue;
2581     }
2582     auto ret =
2583         DoesConvolutionFeedReduceWindowOrSelectAndScatter(user, depth - 1);
2584     if (ret != nullptr) {
2585       return ret;
2586     }
2587   }
2588   return nullptr;
2589 }
2590 
GetConvolutionDetails(HloInstruction * convolution,ConvolutionDimensionNumbers & dim_numbers)2591 ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
2592     HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
2593   auto activations = convolution->mutable_operand(0);
2594 
2595   auto kernel = convolution->mutable_operand(1);
2596   const auto& kernel_shape = kernel->shape();
2597   const int64 kernel_spatial_dim_size =
2598       kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
2599           get_chosen_spatial_dim(convolution)));
2600 
2601   const int64 spatial_dimension_to_split =
2602       dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution));
2603 
2604   const int64 input_dim_size =
2605       activations->shape().dimensions(spatial_dimension_to_split);
2606 
2607   const int64 inherent_low_padding =
2608       convolution->window()
2609           .dimensions(get_chosen_spatial_dim(convolution))
2610           .padding_low();
2611   const int64 inherent_high_padding =
2612       convolution->window()
2613           .dimensions(get_chosen_spatial_dim(convolution))
2614           .padding_high();
2615 
2616   const int64 stride = convolution->window()
2617                            .dimensions(get_chosen_spatial_dim(convolution))
2618                            .stride();
2619 
2620   const int64 base_dilation_factor =
2621       convolution->window()
2622           .dimensions(get_chosen_spatial_dim(convolution))
2623           .base_dilation();
2624 
2625   const int64 spatial_size =
2626       input_dim_size + (base_dilation_factor > 1 ? 0 : inherent_low_padding) +
2627       inherent_high_padding;
2628 
2629   const int64 halo_size =
2630       std::max(kernel_spatial_dim_size - stride - (base_dilation_factor - 1),
2631                static_cast<int64>(0));
2632   const int64 high_padding_for_conv = base_dilation_factor == 1 ? 0
2633                                       : inherent_low_padding == 0
2634                                           ? base_dilation_factor - 1
2635                                           : 0;
2636   const int64 low_padding_for_conv =
2637       base_dilation_factor == 1 ? 0 : inherent_low_padding;
2638 
2639   return ConvDetails{spatial_dimension_to_split,
2640                      inherent_low_padding,
2641                      inherent_high_padding,
2642                      stride,
2643                      spatial_size,
2644                      base_dilation_factor,
2645                      halo_size,
2646                      high_padding_for_conv,
2647                      low_padding_for_conv,
2648                      kernel_spatial_dim_size,
2649                      input_dim_size};
2650 }
2651 
PerformSpaceToBatchOnConvolution(HloInstruction * convolution)2652 Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
2653     HloInstruction* convolution) {
2654   VLOG(1) << "Handling conv " << convolution->ToString();
2655 
2656   changed_ = false;
2657 
2658   ConvolutionDimensionNumbers dim_numbers =
2659       convolution->convolution_dimension_numbers();
2660 
2661   ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
2662 
2663   int64 activations_batch_dim = dim_numbers.input_batch_dimension();
2664 
2665   const int64 old_batch_size =
2666       convolution->operand(0)->shape().dimensions(activations_batch_dim);
2667 
2668   auto activations = convolution->mutable_operand(0);
2669 
2670   VLOG(1) << "spatial size " << c.spatial_size;
2671 
2672   auto original_conv = convolution;
2673 
2674   const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
2675       get_chosen_spatial_dim(convolution));
2676   const int64 output_offsets =
2677       convolution->shape().dimensions(output_spatial_dim);
2678   const int64 output_offsets_per_split =
2679       CeilOfRatio(output_offsets, kNumSplits);
2680 
2681   int64 spatial_split_size =
2682       CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
2683   // Keep increasing the split size so that overall size isn't smaller than the
2684   // original spatial dimension.
2685   while (spatial_split_size * kNumSplits - c.spatial_size < 0) {
2686     spatial_split_size += c.stride;
2687   }
2688 
2689   auto reduce_window_or_select_and_scatter =
2690       DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
2691 
2692   if (reduce_window_or_select_and_scatter != nullptr) {
2693     VLOG(2)
2694         << "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
2695     // Take into account the stride of the reduce window while choosing the
2696     // spatial_split_size. This will guarantee propagation through reduce
2697     // windows.
2698     const int64 win_stride = reduce_window_or_select_and_scatter->window()
2699                                  .dimensions(output_spatial_dim)
2700                                  .stride();
2701     while ((spatial_split_size / c.stride) % win_stride != 0) {
2702       spatial_split_size += c.stride;
2703     }
2704   }
2705 
2706   const int64 slice_size = spatial_split_size + c.halo_size;
2707 
2708   // Pad spatial dim.
2709   const int64 pad_size = spatial_split_size * kNumSplits - c.spatial_size;
2710 
2711   VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
2712           << c.stride << " slice_size " << slice_size;
2713   VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
2714           << " num_splits " << kNumSplits << " kernel_spatial_dim_size "
2715           << c.kernel_spatial_dim_size;
2716   int64 spatial_dimension_to_split = c.spatial_dimension_to_split;
2717   TF_ASSIGN_OR_RETURN(
2718       auto retval,
2719       SplitSpace(activations, dim_numbers, spatial_dimension_to_split,
2720                  activations_batch_dim,
2721                  /*high_padding=*/c.inherent_high_padding + pad_size,
2722                  /*low_padding=*/c.base_dilation_factor == 1
2723                      ? c.inherent_low_padding
2724                      : 0,
2725                  spatial_split_size, kNumSplits));
2726   HloInstruction* batch_increased_reshape = retval.first;
2727   convolution->SetupDerivedInstruction(batch_increased_reshape);
2728 
2729   VLOG(1) << "First reshape done " << batch_increased_reshape->ToString();
2730 
2731   TF_ASSIGN_OR_RETURN(
2732       activations, HaloDuplicateWithSlice(batch_increased_reshape,
2733                                           spatial_dimension_to_split,
2734                                           activations_batch_dim, old_batch_size,
2735                                           /*low_padding=*/0, /*high_padding=*/0,
2736                                           c.halo_size, c.input_dim_size));
2737 
2738   VLOG(1) << "Batch merge done " << activations->ToString();
2739 
2740   // Now, we rewrite the convolution with a larger batch.
2741 
2742   // Create the new convolution dim numbers.
2743   auto new_dim_numbers = dim_numbers;
2744 
2745   // We will generate output such that batch is followed by the split spatial
2746   // dimension.
2747   const int64 rank = convolution->shape().rank();
2748   std::vector<int64> transpose_dims(rank);
2749   int dim_count = 0;
2750   std::map<int64, int64> dim_map;
2751 
2752   for (int j = 0; j < dim_numbers.output_spatial_dimensions_size(); ++j) {
2753     if (j == get_chosen_spatial_dim(convolution)) {
2754       dim_map[dim_numbers.output_batch_dimension()] = dim_count;
2755       new_dim_numbers.set_output_batch_dimension(dim_count++);
2756     }
2757     dim_map[dim_numbers.output_spatial_dimensions(j)] = dim_count;
2758     new_dim_numbers.set_output_spatial_dimensions(j, dim_count);
2759     dim_count++;
2760   }
2761 
2762   dim_map[dim_numbers.output_feature_dimension()] = dim_count;
2763   new_dim_numbers.set_output_feature_dimension(dim_count);
2764 
2765   int p = 0;
2766   for (const auto& entry : dim_map) {
2767     transpose_dims[p] = entry.second;
2768     p++;
2769   }
2770   VLOG(1) << "New dim numbers " << new_dim_numbers.DebugString()
2771           << " batch dim " << new_dim_numbers.input_batch_dimension();
2772   auto new_window = convolution->window();
2773   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2774       ->set_padding_high(c.high_padding_for_conv);
2775   new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
2776       ->set_padding_low(c.low_padding_for_conv);
2777   TF_ASSIGN_OR_RETURN(
2778       HloInstruction * new_conv,
2779       MakeConvolveHlo(
2780           activations, /*rhs=*/convolution->mutable_operand(1),
2781           convolution->feature_group_count(), convolution->batch_group_count(),
2782           new_window, new_dim_numbers, convolution->precision_config(),
2783           /*preferred_element_type=*/convolution->shape().element_type()));
2784   convolution->SetupDerivedInstruction(new_conv);
2785 
2786   // If the activations were to be batch-to-spaced again, simply use the
2787   // original value.
2788   batch_to_space_map_[convolution->mutable_operand(0)] =
2789       convolution->mutable_operand(0);
2790 
2791   VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
2792 
2793   const int64 output_split_spatial_dim =
2794       new_dim_numbers.output_spatial_dimensions(
2795           get_chosen_spatial_dim(convolution));
2796   const int64 output_batch_dim = new_dim_numbers.output_batch_dimension();
2797   VLOG(1) << "output_batch_dim " << output_batch_dim
2798           << " output_split_spatial_dim " << output_split_spatial_dim;
2799 
2800   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
2801       LiteralUtil::Zero(new_conv->shape().element_type())));
2802 
2803   TF_ASSIGN_OR_RETURN(
2804       new_conv, SelectValidPortion(new_conv, original_conv, select_val,
2805                                    output_batch_dim, output_split_spatial_dim,
2806                                    dim_numbers.output_batch_dimension(),
2807                                    dim_numbers.output_spatial_dimensions(
2808                                        get_chosen_spatial_dim(original_conv))));
2809   old_to_new_instrs_[original_conv] = new_conv;
2810 
2811   instr_to_dim_map_[original_conv] =
2812       std::make_pair(dim_numbers.output_batch_dimension(),
2813                      dim_numbers.output_spatial_dimensions(
2814                          get_chosen_spatial_dim(original_conv)));
2815 
2816   instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims);
2817   if (non_propagatable_instrs_.count(convolution) > 0) {
2818     non_propagatable_instrs_.erase(convolution);
2819   }
2820   TF_CHECK_OK(PropagateOnUsers(original_conv));
2821 
2822   changed_ = true;
2823 
2824   return Status::OK();
2825 }
2826 
2827 }  // namespace
2828 
Run(HloModule * module)2829 StatusOr<bool> ConvolutionSpaceToBatchConverter::Run(HloModule* module) {
2830   XLA_VLOG_LINES(2, "ConvolutionSpaceToBatchConverter::Run(), before:\n" +
2831                         module->ToString());
2832   bool changed = false;
2833 
2834   for (auto* comp : module->MakeNonfusionComputations()) {
2835     ConvolutionVisitor visitor(limit_on_batch_size_, comp);
2836     if (visitor.Run().ValueOrDie()) {
2837       changed = true;
2838     }
2839     VLOG(1) << "Done operating on computation";
2840   }
2841   XLA_VLOG_LINES(2, "ConvolutionSpaceToBatchConverter::Run(), after:\n" +
2842                         module->ToString());
2843   return changed;
2844 }
2845 
2846 }  // namespace xla
2847