1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
18 
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/core/platform/byte_order.h"
21 #include "tensorflow/core/platform/types.h"
22 
23 #if defined(PLATFORM_WINDOWS)
24 #include "tensorflow/core/platform/windows/intrinsics_port.h"
25 #endif
26 
27 namespace Eigen {
28 namespace internal {
29 
30 // Return the float representation of the bfloat16 value
31 // in the lower 16-bits of input
32 template <typename Packet>
pexpand_bf16_l(const Packet & from)33 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
34   tensorflow::uint32 tmp;
35 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
36   tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
37 #else
38   tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
39 #endif
40   return reinterpret_cast<const float&>(tmp);
41 }
42 
43 // Return the float representation of the bfloat16 value
44 // in the upper 16-bits of input
45 template <typename Packet>
pexpand_bf16_u(const Packet & from)46 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
47   tensorflow::uint32 tmp;
48 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
49   tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
50 #else
51   tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
52 #endif
53   return reinterpret_cast<const float&>(tmp);
54 }
55 
56 // Specialization non-scalar version on non-sse.
57 // Enable vectorization on z13 and higher
58 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
59     defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
60 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)61 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
62   float r[4];
63   tensorflow::uint32 p[4];
64   pstoreu(r, from);
65   tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
66   p[0] = (ir[0] << 16) & 0xffff0000;
67   p[1] = ir[0] & 0xffff0000;
68   p[2] = (ir[1] << 16) & 0xffff0000;
69   p[3] = ir[1] & 0xffff0000;
70   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
71 }
72 
73 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)74 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
75   float r[4];
76   tensorflow::uint32 p[4];
77   pstoreu(r, from);
78   tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
79   p[0] = (ir[2] << 16) & 0xffff0000;
80   p[1] = ir[2] & 0xffff0000;
81   p[2] = (ir[3] << 16) & 0xffff0000;
82   p[3] = ir[3] & 0xffff0000;
83   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
84 }
85 #endif
86 
87 template <typename Packet>
pinterleave4x64(const Packet & from)88 EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
89   return from;
90 }
91 
92 template <typename Packet>
pbroadcast_first(const Packet & a)93 EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) {
94   return a;
95 }
96 
97 template <typename Packet>
pbroadcast_second(const Packet & a)98 EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) {
99   assert(false && "Not applicable to Scalar Values");
100   return a;
101 }
102 
103 template <typename Packet>
pbroadcast_third(const Packet & a)104 EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) {
105   assert(false && "Not applicable to Scalar Values");
106   return a;
107 }
108 
109 template <typename Packet>
pbroadcast_fourth(const Packet & a)110 EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) {
111   assert(false && "Not applicable to Scalar Values");
112   return a;
113 }
114 
115 template <typename Packet>
pload4bf16(const typename unpacket_traits<Packet>::type * from)116 EIGEN_DEVICE_FUNC inline Packet pload4bf16(
117     const typename unpacket_traits<Packet>::type* from) {
118   assert(false && "Not applicable to Scalar Values");
119   return Packet();
120 }
121 
122 template <typename Packet>
pload2bf16(const typename unpacket_traits<Packet>::type * from)123 EIGEN_DEVICE_FUNC inline Packet pload2bf16(
124     const typename unpacket_traits<Packet>::type* from) {
125   assert(false && "Not applicable to Scalar Values");
126   return Packet();
127 }
128 
129 // Specialization for pload4bf16 and pload2bf16 for non-sse.
130 // Enable vectorization on z13 and higher.
131 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
132     defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
133 template <>
134 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
135   tensorflow::uint32 p[4];
136   const tensorflow::uint32* ir =
137       reinterpret_cast<const tensorflow::uint32*>(from);
138   p[0] = (ir[0] << 16) & 0xffff0000;
139   p[1] = ir[0] & 0xffff0000;
140   p[2] = (ir[1] << 16) & 0xffff0000;
141   p[3] = ir[1] & 0xffff0000;
142   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
143 }
144 
145 template <>
146 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
147   tensorflow::uint32 p[4];
148   const tensorflow::uint32* ir =
149       reinterpret_cast<const tensorflow::uint32*>(from);
150   p[0] = (ir[0] << 16) & 0xffff0000;
151   p[1] = ir[0] & 0xffff0000;
152   p[2] = (ir[0] << 16) & 0xffff0000;
153   p[3] = ir[0] & 0xffff0000;
154   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
155 }
156 #endif
157 
158 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
159 // Return a packet with the first value of the input Packet replicated
160 template <>
161 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
162   return vec_splat(a, 0);
163 }
164 
165 // Return a packet with the second value of the input Packet replicated
166 template <>
167 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
168   return vec_splat(a, 1);
169 }
170 
171 // Return a packet with the third value of the input Packet replicated
172 template <>
173 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
174   return vec_splat(a, 2);
175 }
176 
177 // Return a packet with the fourth value of the input Packet replicated
178 template <>
179 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
180   return vec_splat(a, 3);
181 }
182 #endif
183 
184 #ifdef EIGEN_VECTORIZE_SSE2
185 // For PacketSize of 4 floats the Packet is not modified
186 template <>
187 EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) {
188   return from;
189 }
190 
191 // Return a Packet with 4 floats loaded from 4 bfloat16 values
192 template <>
193 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
194   __m128i zero = _mm_setzero_si128();
195   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
196   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
197 }
198 
199 // Return a Packet with 2 floats loaded from 2 bfloat16 values
200 template <>
201 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
202   __m128i zero = _mm_setzero_si128();
203   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
204   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
205 }
206 
207 // Return a Packet with 4 floats expanded from 4 bfloat16 values
208 // in the lower half of the 128-bit lane
209 template <typename Packet>
pexpand_bf16_l(const Packet4f & from)210 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
211   __m128i zero = _mm_setzero_si128();
212   __m128i tmp = _mm_castps_si128(from);
213   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
214 }
215 
216 // Return a Packet with 4 floats expanded from 4 bfloat16 values
217 // in the upper half of the 128-bit lane
218 template <typename Packet>
pexpand_bf16_u(const Packet4f & from)219 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
220   __m128i zero = _mm_setzero_si128();
221   __m128i tmp = _mm_castps_si128(from);
222   return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp));
223 }
224 
225 // Return a packet with the first value of the input Packet replicated
226 template <>
227 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
228   return _mm_set1_ps(pfirst<Packet4f>(a));
229 }
230 
231 // Return a packet with the second value of the input Packet replicated
232 template <>
233 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
234   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1)));
235 }
236 
237 // Return a packet with the third value of the input Packet replicated
238 template <>
239 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
240   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2)));
241 }
242 
243 // Return a packet with the fourth value of the input Packet replicated
244 template <>
245 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
246   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3)));
247 }
248 
249 #endif
250 
251 #ifdef EIGEN_VECTORIZE_AVX512
252 template <>
253 EIGEN_STRONG_INLINE Packet16f
254 pbroadcast_first<Packet16f>(const Packet16f& a_in) {
255   Packet4f a = _mm512_castps512_ps128(a_in);
256   return _mm512_broadcastss_ps(a);
257 }
258 template <>
259 EIGEN_STRONG_INLINE Packet16f
260 pbroadcast_second<Packet16f>(const Packet16f& a_in) {
261   Packet4f a = _mm512_castps512_ps128(a_in);
262   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
263 }
264 template <>
265 EIGEN_STRONG_INLINE Packet16f
266 pbroadcast_third<Packet16f>(const Packet16f& a_in) {
267   Packet4f a = _mm512_castps512_ps128(a_in);
268   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
269 }
270 template <>
271 EIGEN_STRONG_INLINE Packet16f
272 pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
273   Packet4f a = _mm512_castps512_ps128(a_in);
274   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
275 }
276 template <>
277 EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
278   Packet2d a = _mm512_castpd512_pd128(a_in);
279   return _mm512_broadcastsd_pd(a);
280 }
281 template <>
282 EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
283   Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
284   return _mm512_broadcastsd_pd(a);
285 }
286 template <>
287 EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
288   Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
289   return _mm512_broadcastsd_pd(a);
290 }
291 template <>
292 EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
293   Packet2d a =
294       _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
295   return _mm512_broadcastsd_pd(a);
296 }
297 template <>
298 EIGEN_STRONG_INLINE Packet16i
299 pbroadcast_first<Packet16i>(const Packet16i& a_in) {
300   Packet4i a = _mm512_castsi512_si128(a_in);
301   return _mm512_broadcastd_epi32(a);
302 }
303 template <>
304 EIGEN_STRONG_INLINE Packet16i
305 pbroadcast_second<Packet16i>(const Packet16i& a_in) {
306   Packet4i a = _mm512_castsi512_si128(a_in);
307   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
308 }
309 template <>
310 EIGEN_STRONG_INLINE Packet16i
311 pbroadcast_third<Packet16i>(const Packet16i& a_in) {
312   Packet4i a = _mm512_castsi512_si128(a_in);
313   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
314 }
315 template <>
316 EIGEN_STRONG_INLINE Packet16i
317 pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
318   Packet4i a = _mm512_castsi512_si128(a_in);
319   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
320 }
321 #endif
322 
323 #ifdef EIGEN_VECTORIZE_AVX
324 // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
325 template <>
326 EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) {
327 #ifdef EIGEN_VECTORIZE_AVX2
328   return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from),
329                                                       _MM_SHUFFLE(3, 1, 2, 0)));
330 #else
331   auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2);
332   auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3);
333   auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4);
334   auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5);
335   auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4);
336   tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5);
337   tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2);
338   tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3);
339   return _mm256_castsi256_ps(tmp5);
340 #endif
341 }
342 // Return a Packet with 4 floats loaded from 4 bfloat16 values
343 template <>
344 EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) {
345   __m128i zero = _mm_setzero_si128();
346   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
347   return _mm256_castps128_ps256(
348       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
349 }
350 // Return a Packet with 2 floats loaded from 2 bfloat16 values
351 template <>
352 EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
353   __m128i zero = _mm_setzero_si128();
354   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
355   return _mm256_castps128_ps256(
356       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
357 }
358 
359 #ifdef EIGEN_VECTORIZE_AVX512
360 // Return a Packet with 4 floats loaded from 4 bfloat16 values
361 template <>
362 EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
363   __m128i zero = _mm_setzero_si128();
364   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
365   return _mm512_castps128_ps512(
366       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
367 }
368 // Return a Packet with 2 floats loaded from 2 bfloat16 values
369 template <>
370 EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
371   __m128i zero = _mm_setzero_si128();
372   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
373   return _mm512_castps128_ps512(
374       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
375 }
376 #endif
377 
378 // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
379 // of the 128-bit lane
380 template <typename Packet>
pexpand_bf16_l(const Packet8f & from)381 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) {
382 #ifdef EIGEN_VECTORIZE_AVX2
383   __m256i zero = _mm256_setzero_si256();
384   __m256i tmp = _mm256_castps_si256(from);
385   return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp));
386 #else
387   __m128i zero = _mm_setzero_si128();
388   __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
389   __m128i res_l = _mm_unpacklo_epi16(zero, low);
390   __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
391   __m128i res_h = _mm_unpacklo_epi16(zero, high);
392   __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
393   res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
394   return res;
395 #endif
396 }
397 
398 // For each 128-bit lane convert 4 bfloat to 4 float values from the upper half
399 // of the 128-bit lane
400 template <typename Packet>
pexpand_bf16_u(const Packet8f & from)401 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) {
402 #ifdef EIGEN_VECTORIZE_AVX2
403   __m256i zero = _mm256_setzero_si256();
404   __m256i tmp = _mm256_castps_si256(from);
405   return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp));
406 #else
407   __m128i zero = _mm_setzero_si128();
408   __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
409   __m128i res_l = _mm_unpackhi_epi16(zero, low);
410   __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
411   __m128i res_h = _mm_unpackhi_epi16(zero, high);
412   __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
413   res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
414   return res;
415 #endif
416 }
417 
418 // Return a packet with the first value of the input Packet replicated
419 template <>
420 EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) {
421   return _mm256_set1_ps(pfirst<Packet8f>(a));
422 }
423 
424 // Return a packet with the second value of the input Packet replicated
425 template <>
426 EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) {
427   return _mm256_set1_ps(
428       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1))));
429 }
430 
431 // Return a packet with the third value of the input Packet replicated
432 template <>
433 EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) {
434   return _mm256_set1_ps(
435       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2))));
436 }
437 
438 // Return a packet with the fourth value of the input Packet replicated
439 template <>
440 EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
441   return _mm256_set1_ps(
442       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3))));
443 }
444 
445 #endif
446 
447 #ifdef EIGEN_VECTORIZE_AVX512
448 
449 template <typename Packet>
pexpand_bf16_l(const Packet16f & from)450 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
451   return _mm512_castsi512_ps(_mm512_slli_epi32(
452       _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
453       16));
454 }
455 
456 template <typename Packet>
pexpand_bf16_u(const Packet16f & from)457 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
458   Packet16i tmp = _mm512_castps_si512(from);
459   Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
460   return _mm512_castsi512_ps(_mm512_slli_epi32(
461       _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
462 }
463 
464 #endif
465 }  // namespace internal
466 }  // namespace Eigen
467 #endif  // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
468