1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/memory/memory.h"
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/client/lib/comparators.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
29 #include "tensorflow/compiler/xla/service/shape_inference.h"
30 #include "tensorflow/compiler/xla/util.h"
31 
32 namespace xla {
33 using absl::StrCat;
34 
MakeBinaryHlo(HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)35 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
36                                         HloInstruction* rhs) {
37   HloComputation* computation = lhs->parent();
38   CHECK_EQ(computation, rhs->parent());
39   TF_ASSIGN_OR_RETURN(Shape binary_op_shape,
40                       ShapeInference::InferBinaryOpShape(opcode, lhs, rhs));
41   return computation->AddInstruction(
42       HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
43 }
44 
MakeCompareHlo(ComparisonDirection direction,HloInstruction * lhs,HloInstruction * rhs)45 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
46                                          HloInstruction* lhs,
47                                          HloInstruction* rhs) {
48   HloComputation* computation = lhs->parent();
49   CHECK_EQ(computation, rhs->parent());
50   TF_ASSIGN_OR_RETURN(
51       Shape binary_op_shape,
52       ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
53   return computation->AddInstruction(
54       HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
55 }
56 
MakePadHlo(HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)57 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
58                                      HloInstruction* padding_value,
59                                      const PaddingConfig& padding_config) {
60   HloComputation* computation = operand->parent();
61   CHECK_EQ(computation, padding_value->parent());
62   TF_ASSIGN_OR_RETURN(
63       Shape pad_shape,
64       ShapeInference::InferPadShape(operand->shape(), padding_value->shape(),
65                                     padding_config));
66   return computation->AddInstruction(HloInstruction::CreatePad(
67       pad_shape, operand, padding_value, padding_config));
68 }
69 
MakeSliceHlo(HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)70 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
71                                        absl::Span<const int64> start_indices,
72                                        absl::Span<const int64> limit_indices,
73                                        absl::Span<const int64> strides) {
74   HloComputation* computation = operand->parent();
75   TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
76                                              operand->shape(), start_indices,
77                                              limit_indices, strides));
78   return computation->AddInstruction(HloInstruction::CreateSlice(
79       slice_shape, operand, start_indices, limit_indices, strides));
80 }
81 
MakeConvolveHlo(HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)82 StatusOr<HloInstruction*> MakeConvolveHlo(
83     HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
84     const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
85     const PrecisionConfig& precision_config) {
86   HloComputation* computation = lhs->parent();
87   CHECK_EQ(computation, rhs->parent());
88   TF_ASSIGN_OR_RETURN(Shape convolve_shape,
89                       ShapeInference::InferConvolveShape(
90                           lhs->shape(), rhs->shape(), feature_group_count, 1,
91                           window, dimension_numbers));
92   return computation->AddInstruction(HloInstruction::CreateConvolve(
93       convolve_shape, lhs, rhs, feature_group_count, 1, window,
94       dimension_numbers, precision_config));
95 }
96 
MakeTransposeHlo(HloInstruction * operand,absl::Span<const int64> dimensions)97 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
98                                            absl::Span<const int64> dimensions) {
99   HloComputation* computation = operand->parent();
100   TF_ASSIGN_OR_RETURN(
101       Shape transpose_shape,
102       ShapeInference::InferTransposeShape(operand->shape(), dimensions));
103   return computation->AddInstruction(
104       HloInstruction::CreateTranspose(transpose_shape, operand, dimensions));
105 }
106 
MakeReshapeHlo(const Shape & result_shape,HloInstruction * operand)107 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
108                                          HloInstruction* operand) {
109   HloComputation* computation = operand->parent();
110   return computation->AddInstruction(
111       HloInstruction::CreateReshape(result_shape, operand));
112 }
113 
MakeReshapeHlo(absl::Span<const int64> result_shape_dim_bounds,HloInstruction * operand)114 StatusOr<HloInstruction*> MakeReshapeHlo(
115     absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
116   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
117                                          result_shape_dim_bounds);
118   return MakeReshapeHlo(new_shape, operand);
119 }
120 
MakeDynamicSliceHlo(HloInstruction * operand,HloInstruction * start_indices,absl::Span<const int64> slice_sizes)121 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
122     HloInstruction* operand, HloInstruction* start_indices,
123     absl::Span<const int64> slice_sizes) {
124   HloComputation* computation = operand->parent();
125   CHECK_EQ(computation, start_indices->parent());
126   int64 rank = start_indices->shape().dimensions(0);
127   std::vector<HloInstruction*> scalar_start_indices;
128   for (int i = 0; i < rank; ++i) {
129     // TODO(b/118437727): Update callers to provide scalars directly.
130     auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
131         ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
132         start_indices, {i}, {i + 1}, {1}));
133     scalar_start_indices.push_back(
134         computation->AddInstruction(HloInstruction::CreateReshape(
135             ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
136             slice)));
137   }
138   std::vector<Shape> scalar_start_indices_shapes(
139       rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
140   TF_ASSIGN_OR_RETURN(
141       Shape dynamic_slice_shape,
142       ShapeInference::InferDynamicSliceShape(
143           operand->shape(), scalar_start_indices_shapes, slice_sizes));
144   return computation->AddInstruction(HloInstruction::CreateDynamicSlice(
145       dynamic_slice_shape, operand, scalar_start_indices, slice_sizes));
146 }
147 
MakeDynamicUpdateSliceHlo(HloInstruction * operand,HloInstruction * update,HloInstruction * start_indices)148 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
149     HloInstruction* operand, HloInstruction* update,
150     HloInstruction* start_indices) {
151   HloComputation* computation = operand->parent();
152   CHECK_EQ(computation, update->parent());
153   CHECK_EQ(computation, start_indices->parent());
154   int64 rank = start_indices->shape().dimensions(0);
155   std::vector<HloInstruction*> scalar_start_indices;
156   for (int i = 0; i < rank; ++i) {
157     // TODO(b/118437727): Update callers to provide scalars directly.
158     auto slice = computation->AddInstruction(HloInstruction::CreateSlice(
159         ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}),
160         start_indices, {i}, {i + 1}, {1}));
161     scalar_start_indices.push_back(
162         computation->AddInstruction(HloInstruction::CreateReshape(
163             ShapeUtil::MakeShape(start_indices->shape().element_type(), {}),
164             slice)));
165   }
166   std::vector<Shape> scalar_start_indices_shapes(
167       rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {}));
168   TF_ASSIGN_OR_RETURN(
169       Shape dynamic_update_slice_shape,
170       ShapeInference::InferDynamicUpdateSliceShape(
171           operand->shape(), update->shape(), scalar_start_indices_shapes));
172   return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
173       dynamic_update_slice_shape, operand, update, scalar_start_indices));
174 }
175 
MakeBroadcastHlo(HloInstruction * operand,absl::Span<const int64> broadcast_dimensions,absl::Span<const int64> result_shape_bounds)176 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
177                                  absl::Span<const int64> broadcast_dimensions,
178                                  absl::Span<const int64> result_shape_bounds) {
179   HloComputation* computation = operand->parent();
180   Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
181                                                result_shape_bounds);
182 
183   return computation->AddInstruction(HloInstruction::CreateBroadcast(
184       broadcast_shape, operand, broadcast_dimensions));
185 }
186 
MakeGetTupleElementHlo(HloInstruction * operand,int64 index)187 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
188                                                  int64 index) {
189   HloComputation* computation = operand->parent();
190 
191   TF_ASSIGN_OR_RETURN(
192       Shape gte_shape,
193       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
194   return computation->AddInstruction(
195       HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
196 }
197 
MakeConcatHlo(absl::Span<HloInstruction * const> operands,int64 dimension)198 StatusOr<HloInstruction*> MakeConcatHlo(
199     absl::Span<HloInstruction* const> operands, int64 dimension) {
200   CHECK_GT(operands.size(), 0);
201 
202   HloComputation* computation = operands[0]->parent();
203   CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
204     return instr->parent() == computation;
205   }));
206 
207   std::vector<const Shape*> operand_shapes;
208   absl::c_transform(operands, std::back_inserter(operand_shapes),
209                     [](HloInstruction* instr) { return &instr->shape(); });
210 
211   TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
212                                               operand_shapes, dimension));
213   return computation->AddInstruction(
214       HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
215 }
216 
MakeDotHlo(HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config)217 StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
218                                      const DotDimensionNumbers& dim_numbers,
219                                      const PrecisionConfig& precision_config) {
220   HloComputation* computation = lhs->parent();
221   CHECK_EQ(computation, rhs->parent());
222   TF_ASSIGN_OR_RETURN(
223       Shape dot_shape,
224       ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
225   return computation->AddInstruction(HloInstruction::CreateDot(
226       dot_shape, lhs, rhs, dim_numbers, precision_config));
227 }
228 
MakeMapHlo(absl::Span<HloInstruction * const> operands,HloComputation * map_computation)229 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
230                                      HloComputation* map_computation) {
231   CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
232   HloComputation* computation = operands.front()->parent();
233   std::vector<const Shape*> operand_shapes;
234   int64 max_operand_rank = 0;
235   for (const HloInstruction* operand : operands) {
236     CHECK_EQ(computation, operand->parent());
237     operand_shapes.push_back(&operand->shape());
238     max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
239   }
240   std::vector<int64> map_dims(max_operand_rank);
241   std::iota(map_dims.begin(), map_dims.end(), 0);
242   TF_ASSIGN_OR_RETURN(
243       Shape map_shape,
244       ShapeInference::InferMapShape(
245           operand_shapes, map_computation->ComputeProgramShape(), map_dims));
246   return computation->AddInstruction(
247       HloInstruction::CreateMap(map_shape, operands, map_computation));
248 }
249 
MakeReduceHlo(HloInstruction * operand,HloInstruction * init_value,HloOpcode binary_opcode,HloModule * module)250 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
251                                         HloInstruction* init_value,
252                                         HloOpcode binary_opcode,
253                                         HloModule* module) {
254   DCHECK_NE(nullptr, module);
255   std::vector<int64> all_dims(operand->shape().rank());
256   std::iota(all_dims.begin(), all_dims.end(), 0);
257 
258   auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
259   HloComputation* reduce_computation;
260   {
261     HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
262     auto lhs = b.AddInstruction(
263         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
264     auto rhs = b.AddInstruction(
265         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
266     b.AddInstruction(
267         HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
268     reduce_computation = module->AddEmbeddedComputation(b.Build());
269   }
270 
271   return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
272       scalar_shape, operand, init_value, all_dims, reduce_computation));
273 }
274 
MakeSelectHlo(HloInstruction * pred,HloInstruction * on_true,HloInstruction * on_false)275 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
276                                         HloInstruction* on_true,
277                                         HloInstruction* on_false) {
278   HloComputation* computation = pred->parent();
279   DCHECK_EQ(computation, on_true->parent());
280   DCHECK_EQ(computation, on_false->parent());
281   TF_ASSIGN_OR_RETURN(Shape select_shape,
282                       ShapeInference::InferTernaryOpShape(
283                           HloOpcode::kSelect, pred, on_true, on_false));
284   return computation->AddInstruction(HloInstruction::CreateTernary(
285       select_shape, HloOpcode::kSelect, pred, on_true, on_false));
286 }
287 
MakeSortHlo(const Shape & sort_shape,absl::Span<HloInstruction * const> operands,int64 dimension_to_sort,bool is_stable,HloComputation::Builder * builder,HloModule * module)288 StatusOr<HloInstruction*> MakeSortHlo(
289     const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
290     int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
291     HloModule* module) {
292   CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
293   HloComputation* compare_computation;
294   XlaBuilder b("Sort.Compare");
295   std::vector<PrimitiveType> operand_types(operands.size());
296   for (int64 i = 0; i < operands.size(); ++i) {
297     operand_types[i] = operands[i]->shape().element_type();
298   }
299   XlaComputation comparator = CreateScalarLtComputation(operand_types, &b);
300   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape());
301   HloModuleConfig config(program_shape);
302   TF_ASSIGN_OR_RETURN(auto new_module,
303                       HloModule::CreateFromProto(comparator.proto(), config));
304   HloCloneContext context(module);
305   compare_computation =
306       module->DeepCloneComputation(new_module->entry_computation(), &context);
307   return builder->AddInstruction(HloInstruction::CreateSort(
308       sort_shape, dimension_to_sort, operands, compare_computation, is_stable));
309 }
310 
CollapseFirstNDims(HloInstruction * operand,int64 n)311 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
312   CHECK_GT(n, 0);
313 
314   const Shape& operand_shape = operand->shape();
315   CHECK_GE(operand_shape.dimensions_size(), n);
316   int64 new_shape_leading_bound = 1;
317   for (int64 i = 0; i < n; i++) {
318     new_shape_leading_bound *= operand_shape.dimensions(i);
319   }
320 
321   std::vector<int64> new_shape_dims;
322   new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1);
323   new_shape_dims.push_back(new_shape_leading_bound);
324 
325   std::copy(operand_shape.dimensions().begin() + n,
326             operand_shape.dimensions().end(),
327             std::back_inserter(new_shape_dims));
328 
329   Shape output_shape =
330       ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims);
331 
332   return MakeReshapeHlo(output_shape, operand);
333 }
334 
PrependDegenerateDims(HloInstruction * operand,int64 n)335 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
336                                                 int64 n) {
337   CHECK_GT(n, 0);
338   std::vector<int64> new_shape_dims;
339   const Shape& operand_shape = operand->shape();
340   new_shape_dims.reserve(n + operand_shape.dimensions_size());
341   new_shape_dims.insert(new_shape_dims.begin(), n, 1);
342   absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
343   return MakeReshapeHlo(new_shape_dims, operand);
344 }
345 
ExpandFirstDimIntoNDims(HloInstruction * operand,absl::Span<const int64> expanded_dims)346 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
347     HloInstruction* operand, absl::Span<const int64> expanded_dims) {
348   CHECK_GT(operand->shape().dimensions_size(), 0);
349   CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
350 
351   std::vector<int64> expanded_shape_dim_bounds;
352   expanded_shape_dim_bounds.reserve(expanded_dims.size() +
353                                     operand->shape().dimensions_size() - 1);
354   absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
355   std::copy(operand->shape().dimensions().begin() + 1,
356             operand->shape().dimensions().end(),
357             std::back_inserter(expanded_shape_dim_bounds));
358   Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
359                                          expanded_shape_dim_bounds);
360   return MakeReshapeHlo(new_shape, operand);
361 }
362 
ElideDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_elide)363 StatusOr<HloInstruction*> ElideDegenerateDims(
364     HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
365   CHECK(absl::c_is_sorted(dims_to_elide));
366 
367   const Shape& input_shape = operand->shape();
368   // First accumulate in reverse
369   std::vector<int64> new_shape_dim_bounds;
370   new_shape_dim_bounds.reserve(input_shape.dimensions_size() -
371                                dims_to_elide.size());
372   int64 dims_to_elide_idx = dims_to_elide.size() - 1;
373   for (int64 i = input_shape.dimensions_size() - 1; i >= 0; i--) {
374     if (dims_to_elide_idx >= 0 && i == dims_to_elide[dims_to_elide_idx]) {
375       CHECK_EQ(input_shape.dimensions(i), 1);
376       dims_to_elide_idx--;
377     } else {
378       new_shape_dim_bounds.push_back(input_shape.dimensions(i));
379     }
380   }
381 
382   absl::c_reverse(new_shape_dim_bounds);
383   Shape output_shape =
384       ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
385   return MakeReshapeHlo(output_shape, operand);
386 }
387 
InsertDegenerateDims(HloInstruction * operand,absl::Span<const int64> dims_to_insert)388 StatusOr<HloInstruction*> InsertDegenerateDims(
389     HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
390   CHECK(absl::c_is_sorted(dims_to_insert));
391 
392   const Shape& operand_shape = operand->shape();
393   int64 output_shape_rank =
394       operand_shape.dimensions_size() + dims_to_insert.size();
395   for (auto dim_to_insert : dims_to_insert) {
396     CHECK_LT(dim_to_insert, output_shape_rank);
397   }
398 
399   std::vector<int64> output_shape_dim_bounds;
400   output_shape_dim_bounds.reserve(output_shape_rank);
401   int64 operand_dims_idx = 0;
402   int64 dims_to_insert_idx = 0;
403   for (int64 i = 0; i < output_shape_rank; ++i) {
404     if (dims_to_insert_idx < dims_to_insert.size() &&
405         i == dims_to_insert[dims_to_insert_idx]) {
406       output_shape_dim_bounds.push_back(1);
407       ++dims_to_insert_idx;
408     } else {
409       output_shape_dim_bounds.push_back(
410           operand_shape.dimensions(operand_dims_idx));
411       ++operand_dims_idx;
412     }
413   }
414 
415   Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
416                                             output_shape_dim_bounds);
417   return MakeReshapeHlo(output_shape, operand);
418 }
419 
PadVectorWithZeros(HloInstruction * operand,int64 zeros_to_prepend,int64 zeros_to_append)420 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
421                                              int64 zeros_to_prepend,
422                                              int64 zeros_to_append) {
423   HloComputation* computation = operand->parent();
424   CHECK_EQ(operand->shape().dimensions_size(), 1);
425   PaddingConfig padding_config;
426   PaddingConfig::PaddingConfigDimension padding_config_dim;
427   padding_config_dim.set_edge_padding_low(zeros_to_prepend);
428   padding_config_dim.set_edge_padding_high(zeros_to_append);
429   *padding_config.add_dimensions() = padding_config_dim;
430 
431   HloInstruction* zero =
432       computation->AddInstruction(HloInstruction::CreateConstant(
433           LiteralUtil::Zero(operand->shape().element_type())));
434   return MakePadHlo(operand, zero, padding_config);
435 }
436 
BroadcastZeros(HloComputation * computation,PrimitiveType element_type,absl::Span<const int64> broadcast_dimensions)437 HloInstruction* BroadcastZeros(HloComputation* computation,
438                                PrimitiveType element_type,
439                                absl::Span<const int64> broadcast_dimensions) {
440   HloInstruction* zero = computation->AddInstruction(
441       HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
442   return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
443                           /*result_shape_bounds=*/broadcast_dimensions);
444 }
445 
CreateComputationWithSignature(absl::Span<const Shape * const> domain,const Shape & range,absl::string_view name)446 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
447     absl::Span<const Shape* const> domain, const Shape& range,
448     absl::string_view name) {
449   HloComputation::Builder b{string(name)};
450   int64 param_idx = 0;
451   for (const Shape* param_shape : domain) {
452     b.AddInstruction(HloInstruction::CreateParameter(
453         param_idx, *param_shape, StrCat("param.", param_idx)));
454     param_idx++;
455   }
456 
457   // We can't change the root type of a computation once it is created so create
458   // a dummy root instruction to give the computation the right root shape.  In
459   // the future we may want to use a (recursive) broadcast here to avoid
460   // creating large constants.
461   b.AddInstruction(
462       HloInstruction::CreateConstant(Literal::CreateFromShape(range)));
463 
464   return b.Build();
465 }
466 
467 }  // namespace xla
468