1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/while_util.h"
22 #include "tensorflow/compiler/xla/window_util.h"
23 
24 namespace xla {
25 
26 class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
27  public:
DynamicDimensionInferenceVisitor(const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent)28   explicit DynamicDimensionInferenceVisitor(
29       const DynamicParameterBinding& param_bindings,
30       DynamicDimensionInference* parent)
31       : param_bindings_(param_bindings), parent_(parent) {}
32 
33   Status DefaultAction(HloInstruction* hlo) override;
34 
Run(HloComputation * computation,const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent)35   static Status Run(HloComputation* computation,
36                     const DynamicParameterBinding& param_bindings,
37                     DynamicDimensionInference* parent) {
38     DynamicDimensionInferenceVisitor visitor(param_bindings, parent);
39     return computation->Accept(&visitor);
40   }
41 
42   Status HandleParameter(HloInstruction* hlo) override;
43 
44   Status HandleReduce(HloInstruction* hlo) override;
45 
46   Status HandleDot(HloInstruction* hlo) override;
47 
48   Status HandleTuple(HloInstruction* hlo) override;
49 
50   Status HandleTranspose(HloInstruction* hlo) override;
51 
52   Status HandleReshape(HloInstruction* hlo) override;
53 
54   Status HandlePad(HloInstruction* hlo) override;
55 
56   Status HandleBroadcast(HloInstruction* hlo) override;
57 
58   Status HandleGetDimensionSize(HloInstruction* hlo) override;
59 
60   Status HandleSelect(HloInstruction* hlo) override;
61 
62   Status HandleConvolution(HloInstruction* hlo) override;
63 
64   Status HandleReduceWindow(HloInstruction* hlo) override;
65 
66   Status HandleSelectAndScatter(HloInstruction* hlo) override;
67 
68   Status HandleGetTupleElement(HloInstruction* hlo) override;
69 
70   Status HandleElementwiseUnary(HloInstruction* hlo) override;
71 
72   Status HandleElementwiseBinary(HloInstruction* hlo) override;
73 
74   Status HandleWhile(HloInstruction* hlo) override;
75 
76   Status HandleSlice(HloInstruction* hlo) override;
77 
78  private:
79   using OperandDynamicDimensionFn = std::function<Status(
80       HloInstruction* operand, ShapeIndex index, int64 dimension,
81       int64 operand_index, HloInstruction* dynamic_size)>;
82 
83   Status ForEachOperandDynamicDimension(HloInstruction* inst,
84                                         const OperandDynamicDimensionFn&);
85 
86   // Pass through a dynamic dimension from the input to the output with the same
87   // value and index in the shape. This is a helper function to handle trivial
88   // instructions like elementwise operations.
89   Status PassThroughDynamicDimension(HloInstruction*);
90 
91   // The dynamic parameter bindings of this computation.
92   const DynamicParameterBinding& param_bindings_;
93 
94   // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
95   DynamicDimensionInference* parent_;
96 };
97 
DefaultAction(HloInstruction * hlo)98 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
99   return ForEachOperandDynamicDimension(
100       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
101                int64 operand_index, HloInstruction* dynamic_size) {
102         return UnimplementedStrCat(
103             "Asked to propagate a dynamic dimension from hlo ",
104             operand->ToString(), "@", index.ToString(), "@", dimension,
105             " to hlo ", hlo->ToString(), ", which is not implemented.");
106       });
107 }
108 
HandleGetTupleElement(HloInstruction * hlo)109 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
110     HloInstruction* hlo) {
111   return ForEachOperandDynamicDimension(
112       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
113                int64 operand_index, HloInstruction* dynamic_size) {
114         if (hlo->tuple_index() == index[0]) {
115           ShapeIndex new_index =
116               ShapeIndexView(index).ConsumeFront().ToShapeIndex();
117           parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
118         }
119         return Status::OK();
120       });
121 }
122 
HandleTuple(HloInstruction * hlo)123 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
124   return ForEachOperandDynamicDimension(
125       hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
126                int64 operand_index, HloInstruction* dynamic_size) {
127         index.push_front(operand_index);
128         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
129         return Status::OK();
130       });
131 }
132 
HandleBroadcast(HloInstruction * hlo)133 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
134   return ForEachOperandDynamicDimension(
135       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
136                int64 operand_index, HloInstruction* dynamic_size) {
137         int64 broadcast_dim = hlo->dimensions(dimension);
138         parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
139         return Status::OK();
140       });
141 }
142 
HandlePad(HloInstruction * hlo)143 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
144   return ForEachOperandDynamicDimension(
145       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
146                int64 operand_index, HloInstruction* dynamic_size) {
147         if (operand_index != 0) {
148           return Unimplemented(
149               "Dynamic dimension on padding value is not supported");
150         }
151         const PaddingConfig_PaddingConfigDimension& padding_config =
152             hlo->padding_config().dimensions(dimension);
153         if (padding_config.interior_padding() == 0 &&
154             padding_config.edge_padding_low() == 0 &&
155             padding_config.edge_padding_high() == 0) {
156           parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
157           return Status::OK();
158         } else {
159           return Unimplemented(
160               "Dynamic dimension propagation on padding dimension is not "
161               "supported.");
162         }
163       });
164 }
165 
HandleReduce(HloInstruction * hlo)166 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
167   return ForEachOperandDynamicDimension(
168       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
169                int64 operand_index, HloInstruction* dynamic_size) {
170         HloInstruction* reduce = hlo;
171         int64 operand_count = reduce->operand_count();
172         CHECK_EQ(operand_count % 2, 0);
173         if (operand_index >= operand_count / 2) {
174           // Init values doesn't have dynamic size.
175           return Status::OK();
176         }
177         if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
178           // Dimension is to be reduce, stop tracing.
179           return Status::OK();
180         }
181 
182         // Find out the new dynamic dimension after reduce.
183         int64 dimensions_not_reduced_count = 0;
184         for (int i = 0; i < operand->shape().rank(); ++i) {
185           if (dimension == i) {
186             parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count,
187                                     dynamic_size);
188 
189             return Status::OK();
190           }
191           if (absl::c_count(reduce->dimensions(), i) == 0) {
192             dimensions_not_reduced_count++;
193           }
194         }
195 
196         return Status::OK();
197       });
198 }
199 
HandleDot(HloInstruction * hlo)200 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
201   return ForEachOperandDynamicDimension(
202       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
203                int64 operand_index, HloInstruction* dynamic_size) {
204         HloInstruction* dot = hlo;
205         const DotDimensionNumbers& dimension_numbers =
206             dot->dot_dimension_numbers();
207         // A map from the operand dimensions to result dimension.
208         absl::flat_hash_map<int64, int64> result_dim_mapping;
209         int64 current_result_dims = 0;
210         std::unordered_set<int64> batch_dims(
211             dimension_numbers.rhs_batch_dimensions().begin(),
212             dimension_numbers.rhs_batch_dimensions().end());
213 
214         for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
215           result_dim_mapping[i] = current_result_dims++;
216         }
217 
218         for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
219           if (!absl::c_linear_search(
220                   dimension_numbers.lhs_contracting_dimensions(), i)) {
221             if (operand_index == 0) {
222               result_dim_mapping[i] = current_result_dims;
223             }
224             current_result_dims++;
225           }
226         }
227 
228         for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
229           if (!absl::c_linear_search(
230                   dimension_numbers.rhs_contracting_dimensions(), i) &&
231               !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
232                                      i)) {
233             if (operand_index == 1) {
234               result_dim_mapping[i] = current_result_dims;
235             }
236             current_result_dims++;
237           }
238         }
239 
240         // Check if the operand dim is in the result shape. If so, add another
241         // work item to trace that dimension.
242         auto iter = result_dim_mapping.find(dimension);
243         if (iter != result_dim_mapping.end()) {
244           parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
245         }
246 
247         return Status::OK();
248       });
249 }
250 
HandleTranspose(HloInstruction * hlo)251 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
252   return ForEachOperandDynamicDimension(
253       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
254                int64 operand_index, HloInstruction* dynamic_size) {
255         parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension],
256                                 dynamic_size);
257         return Status::OK();
258       });
259 }
260 
HandleConvolution(HloInstruction * hlo)261 Status DynamicDimensionInferenceVisitor::HandleConvolution(
262     HloInstruction* hlo) {
263   return ForEachOperandDynamicDimension(
264       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
265                int64 operand_index, HloInstruction* dynamic_size) {
266         HloInstruction* conv = hlo;
267         const ConvolutionDimensionNumbers& dimension_numbers =
268             conv->convolution_dimension_numbers();
269 
270         if (operand_index == 0) {
271           if (dimension == dimension_numbers.input_batch_dimension()) {
272             parent_->SetDynamicSize(conv, {},
273                                     dimension_numbers.output_batch_dimension(),
274                                     dynamic_size);
275             return Status::OK();
276           }
277 
278           if (dimension == dimension_numbers.input_feature_dimension()) {
279             return Status::OK();
280           }
281         } else {
282           if (dimension == dimension_numbers.kernel_input_feature_dimension()) {
283             return Status::OK();
284           }
285         }
286 
287         return Unimplemented("Dynamic Spatial Convolution is not supported: %s",
288                              conv->ToString());
289       });
290 }
291 
HandleGetDimensionSize(HloInstruction *)292 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
293     HloInstruction*) {
294   // Dynamic dimension doesn't propagate through GetDimensionSize:
295   //
296   //   Input: F32[x, y, z]
297   //     |
298   //   GetDimensionSize(1): U32[]
299   //
300   // The returned value is a scalar, which doesn't have any dynamic dimension in
301   // the shape (although the value contains the real size of the dynamic
302   // dimension of the input).
303   return Status::OK();
304 }
305 
PassThroughDynamicDimension(HloInstruction * hlo)306 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
307     HloInstruction* hlo) {
308   return ForEachOperandDynamicDimension(
309       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
310                int64 operand_index, HloInstruction* dynamic_size) {
311         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
312         return Status::OK();
313       });
314 }
315 
HandleElementwiseUnary(HloInstruction * hlo)316 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
317     HloInstruction* hlo) {
318   return PassThroughDynamicDimension(hlo);
319 }
320 
HandleSelect(HloInstruction * hlo)321 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
322   return PassThroughDynamicDimension(hlo);
323 }
324 
HandleElementwiseBinary(HloInstruction * hlo)325 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
326     HloInstruction* hlo) {
327   return PassThroughDynamicDimension(hlo);
328 }
329 
HandleReshape(HloInstruction * hlo)330 Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
331   return ForEachOperandDynamicDimension(
332       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
333                int64 operand_index, HloInstruction* dynamic_size) {
334         HloInstruction* reshape = hlo;
335         std::vector<std::pair<int64, int64>> unmodified_dims =
336             ShapeUtil::DimensionsUnmodifiedByReshape(operand->shape(),
337                                                      reshape->shape());
338         for (auto& unmodified : unmodified_dims) {
339           if (unmodified.first == dimension) {
340             parent_->SetDynamicSize(reshape, {}, unmodified.second,
341                                     dynamic_size);
342             return Status::OK();
343           }
344         }
345         return Unimplemented(
346             "Dynamic Reshape on modified dimensions is yet not supported: %s",
347             reshape->ToString());
348       });
349 }
350 
HandleReduceWindow(HloInstruction * hlo)351 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
352     HloInstruction* hlo) {
353   return ForEachOperandDynamicDimension(
354       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
355                int64 operand_index, HloInstruction* dynamic_size) {
356         HloInstruction* reduce_window = hlo;
357         const WindowDimension& window_dimension =
358             reduce_window->window().dimensions(dimension);
359 
360         if (!window_util::IsTrivialWindowDimension(window_dimension)) {
361           return Unimplemented(
362               "Dynamic Spatial reduce window is not supported: %s",
363               reduce_window->ToString());
364         }
365 
366         parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size);
367 
368         return Status::OK();
369       });
370 }
371 
HandleSelectAndScatter(HloInstruction * hlo)372 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
373     HloInstruction* hlo) {
374   return ForEachOperandDynamicDimension(
375       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
376                int64 operand_index, HloInstruction* dynamic_size) {
377         HloInstruction* select_and_scatter = hlo;
378         const WindowDimension& window_dimension =
379             select_and_scatter->window().dimensions(dimension);
380 
381         if (!window_util::IsTrivialWindowDimension(window_dimension)) {
382           return Unimplemented(
383               "Dynamic Spatial select and scatter is not supported: %s",
384               select_and_scatter->ToString());
385         }
386 
387         parent_->SetDynamicSize(select_and_scatter, {}, dimension,
388                                 dynamic_size);
389 
390         return Status::OK();
391       });
392 }
393 
HandleSlice(HloInstruction * hlo)394 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
395   return ForEachOperandDynamicDimension(
396       hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 dimension,
397                int64 /*operand_index*/, HloInstruction* dynamic_size) {
398         if (hlo->slice_starts(dimension) != 0 ||
399             hlo->slice_strides(dimension) != 1 ||
400             hlo->slice_limits(dimension) !=
401                 operand->shape().dimensions(dimension)) {
402           return Unimplemented(
403               "Dynamic dimension propagation on Slice where it doesn't slice "
404               "out an entire dimension is not supported %s",
405               hlo->ToString());
406         }
407 
408         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
409 
410         return Status::OK();
411       });
412 }
413 
HandleWhile(HloInstruction * hlo)414 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
415   // While loop is handled by passing dynamic size hlos as parameters into the
416   // hlo while loop. This is done by replacing the original while with a new
417   // one.
418   //
419   // Before:
420   //
421   // op1 = ...
422   // op2 = ...
423   // op1_x = ... // dynamic dimension size of op1
424   // while = while(op1, op2)
425   //
426   //
427   // After:
428   //
429   // op1 = ...
430   // op2 = ...
431   // op1_x = ... // dynamic dimension size of op1
432   // while = while(op1, op2, op1_x)
433   //
434   // In the above graph, op_x is the bound of the dynamic dimension size of op1
435   // and is wired into the while loop as new parameter.
436   //
437   // TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic
438   // bound can be propagated through native xla values instead of relying on
439   // additional parameter.
440 
441   // dynamic_size_to_operand_id_index_map keeps track of dynamic size operations
442   // to their operand ids in the new while loop.
443   absl::flat_hash_map<HloInstruction*, int64>
444       dynamic_size_to_operand_id_index_map;
445 
446   // operands_to_add collects dynamic sizes that need to be added to the while
447   // loop as parameters. Note that a dynamic size is ignored if it is already
448   // part of the parameter. i.e.:
449   //
450   // We don't do:
451   //
452   // op1 = ...
453   // op2 = ...
454   // op_x = ... // dynamic dimension size of both op1 and op2
455   // while = while(op1, op2, op_x, op_x) // 4 parameters
456   //
457   // But we do:
458   //
459   // op1 = ...
460   // op2 = ...
461   // op_x = ... // dynamic dimension size of both op1 and op2
462   // while = while(op1, op2, op_x)
463   //
464   // An alternative is to do this in a while loop CSE pass.
465   //
466   std::vector<HloInstruction*> operands_to_add;
467   int64 operand_count = hlo->shape().tuple_shapes_size();
468   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
469       hlo, [&](HloInstruction*, ShapeIndex, int64, int64,
470                HloInstruction* dynamic_size) {
471         const HloInstruction* tuple_operand = hlo->operand(0);
472         for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
473           if (dynamic_size == tuple_operand->operand(i)) {
474             dynamic_size_to_operand_id_index_map[dynamic_size] = i;
475             return Status::OK();
476           }
477         }
478         auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
479         if (iter == dynamic_size_to_operand_id_index_map.end()) {
480           operands_to_add.push_back(dynamic_size);
481           dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++;
482         }
483         return Status::OK();
484       }));
485 
486   if (!operands_to_add.empty()) {
487     // Only replace the while loop if there are new parameters to add.
488     HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
489     TF_ASSIGN_OR_RETURN(
490         WhileUtil::MakeInstructionsLiveInResult result,
491         WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
492     // WhileUtil creates a new while hlo and tuple. Update the dynamic size
493     // mapping for the newly created tuple.
494     HloInstruction* new_tuple_operand =
495         result.new_while_instr->mutable_operand(0);
496     parent_->CopyMapping(/*from=*/old_tuple_operand, /*to=*/new_tuple_operand);
497     hlo = result.new_while_instr;
498   }
499 
500   // We have replaced the while loop, now set the dynamic dimensions for the
501   // newly created while loop so that the hlos that consumes the while loop can
502   // see the dynamic dimensions. Also sets the dynamic parameter binding for
503   // running inference in the while loop.
504   DynamicParameterBinding binding_for_while;
505   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
506       hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
507                int64 operand_index, HloInstruction* dynamic_size) {
508         DynamicParameterBinding::DynamicParameter dynamic_parameter{
509             operand_index,
510             {dynamic_size_to_operand_id_index_map[dynamic_size]}};
511         DynamicParameterBinding::DynamicDimension dynamic_dimension{
512             operand_index, index, dimension};
513         TF_RETURN_IF_ERROR(
514             binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
515         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
516         return Status::OK();
517       }));
518 
519   // Run inference in while body and condition.
520   TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
521       hlo->while_body(), binding_for_while, parent_));
522   TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
523       hlo->while_condition(), binding_for_while, parent_));
524 
525   return Status::OK();
526 }
527 
HandleParameter(HloInstruction * hlo)528 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
529   return param_bindings_.ForEachBinding(
530       [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
531           const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
532         if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
533           return Status::OK();
534         }
535         HloComputation* computation = hlo->parent();
536         HloInstruction* target_parameter =
537             computation->parameter_instruction(dynamic_dimension.parameter_num);
538 
539         HloInstruction* dynamic_size =
540             computation->parameter_instruction(dynamic_parameter.parameter_num);
541         for (int64 i : dynamic_parameter.parameter_index) {
542           dynamic_size =
543               computation->AddInstruction(HloInstruction::CreateGetTupleElement(
544                   ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
545                   dynamic_size, i));
546         }
547 
548         parent_->SetDynamicSize(target_parameter,
549                                 dynamic_dimension.parameter_index,
550                                 dynamic_dimension.dimension, dynamic_size);
551         return Status::OK();
552       });
553 }
554 
ForEachOperandDynamicDimension(HloInstruction * inst,const OperandDynamicDimensionFn & fn)555 Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
556     HloInstruction* inst, const OperandDynamicDimensionFn& fn) {
557   for (int64 operand_index = 0; operand_index < inst->operand_count();
558        ++operand_index) {
559     auto iter =
560         parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index));
561     if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
562       for (auto& dynamic_dimension : iter->second) {
563         HloInstruction* dynamic_size = parent_->GetDynamicSize(
564             dynamic_dimension.inst, dynamic_dimension.index,
565             dynamic_dimension.dim);
566         TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
567                               dynamic_dimension.dim, operand_index,
568                               dynamic_size));
569       }
570     }
571   }
572   return Status::OK();
573 }
574 
CopyMapping(HloInstruction * from,HloInstruction * to)575 void DynamicDimensionInference::CopyMapping(HloInstruction* from,
576                                             HloInstruction* to) {
577   auto iter = per_hlo_dynamic_dimensions_.find(from);
578   if (iter != per_hlo_dynamic_dimensions_.end()) {
579     for (auto& dynamic_dimension : iter->second) {
580       HloInstruction* dynamic_size =
581           GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
582                          dynamic_dimension.dim);
583       SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
584                      dynamic_size);
585     }
586   }
587 }
588 
589 /* static */
Run(HloModule * module)590 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
591     HloModule* module) {
592   VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString();
593   DynamicDimensionInference inference(module);
594   TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
595   return inference;
596 }
597 
ToString() const598 string DynamicDimensionInference::ToString() const {
599   std::vector<string> pieces;
600   pieces.push_back("DynamicDimensionInference: ");
601   for (const auto& mapping : dynamic_mapping_) {
602     const DynamicDimension& dynamic_dimension = mapping.first;
603     pieces.push_back(absl::StrFormat(
604         " -- instruction %s at %s has dim %lld as dynamic"
605         " dimension, which is represented by instruction %s",
606         dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(),
607         dynamic_dimension.dim, mapping.second->ToString()));
608   }
609   return absl::StrJoin(pieces, "\n");
610 }
611 
DynamicDimensionInference(HloModule * module)612 DynamicDimensionInference::DynamicDimensionInference(HloModule* module)
613     : module_(module) {}
614 
AnalyzeDynamicDimensions()615 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
616   return DynamicDimensionInferenceVisitor::Run(
617       module_->entry_computation(), module_->dynamic_parameter_binding(), this);
618 }
619 
GetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64 dim) const620 HloInstruction* DynamicDimensionInference::GetDynamicSize(
621     HloInstruction* inst, const ShapeIndex& index, int64 dim) const {
622   auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
623   if (iter != dynamic_mapping_.end()) {
624     return iter->second;
625   }
626   return nullptr;
627 }
628 
629 }  // namespace xla
630