1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/while_util.h"
22 #include "tensorflow/compiler/xla/window_util.h"
23
24 namespace xla {
25
26 class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
27 public:
DynamicDimensionInferenceVisitor(const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent)28 explicit DynamicDimensionInferenceVisitor(
29 const DynamicParameterBinding& param_bindings,
30 DynamicDimensionInference* parent)
31 : param_bindings_(param_bindings), parent_(parent) {}
32
33 Status DefaultAction(HloInstruction* hlo) override;
34
Run(HloComputation * computation,const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent)35 static Status Run(HloComputation* computation,
36 const DynamicParameterBinding& param_bindings,
37 DynamicDimensionInference* parent) {
38 DynamicDimensionInferenceVisitor visitor(param_bindings, parent);
39 return computation->Accept(&visitor);
40 }
41
42 Status HandleParameter(HloInstruction* hlo) override;
43
44 Status HandleReduce(HloInstruction* hlo) override;
45
46 Status HandleDot(HloInstruction* hlo) override;
47
48 Status HandleTuple(HloInstruction* hlo) override;
49
50 Status HandleTranspose(HloInstruction* hlo) override;
51
52 Status HandleReshape(HloInstruction* hlo) override;
53
54 Status HandlePad(HloInstruction* hlo) override;
55
56 Status HandleBroadcast(HloInstruction* hlo) override;
57
58 Status HandleGetDimensionSize(HloInstruction* hlo) override;
59
60 Status HandleSelect(HloInstruction* hlo) override;
61
62 Status HandleConvolution(HloInstruction* hlo) override;
63
64 Status HandleReduceWindow(HloInstruction* hlo) override;
65
66 Status HandleSelectAndScatter(HloInstruction* hlo) override;
67
68 Status HandleGetTupleElement(HloInstruction* hlo) override;
69
70 Status HandleElementwiseUnary(HloInstruction* hlo) override;
71
72 Status HandleElementwiseBinary(HloInstruction* hlo) override;
73
74 Status HandleWhile(HloInstruction* hlo) override;
75
76 Status HandleSlice(HloInstruction* hlo) override;
77
78 private:
79 using OperandDynamicDimensionFn = std::function<Status(
80 HloInstruction* operand, ShapeIndex index, int64 dimension,
81 int64 operand_index, HloInstruction* dynamic_size)>;
82
83 Status ForEachOperandDynamicDimension(HloInstruction* inst,
84 const OperandDynamicDimensionFn&);
85
86 // Pass through a dynamic dimension from the input to the output with the same
87 // value and index in the shape. This is a helper function to handle trivial
88 // instructions like elementwise operations.
89 Status PassThroughDynamicDimension(HloInstruction*);
90
91 // The dynamic parameter bindings of this computation.
92 const DynamicParameterBinding& param_bindings_;
93
94 // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
95 DynamicDimensionInference* parent_;
96 };
97
DefaultAction(HloInstruction * hlo)98 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
99 return ForEachOperandDynamicDimension(
100 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
101 int64 operand_index, HloInstruction* dynamic_size) {
102 return UnimplementedStrCat(
103 "Asked to propagate a dynamic dimension from hlo ",
104 operand->ToString(), "@", index.ToString(), "@", dimension,
105 " to hlo ", hlo->ToString(), ", which is not implemented.");
106 });
107 }
108
HandleGetTupleElement(HloInstruction * hlo)109 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
110 HloInstruction* hlo) {
111 return ForEachOperandDynamicDimension(
112 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
113 int64 operand_index, HloInstruction* dynamic_size) {
114 if (hlo->tuple_index() == index[0]) {
115 ShapeIndex new_index =
116 ShapeIndexView(index).ConsumeFront().ToShapeIndex();
117 parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
118 }
119 return Status::OK();
120 });
121 }
122
HandleTuple(HloInstruction * hlo)123 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
124 return ForEachOperandDynamicDimension(
125 hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
126 int64 operand_index, HloInstruction* dynamic_size) {
127 index.push_front(operand_index);
128 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
129 return Status::OK();
130 });
131 }
132
HandleBroadcast(HloInstruction * hlo)133 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
134 return ForEachOperandDynamicDimension(
135 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
136 int64 operand_index, HloInstruction* dynamic_size) {
137 int64 broadcast_dim = hlo->dimensions(dimension);
138 parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
139 return Status::OK();
140 });
141 }
142
HandlePad(HloInstruction * hlo)143 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
144 return ForEachOperandDynamicDimension(
145 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
146 int64 operand_index, HloInstruction* dynamic_size) {
147 if (operand_index != 0) {
148 return Unimplemented(
149 "Dynamic dimension on padding value is not supported");
150 }
151 const PaddingConfig_PaddingConfigDimension& padding_config =
152 hlo->padding_config().dimensions(dimension);
153 if (padding_config.interior_padding() == 0 &&
154 padding_config.edge_padding_low() == 0 &&
155 padding_config.edge_padding_high() == 0) {
156 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
157 return Status::OK();
158 } else {
159 return Unimplemented(
160 "Dynamic dimension propagation on padding dimension is not "
161 "supported.");
162 }
163 });
164 }
165
HandleReduce(HloInstruction * hlo)166 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
167 return ForEachOperandDynamicDimension(
168 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
169 int64 operand_index, HloInstruction* dynamic_size) {
170 HloInstruction* reduce = hlo;
171 int64 operand_count = reduce->operand_count();
172 CHECK_EQ(operand_count % 2, 0);
173 if (operand_index >= operand_count / 2) {
174 // Init values doesn't have dynamic size.
175 return Status::OK();
176 }
177 if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
178 // Dimension is to be reduce, stop tracing.
179 return Status::OK();
180 }
181
182 // Find out the new dynamic dimension after reduce.
183 int64 dimensions_not_reduced_count = 0;
184 for (int i = 0; i < operand->shape().rank(); ++i) {
185 if (dimension == i) {
186 parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count,
187 dynamic_size);
188
189 return Status::OK();
190 }
191 if (absl::c_count(reduce->dimensions(), i) == 0) {
192 dimensions_not_reduced_count++;
193 }
194 }
195
196 return Status::OK();
197 });
198 }
199
HandleDot(HloInstruction * hlo)200 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
201 return ForEachOperandDynamicDimension(
202 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
203 int64 operand_index, HloInstruction* dynamic_size) {
204 HloInstruction* dot = hlo;
205 const DotDimensionNumbers& dimension_numbers =
206 dot->dot_dimension_numbers();
207 // A map from the operand dimensions to result dimension.
208 absl::flat_hash_map<int64, int64> result_dim_mapping;
209 int64 current_result_dims = 0;
210 std::unordered_set<int64> batch_dims(
211 dimension_numbers.rhs_batch_dimensions().begin(),
212 dimension_numbers.rhs_batch_dimensions().end());
213
214 for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
215 result_dim_mapping[i] = current_result_dims++;
216 }
217
218 for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
219 if (!absl::c_linear_search(
220 dimension_numbers.lhs_contracting_dimensions(), i)) {
221 if (operand_index == 0) {
222 result_dim_mapping[i] = current_result_dims;
223 }
224 current_result_dims++;
225 }
226 }
227
228 for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
229 if (!absl::c_linear_search(
230 dimension_numbers.rhs_contracting_dimensions(), i) &&
231 !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
232 i)) {
233 if (operand_index == 1) {
234 result_dim_mapping[i] = current_result_dims;
235 }
236 current_result_dims++;
237 }
238 }
239
240 // Check if the operand dim is in the result shape. If so, add another
241 // work item to trace that dimension.
242 auto iter = result_dim_mapping.find(dimension);
243 if (iter != result_dim_mapping.end()) {
244 parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
245 }
246
247 return Status::OK();
248 });
249 }
250
HandleTranspose(HloInstruction * hlo)251 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
252 return ForEachOperandDynamicDimension(
253 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
254 int64 operand_index, HloInstruction* dynamic_size) {
255 parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension],
256 dynamic_size);
257 return Status::OK();
258 });
259 }
260
HandleConvolution(HloInstruction * hlo)261 Status DynamicDimensionInferenceVisitor::HandleConvolution(
262 HloInstruction* hlo) {
263 return ForEachOperandDynamicDimension(
264 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
265 int64 operand_index, HloInstruction* dynamic_size) {
266 HloInstruction* conv = hlo;
267 const ConvolutionDimensionNumbers& dimension_numbers =
268 conv->convolution_dimension_numbers();
269
270 if (operand_index == 0) {
271 if (dimension == dimension_numbers.input_batch_dimension()) {
272 parent_->SetDynamicSize(conv, {},
273 dimension_numbers.output_batch_dimension(),
274 dynamic_size);
275 return Status::OK();
276 }
277
278 if (dimension == dimension_numbers.input_feature_dimension()) {
279 return Status::OK();
280 }
281 } else {
282 if (dimension == dimension_numbers.kernel_input_feature_dimension()) {
283 return Status::OK();
284 }
285 }
286
287 return Unimplemented("Dynamic Spatial Convolution is not supported: %s",
288 conv->ToString());
289 });
290 }
291
HandleGetDimensionSize(HloInstruction *)292 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
293 HloInstruction*) {
294 // Dynamic dimension doesn't propagate through GetDimensionSize:
295 //
296 // Input: F32[x, y, z]
297 // |
298 // GetDimensionSize(1): U32[]
299 //
300 // The returned value is a scalar, which doesn't have any dynamic dimension in
301 // the shape (although the value contains the real size of the dynamic
302 // dimension of the input).
303 return Status::OK();
304 }
305
PassThroughDynamicDimension(HloInstruction * hlo)306 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
307 HloInstruction* hlo) {
308 return ForEachOperandDynamicDimension(
309 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
310 int64 operand_index, HloInstruction* dynamic_size) {
311 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
312 return Status::OK();
313 });
314 }
315
HandleElementwiseUnary(HloInstruction * hlo)316 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
317 HloInstruction* hlo) {
318 return PassThroughDynamicDimension(hlo);
319 }
320
HandleSelect(HloInstruction * hlo)321 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
322 return PassThroughDynamicDimension(hlo);
323 }
324
HandleElementwiseBinary(HloInstruction * hlo)325 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
326 HloInstruction* hlo) {
327 return PassThroughDynamicDimension(hlo);
328 }
329
HandleReshape(HloInstruction * hlo)330 Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
331 return ForEachOperandDynamicDimension(
332 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
333 int64 operand_index, HloInstruction* dynamic_size) {
334 HloInstruction* reshape = hlo;
335 std::vector<std::pair<int64, int64>> unmodified_dims =
336 ShapeUtil::DimensionsUnmodifiedByReshape(operand->shape(),
337 reshape->shape());
338 for (auto& unmodified : unmodified_dims) {
339 if (unmodified.first == dimension) {
340 parent_->SetDynamicSize(reshape, {}, unmodified.second,
341 dynamic_size);
342 return Status::OK();
343 }
344 }
345 return Unimplemented(
346 "Dynamic Reshape on modified dimensions is yet not supported: %s",
347 reshape->ToString());
348 });
349 }
350
HandleReduceWindow(HloInstruction * hlo)351 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
352 HloInstruction* hlo) {
353 return ForEachOperandDynamicDimension(
354 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
355 int64 operand_index, HloInstruction* dynamic_size) {
356 HloInstruction* reduce_window = hlo;
357 const WindowDimension& window_dimension =
358 reduce_window->window().dimensions(dimension);
359
360 if (!window_util::IsTrivialWindowDimension(window_dimension)) {
361 return Unimplemented(
362 "Dynamic Spatial reduce window is not supported: %s",
363 reduce_window->ToString());
364 }
365
366 parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size);
367
368 return Status::OK();
369 });
370 }
371
HandleSelectAndScatter(HloInstruction * hlo)372 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
373 HloInstruction* hlo) {
374 return ForEachOperandDynamicDimension(
375 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
376 int64 operand_index, HloInstruction* dynamic_size) {
377 HloInstruction* select_and_scatter = hlo;
378 const WindowDimension& window_dimension =
379 select_and_scatter->window().dimensions(dimension);
380
381 if (!window_util::IsTrivialWindowDimension(window_dimension)) {
382 return Unimplemented(
383 "Dynamic Spatial select and scatter is not supported: %s",
384 select_and_scatter->ToString());
385 }
386
387 parent_->SetDynamicSize(select_and_scatter, {}, dimension,
388 dynamic_size);
389
390 return Status::OK();
391 });
392 }
393
HandleSlice(HloInstruction * hlo)394 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
395 return ForEachOperandDynamicDimension(
396 hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 dimension,
397 int64 /*operand_index*/, HloInstruction* dynamic_size) {
398 if (hlo->slice_starts(dimension) != 0 ||
399 hlo->slice_strides(dimension) != 1 ||
400 hlo->slice_limits(dimension) !=
401 operand->shape().dimensions(dimension)) {
402 return Unimplemented(
403 "Dynamic dimension propagation on Slice where it doesn't slice "
404 "out an entire dimension is not supported %s",
405 hlo->ToString());
406 }
407
408 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
409
410 return Status::OK();
411 });
412 }
413
HandleWhile(HloInstruction * hlo)414 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
415 // While loop is handled by passing dynamic size hlos as parameters into the
416 // hlo while loop. This is done by replacing the original while with a new
417 // one.
418 //
419 // Before:
420 //
421 // op1 = ...
422 // op2 = ...
423 // op1_x = ... // dynamic dimension size of op1
424 // while = while(op1, op2)
425 //
426 //
427 // After:
428 //
429 // op1 = ...
430 // op2 = ...
431 // op1_x = ... // dynamic dimension size of op1
432 // while = while(op1, op2, op1_x)
433 //
434 // In the above graph, op_x is the bound of the dynamic dimension size of op1
435 // and is wired into the while loop as new parameter.
436 //
437 // TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic
438 // bound can be propagated through native xla values instead of relying on
439 // additional parameter.
440
441 // dynamic_size_to_operand_id_index_map keeps track of dynamic size operations
442 // to their operand ids in the new while loop.
443 absl::flat_hash_map<HloInstruction*, int64>
444 dynamic_size_to_operand_id_index_map;
445
446 // operands_to_add collects dynamic sizes that need to be added to the while
447 // loop as parameters. Note that a dynamic size is ignored if it is already
448 // part of the parameter. i.e.:
449 //
450 // We don't do:
451 //
452 // op1 = ...
453 // op2 = ...
454 // op_x = ... // dynamic dimension size of both op1 and op2
455 // while = while(op1, op2, op_x, op_x) // 4 parameters
456 //
457 // But we do:
458 //
459 // op1 = ...
460 // op2 = ...
461 // op_x = ... // dynamic dimension size of both op1 and op2
462 // while = while(op1, op2, op_x)
463 //
464 // An alternative is to do this in a while loop CSE pass.
465 //
466 std::vector<HloInstruction*> operands_to_add;
467 int64 operand_count = hlo->shape().tuple_shapes_size();
468 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
469 hlo, [&](HloInstruction*, ShapeIndex, int64, int64,
470 HloInstruction* dynamic_size) {
471 const HloInstruction* tuple_operand = hlo->operand(0);
472 for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
473 if (dynamic_size == tuple_operand->operand(i)) {
474 dynamic_size_to_operand_id_index_map[dynamic_size] = i;
475 return Status::OK();
476 }
477 }
478 auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
479 if (iter == dynamic_size_to_operand_id_index_map.end()) {
480 operands_to_add.push_back(dynamic_size);
481 dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++;
482 }
483 return Status::OK();
484 }));
485
486 if (!operands_to_add.empty()) {
487 // Only replace the while loop if there are new parameters to add.
488 HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
489 TF_ASSIGN_OR_RETURN(
490 WhileUtil::MakeInstructionsLiveInResult result,
491 WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
492 // WhileUtil creates a new while hlo and tuple. Update the dynamic size
493 // mapping for the newly created tuple.
494 HloInstruction* new_tuple_operand =
495 result.new_while_instr->mutable_operand(0);
496 parent_->CopyMapping(/*from=*/old_tuple_operand, /*to=*/new_tuple_operand);
497 hlo = result.new_while_instr;
498 }
499
500 // We have replaced the while loop, now set the dynamic dimensions for the
501 // newly created while loop so that the hlos that consumes the while loop can
502 // see the dynamic dimensions. Also sets the dynamic parameter binding for
503 // running inference in the while loop.
504 DynamicParameterBinding binding_for_while;
505 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
506 hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
507 int64 operand_index, HloInstruction* dynamic_size) {
508 DynamicParameterBinding::DynamicParameter dynamic_parameter{
509 operand_index,
510 {dynamic_size_to_operand_id_index_map[dynamic_size]}};
511 DynamicParameterBinding::DynamicDimension dynamic_dimension{
512 operand_index, index, dimension};
513 TF_RETURN_IF_ERROR(
514 binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
515 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
516 return Status::OK();
517 }));
518
519 // Run inference in while body and condition.
520 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
521 hlo->while_body(), binding_for_while, parent_));
522 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
523 hlo->while_condition(), binding_for_while, parent_));
524
525 return Status::OK();
526 }
527
HandleParameter(HloInstruction * hlo)528 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
529 return param_bindings_.ForEachBinding(
530 [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
531 const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
532 if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
533 return Status::OK();
534 }
535 HloComputation* computation = hlo->parent();
536 HloInstruction* target_parameter =
537 computation->parameter_instruction(dynamic_dimension.parameter_num);
538
539 HloInstruction* dynamic_size =
540 computation->parameter_instruction(dynamic_parameter.parameter_num);
541 for (int64 i : dynamic_parameter.parameter_index) {
542 dynamic_size =
543 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
544 ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
545 dynamic_size, i));
546 }
547
548 parent_->SetDynamicSize(target_parameter,
549 dynamic_dimension.parameter_index,
550 dynamic_dimension.dimension, dynamic_size);
551 return Status::OK();
552 });
553 }
554
ForEachOperandDynamicDimension(HloInstruction * inst,const OperandDynamicDimensionFn & fn)555 Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
556 HloInstruction* inst, const OperandDynamicDimensionFn& fn) {
557 for (int64 operand_index = 0; operand_index < inst->operand_count();
558 ++operand_index) {
559 auto iter =
560 parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index));
561 if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
562 for (auto& dynamic_dimension : iter->second) {
563 HloInstruction* dynamic_size = parent_->GetDynamicSize(
564 dynamic_dimension.inst, dynamic_dimension.index,
565 dynamic_dimension.dim);
566 TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
567 dynamic_dimension.dim, operand_index,
568 dynamic_size));
569 }
570 }
571 }
572 return Status::OK();
573 }
574
CopyMapping(HloInstruction * from,HloInstruction * to)575 void DynamicDimensionInference::CopyMapping(HloInstruction* from,
576 HloInstruction* to) {
577 auto iter = per_hlo_dynamic_dimensions_.find(from);
578 if (iter != per_hlo_dynamic_dimensions_.end()) {
579 for (auto& dynamic_dimension : iter->second) {
580 HloInstruction* dynamic_size =
581 GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
582 dynamic_dimension.dim);
583 SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
584 dynamic_size);
585 }
586 }
587 }
588
589 /* static */
Run(HloModule * module)590 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
591 HloModule* module) {
592 VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString();
593 DynamicDimensionInference inference(module);
594 TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
595 return inference;
596 }
597
ToString() const598 string DynamicDimensionInference::ToString() const {
599 std::vector<string> pieces;
600 pieces.push_back("DynamicDimensionInference: ");
601 for (const auto& mapping : dynamic_mapping_) {
602 const DynamicDimension& dynamic_dimension = mapping.first;
603 pieces.push_back(absl::StrFormat(
604 " -- instruction %s at %s has dim %lld as dynamic"
605 " dimension, which is represented by instruction %s",
606 dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(),
607 dynamic_dimension.dim, mapping.second->ToString()));
608 }
609 return absl::StrJoin(pieces, "\n");
610 }
611
DynamicDimensionInference(HloModule * module)612 DynamicDimensionInference::DynamicDimensionInference(HloModule* module)
613 : module_(module) {}
614
AnalyzeDynamicDimensions()615 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
616 return DynamicDimensionInferenceVisitor::Run(
617 module_->entry_computation(), module_->dynamic_parameter_binding(), this);
618 }
619
GetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64 dim) const620 HloInstruction* DynamicDimensionInference::GetDynamicSize(
621 HloInstruction* inst, const ShapeIndex& index, int64 dim) const {
622 auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
623 if (iter != dynamic_mapping_.end()) {
624 return iter->second;
625 }
626 return nullptr;
627 }
628
629 } // namespace xla
630