1 /* 2 Copyright (c) 2011, Intel Corporation. All rights reserved. 3 4 Redistribution and use in source and binary forms, with or without modification, 5 are permitted provided that the following conditions are met: 6 7 * Redistributions of source code must retain the above copyright notice, this 8 list of conditions and the following disclaimer. 9 * Redistributions in binary form must reproduce the above copyright notice, 10 this list of conditions and the following disclaimer in the documentation 11 and/or other materials provided with the distribution. 12 * Neither the name of Intel Corporation nor the names of its contributors may 13 be used to endorse or promote products derived from this software without 14 specific prior written permission. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 20 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 23 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 27 ******************************************************************************** 28 * Content : Eigen bindings to BLAS F77 29 * Triangular matrix * matrix product functionality based on ?TRMM. 30 ******************************************************************************** 31 */ 32 33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 35 36 namespace Eigen { 37 38 namespace internal { 39 40 41 template <typename Scalar, typename Index, 42 int Mode, bool LhsIsTriangular, 43 int LhsStorageOrder, bool ConjugateLhs, 44 int RhsStorageOrder, bool ConjugateRhs, 45 int ResStorageOrder> 46 struct product_triangular_matrix_matrix_trmm : 47 product_triangular_matrix_matrix<Scalar,Index,Mode, 48 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, 49 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {}; 50 51 52 // try to go to BLAS specialization 53 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ 54 template <typename Index, int Mode, \ 55 int LhsStorageOrder, bool ConjugateLhs, \ 56 int RhsStorageOrder, bool ConjugateRhs> \ 57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \ 58 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \ 59 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\ 60 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \ 61 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \ 62 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \ 63 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \ 64 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 65 } \ 66 }; 67 68 EIGEN_BLAS_TRMM_SPECIALIZE(double, true) 69 EIGEN_BLAS_TRMM_SPECIALIZE(double, false) 70 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true) 71 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false) 72 EIGEN_BLAS_TRMM_SPECIALIZE(float, true) 73 EIGEN_BLAS_TRMM_SPECIALIZE(float, false) 74 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true) 75 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false) 76 77 // implements col-major += alpha * op(triangular) * op(general) 78 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ 79 template <typename Index, int Mode, \ 80 int LhsStorageOrder, bool ConjugateLhs, \ 81 int RhsStorageOrder, bool ConjugateRhs> \ 82 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ 83 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 84 { \ 85 enum { \ 86 IsLower = (Mode&Lower) == Lower, \ 87 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 88 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 89 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 90 LowUp = IsLower ? Lower : Upper, \ 91 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \ 92 }; \ 93 \ 94 static void run( \ 95 Index _rows, Index _cols, Index _depth, \ 96 const EIGTYPE* _lhs, Index lhsStride, \ 97 const EIGTYPE* _rhs, Index rhsStride, \ 98 EIGTYPE* res, Index resStride, \ 99 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 100 { \ 101 Index diagSize = (std::min)(_rows,_depth); \ 102 Index rows = IsLower ? _rows : diagSize; \ 103 Index depth = IsLower ? diagSize : _depth; \ 104 Index cols = _cols; \ 105 \ 106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 108 \ 109 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \ 110 if (rows != depth) { \ 111 \ 112 /* FIXME handle mkl_domain_get_max_threads */ \ 113 /*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\ 114 \ 115 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ 116 /* Most likely no benefit to call TRMM or GEMM from BLAS */ \ 117 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ 118 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 119 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 120 /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \ 121 } else { \ 122 /* Make sense to call GEMM */ \ 123 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 124 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ 125 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 126 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ 127 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 128 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ 129 \ 130 /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \ 131 } \ 132 return; \ 133 } \ 134 char side = 'L', transa, uplo, diag = 'N'; \ 135 EIGTYPE *b; \ 136 const EIGTYPE *a; \ 137 BlasIndex m, n, lda, ldb; \ 138 \ 139 /* Set m, n */ \ 140 m = convert_index<BlasIndex>(diagSize); \ 141 n = convert_index<BlasIndex>(cols); \ 142 \ 143 /* Set trans */ \ 144 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 145 \ 146 /* Set b, ldb */ \ 147 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \ 148 MatrixX##EIGPREFIX b_tmp; \ 149 \ 150 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ 151 b = b_tmp.data(); \ 152 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 153 \ 154 /* Set uplo */ \ 155 uplo = IsLower ? 'L' : 'U'; \ 156 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 157 /* Set a, lda */ \ 158 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 159 MatrixLhs a_tmp; \ 160 \ 161 if ((conjA!=0) || (SetDiag==0)) { \ 162 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \ 163 if (IsZeroDiag) \ 164 a_tmp.diagonal().setZero(); \ 165 else if (IsUnitDiag) \ 166 a_tmp.diagonal().setOnes();\ 167 a = a_tmp.data(); \ 168 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 169 } else { \ 170 a = _lhs; \ 171 lda = convert_index<BlasIndex>(lhsStride); \ 172 } \ 173 /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \ 174 /* call ?trmm*/ \ 175 BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ 176 \ 177 /* Add op(a_triangular)*b into res*/ \ 178 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 179 res_tmp=res_tmp+b_tmp; \ 180 } \ 181 }; 182 183 EIGEN_BLAS_TRMM_L(double, double, d, d) 184 EIGEN_BLAS_TRMM_L(dcomplex, double, cd, z) 185 EIGEN_BLAS_TRMM_L(float, float, f, s) 186 EIGEN_BLAS_TRMM_L(scomplex, float, cf, c) 187 188 // implements col-major += alpha * op(general) * op(triangular) 189 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \ 190 template <typename Index, int Mode, \ 191 int LhsStorageOrder, bool ConjugateLhs, \ 192 int RhsStorageOrder, bool ConjugateRhs> \ 193 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ 194 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 195 { \ 196 enum { \ 197 IsLower = (Mode&Lower) == Lower, \ 198 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 199 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 200 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 201 LowUp = IsLower ? Lower : Upper, \ 202 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \ 203 }; \ 204 \ 205 static void run( \ 206 Index _rows, Index _cols, Index _depth, \ 207 const EIGTYPE* _lhs, Index lhsStride, \ 208 const EIGTYPE* _rhs, Index rhsStride, \ 209 EIGTYPE* res, Index resStride, \ 210 EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \ 211 { \ 212 Index diagSize = (std::min)(_cols,_depth); \ 213 Index rows = _rows; \ 214 Index depth = IsLower ? _depth : diagSize; \ 215 Index cols = IsLower ? diagSize : _cols; \ 216 \ 217 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 218 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 219 \ 220 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \ 221 if (cols != depth) { \ 222 \ 223 int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \ 224 \ 225 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ 226 /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \ 227 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ 228 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 229 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \ 230 /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \ 231 } else { \ 232 /* Make sense to call GEMM */ \ 233 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 234 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ 235 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 236 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \ 237 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 238 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ 239 \ 240 /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \ 241 } \ 242 return; \ 243 } \ 244 char side = 'R', transa, uplo, diag = 'N'; \ 245 EIGTYPE *b; \ 246 const EIGTYPE *a; \ 247 BlasIndex m, n, lda, ldb; \ 248 \ 249 /* Set m, n */ \ 250 m = convert_index<BlasIndex>(rows); \ 251 n = convert_index<BlasIndex>(diagSize); \ 252 \ 253 /* Set trans */ \ 254 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 255 \ 256 /* Set b, ldb */ \ 257 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 258 MatrixX##EIGPREFIX b_tmp; \ 259 \ 260 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ 261 b = b_tmp.data(); \ 262 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 263 \ 264 /* Set uplo */ \ 265 uplo = IsLower ? 'L' : 'U'; \ 266 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 267 /* Set a, lda */ \ 268 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 269 MatrixRhs a_tmp; \ 270 \ 271 if ((conjA!=0) || (SetDiag==0)) { \ 272 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \ 273 if (IsZeroDiag) \ 274 a_tmp.diagonal().setZero(); \ 275 else if (IsUnitDiag) \ 276 a_tmp.diagonal().setOnes();\ 277 a = a_tmp.data(); \ 278 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 279 } else { \ 280 a = _rhs; \ 281 lda = convert_index<BlasIndex>(rhsStride); \ 282 } \ 283 /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \ 284 /* call ?trmm*/ \ 285 BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \ 286 \ 287 /* Add op(a_triangular)*b into res*/ \ 288 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 289 res_tmp=res_tmp+b_tmp; \ 290 } \ 291 }; 292 293 EIGEN_BLAS_TRMM_R(double, double, d, d) 294 EIGEN_BLAS_TRMM_R(dcomplex, double, cd, z) 295 EIGEN_BLAS_TRMM_R(float, float, f, s) 296 EIGEN_BLAS_TRMM_R(scomplex, float, cf, c) 297 298 } // end namespace internal 299 300 } // end namespace Eigen 301 302 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 303