1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // fixedpoint_SSE.h: optimized SSE specializations of the templates 16 // in fixedpoint.h. 17 18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ 19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ 20 21 #include <smmintrin.h> 22 #include "fixedpoint.h" 23 24 namespace gemmlowp { 25 26 // SSE intrinsics are not finely typed: there is a single __m128i vector 27 // type that does not distinguish between "int32x4" and "int16x8" use 28 // cases, unlike the NEON equivalents. Because we had initially focused 29 // on int32x4, we did not pay attention and specialized these fixedpoint 30 // templates directly for __m128i hardcoding the int32x4 semantics, 31 // not leaving room for int16x8 semantics. Amending that by adding a separate 32 // data type, int16x8_m128i, that wraps __m128i while being a separate 33 // type. 34 struct int16x8_m128i { int16x8_m128iint16x8_m128i35 int16x8_m128i() {} int16x8_m128iint16x8_m128i36 explicit int16x8_m128i(__m128i w) : v(w) {} ~int16x8_m128iint16x8_m128i37 ~int16x8_m128i() {} 38 39 __m128i v; 40 }; 41 42 template <> 43 struct FixedPointRawTypeTraits<__m128i> { 44 typedef std::int32_t ScalarRawType; 45 static const int kLanes = 4; 46 }; 47 48 template <> 49 struct FixedPointRawTypeTraits<int16x8_m128i> { 50 typedef std::int16_t ScalarRawType; 51 static const int kLanes = 8; 52 }; 53 54 template <> 55 inline __m128i BitAnd(__m128i a, __m128i b) { 56 return _mm_and_si128(a, b); 57 } 58 59 template <> 60 inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) { 61 return int16x8_m128i(_mm_and_si128(a.v, b.v)); 62 } 63 64 template <> 65 inline __m128i BitOr(__m128i a, __m128i b) { 66 return _mm_or_si128(a, b); 67 } 68 69 template <> 70 inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) { 71 return int16x8_m128i(_mm_or_si128(a.v, b.v)); 72 } 73 74 template <> 75 inline __m128i BitXor(__m128i a, __m128i b) { 76 return _mm_xor_si128(a, b); 77 } 78 79 template <> 80 inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) { 81 return int16x8_m128i(_mm_xor_si128(a.v, b.v)); 82 } 83 84 template <> 85 inline __m128i BitNot(__m128i a) { 86 return _mm_andnot_si128(a, _mm_set1_epi32(-1)); 87 } 88 89 template <> 90 inline int16x8_m128i BitNot(int16x8_m128i a) { 91 return int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1))); 92 } 93 94 template <> 95 inline __m128i Add(__m128i a, __m128i b) { 96 return _mm_add_epi32(a, b); 97 } 98 99 template <> 100 inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) { 101 return int16x8_m128i(_mm_add_epi16(a.v, b.v)); 102 } 103 104 template <> 105 inline __m128i Mul(__m128i a, __m128i b) { 106 return _mm_mullo_epi32(a, b); 107 } 108 109 template <> 110 inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) { 111 return int16x8_m128i(_mm_mullo_epi16(a.v, b.v)); 112 } 113 114 template <> 115 inline __m128i Sub(__m128i a, __m128i b) { 116 return _mm_sub_epi32(a, b); 117 } 118 119 template <> 120 inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) { 121 return int16x8_m128i(_mm_sub_epi16(a.v, b.v)); 122 } 123 124 template <> 125 inline __m128i Neg(__m128i a) { 126 return _mm_sign_epi32(a, _mm_set1_epi32(-1)); 127 } 128 129 template <> 130 inline int16x8_m128i Neg(int16x8_m128i a) { 131 return int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1))); 132 } 133 134 template <> 135 inline __m128i ShiftLeft(__m128i a, int offset) { 136 return _mm_slli_epi32(a, offset); 137 } 138 139 template <> 140 inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) { 141 return int16x8_m128i(_mm_slli_epi16(a.v, offset)); 142 } 143 144 template <> 145 inline __m128i ShiftRight(__m128i a, int offset) { 146 return _mm_srai_epi32(a, offset); 147 } 148 149 template <> 150 inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) { 151 return int16x8_m128i(_mm_srai_epi16(a.v, offset)); 152 } 153 154 template <> 155 inline __m128i SelectUsingMask(__m128i if_mask, __m128i then_val, 156 __m128i else_val) { 157 // borrowed from Intel's arm_neon_sse.h header. 158 return _mm_or_si128(_mm_and_si128(if_mask, then_val), 159 _mm_andnot_si128(if_mask, else_val)); 160 } 161 162 template <> 163 inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask, 164 int16x8_m128i then_val, 165 int16x8_m128i else_val) { 166 // borrowed from Intel's arm_neon_sse.h header. 167 return int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v)); 168 } 169 170 template <> 171 inline __m128i MaskIfEqual(__m128i a, __m128i b) { 172 return _mm_cmpeq_epi32(a, b); 173 } 174 175 template <> 176 inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) { 177 return int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v)); 178 } 179 180 template <> 181 inline __m128i MaskIfNotEqual(__m128i a, __m128i b) { 182 return BitNot(MaskIfEqual(a, b)); 183 } 184 185 template <> 186 inline int16x8_m128i MaskIfNotEqual(int16x8_m128i a, int16x8_m128i b) { 187 return BitNot(MaskIfEqual(a, b)); 188 } 189 190 template <> 191 inline __m128i MaskIfZero(__m128i a) { 192 return MaskIfEqual(a, _mm_set1_epi32(0)); 193 } 194 195 template <> 196 inline int16x8_m128i MaskIfZero(int16x8_m128i a) { 197 return MaskIfEqual(a, int16x8_m128i(_mm_set1_epi16(0))); 198 } 199 200 template <> 201 inline __m128i MaskIfNonZero(__m128i a) { 202 return MaskIfNotEqual(a, _mm_set1_epi32(0)); 203 } 204 205 template <> 206 inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) { 207 return MaskIfNotEqual(a, int16x8_m128i(_mm_set1_epi16(0))); 208 } 209 210 template <> 211 inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) { 212 return _mm_cmpgt_epi32(a, b); 213 } 214 215 template <> 216 inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) { 217 return int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v)); 218 } 219 220 template <> 221 inline __m128i MaskIfLessThan(__m128i a, __m128i b) { 222 return _mm_cmplt_epi32(a, b); 223 } 224 225 template <> 226 inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) { 227 return int16x8_m128i(_mm_cmplt_epi16(a.v, b.v)); 228 } 229 230 template <> 231 inline __m128i MaskIfGreaterThanOrEqual(__m128i a, __m128i b) { 232 return BitNot(MaskIfLessThan(a, b)); 233 } 234 235 template <> 236 inline int16x8_m128i MaskIfGreaterThanOrEqual(int16x8_m128i a, 237 int16x8_m128i b) { 238 return BitNot(MaskIfLessThan(a, b)); 239 } 240 241 template <> 242 inline __m128i MaskIfLessThanOrEqual(__m128i a, __m128i b) { 243 return BitNot(MaskIfGreaterThan(a, b)); 244 } 245 246 template <> 247 inline int16x8_m128i MaskIfLessThanOrEqual(int16x8_m128i a, int16x8_m128i b) { 248 return BitNot(MaskIfGreaterThan(a, b)); 249 } 250 251 /* Assumptions: 252 - All and Any are used on masks. 253 - masks are all_ones for true lanes, all_zeroes otherwise. 254 Hence, All means all 128bits set, and Any means any bit set. 255 */ 256 257 template <> 258 inline bool All(__m128i a) { 259 return _mm_testc_si128(a, a); 260 } 261 262 template <> 263 inline bool All(int16x8_m128i a) { 264 return _mm_testc_si128(a.v, a.v); 265 } 266 267 template <> 268 inline bool Any(__m128i a) { 269 return !_mm_testz_si128(a, a); 270 } 271 272 template <> 273 inline bool Any(int16x8_m128i a) { 274 return !_mm_testz_si128(a.v, a.v); 275 } 276 277 template <> 278 inline __m128i RoundingHalfSum(__m128i a, __m128i b) { 279 /* __m128i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */ 280 /* We divide the inputs before the add to avoid the overflow and costly test 281 */ 282 /* of checking if an overflow occured on signed add */ 283 /* round_bit_mask = _mm_set1_epi32(1); */ 284 /* a_over_2 = _mm_srai_epi32(a, 1); */ 285 /* b_over_2 = _mm_srai_epi32(b, 1); */ 286 /* sum = Add(a_over_2, b_over_2); */ 287 /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */ 288 /* return Add(sum, round_bit); */ 289 290 /* Other possibility detecting overflow and xor the sign if an overflow 291 * happened*/ 292 __m128i one, sign_bit_mask, sum, rounded_half_sum, overflow, result; 293 one = _mm_set1_epi32(1); 294 sign_bit_mask = _mm_set1_epi32(0x80000000); 295 sum = Add(a, b); 296 rounded_half_sum = _mm_srai_epi32(Add(sum, one), 1); 297 overflow = 298 BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)), 299 sign_bit_mask); 300 result = BitXor(rounded_half_sum, overflow); 301 return result; 302 } 303 304 template <> 305 inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) { 306 // Idea: go to unsigned to use _mm_avg_epu16, 307 // borrowed from Intel's arm_neon_sse.h header. 308 __m128i constant_neg_32768 = _mm_set1_epi16(-32768); 309 __m128i a_unsigned = _mm_sub_epi16(a.v, constant_neg_32768); 310 __m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768); 311 __m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned); 312 __m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768); 313 return int16x8_m128i(avg); 314 } 315 316 template <> 317 inline __m128i SaturatingRoundingDoublingHighMul(__m128i a, __m128i b) { 318 __m128i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3; 319 __m128i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded; 320 __m128i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result; 321 __m128i nudge; 322 323 // saturation only happen if a == b == INT_MIN 324 min = _mm_set1_epi32(std::numeric_limits<std::int32_t>::min()); 325 saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min)); 326 327 // a = a0 | a1 | a2 | a3 328 // b = b0 | b1 | b2 | b3 329 a0_a2 = a; 330 a1_a3 = _mm_srli_si128(a, 4); 331 b0_b2 = b; 332 b1_b3 = _mm_srli_si128(b, 4); 333 334 a0b0_a2b2 = _mm_mul_epi32(a0_a2, b0_b2); 335 a1b1_a3b3 = _mm_mul_epi32(a1_a3, b1_b3); 336 337 // do the rounding and take into account that it will be doubled 338 nudge = _mm_set1_epi64x(1 << 30); 339 a0b0_a2b2_rounded = _mm_add_epi64(a0b0_a2b2, nudge); 340 a1b1_a3b3_rounded = _mm_add_epi64(a1b1_a3b3, nudge); 341 342 // do the doubling 343 a0b0_a2b2_rounded_2x = _mm_slli_epi64(a0b0_a2b2_rounded, 1); 344 a1b1_a3b3_rounded_2x = _mm_slli_epi64(a1b1_a3b3_rounded, 1); 345 346 // get the high part of the products 347 result = _mm_blend_epi16(_mm_srli_si128(a0b0_a2b2_rounded_2x, 4), 348 a1b1_a3b3_rounded_2x, 0xcc); 349 350 // saturate those which overflowed 351 return SelectUsingMask(saturation_mask, min, result); 352 } 353 354 template <> 355 inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a, 356 int16x8_m128i b) { 357 // Idea: use _mm_mulhrs_epi16 then saturate with a bit-operation, 358 // borrowed from Intel's arm_neon_sse.h header. 359 __m128i result_unsaturated = _mm_mulhrs_epi16(a.v, b.v); 360 __m128i saturation_mask = 361 _mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000)); 362 __m128i result = _mm_xor_si128(result_unsaturated, saturation_mask); 363 return int16x8_m128i(result); 364 } 365 366 template <> 367 inline __m128i Dup<__m128i>(std::int32_t x) { 368 return _mm_set1_epi32(x); 369 } 370 371 template <> 372 inline int16x8_m128i Dup<int16x8_m128i>(std::int16_t x) { 373 return int16x8_m128i(_mm_set1_epi16(x)); 374 } 375 376 // So far this is only needed for int16. 377 template <> 378 inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) { 379 return int16x8_m128i(_mm_adds_epi16(a.v, b.v)); 380 } 381 382 } // end namespace gemmlowp 383 384 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ 385