1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
17
18 #include "absl/algorithm/container.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22
23 namespace xla {
24 namespace cpu {
VectorSupportLibrary(PrimitiveType primitive_type,int64 vector_size,llvm::IRBuilder<> * b,std::string name)25 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
26 int64 vector_size,
27 llvm::IRBuilder<>* b,
28 std::string name)
29 : vector_size_(vector_size),
30 primitive_type_(primitive_type),
31 b_(b),
32 name_(std::move(name)) {
33 scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
34 primitive_type, b_->GetInsertBlock()->getModule());
35 scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
36 vector_type_ = llvm::VectorType::get(scalar_type_, vector_size, false);
37 vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
38 }
39
TypeToString(llvm::Type * type)40 static string TypeToString(llvm::Type* type) {
41 std::string o;
42 llvm::raw_string_ostream ostream(o);
43 type->print(ostream);
44 return ostream.str();
45 }
46
AssertCorrectTypes(std::initializer_list<llvm::Value * > values)47 void VectorSupportLibrary::AssertCorrectTypes(
48 std::initializer_list<llvm::Value*> values) {
49 for (llvm::Value* v : values) {
50 llvm::Type* type = v->getType();
51 if (type != scalar_type() && type != vector_type()) {
52 LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
53 << TypeToString(vector_type()) << " but got "
54 << TypeToString(type);
55 }
56 }
57 }
58
Mul(llvm::Value * lhs,llvm::Value * rhs)59 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
60 AssertCorrectTypes({lhs, rhs});
61 return MulInternal(lhs, rhs);
62 }
63
MulInternal(llvm::Value * lhs,llvm::Value * rhs)64 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
65 llvm::Value* rhs) {
66 if (scalar_type_->isFloatingPointTy()) {
67 return b()->CreateFMul(lhs, rhs, name());
68 } else {
69 return b()->CreateMul(lhs, rhs, name());
70 }
71 }
72
Add(llvm::Value * lhs,llvm::Value * rhs)73 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
74 AssertCorrectTypes({lhs, rhs});
75 return AddInternal(lhs, rhs);
76 }
77
Sub(llvm::Value * lhs,llvm::Value * rhs)78 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
79 AssertCorrectTypes({lhs, rhs});
80 return b()->CreateFSub(lhs, rhs);
81 }
82
Max(llvm::Value * lhs,llvm::Value * rhs,bool enable_fast_min_max)83 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs,
84 bool enable_fast_min_max) {
85 AssertCorrectTypes({lhs, rhs});
86 if (scalar_type_->isFloatingPointTy()) {
87 return llvm_ir::EmitFloatMax(lhs, rhs, b_, enable_fast_min_max);
88 } else {
89 LOG(FATAL) << "Max for integers is unimplemented";
90 }
91 }
92
Floor(llvm::Value * a)93 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
94 AssertCorrectTypes({a});
95 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
96 {a->getType()}, b());
97 }
98
Div(llvm::Value * lhs,llvm::Value * rhs)99 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
100 AssertCorrectTypes({lhs, rhs});
101 if (scalar_type_->isFloatingPointTy()) {
102 return b()->CreateFDiv(lhs, rhs, name());
103 } else {
104 LOG(FATAL) << "Division for integers is unimplemented";
105 }
106 }
107
Clamp(llvm::Value * a,const llvm::APFloat & low,const llvm::APFloat & high)108 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
109 const llvm::APFloat& low,
110 const llvm::APFloat& high) {
111 CHECK(!low.isNaN());
112 CHECK(!high.isNaN());
113 CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
114
115 AssertCorrectTypes({a});
116 llvm::Type* type = a->getType();
117 CHECK(scalar_type_->isFloatingPointTy());
118
119 llvm::Value* low_value = GetConstantFloat(type, low);
120 llvm::Value* high_value = GetConstantFloat(type, high);
121 a = b_->CreateSelect(b_->CreateFCmpUGE(a, low_value), a, low_value);
122 a = b_->CreateSelect(b_->CreateFCmpULE(a, high_value), a, high_value);
123 return a;
124 }
125
FCmpEQMask(llvm::Value * lhs,llvm::Value * rhs)126 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
127 llvm::Value* rhs) {
128 AssertCorrectTypes({lhs, rhs});
129 return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name()));
130 }
131
FCmpOLTMask(llvm::Value * lhs,llvm::Value * rhs)132 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
133 llvm::Value* rhs) {
134 AssertCorrectTypes({lhs, rhs});
135 return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name()));
136 }
137
FCmpULEMask(llvm::Value * lhs,llvm::Value * rhs)138 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
139 llvm::Value* rhs) {
140 AssertCorrectTypes({lhs, rhs});
141 return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name()));
142 }
143
I1ToFloat(llvm::Value * i1)144 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
145 bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
146 llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
147 return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()),
148 is_vector ? vector_type() : scalar_type(), name());
149 }
150
IntegerTypeForFloatSize(bool vector)151 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
152 CHECK(scalar_type()->isFloatingPointTy());
153 const llvm::DataLayout& data_layout =
154 b()->GetInsertBlock()->getModule()->getDataLayout();
155 int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
156 llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits);
157 if (vector) {
158 return llvm::VectorType::get(scalar_int_type, vector_size(), false);
159 } else {
160 return scalar_int_type;
161 }
162 }
163
BroadcastScalar(llvm::Value * x)164 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
165 CHECK_EQ(x->getType(), scalar_type());
166 return b()->CreateVectorSplat(vector_size(), x, name());
167 }
168
FloatAnd(llvm::Value * lhs,llvm::Value * rhs)169 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
170 llvm::Value* rhs) {
171 AssertCorrectTypes({lhs, rhs});
172 llvm::Type* int_type =
173 IntegerTypeForFloatSize(lhs->getType() == vector_type());
174 return b()->CreateBitCast(
175 b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()),
176 b()->CreateBitCast(rhs, int_type, name()), name()),
177 vector_type());
178 }
179
FloatNot(llvm::Value * lhs)180 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
181 AssertCorrectTypes({lhs});
182 llvm::Type* int_type =
183 IntegerTypeForFloatSize(lhs->getType() == vector_type());
184 return b()->CreateBitCast(
185 b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()),
186 vector_type());
187 }
188
FloatOr(llvm::Value * lhs,llvm::Value * rhs)189 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
190 AssertCorrectTypes({lhs, rhs});
191 llvm::Type* int_type =
192 IntegerTypeForFloatSize(lhs->getType() == vector_type());
193 return b()->CreateBitCast(
194 b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()),
195 b()->CreateBitCast(rhs, int_type, name()), name()),
196 vector_type(), name());
197 }
198
AddInternal(llvm::Value * lhs,llvm::Value * rhs)199 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
200 llvm::Value* rhs) {
201 if (scalar_type_->isFloatingPointTy()) {
202 return b()->CreateFAdd(lhs, rhs, name());
203 } else {
204 return b()->CreateAdd(lhs, rhs, name());
205 }
206 }
207
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements)208 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
209 llvm::Value* base_pointer, llvm::Value* offset_elements) {
210 if (base_pointer->getType() != scalar_pointer_type()) {
211 base_pointer =
212 b()->CreateBitCast(base_pointer, scalar_pointer_type(), name());
213 }
214 return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name());
215 }
216
LoadVector(llvm::Value * pointer)217 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
218 if (pointer->getType() != vector_pointer_type()) {
219 pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
220 }
221 return b()->CreateAlignedLoad(
222 pointer, llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)),
223 name());
224 }
225
LoadScalar(llvm::Value * pointer)226 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
227 if (pointer->getType() != scalar_pointer_type()) {
228 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
229 }
230 return b()->CreateAlignedLoad(
231 pointer, llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)),
232 name());
233 }
234
StoreVector(llvm::Value * value,llvm::Value * pointer)235 void VectorSupportLibrary::StoreVector(llvm::Value* value,
236 llvm::Value* pointer) {
237 AssertCorrectTypes({value});
238 if (pointer->getType() != vector_pointer_type()) {
239 pointer = b()->CreateBitCast(pointer, vector_pointer_type());
240 }
241 b()->CreateAlignedStore(
242 value, pointer,
243 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
244 }
245
StoreScalar(llvm::Value * value,llvm::Value * pointer)246 void VectorSupportLibrary::StoreScalar(llvm::Value* value,
247 llvm::Value* pointer) {
248 AssertCorrectTypes({value});
249 if (pointer->getType() != scalar_pointer_type()) {
250 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
251 }
252 b()->CreateAlignedStore(
253 value, pointer,
254 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
255 }
256
LoadBroadcast(llvm::Value * pointer)257 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
258 if (pointer->getType() != scalar_pointer_type()) {
259 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
260 }
261 return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer),
262 name());
263 }
264
AddReduce(llvm::Value * vector)265 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
266 llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
267 for (unsigned i = vector_size(); i != 1; i >>= 1) {
268 // On every iteration, we shuffle half of the remaining lanes to the top
269 // half of shuffle, and add two old and the new vector.
270
271 for (unsigned j = 0; j < vector_size(); ++j) {
272 if (j < (i / 2)) {
273 mask[j] = b()->getInt32(i / 2 + j);
274 } else {
275 mask[j] = llvm::UndefValue::get(b()->getInt32Ty());
276 }
277 }
278
279 llvm::Value* half_remaining_lanes =
280 b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
281 llvm::ConstantVector::get(mask), "");
282 vector = Add(vector, half_remaining_lanes);
283 }
284
285 return b()->CreateExtractElement(vector, b()->getInt32(0), name());
286 }
287
AvxStyleHorizontalAdd(llvm::Value * lhs,llvm::Value * rhs)288 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
289 llvm::Value* rhs) {
290 CHECK_EQ(lhs->getType(), vector_type());
291 CHECK_EQ(rhs->getType(), vector_type());
292 CHECK_EQ(vector_size() % 2, 0);
293
294 llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
295
296 // Adding the values shuffled using mask_a and mask_b gives us the
297 // AVX-style horizontal add we want. The masks work as documented
298 // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
299 //
300 // Here are the masks for vector_width() == 8:
301 //
302 // index: |0 |1 |2 | 3 |4 |5 | 6 | 7
303 // --------+--+--+--+---+--+--+---+---
304 // mask_a: |0 |2 |8 |10 |4 |6 |12 |14
305 // mask_b: |1 |3 |9 |11 |5 |7 |13 |16
306 //
307 // So, as an example, the value at lane 3 of the result vector is
308 // the result of adding lane 10 and lane 11 in the combined lhs++rhs
309 // vector, which are the lanes 2 and 3 in the rhs vector.
310 for (int i = 0; i < vector_size(); i += 2) {
311 int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
312 mask_a.push_back(b()->getInt32(increment + i));
313 mask_b.push_back(b()->getInt32(increment + i + 1));
314 }
315 for (int i = 0; i < vector_size(); i += 2) {
316 int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
317 mask_a.push_back(b()->getInt32(increment + i));
318 mask_b.push_back(b()->getInt32(increment + i + 1));
319 }
320
321 llvm::Value* shuffle_0 =
322 b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a));
323 llvm::Value* shuffle_1 =
324 b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b));
325
326 return Add(shuffle_0, shuffle_1);
327 }
328
ExtractLowHalf(llvm::Value * vector)329 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
330 llvm::SmallVector<llvm::Constant*, 32> mask;
331 for (int i = 0; i < vector_size() / 2; i++) {
332 mask.push_back(b()->getInt32(i));
333 }
334
335 return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
336 llvm::ConstantVector::get(mask));
337 }
338
ExtractHighHalf(llvm::Value * vector)339 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
340 llvm::SmallVector<llvm::Constant*, 32> mask;
341 for (int i = 0; i < vector_size() / 2; i++) {
342 mask.push_back(b()->getInt32(i + vector_size() / 2));
343 }
344
345 return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
346 llvm::ConstantVector::get(mask));
347 }
348
ComputeHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)349 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
350 std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
351 const int x86_avx_vector_elements =
352 TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
353 if (vector_size() == x86_avx_vector_elements &&
354 vectors.size() == x86_avx_vector_elements) {
355 return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
356 }
357
358 std::vector<llvm::Value*> result;
359 std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
360 [this](llvm::Value* vector) { return AddReduce(vector); });
361 if (init_values) {
362 for (int64 i = 0, e = result.size(); i < e; i++) {
363 result[i] = Add(result[i],
364 b()->CreateExtractElement(init_values, b()->getInt32(i)));
365 }
366 }
367 return result;
368 }
369
370 std::vector<llvm::Value*>
ComputeAvxOptimizedHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)371 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
372 std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
373 // vectors are N llvm vector values, each with N elements.
374 int64 lane_width = vectors.size();
375
376 while (vectors.size() != 2) {
377 std::vector<llvm::Value*> new_vectors;
378 for (int i = 0; i < vectors.size(); i += 2) {
379 new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
380 }
381
382 vectors = std::move(new_vectors);
383 }
384
385 llvm::Value* low =
386 AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
387 if (init_values) {
388 low = AddInternal(ExtractLowHalf(init_values), low);
389 }
390 llvm::Value* high =
391 AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
392 if (init_values) {
393 high = AddInternal(ExtractHighHalf(init_values), high);
394 }
395
396 // `low` has the first `lane_width / 2` horizontal reductions, and `high` has
397 // the next `lane_width / 2` horizontal reductions.
398
399 std::vector<llvm::Value*> results;
400 for (int i = 0; i < lane_width; i++) {
401 llvm::Value* scalar_result =
402 b()->CreateExtractElement(i < (lane_width / 2) ? low : high,
403 b()->getInt32(i % (lane_width / 2)), name());
404 results.push_back(scalar_result);
405 }
406
407 return results;
408 }
409
GetZeroVector()410 llvm::Value* VectorSupportLibrary::GetZeroVector() {
411 return llvm::Constant::getNullValue(vector_type());
412 }
413
GetZeroScalar()414 llvm::Value* VectorSupportLibrary::GetZeroScalar() {
415 return llvm::Constant::getNullValue(scalar_type());
416 }
417
LlvmVariable(llvm::Type * type,llvm::IRBuilder<> * b)418 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) {
419 alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_);
420 }
421
Get() const422 llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); }
423
Set(llvm::Value * new_value)424 void LlvmVariable::Set(llvm::Value* new_value) {
425 b_->CreateStore(new_value, alloca_);
426 }
427
TileVariable(VectorSupportLibrary * vector_support,std::vector<llvm::Value * > initial_value)428 TileVariable::TileVariable(VectorSupportLibrary* vector_support,
429 std::vector<llvm::Value*> initial_value) {
430 for (llvm::Value* initial_vector_value : initial_value) {
431 storage_.emplace_back(vector_support, initial_vector_value);
432 }
433 }
434
Get() const435 std::vector<llvm::Value*> TileVariable::Get() const {
436 std::vector<llvm::Value*> result;
437 absl::c_transform(storage_, std::back_inserter(result),
438 [&](VectorVariable vect_var) { return vect_var.Get(); });
439 return result;
440 }
441
Set(absl::Span<llvm::Value * const> value)442 void TileVariable::Set(absl::Span<llvm::Value* const> value) {
443 CHECK_EQ(value.size(), storage_.size());
444 for (int64 i = 0, e = value.size(); i < e; i++) {
445 storage_[i].Set(value[i]);
446 }
447 }
448
449 } // namespace cpu
450 } // namespace xla
451