1/****************************************************************************
2 * Copyright (C) 2017 Intel Corporation.   All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 ****************************************************************************/
23#if !defined(__SIMD_LIB_AVX512_HPP__)
24#error Do not include this file directly, use "simdlib.hpp" instead.
25#endif
26
27#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
28// gcc as of 7.1 was missing these intrinsics
29#ifndef _mm512_cmpneq_ps_mask
30#define _mm512_cmpneq_ps_mask(a, b) _mm512_cmp_ps_mask((a), (b), _CMP_NEQ_UQ)
31#endif
32
33#ifndef _mm512_cmplt_ps_mask
34#define _mm512_cmplt_ps_mask(a, b) _mm512_cmp_ps_mask((a), (b), _CMP_LT_OS)
35#endif
36
37#ifndef _mm512_cmplt_pd_mask
38#define _mm512_cmplt_pd_mask(a, b) _mm512_cmp_pd_mask((a), (b), _CMP_LT_OS)
39#endif
40
41#endif
42
43//============================================================================
44// SIMD16 AVX512 (F) implementation (compatible with Knights and Core
45// processors)
46//
47//============================================================================
48
49static const int TARGET_SIMD_WIDTH = 16;
50using SIMD256T                     = SIMD256Impl::AVX2Impl;
51
52#define SIMD_WRAPPER_1_(op, intrin) \
53    static SIMDINLINE Float SIMDCALL op(Float a) { return intrin(a); }
54
55#define SIMD_WRAPPER_1(op) SIMD_WRAPPER_1_(op, _mm512_##op)
56
57#define SIMD_WRAPPER_2_(op, intrin) \
58    static SIMDINLINE Float SIMDCALL op(Float a, Float b) { return _mm512_##intrin(a, b); }
59#define SIMD_WRAPPER_2(op) SIMD_WRAPPER_2_(op, op)
60
61#define SIMD_WRAPPERI_2_(op, intrin)                                          \
62    static SIMDINLINE Float SIMDCALL op(Float a, Float b)                     \
63    {                                                                         \
64        return _mm512_castsi512_ps(                                           \
65            _mm512_##intrin(_mm512_castps_si512(a), _mm512_castps_si512(b))); \
66    }
67
68#define SIMD_DWRAPPER_2(op) \
69    static SIMDINLINE Double SIMDCALL op(Double a, Double b) { return _mm512_##op(a, b); }
70
71#define SIMD_WRAPPER_2I_(op, intrin)                      \
72    template <int ImmT>                                   \
73    static SIMDINLINE Float SIMDCALL op(Float a, Float b) \
74    {                                                     \
75        return _mm512_##intrin(a, b, ImmT);               \
76    }
77#define SIMD_WRAPPER_2I(op) SIMD_WRAPPER_2I_(op, op)
78
79#define SIMD_DWRAPPER_2I_(op, intrin)                        \
80    template <int ImmT>                                      \
81    static SIMDINLINE Double SIMDCALL op(Double a, Double b) \
82    {                                                        \
83        return _mm512_##intrin(a, b, ImmT);                  \
84    }
85#define SIMD_DWRAPPER_2I(op) SIMD_DWRAPPER_2I_(op, op)
86
87#define SIMD_WRAPPER_3(op) \
88    static SIMDINLINE Float SIMDCALL op(Float a, Float b, Float c) { return _mm512_##op(a, b, c); }
89
90#define SIMD_IWRAPPER_1(op) \
91    static SIMDINLINE Integer SIMDCALL op(Integer a) { return _mm512_##op(a); }
92#define SIMD_IWRAPPER_1_8(op) \
93    static SIMDINLINE Integer SIMDCALL op(SIMD256Impl::Integer a) { return _mm512_##op(a); }
94
95#define SIMD_IWRAPPER_1_4(op) \
96    static SIMDINLINE Integer SIMDCALL op(SIMD128Impl::Integer a) { return _mm512_##op(a); }
97
98#define SIMD_IWRAPPER_1I_(op, intrin)                \
99    template <int ImmT>                              \
100    static SIMDINLINE Integer SIMDCALL op(Integer a) \
101    {                                                \
102        return intrin(a, ImmT);                      \
103    }
104#define SIMD_IWRAPPER_1I(op) SIMD_IWRAPPER_1I_(op, _mm512_##op)
105
106#define SIMD_IWRAPPER_2_(op, intrin) \
107    static SIMDINLINE Integer SIMDCALL op(Integer a, Integer b) { return _mm512_##intrin(a, b); }
108#define SIMD_IWRAPPER_2(op) SIMD_IWRAPPER_2_(op, op)
109
110#define SIMD_IWRAPPER_2_CMP(op, cmp) \
111    static SIMDINLINE Integer SIMDCALL op(Integer a, Integer b) { return cmp(a, b); }
112
113#define SIMD_IFWRAPPER_2(op, intrin)                                   \
114    static SIMDINLINE Integer SIMDCALL op(Integer a, Integer b)        \
115    {                                                                  \
116        return castps_si(_mm512_##intrin(castsi_ps(a), castsi_ps(b))); \
117    }
118
119#define SIMD_IWRAPPER_2I_(op, intrin)                           \
120    template <int ImmT>                                         \
121    static SIMDINLINE Integer SIMDCALL op(Integer a, Integer b) \
122    {                                                           \
123        return _mm512_##intrin(a, b, ImmT);                     \
124    }
125#define SIMD_IWRAPPER_2I(op) SIMD_IWRAPPER_2I_(op, op)
126
127private:
128static SIMDINLINE Integer vmask(__mmask16 m)
129{
130    return _mm512_maskz_set1_epi32(m, -1);
131}
132
133static SIMDINLINE Integer vmask(__mmask8 m)
134{
135    return _mm512_maskz_set1_epi64(m, -1LL);
136}
137
138public:
139//-----------------------------------------------------------------------
140// Single precision floating point arithmetic operations
141//-----------------------------------------------------------------------
142SIMD_WRAPPER_2(add_ps);                       // return a + b
143SIMD_WRAPPER_2(div_ps);                       // return a / b
144SIMD_WRAPPER_3(fmadd_ps);                     // return (a * b) + c
145SIMD_WRAPPER_3(fmsub_ps);                     // return (a * b) - c
146SIMD_WRAPPER_2(max_ps);                       // return (a > b) ? a : b
147SIMD_WRAPPER_2(min_ps);                       // return (a < b) ? a : b
148SIMD_WRAPPER_2(mul_ps);                       // return a * b
149SIMD_WRAPPER_1_(rcp_ps, _mm512_rcp14_ps);     // return 1.0f / a
150SIMD_WRAPPER_1_(rsqrt_ps, _mm512_rsqrt14_ps); // return 1.0f / sqrt(a)
151SIMD_WRAPPER_2(sub_ps);                       // return a - b
152
153template <RoundMode RMT>
154static SIMDINLINE Float SIMDCALL round_ps(Float a)
155{
156    return _mm512_roundscale_ps(a, static_cast<int>(RMT));
157}
158
159static SIMDINLINE Float SIMDCALL ceil_ps(Float a)
160{
161    return round_ps<RoundMode::CEIL_NOEXC>(a);
162}
163static SIMDINLINE Float SIMDCALL floor_ps(Float a)
164{
165    return round_ps<RoundMode::FLOOR_NOEXC>(a);
166}
167
168//-----------------------------------------------------------------------
169// Integer (various width) arithmetic operations
170//-----------------------------------------------------------------------
171SIMD_IWRAPPER_1(abs_epi32); // return absolute_value(a) (int32)
172SIMD_IWRAPPER_2(add_epi32); // return a + b (int32)
173// SIMD_IWRAPPER_2(add_epi8);  // return a + b (int8)
174// SIMD_IWRAPPER_2(adds_epu8); // return ((a + b) > 0xff) ? 0xff : (a + b) (uint8)
175SIMD_IWRAPPER_2(max_epi32); // return (a > b) ? a : b (int32)
176SIMD_IWRAPPER_2(max_epu32); // return (a > b) ? a : b (uint32)
177SIMD_IWRAPPER_2(min_epi32); // return (a < b) ? a : b (int32)
178SIMD_IWRAPPER_2(min_epu32); // return (a < b) ? a : b (uint32)
179SIMD_IWRAPPER_2(mul_epi32); // return a * b (int32)
180
181// return (a * b) & 0xFFFFFFFF
182//
183// Multiply the packed 32-bit integers in a and b, producing intermediate 64-bit integers,
184// and store the low 32 bits of the intermediate integers in dst.
185SIMD_IWRAPPER_2(mullo_epi32);
186SIMD_IWRAPPER_2(sub_epi32); // return a - b (int32)
187SIMD_IWRAPPER_2(sub_epi64); // return a - b (int64)
188// SIMD_IWRAPPER_2(subs_epu8); // return (b > a) ? 0 : (a - b) (uint8)
189
190//-----------------------------------------------------------------------
191// Logical operations
192//-----------------------------------------------------------------------
193SIMD_IWRAPPER_2_(and_si, and_si512);       // return a & b       (int)
194SIMD_IWRAPPER_2_(andnot_si, andnot_si512); // return (~a) & b    (int)
195SIMD_IWRAPPER_2_(or_si, or_si512);         // return a | b       (int)
196SIMD_IWRAPPER_2_(xor_si, xor_si512);       // return a ^ b       (int)
197
198// SIMD_WRAPPER_2(and_ps);                     // return a & b       (float treated as int)
199// SIMD_WRAPPER_2(andnot_ps);                  // return (~a) & b    (float treated as int)
200// SIMD_WRAPPER_2(or_ps);                      // return a | b       (float treated as int)
201// SIMD_WRAPPER_2(xor_ps);                     // return a ^ b       (float treated as int)
202
203//-----------------------------------------------------------------------
204// Shift operations
205//-----------------------------------------------------------------------
206SIMD_IWRAPPER_1I(slli_epi32); // return a << ImmT
207SIMD_IWRAPPER_2(sllv_epi32);
208SIMD_IWRAPPER_1I(srai_epi32); // return a >> ImmT   (int32)
209SIMD_IWRAPPER_1I(srli_epi32); // return a >> ImmT   (uint32)
210
211#if 0
212SIMD_IWRAPPER_1I_(srli_si, srli_si512);     // return a >> (ImmT*8) (uint)
213
214template<int ImmT>                              // same as srli_si, but with Float cast to int
215static SIMDINLINE Float SIMDCALL srlisi_ps(Float a)
216{
217    return castsi_ps(srli_si<ImmT>(castps_si(a)));
218}
219#endif
220
221SIMD_IWRAPPER_2(srlv_epi32);
222
223//-----------------------------------------------------------------------
224// Conversion operations
225//-----------------------------------------------------------------------
226static SIMDINLINE Float SIMDCALL castpd_ps(Double a) // return *(Float*)(&a)
227{
228    return _mm512_castpd_ps(a);
229}
230
231static SIMDINLINE Integer SIMDCALL castps_si(Float a) // return *(Integer*)(&a)
232{
233    return _mm512_castps_si512(a);
234}
235
236static SIMDINLINE Double SIMDCALL castsi_pd(Integer a) // return *(Double*)(&a)
237{
238    return _mm512_castsi512_pd(a);
239}
240
241static SIMDINLINE Double SIMDCALL castps_pd(Float a) // return *(Double*)(&a)
242{
243    return _mm512_castps_pd(a);
244}
245
246static SIMDINLINE Integer SIMDCALL castpd_si(Double a) // return *(Integer*)(&a)
247{
248    return _mm512_castpd_si512(a);
249}
250
251static SIMDINLINE Float SIMDCALL castsi_ps(Integer a) // return *(Float*)(&a)
252{
253    return _mm512_castsi512_ps(a);
254}
255
256static SIMDINLINE Float SIMDCALL cvtepi32_ps(Integer a) // return (float)a    (int32 --> float)
257{
258    return _mm512_cvtepi32_ps(a);
259}
260
261// SIMD_IWRAPPER_1_8(cvtepu8_epi16);     // return (int16)a    (uint8 --> int16)
262SIMD_IWRAPPER_1_4(cvtepu8_epi32);  // return (int32)a    (uint8 --> int32)
263SIMD_IWRAPPER_1_8(cvtepu16_epi32); // return (int32)a    (uint16 --> int32)
264SIMD_IWRAPPER_1_4(cvtepu16_epi64); // return (int64)a    (uint16 --> int64)
265SIMD_IWRAPPER_1_8(cvtepu32_epi64); // return (int64)a    (uint32 --> int64)
266
267static SIMDINLINE Integer SIMDCALL cvtps_epi32(Float a) // return (int32)a    (float --> int32)
268{
269    return _mm512_cvtps_epi32(a);
270}
271
272static SIMDINLINE Integer SIMDCALL
273                          cvttps_epi32(Float a) // return (int32)a    (rnd_to_zero(float) --> int32)
274{
275    return _mm512_cvttps_epi32(a);
276}
277
278//-----------------------------------------------------------------------
279// Comparison operations
280//-----------------------------------------------------------------------
281template <CompareType CmpTypeT>
282static SIMDINLINE Mask SIMDCALL cmp_ps_mask(Float a, Float b)
283{
284    return _mm512_cmp_ps_mask(a, b, static_cast<const int>(CmpTypeT));
285}
286
287template <CompareType CmpTypeT>
288static SIMDINLINE Float SIMDCALL cmp_ps(Float a, Float b) // return a (CmpTypeT) b
289{
290    // Legacy vector mask generator
291    __mmask16 result = cmp_ps_mask<CmpTypeT>(a, b);
292    return castsi_ps(vmask(result));
293}
294
295static SIMDINLINE Float SIMDCALL cmplt_ps(Float a, Float b)
296{
297    return cmp_ps<CompareType::LT_OQ>(a, b);
298}
299static SIMDINLINE Float SIMDCALL cmpgt_ps(Float a, Float b)
300{
301    return cmp_ps<CompareType::GT_OQ>(a, b);
302}
303static SIMDINLINE Float SIMDCALL cmpneq_ps(Float a, Float b)
304{
305    return cmp_ps<CompareType::NEQ_OQ>(a, b);
306}
307static SIMDINLINE Float SIMDCALL cmpeq_ps(Float a, Float b)
308{
309    return cmp_ps<CompareType::EQ_OQ>(a, b);
310}
311static SIMDINLINE Float SIMDCALL cmpge_ps(Float a, Float b)
312{
313    return cmp_ps<CompareType::GE_OQ>(a, b);
314}
315static SIMDINLINE Float SIMDCALL cmple_ps(Float a, Float b)
316{
317    return cmp_ps<CompareType::LE_OQ>(a, b);
318}
319
320template <CompareTypeInt CmpTypeT>
321static SIMDINLINE Integer SIMDCALL cmp_epi32(Integer a, Integer b)
322{
323    // Legacy vector mask generator
324    __mmask16 result = _mm512_cmp_epi32_mask(a, b, static_cast<const int>(CmpTypeT));
325    return vmask(result);
326}
327template <CompareTypeInt CmpTypeT>
328static SIMDINLINE Integer SIMDCALL cmp_epi64(Integer a, Integer b)
329{
330    // Legacy vector mask generator
331    __mmask8 result = _mm512_cmp_epi64_mask(a, b, static_cast<const int>(CmpTypeT));
332    return vmask(result);
333}
334
335// SIMD_IWRAPPER_2_CMP(cmpeq_epi8,  cmp_epi8<CompareTypeInt::EQ>);    // return a == b (int8)
336// SIMD_IWRAPPER_2_CMP(cmpeq_epi16, cmp_epi16<CompareTypeInt::EQ>);   // return a == b (int16)
337SIMD_IWRAPPER_2_CMP(cmpeq_epi32, cmp_epi32<CompareTypeInt::EQ>); // return a == b (int32)
338SIMD_IWRAPPER_2_CMP(cmpeq_epi64, cmp_epi64<CompareTypeInt::EQ>); // return a == b (int64)
339// SIMD_IWRAPPER_2_CMP(cmpgt_epi8,  cmp_epi8<CompareTypeInt::GT>);    // return a > b (int8)
340// SIMD_IWRAPPER_2_CMP(cmpgt_epi16, cmp_epi16<CompareTypeInt::GT>);   // return a > b (int16)
341SIMD_IWRAPPER_2_CMP(cmpgt_epi32, cmp_epi32<CompareTypeInt::GT>); // return a > b (int32)
342SIMD_IWRAPPER_2_CMP(cmpgt_epi64, cmp_epi64<CompareTypeInt::GT>); // return a > b (int64)
343SIMD_IWRAPPER_2_CMP(cmplt_epi32, cmp_epi32<CompareTypeInt::LT>); // return a < b (int32)
344
345static SIMDINLINE bool SIMDCALL testz_ps(Float a,
346                                         Float b) // return all_lanes_zero(a & b) ? 1 : 0 (float)
347{
348    return (0 == static_cast<int>(_mm512_test_epi32_mask(castps_si(a), castps_si(b))));
349}
350
351static SIMDINLINE bool SIMDCALL testz_si(Integer a,
352                                         Integer b) // return all_lanes_zero(a & b) ? 1 : 0 (int)
353{
354    return (0 == static_cast<int>(_mm512_test_epi32_mask(a, b)));
355}
356
357//-----------------------------------------------------------------------
358// Blend / shuffle / permute operations
359//-----------------------------------------------------------------------
360template <int ImmT>
361static SIMDINLINE Float blend_ps(Float a, Float b) // return ImmT ? b : a  (float)
362{
363    return _mm512_mask_blend_ps(__mmask16(ImmT), a, b);
364}
365
366template <int ImmT>
367static SIMDINLINE Integer blend_epi32(Integer a, Integer b) // return ImmT ? b : a  (int32)
368{
369    return _mm512_mask_blend_epi32(__mmask16(ImmT), a, b);
370}
371
372static SIMDINLINE Float blendv_ps(Float a, Float b, Float mask) // return mask ? b : a  (float)
373{
374    return _mm512_mask_blend_ps(__mmask16(movemask_ps(mask)), a, b);
375}
376
377static SIMDINLINE Integer SIMDCALL blendv_epi32(Integer a,
378                                                Integer b,
379                                                Float   mask) // return mask ? b : a (int)
380{
381    return castps_si(blendv_ps(castsi_ps(a), castsi_ps(b), mask));
382}
383
384static SIMDINLINE Integer SIMDCALL blendv_epi32(Integer a,
385                                                Integer b,
386                                                Integer mask) // return mask ? b : a (int)
387{
388    return castps_si(blendv_ps(castsi_ps(a), castsi_ps(b), castsi_ps(mask)));
389}
390
391static SIMDINLINE Float SIMDCALL
392                        broadcast_ss(float const* p) // return *p (all elements in vector get same value)
393{
394    return _mm512_set1_ps(*p);
395}
396
397template <int imm>
398static SIMDINLINE SIMD256Impl::Float SIMDCALL extract_ps(Float a)
399{
400    return _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), imm));
401}
402
403template <int imm>
404static SIMDINLINE SIMD256Impl::Double SIMDCALL extract_pd(Double a)
405{
406    return _mm512_extractf64x4_pd(a, imm);
407}
408
409template <int imm>
410static SIMDINLINE SIMD256Impl::Integer SIMDCALL extract_si(Integer a)
411{
412    return _mm512_extracti64x4_epi64(a, imm);
413}
414
415template <int imm>
416static SIMDINLINE Float SIMDCALL insert_ps(Float a, SIMD256Impl::Float b)
417{
418    return _mm512_castpd_ps(_mm512_insertf64x4(_mm512_castps_pd(a), _mm256_castps_pd(b), imm));
419}
420
421template <int imm>
422static SIMDINLINE Double SIMDCALL insert_pd(Double a, SIMD256Impl::Double b)
423{
424    return _mm512_insertf64x4(a, b, imm);
425}
426
427template <int imm>
428static SIMDINLINE Integer SIMDCALL insert_si(Integer a, SIMD256Impl::Integer b)
429{
430    return _mm512_inserti64x4(a, b, imm);
431}
432
433// SIMD_IWRAPPER_2(packs_epi16);   // See documentation for _mm512_packs_epi16 and
434// _mm512_packs_epi16 SIMD_IWRAPPER_2(packs_epi32);   // See documentation for _mm512_packs_epi32
435// and _mm512_packs_epi32 SIMD_IWRAPPER_2(packus_epi16);  // See documentation for
436// _mm512_packus_epi16 and _mm512_packus_epi16 SIMD_IWRAPPER_2(packus_epi32);  // See documentation
437// for _mm512_packus_epi32 and _mm512_packus_epi32
438
439template <int ImmT>
440static SIMDINLINE Float SIMDCALL permute_ps(Float const& a)
441{
442    return _mm512_permute_ps(a, ImmT);
443}
444
445static SIMDINLINE Integer SIMDCALL
446                          permute_epi32(Integer a, Integer swiz) // return a[swiz[i]] for each 32-bit lane i (float)
447{
448    return _mm512_permutexvar_epi32(swiz, a);
449}
450
451static SIMDINLINE Float SIMDCALL
452                        permute_ps(Float a, Integer swiz) // return a[swiz[i]] for each 32-bit lane i (float)
453{
454    return _mm512_permutexvar_ps(swiz, a);
455}
456
457SIMD_WRAPPER_2I_(permute2f128_ps, shuffle_f32x4);
458SIMD_DWRAPPER_2I_(permute2f128_pd, shuffle_f64x2);
459SIMD_IWRAPPER_2I_(permute2f128_si, shuffle_i32x4);
460
461SIMD_IWRAPPER_1I(shuffle_epi32);
462
463// SIMD_IWRAPPER_2(shuffle_epi8);
464SIMD_DWRAPPER_2I(shuffle_pd);
465SIMD_WRAPPER_2I(shuffle_ps);
466
467template <int ImmT>
468static SIMDINLINE Integer SIMDCALL shuffle_epi64(Integer a, Integer b)
469{
470    return castpd_si(shuffle_pd<ImmT>(castsi_pd(a), castsi_pd(b)));
471}
472
473SIMD_IWRAPPER_2(unpackhi_epi16);
474
475// SIMD_IFWRAPPER_2(unpackhi_epi32, _mm512_unpackhi_ps);
476static SIMDINLINE Integer SIMDCALL unpackhi_epi32(Integer a, Integer b)
477{
478    return castps_si(_mm512_unpackhi_ps(castsi_ps(a), castsi_ps(b)));
479}
480
481SIMD_IWRAPPER_2(unpackhi_epi64);
482// SIMD_IWRAPPER_2(unpackhi_epi8);
483SIMD_DWRAPPER_2(unpackhi_pd);
484SIMD_WRAPPER_2(unpackhi_ps);
485// SIMD_IWRAPPER_2(unpacklo_epi16);
486SIMD_IFWRAPPER_2(unpacklo_epi32, unpacklo_ps);
487SIMD_IWRAPPER_2(unpacklo_epi64);
488// SIMD_IWRAPPER_2(unpacklo_epi8);
489SIMD_DWRAPPER_2(unpacklo_pd);
490SIMD_WRAPPER_2(unpacklo_ps);
491
492//-----------------------------------------------------------------------
493// Load / store operations
494//-----------------------------------------------------------------------
495template <ScaleFactor ScaleT = ScaleFactor::SF_1>
496static SIMDINLINE Float SIMDCALL
497                        i32gather_ps(float const* p, Integer idx) // return *(float*)(((int8*)p) + (idx * ScaleT))
498{
499    return _mm512_i32gather_ps(idx, p, static_cast<int>(ScaleT));
500}
501
502static SIMDINLINE Float SIMDCALL
503                        load1_ps(float const* p) // return *p    (broadcast 1 value to all elements)
504{
505    return broadcast_ss(p);
506}
507
508static SIMDINLINE Float SIMDCALL
509                        load_ps(float const* p) // return *p    (loads SIMD width elements from memory)
510{
511    return _mm512_load_ps(p);
512}
513
514static SIMDINLINE Integer SIMDCALL load_si(Integer const* p) // return *p
515{
516    return _mm512_load_si512(&p->v);
517}
518
519static SIMDINLINE Float SIMDCALL
520                        loadu_ps(float const* p) // return *p    (same as load_ps but allows for unaligned mem)
521{
522    return _mm512_loadu_ps(p);
523}
524
525static SIMDINLINE Integer SIMDCALL
526                          loadu_si(Integer const* p) // return *p    (same as load_si but allows for unaligned mem)
527{
528    return _mm512_loadu_si512(p);
529}
530
531// for each element: (mask & (1 << 31)) ? (i32gather_ps<ScaleT>(p, idx), mask = 0) : old
532template <ScaleFactor ScaleT = ScaleFactor::SF_1>
533static SIMDINLINE Float SIMDCALL
534                        mask_i32gather_ps(Float old, float const* p, Integer idx, Float mask)
535{
536    __mmask16 k = _mm512_test_epi32_mask(castps_si(mask), set1_epi32(0x80000000));
537
538    return _mm512_mask_i32gather_ps(old, k, idx, p, static_cast<int>(ScaleT));
539}
540
541static SIMDINLINE void SIMDCALL maskstore_ps(float* p, Integer mask, Float src)
542{
543    Mask m = _mm512_cmplt_epi32_mask(mask, setzero_si());
544    _mm512_mask_store_ps(p, m, src);
545}
546
547// static SIMDINLINE uint64_t SIMDCALL movemask_epi8(Integer a)
548//{
549//    __mmask64 m = _mm512_cmplt_epi8_mask(a, setzero_si());
550//    return static_cast<uint64_t>(m);
551//}
552
553static SIMDINLINE uint32_t SIMDCALL movemask_pd(Double a)
554{
555    __mmask8 m = _mm512_test_epi64_mask(castpd_si(a), set1_epi64(0x8000000000000000LL));
556    return static_cast<uint32_t>(m);
557}
558static SIMDINLINE uint32_t SIMDCALL movemask_ps(Float a)
559{
560    __mmask16 m = _mm512_test_epi32_mask(castps_si(a), set1_epi32(0x80000000));
561    return static_cast<uint32_t>(m);
562}
563
564static SIMDINLINE Integer SIMDCALL set1_epi64(long long i) // return i (all elements are same value)
565{
566    return _mm512_set1_epi64(i);
567}
568
569static SIMDINLINE Integer SIMDCALL set1_epi32(int i) // return i (all elements are same value)
570{
571    return _mm512_set1_epi32(i);
572}
573
574static SIMDINLINE Integer SIMDCALL set1_epi8(char i) // return i (all elements are same value)
575{
576    return _mm512_set1_epi8(i);
577}
578
579static SIMDINLINE Float SIMDCALL set1_ps(float f) // return f (all elements are same value)
580{
581    return _mm512_set1_ps(f);
582}
583
584static SIMDINLINE Double SIMDCALL setzero_pd() // return 0 (double)
585{
586    return _mm512_setzero_pd();
587}
588
589static SIMDINLINE Float SIMDCALL setzero_ps() // return 0 (float)
590{
591    return _mm512_setzero_ps();
592}
593
594static SIMDINLINE Integer SIMDCALL setzero_si() // return 0 (integer)
595{
596    return _mm512_setzero_si512();
597}
598
599static SIMDINLINE void SIMDCALL
600                       store_ps(float* p, Float a) // *p = a   (stores all elements contiguously in memory)
601{
602    _mm512_store_ps(p, a);
603}
604
605static SIMDINLINE void SIMDCALL store_si(Integer* p, Integer a) // *p = a
606{
607    _mm512_store_si512(&p->v, a);
608}
609
610static SIMDINLINE void SIMDCALL
611                       storeu_si(Integer* p, Integer a) // *p = a    (same as store_si but allows for unaligned mem)
612{
613    _mm512_storeu_si512(&p->v, a);
614}
615
616static SIMDINLINE void SIMDCALL
617                       stream_ps(float* p, Float a) // *p = a   (same as store_ps, but doesn't keep memory in cache)
618{
619    _mm512_stream_ps(p, a);
620}
621
622static SIMDINLINE Integer SIMDCALL set_epi32(int i15,
623                                             int i14,
624                                             int i13,
625                                             int i12,
626                                             int i11,
627                                             int i10,
628                                             int i9,
629                                             int i8,
630                                             int i7,
631                                             int i6,
632                                             int i5,
633                                             int i4,
634                                             int i3,
635                                             int i2,
636                                             int i1,
637                                             int i0)
638{
639    return _mm512_set_epi32(i15, i14, i13, i12, i11, i10, i9, i8, i7, i6, i5, i4, i3, i2, i1, i0);
640}
641
642static SIMDINLINE Integer SIMDCALL
643                          set_epi32(int i7, int i6, int i5, int i4, int i3, int i2, int i1, int i0)
644{
645    return set_epi32(0, 0, 0, 0, 0, 0, 0, 0, i7, i6, i5, i4, i3, i2, i1, i0);
646}
647
648static SIMDINLINE Float SIMDCALL set_ps(float i15,
649                                        float i14,
650                                        float i13,
651                                        float i12,
652                                        float i11,
653                                        float i10,
654                                        float i9,
655                                        float i8,
656                                        float i7,
657                                        float i6,
658                                        float i5,
659                                        float i4,
660                                        float i3,
661                                        float i2,
662                                        float i1,
663                                        float i0)
664{
665    return _mm512_set_ps(i15, i14, i13, i12, i11, i10, i9, i8, i7, i6, i5, i4, i3, i2, i1, i0);
666}
667
668static SIMDINLINE Float SIMDCALL
669                        set_ps(float i7, float i6, float i5, float i4, float i3, float i2, float i1, float i0)
670{
671    return set_ps(0, 0, 0, 0, 0, 0, 0, 0, i7, i6, i5, i4, i3, i2, i1, i0);
672}
673
674static SIMDINLINE Float SIMDCALL vmask_ps(int32_t mask)
675{
676    return castsi_ps(_mm512_maskz_mov_epi32(__mmask16(mask), set1_epi32(-1)));
677}
678
679#undef SIMD_WRAPPER_1_
680#undef SIMD_WRAPPER_1
681#undef SIMD_WRAPPER_2
682#undef SIMD_WRAPPER_2_
683#undef SIMD_WRAPPERI_2_
684#undef SIMD_DWRAPPER_2
685#undef SIMD_DWRAPPER_2I
686#undef SIMD_WRAPPER_2I_
687#undef SIMD_WRAPPER_3_
688#undef SIMD_WRAPPER_2I
689#undef SIMD_WRAPPER_3
690#undef SIMD_IWRAPPER_1
691#undef SIMD_IWRAPPER_2
692#undef SIMD_IFWRAPPER_2
693#undef SIMD_IWRAPPER_2I
694#undef SIMD_IWRAPPER_1
695#undef SIMD_IWRAPPER_1I
696#undef SIMD_IWRAPPER_1I_
697#undef SIMD_IWRAPPER_2
698#undef SIMD_IWRAPPER_2_
699#undef SIMD_IWRAPPER_2I
700