1 /*===------------ avx512bf16intrin.h - AVX512_BF16 intrinsics --------------===
2  *
3  * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  * See https://llvm.org/LICENSE.txt for license information.
5  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  *
7  *===-----------------------------------------------------------------------===
8  */
9 #ifndef __IMMINTRIN_H
10 #error "Never use <avx512bf16intrin.h> directly; include <immintrin.h> instead."
11 #endif
12 
13 #ifndef __AVX512BF16INTRIN_H
14 #define __AVX512BF16INTRIN_H
15 
16 typedef short __m512bh __attribute__((__vector_size__(64), __aligned__(64)));
17 typedef short __m256bh __attribute__((__vector_size__(32), __aligned__(32)));
18 typedef unsigned short __bfloat16;
19 
20 #define __DEFAULT_FN_ATTRS512 \
21   __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \
22                  __min_vector_width__(512)))
23 #define __DEFAULT_FN_ATTRS                                                     \
24   __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16")))
25 
26 /// Convert One BF16 Data to One Single Float Data.
27 ///
28 /// \headerfile <x86intrin.h>
29 ///
30 /// This intrinsic does not correspond to a specific instruction.
31 ///
32 /// \param __A
33 ///    A bfloat data.
34 /// \returns A float data whose sign field and exponent field keep unchanged,
35 ///    and fraction field is extended to 23 bits.
_mm_cvtsbh_ss(__bfloat16 __A)36 static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bfloat16 __A) {
37   return __builtin_ia32_cvtsbf162ss_32(__A);
38 }
39 
40 /// Convert Two Packed Single Data to One Packed BF16 Data.
41 ///
42 /// \headerfile <x86intrin.h>
43 ///
44 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
45 ///
46 /// \param __A
47 ///    A 512-bit vector of [16 x float].
48 /// \param __B
49 ///    A 512-bit vector of [16 x float].
50 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from
51 ///    conversion of __B, and higher 256 bits come from conversion of __A.
52 static __inline__ __m512bh __DEFAULT_FN_ATTRS512
_mm512_cvtne2ps_pbh(__m512 __A,__m512 __B)53 _mm512_cvtne2ps_pbh(__m512 __A, __m512 __B) {
54   return (__m512bh)__builtin_ia32_cvtne2ps2bf16_512((__v16sf) __A,
55                                                     (__v16sf) __B);
56 }
57 
58 /// Convert Two Packed Single Data to One Packed BF16 Data.
59 ///
60 /// \headerfile <x86intrin.h>
61 ///
62 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
63 ///
64 /// \param __A
65 ///    A 512-bit vector of [16 x float].
66 /// \param __B
67 ///    A 512-bit vector of [16 x float].
68 /// \param __W
69 ///    A 512-bit vector of [32 x bfloat].
70 /// \param __U
71 ///    A 32-bit mask value specifying what is chosen for each element.
72 ///    A 1 means conversion of __A or __B. A 0 means element from __W.
73 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from
74 ///    conversion of __B, and higher 256 bits come from conversion of __A.
75 static __inline__ __m512bh __DEFAULT_FN_ATTRS512
_mm512_mask_cvtne2ps_pbh(__m512bh __W,__mmask32 __U,__m512 __A,__m512 __B)76 _mm512_mask_cvtne2ps_pbh(__m512bh __W, __mmask32 __U, __m512 __A, __m512 __B) {
77   return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U,
78                                         (__v32hi)_mm512_cvtne2ps_pbh(__A, __B),
79                                         (__v32hi)__W);
80 }
81 
82 /// Convert Two Packed Single Data to One Packed BF16 Data.
83 ///
84 /// \headerfile <x86intrin.h>
85 ///
86 /// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
87 ///
88 /// \param __A
89 ///    A 512-bit vector of [16 x float].
90 /// \param __B
91 ///    A 512-bit vector of [16 x float].
92 /// \param __U
93 ///    A 32-bit mask value specifying what is chosen for each element.
94 ///    A 1 means conversion of __A or __B. A 0 means element is zero.
95 /// \returns A 512-bit vector of [32 x bfloat] whose lower 256 bits come from
96 ///    conversion of __B, and higher 256 bits come from conversion of __A.
97 static __inline__ __m512bh __DEFAULT_FN_ATTRS512
_mm512_maskz_cvtne2ps_pbh(__mmask32 __U,__m512 __A,__m512 __B)98 _mm512_maskz_cvtne2ps_pbh(__mmask32 __U, __m512 __A, __m512 __B) {
99   return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U,
100                                         (__v32hi)_mm512_cvtne2ps_pbh(__A, __B),
101                                         (__v32hi)_mm512_setzero_si512());
102 }
103 
104 /// Convert Packed Single Data to Packed BF16 Data.
105 ///
106 /// \headerfile <x86intrin.h>
107 ///
108 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
109 ///
110 /// \param __A
111 ///    A 512-bit vector of [16 x float].
112 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A.
113 static __inline__ __m256bh __DEFAULT_FN_ATTRS512
_mm512_cvtneps_pbh(__m512 __A)114 _mm512_cvtneps_pbh(__m512 __A) {
115   return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
116                                               (__v16hi)_mm256_undefined_si256(),
117                                               (__mmask16)-1);
118 }
119 
120 /// Convert Packed Single Data to Packed BF16 Data.
121 ///
122 /// \headerfile <x86intrin.h>
123 ///
124 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
125 ///
126 /// \param __A
127 ///    A 512-bit vector of [16 x float].
128 /// \param __W
129 ///    A 256-bit vector of [16 x bfloat].
130 /// \param __U
131 ///    A 16-bit mask value specifying what is chosen for each element.
132 ///    A 1 means conversion of __A. A 0 means element from __W.
133 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A.
134 static __inline__ __m256bh __DEFAULT_FN_ATTRS512
_mm512_mask_cvtneps_pbh(__m256bh __W,__mmask16 __U,__m512 __A)135 _mm512_mask_cvtneps_pbh(__m256bh __W, __mmask16 __U, __m512 __A) {
136   return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
137                                                         (__v16hi)__W,
138                                                         (__mmask16)__U);
139 }
140 
141 /// Convert Packed Single Data to Packed BF16 Data.
142 ///
143 /// \headerfile <x86intrin.h>
144 ///
145 /// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
146 ///
147 /// \param __A
148 ///    A 512-bit vector of [16 x float].
149 /// \param __U
150 ///    A 16-bit mask value specifying what is chosen for each element.
151 ///    A 1 means conversion of __A. A 0 means element is zero.
152 /// \returns A 256-bit vector of [16 x bfloat] come from conversion of __A.
153 static __inline__ __m256bh __DEFAULT_FN_ATTRS512
_mm512_maskz_cvtneps_pbh(__mmask16 __U,__m512 __A)154 _mm512_maskz_cvtneps_pbh(__mmask16 __U, __m512 __A) {
155   return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
156                                                 (__v16hi)_mm256_setzero_si256(),
157                                                 (__mmask16)__U);
158 }
159 
160 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
161 ///
162 /// \headerfile <x86intrin.h>
163 ///
164 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
165 ///
166 /// \param __A
167 ///    A 512-bit vector of [32 x bfloat].
168 /// \param __B
169 ///    A 512-bit vector of [32 x bfloat].
170 /// \param __D
171 ///    A 512-bit vector of [16 x float].
172 /// \returns A 512-bit vector of [16 x float] comes from  Dot Product of
173 ///  __A, __B and __D
174 static __inline__ __m512 __DEFAULT_FN_ATTRS512
_mm512_dpbf16_ps(__m512 __D,__m512bh __A,__m512bh __B)175 _mm512_dpbf16_ps(__m512 __D, __m512bh __A, __m512bh __B) {
176   return (__m512)__builtin_ia32_dpbf16ps_512((__v16sf) __D,
177                                              (__v16si) __A,
178                                              (__v16si) __B);
179 }
180 
181 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
182 ///
183 /// \headerfile <x86intrin.h>
184 ///
185 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
186 ///
187 /// \param __A
188 ///    A 512-bit vector of [32 x bfloat].
189 /// \param __B
190 ///    A 512-bit vector of [32 x bfloat].
191 /// \param __D
192 ///    A 512-bit vector of [16 x float].
193 /// \param __U
194 ///    A 16-bit mask value specifying what is chosen for each element.
195 ///    A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
196 /// \returns A 512-bit vector of [16 x float] comes from  Dot Product of
197 ///  __A, __B and __D
198 static __inline__ __m512 __DEFAULT_FN_ATTRS512
_mm512_mask_dpbf16_ps(__m512 __D,__mmask16 __U,__m512bh __A,__m512bh __B)199 _mm512_mask_dpbf16_ps(__m512 __D, __mmask16 __U, __m512bh __A, __m512bh __B) {
200   return (__m512)__builtin_ia32_selectps_512((__mmask16)__U,
201                                        (__v16sf)_mm512_dpbf16_ps(__D, __A, __B),
202                                        (__v16sf)__D);
203 }
204 
205 /// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
206 ///
207 /// \headerfile <x86intrin.h>
208 ///
209 /// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
210 ///
211 /// \param __A
212 ///    A 512-bit vector of [32 x bfloat].
213 /// \param __B
214 ///    A 512-bit vector of [32 x bfloat].
215 /// \param __D
216 ///    A 512-bit vector of [16 x float].
217 /// \param __U
218 ///    A 16-bit mask value specifying what is chosen for each element.
219 ///    A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
220 /// \returns A 512-bit vector of [16 x float] comes from  Dot Product of
221 ///  __A, __B and __D
222 static __inline__ __m512 __DEFAULT_FN_ATTRS512
_mm512_maskz_dpbf16_ps(__mmask16 __U,__m512 __D,__m512bh __A,__m512bh __B)223 _mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) {
224   return (__m512)__builtin_ia32_selectps_512((__mmask16)__U,
225                                        (__v16sf)_mm512_dpbf16_ps(__D, __A, __B),
226                                        (__v16sf)_mm512_setzero_si512());
227 }
228 
229 /// Convert Packed BF16 Data to Packed float Data.
230 ///
231 /// \headerfile <x86intrin.h>
232 ///
233 /// \param __A
234 ///    A 256-bit vector of [16 x bfloat].
235 /// \returns A 512-bit vector of [16 x float] come from convertion of __A
_mm512_cvtpbh_ps(__m256bh __A)236 static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
237   return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32(
238       (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16));
239 }
240 
241 /// Convert Packed BF16 Data to Packed float Data using zeroing mask.
242 ///
243 /// \headerfile <x86intrin.h>
244 ///
245 /// \param __U
246 ///    A 16-bit mask. Elements are zeroed out when the corresponding mask
247 ///    bit is not set.
248 /// \param __A
249 ///    A 256-bit vector of [16 x bfloat].
250 /// \returns A 512-bit vector of [16 x float] come from convertion of __A
251 static __inline__ __m512 __DEFAULT_FN_ATTRS512
_mm512_maskz_cvtpbh_ps(__mmask16 __U,__m256bh __A)252 _mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
253   return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32(
254       (__m512i)_mm512_maskz_cvtepi16_epi32((__mmask16)__U, (__m256i)__A), 16));
255 }
256 
257 /// Convert Packed BF16 Data to Packed float Data using merging mask.
258 ///
259 /// \headerfile <x86intrin.h>
260 ///
261 /// \param __S
262 ///    A 512-bit vector of [16 x float]. Elements are copied from __S when
263 ///     the corresponding mask bit is not set.
264 /// \param __U
265 ///    A 16-bit mask.
266 /// \param __A
267 ///    A 256-bit vector of [16 x bfloat].
268 /// \returns A 512-bit vector of [16 x float] come from convertion of __A
269 static __inline__ __m512 __DEFAULT_FN_ATTRS512
_mm512_mask_cvtpbh_ps(__m512 __S,__mmask16 __U,__m256bh __A)270 _mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) {
271   return _mm512_castsi512_ps((__m512i)_mm512_mask_slli_epi32(
272       (__m512i)__S, (__mmask16)__U,
273       (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16));
274 }
275 
276 #undef __DEFAULT_FN_ATTRS
277 #undef __DEFAULT_FN_ATTRS512
278 
279 #endif
280