1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/memory/memory.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/comparators.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/comparison_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
31 #include "tensorflow/compiler/xla/service/shape_inference.h"
32 #include "tensorflow/compiler/xla/util.h"
33 
34 namespace xla {
35 using absl::StrCat;
36 
MakeUnaryHlo(HloOpcode opcode,HloInstruction * operand)37 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
38                                        HloInstruction* operand) {
39   HloComputation* computation = operand->parent();
40   TF_ASSIGN_OR_RETURN(Shape unary_op_shape,
41                       ShapeInference::InferUnaryOpShape(opcode, operand));
42   return computation->AddInstruction(
43       HloInstruction::CreateUnary(unary_op_shape, opcode, operand));
44 }
45 
MakeBinaryHlo(HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)46 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
47                                         HloInstruction* rhs) {
48   HloComputation* computation = lhs->parent();
49   CHECK_EQ(computation, rhs->parent());
50   TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
51                       ShapeInference::InferBinaryOpShape(opcode, lhs, rhs));
52   return computation->AddInstruction(
53       HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
54 }
55 
MakeCompareHlo(ComparisonDirection direction,HloInstruction * lhs,HloInstruction * rhs)56 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
57                                          HloInstruction* lhs,
58                                          HloInstruction* rhs) {
59   HloComputation* computation = lhs->parent();
60   CHECK_EQ(computation, rhs->parent());
61   TF_ASSIGN_OR_RETURN(
62       Shape binary_op_shape,
63       ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
64   return computation->AddInstruction(
65       HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
66 }
67 
MakePadHlo(HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)68 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
69                                      HloInstruction* padding_value,
70                                      const PaddingConfig& padding_config) {
71   HloComputation* computation = operand->parent();
72   CHECK_EQ(computation, padding_value->parent());
73   TF_ASSIGN_OR_RETURN(
74       Shape pad_shape,
75       ShapeInference::InferPadShape(operand->shape(), padding_value->shape(),
76                                     padding_config));
77   return computation->AddInstruction(HloInstruction::CreatePad(
78       pad_shape, operand, padding_value, padding_config));
79 }
80 
MakeSliceHlo(HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)81 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
82                                        absl::Span<const int64> start_indices,
83                                        absl::Span<const int64> limit_indices,
84                                        absl::Span<const int64> strides) {
85   HloComputation* computation = operand->parent();
86   TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
87                                              operand->shape(), start_indices,
88                                              limit_indices, strides));
89   return computation->AddInstruction(HloInstruction::CreateSlice(
90       slice_shape, operand, start_indices, limit_indices, strides));
91 }
92 
MakeConvolveHlo(HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config,absl::optional<PrimitiveType> preferred_element_type)93 StatusOr<HloInstruction*> MakeConvolveHlo(
94     HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
95     int64 batch_group_count, const Window& window,
96     const ConvolutionDimensionNumbers& dimension_numbers,
97     const PrecisionConfig& precision_config,
98     absl::optional<PrimitiveType> preferred_element_type) {
99   HloComputation* computation = lhs->parent();
100   CHECK_EQ(computation, rhs->parent());
101   TF_ASSIGN_OR_RETURN(
102       Shape convolve_shape,
103       ShapeInference::InferConvolveShape(
104           lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
105           window, dimension_numbers, preferred_element_type));
106   return computation->AddInstruction(HloInstruction::CreateConvolve(
107       convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
108       dimension_numbers, precision_config));
109 }
110 
MakeTransposeHlo(HloInstruction * operand,absl::Span<const int64> dimensions)111 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
112                                            absl::Span<const int64> dimensions) {
113   HloComputation* computation = operand->parent();
114   TF_ASSIGN_OR_RETURN(
115       Shape transpose_shape,
116       ShapeInference::InferTransposeShape(operand->shape(), dimensions));
117   return computation->AddInstruction(
118       HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
119 }
120 
MakeReshapeHlo(const Shape & result_shape,HloInstruction * operand)121 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
122                                          HloInstruction* operand) {
123   HloComputation* computation = operand->parent();
124   return computation->AddInstruction(
125       HloInstruction::CreateReshape(result_shape, operand));
126 }
127 
MakeReshapeHlo(absl::Span<const int64> result_shape_dim_bounds,HloInstruction * operand)128 StatusOr<HloInstruction*> MakeReshapeHlo(
129     absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
130   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
131                                          result_shape_dim_bounds);
132   return MakeReshapeHlo(new_shape, operand);
133 }
134 
MakeDynamicSliceHlo(HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)135 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
136     HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
137     absl::Span<const int64> slice_sizes) {
138   HloComputation* computation = operand->parent();
139   std::vector<Shape> scalar_start_indices_shapes(
140       start_indices.size(),
141       ShapeUtil::MakeShape(start_indices[0]->shape().element_type(), {}));
142   TF_ASSIGN_OR_RETURN(
143       Shape dynamic_slice_shape,
144       ShapeInference::InferDynamicSliceShape(
145           operand->shape(), scalar_start_indices_shapes, slice_sizes));
146   return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
147       dynamic_slice_shape, operand, start_indices, slice_sizes));
148 }
149 
MakeDynamicSliceHlo(HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)150 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
151     HloInstruction* operand, HloInstruction* start_indices,
152     absl::Span<const int64> slice_sizes) {
153   HloComputation* computation = operand->parent();
154   CHECK_EQ(computation, start_indices->parent());
155   int64 rank = start_indices->shape().dimensions(0);
156   std::vector<HloInstruction*> scalar_start_indices;
157   for (int i = 0; i < rank; ++i) {
158     // TODO(b/118437727): Update callers to provide scalars directly.
159     auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
160         ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
161         start_indices, {i}, {i + 1}, {1}));
162     scalar_start_indices.push_back(
163         computation->AddInstruction(HloInstruction::CreateReshape(
164             ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
165             slice)));
166   }
167   std::vector<Shape> scalar_start_indices_shapes(
168       rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
169   TF_ASSIGN_OR_RETURN(
170       Shape dynamic_slice_shape,
171       ShapeInference::InferDynamicSliceShape(
172           operand->shape(), scalar_start_indices_shapes, slice_sizes));
173   return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
174       dynamic_slice_shape, operand, scalar_start_indices, slice_sizes));
175 }
176 
MakeDynamicUpdateSliceHlo(HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)177 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
178     HloInstruction* operand, HloInstruction* update,
179     HloInstruction* start_indices) {
180   HloComputation* computation = operand->parent();
181   CHECK_EQ(computation, update->parent());
182   CHECK_EQ(computation, start_indices->parent());
183   int64 rank = start_indices->shape().dimensions(0);
184   std::vector<HloInstruction*> scalar_start_indices;
185   for (int i = 0; i < rank; ++i) {
186     // TODO(b/118437727): Update callers to provide scalars directly.
187     auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
188         ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
189         start_indices, {i}, {i + 1}, {1}));
190     scalar_start_indices.push_back(
191         computation->AddInstruction(HloInstruction::CreateReshape(
192             ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
193             slice)));
194   }
195   std::vector<Shape> scalar_start_indices_shapes(
196       rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
197   TF_ASSIGN_OR_RETURN(
198       Shape dynamic_update_slice_shape,
199       ShapeInference::InferDynamicUpdateSliceShape(
200           operand->shape(), update->shape(), scalar_start_indices_shapes));
201   return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
202       dynamic_update_slice_shape, operand, update, scalar_start_indices));
203 }
204 
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,absl::Span<const int64> result_shape_bounds)205 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
206                                  absl::Span<const int64> broadcast_dimensions,
207                                  absl::Span<const int64> result_shape_bounds) {
208   HloComputation* computation = operand->parent();
209   Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
210                                                result_shape_bounds);
211 
212   return computation->AddInstruction(HloInstruction::CreateBroadcast(
213       broadcast_shape, operand, broadcast_dimensions));
214 }
215 
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,const Shape & shape)216 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
217                                  absl::Span<const int64> broadcast_dimensions,
218                                  const Shape& shape) {
219   return MakeBroadcastHlo(operand, broadcast_dimensions, shape.dimensions());
220 }
221 
MakeGetTupleElementHlo(HloInstruction * operand,int64 index)222 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
223                                                  int64 index) {
224   HloComputation* computation = operand->parent();
225 
226   TF_ASSIGN_OR_RETURN(
227       Shape gte_shape,
228       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
229   return computation->AddInstruction(
230       HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
231 }
232 
MakeConcatHlo(absl::Span<HloInstruction * const> operands,int64 dimension)233 StatusOr<HloInstruction*> MakeConcatHlo(
234     absl::Span<HloInstruction* const> operands, int64 dimension) {
235   CHECK_GT(operands.size(), 0);
236 
237   HloComputation* computation = operands[0]->parent();
238   CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
239     return instr->parent() == computation;
240   }));
241 
242   std::vector<const Shape*> operand_shapes;
243   absl::c_transform(operands, std::back_inserter(operand_shapes),
244                     [](HloInstruction* instr) { return &instr->shape(); });
245 
246   TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
247                                               operand_shapes, dimension));
248   return computation->AddInstruction(
249       HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
250 }
251 
MakeConvertToHlo(HloInstruction * hlo,PrimitiveType type)252 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type) {
253   if (hlo->shape().element_type() == type) {
254     return hlo;
255   }
256   Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
257   hlo =
258       hlo->parent()->AddInstruction(HloInstruction::CreateConvert(shape, hlo));
259   CHECK_EQ(hlo->shape().element_type(), type);
260   return hlo;
261 }
262 
MakeBitcastConvertToHlo(HloInstruction * hlo,PrimitiveType type)263 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
264                                         PrimitiveType type) {
265   if (hlo->shape().element_type() == type) {
266     return hlo;
267   }
268   Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
269   // PRED are stored as one byte, PRED have a BitWidth of 1, avoid this problem
270   // by using a convert instead of bitcast convert.
271   if (type == PRED || hlo->shape().element_type() == PRED) {
272     return MakeConvertToHlo(hlo, type);
273   }
274   hlo = hlo->parent()->AddInstruction(
275       HloInstruction::CreateBitcastConvert(shape, hlo));
276   CHECK_EQ(hlo->shape().element_type(), type);
277   return hlo;
278 }
279 
MakeIotaHlo(HloComputation * computation,const Shape & shape,int64 iota_dimension)280 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
281                             int64 iota_dimension) {
282   return computation->AddInstruction(
283       HloInstruction::CreateIota(shape, iota_dimension));
284 }
285 
MakeDotHlo(HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,absl::optional<PrimitiveType> preferred_element_type)286 StatusOr<HloInstruction*> MakeDotHlo(
287     HloInstruction* lhs, HloInstruction* rhs,
288     const DotDimensionNumbers& dim_numbers,
289     const PrecisionConfig& precision_config,
290     absl::optional<PrimitiveType> preferred_element_type) {
291   HloComputation* computation = lhs->parent();
292   CHECK_EQ(computation, rhs->parent());
293   TF_ASSIGN_OR_RETURN(
294       Shape dot_shape,
295       ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
296                                       preferred_element_type));
297   return computation->AddInstruction(HloInstruction::CreateDot(
298       dot_shape, lhs, rhs, dim_numbers, precision_config));
299 }
300 
MakeMapHlo(absl::Span<HloInstruction * const> operands,HloComputation * map_computation)301 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
302                                      HloComputation* map_computation) {
303   CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
304   HloComputation* computation = operands.front()->parent();
305   std::vector<const Shape*> operand_shapes;
306   int64 max_operand_rank = 0;
307   for (const HloInstruction* operand : operands) {
308     CHECK_EQ(computation, operand->parent());
309     operand_shapes.push_back(&operand->shape());
310     max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
311   }
312   std::vector<int64> map_dims(max_operand_rank);
313   std::iota(map_dims.begin(), map_dims.end(), 0);
314   TF_ASSIGN_OR_RETURN(
315       Shape map_shape,
316       ShapeInference::InferMapShape(
317           operand_shapes, map_computation->ComputeProgramShape(), map_dims));
318   return computation->AddInstruction(
319       HloInstruction::CreateMap(map_shape, operands, map_computation));
320 }
321 
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions,HloOpcode binary_opcode)322 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
323                                         HloInstruction* init_value,
324                                         absl::Span<const int64> dimensions,
325                                         HloOpcode binary_opcode) {
326   auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
327   auto result_shape = ShapeUtil::FilterDimensions(
328       [&](const int64 dim) { return !absl::c_linear_search(dimensions, dim); },
329       operand->shape());
330   HloComputation* reduce_computation;
331   {
332     HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
333     auto lhs = b.AddInstruction(
334         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
335     auto rhs = b.AddInstruction(
336         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
337     b.AddInstruction(
338         HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
339     reduce_computation =
340         operand->parent()->parent()->AddEmbeddedComputation(b.Build());
341   }
342 
343   return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
344       result_shape, operand, init_value, dimensions, reduce_computation));
345 }
346 
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,HloOpcode binary_opcode,HloModule * module)347 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
348                                         HloInstruction* init_value,
349                                         HloOpcode binary_opcode,
350                                         HloModule* module) {
351   DCHECK_NE(nullptr, module);
352   std::vector<int64> all_dims(operand->shape().rank());
353   std::iota(all_dims.begin(), all_dims.end(), 0);
354 
355   auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
356   HloComputation* reduce_computation;
357   {
358     HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
359     auto lhs = b.AddInstruction(
360         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
361     auto rhs = b.AddInstruction(
362         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
363     b.AddInstruction(
364         HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
365     reduce_computation = module->AddEmbeddedComputation(b.Build());
366   }
367 
368   return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
369       scalar_shape, operand, init_value, all_dims, reduce_computation));
370 }
371 
MakeReverseHlo(HloInstruction * operand,absl::Span<const int64> dimensions)372 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
373                                          absl::Span<const int64> dimensions) {
374   HloComputation* computation = operand->parent();
375   TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape(
376                                                operand->shape(), dimensions));
377   return computation->AddInstruction(
378       HloInstruction::CreateReverse(reverse_shape, operand, dimensions));
379 }
380 
MakeSelectHlo(HloInstruction * pred,HloInstruction * on_true,HloInstruction * on_false,HloInstruction * derived_from)381 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
382                                         HloInstruction* on_true,
383                                         HloInstruction* on_false,
384                                         HloInstruction* derived_from) {
385   HloComputation* computation = pred->parent();
386   DCHECK_EQ(computation, on_true->parent());
387   DCHECK_EQ(computation, on_false->parent());
388   Shape op_shape = on_true->shape();
389   if (ShapeUtil::IsScalar(pred->shape())) {
390     if (!ShapeUtil::IsScalar(op_shape) && !op_shape.IsTuple()) {
391       // If the output is not scalar, we need to broadcast the condition
392       // to match the contract of kSelect. For tuples, we use kTupleSelect
393       // which expects the condition to be a scalar.
394       pred = computation->AddInstruction(HloInstruction::CreateBroadcast(
395           ShapeUtil::ChangeElementType(op_shape, PrimitiveType::PRED), pred,
396           {}));
397       if (derived_from) {
398         derived_from->SetupDerivedInstruction(pred);
399       }
400     }
401   }
402   HloOpcode select_op_code =
403       op_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
404   TF_ASSIGN_OR_RETURN(Shape select_shape,
405                       ShapeInference::InferTernaryOpShape(select_op_code, pred,
406                                                           on_true, on_false));
407   HloInstruction* select =
408       computation->AddInstruction(HloInstruction::CreateTernary(
409           select_shape, select_op_code, pred, on_true, on_false));
410   if (derived_from) {
411     derived_from->SetupDerivedInstruction(select);
412   }
413   return select;
414 }
415 
MakeSortHlo(const Shape & sort_shape,absl::Span<HloInstruction * const> operands,int64 dimension_to_sort,bool is_stable,HloComputation::Builder * builder,HloModule * module)416 StatusOr<HloInstruction*> MakeSortHlo(
417     const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
418     int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
419     HloModule* module) {
420   CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
421   HloComputation* compare_computation;
422   XlaBuilder b("Sort.Compare");
423   std::vector<PrimitiveType> operand_types(operands.size());
424   for (int64 i = 0; i < operands.size(); ++i) {
425     operand_types[i] = operands[i]->shape().element_type();
426   }
427   XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
428   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
429   HloModuleConfig config(program_shape);
430   TF_ASSIGN_OR_RETURN(auto new_module,
431                       HloModule::CreateFromProto(comparator.proto(), config));
432   HloCloneContext context(module);
433   compare_computation =
434       module->DeepCloneComputation(new_module->entry_computation(), &context);
435   return builder->AddInstruction(HloInstruction::CreateSort(
436       sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
437 }
438 
CollapseFirstNDims(HloInstruction * operand,int64 n)439 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
440   CHECK_GT(n, 0);
441 
442   const Shape& operand_shape = operand->shape();
443   CHECK_GE(operand_shape.dimensions_size(), n);
444   int64 new_shape_leading_bound = 1;
445   for (int64 i = 0; i < n; i++) {
446     new_shape_leading_bound *= operand_shape.dimensions(i);
447   }
448 
449   std::vector<int64> new_shape_dims;
450   new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1);
451   new_shape_dims.push_back(new_shape_leading_bound);
452 
453   std::copy(operand_shape.dimensions().begin() + n,
454             operand_shape.dimensions().end(),
455             std::back_inserter(new_shape_dims));
456 
457   Shape output_shape =
458       ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
459 
460   return MakeReshapeHlo(output_shape, operand);
461 }
462 
PrependDegenerateDims(HloInstruction * operand,int64 n)463 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
464                                                 int64 n) {
465   CHECK_GT(n, 0);
466   std::vector<int64> new_shape_dims;
467   const Shape& operand_shape = operand->shape();
468   new_shape_dims.reserve(n + operand_shape.dimensions_size());
469   new_shape_dims.insert(new_shape_dims.begin(), n, 1);
470   absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
471   return MakeReshapeHlo(new_shape_dims, operand);
472 }
473 
ExpandFirstDimIntoNDims(HloInstruction * operand,absl::Span<const int64> expanded_dims)474 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
475     HloInstruction* operand, absl::Span<const int64> expanded_dims) {
476   CHECK_GT(operand->shape().dimensions_size(), 0);
477   CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
478 
479   std::vector<int64> expanded_shape_dim_bounds;
480   expanded_shape_dim_bounds.reserve(expanded_dims.size() +
481                                     operand->shape().dimensions_size() - 1);
482   absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
483   std::copy(operand->shape().dimensions().begin() + 1,
484             operand->shape().dimensions().end(),
485             std::back_inserter(expanded_shape_dim_bounds));
486   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
487                                          expanded_shape_dim_bounds);
488   return MakeReshapeHlo(new_shape, operand);
489 }
490 
ElideDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_elide)491 StatusOr<HloInstruction*> ElideDegenerateDims(
492     HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
493   return MakeReshapeHlo(
494       ShapeUtil::FilterDimensions(
495           [&](int64 dim) { return !absl::c_linear_search(dims_to_elide, dim); },
496           operand->shape()),
497       operand);
498 }
499 
InsertDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_insert)500 StatusOr<HloInstruction*> InsertDegenerateDims(
501     HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
502   CHECK(absl::c_is_sorted(dims_to_insert));
503 
504   const Shape& operand_shape = operand->shape();
505   int64 output_shape_rank =
506       operand_shape.dimensions_size() + dims_to_insert.size();
507   for (auto dim_to_insert : dims_to_insert) {
508     CHECK_LT(dim_to_insert, output_shape_rank);
509   }
510 
511   std::vector<int64> output_shape_dim_bounds;
512   output_shape_dim_bounds.reserve(output_shape_rank);
513   int64 operand_dims_idx = 0;
514   int64 dims_to_insert_idx = 0;
515   for (int64 i = 0; i < output_shape_rank; ++i) {
516     if (dims_to_insert_idx < dims_to_insert.size() &&
517         i == dims_to_insert[dims_to_insert_idx]) {
518       output_shape_dim_bounds.push_back(1);
519       ++dims_to_insert_idx;
520     } else {
521       output_shape_dim_bounds.push_back(
522           operand_shape.dimensions(operand_dims_idx));
523       ++operand_dims_idx;
524     }
525   }
526 
527   Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
528                                             output_shape_dim_bounds);
529   return MakeReshapeHlo(output_shape, operand);
530 }
531 
PadVectorWithZeros(HloInstruction * operand,int64 zeros_to_prepend,int64 zeros_to_append)532 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
533                                              int64 zeros_to_prepend,
534                                              int64 zeros_to_append) {
535   HloComputation* computation = operand->parent();
536   CHECK_EQ(operand->shape().dimensions_size(), 1);
537   PaddingConfig padding_config;
538   PaddingConfig::PaddingConfigDimension padding_config_dim;
539   padding_config_dim.set_edge_padding_low(zeros_to_prepend);
540   padding_config_dim.set_edge_padding_high(zeros_to_append);
541   *padding_config.add_dimensions() = padding_config_dim;
542 
543   HloInstruction* zero =
544       computation->AddInstruction(HloInstruction::CreateConstant(
545           LiteralUtil::Zero(operand->shape().element_type())));
546   return MakePadHlo(operand, zero, padding_config);
547 }
548 
BroadcastZeros(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)549 HloInstruction* BroadcastZeros(HloComputation* computation,
550                                PrimitiveType element_type,
551                                absl::Span<const int64> broadcast_dimensions) {
552   HloInstruction* zero = computation->AddInstruction(
553       HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
554   return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
555                           /*result_shape_bounds=*/broadcast_dimensions);
556 }
557 
BroadcastOnes(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)558 HloInstruction* BroadcastOnes(HloComputation* computation,
559                               PrimitiveType element_type,
560                               absl::Span<const int64> broadcast_dimensions) {
561   HloInstruction* one = computation->AddInstruction(
562       HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
563   return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
564                           /*result_shape_bounds=*/broadcast_dimensions);
565 }
566 
567 // Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
568 // while internal nodes are tuples.
CreateDummyOp(HloComputation::Builder * b,const Shape & shape)569 HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
570   if (shape.IsArray()) {
571     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
572         LiteralUtil::Zero(shape.element_type())));
573     return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
574   }
575   CHECK(shape.IsTuple());
576   std::vector<HloInstruction*> sub_instructions;
577   for (const Shape& subshape : shape.tuple_shapes()) {
578     sub_instructions.push_back(CreateDummyOp(b, subshape));
579   }
580   return b->AddInstruction(HloInstruction::CreateTuple(sub_instructions));
581 }
582 
CreateComputationWithSignature(absl::Span<const Shape * const> domain,const Shape & range,absl::string_view name)583 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
584     absl::Span<const Shape* const> domain, const Shape& range,
585     absl::string_view name) {
586   HloComputation::Builder b{string(name)};
587   int64 param_idx = 0;
588   for (const Shape* param_shape : domain) {
589     b.AddInstruction(HloInstruction::CreateParameter(
590         param_idx, *param_shape, StrCat("param.", param_idx)));
591     param_idx++;
592   }
593 
594   // We can't change the root type of a computation once it is created so create
595   // a dummy root instruction to give the computation the right root shape.  Use
596   // a (recursive) broadcast here to avoid creating large constants.
597   CreateDummyOp(&b, range);
598   return b.Build();
599 }
600 
601 }  // namespace xla
602