1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
18
19 #include <string>
20
21 #include "absl/types/span.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29 namespace cpu {
30
31 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more
32 // obvious.
33
GetIeeeF32(float f)34 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
GetIeeeF32FromBitwiseRep(int32 bitwise_value)35 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
36 return llvm::APFloat(llvm::APFloat::IEEEsingle(),
37 llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
38 }
39
40 // A thin wrapper around llvm_util.h to make code generating vector math flow
41 // more readable.
42 class VectorSupportLibrary {
43 public:
44 // This VectorSupportLibrary instance remembers `primitive_type` and
45 // `vector_size`, and these are implicitly used by the methods on this
46 // instance (i.e. LoadVector will load a vector of type <`vector_size` x
47 // `primitive_type`>).
48 VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
49 llvm::IRBuilder<>* b, std::string name);
50
51 llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
Mul(int64 lhs,llvm::Value * rhs)52 llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
53 return Mul(b()->getInt64(lhs), rhs);
54 }
Mul(const llvm::APFloat & lhs,llvm::Value * rhs)55 llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
56 return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
57 }
58
59 // If your call resolved to these then you probably wanted the versions taking
60 // APFloat.
61 llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
62 llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
63
64 llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
Add(int64 lhs,llvm::Value * rhs)65 llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
66 return Add(b()->getInt64(lhs), rhs);
67 }
Add(const llvm::APFloat & lhs,llvm::Value * rhs)68 llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
69 return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
70 }
71
72 // If your call resolved to these then you probably wanted the versions taking
73 // APFloat.
74 llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
75 llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
76
77 llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
Sub(llvm::Value * lhs,const llvm::APFloat & rhs)78 llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
79 return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
80 }
81 llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs,
82 bool enable_fast_min_max);
Max(const llvm::APFloat & lhs,llvm::Value * rhs,bool enable_fast_min_max)83 llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs,
84 bool enable_fast_min_max) {
85 return Max(GetConstantFloat(rhs->getType(), lhs), rhs, enable_fast_min_max);
86 }
87 llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
88
MulAdd(llvm::Value * a,llvm::Value * b,llvm::Value * c)89 llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
90 return Add(c, Mul(a, b));
91 }
92
MulAdd(llvm::Value * a,llvm::Value * b,const llvm::APFloat & c)93 llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
94 return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
95 }
96
MulAdd(llvm::Value * a,const llvm::APFloat & b,const llvm::APFloat & c)97 llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
98 const llvm::APFloat& c) {
99 return Add(GetConstantFloat(a->getType(), c),
100 Mul(a, GetConstantFloat(a->getType(), b)));
101 }
102
103 llvm::Value* Floor(llvm::Value* a);
104
105 // Precondition: Neither `low` nor `high` is nan.
106 llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
107 const llvm::APFloat& high);
108
SplatFloat(const llvm::APFloat & d)109 llvm::Value* SplatFloat(const llvm::APFloat& d) {
110 return GetConstantFloat(vector_type(), d);
111 }
112
113 // These compare instructions return a floating point typed mask instead of an
114 // i1. For instance, on a vector typed input, lanes where the predicate is
115 // true get a float with all ones and other lanes get a float with all zeros.
116 // This is slightly odd from the perspective of LLVM's type system, but it
117 // makes kernel IR generation code written using VectorSupportLibrary (its
118 // raison d'etre) less cluttered.
119
120 llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpEQMask(llvm::Value * lhs,const llvm::APFloat & rhs)121 llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
122 return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs));
123 }
124 llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
125 llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpOLTMask(llvm::Value * lhs,const llvm::APFloat & rhs)126 llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
127 return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
128 }
129
130 // These boolean operations operate on the bitwise values of the floating
131 // point inputs. They return a (vector of) float(s) but like in the mask
132 // generating predicates above this type system oddity makes the kernel IR
133 // generation code less cluttered.
134 llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
FloatAnd(llvm::Value * lhs,const llvm::APFloat & rhs)135 llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
136 return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
137 }
138 llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
FloatOr(llvm::Value * lhs,const llvm::APFloat & rhs)139 llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
140 return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
141 }
142 llvm::Value* FloatNot(llvm::Value* lhs);
FloatAndNot(llvm::Value * lhs,llvm::Value * rhs)143 llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
144 return FloatAnd(FloatNot(lhs), rhs);
145 }
146
147 llvm::Value* BroadcastScalar(llvm::Value* x);
BroadcastScalar(const llvm::APFloat & d)148 llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
149 return BroadcastScalar(GetConstantFloat(scalar_type(), d));
150 }
151
152 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
153 llvm::Value* offset_elements);
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements,int64 scale)154 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
155 llvm::Value* offset_elements, int64 scale) {
156 return ComputeOffsetPointer(
157 base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements));
158 }
ComputeOffsetPointer(llvm::Value * base_pointer,int64 offset_elements)159 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
160 int64 offset_elements) {
161 return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements));
162 }
163
164 llvm::Value* LoadVector(llvm::Value* pointer);
165
LoadVector(llvm::Value * base_pointer,llvm::Value * offset_elements)166 llvm::Value* LoadVector(llvm::Value* base_pointer,
167 llvm::Value* offset_elements) {
168 return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
169 }
170
LoadVector(llvm::Value * base_pointer,int64 offset_elements)171 llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
172 return LoadVector(base_pointer, b()->getInt64(offset_elements));
173 }
174
175 llvm::Value* LoadScalar(llvm::Value* pointer);
176
LoadScalar(llvm::Value * base_pointer,llvm::Value * offset_elements)177 llvm::Value* LoadScalar(llvm::Value* base_pointer,
178 llvm::Value* offset_elements) {
179 return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
180 }
181
LoadScalar(llvm::Value * base_pointer,int64 offset_elements)182 llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
183 return LoadScalar(base_pointer, b()->getInt64(offset_elements));
184 }
185
186 void StoreVector(llvm::Value* value, llvm::Value* pointer);
187
StoreVector(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)188 void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
189 llvm::Value* offset_elements) {
190 StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
191 }
192
StoreVector(llvm::Value * value,llvm::Value * base_pointer,int64 offset_elements)193 void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
194 int64 offset_elements) {
195 StoreVector(value, base_pointer, b()->getInt64(offset_elements));
196 }
197
198 void StoreScalar(llvm::Value* value, llvm::Value* pointer);
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)199 void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
200 llvm::Value* offset_elements) {
201 StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
202 }
203
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,int64 offset_elements)204 void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
205 int64 offset_elements) {
206 StoreScalar(base_pointer, b()->getInt64(offset_elements));
207 }
208
209 llvm::Value* LoadBroadcast(llvm::Value* pointer);
LoadBroadcast(llvm::Value * base_pointer,llvm::Value * offset_elements)210 llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
211 llvm::Value* offset_elements) {
212 return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
213 }
LoadBroadcast(llvm::Value * base_pointer,int64 offset_elements)214 llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
215 return LoadBroadcast(base_pointer, b()->getInt64(offset_elements));
216 }
217
218 // Compute the horizontal sum of each vector in `vectors`. The i'th element
219 // in the result vector is the (scalar) horizontal sum of the i'th vector in
220 // `vectors`. If `init_values` is not nullptr then the value in the i'th lane
221 // in `init_values` is added to the i'th horizontal sum.
222 std::vector<llvm::Value*> ComputeHorizontalSums(
223 std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
224
225 llvm::Value* GetZeroVector();
226 llvm::Value* GetZeroScalar();
227
b()228 llvm::IRBuilder<>* b() const { return b_; }
vector_size()229 int64 vector_size() const { return vector_size_; }
vector_type()230 llvm::Type* vector_type() const { return vector_type_; }
vector_pointer_type()231 llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
scalar_type()232 llvm::Type* scalar_type() const { return scalar_type_; }
scalar_pointer_type()233 llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
scalar_byte_size()234 int64 scalar_byte_size() const {
235 return primitive_util::BitWidth(primitive_type_) / 8;
236 }
237
name()238 const std::string& name() const { return name_; }
239
240 private:
241 llvm::Value* ExtractLowHalf(llvm::Value*);
242 llvm::Value* ExtractHighHalf(llvm::Value*);
243
244 llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
245 llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
246
247 llvm::Value* AddReduce(llvm::Value* vector);
248
249 // Checks that each value in `values` is either of type scalar_type() or
250 // vector_type(). This LOG(FATAL)'s so it should only be called in cases
251 // where a mismatching type is a programmer bug.
252 void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
253
254 // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The
255 // resulting IR for an 8-float wide vector is expected to lower to a single
256 // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
257 // other cases.
258 //
259 // For a vector width of 8, the result vector is computed as:
260 // Result[0] = Lhs[0] + Lhs[1]
261 // Result[1] = Lhs[2] + Lhs[3]
262 // Result[2] = Rhs[0] + Rhs[1]
263 // Result[3] = Rhs[2] + Rhs[3]
264 // Result[4] = Lhs[4] + Lhs[5]
265 // Result[5] = Lhs[6] + Lhs[7]
266 // Result[6] = Rhs[4] + Rhs[5]
267 // Result[7] = Rhs[6] + Rhs[7]
268 llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
269
270 std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
271 std::vector<llvm::Value*> vectors, llvm::Value* init_values);
272
273 llvm::Type* IntegerTypeForFloatSize(bool vector);
274 llvm::Value* I1ToFloat(llvm::Value* i1);
GetConstantFloat(llvm::Type * type,const llvm::APFloat & f)275 llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
276 llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
277 if (llvm::isa<llvm::VectorType>(type)) {
278 return llvm::ConstantVector::getSplat(
279 llvm::ElementCount::getFixed(vector_size()), scalar_value);
280 }
281 return scalar_value;
282 }
283
284 int64 vector_size_;
285 PrimitiveType primitive_type_;
286 llvm::IRBuilder<>* b_;
287 llvm::Type* vector_type_;
288 llvm::Type* vector_pointer_type_;
289 llvm::Type* scalar_type_;
290 llvm::Type* scalar_pointer_type_;
291 std::string name_;
292 };
293
294 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass
295 // can later convert to a SSA value.
296 class LlvmVariable {
297 public:
298 LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b);
299
300 llvm::Value* Get() const;
301 void Set(llvm::Value* new_value);
302
303 private:
304 llvm::AllocaInst* alloca_;
305 llvm::IRBuilder<>* b_;
306 };
307
308 class VectorVariable : public LlvmVariable {
309 public:
VectorVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)310 VectorVariable(VectorSupportLibrary* vector_support,
311 llvm::Value* initial_value)
312 : LlvmVariable(vector_support->vector_type(), vector_support->b()) {
313 Set(initial_value);
314 }
315 };
316
317 class ScalarVariable : public LlvmVariable {
318 public:
ScalarVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)319 ScalarVariable(VectorSupportLibrary* vector_support,
320 llvm::Value* initial_value)
321 : LlvmVariable(vector_support->scalar_type(), vector_support->b()) {
322 Set(initial_value);
323 }
324 };
325
326 // This wraps a set of alloca-backed stack variables that can, as a whole, store
327 // a tile. A "tile" is a sequence of vectors that is typically used as a 2D
328 // grid of scalar values (e.g. for tiled GEMMs).
329 class TileVariable {
330 public:
331 TileVariable(VectorSupportLibrary* vector_support,
332 std::vector<llvm::Value*> initial_value);
333
334 std::vector<llvm::Value*> Get() const;
335 void Set(absl::Span<llvm::Value* const> value);
336
337 private:
338 std::vector<VectorVariable> storage_;
339 };
340 } // namespace cpu
341 } // namespace xla
342
343 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
344