1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
18 
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/core/platform/byte_order.h"
21 #include "tensorflow/core/platform/types.h"
22 
23 #if defined(PLATFORM_WINDOWS)
24 #include "tensorflow/core/platform/windows/cpu_info.h"
25 #include "tensorflow/core/platform/windows/intrinsics_port.h"
26 #endif
27 
28 namespace Eigen {
29 namespace internal {
30 
31 // Return the float representation of the bfloat16 value
32 // in the lower 16-bits of input
33 template <typename Packet>
pexpand_bf16_l(const Packet & from)34 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
35   tensorflow::uint32 tmp;
36 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
37   tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
38 #else
39   tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
40 #endif
41   return reinterpret_cast<const float&>(tmp);
42 }
43 
44 // Return the float representation of the bfloat16 value
45 // in the upper 16-bits of input
46 template <typename Packet>
pexpand_bf16_u(const Packet & from)47 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
48   tensorflow::uint32 tmp;
49 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
50   tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
51 #else
52   tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
53 #endif
54   return reinterpret_cast<const float&>(tmp);
55 }
56 
57 // Specialization non-scalar version on non-sse.
58 // Enable vectorization on z13 and higher
59 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
60     defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
61 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)62 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
63   float r[4];
64   tensorflow::uint32 p[4];
65   pstoreu(r, from);
66   tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
67   p[0] = (ir[0] << 16) & 0xffff0000;
68   p[1] = ir[0] & 0xffff0000;
69   p[2] = (ir[1] << 16) & 0xffff0000;
70   p[3] = ir[1] & 0xffff0000;
71   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
72 }
73 
74 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)75 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
76   float r[4];
77   tensorflow::uint32 p[4];
78   pstoreu(r, from);
79   tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
80   p[0] = (ir[2] << 16) & 0xffff0000;
81   p[1] = ir[2] & 0xffff0000;
82   p[2] = (ir[3] << 16) & 0xffff0000;
83   p[3] = ir[3] & 0xffff0000;
84   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
85 }
86 #endif
87 
88 template <typename Packet>
pinterleave4x64(const Packet & from)89 EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
90   return from;
91 }
92 
93 template <typename Packet>
pbroadcast_first(const Packet & a)94 EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) {
95   return a;
96 }
97 
98 template <typename Packet>
pbroadcast_second(const Packet & a)99 EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) {
100   assert(false && "Not applicable to Scalar Values");
101   return a;
102 }
103 
104 template <typename Packet>
pbroadcast_third(const Packet & a)105 EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) {
106   assert(false && "Not applicable to Scalar Values");
107   return a;
108 }
109 
110 template <typename Packet>
pbroadcast_fourth(const Packet & a)111 EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) {
112   assert(false && "Not applicable to Scalar Values");
113   return a;
114 }
115 
116 template <typename Packet>
pload4bf16(const typename unpacket_traits<Packet>::type * from)117 EIGEN_DEVICE_FUNC inline Packet pload4bf16(
118     const typename unpacket_traits<Packet>::type* from) {
119   assert(false && "Not applicable to Scalar Values");
120   return Packet();
121 }
122 
123 template <typename Packet>
pload2bf16(const typename unpacket_traits<Packet>::type * from)124 EIGEN_DEVICE_FUNC inline Packet pload2bf16(
125     const typename unpacket_traits<Packet>::type* from) {
126   assert(false && "Not applicable to Scalar Values");
127   return Packet();
128 }
129 
130 // Specialization for pload4bf16 and pload2bf16 for non-sse.
131 // Enable vectorization on z13 and higher.
132 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
133     defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
134 template <>
135 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
136   tensorflow::uint32 p[4];
137   const tensorflow::uint32* ir =
138       reinterpret_cast<const tensorflow::uint32*>(from);
139   p[0] = (ir[0] << 16) & 0xffff0000;
140   p[1] = ir[0] & 0xffff0000;
141   p[2] = (ir[1] << 16) & 0xffff0000;
142   p[3] = ir[1] & 0xffff0000;
143   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
144 }
145 
146 template <>
147 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
148   tensorflow::uint32 p[4];
149   const tensorflow::uint32* ir =
150       reinterpret_cast<const tensorflow::uint32*>(from);
151   p[0] = (ir[0] << 16) & 0xffff0000;
152   p[1] = ir[0] & 0xffff0000;
153   p[2] = (ir[0] << 16) & 0xffff0000;
154   p[3] = ir[0] & 0xffff0000;
155   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
156 }
157 #endif
158 
159 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
160 // Return a packet with the first value of the input Packet replicated
161 template <>
162 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
163   return vec_splat(a, 0);
164 }
165 
166 // Return a packet with the second value of the input Packet replicated
167 template <>
168 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
169   return vec_splat(a, 1);
170 }
171 
172 // Return a packet with the third value of the input Packet replicated
173 template <>
174 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
175   return vec_splat(a, 2);
176 }
177 
178 // Return a packet with the fourth value of the input Packet replicated
179 template <>
180 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
181   return vec_splat(a, 3);
182 }
183 #endif
184 
185 #ifdef EIGEN_VECTORIZE_SSE2
186 // For PacketSize of 4 floats the Packet is not modified
187 template <>
188 EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) {
189   return from;
190 }
191 
192 // Return a Packet with 4 floats loaded from 4 bfloat16 values
193 template <>
194 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
195   __m128i zero = _mm_setzero_si128();
196   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
197   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
198 }
199 
200 // Return a Packet with 2 floats loaded from 2 bfloat16 values
201 template <>
202 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
203   __m128i zero = _mm_setzero_si128();
204   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
205   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
206 }
207 
208 // Return a Packet with 4 floats expanded from 4 bfloat16 values
209 // in the lower half of the 128-bit lane
210 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)211 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
212   __m128i zero = _mm_setzero_si128();
213   __m128i tmp = _mm_castps_si128(from);
214   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
215 }
216 
217 // Return a Packet with 4 floats expanded from 4 bfloat16 values
218 // in the upper half of the 128-bit lane
219 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)220 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
221   __m128i zero = _mm_setzero_si128();
222   __m128i tmp = _mm_castps_si128(from);
223   return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp));
224 }
225 
226 // Return a packet with the first value of the input Packet replicated
227 template <>
228 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
229   return _mm_set1_ps(pfirst<Packet4f>(a));
230 }
231 
232 // Return a packet with the second value of the input Packet replicated
233 template <>
234 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
235   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1)));
236 }
237 
238 // Return a packet with the third value of the input Packet replicated
239 template <>
240 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
241   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2)));
242 }
243 
244 // Return a packet with the fourth value of the input Packet replicated
245 template <>
246 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
247   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3)));
248 }
249 
250 #endif
251 
252 #ifdef EIGEN_VECTORIZE_AVX512
253 template <>
254 EIGEN_STRONG_INLINE Packet16f
255 pbroadcast_first<Packet16f>(const Packet16f& a_in) {
256   Packet4f a = _mm512_castps512_ps128(a_in);
257   return _mm512_broadcastss_ps(a);
258 }
259 template <>
260 EIGEN_STRONG_INLINE Packet16f
261 pbroadcast_second<Packet16f>(const Packet16f& a_in) {
262   Packet4f a = _mm512_castps512_ps128(a_in);
263   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
264 }
265 template <>
266 EIGEN_STRONG_INLINE Packet16f
267 pbroadcast_third<Packet16f>(const Packet16f& a_in) {
268   Packet4f a = _mm512_castps512_ps128(a_in);
269   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
270 }
271 template <>
272 EIGEN_STRONG_INLINE Packet16f
273 pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
274   Packet4f a = _mm512_castps512_ps128(a_in);
275   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
276 }
277 template <>
278 EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
279   Packet2d a = _mm512_castpd512_pd128(a_in);
280   return _mm512_broadcastsd_pd(a);
281 }
282 template <>
283 EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
284   Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
285   return _mm512_broadcastsd_pd(a);
286 }
287 template <>
288 EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
289   Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
290   return _mm512_broadcastsd_pd(a);
291 }
292 template <>
293 EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
294   Packet2d a =
295       _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
296   return _mm512_broadcastsd_pd(a);
297 }
298 template <>
299 EIGEN_STRONG_INLINE Packet16i
300 pbroadcast_first<Packet16i>(const Packet16i& a_in) {
301   Packet4i a = _mm512_castsi512_si128(a_in);
302   return _mm512_broadcastd_epi32(a);
303 }
304 template <>
305 EIGEN_STRONG_INLINE Packet16i
306 pbroadcast_second<Packet16i>(const Packet16i& a_in) {
307   Packet4i a = _mm512_castsi512_si128(a_in);
308   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
309 }
310 template <>
311 EIGEN_STRONG_INLINE Packet16i
312 pbroadcast_third<Packet16i>(const Packet16i& a_in) {
313   Packet4i a = _mm512_castsi512_si128(a_in);
314   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
315 }
316 template <>
317 EIGEN_STRONG_INLINE Packet16i
318 pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
319   Packet4i a = _mm512_castsi512_si128(a_in);
320   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
321 }
322 #endif
323 
324 #ifdef EIGEN_VECTORIZE_AVX
325 // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
326 template <>
327 EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) {
328 #ifdef EIGEN_VECTORIZE_AVX2
329   return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from),
330                                                       _MM_SHUFFLE(3, 1, 2, 0)));
331 #else
332   auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2);
333   auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3);
334   auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4);
335   auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5);
336   auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4);
337   tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5);
338   tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2);
339   tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3);
340   return _mm256_castsi256_ps(tmp5);
341 #endif
342 }
343 // Return a Packet with 4 floats loaded from 4 bfloat16 values
344 template <>
345 EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) {
346   __m128i zero = _mm_setzero_si128();
347   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
348   return _mm256_castps128_ps256(
349       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
350 }
351 // Return a Packet with 2 floats loaded from 2 bfloat16 values
352 template <>
353 EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
354   __m128i zero = _mm_setzero_si128();
355   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
356   return _mm256_castps128_ps256(
357       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
358 }
359 
360 #ifdef EIGEN_VECTORIZE_AVX512
361 // Return a Packet with 4 floats loaded from 4 bfloat16 values
362 template <>
363 EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
364   __m128i zero = _mm_setzero_si128();
365   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
366   return _mm512_castps128_ps512(
367       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
368 }
369 // Return a Packet with 2 floats loaded from 2 bfloat16 values
370 template <>
371 EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
372   __m128i zero = _mm_setzero_si128();
373   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
374   return _mm512_castps128_ps512(
375       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
376 }
377 #endif
378 
379 // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
380 // of the 128-bit lane
381 template <typename Packet>
pexpand_bf16_l(const Packet8f & from)382 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) {
383 #ifdef EIGEN_VECTORIZE_AVX2
384   __m256i zero = _mm256_setzero_si256();
385   __m256i tmp = _mm256_castps_si256(from);
386   return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp));
387 #else
388   __m128i zero = _mm_setzero_si128();
389   __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
390   __m128i res_l = _mm_unpacklo_epi16(zero, low);
391   __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
392   __m128i res_h = _mm_unpacklo_epi16(zero, high);
393   __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
394   res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
395   return res;
396 #endif
397 }
398 
399 // For each 128-bit lane convert 4 bfloat to 4 float values from the upper half
400 // of the 128-bit lane
401 template <typename Packet>
pexpand_bf16_u(const Packet8f & from)402 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) {
403 #ifdef EIGEN_VECTORIZE_AVX2
404   __m256i zero = _mm256_setzero_si256();
405   __m256i tmp = _mm256_castps_si256(from);
406   return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp));
407 #else
408   __m128i zero = _mm_setzero_si128();
409   __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
410   __m128i res_l = _mm_unpackhi_epi16(zero, low);
411   __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
412   __m128i res_h = _mm_unpackhi_epi16(zero, high);
413   __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
414   res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
415   return res;
416 #endif
417 }
418 
419 // Return a packet with the first value of the input Packet replicated
420 template <>
421 EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) {
422   return _mm256_set1_ps(pfirst<Packet8f>(a));
423 }
424 
425 // Return a packet with the second value of the input Packet replicated
426 template <>
427 EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) {
428   return _mm256_set1_ps(
429       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1))));
430 }
431 
432 // Return a packet with the third value of the input Packet replicated
433 template <>
434 EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) {
435   return _mm256_set1_ps(
436       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2))));
437 }
438 
439 // Return a packet with the fourth value of the input Packet replicated
440 template <>
441 EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
442   return _mm256_set1_ps(
443       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3))));
444 }
445 
446 #endif
447 
448 #ifdef EIGEN_VECTORIZE_AVX512
449 
450 template <typename Packet>
pexpand_bf16_l(const Packet16f & from)451 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
452   return _mm512_castsi512_ps(_mm512_slli_epi32(
453       _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
454       16));
455 }
456 
457 template <typename Packet>
pexpand_bf16_u(const Packet16f & from)458 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
459   Packet16i tmp = _mm512_castps_si512(from);
460   Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
461   return _mm512_castsi512_ps(_mm512_slli_epi32(
462       _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
463 }
464 
465 #endif
466 }  // namespace internal
467 }  // namespace Eigen
468 #endif  // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
469