1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // kernel_msa.h: a collection of MSA optimized kernels.
16 // Check in kernel_default.h which one(s) are actually used by default.
17 // Others are mere experiments; they are still covered by tests
18 // in case they might be useful some day.
19 
20 #ifndef GEMMLOWP_INTERNAL_KERNEL_MSA_H_
21 #define GEMMLOWP_INTERNAL_KERNEL_MSA_H_
22 
23 #include "kernel.h"
24 
25 #include <msa.h>
26 #include <cassert>
27 
28 namespace gemmlowp {
29 
30 #ifdef GEMMLOWP_MSA
31 
32 // Some convenience macros to hide differences between MIPS32 and MIPS64.
33 #ifdef GEMMLOWP_MIPS_64
34 #define GEMMLOWP_MIPS_XADDU "daddu"
35 #define GEMMLOWP_MIPS_XADDIU "daddiu"
36 #define GEMMLOWP_MIPS_XSLL "dsll"
37 #else
38 #define GEMMLOWP_MIPS_XADDU "addu"
39 #define GEMMLOWP_MIPS_XADDIU "addiu"
40 #define GEMMLOWP_MIPS_XSLL "sll"
41 #endif
42 
43 // Our main GEMM kernel.
44 struct MSA_Kernel12x8Depth2 : KernelBase {
45   typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
46                        KernelSideFormat<CellFormat<4, 2>, 2> >
47       Format;
48 
NameMSA_Kernel12x8Depth249   const char* Name() const override { return "MSA, 12x8, depth 2"; }
50 
51   // TODO(benoitjacob): reorder function arguments so dst comes last
RunMSA_Kernel12x8Depth252   void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
53            std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
54            const std::uint8_t* rhs_ptr, std::size_t start_depth,
55            std::size_t run_depth) const override {
56     ScopedProfilingLabel label("optimized kernel (MSA 12x8)");
57 // See comments above for why we need local numerical labels in our asm.
58 #define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
59 #define GEMMLOWP_LABEL_BEFORE_LOOP "2"
60 #define GEMMLOWP_LABEL_LOOP "3"
61 #define GEMMLOWP_LABEL_AFTER_LOOP "4"
62 
63     assert(dst_row_stride == 1);
64     asm volatile(
65         // Set a temp to all zeroes.
66         "ldi.b  $w31, 0\n"
67 
68         // Multiply dst_col_stride by 4 == sizeof(int32) to use
69         // it as a byte offset below.
70         GEMMLOWP_MIPS_XSLL
71         " %[dst_col_stride], %[dst_col_stride], 2\n"
72 
73         // Check if start_depth==0 to decide whether we will clear
74         // accumulators or load existing accumulators.
75         "beqz   %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
76 
77         // Load accumulators (start_depth != 0).
78         GEMMLOWP_MIPS_XADDU
79         " $a0, %[dst_ptr], %[dst_col_stride]\n"
80         "ld.w   $w0,  (0*16)(%[dst_ptr])\n"
81         "ld.w   $w4,  (1*16)(%[dst_ptr])\n"
82         "ld.w   $w8,  (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
83         " $a1, $a0, %[dst_col_stride]\n"
84         "ld.w   $w1,  (0*16)($a0)\n"
85         "ld.w   $w5,  (1*16)($a0)\n"
86         "ld.w   $w9,  (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
87         " $a0, $a1, %[dst_col_stride]\n"
88         "ld.w   $w2,  (0*16)($a1)\n"
89         "ld.w   $w6,  (1*16)($a1)\n"
90         "ld.w   $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
91         " $a1, $a0, %[dst_col_stride]\n"
92         "ld.w   $w3,  (0*16)($a0)\n"
93         "ld.w   $w7,  (1*16)($a0)\n"
94         "ld.w   $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
95         " $a0, $a1, %[dst_col_stride]\n"
96         "ld.w   $w12, (0*16)($a1)\n"
97         "ld.w   $w16, (1*16)($a1)\n"
98         "ld.w   $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
99         " $a1, $a0, %[dst_col_stride]\n"
100         "ld.w   $w13, (0*16)($a0)\n"
101         "ld.w   $w17, (1*16)($a0)\n"
102         "ld.w   $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
103         " $a0, $a1, %[dst_col_stride]\n"
104         "ld.w   $w14, (0*16)($a1)\n"
105         "ld.w   $w18, (1*16)($a1)\n"
106         "ld.w   $w22, (2*16)($a1)\n"
107         "ld.w   $w15, (0*16)($a0)\n"
108         "ld.w   $w19, (1*16)($a0)\n"
109         "ld.w   $w23, (2*16)($a0)\n"
110         "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
111 
112         GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
113         ":\n"
114         // Clear accumulators (start_depth == 0).
115         "ldi.w  $w0,  0\n"
116         "ldi.w  $w4,  0\n"
117         "ldi.w  $w8,  0\n"
118         "ldi.w  $w1,  0\n"
119         "ldi.w  $w5,  0\n"
120         "ldi.w  $w9,  0\n"
121         "ldi.w  $w2,  0\n"
122         "ldi.w  $w6,  0\n"
123         "ldi.w  $w10, 0\n"
124         "ldi.w  $w3,  0\n"
125         "ldi.w  $w7,  0\n"
126         "ldi.w  $w11, 0\n"
127         "ldi.w  $w12, 0\n"
128         "ldi.w  $w16, 0\n"
129         "ldi.w  $w20, 0\n"
130         "ldi.w  $w13, 0\n"
131         "ldi.w  $w17, 0\n"
132         "ldi.w  $w21, 0\n"
133         "ldi.w  $w14, 0\n"
134         "ldi.w  $w18, 0\n"
135         "ldi.w  $w22, 0\n"
136         "ldi.w  $w15, 0\n"
137         "ldi.w  $w19, 0\n"
138         "ldi.w  $w23, 0\n"
139 
140         GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
141 
142         GEMMLOWP_LABEL_LOOP
143         ":\n"
144         // Overview of register layout:
145         //
146         // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
147         // (each register contains 4 replicas of a pair of elements).
148         // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
149         // A 12x8 block of accumulators is stored in 32bit in w0-w23.
150         //
151         //                    +------+------+------+------+
152         //               Rhs  |w27   |w28   |w29   |w30   |
153         //                    +------+------+------+------+
154         //
155         //                    |      |      |      |      |
156         //
157         //       Lhs          |      |      |      |      |
158         //
159         //      +---+ - - - - +------+------+------+------+
160         //      |w24|         |w0/12 |w1/13 |w2/14 |w3/15 |
161         //      |w24|         |w0/12 |w1/13 |w2/14 |w3/15 |
162         //      |w24|         |w0/12 |w1/13 |w2/14 |w3/15 |
163         //      |w24|         |w0/12 |w1/13 |w2/14 |w3/15 |
164         //      +---+ - - - - +------+------+------+------+
165         //      |w25|         |w4/16 |w5/17 |w6/18 |w7/19 |
166         //      |w25|         |w4/16 |w5/17 |w6/18 |w7/19 |
167         //      |w25|         |w4/16 |w5/17 |w6/18 |w7/19 |
168         //      |w25|         |w4/16 |w5/17 |w6/18 |w7/19 |
169         //      +---+ - - - - +------+------+------+------+
170         //      |w26|         |w8/20 |w9/21 |w10/22|w11/23|
171         //      |w26|         |w8/20 |w9/21 |w10/22|w11/23|
172         //      |w26|         |w8/20 |w9/21 |w10/22|w11/23|
173         //      |w26|         |w8/20 |w9/21 |w10/22|w11/23|
174         //      +---+ - - - - +------+------+------+------+
175         //
176         //                             Accumulators
177 
178         // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
179         "ld.b   $w24, 0(%[lhs_ptr])\n"
180         "ld.b   $w25, 8(%[lhs_ptr])\n"
181 
182         // Load 4 bytes of rhs[] for the first half of depth 0.
183         "lbu    $a0, 0(%[rhs_ptr])\n"
184         "lbu    $a1, 1(%[rhs_ptr])\n"
185         "lbu    $a2, 2(%[rhs_ptr])\n"
186         "lbu    $a3, 3(%[rhs_ptr])\n"
187         // Load 4 bytes of rhs[] for the first half of depth 1.
188         "lbu    $v0, 4(%[rhs_ptr])\n"
189         "lbu    $v1, 5(%[rhs_ptr])\n"
190         "lbu    $t8, 6(%[rhs_ptr])\n"
191         "lbu    $t9, 7(%[rhs_ptr])\n"
192 
193         // Zero-extend 8-bit elements of lhs[] to 16 bits.
194         "ilvr.b $w24, $w31, $w24\n"
195         "ilvl.b $w26, $w31, $w25\n"
196         "ilvr.b $w25, $w31, $w25\n"
197         // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
198         "ilvl.d $w27, $w31, $w24\n"
199         "ilvl.d $w28, $w31, $w25\n"
200         "ilvl.d $w29, $w31, $w26\n"
201         "ilvr.h $w24, $w27, $w24\n"
202         "ilvr.h $w25, $w28, $w25\n"
203         "ilvr.h $w26, $w29, $w26\n"
204 
205         // Combine and interleave depth 0 and depth 1 elements of rhs[] for
206         // dpadd_u.w (for the first half).
207         "ins    $a0, $v0, 16, 8\n"
208         "ins    $a1, $v1, 16, 8\n"
209         "ins    $a2, $t8, 16, 8\n"
210         "ins    $a3, $t9, 16, 8\n"
211         // Make 4 replicas of every pair of rhs[] elements.
212         "fill.w $w27, $a0\n"
213         "fill.w $w28, $a1\n"
214         "fill.w $w29, $a2\n"
215         "fill.w $w30, $a3\n"
216 
217         // Load 4 bytes of rhs[] for the second half of depth 0.
218         "lbu    $a0, 8(%[rhs_ptr])\n"
219         "lbu    $a1, 9(%[rhs_ptr])\n"
220         "lbu    $a2, 10(%[rhs_ptr])\n"
221         "lbu    $a3, 11(%[rhs_ptr])\n"
222         // Load 4 bytes of rhs[] for the second half of depth 1.
223         "lbu    $v0, 12(%[rhs_ptr])\n"
224         "lbu    $v1, 13(%[rhs_ptr])\n"
225         "lbu    $t8, 14(%[rhs_ptr])\n"
226         "lbu    $t9, 15(%[rhs_ptr])\n"
227 
228         // First half of depths 0 and 1.
229         // Dot-product-(and)-add doubles multiplicand width.
230         "dpadd_u.w $w0, $w24, $w27\n"
231         "dpadd_u.w $w4, $w25, $w27\n"
232         "dpadd_u.w $w8, $w26, $w27\n"
233         "dpadd_u.w $w1, $w24, $w28\n"
234         "dpadd_u.w $w5, $w25, $w28\n"
235         "dpadd_u.w $w9, $w26, $w28\n"
236         "dpadd_u.w $w2, $w24, $w29\n"
237         "dpadd_u.w $w6, $w25, $w29\n"
238         "dpadd_u.w $w10, $w26, $w29\n"
239         "dpadd_u.w $w3, $w24, $w30\n"
240         "dpadd_u.w $w7, $w25, $w30\n"
241         "dpadd_u.w $w11, $w26, $w30\n"
242 
243         // Combine and interleave depth 0 and depth 1 elements of rhs[] for
244         // dpadd_u.w (for the second half).
245         "ins    $a0, $v0, 16, 8\n"
246         "ins    $a1, $v1, 16, 8\n"
247         "ins    $a2, $t8, 16, 8\n"
248         "ins    $a3, $t9, 16, 8\n"
249         // Make 4 replicas of every pair of rhs[] elements.
250         "fill.w $w27, $a0\n"
251         "fill.w $w28, $a1\n"
252         "fill.w $w29, $a2\n"
253         "fill.w $w30, $a3\n"
254 
255         // Second half of depths 0 and 1.
256         // Dot-product-(and)-add doubles multiplicand width.
257         "dpadd_u.w $w12, $w24, $w27\n"
258         "dpadd_u.w $w16, $w25, $w27\n"
259         "dpadd_u.w $w20, $w26, $w27\n"
260         "dpadd_u.w $w13, $w24, $w28\n"
261         "dpadd_u.w $w17, $w25, $w28\n"
262         "dpadd_u.w $w21, $w26, $w28\n"
263         "dpadd_u.w $w14, $w24, $w29\n"
264         "dpadd_u.w $w18, $w25, $w29\n"
265         "dpadd_u.w $w22, $w26, $w29\n"
266         "dpadd_u.w $w15, $w24, $w30\n"
267         "dpadd_u.w $w19, $w25, $w30\n"
268         "dpadd_u.w $w23, $w26, $w30\n"
269 
270         GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU
271         " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU
272         " %[rhs_ptr], 16\n"
273         "bnez   %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n"
274 
275         GEMMLOWP_LABEL_AFTER_LOOP ":\n"
276 
277         // Store accumulators.
278         GEMMLOWP_MIPS_XADDU
279         " $a0, %[dst_ptr], %[dst_col_stride]\n"
280         "st.w   $w0,  (0*16)(%[dst_ptr])\n"
281         "st.w   $w4,  (1*16)(%[dst_ptr])\n"
282         "st.w   $w8,  (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
283         " $a1, $a0, %[dst_col_stride]\n"
284         "st.w   $w1,  (0*16)($a0)\n"
285         "st.w   $w5,  (1*16)($a0)\n"
286         "st.w   $w9,  (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
287         " $a0, $a1, %[dst_col_stride]\n"
288         "st.w   $w2,  (0*16)($a1)\n"
289         "st.w   $w6,  (1*16)($a1)\n"
290         "st.w   $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
291         " $a1, $a0, %[dst_col_stride]\n"
292         "st.w   $w3,  (0*16)($a0)\n"
293         "st.w   $w7,  (1*16)($a0)\n"
294         "st.w   $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
295         " $a0, $a1, %[dst_col_stride]\n"
296         "st.w   $w12, (0*16)($a1)\n"
297         "st.w   $w16, (1*16)($a1)\n"
298         "st.w   $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
299         " $a1, $a0, %[dst_col_stride]\n"
300         "st.w   $w13, (0*16)($a0)\n"
301         "st.w   $w17, (1*16)($a0)\n"
302         "st.w   $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
303         " $a0, $a1, %[dst_col_stride]\n"
304         "st.w   $w14, (0*16)($a1)\n"
305         "st.w   $w18, (1*16)($a1)\n"
306         "st.w   $w22, (2*16)($a1)\n"
307         "st.w   $w15, (0*16)($a0)\n"
308         "st.w   $w19, (1*16)($a0)\n"
309         "st.w   $w23, (2*16)($a0)\n"
310         :  // outputs
311         [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
312         [run_depth] "+r"(run_depth),
313         [dst_col_stride] "+r"(dst_col_stride)
314         :  // inputs
315         [dst_ptr] "r"(dst_ptr),
316         [start_depth] "r"(start_depth)
317         :  // clobbers
318         "memory", "v0", "v1", "a0", "a1", "a2", "a3", "t8", "t9", "$f0", "$f1",
319         "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11",
320         "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
321         "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29",
322         "$f30", "$f31");
323 
324 #undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
325 #undef GEMMLOWP_LABEL_BEFORE_LOOP
326 #undef GEMMLOWP_LABEL_LOOP
327 #undef GEMMLOWP_LABEL_AFTER_LOOP
328   }
329 };
330 
331 #undef GEMMLOWP_MIPS_XADDU
332 #undef GEMMLOWP_MIPS_XADDIU
333 #undef GEMMLOWP_MIPS_XSLL
334 
335 #endif  // GEMMLOWP_MSA
336 
337 }  // namespace gemmlowp
338 
339 #endif  // GEMMLOWP_INTERNAL_KERNEL_MSA_H_
340