1 /*
2  Copyright (c) 2011, Intel Corporation. All rights reserved.
3 
4  Redistribution and use in source and binary forms, with or without modification,
5  are permitted provided that the following conditions are met:
6 
7  * Redistributions of source code must retain the above copyright notice, this
8    list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright notice,
10    this list of conditions and the following disclaimer in the documentation
11    and/or other materials provided with the distribution.
12  * Neither the name of Intel Corporation nor the names of its contributors may
13    be used to endorse or promote products derived from this software without
14    specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23  ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 
27  ********************************************************************************
28  *   Content : Eigen bindings to BLAS F77
29  *   Triangular matrix * matrix product functionality based on ?TRMM.
30  ********************************************************************************
31 */
32 
33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
35 
36 namespace Eigen {
37 
38 namespace internal {
39 
40 
41 template <typename Scalar, typename Index,
42           int Mode, bool LhsIsTriangular,
43           int LhsStorageOrder, bool ConjugateLhs,
44           int RhsStorageOrder, bool ConjugateRhs,
45           int ResStorageOrder>
46 struct product_triangular_matrix_matrix_trmm :
47        product_triangular_matrix_matrix<Scalar,Index,Mode,
48           LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49           RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
50 
51 
52 // try to go to BLAS specialization
53 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
54 template <typename Index, int Mode, \
55           int LhsStorageOrder, bool ConjugateLhs, \
56           int RhsStorageOrder, bool ConjugateRhs> \
57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
58            LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \
59   static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
60     const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
61       product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
62         LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
63         RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
64         _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
65   } \
66 };
67 
68 EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
69 EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
70 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
71 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
72 EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
73 EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
74 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
75 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
76 
77 // implements col-major += alpha * op(triangular) * op(general)
78 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
79 template <typename Index, int Mode, \
80           int LhsStorageOrder, bool ConjugateLhs, \
81           int RhsStorageOrder, bool ConjugateRhs> \
82 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
83          LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
84 { \
85   enum { \
86     IsLower = (Mode&Lower) == Lower, \
87     SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
88     IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
89     IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
90     LowUp = IsLower ? Lower : Upper, \
91     conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
92   }; \
93 \
94   static void run( \
95     Index _rows, Index _cols, Index _depth, \
96     const EIGTYPE* _lhs, Index lhsStride, \
97     const EIGTYPE* _rhs, Index rhsStride, \
98     EIGTYPE* res,        Index resStride, \
99     EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
100   { \
101    Index diagSize  = (std::min)(_rows,_depth); \
102    Index rows      = IsLower ? _rows : diagSize; \
103    Index depth     = IsLower ? diagSize : _depth; \
104    Index cols      = _cols; \
105 \
106    typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
107    typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
108 \
109 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
110    if (rows != depth) { \
111 \
112      /* FIXME handle mkl_domain_get_max_threads */ \
113      /*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\
114 \
115      if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
116      /* Most likely no benefit to call TRMM or GEMM from BLAS */ \
117        product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
118        LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
119            _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
120      /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
121      } else { \
122      /* Make sense to call GEMM */ \
123        Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
124        MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
125        BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
126        gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
127        general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
128        rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
129 \
130      /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
131      } \
132      return; \
133    } \
134    char side = 'L', transa, uplo, diag = 'N'; \
135    EIGTYPE *b; \
136    const EIGTYPE *a; \
137    BlasIndex m, n, lda, ldb; \
138 \
139 /* Set m, n */ \
140    m = convert_index<BlasIndex>(diagSize); \
141    n = convert_index<BlasIndex>(cols); \
142 \
143 /* Set trans */ \
144    transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
145 \
146 /* Set b, ldb */ \
147    Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
148    MatrixX##EIGPREFIX b_tmp; \
149 \
150    if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
151    b = b_tmp.data(); \
152    ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
153 \
154 /* Set uplo */ \
155    uplo = IsLower ? 'L' : 'U'; \
156    if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
157 /* Set a, lda */ \
158    Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
159    MatrixLhs a_tmp; \
160 \
161    if ((conjA!=0) || (SetDiag==0)) { \
162      if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
163      if (IsZeroDiag) \
164        a_tmp.diagonal().setZero(); \
165      else if (IsUnitDiag) \
166        a_tmp.diagonal().setOnes();\
167      a = a_tmp.data(); \
168      lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
169    } else { \
170      a = _lhs; \
171      lda = convert_index<BlasIndex>(lhsStride); \
172    } \
173    /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
174 /* call ?trmm*/ \
175    BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
176 \
177 /* Add op(a_triangular)*b into res*/ \
178    Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
179    res_tmp=res_tmp+b_tmp; \
180   } \
181 };
182 
183 EIGEN_BLAS_TRMM_L(double, double, d, d)
184 EIGEN_BLAS_TRMM_L(dcomplex, double, cd, z)
185 EIGEN_BLAS_TRMM_L(float, float, f, s)
186 EIGEN_BLAS_TRMM_L(scomplex, float, cf, c)
187 
188 // implements col-major += alpha * op(general) * op(triangular)
189 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
190 template <typename Index, int Mode, \
191           int LhsStorageOrder, bool ConjugateLhs, \
192           int RhsStorageOrder, bool ConjugateRhs> \
193 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
194          LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
195 { \
196   enum { \
197     IsLower = (Mode&Lower) == Lower, \
198     SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
199     IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
200     IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
201     LowUp = IsLower ? Lower : Upper, \
202     conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
203   }; \
204 \
205   static void run( \
206     Index _rows, Index _cols, Index _depth, \
207     const EIGTYPE* _lhs, Index lhsStride, \
208     const EIGTYPE* _rhs, Index rhsStride, \
209     EIGTYPE* res,        Index resStride, \
210     EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
211   { \
212    Index diagSize  = (std::min)(_cols,_depth); \
213    Index rows      = _rows; \
214    Index depth     = IsLower ? _depth : diagSize; \
215    Index cols      = IsLower ? diagSize : _cols; \
216 \
217    typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
218    typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
219 \
220 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
221    if (cols != depth) { \
222 \
223      int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
224 \
225      if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
226      /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
227        product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
228        LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
229            _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
230        /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
231      } else { \
232      /* Make sense to call GEMM */ \
233        Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
234        MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
235        BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
236        gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
237        general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
238        rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
239 \
240      /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
241      } \
242      return; \
243    } \
244    char side = 'R', transa, uplo, diag = 'N'; \
245    EIGTYPE *b; \
246    const EIGTYPE *a; \
247    BlasIndex m, n, lda, ldb; \
248 \
249 /* Set m, n */ \
250    m = convert_index<BlasIndex>(rows); \
251    n = convert_index<BlasIndex>(diagSize); \
252 \
253 /* Set trans */ \
254    transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
255 \
256 /* Set b, ldb */ \
257    Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
258    MatrixX##EIGPREFIX b_tmp; \
259 \
260    if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
261    b = b_tmp.data(); \
262    ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
263 \
264 /* Set uplo */ \
265    uplo = IsLower ? 'L' : 'U'; \
266    if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
267 /* Set a, lda */ \
268    Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
269    MatrixRhs a_tmp; \
270 \
271    if ((conjA!=0) || (SetDiag==0)) { \
272      if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
273      if (IsZeroDiag) \
274        a_tmp.diagonal().setZero(); \
275      else if (IsUnitDiag) \
276        a_tmp.diagonal().setOnes();\
277      a = a_tmp.data(); \
278      lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
279    } else { \
280      a = _rhs; \
281      lda = convert_index<BlasIndex>(rhsStride); \
282    } \
283    /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
284 /* call ?trmm*/ \
285    BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
286 \
287 /* Add op(a_triangular)*b into res*/ \
288    Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
289    res_tmp=res_tmp+b_tmp; \
290   } \
291 };
292 
293 EIGEN_BLAS_TRMM_R(double, double, d, d)
294 EIGEN_BLAS_TRMM_R(dcomplex, double, cd, z)
295 EIGEN_BLAS_TRMM_R(float, float, f, s)
296 EIGEN_BLAS_TRMM_R(scomplex, float, cf, c)
297 
298 } // end namespace internal
299 
300 } // end namespace Eigen
301 
302 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
303