1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H 11 #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 // if the rhs is row major, let's transpose the product 18 template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder> 19 struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor> 20 { 21 static void run( 22 Index size, Index cols, 23 const Scalar* tri, Index triStride, 24 Scalar* _other, Index otherStride, 25 level3_blocking<Scalar,Scalar>& blocking) 26 { 27 triangular_solve_matrix< 28 Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft, 29 (Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper), 30 NumTraits<Scalar>::IsComplex && Conjugate, 31 TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor> 32 ::run(size, cols, tri, triStride, _other, otherStride, blocking); 33 } 34 }; 35 36 /* Optimized triangular solver with multiple right hand side and the triangular matrix on the left 37 */ 38 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder> 39 struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> 40 { 41 static EIGEN_DONT_INLINE void run( 42 Index size, Index otherSize, 43 const Scalar* _tri, Index triStride, 44 Scalar* _other, Index otherStride, 45 level3_blocking<Scalar,Scalar>& blocking); 46 }; 47 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder> 48 EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>::run( 49 Index size, Index otherSize, 50 const Scalar* _tri, Index triStride, 51 Scalar* _other, Index otherStride, 52 level3_blocking<Scalar,Scalar>& blocking) 53 { 54 Index cols = otherSize; 55 56 typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper; 57 typedef blas_data_mapper<Scalar, Index, ColMajor> OtherMapper; 58 TriMapper tri(_tri, triStride); 59 OtherMapper other(_other, otherStride); 60 61 typedef gebp_traits<Scalar,Scalar> Traits; 62 63 enum { 64 SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr), 65 IsLower = (Mode&Lower) == Lower 66 }; 67 68 Index kc = blocking.kc(); // cache block size along the K direction 69 Index mc = (std::min)(size,blocking.mc()); // cache block size along the M direction 70 71 std::size_t sizeA = kc*mc; 72 std::size_t sizeB = kc*cols; 73 74 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); 75 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); 76 77 conj_if<Conjugate> conj; 78 gebp_kernel<Scalar, Scalar, Index, OtherMapper, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel; 79 gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, TriStorageOrder> pack_lhs; 80 gemm_pack_rhs<Scalar, Index, OtherMapper, Traits::nr, ColMajor, false, true> pack_rhs; 81 82 // the goal here is to subdivise the Rhs panels such that we keep some cache 83 // coherence when accessing the rhs elements 84 std::ptrdiff_t l1, l2, l3; 85 manage_caching_sizes(GetAction, &l1, &l2, &l3); 86 Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * std::max<Index>(otherStride,size)) : 0; 87 subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr); 88 89 for(Index k2=IsLower ? 0 : size; 90 IsLower ? k2<size : k2>0; 91 IsLower ? k2+=kc : k2-=kc) 92 { 93 const Index actual_kc = (std::min)(IsLower ? size-k2 : k2, kc); 94 95 // We have selected and packed a big horizontal panel R1 of rhs. Let B be the packed copy of this panel, 96 // and R2 the remaining part of rhs. The corresponding vertical panel of lhs is split into 97 // A11 (the triangular part) and A21 the remaining rectangular part. 98 // Then the high level algorithm is: 99 // - B = R1 => general block copy (done during the next step) 100 // - R1 = A11^-1 B => tricky part 101 // - update B from the new R1 => actually this has to be performed continuously during the above step 102 // - R2 -= A21 * B => GEPP 103 104 // The tricky part: compute R1 = A11^-1 B while updating B from R1 105 // The idea is to split A11 into multiple small vertical panels. 106 // Each panel can be split into a small triangular part T1k which is processed without optimization, 107 // and the remaining small part T2k which is processed using gebp with appropriate block strides 108 for(Index j2=0; j2<cols; j2+=subcols) 109 { 110 Index actual_cols = (std::min)(cols-j2,subcols); 111 // for each small vertical panels [T1k^T, T2k^T]^T of lhs 112 for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth) 113 { 114 Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth); 115 // tr solve 116 for (Index k=0; k<actualPanelWidth; ++k) 117 { 118 // TODO write a small kernel handling this (can be shared with trsv) 119 Index i = IsLower ? k2+k1+k : k2-k1-k-1; 120 Index rs = actualPanelWidth - k - 1; // remaining size 121 Index s = TriStorageOrder==RowMajor ? (IsLower ? k2+k1 : i+1) 122 : IsLower ? i+1 : i-rs; 123 124 Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i)); 125 for (Index j=j2; j<j2+actual_cols; ++j) 126 { 127 if (TriStorageOrder==RowMajor) 128 { 129 Scalar b(0); 130 const Scalar* l = &tri(i,s); 131 Scalar* r = &other(s,j); 132 for (Index i3=0; i3<k; ++i3) 133 b += conj(l[i3]) * r[i3]; 134 135 other(i,j) = (other(i,j) - b)*a; 136 } 137 else 138 { 139 Scalar b = (other(i,j) *= a); 140 Scalar* r = &other(s,j); 141 const Scalar* l = &tri(s,i); 142 for (Index i3=0;i3<rs;++i3) 143 r[i3] -= b * conj(l[i3]); 144 } 145 } 146 } 147 148 Index lengthTarget = actual_kc-k1-actualPanelWidth; 149 Index startBlock = IsLower ? k2+k1 : k2-k1-actualPanelWidth; 150 Index blockBOffset = IsLower ? k1 : lengthTarget; 151 152 // update the respective rows of B from other 153 pack_rhs(blockB+actual_kc*j2, other.getSubMapper(startBlock,j2), actualPanelWidth, actual_cols, actual_kc, blockBOffset); 154 155 // GEBP 156 if (lengthTarget>0) 157 { 158 Index startTarget = IsLower ? k2+k1+actualPanelWidth : k2-actual_kc; 159 160 pack_lhs(blockA, tri.getSubMapper(startTarget,startBlock), actualPanelWidth, lengthTarget); 161 162 gebp_kernel(other.getSubMapper(startTarget,j2), blockA, blockB+actual_kc*j2, lengthTarget, actualPanelWidth, actual_cols, Scalar(-1), 163 actualPanelWidth, actual_kc, 0, blockBOffset); 164 } 165 } 166 } 167 168 // R2 -= A21 * B => GEPP 169 { 170 Index start = IsLower ? k2+kc : 0; 171 Index end = IsLower ? size : k2-kc; 172 for(Index i2=start; i2<end; i2+=mc) 173 { 174 const Index actual_mc = (std::min)(mc,end-i2); 175 if (actual_mc>0) 176 { 177 pack_lhs(blockA, tri.getSubMapper(i2, IsLower ? k2 : k2-kc), actual_kc, actual_mc); 178 179 gebp_kernel(other.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1), -1, -1, 0, 0); 180 } 181 } 182 } 183 } 184 } 185 186 /* Optimized triangular solver with multiple left hand sides and the triangular matrix on the right 187 */ 188 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder> 189 struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> 190 { 191 static EIGEN_DONT_INLINE void run( 192 Index size, Index otherSize, 193 const Scalar* _tri, Index triStride, 194 Scalar* _other, Index otherStride, 195 level3_blocking<Scalar,Scalar>& blocking); 196 }; 197 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder> 198 EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>::run( 199 Index size, Index otherSize, 200 const Scalar* _tri, Index triStride, 201 Scalar* _other, Index otherStride, 202 level3_blocking<Scalar,Scalar>& blocking) 203 { 204 Index rows = otherSize; 205 typedef typename NumTraits<Scalar>::Real RealScalar; 206 207 typedef blas_data_mapper<Scalar, Index, ColMajor> LhsMapper; 208 typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper; 209 LhsMapper lhs(_other, otherStride); 210 RhsMapper rhs(_tri, triStride); 211 212 typedef gebp_traits<Scalar,Scalar> Traits; 213 enum { 214 RhsStorageOrder = TriStorageOrder, 215 SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr), 216 IsLower = (Mode&Lower) == Lower 217 }; 218 219 Index kc = blocking.kc(); // cache block size along the K direction 220 Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction 221 222 std::size_t sizeA = kc*mc; 223 std::size_t sizeB = kc*size; 224 225 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); 226 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); 227 228 conj_if<Conjugate> conj; 229 gebp_kernel<Scalar, Scalar, Index, LhsMapper, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel; 230 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs; 231 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder,false,true> pack_rhs_panel; 232 gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, ColMajor, false, true> pack_lhs_panel; 233 234 for(Index k2=IsLower ? size : 0; 235 IsLower ? k2>0 : k2<size; 236 IsLower ? k2-=kc : k2+=kc) 237 { 238 const Index actual_kc = (std::min)(IsLower ? k2 : size-k2, kc); 239 Index actual_k2 = IsLower ? k2-actual_kc : k2 ; 240 241 Index startPanel = IsLower ? 0 : k2+actual_kc; 242 Index rs = IsLower ? actual_k2 : size - actual_k2 - actual_kc; 243 Scalar* geb = blockB+actual_kc*actual_kc; 244 245 if (rs>0) pack_rhs(geb, rhs.getSubMapper(actual_k2,startPanel), actual_kc, rs); 246 247 // triangular packing (we only pack the panels off the diagonal, 248 // neglecting the blocks overlapping the diagonal 249 { 250 for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth) 251 { 252 Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth); 253 Index actual_j2 = actual_k2 + j2; 254 Index panelOffset = IsLower ? j2+actualPanelWidth : 0; 255 Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2; 256 257 if (panelLength>0) 258 pack_rhs_panel(blockB+j2*actual_kc, 259 rhs.getSubMapper(actual_k2+panelOffset, actual_j2), 260 panelLength, actualPanelWidth, 261 actual_kc, panelOffset); 262 } 263 } 264 265 for(Index i2=0; i2<rows; i2+=mc) 266 { 267 const Index actual_mc = (std::min)(mc,rows-i2); 268 269 // triangular solver kernel 270 { 271 // for each small block of the diagonal (=> vertical panels of rhs) 272 for (Index j2 = IsLower 273 ? (actual_kc - ((actual_kc%SmallPanelWidth) ? Index(actual_kc%SmallPanelWidth) 274 : Index(SmallPanelWidth))) 275 : 0; 276 IsLower ? j2>=0 : j2<actual_kc; 277 IsLower ? j2-=SmallPanelWidth : j2+=SmallPanelWidth) 278 { 279 Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth); 280 Index absolute_j2 = actual_k2 + j2; 281 Index panelOffset = IsLower ? j2+actualPanelWidth : 0; 282 Index panelLength = IsLower ? actual_kc - j2 - actualPanelWidth : j2; 283 284 // GEBP 285 if(panelLength>0) 286 { 287 gebp_kernel(lhs.getSubMapper(i2,absolute_j2), 288 blockA, blockB+j2*actual_kc, 289 actual_mc, panelLength, actualPanelWidth, 290 Scalar(-1), 291 actual_kc, actual_kc, // strides 292 panelOffset, panelOffset); // offsets 293 } 294 295 // unblocked triangular solve 296 for (Index k=0; k<actualPanelWidth; ++k) 297 { 298 Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k; 299 300 Scalar* r = &lhs(i2,j); 301 for (Index k3=0; k3<k; ++k3) 302 { 303 Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j)); 304 Scalar* a = &lhs(i2,IsLower ? j+1+k3 : absolute_j2+k3); 305 for (Index i=0; i<actual_mc; ++i) 306 r[i] -= a[i] * b; 307 } 308 if((Mode & UnitDiag)==0) 309 { 310 Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j)); 311 for (Index i=0; i<actual_mc; ++i) 312 r[i] *= inv_rjj; 313 } 314 } 315 316 // pack the just computed part of lhs to A 317 pack_lhs_panel(blockA, LhsMapper(_other+absolute_j2*otherStride+i2, otherStride), 318 actualPanelWidth, actual_mc, 319 actual_kc, j2); 320 } 321 } 322 323 if (rs>0) 324 gebp_kernel(lhs.getSubMapper(i2, startPanel), blockA, geb, 325 actual_mc, actual_kc, rs, Scalar(-1), 326 -1, -1, 0, 0); 327 } 328 } 329 } 330 331 } // end namespace internal 332 333 } // end namespace Eigen 334 335 #endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H 336