1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17 
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Instructions.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace xla {
30 namespace llvm_ir {
31 
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)32 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
33                       llvm::Value* linear, const Shape& shape,
34                       llvm::Type* index_type)
35     : Index(multidim, shape, index_type) {
36   CHECK_NE(linear, nullptr);
37   linear_ = linear;
38 }
39 
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const40 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
41                                  llvm::Value* linear, const Shape& shape,
42                                  llvm::IRBuilder<>* b) const {
43   int64 divisor = 1;
44   const Layout& layout = shape.layout();
45   for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
46     int64 dimension = layout.minor_to_major(i);
47     int64 size_of_current_dimension = shape.dimensions(dimension);
48 
49     // If i is not the last dimension, compute
50     //   (linear_index / divisor) % current_dimension.
51     // If i is the last dimension, we can skip the mod, because we assume that
52     // linear is in bounds.
53     //
54     // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
55     // guarded under some sort of xla-memcheck flag.  This might be particularly
56     // useful because cuda-memcheck can't help us much in XLA: Most of our
57     // memory lives in one big allocation, so cuda-memcheck can't detect
58     // out-of-bounds accesses.
59     auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
60     if (i < layout.minor_to_major_size() - 1) {
61       (*multidim)[dimension] = b->CreateURem(
62           quot, GetConstantWithIndexType(size_of_current_dimension));
63     } else {
64       (*multidim)[dimension] = quot;
65     }
66     divisor *= size_of_current_dimension;
67   }
68 }
69 
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)70 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
71                       llvm::IRBuilder<>* b)
72     : multidim_(shape.rank()),
73       linear_(linear),
74       layout_(shape.layout()),
75       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
76   CHECK_NE(linear, nullptr);
77   index_type_ = linear->getType();
78   CHECK(LayoutUtil::HasLayout(shape))
79       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
80       << " should have a layout.";
81   Delinearize(&multidim_, linear, shape, b);
82 }
83 
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)84 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
85                       const Shape& shape, llvm::Type* index_type)
86     : multidim_(multidim.begin(), multidim.end()),
87       linear_(nullptr),
88       layout_(shape.layout()),
89       dims_(shape.dimensions().begin(), shape.dimensions().end()),
90       index_type_(index_type) {
91   CHECK_NE(index_type_, nullptr);
92   CHECK_EQ(shape.dimensions_size(), multidim.size());
93   for (const auto* dim : multidim) {
94     CHECK_NE(dim, nullptr);
95   }
96   CHECK(LayoutUtil::HasLayout(shape))
97       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
98       << " should have a layout.";
99 }
100 
IrArray(llvm::Value * base_ptr,Shape shape)101 IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
102     : base_ptr_(base_ptr), shape_(std::move(shape)) {
103   TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
104   CHECK(base_ptr_->getType()->isPointerTy());
105   int depth = 0;
106   element_type_ =
107       llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
108   while (llvm::ArrayType* array_type =
109              llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
110     element_type_ = array_type->getElementType();
111     ++depth;
112   }
113 
114   if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
115     DCHECK(depth == 1 || depth == 0) << depth;
116   } else {
117     DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
118   }
119 }
120 
121 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const122 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
123   auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
124   *b.mutable_layout() = layout_;
125   return linear_ != nullptr &&
126          ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
127          ShapeUtil::ReshapeIsBitcast(a, b);
128 }
129 
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const130 IrArray::Index IrArray::Index::SourceIndexOfReshape(
131     const Shape& output_shape, const Shape& input_shape,
132     llvm::IRBuilder<>* builder) const {
133   const auto& target_index = *this;
134   CHECK_EQ(target_index.size(), output_shape.rank());
135   std::vector<std::pair<int64, int64>> common_factors =
136       CommonFactors(AsInt64Slice(input_shape.dimensions()),
137                     AsInt64Slice(output_shape.dimensions()));
138   std::vector<llvm::Value*> source_multidim_index(
139       input_shape.rank(), llvm::UndefValue::get(index_type_));
140   // We compute the source indices in each common factor from only the target
141   // indices in the same common factor.
142   for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
143     llvm::Value* logical_linear_index =
144         Index(absl::Span<llvm::Value* const>(multidim_).subspan(
145                   common_factors[k].second,
146                   common_factors[k + 1].second - common_factors[k].second),
147               index_type_)
148             .Linearize(AsInt64Slice(output_shape.dimensions())
149                            .subspan(common_factors[k].second,
150                                     common_factors[k + 1].second -
151                                         common_factors[k].second),
152                        builder);
153     // Delinearizes logical_linear_index for the source array in row-major
154     // collapsed order. The first rank-1 indices are the remainder of the
155     // linear index by each dimension size.
156     for (int64 i = common_factors[k + 1].first - 1;
157          i >= common_factors[k].first; --i) {
158       llvm::Value* divisor =
159           GetConstantWithIndexType(input_shape.dimensions(i));
160       if (input_shape.dimensions(i) == 1) {
161         source_multidim_index[i] = GetConstantWithIndexType(0);
162       } else if (i == common_factors[k].first) {
163         source_multidim_index[i] = logical_linear_index;
164       } else {
165         source_multidim_index[i] =
166             builder->CreateURem(logical_linear_index, divisor);
167       }
168       logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
169     }
170   }
171 
172   if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
173       LayoutUtil::HasLayout(output_shape) &&
174       ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
175     return Index(source_multidim_index, linear(), input_shape, index_type_);
176   }
177   return Index(source_multidim_index, index_type_);
178 }
179 
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64> starts,absl::Span<const int64> strides,llvm::IRBuilder<> * builder) const180 IrArray::Index IrArray::Index::SourceIndexOfSlice(
181     const Shape& operand_shape, absl::Span<const int64> starts,
182     absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
183   std::vector<llvm::Value*> source_multi_index(multidim_.size());
184   for (int i = 0; i < multidim_.size(); ++i) {
185     int64 stride = strides[i];
186     auto type = multidim_[i]->getType();
187 
188     if (stride != 1) {
189       source_multi_index[i] = builder->CreateAdd(
190           builder->CreateMul(multidim_[i],
191                              llvm::ConstantInt::get(type, stride)),
192           llvm::ConstantInt::get(type, starts[i]));
193     } else {
194       source_multi_index[i] = builder->CreateAdd(
195           multidim_[i], llvm::ConstantInt::get(type, starts[i]));
196     }
197   }
198   return Index(source_multi_index, operand_shape, index_type_);
199 }
200 
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const201 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
202     const Shape& shape, const Shape& operand_shape,
203     absl::Span<const int64> dimension_mapping,
204     llvm::IRBuilder<>* builder) const {
205   std::vector<llvm::Value*> operand_multidim_index =
206       Permute(dimension_mapping, multidim());
207 
208   if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
209       LayoutUtil::HasLayout(shape) &&
210       ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
211     return Index(operand_multidim_index, linear(), operand_shape, index_type_);
212   }
213 
214   return Index(operand_multidim_index);
215 }
216 
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const217 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
218     const Shape& shape, const Shape& operand_shape,
219     llvm::IRBuilder<>* builder) const {
220   CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
221   // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
222   // instead. This will reuse linear() if possible, so we don't have to build a
223   // new 'linear_index'.
224   if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
225     return SourceIndexOfReshape(shape, operand_shape, builder);
226   }
227 
228   // First linearize the index coming from the output of the bitcast. We want
229   // the physical index of the element in the buffer. This is like Linearize,
230   // but takes the layout into account.
231   int64 scale = 1;
232   llvm::Value* linear_index = GetConstantWithIndexType(0);
233   for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
234     linear_index = builder->CreateAdd(
235         linear_index,
236         builder->CreateMul(multidim_[dimension],
237                            GetConstantWithIndexType(scale), "",
238                            /*HasNUW=*/true, /*HasNSW=*/true),
239         "", /*HasNUW=*/true, /*HasNSW=*/true);
240     scale *= shape.dimensions(dimension);
241   }
242 
243   // Now delinearize it for the input of the bitcast.
244   std::vector<llvm::Value*> multi_index(operand_shape.dimensions_size());
245   Delinearize(&multi_index, linear_index, operand_shape, builder);
246 
247   return Index(multi_index, linear_index, operand_shape, index_type_);
248 }
249 
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const250 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
251     const Shape& shape, const Shape& operand_shape,
252     absl::Span<const int64> dimension_mapping,
253     llvm::IRBuilder<>* builder) const {
254   int64 rank = operand_shape.rank();
255   std::vector<llvm::Value*> source_index(rank);
256   for (int64 i = 0; i < rank; ++i) {
257     source_index[i] = multidim_[dimension_mapping[i]];
258   }
259   if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
260       !LayoutUtil::HasLayout(shape)) {
261     return Index(source_index, index_type_);
262   }
263   // High-level idea: we can reuse the linear index if the broadcasted
264   // dimensions are contiguous, and this part of the operation is a bitcast.
265   // The other dimensions can be masked out with a div and a mod operation.
266   std::vector<int64> logical_to_physical =
267       LayoutUtil::MakeLogicalToPhysical(shape.layout());
268   int64 output_rank = shape.rank();
269   // The minimum physical dimension that is broadcasted.
270   int64 min_broadcasted_dimension = output_rank;
271   // The maximum physical dimension that is broadcasted.
272   int64 max_broadcasted_dimension = -1;
273   for (int64 i = 0; i < rank; ++i) {
274     int64 physical_dim = logical_to_physical[dimension_mapping[i]];
275     min_broadcasted_dimension =
276         std::min(min_broadcasted_dimension, physical_dim);
277     max_broadcasted_dimension =
278         std::max(max_broadcasted_dimension, physical_dim);
279   }
280   bool contiguous_broadcast_dimensions =
281       max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
282   if (!contiguous_broadcast_dimensions) {
283     return Index(source_index, index_type_);
284   }
285   // Check if the mapped dimensions are a bitcast.
286   std::vector<int64> operand_logical_to_physical =
287       LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
288   for (int64 i = 0; i < rank; ++i) {
289     if (operand_logical_to_physical[i] !=
290         logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
291       return Index(source_index, index_type_);
292     }
293   }
294   llvm::Value* linear = linear_;
295   int64 divisor = 1;
296   for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
297     divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
298   }
299   if (divisor > 1) {
300     linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
301   }
302   if (min_broadcasted_dimension > 0) {
303     int64 mod = 1;
304     for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
305          ++i) {
306       mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
307     }
308     linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
309   }
310   return Index(source_index, linear, operand_shape, index_type_);
311 }
312 
Linearize(absl::Span<const int64> dimensions,llvm::IRBuilder<> * builder) const313 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
314                                        llvm::IRBuilder<>* builder) const {
315   // Each dimension is multiplied by the product of the sizes of all
316   // earlier dimensions and added to the accumulator logical_linear_index.
317   CHECK_EQ(size(), dimensions.size());
318   llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
319   int64 multiplier = 1;
320   for (ssize_t i = size() - 1; i >= 0; --i) {
321     llvm::Value* addend =
322         builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
323                            /*HasNUW=*/true, /*HasNSW=*/true);
324     addend = builder->CreateZExtOrTrunc(addend, index_type_);
325     logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
326                                               /*HasNUW=*/true, /*HasNSW=*/true);
327     multiplier *= dimensions[i];
328   }
329   return logical_linear_index;
330 }
331 
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const332 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
333                                               llvm::IRBuilder<>* b,
334                                               absl::string_view name,
335                                               bool use_linear_index) const {
336   if (ShapeUtil::IsScalar(shape_)) {
337     // Special handling of scalars: a scalar pretends to have the same value for
338     // every index, thus effectively implementing broadcasting of its value
339     // over higher-rank arrays.
340     return base_ptr_;
341   }
342   CHECK_EQ(index.size(), shape_.rank());
343 
344   if (use_linear_index && index.LinearValidOnShape(shape_)) {
345     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
346     return b->CreateInBoundsGEP(
347         b->CreateBitCast(base_ptr_,
348                          PrimitiveTypeToIrType(shape_.element_type(), module)
349                              ->getPointerTo()),
350         {index.linear()}, llvm_ir::AsStringRef(name));
351   }
352 
353   std::vector<llvm::Value*> actual_index;
354   for (int64 i = 0; i < index.size(); ++i) {
355     // When dimension i is of size 1, LLVM optimization is able to replace
356     // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
357     // produce better code in some cases.
358     auto dim = shape_.dimensions(i);
359     actual_index.push_back(
360         dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
361   }
362 
363   // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
364   // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
365   // should be computed by
366   //
367   //   getelementptr base_ptr_, 0, most major index, ..., most minor index
368   CHECK_GT(index.size(), 0);
369   std::vector<llvm::Value*> gep_indices(
370       1, llvm::ConstantInt::get(index[0]->getType(), 0));
371   for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
372     int64 dimension = LayoutUtil::Major(shape_.layout(), i);
373     gep_indices.push_back(actual_index[dimension]);
374   }
375   return b->CreateInBoundsGEP(base_ptr_, gep_indices,
376                               llvm_ir::AsStringRef(name));
377 }
378 
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const379 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
380     llvm::Instruction* instruction) const {
381   CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
382         llvm::isa<llvm::StoreInst>(instruction));
383   CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
384       << "Trying to create a store to an invariant IRArray.";
385 
386   for (const auto& kind_md_pair : metadata_) {
387     instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
388   }
389 }
390 
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const391 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
392                                            llvm::IRBuilder<>* b,
393                                            absl::string_view name,
394                                            bool use_linear_index) const {
395   llvm::Value* element_address =
396       EmitArrayElementAddress(index, b, name, use_linear_index);
397   llvm::LoadInst* load = b->CreateLoad(element_address);
398   AnnotateLoadStoreInstructionWithMetadata(load);
399   return load;
400 }
401 
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const402 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
403                                     llvm::IRBuilder<>* b,
404                                     bool use_linear_index) const {
405   llvm::Value* element_address =
406       EmitArrayElementAddress(index, b, "", use_linear_index);
407   llvm::StoreInst* store = b->CreateStore(value, element_address);
408   AnnotateLoadStoreInstructionWithMetadata(store);
409 }
410 
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const411 IrArray IrArray::CastToShape(const Shape& new_shape,
412                              llvm::IRBuilder<>* b) const {
413   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
414   llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
415   IrArray new_irarray(
416       b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape);
417   new_irarray.metadata_ = metadata_;
418   return new_irarray;
419 }
420 
421 }  // namespace llvm_ir
422 }  // namespace xla
423