1 /****************************************************************************
2 * Copyright (C) 2017 Intel Corporation.   All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 ****************************************************************************/
23 #pragma once
24 
25 #if !defined(__cplusplus)
26 #error C++ compilation required
27 #endif
28 
29 #include <immintrin.h>
30 #include <inttypes.h>
31 #include <stdint.h>
32 
33 #define SIMD_ARCH_AVX       0
34 #define SIMD_ARCH_AVX2      1
35 #define SIMD_ARCH_AVX512    2
36 
37 #if !defined(SIMD_ARCH)
38 #define SIMD_ARCH SIMD_ARCH_AVX
39 #endif
40 
41 #if defined(_MSC_VER)
42 #define SIMDCALL __vectorcall
43 #define SIMDINLINE __forceinline
44 #define SIMDALIGN(type_, align_) __declspec(align(align_)) type_
45 #else
46 #define SIMDCALL
47 #define SIMDINLINE inline
48 #define SIMDALIGN(type_, align_) type_ __attribute__((aligned(align_)))
49 #endif
50 
51 // For documentation, please see the following include...
52 // #include "simdlib_interface.hpp"
53 
54 namespace SIMDImpl
55 {
56     enum class CompareType
57     {
58         EQ_OQ      = 0x00, // Equal (ordered, nonsignaling)
59         LT_OS      = 0x01, // Less-than (ordered, signaling)
60         LE_OS      = 0x02, // Less-than-or-equal (ordered, signaling)
61         UNORD_Q    = 0x03, // Unordered (nonsignaling)
62         NEQ_UQ     = 0x04, // Not-equal (unordered, nonsignaling)
63         NLT_US     = 0x05, // Not-less-than (unordered, signaling)
64         NLE_US     = 0x06, // Not-less-than-or-equal (unordered, signaling)
65         ORD_Q      = 0x07, // Ordered (nonsignaling)
66         EQ_UQ      = 0x08, // Equal (unordered, non-signaling)
67         NGE_US     = 0x09, // Not-greater-than-or-equal (unordered, signaling)
68         NGT_US     = 0x0A, // Not-greater-than (unordered, signaling)
69         FALSE_OQ   = 0x0B, // False (ordered, nonsignaling)
70         NEQ_OQ     = 0x0C, // Not-equal (ordered, non-signaling)
71         GE_OS      = 0x0D, // Greater-than-or-equal (ordered, signaling)
72         GT_OS      = 0x0E, // Greater-than (ordered, signaling)
73         TRUE_UQ    = 0x0F, // True (unordered, non-signaling)
74         EQ_OS      = 0x10, // Equal (ordered, signaling)
75         LT_OQ      = 0x11, // Less-than (ordered, nonsignaling)
76         LE_OQ      = 0x12, // Less-than-or-equal (ordered, nonsignaling)
77         UNORD_S    = 0x13, // Unordered (signaling)
78         NEQ_US     = 0x14, // Not-equal (unordered, signaling)
79         NLT_UQ     = 0x15, // Not-less-than (unordered, nonsignaling)
80         NLE_UQ     = 0x16, // Not-less-than-or-equal (unordered, nonsignaling)
81         ORD_S      = 0x17, // Ordered (signaling)
82         EQ_US      = 0x18, // Equal (unordered, signaling)
83         NGE_UQ     = 0x19, // Not-greater-than-or-equal (unordered, nonsignaling)
84         NGT_UQ     = 0x1A, // Not-greater-than (unordered, nonsignaling)
85         FALSE_OS   = 0x1B, // False (ordered, signaling)
86         NEQ_OS     = 0x1C, // Not-equal (ordered, signaling)
87         GE_OQ      = 0x1D, // Greater-than-or-equal (ordered, nonsignaling)
88         GT_OQ      = 0x1E, // Greater-than (ordered, nonsignaling)
89         TRUE_US    = 0x1F, // True (unordered, signaling)
90     };
91 
92 #if SIMD_ARCH >= SIMD_ARCH_AVX512
93     enum class CompareTypeInt
94     {
95         EQ  = _MM_CMPINT_EQ,    // Equal
96         LT  = _MM_CMPINT_LT,    // Less than
97         LE  = _MM_CMPINT_LE,    // Less than or Equal
98         NE  = _MM_CMPINT_NE,    // Not Equal
99         GE  = _MM_CMPINT_GE,    // Greater than or Equal
100         GT  = _MM_CMPINT_GT,    // Greater than
101     };
102 #endif // SIMD_ARCH >= SIMD_ARCH_AVX512
103 
104     enum class ScaleFactor
105     {
106         SF_1 = 1,   // No scaling
107         SF_2 = 2,   // Scale offset by 2
108         SF_4 = 4,   // Scale offset by 4
109         SF_8 = 8,   // Scale offset by 8
110     };
111 
112     enum class RoundMode
113     {
114         TO_NEAREST_INT  = 0x00, // Round to nearest integer == TRUNCATE(value + 0.5)
115         TO_NEG_INF      = 0x01, // Round to negative infinity
116         TO_POS_INF      = 0x02, // Round to positive infinity
117         TO_ZERO         = 0x03, // Round to 0 a.k.a. truncate
118         CUR_DIRECTION   = 0x04, // Round in direction set in MXCSR register
119 
120         RAISE_EXC       = 0x00, // Raise exception on overflow
121         NO_EXC          = 0x08, // Suppress exceptions
122 
123         NINT            = static_cast<int>(TO_NEAREST_INT)  | static_cast<int>(RAISE_EXC),
124         NINT_NOEXC      = static_cast<int>(TO_NEAREST_INT)  | static_cast<int>(NO_EXC),
125         FLOOR           = static_cast<int>(TO_NEG_INF)      | static_cast<int>(RAISE_EXC),
126         FLOOR_NOEXC     = static_cast<int>(TO_NEG_INF)      | static_cast<int>(NO_EXC),
127         CEIL            = static_cast<int>(TO_POS_INF)      | static_cast<int>(RAISE_EXC),
128         CEIL_NOEXC      = static_cast<int>(TO_POS_INF)      | static_cast<int>(NO_EXC),
129         TRUNC           = static_cast<int>(TO_ZERO)         | static_cast<int>(RAISE_EXC),
130         TRUNC_NOEXC     = static_cast<int>(TO_ZERO)         | static_cast<int>(NO_EXC),
131         RINT            = static_cast<int>(CUR_DIRECTION)   | static_cast<int>(RAISE_EXC),
132         NEARBYINT       = static_cast<int>(CUR_DIRECTION)   | static_cast<int>(NO_EXC),
133     };
134 
135     struct Traits
136     {
137         using CompareType = SIMDImpl::CompareType;
138         using ScaleFactor = SIMDImpl::ScaleFactor;
139         using RoundMode   = SIMDImpl::RoundMode;
140     };
141 
142     // Attribute, 4-dimensional attribute in SIMD SOA layout
143     template<typename Float, typename Integer, typename Double>
144     union Vec4
145     {
146         Float   v[4];
147         Integer vi[4];
148         Double  vd[4];
149         struct
150         {
151             Float  x;
152             Float  y;
153             Float  z;
154             Float  w;
155         };
operator [](const int i)156         SIMDINLINE Float& SIMDCALL operator[] (const int i) { return v[i]; }
operator [](const int i) const157         SIMDINLINE Float const & SIMDCALL operator[] (const int i) const { return v[i]; }
operator =(Vec4 const & in)158         SIMDINLINE Vec4& SIMDCALL operator=(Vec4 const & in)
159         {
160             v[0] = in.v[0];
161             v[1] = in.v[1];
162             v[2] = in.v[2];
163             v[3] = in.v[3];
164             return *this;
165         }
166     };
167 
168     namespace SIMD128Impl
169     {
170         union Float
171         {
172             SIMDINLINE Float() = default;
Float(__m128 in)173             SIMDINLINE Float(__m128 in) : v(in) {}
operator =(__m128 in)174             SIMDINLINE Float& SIMDCALL operator=(__m128 in) { v = in; return *this; }
operator =(Float const & in)175             SIMDINLINE Float& SIMDCALL operator=(Float const & in) { v = in.v; return *this; }
operator __m128() const176             SIMDINLINE SIMDCALL operator __m128() const { return v; }
177 
178             SIMDALIGN(__m128, 16) v;
179         };
180 
181         union Integer
182         {
183             SIMDINLINE Integer() = default;
Integer(__m128i in)184             SIMDINLINE Integer(__m128i in) : v(in) {}
operator =(__m128i in)185             SIMDINLINE Integer& SIMDCALL operator=(__m128i in) { v = in; return *this; }
operator =(Integer const & in)186             SIMDINLINE Integer& SIMDCALL operator=(Integer const & in) { v = in.v; return *this; }
operator __m128i() const187             SIMDINLINE SIMDCALL operator __m128i() const { return v; }
188 
189             SIMDALIGN(__m128i, 16) v;
190         };
191 
192         union Double
193         {
194             SIMDINLINE Double() = default;
Double(__m128d in)195             SIMDINLINE Double(__m128d in) : v(in) {}
operator =(__m128d in)196             SIMDINLINE Double& SIMDCALL operator=(__m128d in) { v = in; return *this; }
operator =(Double const & in)197             SIMDINLINE Double& SIMDCALL operator=(Double const & in) { v = in.v; return *this; }
operator __m128d() const198             SIMDINLINE SIMDCALL operator __m128d() const { return v; }
199 
200             SIMDALIGN(__m128d, 16) v;
201         };
202 
203         using Vec4 = SIMDImpl::Vec4<Float, Integer, Double>;
204         using Mask = uint8_t;
205 
206         static const uint32_t SIMD_WIDTH = 4;
207     } // ns SIMD128Impl
208 
209     namespace SIMD256Impl
210     {
211         union Float
212         {
213             SIMDINLINE Float() = default;
Float(__m256 in)214             SIMDINLINE Float(__m256 in) : v(in) {}
Float(SIMD128Impl::Float const & in_lo,SIMD128Impl::Float const & in_hi=_mm_setzero_ps ())215             SIMDINLINE Float(SIMD128Impl::Float const &in_lo, SIMD128Impl::Float const &in_hi = _mm_setzero_ps())
216             {
217                 v = _mm256_insertf128_ps(_mm256_castps128_ps256(in_lo), in_hi, 0x1);
218             }
operator =(__m256 in)219             SIMDINLINE Float& SIMDCALL operator=(__m256 in) { v = in; return *this; }
operator =(Float const & in)220             SIMDINLINE Float& SIMDCALL operator=(Float const & in) { v = in.v; return *this; }
operator __m256() const221             SIMDINLINE SIMDCALL operator __m256() const { return v; }
222 
223             SIMDALIGN(__m256, 32) v;
224             SIMD128Impl::Float v4[2];
225         };
226 
227         union Integer
228         {
229             SIMDINLINE Integer() = default;
Integer(__m256i in)230             SIMDINLINE Integer(__m256i in) : v(in) {}
Integer(SIMD128Impl::Integer const & in_lo,SIMD128Impl::Integer const & in_hi=_mm_setzero_si128 ())231             SIMDINLINE Integer(SIMD128Impl::Integer const &in_lo, SIMD128Impl::Integer const &in_hi = _mm_setzero_si128())
232             {
233                 v = _mm256_insertf128_si256(_mm256_castsi128_si256(in_lo), in_hi, 0x1);
234             }
operator =(__m256i in)235             SIMDINLINE Integer& SIMDCALL operator=(__m256i in) { v = in; return *this; }
operator =(Integer const & in)236             SIMDINLINE Integer& SIMDCALL operator=(Integer const & in) { v = in.v; return *this; }
operator __m256i() const237             SIMDINLINE SIMDCALL operator __m256i() const { return v; }
238 
239             SIMDALIGN(__m256i, 32) v;
240             SIMD128Impl::Integer v4[2];
241         };
242 
243         union Double
244         {
245             SIMDINLINE Double() = default;
Double(__m256d const & in)246             SIMDINLINE Double(__m256d const &in) : v(in) {}
Double(SIMD128Impl::Double const & in_lo,SIMD128Impl::Double const & in_hi=_mm_setzero_pd ())247             SIMDINLINE Double(SIMD128Impl::Double const &in_lo, SIMD128Impl::Double const &in_hi = _mm_setzero_pd())
248             {
249                 v = _mm256_insertf128_pd(_mm256_castpd128_pd256(in_lo), in_hi, 0x1);
250             }
operator =(__m256d in)251             SIMDINLINE Double& SIMDCALL operator=(__m256d in) { v = in; return *this; }
operator =(Double const & in)252             SIMDINLINE Double& SIMDCALL operator=(Double const & in) { v = in.v; return *this; }
operator __m256d() const253             SIMDINLINE SIMDCALL operator __m256d() const { return v; }
254 
255             SIMDALIGN(__m256d, 32) v;
256             SIMD128Impl::Double v4[2];
257         };
258 
259         using Vec4 = SIMDImpl::Vec4<Float, Integer, Double>;
260         using Mask = uint8_t;
261 
262         static const uint32_t SIMD_WIDTH = 8;
263     } // ns SIMD256Impl
264 
265     namespace SIMD512Impl
266     {
267 #if !(defined(__AVX512F__) || defined(_MM_K0_REG))
268         // Define AVX512 types if not included via immintrin.h.
269         // All data members of these types are ONLY to viewed
270         // in a debugger.  Do NOT access them via code!
271         union __m512
272         {
273         private:
274             float m512_f32[16];
275         };
276         struct __m512d
277         {
278         private:
279             double m512d_f64[8];
280         };
281 
282         union __m512i
283         {
284         private:
285             int8_t              m512i_i8[64];
286             int16_t             m512i_i16[32];
287             int32_t             m512i_i32[16];
288             int64_t             m512i_i64[8];
289             uint8_t             m512i_u8[64];
290             uint16_t            m512i_u16[32];
291             uint32_t            m512i_u32[16];
292             uint64_t            m512i_u64[8];
293         };
294 
295         using __mmask16 = uint16_t;
296 #endif
297 
298 #if defined(__INTEL_COMPILER) || (SIMD_ARCH >= SIMD_ARCH_AVX512)
299 #define SIMD_ALIGNMENT_BYTES 64
300 #else
301 #define SIMD_ALIGNMENT_BYTES 32
302 #endif
303 
304         union Float
305         {
306             SIMDINLINE Float() = default;
Float(__m512 in)307             SIMDINLINE Float(__m512 in) : v(in) {}
Float(SIMD256Impl::Float const & in_lo,SIMD256Impl::Float const & in_hi=_mm256_setzero_ps ())308             SIMDINLINE Float(SIMD256Impl::Float const &in_lo, SIMD256Impl::Float const &in_hi = _mm256_setzero_ps()) { v8[0] = in_lo; v8[1] = in_hi; }
operator =(__m512 in)309             SIMDINLINE Float& SIMDCALL operator=(__m512 in) { v = in; return *this; }
operator =(Float const & in)310             SIMDINLINE Float& SIMDCALL operator=(Float const & in)
311             {
312 #if SIMD_ARCH >= SIMD_ARCH_AVX512
313                 v = in.v;
314 #else
315                 v8[0] = in.v8[0];
316                 v8[1] = in.v8[1];
317 #endif
318                 return *this;
319             }
operator __m512() const320             SIMDINLINE SIMDCALL operator __m512() const { return v; }
321 
322             SIMDALIGN(__m512, SIMD_ALIGNMENT_BYTES) v;
323             SIMD256Impl::Float v8[2];
324         };
325 
326         union Integer
327         {
328             SIMDINLINE Integer() = default;
Integer(__m512i in)329             SIMDINLINE Integer(__m512i in) : v(in) {}
Integer(SIMD256Impl::Integer const & in_lo,SIMD256Impl::Integer const & in_hi=_mm256_setzero_si256 ())330             SIMDINLINE Integer(SIMD256Impl::Integer const &in_lo, SIMD256Impl::Integer const &in_hi = _mm256_setzero_si256()) { v8[0] = in_lo; v8[1] = in_hi; }
operator =(__m512i in)331             SIMDINLINE Integer& SIMDCALL operator=(__m512i in) { v = in; return *this; }
operator =(Integer const & in)332             SIMDINLINE Integer& SIMDCALL operator=(Integer const & in)
333             {
334 #if SIMD_ARCH >= SIMD_ARCH_AVX512
335                 v = in.v;
336 #else
337                 v8[0] = in.v8[0];
338                 v8[1] = in.v8[1];
339 #endif
340                 return *this;
341             }
342 
operator __m512i() const343             SIMDINLINE SIMDCALL operator __m512i() const { return v; }
344 
345             SIMDALIGN(__m512i, SIMD_ALIGNMENT_BYTES) v;
346             SIMD256Impl::Integer v8[2];
347         };
348 
349         union Double
350         {
351             SIMDINLINE Double() = default;
Double(__m512d in)352             SIMDINLINE Double(__m512d in) : v(in) {}
Double(SIMD256Impl::Double const & in_lo,SIMD256Impl::Double const & in_hi=_mm256_setzero_pd ())353             SIMDINLINE Double(SIMD256Impl::Double const &in_lo, SIMD256Impl::Double const &in_hi = _mm256_setzero_pd()) { v8[0] = in_lo; v8[1] = in_hi; }
operator =(__m512d in)354             SIMDINLINE Double& SIMDCALL operator=(__m512d in) { v = in; return *this; }
operator =(Double const & in)355             SIMDINLINE Double& SIMDCALL operator=(Double const & in)
356             {
357 #if SIMD_ARCH >= SIMD_ARCH_AVX512
358                 v = in.v;
359 #else
360                 v8[0] = in.v8[0];
361                 v8[1] = in.v8[1];
362 #endif
363                 return *this;
364             }
365 
operator __m512d() const366             SIMDINLINE SIMDCALL operator __m512d() const { return v; }
367 
368             SIMDALIGN(__m512d, SIMD_ALIGNMENT_BYTES) v;
369             SIMD256Impl::Double v8[2];
370         };
371 
372         typedef SIMDImpl::Vec4<Float, Integer, Double> SIMDALIGN(Vec4, 64);
373         using Mask = __mmask16;
374 
375         static const uint32_t SIMD_WIDTH = 16;
376 
377 #undef SIMD_ALIGNMENT_BYTES
378     } // ns SIMD512Impl
379 } // ns SIMDImpl
380