1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 
23 namespace xla {
24 namespace cpu {
VectorSupportLibrary(PrimitiveType primitive_type,int64 vector_size,llvm::IRBuilder<> * b,std::string name)25 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
26                                            int64 vector_size,
27                                            llvm::IRBuilder<>* b,
28                                            std::string name)
29     : vector_size_(vector_size),
30       primitive_type_(primitive_type),
31       b_(b),
32       name_(std::move(name)) {
33   scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
34       primitive_type, b_->GetInsertBlock()->getModule());
35   scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
36   vector_type_ = llvm::VectorType::get(scalar_type_, vector_size, false);
37   vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
38 }
39 
TypeToString(llvm::Type * type)40 static string TypeToString(llvm::Type* type) {
41   std::string o;
42   llvm::raw_string_ostream ostream(o);
43   type->print(ostream);
44   return ostream.str();
45 }
46 
AssertCorrectTypes(std::initializer_list<llvm::Value * > values)47 void VectorSupportLibrary::AssertCorrectTypes(
48     std::initializer_list<llvm::Value*> values) {
49   for (llvm::Value* v : values) {
50     llvm::Type* type = v->getType();
51     if (type != scalar_type() && type != vector_type()) {
52       LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
53                  << TypeToString(vector_type()) << " but got "
54                  << TypeToString(type);
55     }
56   }
57 }
58 
Mul(llvm::Value * lhs,llvm::Value * rhs)59 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
60   AssertCorrectTypes({lhs, rhs});
61   return MulInternal(lhs, rhs);
62 }
63 
MulInternal(llvm::Value * lhs,llvm::Value * rhs)64 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
65                                                llvm::Value* rhs) {
66   if (scalar_type_->isFloatingPointTy()) {
67     return b()->CreateFMul(lhs, rhs, name());
68   } else {
69     return b()->CreateMul(lhs, rhs, name());
70   }
71 }
72 
Add(llvm::Value * lhs,llvm::Value * rhs)73 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
74   AssertCorrectTypes({lhs, rhs});
75   return AddInternal(lhs, rhs);
76 }
77 
Sub(llvm::Value * lhs,llvm::Value * rhs)78 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
79   AssertCorrectTypes({lhs, rhs});
80   return b()->CreateFSub(lhs, rhs);
81 }
82 
Max(llvm::Value * lhs,llvm::Value * rhs,bool enable_fast_min_max)83 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs,
84                                        bool enable_fast_min_max) {
85   AssertCorrectTypes({lhs, rhs});
86   if (scalar_type_->isFloatingPointTy()) {
87     return llvm_ir::EmitFloatMax(lhs, rhs, b_, enable_fast_min_max);
88   } else {
89     LOG(FATAL) << "Max for integers is unimplemented";
90   }
91 }
92 
Floor(llvm::Value * a)93 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
94   AssertCorrectTypes({a});
95   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
96                                       {a->getType()}, b());
97 }
98 
Div(llvm::Value * lhs,llvm::Value * rhs)99 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
100   AssertCorrectTypes({lhs, rhs});
101   if (scalar_type_->isFloatingPointTy()) {
102     return b()->CreateFDiv(lhs, rhs, name());
103   } else {
104     LOG(FATAL) << "Division for integers is unimplemented";
105   }
106 }
107 
Clamp(llvm::Value * a,const llvm::APFloat & low,const llvm::APFloat & high)108 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
109                                          const llvm::APFloat& low,
110                                          const llvm::APFloat& high) {
111   CHECK(!low.isNaN());
112   CHECK(!high.isNaN());
113   CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
114 
115   AssertCorrectTypes({a});
116   llvm::Type* type = a->getType();
117   CHECK(scalar_type_->isFloatingPointTy());
118 
119   llvm::Value* low_value = GetConstantFloat(type, low);
120   llvm::Value* high_value = GetConstantFloat(type, high);
121   a = b_->CreateSelect(b_->CreateFCmpUGE(a, low_value), a, low_value);
122   a = b_->CreateSelect(b_->CreateFCmpULE(a, high_value), a, high_value);
123   return a;
124 }
125 
FCmpEQMask(llvm::Value * lhs,llvm::Value * rhs)126 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
127                                               llvm::Value* rhs) {
128   AssertCorrectTypes({lhs, rhs});
129   return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name()));
130 }
131 
FCmpOLTMask(llvm::Value * lhs,llvm::Value * rhs)132 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
133                                                llvm::Value* rhs) {
134   AssertCorrectTypes({lhs, rhs});
135   return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name()));
136 }
137 
FCmpULEMask(llvm::Value * lhs,llvm::Value * rhs)138 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
139                                                llvm::Value* rhs) {
140   AssertCorrectTypes({lhs, rhs});
141   return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name()));
142 }
143 
I1ToFloat(llvm::Value * i1)144 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
145   bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
146   llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
147   return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()),
148                             is_vector ? vector_type() : scalar_type(), name());
149 }
150 
IntegerTypeForFloatSize(bool vector)151 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
152   CHECK(scalar_type()->isFloatingPointTy());
153   const llvm::DataLayout& data_layout =
154       b()->GetInsertBlock()->getModule()->getDataLayout();
155   int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
156   llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits);
157   if (vector) {
158     return llvm::VectorType::get(scalar_int_type, vector_size(), false);
159   } else {
160     return scalar_int_type;
161   }
162 }
163 
BroadcastScalar(llvm::Value * x)164 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
165   CHECK_EQ(x->getType(), scalar_type());
166   return b()->CreateVectorSplat(vector_size(), x, name());
167 }
168 
FloatAnd(llvm::Value * lhs,llvm::Value * rhs)169 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
170                                             llvm::Value* rhs) {
171   AssertCorrectTypes({lhs, rhs});
172   llvm::Type* int_type =
173       IntegerTypeForFloatSize(lhs->getType() == vector_type());
174   return b()->CreateBitCast(
175       b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()),
176                      b()->CreateBitCast(rhs, int_type, name()), name()),
177       vector_type());
178 }
179 
FloatNot(llvm::Value * lhs)180 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
181   AssertCorrectTypes({lhs});
182   llvm::Type* int_type =
183       IntegerTypeForFloatSize(lhs->getType() == vector_type());
184   return b()->CreateBitCast(
185       b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()),
186       vector_type());
187 }
188 
FloatOr(llvm::Value * lhs,llvm::Value * rhs)189 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
190   AssertCorrectTypes({lhs, rhs});
191   llvm::Type* int_type =
192       IntegerTypeForFloatSize(lhs->getType() == vector_type());
193   return b()->CreateBitCast(
194       b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()),
195                     b()->CreateBitCast(rhs, int_type, name()), name()),
196       vector_type(), name());
197 }
198 
AddInternal(llvm::Value * lhs,llvm::Value * rhs)199 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
200                                                llvm::Value* rhs) {
201   if (scalar_type_->isFloatingPointTy()) {
202     return b()->CreateFAdd(lhs, rhs, name());
203   } else {
204     return b()->CreateAdd(lhs, rhs, name());
205   }
206 }
207 
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements)208 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
209     llvm::Value* base_pointer, llvm::Value* offset_elements) {
210   if (base_pointer->getType() != scalar_pointer_type()) {
211     base_pointer =
212         b()->CreateBitCast(base_pointer, scalar_pointer_type(), name());
213   }
214   return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name());
215 }
216 
LoadVector(llvm::Value * pointer)217 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
218   if (pointer->getType() != vector_pointer_type()) {
219     pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
220   }
221   return b()->CreateAlignedLoad(
222       pointer, llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)),
223       name());
224 }
225 
LoadScalar(llvm::Value * pointer)226 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
227   if (pointer->getType() != scalar_pointer_type()) {
228     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
229   }
230   return b()->CreateAlignedLoad(
231       pointer, llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)),
232       name());
233 }
234 
StoreVector(llvm::Value * value,llvm::Value * pointer)235 void VectorSupportLibrary::StoreVector(llvm::Value* value,
236                                        llvm::Value* pointer) {
237   AssertCorrectTypes({value});
238   if (pointer->getType() != vector_pointer_type()) {
239     pointer = b()->CreateBitCast(pointer, vector_pointer_type());
240   }
241   b()->CreateAlignedStore(
242       value, pointer,
243       llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
244 }
245 
StoreScalar(llvm::Value * value,llvm::Value * pointer)246 void VectorSupportLibrary::StoreScalar(llvm::Value* value,
247                                        llvm::Value* pointer) {
248   AssertCorrectTypes({value});
249   if (pointer->getType() != scalar_pointer_type()) {
250     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
251   }
252   b()->CreateAlignedStore(
253       value, pointer,
254       llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
255 }
256 
LoadBroadcast(llvm::Value * pointer)257 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
258   if (pointer->getType() != scalar_pointer_type()) {
259     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
260   }
261   return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer),
262                                 name());
263 }
264 
AddReduce(llvm::Value * vector)265 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
266   llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
267   for (unsigned i = vector_size(); i != 1; i >>= 1) {
268     // On every iteration, we shuffle half of the remaining lanes to the top
269     // half of shuffle, and add two old and the new vector.
270 
271     for (unsigned j = 0; j < vector_size(); ++j) {
272       if (j < (i / 2)) {
273         mask[j] = b()->getInt32(i / 2 + j);
274       } else {
275         mask[j] = llvm::UndefValue::get(b()->getInt32Ty());
276       }
277     }
278 
279     llvm::Value* half_remaining_lanes =
280         b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
281                                  llvm::ConstantVector::get(mask), "");
282     vector = Add(vector, half_remaining_lanes);
283   }
284 
285   return b()->CreateExtractElement(vector, b()->getInt32(0), name());
286 }
287 
AvxStyleHorizontalAdd(llvm::Value * lhs,llvm::Value * rhs)288 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
289                                                          llvm::Value* rhs) {
290   CHECK_EQ(lhs->getType(), vector_type());
291   CHECK_EQ(rhs->getType(), vector_type());
292   CHECK_EQ(vector_size() % 2, 0);
293 
294   llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
295 
296   // Adding the values shuffled using mask_a and mask_b gives us the
297   // AVX-style horizontal add we want.  The masks work as documented
298   // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
299   //
300   // Here are the masks for vector_width() == 8:
301   //
302   //    index: |0 |1 |2 | 3 |4 |5 | 6 | 7
303   //   --------+--+--+--+---+--+--+---+---
304   //   mask_a: |0 |2 |8 |10 |4 |6 |12 |14
305   //   mask_b: |1 |3 |9 |11 |5 |7 |13 |16
306   //
307   // So, as an example, the value at lane 3 of the result vector is
308   // the result of adding lane 10 and lane 11 in the combined lhs++rhs
309   // vector, which are the lanes 2 and 3 in the rhs vector.
310   for (int i = 0; i < vector_size(); i += 2) {
311     int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
312     mask_a.push_back(b()->getInt32(increment + i));
313     mask_b.push_back(b()->getInt32(increment + i + 1));
314   }
315   for (int i = 0; i < vector_size(); i += 2) {
316     int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
317     mask_a.push_back(b()->getInt32(increment + i));
318     mask_b.push_back(b()->getInt32(increment + i + 1));
319   }
320 
321   llvm::Value* shuffle_0 =
322       b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a));
323   llvm::Value* shuffle_1 =
324       b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b));
325 
326   return Add(shuffle_0, shuffle_1);
327 }
328 
ExtractLowHalf(llvm::Value * vector)329 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
330   llvm::SmallVector<llvm::Constant*, 32> mask;
331   for (int i = 0; i < vector_size() / 2; i++) {
332     mask.push_back(b()->getInt32(i));
333   }
334 
335   return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
336                                   llvm::ConstantVector::get(mask));
337 }
338 
ExtractHighHalf(llvm::Value * vector)339 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
340   llvm::SmallVector<llvm::Constant*, 32> mask;
341   for (int i = 0; i < vector_size() / 2; i++) {
342     mask.push_back(b()->getInt32(i + vector_size() / 2));
343   }
344 
345   return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
346                                   llvm::ConstantVector::get(mask));
347 }
348 
ComputeHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)349 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
350     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
351   const int x86_avx_vector_elements =
352       TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
353   if (vector_size() == x86_avx_vector_elements &&
354       vectors.size() == x86_avx_vector_elements) {
355     return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
356   }
357 
358   std::vector<llvm::Value*> result;
359   std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
360                  [this](llvm::Value* vector) { return AddReduce(vector); });
361   if (init_values) {
362     for (int64 i = 0, e = result.size(); i < e; i++) {
363       result[i] = Add(result[i],
364                       b()->CreateExtractElement(init_values, b()->getInt32(i)));
365     }
366   }
367   return result;
368 }
369 
370 std::vector<llvm::Value*>
ComputeAvxOptimizedHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)371 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
372     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
373   // vectors are N llvm vector values, each with N elements.
374   int64 lane_width = vectors.size();
375 
376   while (vectors.size() != 2) {
377     std::vector<llvm::Value*> new_vectors;
378     for (int i = 0; i < vectors.size(); i += 2) {
379       new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
380     }
381 
382     vectors = std::move(new_vectors);
383   }
384 
385   llvm::Value* low =
386       AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
387   if (init_values) {
388     low = AddInternal(ExtractLowHalf(init_values), low);
389   }
390   llvm::Value* high =
391       AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
392   if (init_values) {
393     high = AddInternal(ExtractHighHalf(init_values), high);
394   }
395 
396   // `low` has the first `lane_width / 2` horizontal reductions, and `high` has
397   // the next `lane_width / 2` horizontal reductions.
398 
399   std::vector<llvm::Value*> results;
400   for (int i = 0; i < lane_width; i++) {
401     llvm::Value* scalar_result =
402         b()->CreateExtractElement(i < (lane_width / 2) ? low : high,
403                                   b()->getInt32(i % (lane_width / 2)), name());
404     results.push_back(scalar_result);
405   }
406 
407   return results;
408 }
409 
GetZeroVector()410 llvm::Value* VectorSupportLibrary::GetZeroVector() {
411   return llvm::Constant::getNullValue(vector_type());
412 }
413 
GetZeroScalar()414 llvm::Value* VectorSupportLibrary::GetZeroScalar() {
415   return llvm::Constant::getNullValue(scalar_type());
416 }
417 
LlvmVariable(llvm::Type * type,llvm::IRBuilder<> * b)418 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) {
419   alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_);
420 }
421 
Get() const422 llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); }
423 
Set(llvm::Value * new_value)424 void LlvmVariable::Set(llvm::Value* new_value) {
425   b_->CreateStore(new_value, alloca_);
426 }
427 
TileVariable(VectorSupportLibrary * vector_support,std::vector<llvm::Value * > initial_value)428 TileVariable::TileVariable(VectorSupportLibrary* vector_support,
429                            std::vector<llvm::Value*> initial_value) {
430   for (llvm::Value* initial_vector_value : initial_value) {
431     storage_.emplace_back(vector_support, initial_vector_value);
432   }
433 }
434 
Get() const435 std::vector<llvm::Value*> TileVariable::Get() const {
436   std::vector<llvm::Value*> result;
437   absl::c_transform(storage_, std::back_inserter(result),
438                     [&](VectorVariable vect_var) { return vect_var.Get(); });
439   return result;
440 }
441 
Set(absl::Span<llvm::Value * const> value)442 void TileVariable::Set(absl::Span<llvm::Value* const> value) {
443   CHECK_EQ(value.size(), storage_.size());
444   for (int64 i = 0, e = value.size(); i < e; i++) {
445     storage_[i].Set(value[i]);
446   }
447 }
448 
449 }  // namespace cpu
450 }  // namespace xla
451