1/****************************************************************************
2 * Copyright (C) 2017 Intel Corporation.   All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 ****************************************************************************/
23#if !defined(__SIMD_LIB_AVX512_HPP__)
24#error Do not include this file directly, use "simdlib.hpp" instead.
25#endif
26
27//============================================================================
28// SIMD128 AVX (512) implementation
29//
30// Since this implementation inherits from the AVX (2) implementation,
31// the only operations below ones that replace AVX (2) operations.
32// These use native AVX512 instructions with masking to enable a larger
33// register set.
34//============================================================================
35
36private:
37static SIMDINLINE __m512 __conv(Float r)
38{
39    return _mm512_castps128_ps512(r.v);
40}
41static SIMDINLINE __m512d __conv(Double r)
42{
43    return _mm512_castpd128_pd512(r.v);
44}
45static SIMDINLINE __m512i __conv(Integer r)
46{
47    return _mm512_castsi128_si512(r.v);
48}
49static SIMDINLINE Float __conv(__m512 r)
50{
51    return _mm512_castps512_ps128(r);
52}
53static SIMDINLINE Double __conv(__m512d r)
54{
55    return _mm512_castpd512_pd128(r);
56}
57static SIMDINLINE Integer __conv(__m512i r)
58{
59    return _mm512_castsi512_si128(r);
60}
61
62public:
63#define SIMD_WRAPPER_1_(op, intrin, mask)                        \
64    static SIMDINLINE Float SIMDCALL op(Float a)                 \
65    {                                                            \
66        return __conv(_mm512_maskz_##intrin((mask), __conv(a))); \
67    }
68#define SIMD_WRAPPER_1(op) SIMD_WRAPPER_1_(op, op, __mmask16(0xf))
69
70#define SIMD_WRAPPER_1I_(op, intrin, mask)                             \
71    template <int ImmT>                                                \
72    static SIMDINLINE Float SIMDCALL op(Float a)                       \
73    {                                                                  \
74        return __conv(_mm512_maskz_##intrin((mask), __conv(a), ImmT)); \
75    }
76#define SIMD_WRAPPER_1I(op) SIMD_WRAPPER_1I_(op, op, __mmask16(0xf))
77
78#define SIMD_WRAPPER_2_(op, intrin, mask)                                   \
79    static SIMDINLINE Float SIMDCALL op(Float a, Float b)                   \
80    {                                                                       \
81        return __conv(_mm512_maskz_##intrin((mask), __conv(a), __conv(b))); \
82    }
83#define SIMD_WRAPPER_2(op) SIMD_WRAPPER_2_(op, op, __mmask16(0xf))
84
85#define SIMD_WRAPPER_2I(op)                                                \
86    template <int ImmT>                                                    \
87    static SIMDINLINE Float SIMDCALL op(Float a, Float b)                  \
88    {                                                                      \
89        return __conv(_mm512_maskz_##op(0xf, __conv(a), __conv(b), ImmT)); \
90    }
91
92#define SIMD_WRAPPER_3_(op, intrin, mask)                                              \
93    static SIMDINLINE Float SIMDCALL op(Float a, Float b, Float c)                     \
94    {                                                                                  \
95        return __conv(_mm512_maskz_##intrin((mask), __conv(a), __conv(b), __conv(c))); \
96    }
97#define SIMD_WRAPPER_3(op) SIMD_WRAPPER_3_(op, op, __mmask16(0xf))
98
99#define SIMD_DWRAPPER_2I(op)                                               \
100    template <int ImmT>                                                    \
101    static SIMDINLINE Double SIMDCALL op(Double a, Double b)               \
102    {                                                                      \
103        return __conv(_mm512_maskz_##op(0x3, __conv(a), __conv(b), ImmT)); \
104    }
105
106#define SIMD_IWRAPPER_1_(op, intrin, mask)                       \
107    static SIMDINLINE Integer SIMDCALL op(Integer a)             \
108    {                                                            \
109        return __conv(_mm512_maskz_##intrin((mask), __conv(a))); \
110    }
111#define SIMD_IWRAPPER_1_32(op) SIMD_IWRAPPER_1_(op, op, __mmask16(0xf))
112
113#define SIMD_IWRAPPER_1I_(op, intrin, mask)                            \
114    template <int ImmT>                                                \
115    static SIMDINLINE Integer SIMDCALL op(Integer a)                   \
116    {                                                                  \
117        return __conv(_mm512_maskz_##intrin((mask), __conv(a), ImmT)); \
118    }
119#define SIMD_IWRAPPER_1I_32(op) SIMD_IWRAPPER_1I_(op, op, __mmask16(0xf))
120
121#define SIMD_IWRAPPER_2_(op, intrin, mask)                                  \
122    static SIMDINLINE Integer SIMDCALL op(Integer a, Integer b)             \
123    {                                                                       \
124        return __conv(_mm512_maskz_##intrin((mask), __conv(a), __conv(b))); \
125    }
126#define SIMD_IWRAPPER_2_32(op) SIMD_IWRAPPER_2_(op, op, __mmask16(0xf))
127
128#define SIMD_IWRAPPER_2I(op)                                               \
129    template <int ImmT>                                                    \
130    static SIMDINLINE Integer SIMDCALL op(Integer a, Integer b)            \
131    {                                                                      \
132        return __conv(_mm512_maskz_##op(0xf, __conv(a), __conv(b), ImmT)); \
133    }
134
135//-----------------------------------------------------------------------
136// Single precision floating point arithmetic operations
137//-----------------------------------------------------------------------
138SIMD_WRAPPER_2(add_ps);                                // return a + b
139SIMD_WRAPPER_2(div_ps);                                // return a / b
140SIMD_WRAPPER_3(fmadd_ps);                              // return (a * b) + c
141SIMD_WRAPPER_3(fmsub_ps);                              // return (a * b) - c
142SIMD_WRAPPER_2(max_ps);                                // return (a > b) ? a : b
143SIMD_WRAPPER_2(min_ps);                                // return (a < b) ? a : b
144SIMD_WRAPPER_2(mul_ps);                                // return a * b
145SIMD_WRAPPER_1_(rcp_ps, rcp14_ps, __mmask16(0xf));     // return 1.0f / a
146SIMD_WRAPPER_1_(rsqrt_ps, rsqrt14_ps, __mmask16(0xf)); // return 1.0f / sqrt(a)
147SIMD_WRAPPER_2(sub_ps);                                // return a - b
148
149//-----------------------------------------------------------------------
150// Integer (various width) arithmetic operations
151//-----------------------------------------------------------------------
152SIMD_IWRAPPER_1_32(abs_epi32); // return absolute_value(a) (int32)
153SIMD_IWRAPPER_2_32(add_epi32); // return a + b (int32)
154SIMD_IWRAPPER_2_32(max_epi32); // return (a > b) ? a : b (int32)
155SIMD_IWRAPPER_2_32(max_epu32); // return (a > b) ? a : b (uint32)
156SIMD_IWRAPPER_2_32(min_epi32); // return (a < b) ? a : b (int32)
157SIMD_IWRAPPER_2_32(min_epu32); // return (a < b) ? a : b (uint32)
158SIMD_IWRAPPER_2_32(mul_epi32); // return a * b (int32)
159
160// SIMD_IWRAPPER_2_8(add_epi8);    // return a + b (int8)
161// SIMD_IWRAPPER_2_8(adds_epu8);   // return ((a + b) > 0xff) ? 0xff : (a + b) (uint8)
162
163// return (a * b) & 0xFFFFFFFF
164//
165// Multiply the packed 32-bit integers in a and b, producing intermediate 64-bit integers,
166// and store the low 32 bits of the intermediate integers in dst.
167SIMD_IWRAPPER_2_32(mullo_epi32);
168SIMD_IWRAPPER_2_32(sub_epi32); // return a - b (int32)
169
170// SIMD_IWRAPPER_2_64(sub_epi64);  // return a - b (int64)
171// SIMD_IWRAPPER_2_8(subs_epu8);   // return (b > a) ? 0 : (a - b) (uint8)
172
173//-----------------------------------------------------------------------
174// Logical operations
175//-----------------------------------------------------------------------
176SIMD_IWRAPPER_2_(and_si, and_epi32, __mmask16(0xf));       // return a & b       (int)
177SIMD_IWRAPPER_2_(andnot_si, andnot_epi32, __mmask16(0xf)); // return (~a) & b    (int)
178SIMD_IWRAPPER_2_(or_si, or_epi32, __mmask16(0xf));         // return a | b       (int)
179SIMD_IWRAPPER_2_(xor_si, xor_epi32, __mmask16(0xf));       // return a ^ b       (int)
180
181//-----------------------------------------------------------------------
182// Shift operations
183//-----------------------------------------------------------------------
184SIMD_IWRAPPER_1I_32(slli_epi32); // return a << ImmT
185SIMD_IWRAPPER_2_32(sllv_epi32);  // return a << b      (uint32)
186SIMD_IWRAPPER_1I_32(srai_epi32); // return a >> ImmT   (int32)
187SIMD_IWRAPPER_1I_32(srli_epi32); // return a >> ImmT   (uint32)
188SIMD_IWRAPPER_2_32(srlv_epi32);  // return a >> b      (uint32)
189
190// use AVX2 version
191// SIMD_IWRAPPER_1I_(srli_si, srli_si256);     // return a >> (ImmT*8) (uint)
192
193//-----------------------------------------------------------------------
194// Conversion operations (Use AVX2 versions)
195//-----------------------------------------------------------------------
196// SIMD_IWRAPPER_1L(cvtepu8_epi16, 0xffff);    // return (int16)a    (uint8 --> int16)
197// SIMD_IWRAPPER_1L(cvtepu8_epi32, 0xff);      // return (int32)a    (uint8 --> int32)
198// SIMD_IWRAPPER_1L(cvtepu16_epi32, 0xff);     // return (int32)a    (uint16 --> int32)
199// SIMD_IWRAPPER_1L(cvtepu16_epi64, 0xf);      // return (int64)a    (uint16 --> int64)
200// SIMD_IWRAPPER_1L(cvtepu32_epi64, 0xf);      // return (int64)a    (uint32 --> int64)
201
202//-----------------------------------------------------------------------
203// Comparison operations (Use AVX2 versions
204//-----------------------------------------------------------------------
205// SIMD_IWRAPPER_2_CMP(cmpeq_epi8);    // return a == b (int8)
206// SIMD_IWRAPPER_2_CMP(cmpeq_epi16);   // return a == b (int16)
207// SIMD_IWRAPPER_2_CMP(cmpeq_epi32);   // return a == b (int32)
208// SIMD_IWRAPPER_2_CMP(cmpeq_epi64);   // return a == b (int64)
209// SIMD_IWRAPPER_2_CMP(cmpgt_epi8,);   // return a > b (int8)
210// SIMD_IWRAPPER_2_CMP(cmpgt_epi16);   // return a > b (int16)
211// SIMD_IWRAPPER_2_CMP(cmpgt_epi32);   // return a > b (int32)
212// SIMD_IWRAPPER_2_CMP(cmpgt_epi64);   // return a > b (int64)
213//
214// static SIMDINLINE Integer SIMDCALL cmplt_epi32(Integer a, Integer b)   // return a < b (int32)
215//{
216//    return cmpgt_epi32(b, a);
217//}
218
219//-----------------------------------------------------------------------
220// Blend / shuffle / permute operations
221//-----------------------------------------------------------------------
222// SIMD_IWRAPPER_2_8(packs_epi16);     // int16 --> int8    See documentation for _mm256_packs_epi16
223// and _mm512_packs_epi16 SIMD_IWRAPPER_2_16(packs_epi32);    // int32 --> int16   See documentation
224// for _mm256_packs_epi32 and _mm512_packs_epi32 SIMD_IWRAPPER_2_8(packus_epi16);    // uint16 -->
225// uint8  See documentation for _mm256_packus_epi16 and _mm512_packus_epi16
226// SIMD_IWRAPPER_2_16(packus_epi32);   // uint32 --> uint16 See documentation for
227// _mm256_packus_epi32 and _mm512_packus_epi32 SIMD_IWRAPPER_2_(permute_epi32,
228// permutevar8x32_epi32);
229
230// static SIMDINLINE Float SIMDCALL permute_ps(Float a, Integer swiz)    // return a[swiz[i]] for
231// each 32-bit lane i (float)
232//{
233//    return _mm256_permutevar8x32_ps(a, swiz);
234//}
235
236SIMD_IWRAPPER_1I_32(shuffle_epi32);
237// template<int ImmT>
238// static SIMDINLINE Integer SIMDCALL shuffle_epi64(Integer a, Integer b)
239//{
240//    return castpd_si(shuffle_pd<ImmT>(castsi_pd(a), castsi_pd(b)));
241//}
242// SIMD_IWRAPPER_2(shuffle_epi8);
243SIMD_IWRAPPER_2_32(unpackhi_epi32);
244SIMD_IWRAPPER_2_32(unpacklo_epi32);
245
246// SIMD_IWRAPPER_2_16(unpackhi_epi16);
247// SIMD_IWRAPPER_2_64(unpackhi_epi64);
248// SIMD_IWRAPPER_2_8(unpackhi_epi8);
249// SIMD_IWRAPPER_2_16(unpacklo_epi16);
250// SIMD_IWRAPPER_2_64(unpacklo_epi64);
251// SIMD_IWRAPPER_2_8(unpacklo_epi8);
252
253//-----------------------------------------------------------------------
254// Load / store operations
255//-----------------------------------------------------------------------
256static SIMDINLINE Float SIMDCALL
257                        load_ps(float const* p) // return *p    (loads SIMD width elements from memory)
258{
259    return __conv(_mm512_maskz_loadu_ps(__mmask16(0xf), p));
260}
261
262static SIMDINLINE Integer SIMDCALL load_si(Integer const* p) // return *p
263{
264    return __conv(_mm512_maskz_loadu_epi32(__mmask16(0xf), p));
265}
266
267static SIMDINLINE Float SIMDCALL
268                        loadu_ps(float const* p) // return *p    (same as load_ps but allows for unaligned mem)
269{
270    return __conv(_mm512_maskz_loadu_ps(__mmask16(0xf), p));
271}
272
273static SIMDINLINE Integer SIMDCALL
274                          loadu_si(Integer const* p) // return *p    (same as load_si but allows for unaligned mem)
275{
276    return __conv(_mm512_maskz_loadu_epi32(__mmask16(0xf), p));
277}
278
279template <ScaleFactor ScaleT = ScaleFactor::SF_1>
280static SIMDINLINE Float SIMDCALL
281                        i32gather_ps(float const* p, Integer idx) // return *(float*)(((int8*)p) + (idx * ScaleT))
282{
283    return __conv(_mm512_mask_i32gather_ps(
284        _mm512_setzero_ps(), __mmask16(0xf), __conv(idx), p, static_cast<int>(ScaleT)));
285}
286
287// for each element: (mask & (1 << 31)) ? (i32gather_ps<ScaleT>(p, idx), mask = 0) : old
288template <ScaleFactor ScaleT = ScaleFactor::SF_1>
289static SIMDINLINE Float SIMDCALL
290                        mask_i32gather_ps(Float old, float const* p, Integer idx, Float mask)
291{
292    __mmask16 m = 0xf;
293    m           = _mm512_mask_test_epi32_mask(
294        m, _mm512_castps_si512(__conv(mask)), _mm512_set1_epi32(0x80000000));
295    return __conv(
296        _mm512_mask_i32gather_ps(__conv(old), m, __conv(idx), p, static_cast<int>(ScaleT)));
297}
298
299// static SIMDINLINE uint32_t SIMDCALL movemask_epi8(Integer a)
300// {
301//     __mmask64 m = 0xffffull;
302//     return static_cast<uint32_t>(
303//         _mm512_mask_test_epi8_mask(m, __conv(a), _mm512_set1_epi8(0x80)));
304// }
305
306static SIMDINLINE void SIMDCALL maskstore_ps(float* p, Integer mask, Float src)
307{
308    __mmask16 m = 0xf;
309    m           = _mm512_mask_test_epi32_mask(m, __conv(mask), _mm512_set1_epi32(0x80000000));
310    _mm512_mask_storeu_ps(p, m, __conv(src));
311}
312
313static SIMDINLINE void SIMDCALL
314                       store_ps(float* p, Float a) // *p = a   (stores all elements contiguously in memory)
315{
316    _mm512_mask_storeu_ps(p, __mmask16(0xf), __conv(a));
317}
318
319static SIMDINLINE void SIMDCALL store_si(Integer* p, Integer a) // *p = a
320{
321    _mm512_mask_storeu_epi32(p, __mmask16(0xf), __conv(a));
322}
323
324static SIMDINLINE Float SIMDCALL vmask_ps(int32_t mask)
325{
326    return castsi_ps(__conv(_mm512_maskz_set1_epi32(__mmask16(mask & 0xf), -1)));
327}
328
329//=======================================================================
330// Legacy interface (available only in SIMD256 width)
331//=======================================================================
332
333#undef SIMD_WRAPPER_1_
334#undef SIMD_WRAPPER_1
335#undef SIMD_WRAPPER_1I_
336#undef SIMD_WRAPPER_1I
337#undef SIMD_WRAPPER_2_
338#undef SIMD_WRAPPER_2
339#undef SIMD_WRAPPER_2I
340#undef SIMD_WRAPPER_3_
341#undef SIMD_WRAPPER_3
342#undef SIMD_DWRAPPER_1_
343#undef SIMD_DWRAPPER_1
344#undef SIMD_DWRAPPER_1I_
345#undef SIMD_DWRAPPER_1I
346#undef SIMD_DWRAPPER_2_
347#undef SIMD_DWRAPPER_2
348#undef SIMD_DWRAPPER_2I
349#undef SIMD_IWRAPPER_1_
350#undef SIMD_IWRAPPER_1_8
351#undef SIMD_IWRAPPER_1_16
352#undef SIMD_IWRAPPER_1_32
353#undef SIMD_IWRAPPER_1_64
354#undef SIMD_IWRAPPER_1I_
355#undef SIMD_IWRAPPER_1I_8
356#undef SIMD_IWRAPPER_1I_16
357#undef SIMD_IWRAPPER_1I_32
358#undef SIMD_IWRAPPER_1I_64
359#undef SIMD_IWRAPPER_2_
360#undef SIMD_IWRAPPER_2_8
361#undef SIMD_IWRAPPER_2_16
362#undef SIMD_IWRAPPER_2_32
363#undef SIMD_IWRAPPER_2_64
364#undef SIMD_IWRAPPER_2I
365//#undef SIMD_IWRAPPER_2I_8
366//#undef SIMD_IWRAPPER_2I_16
367//#undef SIMD_IWRAPPER_2I_32
368//#undef SIMD_IWRAPPER_2I_64
369