1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
16 // extending the set of such functions from fixedpoint.h.
17 
18 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
19 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
20 
21 #include <algorithm>
22 #include <type_traits>
23 #include "../fixedpoint/fixedpoint.h"
24 
25 namespace gemmlowp {
26 
27 template <typename ScalarType, int ScalarCount>
28 struct RegisterType {
29   using Type = ScalarType;
30 };
31 
Min(std::int32_t a,std::int32_t b)32 inline std::int32_t Min(std::int32_t a, std::int32_t b) {
33   return std::min(a, b);
34 }
35 
Max(std::int32_t a,std::int32_t b)36 inline std::int32_t Max(std::int32_t a, std::int32_t b) {
37   return std::max(a, b);
38 }
39 
MulAdd(std::int32_t lhs,std::int32_t rhs,std::int32_t * acc)40 inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
41   *acc += lhs * rhs;
42 }
43 
44 template <typename tScalarType, int tScalarCount>
45 struct RegisterBuffer {
46   using ScalarType = tScalarType;
47   static constexpr int kScalarCount = tScalarCount;
48   using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
49   static_assert((kScalarCount & (kScalarCount - 1)) == 0,
50                 "kScalarCount must be a power of two");
51   static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
52   static constexpr int kRegisterLanes =
53       sizeof(RegisterType) / sizeof(ScalarType);
54   static constexpr int kRegisterCount =
55       (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
56       sizeof(RegisterType);
57 
58   RegisterType reg[kRegisterCount];
59 };
60 
61 template <typename tScalarType, int tRows, int tCols>
62 struct RegisterBlock {
63   using ScalarType = tScalarType;
64   static constexpr int kRows = tRows;
65   static constexpr int kCols = tCols;
66   static constexpr int kScalarCount = kRows * kCols;
67   using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
68   using RegisterType = typename BufferType::RegisterType;
69   static constexpr int kRegisterCount = BufferType::kRegisterCount;
70   static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
71 
72   BufferType buf;
73 };
74 
75 template <typename RegisterBlockType>
76 struct RegisterBlockAddImpl {
RunRegisterBlockAddImpl77   static RegisterBlockType Run(const RegisterBlockType& lhs,
78                                const RegisterBlockType& rhs) {
79     RegisterBlockType result;
80     for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
81       result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
82     }
83     return result;
84   }
85 };
86 
87 template <typename RegisterBlockType>
RegisterBlockAdd(const RegisterBlockType & lhs,const RegisterBlockType & rhs)88 RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
89                                    const RegisterBlockType& rhs) {
90   return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
91 }
92 
93 template <typename LhsType, typename RhsType>
94 struct ShouldFlipLhsRhs {
95   static constexpr bool kValue =
96       (LhsType::kScalarCount < RhsType::kScalarCount) ||
97       (LhsType::kScalarCount == RhsType::kScalarCount &&
98        (LhsType::kRows < RhsType::kRows));
99 };
100 
101 template <typename LhsType, typename RhsType,
102           bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
103 struct FlipLhsRhs {
104   using FlippedLhsType = LhsType;
105   using FlippedRhsType = RhsType;
FlippedLhsFlipLhsRhs106   static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
107                                           const RhsType& rhs) {
108     return lhs;
109   }
FlippedRhsFlipLhsRhs110   static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
111                                           const RhsType& rhs) {
112     return rhs;
113   }
114 };
115 
116 template <typename LhsType, typename RhsType>
117 struct FlipLhsRhs<LhsType, RhsType, true> {
118   using FlippedLhsType = RhsType;
119   using FlippedRhsType = LhsType;
120   static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
121                                           const RhsType& rhs) {
122     return rhs;
123   }
124   static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
125                                           const RhsType& rhs) {
126     return lhs;
127   }
128 };
129 
130 template <typename Lhs, typename Rhs>
131 struct BroadcastBinaryOpShape {
132   static constexpr int kRows =
133       Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
134   static constexpr int kCols =
135       Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
136 };
137 
138 template <typename Lhs, typename Rhs>
139 struct BroadcastBinaryOpRegisterBlock {
140   using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
141   using ScalarType = typename Lhs::ScalarType;
142   using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
143 };
144 
145 template <typename Lhs, typename Rhs>
146 struct BroadcastAddImpl {
147   using ResultBlockType =
148       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
149   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
150     ResultBlockType result;
151     static constexpr int Rows = ResultBlockType::kRows;
152     static constexpr int Cols = ResultBlockType::kCols;
153     static constexpr int LhsRows = Lhs::kRows;
154     static constexpr int LhsCols = Lhs::kCols;
155     static constexpr int RhsRows = Rhs::kRows;
156     static constexpr int RhsCols = Rhs::kCols;
157 
158     static_assert(LhsRows == Rows || LhsRows == 1, "");
159     static_assert(RhsRows == Rows || RhsRows == 1, "");
160     static_assert(LhsCols == Cols || LhsCols == 1, "");
161     static_assert(RhsCols == Cols || RhsCols == 1, "");
162     static_assert(ResultBlockType::kRegisterLanes == 1,
163                   "This path is only for scalar values");
164     static_assert(Lhs::kRegisterLanes == 1,
165                   "This path is only for scalar values");
166     static_assert(Rhs::kRegisterLanes == 1,
167                   "This path is only for scalar values");
168 
169     for (int c = 0; c < Cols; c++) {
170       const int lhs_c = LhsCols == Cols ? c : 0;
171       const int rhs_c = RhsCols == Cols ? c : 0;
172       for (int r = 0; r < Rows; r++) {
173         const int lhs_r = LhsRows == Rows ? r : 0;
174         const int rhs_r = RhsRows == Rows ? r : 0;
175         result.buf.reg[r + c * Rows] =
176             Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
177                 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
178       }
179     }
180     return result;
181   }
182 };
183 
184 template <typename Lhs, typename Rhs>
185 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
186     const Lhs& lhs, const Rhs& rhs) {
187   using Flip = FlipLhsRhs<Lhs, Rhs>;
188   return BroadcastAddImpl<
189       typename Flip::FlippedLhsType,
190       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
191                                           Flip::FlippedRhs(lhs, rhs));
192 }
193 
194 template <typename Lhs, typename Rhs>
195 struct BroadcastMulImpl {
196   using ResultBlockType =
197       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
198   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
199     ResultBlockType result;
200     static constexpr int Rows = ResultBlockType::kRows;
201     static constexpr int Cols = ResultBlockType::kCols;
202     static constexpr int LhsRows = Lhs::kRows;
203     static constexpr int LhsCols = Lhs::kCols;
204     static constexpr int RhsRows = Rhs::kRows;
205     static constexpr int RhsCols = Rhs::kCols;
206     static_assert(ResultBlockType::kRegisterLanes == 1,
207                   "This path is only for scalar values");
208     static_assert(Lhs::kRegisterLanes == 1,
209                   "This path is only for scalar values");
210     static_assert(Rhs::kRegisterLanes == 1,
211                   "This path is only for scalar values");
212 
213     static_assert(LhsRows == Rows || LhsRows == 1, "");
214     static_assert(RhsRows == Rows || RhsRows == 1, "");
215     static_assert(LhsCols == Cols || LhsCols == 1, "");
216     static_assert(RhsCols == Cols || RhsCols == 1, "");
217     for (int c = 0; c < Cols; c++) {
218       const int lhs_c = LhsCols == Cols ? c : 0;
219       const int rhs_c = RhsCols == Cols ? c : 0;
220       for (int r = 0; r < Rows; r++) {
221         const int lhs_r = LhsRows == Rows ? r : 0;
222         const int rhs_r = RhsRows == Rows ? r : 0;
223         result.buf.reg[r + c * Rows] =
224             Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
225                 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
226       }
227     }
228     return result;
229   }
230 };
231 
232 template <typename Lhs, typename Rhs>
233 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
234     const Lhs& lhs, const Rhs& rhs) {
235   using Flip = FlipLhsRhs<Lhs, Rhs>;
236   return BroadcastMulImpl<
237       typename Flip::FlippedLhsType,
238       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
239                                           Flip::FlippedRhs(lhs, rhs));
240 }
241 
242 template <typename Lhs, typename Rhs, typename Acc>
243 struct BroadcastMulAddImpl {
244   static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
245     static constexpr int Rows = Acc::kRows;
246     static constexpr int Cols = Acc::kCols;
247     static constexpr int LhsRows = Lhs::kRows;
248     static constexpr int LhsCols = Lhs::kCols;
249     static constexpr int RhsRows = Rhs::kRows;
250     static constexpr int RhsCols = Rhs::kCols;
251     static_assert(Acc::kRegisterLanes == 1,
252                   "This path is only for scalar values");
253     static_assert(Lhs::kRegisterLanes == 1,
254                   "This path is only for scalar values");
255     static_assert(Rhs::kRegisterLanes == 1,
256                   "This path is only for scalar values");
257 
258     static_assert(LhsRows == Rows || LhsRows == 1, "");
259     static_assert(RhsRows == Rows || RhsRows == 1, "");
260     static_assert(LhsCols == Cols || LhsCols == 1, "");
261     static_assert(RhsCols == Cols || RhsCols == 1, "");
262     for (int c = 0; c < Cols; c++) {
263       const int lhs_c = LhsCols == Cols ? c : 0;
264       const int rhs_c = RhsCols == Cols ? c : 0;
265       for (int r = 0; r < Rows; r++) {
266         const int lhs_r = LhsRows == Rows ? r : 0;
267         const int rhs_r = RhsRows == Rows ? r : 0;
268         MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
269                rhs.buf.reg[rhs_r + rhs_c * RhsRows],
270                &acc->buf.reg[r + c * Rows]);
271       }
272     }
273   }
274 };
275 
276 template <typename Lhs, typename Rhs, typename Acc>
277 void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
278   using Flip = FlipLhsRhs<Lhs, Rhs>;
279   BroadcastMulAddImpl<typename Flip::FlippedLhsType,
280                       typename Flip::FlippedRhsType,
281                       Acc>::Run(Flip::FlippedLhs(lhs, rhs),
282                                 Flip::FlippedRhs(lhs, rhs), acc);
283 }
284 
285 template <typename RegisterBlockType, typename SrcObjectType>
286 struct LoadImpl {
287   static_assert(std::is_same<SrcObjectType, void>::value,
288                 "This generic impl should never be hit");
289 };
290 
291 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
292 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
293                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
294   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
295   using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
296   static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
297     RegisterBlockType result;
298     int i = 0;
299     for (int c = 0; c < Cols; c++) {
300       const ScalarType* src_ptr = src.data(row, col + c);
301       for (int r = 0; r < Rows; r++) {
302         result.buf.reg[i++] = *src_ptr++;
303       }
304     }
305     return result;
306   }
307 };
308 
309 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
310           VectorShape Shape>
311 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
312                 VectorMap<SrcScalarType, Shape>> {
313   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
314   using SrcObjectType = VectorMap<SrcScalarType, Shape>;
315   static RegisterBlockType Run(const SrcObjectType& src, int pos) {
316     static_assert(Shape == VectorShape::Col || Rows == 1, "");
317     static_assert(Shape == VectorShape::Row || Cols == 1, "");
318     RegisterBlockType result;
319     for (int i = 0; i < Rows * Cols; i++) {
320       result.buf.reg[i] = src(pos + i);
321     }
322     return result;
323   }
324 };
325 
326 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
327           VectorShape Shape>
328 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
329                 VectorDup<SrcScalarType, Shape>> {
330   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
331   using SrcObjectType = VectorDup<SrcScalarType, Shape>;
332   static RegisterBlockType Run(const SrcObjectType& src, int) {
333     static_assert(Shape == VectorShape::Col || Rows == 1, "");
334     static_assert(Shape == VectorShape::Row || Cols == 1, "");
335     RegisterBlockType result;
336     for (int i = 0; i < Rows * Cols; i++) {
337       result.buf.reg[i] = src(0);
338     }
339     return result;
340   }
341 };
342 
343 template <typename RegisterBlockType, typename SrcObjectType>
344 RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
345   return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
346 }
347 
348 template <typename RegisterBlockType, typename SrcObjectType>
349 RegisterBlockType Load(const SrcObjectType& src, int pos) {
350   return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
351 }
352 
353 template <typename RegisterBlockType>
354 struct LoadContiguousImpl {
355   using ScalarType = typename RegisterBlockType::ScalarType;
356   static_assert(RegisterBlockType::kRegisterLanes == 1,
357                 "This path is only for scalar values");
358   static RegisterBlockType Run(const ScalarType* src) {
359     RegisterBlockType result;
360     for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
361       result.buf.reg[i] = src[i];
362     }
363     return result;
364   }
365 };
366 
367 template <typename RegisterBlockType>
368 RegisterBlockType LoadContiguous(
369     const typename RegisterBlockType::ScalarType* src) {
370   return LoadContiguousImpl<RegisterBlockType>::Run(src);
371 }
372 
373 template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
374 struct LoadForBroadcastingShape {};
375 
376 template <int BroadcastRows, int BroadcastCols, typename ScalarType,
377           VectorShape Shape>
378 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
379                                 VectorMap<ScalarType, Shape>> {
380   static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
381   static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
382 };
383 
384 template <int BroadcastRows, int BroadcastCols, typename ScalarType,
385           VectorShape Shape>
386 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
387                                 VectorDup<ScalarType, Shape>> {
388   static constexpr int kRows = 1;
389   static constexpr int kCols = 1;
390 };
391 
392 template <typename RegisterBlockType, typename SrcObjectType>
393 struct LoadForBroadcastingRegisterBlock {
394   using Shape =
395       LoadForBroadcastingShape<RegisterBlockType::kRows,
396                                RegisterBlockType::kCols, SrcObjectType>;
397   using ScalarType = typename RegisterBlockType::ScalarType;
398   using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
399 };
400 
401 template <typename RegisterBlockType, typename SrcObjectType>
402 struct LoadForBroadcastingImpl {
403   static_assert(std::is_same<SrcObjectType, void>::value,
404                 "This generic impl should never be hit");
405 };
406 
407 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
408           VectorShape Shape>
409 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
410                                VectorMap<SrcScalarType, Shape>> {
411   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
412   using SrcObjectType = VectorMap<SrcScalarType, Shape>;
413   using ResultBlockType =
414       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
415                                                 SrcObjectType>::Type;
416   static_assert(ResultBlockType::kRegisterLanes == 1,
417                 "This path is only for scalar values");
418   static ResultBlockType Run(const SrcObjectType& src, int pos) {
419     ResultBlockType result;
420     for (int c = 0; c < ResultBlockType::kCols; c++) {
421       for (int r = 0; r < ResultBlockType::kRows; r++) {
422         const int i = Shape == VectorShape::Col ? r : c;
423         result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
424       }
425     }
426     return result;
427   }
428 };
429 
430 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
431           VectorShape Shape>
432 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
433                                VectorDup<SrcScalarType, Shape>> {
434   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
435   using SrcObjectType = VectorDup<SrcScalarType, Shape>;
436   using ResultBlockType =
437       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
438                                                 SrcObjectType>::Type;
439   static_assert(ResultBlockType::kRegisterLanes == 1,
440                 "This path is only for scalar values");
441   static ResultBlockType Run(const SrcObjectType& src, int) {
442     ResultBlockType result;
443     for (int c = 0; c < ResultBlockType::kCols; c++) {
444       for (int r = 0; r < ResultBlockType::kRows; r++) {
445         result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
446       }
447     }
448     return result;
449   }
450 };
451 
452 template <typename RegisterBlockType, typename SrcObjectType>
453 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
454                                           SrcObjectType>::Type
455 LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
456   return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
457       src, row, col);
458 }
459 
460 template <typename RegisterBlockType, typename SrcObjectType>
461 typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
462                                           SrcObjectType>::Type
463 LoadForBroadcasting(const SrcObjectType& src, int pos) {
464   return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
465                                                                         pos);
466 }
467 
468 template <int ConstantValue, typename RegisterBlockType>
469 struct AddConstantImpl {
470   static void Run(RegisterBlockType* block) {
471     using RegisterType = typename RegisterBlockType::RegisterType;
472     const RegisterType dup = Dup<RegisterType>(ConstantValue);
473     for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
474       block->buf.reg[i] = Add(block->buf.reg[i], dup);
475     }
476   }
477 };
478 
479 template <typename RegisterBlockType>
480 struct AddConstantImpl<0, RegisterBlockType> {
481   static void Run(RegisterBlockType*) {
482     // This is a no-op.
483   }
484 };
485 
486 template <int ConstantValue, typename RegisterBlockType>
487 void AddConstant(RegisterBlockType* block) {
488   AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
489 }
490 
491 template <int N>
492 using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
493 template <int N>
494 using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
495 template <int N>
496 using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
497 template <int R, int C>
498 using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
499 template <int R, int C>
500 using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
501 template <int R, int C>
502 using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
503 
504 }  // end namespace gemmlowp
505 
506 #if defined GEMMLOWP_NEON
507 #include "simd_wrappers_neon.h"
508 #elif defined GEMMLOWP_SSE4
509 #include "simd_wrappers_sse.h"
510 #elif defined GEMMLOWP_MSA
511 #include "simd_wrappers_msa.h"
512 #endif
513 
514 #endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
515