1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // kernel_msa.h: a collection of MSA optimized kernels. 16 // Check in kernel_default.h which one(s) are actually used by default. 17 // Others are mere experiments; they are still covered by tests 18 // in case they might be useful some day. 19 20 #ifndef GEMMLOWP_INTERNAL_KERNEL_MSA_H_ 21 #define GEMMLOWP_INTERNAL_KERNEL_MSA_H_ 22 23 #include "kernel.h" 24 25 #include <msa.h> 26 #include <cassert> 27 28 namespace gemmlowp { 29 30 #ifdef GEMMLOWP_MSA 31 32 // Some convenience macros to hide differences between MIPS32 and MIPS64. 33 #ifdef GEMMLOWP_MIPS_64 34 #define GEMMLOWP_MIPS_XADDU "daddu" 35 #define GEMMLOWP_MIPS_XADDIU "daddiu" 36 #define GEMMLOWP_MIPS_XSLL "dsll" 37 #else 38 #define GEMMLOWP_MIPS_XADDU "addu" 39 #define GEMMLOWP_MIPS_XADDIU "addiu" 40 #define GEMMLOWP_MIPS_XSLL "sll" 41 #endif 42 43 // Our main GEMM kernel. 44 struct MSA_Kernel12x8Depth2 : KernelBase { 45 typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>, 46 KernelSideFormat<CellFormat<4, 2>, 2> > 47 Format; 48 NameMSA_Kernel12x8Depth249 const char* Name() const override { return "MSA, 12x8, depth 2"; } 50 51 // TODO(benoitjacob): reorder function arguments so dst comes last RunMSA_Kernel12x8Depth252 void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, 53 std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, 54 const std::uint8_t* rhs_ptr, std::size_t start_depth, 55 std::size_t run_depth) const override { 56 ScopedProfilingLabel label("optimized kernel (MSA 12x8)"); 57 // See comments above for why we need local numerical labels in our asm. 58 #define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1" 59 #define GEMMLOWP_LABEL_BEFORE_LOOP "2" 60 #define GEMMLOWP_LABEL_LOOP "3" 61 #define GEMMLOWP_LABEL_AFTER_LOOP "4" 62 63 assert(dst_row_stride == 1); 64 asm volatile( 65 // Set a temp to all zeroes. 66 "ldi.b $w31, 0\n" 67 68 // Multiply dst_col_stride by 4 == sizeof(int32) to use 69 // it as a byte offset below. 70 GEMMLOWP_MIPS_XSLL 71 " %[dst_col_stride], %[dst_col_stride], 2\n" 72 73 // Check if start_depth==0 to decide whether we will clear 74 // accumulators or load existing accumulators. 75 "beqz %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n" 76 77 // Load accumulators (start_depth != 0). 78 GEMMLOWP_MIPS_XADDU 79 " $a0, %[dst_ptr], %[dst_col_stride]\n" 80 "ld.w $w0, (0*16)(%[dst_ptr])\n" 81 "ld.w $w4, (1*16)(%[dst_ptr])\n" 82 "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU 83 " $a1, $a0, %[dst_col_stride]\n" 84 "ld.w $w1, (0*16)($a0)\n" 85 "ld.w $w5, (1*16)($a0)\n" 86 "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU 87 " $a0, $a1, %[dst_col_stride]\n" 88 "ld.w $w2, (0*16)($a1)\n" 89 "ld.w $w6, (1*16)($a1)\n" 90 "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU 91 " $a1, $a0, %[dst_col_stride]\n" 92 "ld.w $w3, (0*16)($a0)\n" 93 "ld.w $w7, (1*16)($a0)\n" 94 "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU 95 " $a0, $a1, %[dst_col_stride]\n" 96 "ld.w $w12, (0*16)($a1)\n" 97 "ld.w $w16, (1*16)($a1)\n" 98 "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU 99 " $a1, $a0, %[dst_col_stride]\n" 100 "ld.w $w13, (0*16)($a0)\n" 101 "ld.w $w17, (1*16)($a0)\n" 102 "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU 103 " $a0, $a1, %[dst_col_stride]\n" 104 "ld.w $w14, (0*16)($a1)\n" 105 "ld.w $w18, (1*16)($a1)\n" 106 "ld.w $w22, (2*16)($a1)\n" 107 "ld.w $w15, (0*16)($a0)\n" 108 "ld.w $w19, (1*16)($a0)\n" 109 "ld.w $w23, (2*16)($a0)\n" 110 "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n" 111 112 GEMMLOWP_LABEL_CLEAR_ACCUMULATORS 113 ":\n" 114 // Clear accumulators (start_depth == 0). 115 "ldi.w $w0, 0\n" 116 "ldi.w $w4, 0\n" 117 "ldi.w $w8, 0\n" 118 "ldi.w $w1, 0\n" 119 "ldi.w $w5, 0\n" 120 "ldi.w $w9, 0\n" 121 "ldi.w $w2, 0\n" 122 "ldi.w $w6, 0\n" 123 "ldi.w $w10, 0\n" 124 "ldi.w $w3, 0\n" 125 "ldi.w $w7, 0\n" 126 "ldi.w $w11, 0\n" 127 "ldi.w $w12, 0\n" 128 "ldi.w $w16, 0\n" 129 "ldi.w $w20, 0\n" 130 "ldi.w $w13, 0\n" 131 "ldi.w $w17, 0\n" 132 "ldi.w $w21, 0\n" 133 "ldi.w $w14, 0\n" 134 "ldi.w $w18, 0\n" 135 "ldi.w $w22, 0\n" 136 "ldi.w $w15, 0\n" 137 "ldi.w $w19, 0\n" 138 "ldi.w $w23, 0\n" 139 140 GEMMLOWP_LABEL_BEFORE_LOOP ":\n" 141 142 GEMMLOWP_LABEL_LOOP 143 ":\n" 144 // Overview of register layout: 145 // 146 // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30 147 // (each register contains 4 replicas of a pair of elements). 148 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26. 149 // A 12x8 block of accumulators is stored in 32bit in w0-w23. 150 // 151 // +------+------+------+------+ 152 // Rhs |w27 |w28 |w29 |w30 | 153 // +------+------+------+------+ 154 // 155 // | | | | | 156 // 157 // Lhs | | | | | 158 // 159 // +---+ - - - - +------+------+------+------+ 160 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | 161 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | 162 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | 163 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | 164 // +---+ - - - - +------+------+------+------+ 165 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | 166 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | 167 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | 168 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | 169 // +---+ - - - - +------+------+------+------+ 170 // |w26| |w8/20 |w9/21 |w10/22|w11/23| 171 // |w26| |w8/20 |w9/21 |w10/22|w11/23| 172 // |w26| |w8/20 |w9/21 |w10/22|w11/23| 173 // |w26| |w8/20 |w9/21 |w10/22|w11/23| 174 // +---+ - - - - +------+------+------+------+ 175 // 176 // Accumulators 177 178 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. 179 "ld.b $w24, 0(%[lhs_ptr])\n" 180 "ld.b $w25, 8(%[lhs_ptr])\n" 181 182 // Load 4 bytes of rhs[] for the first half of depth 0. 183 "lbu $a0, 0(%[rhs_ptr])\n" 184 "lbu $a1, 1(%[rhs_ptr])\n" 185 "lbu $a2, 2(%[rhs_ptr])\n" 186 "lbu $a3, 3(%[rhs_ptr])\n" 187 // Load 4 bytes of rhs[] for the first half of depth 1. 188 "lbu $v0, 4(%[rhs_ptr])\n" 189 "lbu $v1, 5(%[rhs_ptr])\n" 190 "lbu $t8, 6(%[rhs_ptr])\n" 191 "lbu $t9, 7(%[rhs_ptr])\n" 192 193 // Zero-extend 8-bit elements of lhs[] to 16 bits. 194 "ilvr.b $w24, $w31, $w24\n" 195 "ilvl.b $w26, $w31, $w25\n" 196 "ilvr.b $w25, $w31, $w25\n" 197 // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. 198 "ilvl.d $w27, $w31, $w24\n" 199 "ilvl.d $w28, $w31, $w25\n" 200 "ilvl.d $w29, $w31, $w26\n" 201 "ilvr.h $w24, $w27, $w24\n" 202 "ilvr.h $w25, $w28, $w25\n" 203 "ilvr.h $w26, $w29, $w26\n" 204 205 // Combine and interleave depth 0 and depth 1 elements of rhs[] for 206 // dpadd_u.w (for the first half). 207 "ins $a0, $v0, 16, 8\n" 208 "ins $a1, $v1, 16, 8\n" 209 "ins $a2, $t8, 16, 8\n" 210 "ins $a3, $t9, 16, 8\n" 211 // Make 4 replicas of every pair of rhs[] elements. 212 "fill.w $w27, $a0\n" 213 "fill.w $w28, $a1\n" 214 "fill.w $w29, $a2\n" 215 "fill.w $w30, $a3\n" 216 217 // Load 4 bytes of rhs[] for the second half of depth 0. 218 "lbu $a0, 8(%[rhs_ptr])\n" 219 "lbu $a1, 9(%[rhs_ptr])\n" 220 "lbu $a2, 10(%[rhs_ptr])\n" 221 "lbu $a3, 11(%[rhs_ptr])\n" 222 // Load 4 bytes of rhs[] for the second half of depth 1. 223 "lbu $v0, 12(%[rhs_ptr])\n" 224 "lbu $v1, 13(%[rhs_ptr])\n" 225 "lbu $t8, 14(%[rhs_ptr])\n" 226 "lbu $t9, 15(%[rhs_ptr])\n" 227 228 // First half of depths 0 and 1. 229 // Dot-product-(and)-add doubles multiplicand width. 230 "dpadd_u.w $w0, $w24, $w27\n" 231 "dpadd_u.w $w4, $w25, $w27\n" 232 "dpadd_u.w $w8, $w26, $w27\n" 233 "dpadd_u.w $w1, $w24, $w28\n" 234 "dpadd_u.w $w5, $w25, $w28\n" 235 "dpadd_u.w $w9, $w26, $w28\n" 236 "dpadd_u.w $w2, $w24, $w29\n" 237 "dpadd_u.w $w6, $w25, $w29\n" 238 "dpadd_u.w $w10, $w26, $w29\n" 239 "dpadd_u.w $w3, $w24, $w30\n" 240 "dpadd_u.w $w7, $w25, $w30\n" 241 "dpadd_u.w $w11, $w26, $w30\n" 242 243 // Combine and interleave depth 0 and depth 1 elements of rhs[] for 244 // dpadd_u.w (for the second half). 245 "ins $a0, $v0, 16, 8\n" 246 "ins $a1, $v1, 16, 8\n" 247 "ins $a2, $t8, 16, 8\n" 248 "ins $a3, $t9, 16, 8\n" 249 // Make 4 replicas of every pair of rhs[] elements. 250 "fill.w $w27, $a0\n" 251 "fill.w $w28, $a1\n" 252 "fill.w $w29, $a2\n" 253 "fill.w $w30, $a3\n" 254 255 // Second half of depths 0 and 1. 256 // Dot-product-(and)-add doubles multiplicand width. 257 "dpadd_u.w $w12, $w24, $w27\n" 258 "dpadd_u.w $w16, $w25, $w27\n" 259 "dpadd_u.w $w20, $w26, $w27\n" 260 "dpadd_u.w $w13, $w24, $w28\n" 261 "dpadd_u.w $w17, $w25, $w28\n" 262 "dpadd_u.w $w21, $w26, $w28\n" 263 "dpadd_u.w $w14, $w24, $w29\n" 264 "dpadd_u.w $w18, $w25, $w29\n" 265 "dpadd_u.w $w22, $w26, $w29\n" 266 "dpadd_u.w $w15, $w24, $w30\n" 267 "dpadd_u.w $w19, $w25, $w30\n" 268 "dpadd_u.w $w23, $w26, $w30\n" 269 270 GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU 271 " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU 272 " %[rhs_ptr], 16\n" 273 "bnez %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n" 274 275 GEMMLOWP_LABEL_AFTER_LOOP ":\n" 276 277 // Store accumulators. 278 GEMMLOWP_MIPS_XADDU 279 " $a0, %[dst_ptr], %[dst_col_stride]\n" 280 "st.w $w0, (0*16)(%[dst_ptr])\n" 281 "st.w $w4, (1*16)(%[dst_ptr])\n" 282 "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU 283 " $a1, $a0, %[dst_col_stride]\n" 284 "st.w $w1, (0*16)($a0)\n" 285 "st.w $w5, (1*16)($a0)\n" 286 "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU 287 " $a0, $a1, %[dst_col_stride]\n" 288 "st.w $w2, (0*16)($a1)\n" 289 "st.w $w6, (1*16)($a1)\n" 290 "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU 291 " $a1, $a0, %[dst_col_stride]\n" 292 "st.w $w3, (0*16)($a0)\n" 293 "st.w $w7, (1*16)($a0)\n" 294 "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU 295 " $a0, $a1, %[dst_col_stride]\n" 296 "st.w $w12, (0*16)($a1)\n" 297 "st.w $w16, (1*16)($a1)\n" 298 "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU 299 " $a1, $a0, %[dst_col_stride]\n" 300 "st.w $w13, (0*16)($a0)\n" 301 "st.w $w17, (1*16)($a0)\n" 302 "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU 303 " $a0, $a1, %[dst_col_stride]\n" 304 "st.w $w14, (0*16)($a1)\n" 305 "st.w $w18, (1*16)($a1)\n" 306 "st.w $w22, (2*16)($a1)\n" 307 "st.w $w15, (0*16)($a0)\n" 308 "st.w $w19, (1*16)($a0)\n" 309 "st.w $w23, (2*16)($a0)\n" 310 : // outputs 311 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 312 [run_depth] "+r"(run_depth), 313 [dst_col_stride] "+r"(dst_col_stride) 314 : // inputs 315 [dst_ptr] "r"(dst_ptr), 316 [start_depth] "r"(start_depth) 317 : // clobbers 318 "memory", "v0", "v1", "a0", "a1", "a2", "a3", "t8", "t9", "$f0", "$f1", 319 "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11", 320 "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20", 321 "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", 322 "$f30", "$f31"); 323 324 #undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS 325 #undef GEMMLOWP_LABEL_BEFORE_LOOP 326 #undef GEMMLOWP_LABEL_LOOP 327 #undef GEMMLOWP_LABEL_AFTER_LOOP 328 } 329 }; 330 331 #undef GEMMLOWP_MIPS_XADDU 332 #undef GEMMLOWP_MIPS_XADDIU 333 #undef GEMMLOWP_MIPS_XSLL 334 335 #endif // GEMMLOWP_MSA 336 337 } // namespace gemmlowp 338 339 #endif // GEMMLOWP_INTERNAL_KERNEL_MSA_H_ 340