1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // fixedpoint_msa.h: optimized MSA specializations of the templates
16 // in fixedpoint.h.
17 
18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
20 
21 #include <msa.h>
22 
23 namespace gemmlowp {
24 
25 template <>
26 struct FixedPointRawTypeTraits<v4i32> {
27   typedef std::int32_t ScalarRawType;
28   static constexpr int kLanes = 4;
29 };
30 
31 template <>
32 struct FixedPointRawTypeTraits<v8i16> {
33   typedef std::int16_t ScalarRawType;
34   static constexpr int kLanes = 8;
35 };
36 
37 template <>
38 inline v4i32 BitAnd(v4i32 a, v4i32 b) {
39   return reinterpret_cast<v4i32>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a),
40                                                      reinterpret_cast<v16u8>(b)));
41 }
42 
43 template <>
44 inline v8i16 BitAnd(v8i16 a, v8i16 b) {
45   return reinterpret_cast<v8i16>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a),
46                                                      reinterpret_cast<v16u8>(b)));
47 }
48 
49 template <>
50 inline v4i32 BitOr(v4i32 a, v4i32 b) {
51   return reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a),
52                                                     reinterpret_cast<v16u8>(b)));
53 }
54 
55 template <>
56 inline v8i16 BitOr(v8i16 a, v8i16 b) {
57   return reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a),
58                                                     reinterpret_cast<v16u8>(b)));
59 }
60 
61 template <>
62 inline v4i32 BitXor(v4i32 a, v4i32 b) {
63   return reinterpret_cast<v4i32>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a),
64                                                      reinterpret_cast<v16u8>(b)));
65 }
66 
67 template <>
68 inline v8i16 BitXor(v8i16 a, v8i16 b) {
69   return reinterpret_cast<v8i16>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a),
70                                                      reinterpret_cast<v16u8>(b)));
71 }
72 
73 template <>
74 inline v4i32 BitNot(v4i32 a) {
75   return reinterpret_cast<v4i32>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a),
76                                                      reinterpret_cast<v16u8>(a)));
77 }
78 
79 template <>
80 inline v8i16 BitNot(v8i16 a) {
81   return reinterpret_cast<v8i16>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a),
82                                                      reinterpret_cast<v16u8>(a)));
83 }
84 
85 template <>
86 inline v4i32 Add(v4i32 a, v4i32 b) {
87   return __builtin_msa_addv_w(a, b);
88 }
89 
90 template <>
91 inline v8i16 Add(v8i16 a, v8i16 b) {
92   return __builtin_msa_addv_h(a, b);
93 }
94 
95 template <>
96 inline v4i32 Sub(v4i32 a, v4i32 b) {
97   return __builtin_msa_subv_w(a, b);
98 }
99 
100 template <>
101 inline v8i16 Sub(v8i16 a, v8i16 b) {
102   return __builtin_msa_subv_h(a, b);
103 }
104 
105 template <>
106 inline v4i32 Neg(v4i32 a) {
107   v4i32 zeroes = __builtin_msa_ldi_w(0);
108   return __builtin_msa_subv_w(zeroes, a);
109 }
110 
111 template <>
112 inline v8i16 Neg(v8i16 a) {
113   v8i16 zeroes = __builtin_msa_ldi_h(0);
114   return __builtin_msa_subv_h(zeroes, a);
115 }
116 
117 template <>
118 inline v4i32 ShiftLeft(v4i32 a, int offset) {
119   return __builtin_msa_sll_w(a, __builtin_msa_fill_w(offset));
120 }
121 
122 template <>
123 inline v8i16 ShiftLeft(v8i16 a, int offset) {
124   return __builtin_msa_sll_h(a, __builtin_msa_fill_h(offset));
125 }
126 
127 template <>
128 inline v4i32 ShiftRight(v4i32 a, int offset) {
129   return __builtin_msa_sra_w(a, __builtin_msa_fill_w(offset));
130 }
131 
132 template <>
133 inline v8i16 ShiftRight(v8i16 a, int offset) {
134   return __builtin_msa_sra_h(a, __builtin_msa_fill_h(offset));
135 }
136 
137 template <>
138 inline v4i32 SelectUsingMask(v4i32 if_mask, v4i32 then_val, v4i32 else_val) {
139   if_mask = reinterpret_cast<v4i32>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask),
140                                                          reinterpret_cast<v16u8>(else_val),
141                                                          reinterpret_cast<v16u8>(then_val)));
142   return if_mask;
143 }
144 
145 template <>
146 inline v8i16 SelectUsingMask(v8i16 if_mask, v8i16 then_val, v8i16 else_val) {
147   if_mask = reinterpret_cast<v8i16>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask),
148                                                          reinterpret_cast<v16u8>(else_val),
149                                                          reinterpret_cast<v16u8>(then_val)));
150   return if_mask;
151 }
152 
153 template <>
154 inline v4i32 MaskIfEqual(v4i32 a, v4i32 b) {
155   return __builtin_msa_ceq_w(a, b);
156 }
157 
158 template <>
159 inline v8i16 MaskIfEqual(v8i16 a, v8i16 b) {
160   return __builtin_msa_ceq_h(a, b);
161 }
162 
163 template <>
164 inline v4i32 MaskIfNotEqual(v4i32 a, v4i32 b) {
165   return BitNot(MaskIfEqual(a, b));
166 }
167 
168 template <>
169 inline v8i16 MaskIfNotEqual(v8i16 a, v8i16 b) {
170   return BitNot(MaskIfEqual(a, b));
171 }
172 
173 template <>
174 inline v4i32 MaskIfZero(v4i32 a) {
175   return __builtin_msa_ceqi_w(a, 0);
176 }
177 
178 template <>
179 inline v8i16 MaskIfZero(v8i16 a) {
180   return __builtin_msa_ceqi_h(a, 0);
181 }
182 
183 template <>
184 inline v4i32 MaskIfNonZero(v4i32 a) {
185   return BitNot(MaskIfZero(a));
186 }
187 
188 template <>
189 inline v8i16 MaskIfNonZero(v8i16 a) {
190   return BitNot(MaskIfZero(a));
191 }
192 
193 template <>
194 inline v4i32 MaskIfGreaterThan(v4i32 a, v4i32 b) {
195   return __builtin_msa_clt_s_w(b, a);
196 }
197 
198 template <>
199 inline v8i16 MaskIfGreaterThan(v8i16 a, v8i16 b) {
200   return __builtin_msa_clt_s_h(b, a);
201 }
202 
203 template <>
204 inline v4i32 MaskIfGreaterThanOrEqual(v4i32 a, v4i32 b) {
205   return __builtin_msa_cle_s_w(b, a);
206 }
207 
208 template <>
209 inline v8i16 MaskIfGreaterThanOrEqual(v8i16 a, v8i16 b) {
210   return __builtin_msa_cle_s_h(b, a);
211 }
212 
213 template <>
214 inline v4i32 MaskIfLessThan(v4i32 a, v4i32 b) {
215   return __builtin_msa_clt_s_w(a, b);
216 }
217 
218 template <>
219 inline v8i16 MaskIfLessThan(v8i16 a, v8i16 b) {
220   return __builtin_msa_clt_s_h(a, b);
221 }
222 
223 template <>
224 inline v4i32 MaskIfLessThanOrEqual(v4i32 a, v4i32 b) {
225   return __builtin_msa_cle_s_w(a, b);
226 }
227 
228 template <>
229 inline v8i16 MaskIfLessThanOrEqual(v8i16 a, v8i16 b) {
230   return __builtin_msa_cle_s_h(a, b);
231 }
232 
233 template <>
234 inline bool All(v4i32 a) {
235   return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a)));
236 }
237 
238 template <>
239 inline bool All(v8i16 a) {
240   return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a)));
241 }
242 
243 template <>
244 inline bool Any(v4i32 a) {
245   return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a));
246 }
247 
248 template <>
249 inline bool Any(v8i16 a) {
250   return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a));
251 }
252 
253 template <>
254 inline v4i32 RoundingHalfSum(v4i32 a, v4i32 b) {
255   return __builtin_msa_aver_s_w(a, b);
256 }
257 
258 template <>
259 inline v8i16 RoundingHalfSum(v8i16 a, v8i16 b) {
260   return __builtin_msa_aver_s_h(a, b);
261 }
262 
263 template <>
264 inline v4i32 SaturatingRoundingDoublingHighMul(v4i32 a, v4i32 b) {
265   return __builtin_msa_mulr_q_w(a, b);
266 }
267 
268 template <>
269 inline v8i16 SaturatingRoundingDoublingHighMul(v8i16 a, v8i16 b) {
270   return __builtin_msa_mulr_q_h(a, b);
271 }
272 
273 template <int Exponent>
274 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, 1> {
275   static v4i32 eval(v4i32 x) {
276     static_assert(Exponent >= 0 && Exponent < 32, "");
277     if (Exponent < 5) {
278       for (int i = 0; i < Exponent; i++) {
279         x = __builtin_msa_adds_s_w(x, x);
280       }
281       return x;
282     } else {
283       // Saturate each signed 32-bit element to (32 - Exponent)
284       // bits (this takes full care of negative elements).
285       v4i32 res = __builtin_msa_sat_s_w(x, 31 - Exponent);
286       // Set tmp to 0x7FFFFFFF for those elements which staturated
287       // to smaller (positive) values and 0 for all others.
288       v4i32 tmp = __builtin_msa_srli_w(__builtin_msa_clt_s_w(res, x), 1);
289       // Shift the saturated elements. The positive saturated elements
290       // will have Exponent trailing zero bits after the shift. Those
291       // need to be ones, not zeroes.
292       res = __builtin_msa_slli_w(res, Exponent);
293       // Finally, set those trailing zero bits to ones.
294       res = reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res),
295                                                        reinterpret_cast<v16u8>(tmp)));
296       return res;
297     }
298   }
299 };
300 
301 template <int Exponent>
302 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> {
303   static v8i16 eval(v8i16 x) {
304     static_assert(Exponent >= 0 && Exponent < 16, "");
305     if (Exponent < 5) {
306       for (int i = 0; i < Exponent; i++) {
307         x = __builtin_msa_adds_s_h(x, x);
308       }
309       return x;
310     } else {
311       // Saturate each signed 16-bit element to (16 - Exponent)
312       // bits (this takes full care of negative elements).
313       v8i16 res = __builtin_msa_sat_s_h(x, 15 - Exponent);
314       // Set tmp to 0x7FFF for those elements which staturated
315       // to smaller (positive) values and 0 for all others.
316       v8i16 tmp = __builtin_msa_srli_h(__builtin_msa_clt_s_h(res, x), 1);
317       // Shift the saturated elements. The positive saturated elements
318       // will have Exponent trailing zero bits after the shift. Those
319       // need to be ones, not zeroes.
320       res = __builtin_msa_slli_h(res, Exponent);
321       // Finally, set those trailing zero bits to ones.
322       res = reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res),
323                                                        reinterpret_cast<v16u8>(tmp)));
324       return res;
325     }
326   }
327 };
328 
329 template <int Exponent>
330 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> {
331   static v4i32 eval(v4i32 x) {
332     static_assert(-31 <= Exponent && Exponent <= -1, "");
333     // Isolate the sign bits.
334     v4i32 sign = __builtin_msa_srli_w(x, 31);
335     // Decrement the negative elements by 1 (with saturation).
336     x = __builtin_msa_subs_s_w(x, sign);
337     // Arithmetic shift right with rounding.
338     // The srari instruction rounds all midpoint values towards +infinity.
339     // It will correctly round negative midpoint values as we just
340     // decremented the negative values by 1.
341     return __builtin_msa_srari_w(x, -Exponent);
342   }
343 };
344 
345 template <int Exponent>
346 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> {
347   static v8i16 eval(v8i16 x) {
348     static_assert(-15 <= Exponent && Exponent <= -1, "");
349     // Isolate the sign bits.
350     v8i16 sign = __builtin_msa_srli_h(x, 15);
351     // Decrement the negative elements by 1 (with saturation).
352     x = __builtin_msa_subs_s_h(x, sign);
353     // Arithmetic shift right with rounding.
354     // The srari instruction rounds all midpoint values towards +infinity.
355     // It will correctly round negative midpoint values as we just
356     // decremented the negative values by 1.
357     return __builtin_msa_srari_h(x, -Exponent);
358   }
359 };
360 
361 template <>
362 inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) {
363   v4i32 e = __builtin_msa_fill_w(exponent);
364   // Isolate the sign bits.
365   v4i32 sign = __builtin_msa_srli_w(x, 31);
366   // Reset them to 0 if exponent is 0.
367   sign = __builtin_msa_min_s_w(sign, e);
368   // Decrement the negative elements by 1 (with saturation)
369   // if exponent is non-zero.
370   x = __builtin_msa_subs_s_w(x, sign);
371   // Arithmetic shift right with rounding.
372   // The srar instruction rounds all midpoint values towards +infinity.
373   // It will correctly round negative midpoint values as we just
374   // decremented the negative values by 1.
375   return __builtin_msa_srar_w(x, e);
376 }
377 
378 template <>
379 inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) {
380   v8i16 e = __builtin_msa_fill_h(exponent);
381   // Isolate the sign bits.
382   v8i16 sign = __builtin_msa_srli_h(x, 15);
383   // Reset them to 0 if exponent is 0.
384   sign = __builtin_msa_min_s_h(sign, e);
385   // Decrement the negative elements by 1 (with saturation)
386   // if exponent is non-zero.
387   x = __builtin_msa_subs_s_h(x, sign);
388   // Arithmetic shift right with rounding.
389   // The srar instruction rounds all midpoint values towards +infinity.
390   // It will correctly round negative midpoint values as we just
391   // decremented the negative values by 1.
392   return __builtin_msa_srar_h(x, e);
393 }
394 
395 template <>
396 inline v4i32 Dup<v4i32>(std::int32_t x) {
397   return __builtin_msa_fill_w(x);
398 }
399 
400 template <>
401 inline v8i16 Dup<v8i16>(std::int16_t x) {
402   return __builtin_msa_fill_h(x);
403 }
404 
405 // So far this is only needed for int16.
406 template <>
407 inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
408   return __builtin_msa_adds_s_h(a, b);
409 }
410 
411 }  // end namespace gemmlowp
412 
413 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
414