1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code
16 
17 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
18 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
19 
20 #include "simd_wrappers.h"
21 
22 namespace gemmlowp {
23 
24 template <typename SrcScalarType, int N>
25 struct LoadImpl<RegBlockInt32<4, N>,
26                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
27   static RegBlockInt32<4, N> Run(
28       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
29       int col) {
30     RegBlockInt32<4, N> result;
31     for (int i = 0; i < N; i++) {
32       result.buf.reg[i] = LoadInt32x4(src.data(row, col + i));
33     }
34     return result;
35   }
36 };
37 
38 template <typename SrcScalarType, int N>
39 struct LoadImpl<RegBlockInt32<8, N>,
40                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
41   static RegBlockInt32<8, N> Run(
42       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
43       int col) {
44     RegBlockInt32<8, N> result;
45     for (int i = 0; i < N; i++) {
46       result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i));
47       result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i));
48     }
49     return result;
50   }
51 };
52 
53 template <typename SrcScalarType>
54 struct LoadImpl<RegBlockInt32<1, 4>,
55                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
56   static RegBlockInt32<1, 4> Run(
57       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
58       int col) {
59     RegBlockInt32<1, 4> result;
60     std::int32_t buf[4];
61     for (int i = 0; i < 4; i++) {
62       buf[i] = src(row, col + i);
63     }
64     result.buf.reg[0] = LoadInt32x4(buf);
65     return result;
66   }
67 };
68 
69 template <typename SrcScalarType>
70 struct LoadImpl<RegBlockInt32<1, 8>,
71                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
72   static RegBlockInt32<1, 8> Run(
73       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
74       int col) {
75     RegBlockInt32<1, 8> result;
76     std::int32_t buf[8];
77     for (int i = 0; i < 8; i++) {
78       buf[i] = src(row, col + i);
79     }
80     result.buf.reg[0] = LoadInt32x4(buf);
81     result.buf.reg[1] = LoadInt32x4(buf + 4);
82     return result;
83   }
84 };
85 
86 template <typename SrcScalarType>
87 struct LoadImpl<RegBlockInt32<4, 1>,
88                 VectorMap<SrcScalarType, VectorShape::Col>> {
89   static RegBlockInt32<4, 1> Run(
90       const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) {
91     RegBlockInt32<4, 1> result;
92     result.buf.reg[0] = LoadInt32x4(src.data(pos));
93     return result;
94   }
95 };
96 
97 template <typename SrcScalarType>
98 struct LoadImpl<RegBlockInt32<4, 1>,
99                 VectorDup<SrcScalarType, VectorShape::Col>> {
100   static RegBlockInt32<4, 1> Run(
101       const VectorDup<SrcScalarType, VectorShape::Col>& src, int) {
102     RegBlockInt32<4, 1> result;
103     result.buf.reg[0] = LoadInt32x4(src(0));
104     return result;
105   }
106 };
107 
108 template <typename SrcScalarType, int N>
109 struct LoadForBroadcastingImpl<RegBlockInt32<4, N>,
110                                VectorMap<SrcScalarType, VectorShape::Col>> {
111   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
112   using RegisterBlockType = RegBlockInt32<4, N>;
113   using ResultBlockType =
114       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
115                                                 SrcObjectType>::Type;
116 
117   static ResultBlockType Run(const SrcObjectType& src, int pos) {
118     ResultBlockType result;
119     static_assert(ResultBlockType::kRegisterCount == 1, "");
120     result.buf.reg[0] = LoadInt32x4(src.data(pos));
121     return result;
122   }
123 };
124 
125 template <typename SrcScalarType, int N>
126 struct LoadForBroadcastingImpl<RegBlockInt32<8, N>,
127                                VectorMap<SrcScalarType, VectorShape::Col>> {
128   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
129   using RegisterBlockType = RegBlockInt32<8, N>;
130   using ResultBlockType =
131       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
132                                                 SrcObjectType>::Type;
133 
134   static ResultBlockType Run(const SrcObjectType& src, int pos) {
135     ResultBlockType result;
136     static_assert(ResultBlockType::kRegisterCount == 2, "");
137     result.buf.reg[0] = LoadInt32x4(src.data(pos));
138     result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
139     return result;
140   }
141 };
142 
143 template <typename SrcScalarType>
144 struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>,
145                                VectorMap<SrcScalarType, VectorShape::Row>> {
146   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
147   using RegisterBlockType = RegBlockInt32<4, 1>;
148   using ResultBlockType =
149       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
150                                                 SrcObjectType>::Type;
151 
152   static ResultBlockType Run(const SrcObjectType& src, int pos) {
153     ResultBlockType result;
154     result.buf.reg[0] = src(pos);
155     return result;
156   }
157 };
158 
159 template <typename SrcScalarType, int N>
160 struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>,
161                                VectorMap<SrcScalarType, VectorShape::Row>> {
162   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
163   using RegisterBlockType = RegBlockInt32<N, 4>;
164   using ResultBlockType =
165       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
166                                                 SrcObjectType>::Type;
167 
168   static ResultBlockType Run(const SrcObjectType& src, int pos) {
169     ResultBlockType result;
170     static_assert(ResultBlockType::kRegisterCount == 1, "");
171     result.buf.reg[0] = LoadInt32x4(src.data(pos));
172     return result;
173   }
174 };
175 
176 template <typename SrcScalarType, int N>
177 struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>,
178                                VectorMap<SrcScalarType, VectorShape::Row>> {
179   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
180   using RegisterBlockType = RegBlockInt32<N, 8>;
181   using ResultBlockType =
182       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
183                                                 SrcObjectType>::Type;
184 
185   static ResultBlockType Run(const SrcObjectType& src, int pos) {
186     ResultBlockType result;
187     static_assert(ResultBlockType::kRegisterCount == 2, "");
188     result.buf.reg[0] = LoadInt32x4(src.data(pos));
189     result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
190     return result;
191   }
192 };
193 
194 // 4x1 := 4x1 + 1x1
195 template <>
196 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
197   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
198                                  const RegBlockInt32<1, 1>& rhs) {
199     RegBlockInt32<4, 1> result;
200     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
201     return result;
202   }
203 };
204 
205 // 1x4 := 1x4 + 1x1
206 template <>
207 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
208   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
209                                  const RegBlockInt32<1, 1>& rhs) {
210     RegBlockInt32<1, 4> result;
211     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
212     return result;
213   }
214 };
215 
216 // 4x1 := 4x1 + 4x1
217 template <>
218 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
219   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
220                                  const RegBlockInt32<4, 1>& rhs) {
221     RegBlockInt32<4, 1> result;
222     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
223     return result;
224   }
225 };
226 
227 // 1x4 := 1x4 + 1x4
228 template <>
229 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
230   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
231                                  const RegBlockInt32<1, 4>& rhs) {
232     RegBlockInt32<1, 4> result;
233     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
234     return result;
235   }
236 };
237 
238 // 4x4 := 4x4 + 1x4
239 template <>
240 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
241   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
242                                  const RegBlockInt32<1, 4>& rhs) {
243     RegBlockInt32<4, 4> result;
244     result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
245     result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
246     result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
247     result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
248     return result;
249   }
250 };
251 
252 // 4x4 := 4x4 + 4x1
253 template <>
254 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
255   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
256                                  const RegBlockInt32<4, 1>& rhs) {
257     RegBlockInt32<4, 4> result;
258     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
259     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]);
260     result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
261     result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]);
262     return result;
263   }
264 };
265 
266 // 8x1 := 8x1 + 1x1
267 template <>
268 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
269   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
270                                  const RegBlockInt32<1, 1>& rhs) {
271     RegBlockInt32<8, 1> result;
272     const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
273     for (int i = 0; i < 2; i++) {
274       result.buf.reg[i] = Add(lhs.buf.reg[i], p);
275     }
276     return result;
277   }
278 };
279 
280 // 8x1 := 8x1 + 8x1
281 template <>
282 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
283   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
284                                  const RegBlockInt32<8, 1>& rhs) {
285     RegBlockInt32<8, 1> result;
286     for (int i = 0; i < 2; i++) {
287       result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
288     }
289     return result;
290   }
291 };
292 
293 // 8x4 := 8x4 + 1x4
294 template <>
295 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
296   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
297                                  const RegBlockInt32<1, 4>& rhs) {
298     RegBlockInt32<8, 4> result;
299     result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
300     result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
301     result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
302     result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
303     result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
304     result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
305     result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
306     result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
307     return result;
308   }
309 };
310 
311 // 8x4 := 8x4 + 8x1
312 template <>
313 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
314   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
315                                  const RegBlockInt32<8, 1>& rhs) {
316     RegBlockInt32<8, 4> result;
317     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
318     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
319     result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
320     result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]);
321     result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]);
322     result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]);
323     result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]);
324     result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]);
325     return result;
326   }
327 };
328 
329 // 1x8 := 1x8 + 1x8
330 template <>
331 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
332   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
333                                  const RegBlockInt32<1, 8>& rhs) {
334     RegBlockInt32<1, 8> result;
335     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
336     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
337     return result;
338   }
339 };
340 
341 // 1x8 := 1x8 + 1x1
342 template <>
343 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
344   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
345                                  const RegBlockInt32<1, 1>& rhs) {
346     RegBlockInt32<1, 8> result;
347     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
348     result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
349     return result;
350   }
351 };
352 
353 // 4x1 := 4x1 * 1x1
354 template <>
355 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
356   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
357                                  const RegBlockInt32<1, 1>& rhs) {
358     RegBlockInt32<4, 1> result;
359     result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
360     return result;
361   }
362 };
363 
364 // 4x1 := 4x1 * 4x1
365 template <>
366 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
367   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
368                                  const RegBlockInt32<4, 1>& rhs) {
369     RegBlockInt32<4, 1> result;
370     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
371     return result;
372   }
373 };
374 
375 // 1x4 := 1x4 * 1x4
376 template <>
377 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
378   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
379                                  const RegBlockInt32<1, 4>& rhs) {
380     RegBlockInt32<1, 4> result;
381     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
382     return result;
383   }
384 };
385 
386 // 1x4 := 1x4 * 1x1
387 template <>
388 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
389   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
390                                  const RegBlockInt32<1, 1>& rhs) {
391     RegBlockInt32<1, 4> result;
392     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
393     return result;
394   }
395 };
396 
397 // 4x4 := 4x4 * 1x4
398 template <>
399 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
400   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
401                                  const RegBlockInt32<1, 4>& rhs) {
402     RegBlockInt32<4, 4> result;
403     const Int32x4 p = rhs.buf.reg[0];
404     result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p);
405     result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p);
406     result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p);
407     result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p);
408     return result;
409   }
410 };
411 
412 // 4x4 := 4x4 * 4x1
413 template <>
414 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
415   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
416                                  const RegBlockInt32<4, 1>& rhs) {
417     RegBlockInt32<4, 4> result;
418     const Int32x4 p = rhs.buf.reg[0];
419     result.buf.reg[0] = Mul(lhs.buf.reg[0], p);
420     result.buf.reg[1] = Mul(lhs.buf.reg[1], p);
421     result.buf.reg[2] = Mul(lhs.buf.reg[2], p);
422     result.buf.reg[3] = Mul(lhs.buf.reg[3], p);
423     return result;
424   }
425 };
426 
427 // 8x1 := 8x1 * 1x1
428 template <>
429 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
430   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
431                                  const RegBlockInt32<1, 1>& rhs) {
432     RegBlockInt32<8, 1> result;
433     const std::int32_t p = rhs.buf.reg[0];
434     for (int i = 0; i < 2; i++) {
435       result.buf.reg[i] = Mul(lhs.buf.reg[i], p);
436     }
437     return result;
438   }
439 };
440 
441 // 8x1 := 8x1 * 8x1
442 template <>
443 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
444   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
445                                  const RegBlockInt32<8, 1>& rhs) {
446     RegBlockInt32<8, 1> result;
447     for (int i = 0; i < 2; i++) {
448       result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]);
449     }
450     return result;
451   }
452 };
453 
454 // 8x4 := 8x4 * 1x4
455 template <>
456 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
457   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
458                                  const RegBlockInt32<1, 4>& rhs) {
459     RegBlockInt32<8, 4> result;
460     const Int32x4 p = rhs.buf.reg[0];
461     for (int i = 0; i < 2; i++) {
462       result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p);
463       result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p);
464       result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p);
465       result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p);
466     }
467     return result;
468   }
469 };
470 
471 // 8x4 := 8x4 * 8x1
472 template <>
473 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
474   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
475                                  const RegBlockInt32<8, 1>& rhs) {
476     RegBlockInt32<8, 4> result;
477     const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]};
478     for (int i = 0; i < 4; i++) {
479       for (int j = 0; j < 2; j++) {
480         const int k = j + 2 * i;
481         result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]);
482       }
483     }
484     return result;
485   }
486 };
487 
488 // Rx1 += Rx1 * 1x1
489 template <int Rows>
490 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
491                            RegBlockInt32<Rows, 1>> {
492   static void Run(const RegBlockInt32<Rows, 1>& lhs,
493                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) {
494     const std::int32_t p = rhs.buf.reg[0];
495     for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) {
496       MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
497     }
498   }
499 };
500 
501 // RxC += Rx1 * 1x1
502 template <int Rows, int Cols>
503 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
504                            RegBlockInt32<Rows, Cols>> {
505   static void Run(const RegBlockInt32<Rows, 1>& lhs,
506                   const RegBlockInt32<1, 1>& rhs,
507                   RegBlockInt32<Rows, Cols>* acc) {
508     const std::int32_t p = rhs.buf.reg[0];
509     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
510     for (int i = 0; i < kRegsPerCol; i++) {
511       const Int32x4 q = Mul(lhs.buf.reg[i], p);
512       for (int j = 0; j < Cols; j++) {
513         acc->buf.reg[i + j * kRegsPerCol] =
514             Add(acc->buf.reg[i + j * kRegsPerCol], q);
515       }
516     }
517   }
518 };
519 
520 // 1xC += 1xC * 1x1
521 template <int Cols>
522 struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>,
523                            RegBlockInt32<1, Cols>> {
524   static void Run(const RegBlockInt32<1, Cols>& lhs,
525                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
526     const std::int32_t p = rhs.buf.reg[0];
527     for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
528       MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
529     }
530   }
531 };
532 
533 // RxC += 1x1 * 1x1
534 template <int Rows, int Cols>
535 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
536                            RegBlockInt32<Rows, Cols>> {
537   static void Run(const RegBlockInt32<1, 1>& lhs,
538                   const RegBlockInt32<1, 1>& rhs,
539                   RegBlockInt32<Rows, Cols>* acc) {
540     const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
541     for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) {
542       acc->buf.reg[i] = Add(acc->buf.reg[i], p);
543     }
544   }
545 };
546 
547 // 1x1 += 1x1 * 1x1
548 template <>
549 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
550                            RegBlockInt32<1, 1>> {
551   static void Run(const RegBlockInt32<1, 1>& lhs,
552                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) {
553     MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]);
554   }
555 };
556 
557 // Rx4 += Rx1 * 1x4
558 template <int Rows>
559 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>,
560                            RegBlockInt32<Rows, 4>> {
561   static void Run(const RegBlockInt32<Rows, 1>& lhs,
562                   const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) {
563     const Int32x4 p = rhs.buf.reg[0];
564     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
565     for (int i = 0; i < kRegsPerCol; i++) {
566       MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]);
567       MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]);
568       MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]);
569       MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]);
570     }
571   }
572 };
573 
574 // Rx4 += 1x4 * 1x1
575 template <int Rows>
576 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
577                            RegBlockInt32<Rows, 4>> {
578   static void Run(const RegBlockInt32<1, 4>& lhs,
579                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) {
580     const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
581     Int32x4 q[4];
582     q[0] = DupLane<0>(p);
583     q[1] = DupLane<1>(p);
584     q[2] = DupLane<2>(p);
585     q[3] = DupLane<3>(p);
586     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
587     for (int i = 0; i < kRegsPerCol; i++) {
588       for (int j = 0; j < 4; j++) {
589         acc->buf.reg[i + j * kRegsPerCol] =
590             Add(q[j], acc->buf.reg[i + j * kRegsPerCol]);
591       }
592     }
593   }
594 };
595 
596 // 1xC += 1x1 * 1x1
597 template <int Cols>
598 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
599                            RegBlockInt32<1, Cols>> {
600   static void Run(const RegBlockInt32<1, 1>& lhs,
601                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
602     const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
603     for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
604       acc->buf.reg[i] = Add(acc->buf.reg[i], p);
605     }
606   }
607 };
608 
609 // 1x4 += 1x4 * 1x1
610 template <>
611 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
612                            RegBlockInt32<1, 4>> {
613   static void Run(const RegBlockInt32<1, 4>& lhs,
614                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) {
615     const std::int32_t p = rhs.buf.reg[0];
616     MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
617   }
618 };
619 
620 // 4xC += 4x1 * 1x1
621 template <int Cols>
622 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
623                            RegBlockInt32<4, Cols>> {
624   static void Run(const RegBlockInt32<4, 1>& lhs,
625                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) {
626     const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
627     for (int i = 0; i < Cols; i++) {
628       acc->buf.reg[i] = Add(p, acc->buf.reg[i]);
629     }
630   }
631 };
632 
633 // 4x1 += 4x1 * 1x1
634 template <>
635 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
636                            RegBlockInt32<4, 1>> {
637   static void Run(const RegBlockInt32<4, 1>& lhs,
638                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) {
639     const std::int32_t p = rhs.buf.reg[0];
640     MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
641   }
642 };
643 
644 }  // namespace gemmlowp
645 
646 #endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
647