1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
18
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/statusor.h"
23
24 namespace xla {
25
26 // Some lightweight utilities intended to make HLO instruction creation more
27 // ergonomic. We don't have a complete set of helpers yet -- I expect we'll
28 // expand this interface as needed on an ad-hoc basis.
29
30 // Creates a binary HLO instruction and adds it to the computation containing
31 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
32 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
33 HloInstruction* rhs);
34
35 // Creates a compare HLO instruction and adds it to the computation containing
36 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
37 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
38 HloInstruction* lhs,
39 HloInstruction* rhs);
40
41 // Creates a pad HLO instruction and adds it to the computation containing
42 // `operand` and `padding_value` (`operand` and `padding_value` must be in the
43 // same computation).
44 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
45 HloInstruction* padding_value,
46 const PaddingConfig& padding_config);
47
48 // Creates a slice HLO instruction and adds it to the computation containing
49 // `operand`.
50 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
51 absl::Span<const int64> start_indices,
52 absl::Span<const int64> limit_indices,
53 absl::Span<const int64> strides);
54
55 // Creates a convolution HLO instruction and adds it to the computation
56 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
57 StatusOr<HloInstruction*> MakeConvolveHlo(
58 HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
59 const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
60 const PrecisionConfig& precision_config);
61
62 // Creates a transpose HLO instruction and adds it to the computation containing
63 // `operand`.
64 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
65 absl::Span<const int64> dimensions);
66
67 // Creates a reshape HLO instruction and adds it to the computation containing
68 // `operand`.
69 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
70 HloInstruction* operand);
71
72 StatusOr<HloInstruction*> MakeReshapeHlo(
73 absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand);
74
75 // Creates a dynamic-slice HLO instruction and adds it to the computation
76 // containing `operand` and `start_indices` (`operand` and `start_indices` must
77 // be in the same computation).
78 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
79 HloInstruction* operand, HloInstruction* start_indices,
80 absl::Span<const int64> slice_sizes);
81
82 // Creates a dynamic-update-slice HLO instruction and adds it to the computation
83 // containing `operand`, `update` and `start_indices` (`operand`, `update` and
84 // `start_indices` must be in the same computation).
85 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
86 HloInstruction* operand, HloInstruction* update,
87 HloInstruction* start_indices);
88
89 // Creates a broadcast HLO instruction and adds it to the computation containing
90 // `operand`.
91 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
92 absl::Span<const int64> broadcast_dimensions,
93 absl::Span<const int64> result_shape_bounds);
94
95 // Creates a GetTupleElement HLO instruction and adds it to the computation
96 // containing `operand`.
97 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
98 int64 index);
99
100 // Creates a Concatenate HLO instruction and adds it to the computation
101 // containing `operands` (`operands` must be non-empty and every element must be
102 // contained in the same computation).
103 StatusOr<HloInstruction*> MakeConcatHlo(
104 absl::Span<HloInstruction* const> operands, int64 dimension);
105
106 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
107 // and `rhs` (both must be in the same computation).
108 StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
109 const DotDimensionNumbers& dim_numbers,
110 const PrecisionConfig& precision_config);
111
112 // Creates a Map HLO instruction and adds it to the computation containing the
113 // operands. All operands must be in the same computation.
114 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
115 HloComputation* map_computation);
116
117 // Creates a Reduce HLO instruction and adds it to the computation containing
118 // the operand. This will create the sub-computation needed for the reduction in
119 // the given module. binary_opcode should represent a binary operation.
120 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
121 HloInstruction* init_value,
122 HloOpcode binary_opcode,
123 HloModule* module);
124
125 // Creates a Select HLO instruction and adds it to the computation containing
126 // the predicate. The on_true and on_false instructions must also be contained
127 // in the same computation.
128 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
129 HloInstruction* on_true,
130 HloInstruction* on_false);
131
132 // Creates a Sort HLO instruction and adds it to the computation containing the
133 // operands. All operands must be in the same computation. Also creates a
134 // default compare sub-computation which sorts the first operand into ascending
135 // order. 'is_stable' specifies whether the sorting should be stable.
136 StatusOr<HloInstruction*> MakeSortHlo(
137 const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
138 int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
139 HloModule* module);
140
141 // Creates an R1 Constant HLO instruction of the given PrimitiveType with the
142 // given values and adds it to the given computation.
143 template <typename NativeT>
MakeR1ConstantHlo(HloComputation * computation,PrimitiveType type,absl::Span<const NativeT> values)144 StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
145 PrimitiveType type,
146 absl::Span<const NativeT> values) {
147 Literal literal = LiteralUtil::CreateR1<NativeT>(values);
148 if (literal.shape().element_type() != type) {
149 TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
150 }
151 return computation->AddInstruction(
152 HloInstruction::CreateConstant(std::move(literal)));
153 }
154
155 // -----------------------------------------------------------------------------
156 // Some other miscellaneous helpers to generate common HLO patterns. All of
157 // these add all the instructions they generate into the computation containing
158 // their operand(s).
159
160 // Collapses (via reshape) the first N (logical) dimensions of `operand` into a
161 // single leading dimension. `operand` must have rank > `n` and `n` must not be
162 // 0.
163 //
164 // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is
165 // the `operand` reshaped to [56,9].
166 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n);
167
168 // Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand`
169 // using a reshape.
170 //
171 // For instance if operand has shape f32[3,4,5] then this returns the operand
172 // reshaped to f32[1,3,4,5]. If the operand is a f32 scalar (i.e. has shape
173 // f32[]) then this returns the operand reshaped to f32[1].
174 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
175 int64 n);
176
177 // Expands (via reshape) the first (logical) dimension of `operand` into a
178 // sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1
179 // and the number of elements in its first dimension must be equal to the
180 // product of `expanded_dims`.
181 //
182 // For instance if `operand` has shape f32[200,9,7] and expanded_dims is
183 // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
184 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
185 HloInstruction* operand, absl::Span<const int64> expanded_dims);
186
187 // Elides (via reshape) a set of degenerate dimensions (dimensions containing
188 // exactly one element), `dims_to_elide` from `operand`. Every dimension in
189 // `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be
190 // sorted and not contain duplicates.
191 //
192 // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
193 // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
194 StatusOr<HloInstruction*> ElideDegenerateDims(
195 HloInstruction* operand, absl::Span<const int64> dims_to_elide);
196
197 // Inserts (via reshape) a set of degenerate dimensions (dimensions containing
198 // exactly one element), `dims_to_insert` into `operand`. The dimensions in
199 // `dims_to_insert` refer to the dimensions in the result, and hence should be
200 // less than the rank of the result. Also, `dims_to_insert` must be sorted.
201 //
202 // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
203 // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
204 StatusOr<HloInstruction*> InsertDegenerateDims(
205 HloInstruction* operand, absl::Span<const int64> dims_to_insert);
206
207 // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
208 // front and `zeros_to_append` zeros in the back.
209 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
210 int64 zeros_to_prepend,
211 int64 zeros_to_append);
212
213 // Broadcasts a zero value of type `element_type` into a tensor with element
214 // type `element_type` and dimension bounds `broadcast_dimensions`. The
215 // broadcast instruction is emitted into `computation`.
216 HloInstruction* BroadcastZeros(HloComputation* computation,
217 PrimitiveType element_type,
218 absl::Span<const int64> broadcast_dimensions);
219
220 // Creates a HLO computation that takes arguments of type `domain` and produces
221 // a value of type `range`.
222 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
223 absl::Span<const Shape* const> domain, const Shape& range,
224 absl::string_view name);
225
226 } // namespace xla
227
228 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
229