1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
18 
19 #include <string>
20 
21 #include "absl/types/span.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 namespace cpu {
30 
31 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more
32 // obvious.
33 
GetIeeeF32(float f)34 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
GetIeeeF32FromBitwiseRep(int32 bitwise_value)35 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
36   return llvm::APFloat(llvm::APFloat::IEEEsingle(),
37                        llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
38 }
39 
40 // A thin wrapper around llvm_util.h to make code generating vector math flow
41 // more readable.
42 class VectorSupportLibrary {
43  public:
44   // This VectorSupportLibrary instance remembers `primitive_type` and
45   // `vector_size`, and these are implicitly used by the methods on this
46   // instance (i.e. LoadVector will load a vector of type <`vector_size` x
47   // `primitive_type`>).
48   VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
49                        llvm::IRBuilder<>* b, std::string name);
50 
51   llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
Mul(int64 lhs,llvm::Value * rhs)52   llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
53     return Mul(b()->getInt64(lhs), rhs);
54   }
Mul(const llvm::APFloat & lhs,llvm::Value * rhs)55   llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
56     return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
57   }
58 
59   // If your call resolved to these then you probably wanted the versions taking
60   // APFloat.
61   llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
62   llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
63 
64   llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
Add(int64 lhs,llvm::Value * rhs)65   llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
66     return Add(b()->getInt64(lhs), rhs);
67   }
Add(const llvm::APFloat & lhs,llvm::Value * rhs)68   llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
69     return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
70   }
71 
72   // If your call resolved to these then you probably wanted the versions taking
73   // APFloat.
74   llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
75   llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
76 
77   llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
Sub(llvm::Value * lhs,const llvm::APFloat & rhs)78   llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
79     return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
80   }
81   llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs,
82                    bool enable_fast_min_max);
Max(const llvm::APFloat & lhs,llvm::Value * rhs,bool enable_fast_min_max)83   llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs,
84                    bool enable_fast_min_max) {
85     return Max(GetConstantFloat(rhs->getType(), lhs), rhs, enable_fast_min_max);
86   }
87   llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
88 
MulAdd(llvm::Value * a,llvm::Value * b,llvm::Value * c)89   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
90     return Add(c, Mul(a, b));
91   }
92 
MulAdd(llvm::Value * a,llvm::Value * b,const llvm::APFloat & c)93   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
94     return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
95   }
96 
MulAdd(llvm::Value * a,const llvm::APFloat & b,const llvm::APFloat & c)97   llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
98                       const llvm::APFloat& c) {
99     return Add(GetConstantFloat(a->getType(), c),
100                Mul(a, GetConstantFloat(a->getType(), b)));
101   }
102 
103   llvm::Value* Floor(llvm::Value* a);
104 
105   // Precondition: Neither `low` nor `high` is nan.
106   llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
107                      const llvm::APFloat& high);
108 
SplatFloat(const llvm::APFloat & d)109   llvm::Value* SplatFloat(const llvm::APFloat& d) {
110     return GetConstantFloat(vector_type(), d);
111   }
112 
113   // These compare instructions return a floating point typed mask instead of an
114   // i1.  For instance, on a vector typed input, lanes where the predicate is
115   // true get a float with all ones and other lanes get a float with all zeros.
116   // This is slightly odd from the perspective of LLVM's type system, but it
117   // makes kernel IR generation code written using VectorSupportLibrary (its
118   // raison d'etre) less cluttered.
119 
120   llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpEQMask(llvm::Value * lhs,const llvm::APFloat & rhs)121   llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
122     return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs));
123   }
124   llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
125   llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpOLTMask(llvm::Value * lhs,const llvm::APFloat & rhs)126   llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
127     return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
128   }
129 
130   // These boolean operations operate on the bitwise values of the floating
131   // point inputs.  They return a (vector of) float(s) but like in the mask
132   // generating predicates above this type system oddity makes the kernel IR
133   // generation code less cluttered.
134   llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
FloatAnd(llvm::Value * lhs,const llvm::APFloat & rhs)135   llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
136     return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
137   }
138   llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
FloatOr(llvm::Value * lhs,const llvm::APFloat & rhs)139   llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
140     return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
141   }
142   llvm::Value* FloatNot(llvm::Value* lhs);
FloatAndNot(llvm::Value * lhs,llvm::Value * rhs)143   llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
144     return FloatAnd(FloatNot(lhs), rhs);
145   }
146 
147   llvm::Value* BroadcastScalar(llvm::Value* x);
BroadcastScalar(const llvm::APFloat & d)148   llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
149     return BroadcastScalar(GetConstantFloat(scalar_type(), d));
150   }
151 
152   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
153                                     llvm::Value* offset_elements);
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements,int64 scale)154   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
155                                     llvm::Value* offset_elements, int64 scale) {
156     return ComputeOffsetPointer(
157         base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements));
158   }
ComputeOffsetPointer(llvm::Value * base_pointer,int64 offset_elements)159   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
160                                     int64 offset_elements) {
161     return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements));
162   }
163 
164   llvm::Value* LoadVector(llvm::Value* pointer);
165 
LoadVector(llvm::Value * base_pointer,llvm::Value * offset_elements)166   llvm::Value* LoadVector(llvm::Value* base_pointer,
167                           llvm::Value* offset_elements) {
168     return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
169   }
170 
LoadVector(llvm::Value * base_pointer,int64 offset_elements)171   llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
172     return LoadVector(base_pointer, b()->getInt64(offset_elements));
173   }
174 
175   llvm::Value* LoadScalar(llvm::Value* pointer);
176 
LoadScalar(llvm::Value * base_pointer,llvm::Value * offset_elements)177   llvm::Value* LoadScalar(llvm::Value* base_pointer,
178                           llvm::Value* offset_elements) {
179     return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
180   }
181 
LoadScalar(llvm::Value * base_pointer,int64 offset_elements)182   llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
183     return LoadScalar(base_pointer, b()->getInt64(offset_elements));
184   }
185 
186   void StoreVector(llvm::Value* value, llvm::Value* pointer);
187 
StoreVector(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)188   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
189                    llvm::Value* offset_elements) {
190     StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
191   }
192 
StoreVector(llvm::Value * value,llvm::Value * base_pointer,int64 offset_elements)193   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
194                    int64 offset_elements) {
195     StoreVector(value, base_pointer, b()->getInt64(offset_elements));
196   }
197 
198   void StoreScalar(llvm::Value* value, llvm::Value* pointer);
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)199   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
200                    llvm::Value* offset_elements) {
201     StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
202   }
203 
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,int64 offset_elements)204   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
205                    int64 offset_elements) {
206     StoreScalar(base_pointer, b()->getInt64(offset_elements));
207   }
208 
209   llvm::Value* LoadBroadcast(llvm::Value* pointer);
LoadBroadcast(llvm::Value * base_pointer,llvm::Value * offset_elements)210   llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
211                              llvm::Value* offset_elements) {
212     return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
213   }
LoadBroadcast(llvm::Value * base_pointer,int64 offset_elements)214   llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
215     return LoadBroadcast(base_pointer, b()->getInt64(offset_elements));
216   }
217 
218   // Compute the horizontal sum of each vector in `vectors`.  The i'th element
219   // in the result vector is the (scalar) horizontal sum of the i'th vector in
220   // `vectors`.  If `init_values` is not nullptr then the value in the i'th lane
221   // in `init_values` is added to the i'th horizontal sum.
222   std::vector<llvm::Value*> ComputeHorizontalSums(
223       std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
224 
225   llvm::Value* GetZeroVector();
226   llvm::Value* GetZeroScalar();
227 
b()228   llvm::IRBuilder<>* b() const { return b_; }
vector_size()229   int64 vector_size() const { return vector_size_; }
vector_type()230   llvm::Type* vector_type() const { return vector_type_; }
vector_pointer_type()231   llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
scalar_type()232   llvm::Type* scalar_type() const { return scalar_type_; }
scalar_pointer_type()233   llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
scalar_byte_size()234   int64 scalar_byte_size() const {
235     return primitive_util::BitWidth(primitive_type_) / 8;
236   }
237 
name()238   const std::string& name() const { return name_; }
239 
240  private:
241   llvm::Value* ExtractLowHalf(llvm::Value*);
242   llvm::Value* ExtractHighHalf(llvm::Value*);
243 
244   llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
245   llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
246 
247   llvm::Value* AddReduce(llvm::Value* vector);
248 
249   // Checks that each value in `values` is either of type scalar_type() or
250   // vector_type().  This LOG(FATAL)'s so it should only be called in cases
251   // where a mismatching type is a programmer bug.
252   void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
253 
254   // Perform an X86 AVX style horizontal add between `lhs` and `rhs`.  The
255   // resulting IR for an 8-float wide vector is expected to lower to a single
256   // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
257   // other cases.
258   //
259   // For a vector width of 8, the result vector is computed as:
260   //   Result[0] = Lhs[0] + Lhs[1]
261   //   Result[1] = Lhs[2] + Lhs[3]
262   //   Result[2] = Rhs[0] + Rhs[1]
263   //   Result[3] = Rhs[2] + Rhs[3]
264   //   Result[4] = Lhs[4] + Lhs[5]
265   //   Result[5] = Lhs[6] + Lhs[7]
266   //   Result[6] = Rhs[4] + Rhs[5]
267   //   Result[7] = Rhs[6] + Rhs[7]
268   llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
269 
270   std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
271       std::vector<llvm::Value*> vectors, llvm::Value* init_values);
272 
273   llvm::Type* IntegerTypeForFloatSize(bool vector);
274   llvm::Value* I1ToFloat(llvm::Value* i1);
GetConstantFloat(llvm::Type * type,const llvm::APFloat & f)275   llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
276     llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
277     if (llvm::isa<llvm::VectorType>(type)) {
278       return llvm::ConstantVector::getSplat(
279           llvm::ElementCount::getFixed(vector_size()), scalar_value);
280     }
281     return scalar_value;
282   }
283 
284   int64 vector_size_;
285   PrimitiveType primitive_type_;
286   llvm::IRBuilder<>* b_;
287   llvm::Type* vector_type_;
288   llvm::Type* vector_pointer_type_;
289   llvm::Type* scalar_type_;
290   llvm::Type* scalar_pointer_type_;
291   std::string name_;
292 };
293 
294 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass
295 // can later convert to a SSA value.
296 class LlvmVariable {
297  public:
298   LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b);
299 
300   llvm::Value* Get() const;
301   void Set(llvm::Value* new_value);
302 
303  private:
304   llvm::AllocaInst* alloca_;
305   llvm::IRBuilder<>* b_;
306 };
307 
308 class VectorVariable : public LlvmVariable {
309  public:
VectorVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)310   VectorVariable(VectorSupportLibrary* vector_support,
311                  llvm::Value* initial_value)
312       : LlvmVariable(vector_support->vector_type(), vector_support->b()) {
313     Set(initial_value);
314   }
315 };
316 
317 class ScalarVariable : public LlvmVariable {
318  public:
ScalarVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)319   ScalarVariable(VectorSupportLibrary* vector_support,
320                  llvm::Value* initial_value)
321       : LlvmVariable(vector_support->scalar_type(), vector_support->b()) {
322     Set(initial_value);
323   }
324 };
325 
326 // This wraps a set of alloca-backed stack variables that can, as a whole, store
327 // a tile.  A "tile" is a sequence of vectors that is typically used as a 2D
328 // grid of scalar values (e.g. for tiled GEMMs).
329 class TileVariable {
330  public:
331   TileVariable(VectorSupportLibrary* vector_support,
332                std::vector<llvm::Value*> initial_value);
333 
334   std::vector<llvm::Value*> Get() const;
335   void Set(absl::Span<llvm::Value* const> value);
336 
337  private:
338   std::vector<VectorVariable> storage_;
339 };
340 }  // namespace cpu
341 }  // namespace xla
342 
343 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
344