1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MatrixBuilder class, which is used as a convenient way
10 // to lower matrix operations to LLVM IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_IR_MATRIXBUILDER_H
15 #define LLVM_IR_MATRIXBUILDER_H
16 
17 #include "llvm/IR/Constant.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstrTypes.h"
21 #include "llvm/IR/Instruction.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/Type.h"
24 #include "llvm/IR/Value.h"
25 #include "llvm/Support/Alignment.h"
26 
27 namespace llvm {
28 
29 class Function;
30 class Twine;
31 class Module;
32 
33 template <class IRBuilderTy> class MatrixBuilder {
34   IRBuilderTy &B;
getModule()35   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
36 
splatScalarOperandIfNeeded(Value * LHS,Value * RHS)37   std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
38                                                          Value *RHS) {
39     assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
40            "One of the operands must be a matrix (embedded in a vector)");
41     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
42       assert(!isa<ScalableVectorType>(LHS->getType()) &&
43              "LHS Assumed to be fixed width");
44       RHS = B.CreateVectorSplat(
45           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
46           "scalar.splat");
47     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
48       assert(!isa<ScalableVectorType>(RHS->getType()) &&
49              "RHS Assumed to be fixed width");
50       LHS = B.CreateVectorSplat(
51           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
52           "scalar.splat");
53     }
54     return {LHS, RHS};
55   }
56 
57 public:
MatrixBuilder(IRBuilderTy & Builder)58   MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
59 
60   /// Create a column major, strided matrix load.
61   /// \p DataPtr - Start address of the matrix read
62   /// \p Rows    - Number of rows in matrix (must be a constant)
63   /// \p Columns - Number of columns in matrix (must be a constant)
64   /// \p Stride  - Space between columns
65   CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment,
66                                   Value *Stride, bool IsVolatile, unsigned Rows,
67                                   unsigned Columns, const Twine &Name = "") {
68 
69     // Deal with the pointer
70     PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
71     Type *EltTy = PtrTy->getElementType();
72 
73     auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
74 
75     Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
76                     B.getInt32(Columns)};
77     Type *OverloadedTypes[] = {RetType};
78 
79     Function *TheFn = Intrinsic::getDeclaration(
80         getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
81 
82     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
83     Attribute AlignAttr =
84         Attribute::getWithAlignment(Call->getContext(), Alignment);
85     Call->addAttribute(1, AlignAttr);
86     return Call;
87   }
88 
89   /// Create a column major, strided matrix store.
90   /// \p Matrix  - Matrix to store
91   /// \p Ptr     - Pointer to write back to
92   /// \p Stride  - Space between columns
93   CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
94                                    Value *Stride, bool IsVolatile,
95                                    unsigned Rows, unsigned Columns,
96                                    const Twine &Name = "") {
97     Value *Ops[] = {Matrix,           Ptr,
98                     Stride,           B.getInt1(IsVolatile),
99                     B.getInt32(Rows), B.getInt32(Columns)};
100     Type *OverloadedTypes[] = {Matrix->getType()};
101 
102     Function *TheFn = Intrinsic::getDeclaration(
103         getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
104 
105     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
106     Attribute AlignAttr =
107         Attribute::getWithAlignment(Call->getContext(), Alignment);
108     Call->addAttribute(2, AlignAttr);
109     return Call;
110   }
111 
112   /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
113   /// rows and \p Columns columns.
114   CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
115                                   unsigned Columns, const Twine &Name = "") {
116     auto *OpType = cast<VectorType>(Matrix->getType());
117     auto *ReturnType =
118         FixedVectorType::get(OpType->getElementType(), Rows * Columns);
119 
120     Type *OverloadedTypes[] = {ReturnType};
121     Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
122     Function *TheFn = Intrinsic::getDeclaration(
123         getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
124 
125     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
126   }
127 
128   /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
129   /// RHS.
130   CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
131                                  unsigned LHSColumns, unsigned RHSColumns,
132                                  const Twine &Name = "") {
133     auto *LHSType = cast<VectorType>(LHS->getType());
134     auto *RHSType = cast<VectorType>(RHS->getType());
135 
136     auto *ReturnType =
137         FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
138 
139     Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
140                     B.getInt32(RHSColumns)};
141     Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
142 
143     Function *TheFn = Intrinsic::getDeclaration(
144         getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
145     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
146   }
147 
148   /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
149   /// ColumnIdx).
CreateMatrixInsert(Value * Matrix,Value * NewVal,Value * RowIdx,Value * ColumnIdx,unsigned NumRows)150   Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
151                             Value *ColumnIdx, unsigned NumRows) {
152     return B.CreateInsertElement(
153         Matrix, NewVal,
154         B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
155                                                ColumnIdx->getType(), NumRows)),
156                     RowIdx));
157   }
158 
159   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
160   /// matrixes.
CreateAdd(Value * LHS,Value * RHS)161   Value *CreateAdd(Value *LHS, Value *RHS) {
162     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
163     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
164       assert(!isa<ScalableVectorType>(LHS->getType()) &&
165              "LHS Assumed to be fixed width");
166       RHS = B.CreateVectorSplat(
167           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
168           "scalar.splat");
169     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
170       assert(!isa<ScalableVectorType>(RHS->getType()) &&
171              "RHS Assumed to be fixed width");
172       LHS = B.CreateVectorSplat(
173           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
174           "scalar.splat");
175     }
176 
177     return cast<VectorType>(LHS->getType())
178                    ->getElementType()
179                    ->isFloatingPointTy()
180                ? B.CreateFAdd(LHS, RHS)
181                : B.CreateAdd(LHS, RHS);
182   }
183 
184   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
185   /// point matrixes.
CreateSub(Value * LHS,Value * RHS)186   Value *CreateSub(Value *LHS, Value *RHS) {
187     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
188     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
189       assert(!isa<ScalableVectorType>(LHS->getType()) &&
190              "LHS Assumed to be fixed width");
191       RHS = B.CreateVectorSplat(
192           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
193           "scalar.splat");
194     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
195       assert(!isa<ScalableVectorType>(RHS->getType()) &&
196              "RHS Assumed to be fixed width");
197       LHS = B.CreateVectorSplat(
198           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
199           "scalar.splat");
200     }
201 
202     return cast<VectorType>(LHS->getType())
203                    ->getElementType()
204                    ->isFloatingPointTy()
205                ? B.CreateFSub(LHS, RHS)
206                : B.CreateSub(LHS, RHS);
207   }
208 
209   /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
210   /// RHS.
CreateScalarMultiply(Value * LHS,Value * RHS)211   Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
212     std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
213     if (LHS->getType()->getScalarType()->isFloatingPointTy())
214       return B.CreateFMul(LHS, RHS);
215     return B.CreateMul(LHS, RHS);
216   }
217 
218   /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.
219   Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx,
220                               unsigned NumRows, Twine const &Name = "") {
221 
222     unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
223                                  ColumnIdx->getType()->getScalarSizeInBits());
224     Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
225     RowIdx = B.CreateZExt(RowIdx, IntTy);
226     ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
227     Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
228     return B.CreateExtractElement(
229         Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx),
230         "matext");
231   }
232 };
233 
234 } // end namespace llvm
235 
236 #endif // LLVM_IR_MATRIXBUILDER_H
237