1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/sparse_matmul_op.h"
21 
22 #include <map>
23 #include <memory>
24 #include <vector>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/bfloat16.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/lib/core/blocking_counter.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/thread_annotations.h"
39 #include "tensorflow/core/platform/types.h"
40 #ifdef TENSORFLOW_USE_LIBXSMM
41 #include "include/libxsmm_intrinsics_x86.h"
42 #include "include/libxsmm_malloc.h"
43 #include "include/libxsmm_spmdm.h"
44 #endif
45 
46 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
47 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
48 #endif
49 
50 #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
51 
52 namespace tensorflow {
53 namespace {
54 
55 template <typename T>
56 using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
57 
58 template <typename T>
59 using BasicMatrixMap =
60     Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
61 
62 using Matrix = BasicMatrix<float>;
63 using MatrixMap = BasicMatrixMap<float>;
64 using CPUDevice = Eigen::ThreadPoolDevice;
65 using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
66 
67 // Two commonly used static dsizes. We use Eigen::type2index to allow as much
68 // compile time optimization as possible.
69 #ifdef EIGEN_HAS_INDEX_LIST
70 inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
dsizes_00()71 dsizes_00() {
72   return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
73 }
74 inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
dsizes_10()75 dsizes_10() {
76   return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
77 }
78 #else
dsizes_00()79 inline DSizes dsizes_00() { return DSizes(0, 0); }
dsizes_10()80 inline DSizes dsizes_10() { return DSizes(1, 0); }
81 #endif
82 
83 // Blocksizes
84 // TODO(agarwal): compute these sizes based on cache sizes.
85 const int K = 64;
86 const int M = 64;
87 const int N = 128;
88 
89 // This stores a sparse representation of a slice of a matrix with size
90 // (num_rows, num_cols). The slice is represented as a series of blocks of size
91 // (num_rows, b), where b = block_size for all but the last block, which may
92 // have fewer columns.
93 //
94 // num_rows and block_size are assumed to be <= 256. This allows storing
95 // different indices as uint8.
96 //
97 // For each block, we store all the non zero entries in data/data3 vector and
98 // the corresponding coordinates of the element in index/index3 vectors. index3
99 // vector stores index of 3 elements in the same row so that these elements can
100 // share the same row coordinate. Each entry in Index3 corresponds to 3 entries
101 // in data3.
102 //
103 // Note that all the data/indices of all the blocks are stored in the same
104 // vectors respectively. To identify block boundaries, we store the block
105 // offsets using index3_offset/index_offset. If there are n blocks in the slice,
106 // index3_offset and index_offset have n entries. The indices for the ith block
107 // are the values in the following range:
108 // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
109 // index_offset.
110 template <typename T>
111 struct SparseSlice {
112   using ConstMatrixMap = BasicMatrixMap<const T>;
113 
114  public:
115   // Indices of three elements on the same row.
116   struct Index3 {
117     uint8 m;  // row
118     // columns
119     uint8 k1;
120     uint8 k2;
121     uint8 k3;
122   };
123 
124   // Index of one element.
125   struct Index {
126     uint8 m;
127     uint8 k;
128   };
129 
SparseSlicetensorflow::__anon00ddf21f0111::SparseSlice130   SparseSlice(int nrows, int ncols, int bsize)
131       : num_rows(nrows), num_cols(ncols), block_size(bsize) {
132     DCHECK_LE(nrows, 256);
133     DCHECK_LE(block_size, 256);
134   }
135 
136   // Initializes the slice with data starting at mat(0, col_offset) and with
137   // size (num_rows, num_cols).
138   // If Transpose is true, implicitly transposes mat.
139   template <bool Transpose = false>
140   void Initialize(const ConstMatrixMap& mat, int col_offset);
141 
142   void Clear();
143 
144   // See comments above.
145   std::vector<int> index3_offset;
146   std::vector<Index3> index3;
147   std::vector<T> data3;
148 
149   // See comments above. Similar to "index3" except that each element in "index"
150   // corresponds to one element in data.
151   std::vector<int> index_offset;
152   std::vector<Index> index;
153   std::vector<T> data;
154 
155   // Number of rows and columns for the slice.
156   const int num_rows;
157   const int num_cols;
158 
159   // Block size used to initialize from a matrix.
160   const int block_size;
161 };
162 
163 template <typename T>
164 bool IsZero(T v);
165 
166 template <>
IsZero(bfloat16 v)167 ALWAYS_INLINE bool IsZero(bfloat16 v) {
168   return !static_cast<bool>(v);
169 }
170 
171 template <>
IsZero(float v)172 ALWAYS_INLINE bool IsZero(float v) {
173   return v == 0.0f;
174 }
175 
176 template <typename T>
177 template <bool Transpose>
Initialize(const typename SparseSlice<T>::ConstMatrixMap & mat,int col_offset)178 void SparseSlice<T>::Initialize(
179     const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
180   const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
181   const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
182   DCHECK_LE(num_rows, mat_rows);
183   DCHECK_LE(num_cols + col_offset, mat_cols);
184 
185   int num_blocks = (num_cols + block_size - 1) / block_size;
186   int mat_size = num_rows * num_cols;
187 
188   index3_offset.reserve(num_blocks);
189   data3.reserve(mat_size);
190   index3.reserve(mat_size / 3);
191 
192   index_offset.reserve(num_blocks);
193   data.reserve(num_blocks * num_rows * 2);
194   index.reserve(num_blocks * num_rows * 2);
195 
196   Index3 idx3;
197   const int stride = Transpose ? mat.dimension(1) : 1;
198 
199   for (int i = 0; i < num_blocks; ++i) {
200     int num_block_cols = std::min(block_size, num_cols - block_size * i);
201     for (int row = 0; row < num_rows; ++row) {
202       idx3.m = static_cast<uint8>(row);
203       // Safety note: The following code has a race, since it checks whether
204       // *curr is nonzero and then reads it again on use.  However, the result
205       // of the race is only that some of the "nonzeros" in the resulting sparse
206       // representation may actually be zero, which is harmless.
207       const auto* start =
208           Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
209       const auto* curr = start;
210       const auto* end = start + stride * num_block_cols;
211       uint8 k = 0;
212 #define NEXT_ELEM \
213   curr += stride; \
214   ++k;
215 #define EAT_ZEROS                          \
216   while (curr < end && IsZero<T>(*curr)) { \
217     NEXT_ELEM;                             \
218   }
219       while (true) {
220         EAT_ZEROS
221         if (curr >= end) break;
222         idx3.k1 = k;
223         const T value1 = *curr;
224         NEXT_ELEM;
225 
226         EAT_ZEROS
227         if (curr >= end) {
228           data.push_back(value1);
229           index.push_back({idx3.m, idx3.k1});
230           break;
231         }
232         idx3.k2 = k;
233         const T value2 = *curr;
234         NEXT_ELEM;
235 
236         EAT_ZEROS
237         if (curr >= end) {
238           data.push_back(value2);
239           index.push_back({idx3.m, idx3.k2});
240           data.push_back(value1);
241           index.push_back({idx3.m, idx3.k1});
242           break;
243         }
244         idx3.k3 = k;
245         data3.push_back(value1);
246         data3.push_back(value2);
247         data3.push_back(*curr);
248         NEXT_ELEM;
249         index3.push_back(idx3);
250 #undef NEXT_ELEM
251 #undef EAT_ZEROS
252       }
253     }
254     col_offset += block_size;
255     index3_offset.push_back(index3.size());
256     index_offset.push_back(index.size());
257   }
258   DCHECK_EQ(index3_offset.size(), num_blocks);
259   DCHECK_EQ(index_offset.size(), num_blocks);
260   DCHECK_EQ(3 * index3.size(), data3.size());
261   DCHECK_EQ(index.size(), data.size());
262 }
263 
264 template <typename T>
Clear()265 void SparseSlice<T>::Clear() {
266   index3_offset.clear();
267   index3.clear();
268   data3.clear();
269   index_offset.clear();
270   index.clear();
271   data.clear();
272 }
273 
274 using Packet = Eigen::internal::packet_traits<float>::type;
275 const int kNumOperands = (sizeof(Packet) / sizeof(float));
276 #define LOAD(x) Eigen::internal::pload<Packet>(x);
277 #define EXPAND_BFLOAT_L(x, y) \
278   const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
279 #define EXPAND_BFLOAT_U(x, y) \
280   const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
281 #define STORE(x, y) Eigen::internal::pstore<float>(x, y);
282 #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
283 
ConvertBfloat16ToFloat(const bfloat16 * src)284 ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
285   float out = 0;
286   auto tmp = reinterpret_cast<bfloat16*>(&out);
287 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
288   tmp[0] = *src;
289 #else
290   tmp[1] = *src;
291 #endif
292   return out;
293 }
294 
ConvertFourBfloat16ToFloat(const bfloat16 * src)295 ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
296   return Eigen::internal::pload4bf16<Packet>(
297       reinterpret_cast<const float*>(src));
298 }
299 
ConvertTwoBfloat16ToFloat(const bfloat16 * src)300 ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
301   return Eigen::internal::pload2bf16<Packet>(
302       reinterpret_cast<const float*>(src));
303 }
304 
ScalarMulAdd(const float a,const float ** inp,float ** out)305 ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
306   **out += a * **inp;
307   ++*inp;
308   ++*out;
309 }
310 
ScalarMulAdd(const float a,const bfloat16 ** inp,float ** out)311 ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
312                                 float** out) {
313   float inp_f = ConvertBfloat16ToFloat(*inp);
314   **out += a * inp_f;
315   ++*inp;
316   ++*out;
317 }
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)318 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
319                                     const float a3, const bfloat16** inp1,
320                                     const bfloat16** inp2,
321                                     const bfloat16** inp3, float** out) {
322   float inp1_f = ConvertBfloat16ToFloat(*inp1);
323   float inp2_f = ConvertBfloat16ToFloat(*inp2);
324   float inp3_f = ConvertBfloat16ToFloat(*inp3);
325   **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
326   ++*out;
327   ++*inp1;
328   ++*inp2;
329   ++*inp3;
330 }
331 
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)332 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
333                                     const float a3, const float** inp1,
334                                     const float** inp2, const float** inp3,
335                                     float** out) {
336   **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
337   ++*out;
338   ++*inp1;
339   ++*inp2;
340   ++*inp3;
341 }
342 
LoadSingleScalar(const bfloat16 ** data,Packet * l)343 ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
344   auto tmp = ConvertBfloat16ToFloat(*data);
345   *l = Eigen::internal::pset1<Packet>(tmp);
346   ++*data;
347 }
348 
LoadTwoScalars(const bfloat16 ** data,Packet * l1,Packet * l2)349 ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
350                                   Packet* l2) {
351   if (kNumOperands >= 2) {
352     auto tmp = ConvertTwoBfloat16ToFloat(*data);
353     *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
354     *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
355     *data += 2;
356   } else {
357     LoadSingleScalar(data, l1);
358     LoadSingleScalar(data, l2);
359   }
360 }
361 
LoadFourScalars(const bfloat16 ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)362 ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
363                                    Packet* l2, Packet* l3, Packet* l4) {
364   if (kNumOperands >= 4) {
365     auto tmp = ConvertFourBfloat16ToFloat(*data);
366     *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
367     *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
368     *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
369     *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
370     *data += 4;
371   } else {
372     LoadTwoScalars(data, l1, l2);
373     LoadTwoScalars(data, l3, l4);
374   }
375 }
376 
LoadSingleScalar(const float ** data,Packet * l)377 ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
378   *l = Eigen::internal::pload1<Packet>(*data);
379   ++(*data);
380 }
381 
LoadTwoScalars(const float ** data,Packet * l1,Packet * l2)382 ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
383   LoadSingleScalar(data, l1);
384   LoadSingleScalar(data, l2);
385 }
386 
LoadFourScalars(const float ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)387 ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
388                                    Packet* l3, Packet* l4) {
389   LoadTwoScalars(data, l1, l2);
390   LoadTwoScalars(data, l3, l4);
391 }
392 
393 template <typename T>
LoadThreeScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3)394 ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
395                                     Packet* l3) {
396   LoadTwoScalars(data, l1, l2);
397   LoadSingleScalar(data, l3);
398 }
399 
400 template <typename T>
LoadSixScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4,Packet * l5,Packet * l6)401 ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
402                                   Packet* l3, Packet* l4, Packet* l5,
403                                   Packet* l6) {
404   LoadFourScalars(data, l1, l2, l3, l4);
405   LoadTwoScalars(data, l5, l6);
406 }
407 
408 // Vectorized version of ScalarMulAdd.
MulAdd(const Packet a,const bfloat16 ** binp,float ** out)409 ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
410   auto inp = reinterpret_cast<const float*>(*binp);
411   const auto b = LOAD(inp);
412   EXPAND_BFLOAT_L(b, b_0);
413   EXPAND_BFLOAT_U(b, b_1);
414   *binp += 2 * kNumOperands;
415   auto c1 = LOAD(*out);
416   auto c2 = LOAD(*out + kNumOperands);
417   FMA(a, b_0, c1, c1);
418   FMA(a, b_1, c2, c2);
419   STORE(*out, c1);
420   STORE(*out + kNumOperands, c2);
421   *out += 2 * kNumOperands;
422 }
423 
424 // Vectorized version of ScalarMulAdd3Way.
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)425 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
426                               const bfloat16** binp1, const bfloat16** binp2,
427                               const bfloat16** binp3, float** out) {
428   auto inp1 = reinterpret_cast<const float*>(*binp1);
429   auto inp2 = reinterpret_cast<const float*>(*binp2);
430   auto inp3 = reinterpret_cast<const float*>(*binp3);
431   auto c1 = LOAD(*out);
432   auto c2 = LOAD(*out + kNumOperands);
433   const auto b1 = LOAD(inp1);
434   EXPAND_BFLOAT_L(b1, b1_0);
435   EXPAND_BFLOAT_U(b1, b1_1);
436   *binp1 += 2 * kNumOperands;
437   const auto b2 = LOAD(inp2);
438   EXPAND_BFLOAT_L(b2, b2_0);
439   EXPAND_BFLOAT_U(b2, b2_1);
440   *binp2 += 2 * kNumOperands;
441   const auto b3 = LOAD(inp3);
442   EXPAND_BFLOAT_L(b3, b3_0);
443   EXPAND_BFLOAT_U(b3, b3_1);
444   *binp3 += 2 * kNumOperands;
445   FMA(a1, b1_0, c1, c1);
446   FMA(a1, b1_1, c2, c2);
447   FMA(a2, b2_0, c1, c1);
448   FMA(a2, b2_1, c2, c2);
449   FMA(a3, b3_0, c1, c1);
450   FMA(a3, b3_1, c2, c2);
451   STORE(*out, c1);
452   STORE(*out + kNumOperands, c2);
453   *out += 2 * kNumOperands;
454 }
455 
456 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)457 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
458                                  const Packet a3, const bfloat16** binp1,
459                                  const bfloat16** binp2, const bfloat16** binp3,
460                                  float** out) {
461   auto inp1 = reinterpret_cast<const float*>(*binp1);
462   auto inp2 = reinterpret_cast<const float*>(*binp2);
463   auto inp3 = reinterpret_cast<const float*>(*binp3);
464   auto c1 = LOAD(*out);
465   auto c2 = LOAD(*out + kNumOperands);
466   const auto b1 = LOAD(inp1);
467   const auto b2 = LOAD(inp2);
468   const auto b3 = LOAD(inp3);
469 
470   EXPAND_BFLOAT_L(b1, b1_0);
471   EXPAND_BFLOAT_U(b1, b1_1);
472   EXPAND_BFLOAT_L(b2, b2_0);
473   EXPAND_BFLOAT_U(b2, b2_1);
474   EXPAND_BFLOAT_L(b3, b3_0);
475   EXPAND_BFLOAT_U(b3, b3_1);
476   auto c3 = LOAD(*out + 2 * kNumOperands);
477   auto c4 = LOAD(*out + 3 * kNumOperands);
478   const auto b4 = LOAD(inp1 + kNumOperands);
479   const auto b5 = LOAD(inp2 + kNumOperands);
480   const auto b6 = LOAD(inp3 + kNumOperands);
481 
482   EXPAND_BFLOAT_L(b4, b4_0);
483   EXPAND_BFLOAT_U(b4, b4_1);
484   EXPAND_BFLOAT_L(b5, b5_0);
485   EXPAND_BFLOAT_U(b5, b5_1);
486   EXPAND_BFLOAT_L(b6, b6_0);
487   EXPAND_BFLOAT_U(b6, b6_1);
488 
489   FMA(a1, b1_0, c1, c1);
490   FMA(a1, b1_1, c2, c2);
491   FMA(a1, b4_0, c3, c3);
492   FMA(a1, b4_1, c4, c4);
493   FMA(a2, b2_0, c1, c1);
494   FMA(a2, b2_1, c2, c2);
495   FMA(a2, b5_0, c3, c3);
496   FMA(a2, b5_1, c4, c4);
497   FMA(a3, b3_0, c1, c1);
498   FMA(a3, b3_1, c2, c2);
499   FMA(a3, b6_0, c3, c3);
500   FMA(a3, b6_1, c4, c4);
501   STORE(*out, c1);
502   STORE(*out + kNumOperands, c2);
503   STORE(*out + 2 * kNumOperands, c3);
504   STORE(*out + 3 * kNumOperands, c4);
505   *out += 4 * kNumOperands;
506   *binp1 += 4 * kNumOperands;
507   *binp2 += 4 * kNumOperands;
508   *binp3 += 4 * kNumOperands;
509 }
510 
511 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)512 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
513                                  const Packet a3, const bfloat16** inp1,
514                                  const bfloat16** inp2, const bfloat16** inp3,
515                                  float** out) {
516   for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
517     TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
518     TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
519   }
520 }
521 
522 // Vectorized version of ScalarMulAdd
MulAdd(const Packet a,const float ** inp,float ** out)523 ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
524   const auto b = LOAD(*inp);
525   *inp += kNumOperands;
526   auto c = LOAD(*out);
527   FMA(a, b, c, c);
528   STORE(*out, c);
529   *out += kNumOperands;
530 }
531 
532 // Vectorized version of ScalarMulAdd3Way
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)533 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
534                               const float** inp1, const float** inp2,
535                               const float** inp3, float** out) {
536   auto c = LOAD(*out);
537   const auto b1 = LOAD(*inp1);
538   *inp1 += kNumOperands;
539   const auto b2 = LOAD(*inp2);
540   *inp2 += kNumOperands;
541   const auto b3 = LOAD(*inp3);
542   *inp3 += kNumOperands;
543   FMA(a1, b1, c, c);
544   FMA(a2, b2, c, c);
545   FMA(a3, b3, c, c);
546   STORE(*out, c);
547   *out += kNumOperands;
548 }
549 
550 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)551 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
552                                  const Packet a3, const float** inp1,
553                                  const float** inp2, const float** inp3,
554                                  float** out) {
555   auto c1 = LOAD(*out);
556   const auto b1 = LOAD(*inp1);
557   const auto b2 = LOAD(*inp2);
558   const auto b3 = LOAD(*inp3);
559 
560   auto c2 = LOAD(*out + kNumOperands);
561   const auto b4 = LOAD(*inp1 + kNumOperands);
562   const auto b5 = LOAD(*inp2 + kNumOperands);
563   const auto b6 = LOAD(*inp3 + kNumOperands);
564 
565   FMA(a1, b1, c1, c1);
566   FMA(a1, b4, c2, c2);
567   FMA(a2, b2, c1, c1);
568   FMA(a2, b5, c2, c2);
569   FMA(a3, b3, c1, c1);
570   FMA(a3, b6, c2, c2);
571   STORE(*out, c1);
572   STORE(*out + kNumOperands, c2);
573   *out += 2 * kNumOperands;
574   *inp1 += 2 * kNumOperands;
575   *inp2 += 2 * kNumOperands;
576   *inp3 += 2 * kNumOperands;
577 }
578 
579 // Unroll MulAdd3Way for four iterations
FourMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)580 ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
581                                   const Packet a3, const float** inp1,
582                                   const float** inp2, const float** inp3,
583                                   float** out) {
584   TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
585   TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
586 }
587 
588 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)589 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
590                                  const Packet a3, const float** inp1,
591                                  const float** inp2, const float** inp3,
592                                  float** out) {
593   if (kNumOperands == 8) {
594     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
595     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
596     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
597     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
598   } else {
599     DCHECK_LE(4 * kNumOperands, 128);
600     for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
601       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
602       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
603       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
604       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
605     }
606   }
607 }
608 // Computes product of "left_slices" with "num_cols" columns of "right", and
609 // stores the output in *"output".
610 // Note that left_slices is a list of SparseSlices, which are conceptually
611 // assumed to be concatenated along the column dimension. Also each SparseSlice
612 // is encoded as a list of blocks with upto N columns. See SparseSlice for more
613 // details.
614 template <typename TL, typename TR, int Cols>
GEPP(const std::vector<SparseSlice<TL> * > & left_slices,const Eigen::TensorMap<Eigen::Tensor<const TR,2,Eigen::RowMajor>,Eigen::Aligned> & right,const int num_cols,Matrix * output)615 inline void GEPP(
616     const std::vector<SparseSlice<TL>*>& left_slices,
617     const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
618                            Eigen::Aligned>& right,
619     const int num_cols, Matrix* output) {
620   const int cols = (Cols == -1) ? num_cols : Cols;
621   DCHECK_EQ(num_cols, cols);
622   const int right_num_cols = right.dimension(1);
623   const int output_num_cols = output->dimension(1);
624   static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
625   const int cols_mod = cols % kNumOperandsR;
626   int k_offset = 0;
627   // Pre-compute pointers for output matrix.
628   float* out_ptrs[M];
629   float* const out_start = &(*output)(0, 0);
630   for (int j = 0; j < M; ++j) {
631     out_ptrs[j] = out_start + output_num_cols * j;
632   }
633   for (const auto* left_slice : left_slices) {
634     const auto& left = *left_slice;
635     const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
636     const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
637     const int num_blocks = left.index3_offset.size();
638     int begin3 = 0;
639     int begin = 0;
640     for (int i = 0; i < num_blocks; ++i) {
641       // Pre-compute pointers for right matrix
642       const TR* right_ptrs[K];
643       const auto* const right_start = &right(k_offset, 0);
644       DCHECK_LT(k_offset, right.dimension(0));
645       for (int j = 0; j < K; ++j) {
646         right_ptrs[j] = right_start + right_num_cols * j;
647       }
648 
649       const int end3 = left.index3_offset[i];
650       int j = begin3;
651       // Loop unrolled for 2 iterations.
652       for (; j + 1 < end3; j += 2) {
653         Packet l1, l2, l3, nl1, nl2, nl3;
654         LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
655         const auto& index = left.index3[j];
656         const auto& nindex = left.index3[j + 1];
657         float* out = out_ptrs[index.m];
658         float* nout = out_ptrs[nindex.m];
659         const auto* r1 = right_ptrs[index.k1];
660         const auto* r2 = right_ptrs[index.k2];
661         const auto* r3 = right_ptrs[index.k3];
662 
663         const auto* nr1 = right_ptrs[nindex.k1];
664         const auto* nr2 = right_ptrs[nindex.k2];
665         const auto* nr3 = right_ptrs[nindex.k3];
666         if (cols == 128) {
667           MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
668           MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
669         } else {
670           for (int n = 0; n < cols / kNumOperandsR; ++n) {
671             MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
672             MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
673           }
674 
675           const float sl1 = Eigen::internal::pfirst<Packet>(l1);
676           const float sl2 = Eigen::internal::pfirst<Packet>(l2);
677           const float sl3 = Eigen::internal::pfirst<Packet>(l3);
678           const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
679           const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
680           const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
681           for (int k = 0; k < cols_mod; ++k) {
682             ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
683             ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
684           }
685         }
686       }
687       if (j < end3) {
688         Packet l1, l2, l3;
689         LoadThreeScalars(&data3, &l1, &l2, &l3);
690 
691         const auto& index = left.index3[j];
692         float* out = out_ptrs[index.m];
693         const auto* r1 = right_ptrs[index.k1];
694         const auto* r2 = right_ptrs[index.k2];
695         const auto* r3 = right_ptrs[index.k3];
696         if (cols == 128) {
697           MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
698         } else {
699           for (int n = 0; n < cols / kNumOperandsR; ++n) {
700             MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
701           }
702           const float sl1 = Eigen::internal::pfirst<Packet>(l1);
703           const float sl2 = Eigen::internal::pfirst<Packet>(l2);
704           const float sl3 = Eigen::internal::pfirst<Packet>(l3);
705           for (int k = 0; k < cols_mod; ++k) {
706             ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
707           }
708         }
709       }
710       begin3 = end3;
711       int end = left.index_offset[i];
712       // Loop unrolled for 4 iterations.
713       j = begin;
714       for (; j + 3 < end; j += 4) {
715         Packet l, nl, n2l, n3l;
716         LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
717 
718         const auto& index = left.index[j];
719         const auto& nindex = left.index[j + 1];
720         const auto& n2index = left.index[j + 2];
721         const auto& n3index = left.index[j + 3];
722         const auto* r = right_ptrs[index.k];
723         const auto* nr = right_ptrs[nindex.k];
724         const auto* n2r = right_ptrs[n2index.k];
725         const auto* n3r = right_ptrs[n3index.k];
726         float* out = out_ptrs[index.m];
727         float* nout = out_ptrs[nindex.m];
728         float* n2out = out_ptrs[n2index.m];
729         float* n3out = out_ptrs[n3index.m];
730 
731         for (int n = 0; n < cols / kNumOperandsR; ++n) {
732           MulAdd(l, &r, &out);
733           MulAdd(nl, &nr, &nout);
734           MulAdd(n2l, &n2r, &n2out);
735           MulAdd(n3l, &n3r, &n3out);
736         }
737 
738         const float sl1 = Eigen::internal::pfirst<Packet>(l);
739         const float sl2 = Eigen::internal::pfirst<Packet>(nl);
740         const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
741         const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
742         for (int k = 0; k < cols_mod; ++k) {
743           ScalarMulAdd(sl1, &r, &out);
744           ScalarMulAdd(sl2, &nr, &nout);
745           ScalarMulAdd(sl3, &n2r, &n2out);
746           ScalarMulAdd(sl4, &n3r, &n3out);
747         }
748       }
749       while (j < end) {
750         Packet l;
751         LoadSingleScalar(&data, &l);
752         const auto& index = left.index[j];
753         const auto* r = right_ptrs[index.k];
754         float* out = out_ptrs[index.m];
755         for (int n = 0; n < cols / kNumOperandsR; ++n) {
756           MulAdd(l, &r, &out);
757         }
758         const float sl = Eigen::internal::pfirst<Packet>(l);
759         for (int k = 0; k < cols_mod; ++k) {
760           ScalarMulAdd(sl, &r, &out);
761         }
762         j++;
763       }
764       k_offset += left.block_size;
765       begin = end;
766     }
767   }
768 }
769 
770 #undef LOAD
771 #undef EXPAND_BFLOAT_L
772 #undef EXPAND_BFLOAT_U
773 #undef STORE
774 #undef FMA
775 
776 }  // namespace
777 
778 template <typename TL, typename TR>
779 class SparseMatMul {
780   using MatrixL = BasicMatrix<TL>;
781   using MatrixR = BasicMatrix<TR>;
782   using ConstMatrixMapL = BasicMatrixMap<const TL>;
783   using ConstMatrixMapR = BasicMatrixMap<const TR>;
784   using MatrixMapR = BasicMatrixMap<TR>;
785 
786  public:
787   // Not used; added to match interface of LibxsmmSparseMatMul
788   struct TensorInfoCache {};
789 
790   // Perform matrix multiplication of "left" and "right", and store the result
791   // in *"output".
792  public:
793   static inline void Compute(TensorInfoCache* cache,
794                              const ConstMatrixMapL& left,
795                              const ConstMatrixMapR& right, bool transpose_left,
796                              const DeviceBase::CpuWorkerThreads* thread_pool,
797                              bool transpose_output, MatrixMap* output);
798 
799  private:
800   // Computes multiplication of left and num_cols columns of right, and stores
801   // the output block in *"output" at offsets "output_row_offset" and
802   // "output_col_offset". If assign is true, assigns the value to that block,
803   // else adds the values to the existing values.
804   static inline void ComputeOutputBlock(
805       const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
806       int num_cols, int output_row_offset, int output_col_offset, bool assign,
807       bool transpose_output, MatrixMap* output);
808 
809   // Encodes "mat" using a sparse representation and stores that in
810   // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
811   // "slice_num_cols", each grid element is converted into a SparseSlice and
812   // stored in mat_slices. "slice_block_size" is used to perform further column
813   // blocking of each slice.
814   static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
815       const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
816       int slice_block_size, int slice_num_cols,
817       std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
818       const DeviceBase::CpuWorkerThreads* thread_pool);
819 
820   // This function chops "mat" along column dimension into pieces with at most N
821   // columns, and concatenates the pieces one after the other in "buffer". It
822   // returns the list of the pieces in "slices". It returns a BlockingCounter
823   // which should be used to wait for the shuffle operations to complete.
824   static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
825       const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
826       int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
827       MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
828 
829   // Helper function for CreateDenseSlices to move the data around. It returns a
830   // BlockingCounter which should be used to wait for the shuffle operations to
831   // complete.
832   static inline BlockingCounter* ShuffleMatrix(
833       const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
834       int slice_col_start, int slice_num_cols, const int N,
835       const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
836 
837   // Helper function for CreateDenseSlices to create slices.
838   static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
839                                  const int num_slices,
840                                  std::vector<ConstMatrixMapR*>* slices);
841 
842   // Heuristics to compute various block sizes.
843   // KR, NR: block sizes for "right". We run blocking iterations that operate on
844   // matrices with at most this size.
845   // KL: grid size along the column dimension used while encoding left.
846   // IB, JB: number of left and right slices to multiply together. This is used
847   // for ordering different ComputeBlockOutput operations inside each blocking
848   // iteration so as to potentially reduce the working set size.
849   static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
850                                        const ConstMatrixMapR& right,
851                                        bool transpose_left, int num_threads,
852                                        int* KR, int* NR, int* KL, int* JB,
853                                        int* IB);
854 
855   TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
856 };
857 
858 #ifdef TENSORFLOW_USE_LIBXSMM
859 template <typename TL, typename TR>
860 class LibxsmmSparseMatMul {
861   using MatrixL = BasicMatrix<TL>;
862   using MatrixR = BasicMatrix<TR>;
863   using ConstMatrixMapL = BasicMatrixMap<const TL>;
864   using ConstMatrixMapR = BasicMatrixMap<const TR>;
865   using MatrixMapR = BasicMatrixMap<TR>;
866 
867  public:
868   // This structure contains a set of libxsmm kernels for sizes that have been
869   // encountered previously by this operator so that libxsmm does not need to
870   // reallocate its scratchpad memory each time (which hurts performance
871   // substantially).
872   struct TensorInfoCache {
873     struct TensorInfoCacheEntry {
874       // Parameters for kernel
875       int M;
876       int K;
877       int N;
878       int max_threads;
879       // libxsmm handle and matrix data
880       libxsmm_spmdm_handle handle;
881       libxsmm_CSR_sparseslice* output_csr;
882       // Chain to non-libxsmm implementation's cache in case that ever becomes
883       // useful (it is an empty struct right now)
884       typename SparseMatMul<TL, TR>::TensorInfoCache
885           non_libxsmm_cache;  // Currently not used
886     };
887     // protects entries; invariant: entries is a valid std::multimap
888     tensorflow::mutex lock;
889     // Because there could be multiple matrix multiplies with the same sizes
890     // going on at the same time, we need to allow multiple cache entries for a
891     // given set of parameters. Taking and returning entries is used to make
892     // sure the same cache entry is not used from two threads at a time.
893     std::multimap<std::tuple<int, int, int, int>,
894                   std::unique_ptr<TensorInfoCacheEntry>>
895         entries TF_GUARDED_BY(lock);
896 
TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache897     TensorInfoCache() : lock(), entries() {}
898     // Look up and remove first entry with these parameters, creating one if
899     // there isn't one
take_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache900     std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
901                                                            int max_threads)
902         TF_LOCKS_EXCLUDED(lock) {
903       tensorflow::mutex_lock ml(lock);
904       auto key = std::make_tuple(M, K, N, max_threads);
905       auto it = entries.find(key);
906       if (it != entries.end()) {
907         auto val = std::move(it->second);
908         entries.erase(it);
909         return val;
910       } else {
911         std::unique_ptr<TensorInfoCacheEntry> e{
912             new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
913         // setup scoped allocator, which uses cpu_allocator() for this scope
914         const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
915         libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
916         return e;
917       }
918     }
919     // Add a cache entry with certain parameters
return_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache920     void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
921         TF_LOCKS_EXCLUDED(lock) {
922       tensorflow::mutex_lock ml(lock);
923       auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
924       entries.insert(std::make_pair(key, std::move(e)));
925     }
~TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache926     ~TensorInfoCache() {
927       tensorflow::mutex_lock ml(lock);
928       for (auto& p : entries) {
929         libxsmm_spmdm_destroy(&p.second->handle);
930       }
931       entries.clear();
932     }
933 
934    private:
935     TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
936   };
937 
938   // Perform matrix multiplication of "left" and "right", and store the result
939   // in *"output".
940  public:
941   static inline void Compute(TensorInfoCache* cache,
942                              const ConstMatrixMapL& left,
943                              const ConstMatrixMapR& right, bool transpose_left,
944                              const DeviceBase::CpuWorkerThreads* thread_pool,
945                              bool transpose_output, MatrixMap* output);
946 
947  private:
948   TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
949 };
950 #endif
951 
952 template <typename TL, typename TR,
953           template <typename TL2, typename TR2> class DoMatMul>
954 class SparseMatMulOp : public OpKernel {
955   using MatrixR = BasicMatrix<TR>;
956   using ConstMatrixMapR = BasicMatrixMap<const TR>;
957 
958  public:
SparseMatMulOp(OpKernelConstruction * ctx)959   explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
960     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
961     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
962     OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
963     OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
964   }
965 
Compute(OpKernelContext * ctx)966   void Compute(OpKernelContext* ctx) override {
967     const Tensor& a = ctx->input(0);
968     const Tensor& b = ctx->input(1);
969     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
970                 errors::InvalidArgument("a is not a matrix"));
971     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
972                 errors::InvalidArgument("b is not a matrix"));
973 
974     const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
975     const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
976     const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
977     const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
978 
979     OP_REQUIRES(ctx, k == k2,
980                 errors::InvalidArgument(
981                     "Matrix size incompatible: a: ", a.shape().DebugString(),
982                     ", b: ", b.shape().DebugString()));
983     Tensor* output = nullptr;
984     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
985 
986     if (k == 0) {
987       // If the inner dimension k in the matrix multiplication is zero, we fill
988       // the output with zeros.
989       functor::SetZeroFunctor<CPUDevice, float> f;
990       f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
991       return;
992     }
993 
994     auto out = output->matrix<float>();
995 
996     std::unique_ptr<Tensor> a_float;
997     std::unique_ptr<Tensor> b_float;
998     if (!a_is_sparse_ && !b_is_sparse_) {
999       auto left = &a;
1000       auto right = &b;
1001       // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
1002       if (std::is_same<TL, bfloat16>::value) {
1003         a_float.reset(new Tensor(DT_FLOAT, a.shape()));
1004         BFloat16ToFloat(a.flat<bfloat16>().data(),
1005                         a_float->flat<float>().data(), a.NumElements());
1006         left = a_float.get();
1007       }
1008       if (std::is_same<TR, bfloat16>::value) {
1009         b_float.reset(new Tensor(DT_FLOAT, b.shape()));
1010         BFloat16ToFloat(b.flat<bfloat16>().data(),
1011                         b_float->flat<float>().data(), b.NumElements());
1012         right = b_float.get();
1013       }
1014       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
1015       dim_pair[0].first = transpose_a_ ? 0 : 1;
1016       dim_pair[0].second = transpose_b_ ? 1 : 0;
1017 
1018       out.device(ctx->template eigen_device<CPUDevice>()) =
1019           left->matrix<float>().contract(right->matrix<float>(), dim_pair);
1020       return;
1021     }
1022 
1023     auto left = &a;
1024     auto right = &b;
1025     bool transpose_output = false;
1026     bool transpose_a = transpose_a_;
1027     bool transpose_b = transpose_b_;
1028     if (!a_is_sparse_) {
1029       // Swap the order of multiplications using the identity:
1030       // A * B = (B' *  A')'.
1031       std::swap(left, right);
1032       std::swap(transpose_a, transpose_b);
1033       transpose_a = !transpose_a;
1034       transpose_b = !transpose_b;
1035       transpose_output = !transpose_output;
1036     }
1037 
1038     std::unique_ptr<Tensor> right_tr;
1039     if (transpose_b) {
1040       // TODO(agarwal): avoid transposing the matrix here and directly handle
1041       // transpose in CreateDenseSlices.
1042       right_tr.reset(
1043           new Tensor(right->dtype(),
1044                      TensorShape({right->dim_size(1), right->dim_size(0)})));
1045 
1046       const auto perm = dsizes_10();
1047       if (transpose_output) {
1048         right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
1049             right->matrix<TL>().shuffle(perm);
1050       } else {
1051         right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
1052             right->matrix<TR>().shuffle(perm);
1053       }
1054       right = right_tr.get();
1055     }
1056 
1057     if (transpose_output) {
1058       DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
1059                                 right->matrix<TL>(), transpose_a,
1060                                 ctx->device()->tensorflow_cpu_worker_threads(),
1061                                 transpose_output, &out);
1062     } else {
1063       DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
1064                                 right->matrix<TR>(), transpose_a,
1065                                 ctx->device()->tensorflow_cpu_worker_threads(),
1066                                 transpose_output, &out);
1067     }
1068   }
1069 
1070  private:
1071   bool transpose_a_;
1072   bool transpose_b_;
1073   bool a_is_sparse_;
1074   bool b_is_sparse_;
1075 
1076   // Cache for non-transposed-output multiply
1077   typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
1078   // Cache for transposed-output multiply
1079   typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
1080 
1081   TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
1082 };
1083 
1084 template <typename TL, typename TR>
ComputeOutputBlock(const std::vector<SparseSlice<TL> * > & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,int num_cols,int output_row_offset,int output_col_offset,bool assign,bool transpose_output,MatrixMap * output)1085 inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
1086     const std::vector<SparseSlice<TL>*>& left,
1087     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
1088     int output_row_offset, int output_col_offset, bool assign,
1089     bool transpose_output, MatrixMap* output) {
1090   const auto perm = dsizes_10();
1091   int num_rows = left[0]->num_rows;
1092   const int rhs_num_cols = right.dimension(1);
1093   DCHECK_LE(num_cols, rhs_num_cols);
1094   Matrix out(num_rows, rhs_num_cols);
1095   out.setZero();
1096   if (num_cols == N) {
1097     GEPP<TL, TR, N>(left, right, num_cols, &out);
1098   } else {
1099     GEPP<TL, TR, -1>(left, right, num_cols, &out);
1100   }
1101   if (!assign) {
1102     const DSizes begin(output_row_offset, output_col_offset);
1103     const DSizes sizes(num_rows, num_cols);
1104     if (transpose_output) {
1105       if (num_cols == rhs_num_cols) {
1106         output->shuffle(perm).slice(begin, sizes) += out;
1107       } else {
1108         const auto zero = dsizes_00();
1109         output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
1110       }
1111     } else {
1112       if (num_cols == rhs_num_cols) {
1113         output->slice(begin, sizes) += out;
1114       } else {
1115         const auto zero = dsizes_00();
1116         output->slice(begin, sizes) += out.slice(zero, sizes);
1117       }
1118     }
1119   } else {
1120     std::unique_ptr<Matrix> out_tr;
1121     if (transpose_output) {
1122       out_tr.reset(new Matrix(rhs_num_cols, num_rows));
1123       *out_tr = out.shuffle(perm);
1124       std::swap(output_row_offset, output_col_offset);
1125       std::swap(num_rows, num_cols);
1126     }
1127     const Matrix& final_out = transpose_output ? *out_tr : out;
1128     for (int i = 0; i < num_rows; ++i) {
1129       memcpy(&(*output)(output_row_offset + i, output_col_offset),
1130              &final_out(i, 0), num_cols * sizeof(float));
1131     }
1132   }
1133 }
1134 
1135 template <typename TL, typename TR>
1136 inline std::unique_ptr<BlockingCounter>
CreateSparseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & mat,bool transpose,int slice_num_rows,int slice_block_size,int slice_num_cols,std::vector<std::vector<SparseSlice<TL> * >> * mat_slices,const DeviceBase::CpuWorkerThreads * thread_pool)1137 SparseMatMul<TL, TR>::CreateSparseSlices(
1138     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
1139     int slice_num_rows, int slice_block_size, int slice_num_cols,
1140     std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
1141     const DeviceBase::CpuWorkerThreads* thread_pool) {
1142   const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
1143   const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
1144   const int num_slices_dim0 =
1145       std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
1146   const int num_slices_dim1 =
1147       std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
1148   mat_slices->resize(num_slices_dim0);
1149   BlockingCounter* counter =
1150       new BlockingCounter(num_slices_dim0 * num_slices_dim1);
1151   auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
1152                                    SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
1153                                    int col_offset) {
1154     if (transpose) {
1155       sparse_slice->template Initialize<true>(*slice, col_offset);
1156     } else {
1157       sparse_slice->template Initialize<false>(*slice, col_offset);
1158     }
1159     delete slice;
1160     counter->DecrementCount();
1161   };
1162   for (int i = 0; i < num_slices_dim0; ++i) {
1163     (*mat_slices)[i].resize(num_slices_dim1);
1164     int num_rows =
1165         std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
1166     for (int j = 0; j < num_slices_dim1; ++j) {
1167       int num_cols =
1168           std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
1169       SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
1170       if (transpose) {
1171         slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1172             &mat(0, i * slice_num_rows), mat.dimensions());
1173       } else {
1174         DSizes d(num_rows, mat_num_cols);
1175         slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1176             &mat(i * slice_num_rows, 0), d);
1177       }
1178       auto* sparse_slice =
1179           new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
1180       (*mat_slices)[i][j] = sparse_slice;
1181       thread_pool->workers->Schedule(
1182           [=]() { work(sparse_slice, slice, slice_num_cols * j); });
1183     }
1184   }
1185   return std::unique_ptr<BlockingCounter>(counter);
1186 }
1187 #define LOAD(x) Eigen::internal::ploadu<Packet>((x));
1188 #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
1189 #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
1190 
1191 template <int NUM_ELEM = -1>
CopyAndMayBeInterleaveBfloat16(void * bdst,const void * bsrc,int num_elements)1192 ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
1193                                                   int num_elements) {
1194   DCHECK_GE(kNumOperands, 8);
1195   static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
1196   const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
1197   DCHECK_EQ(num, num_elements);
1198   const float* src = reinterpret_cast<const float*>(bsrc);
1199   float* dst = reinterpret_cast<float*>(bdst);
1200   for (int index = 0; index + kStep <= num; index += kStep) {
1201     auto in = LOAD(src);
1202     auto tmp = INTERLEAVE(in);
1203     STORE(dst, tmp);
1204     src += kNumOperands;
1205     dst += kNumOperands;
1206   }
1207   if (num % kStep != 0) {
1208     memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
1209   }
1210 }
1211 
1212 template <typename T>
CopyAndMayBeInterleave(void * dst,const void * src,int num_elements)1213 ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
1214                                           int num_elements) {
1215   if (std::is_same<T, float>::value || kNumOperands < 8) {
1216     memcpy(dst, src, num_elements * sizeof(T));
1217   } else if (std::is_same<T, bfloat16>::value) {
1218     if (num_elements == N) {
1219       CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
1220     } else {
1221       CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
1222     }
1223   } else {
1224     LOG(FATAL) << "Unsupported type";
1225   }
1226 }
1227 
1228 #undef LOAD
1229 #undef Interleave
1230 #undef Store
1231 
1232 template <typename TL, typename TR>
ShuffleMatrix(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int slice_row_start,int slice_num_rows,int slice_col_start,int slice_num_cols,const int N,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer)1233 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
1234     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
1235     int slice_row_start, int slice_num_rows, int slice_col_start,
1236     int slice_num_cols, const int N,
1237     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
1238   DCHECK_EQ(N % 2, 0);
1239   DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
1240   // Note(nikhilsarda): This heuristic is optimal in benchmarks as of
1241   // Jan 21, 2020.
1242   int num_threads = std::min(thread_pool->num_threads, 8);
1243   BlockingCounter* counter = new BlockingCounter(num_threads);
1244   DCHECK_EQ(N, buffer->dimension(1));
1245   auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
1246                        slice_num_cols, N, buffer, counter](int s, int e) {
1247     const int row_start = s % slice_num_rows + slice_row_start;
1248     const int col_start = s / slice_num_rows * N + slice_col_start;
1249     auto* out_start = &(*buffer)(s, 0);
1250     const auto* input_start = &mat(row_start, col_start);
1251     const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
1252                                  slice_col_start + slice_num_cols - 1);
1253     const int mat_num_cols = mat.dimension(1);
1254     const int row_slice_size = slice_num_rows * mat_num_cols;
1255 
1256     const int aligned_end = slice_num_cols / N * slice_num_rows;
1257     const int e1 = std::min(e, aligned_end);
1258     while (s < e1) {
1259       CopyAndMayBeInterleave<TR>(out_start, input_start, N);
1260       out_start += N;
1261       input_start += mat_num_cols;
1262       if (input_start > input_end) {
1263         input_start = input_start - row_slice_size + N;
1264       }
1265       ++s;
1266     }
1267     int s1 = std::max(s, aligned_end);
1268     const int copy_num_cols = slice_num_cols % N;
1269     while (s1 < e) {
1270       CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
1271       out_start += N;
1272       input_start += mat_num_cols;
1273       ++s1;
1274     }
1275     if (counter) counter->DecrementCount();
1276   };
1277 
1278   int start = 0;
1279   int end = 0;
1280   int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
1281   DCHECK_LE(num_out_rows, buffer->dimension(0));
1282   for (int i = std::max(1, num_threads); i > 0; --i) {
1283     end = start + num_out_rows / i;
1284     thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
1285     num_out_rows -= (end - start);
1286     start = end;
1287   }
1288   return counter;
1289 }
1290 
1291 template <typename TL, typename TR>
SliceMatrix(const MatrixR & mat,const int num_rows,const int num_slices,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1292 inline void SparseMatMul<TL, TR>::SliceMatrix(
1293     const MatrixR& mat, const int num_rows, const int num_slices,
1294     std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1295   slices->resize(num_slices);
1296   DSizes d(num_rows, mat.dimension(1));
1297   DCHECK_LE(num_rows * num_slices, mat.dimension(0));
1298   for (int i = 0; i < num_slices; ++i) {
1299     (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
1300   }
1301 }
1302 
1303 template <typename TL, typename TR>
CreateDenseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int row_start,int num_rows,int col_start,int num_cols,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1304 inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
1305     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
1306     int num_rows, int col_start, int num_cols,
1307     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
1308     std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1309   std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
1310       mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
1311   const int num_slices = (num_cols + N - 1) / N;
1312   SliceMatrix(*buffer, num_rows, num_slices, slices);
1313   return shuffle_counter;
1314 }
1315 
1316 template <typename TL, typename TR>
ComputeBlockSizes(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,int num_threads,int * KR,int * NR,int * KL,int * JB,int * IB)1317 inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
1318     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1319     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1320     bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
1321     int* IB) {
1322   // Heuristics for calculating block sizes
1323   // Assume two hyperthreads per core.
1324   const int est_num_cores = std::max(1, (num_threads + 1) / 2);
1325   // Use block of rhs with at most 128K floats per core.
1326   const int mem = est_num_cores * 128 * 1024;
1327   *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
1328   *NR = right.dimension(1);
1329   if (*KR * *NR > mem) {
1330     // 4096 may be enough to amortize the cost of writes.
1331     *KR = std::min<int>(*KR, 4096);
1332   }
1333   // Use sizes that are multiples of K and 256.
1334   *KR = std::max(1, *KR / K) * K;
1335   *NR = std::max(1, *NR / 256) * 256;
1336   if (*KR * *NR > mem) {
1337     *NR = mem / *KR;
1338   }
1339   *NR = std::max(1, *NR / 256) * 256;
1340 
1341   const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1342   const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1343   for (*KL = 1024; *KL > K; *KL /= 2) {
1344     if (*KR % *KL == 0 &&
1345         std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
1346       break;
1347     }
1348   }
1349   DCHECK_EQ(*KL % K, 0);
1350   DCHECK_GE(*KR, *KL);
1351   if (*KR < right.dimension(0)) {
1352     CHECK_EQ(*KR % *KL, 0);
1353   }
1354 
1355   *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
1356   *IB = 8 * *JB;
1357   DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
1358 }
1359 
1360 #ifdef TENSORFLOW_USE_LIBXSMM
1361 
1362 template <typename F>
do_on_all_threads(const DeviceBase::CpuWorkerThreads * thread_pool,const F & f)1363 void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
1364                        const F& f) {
1365   int num_threads = thread_pool->num_threads;
1366   if (num_threads == 0) {
1367     LOG(FATAL) << "Have 0 threads in thread pool";
1368   } else if (num_threads == 1) {
1369     f(0);
1370   } else {
1371     BlockingCounter counter(num_threads - 1);
1372     for (int i = 1; i < num_threads; ++i) {
1373       thread_pool->workers->Schedule([&, i]() {
1374         f(i);
1375         counter.DecrementCount();
1376       });
1377     }
1378     f(0);
1379     counter.Wait();
1380   }
1381 }
1382 
1383 template <typename T>
1384 struct empty_type_wrapper {};
1385 
1386 // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
1387 // allow overloading
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,const float * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1388 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1389     empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1390     const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
1391     int tid, int nthreads) {
1392   return libxsmm_spmdm_createSparseSlice_fp32_thread(
1393       handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
1394 }
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,const bfloat16 * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1395 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1396     empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1397     char transA, const bfloat16* A,
1398     libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
1399     int nthreads) {
1400   return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
1401       handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A),
1402       libxsmm_output_csr_a, block_id, tid, nthreads);
1403 }
1404 
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,char transB,const bfloat16 * alpha,libxsmm_CSR_sparseslice * A_sparse,const bfloat16 * B,char transC,const bfloat16 * beta,float * C,int block_id,int tid,int nthreads)1405 void wrapper_libxsmm_spmdm_compute_generic_thread(
1406     empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1407     char transA, char transB, const bfloat16* alpha,
1408     libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
1409     const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
1410   return libxsmm_spmdm_compute_bfloat16_thread(
1411       handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha),
1412       A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
1413       reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid,
1414       nthreads);
1415 }
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,char transB,const float * alpha,libxsmm_CSR_sparseslice * A_sparse,const float * B,char transC,const float * beta,float * C,int block_id,int tid,int nthreads)1416 void wrapper_libxsmm_spmdm_compute_generic_thread(
1417     empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1418     char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
1419     const float* B, char transC, const float* beta, float* C, int block_id,
1420     int tid, int nthreads) {
1421   return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
1422                                            A_sparse, B, transC, beta, C,
1423                                            block_id, tid, nthreads);
1424 }
1425 
1426 template <typename TL, typename TR>
Compute(typename LibxsmmSparseMatMul<TL,TR>::TensorInfoCache * cache,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1427 inline void LibxsmmSparseMatMul<TL, TR>::Compute(
1428     typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
1429     const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
1430     const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
1431     bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1432     bool transpose_output, MatrixMap* output) {
1433   const int num_threads = thread_pool->num_threads;
1434   const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1435   const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1436   const int right_dim0 = right.dimension(0);
1437   const int right_dim1 = right.dimension(1);
1438   CHECK_EQ(left_dim1, right_dim0);
1439   CHECK_EQ(left_dim0,
1440            (transpose_output ? output->dimension(1) : output->dimension(0)));
1441   CHECK_EQ(right_dim1,
1442            (transpose_output ? output->dimension(0) : output->dimension(1)));
1443 #if 0  // this issue seems to be resolved
1444   if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
1445     // Causes problems in libxsmm
1446     SparseMatMul<TL, TR>::Compute(
1447         nullptr /* Assumes no cached data for fallback */, left, right,
1448         transpose_left, thread_pool, transpose_output, output);
1449     return;
1450   }
1451 #endif
1452   auto left_data = left.data();
1453   auto right_data = right.data();
1454   auto output_data = output->data();
1455   // Initialize libxsmm for this matrix; make sure another thread doesn't use
1456   // this handle
1457   auto entry =
1458       cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
1459   // Convert the left matrix to compressed sparse row (CSR) format
1460   ptrdiff_t total_num_creation_blocks =
1461       libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
1462   std::atomic<int> cur_create_block_number;
1463   cur_create_block_number.store(0);
1464   do_on_all_threads(thread_pool, [&](int i) {
1465     while (true) {
1466       int work_item = cur_create_block_number.fetch_add(1);
1467       if (work_item >= total_num_creation_blocks) break;
1468       wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1469           empty_type_wrapper<TL>{}, &entry->handle,
1470           (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
1471           i, num_threads);
1472     }
1473   });
1474   // Do matrix-matrix multiplication
1475   ptrdiff_t total_num_mult_blocks =
1476       libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
1477   std::atomic<int> cur_mult_block_number;
1478   cur_mult_block_number.store(0);
1479   do_on_all_threads(thread_pool, [&](int i) {
1480     while (true) {
1481       int work_item = cur_mult_block_number.fetch_add(1);
1482       if (work_item >= total_num_mult_blocks) break;
1483       const TL alpha(1.0);  // Stored in a variable so we can get a pointer
1484       const TL beta(0.0);   // Stored in a variable so we can get a pointer
1485       wrapper_libxsmm_spmdm_compute_generic_thread(
1486           empty_type_wrapper<TL>{}, &entry->handle,
1487           (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
1488           right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
1489           work_item, i, num_threads);
1490     }
1491   });
1492   // Put handle + CSR storage back into cache
1493   cache->return_cache_entry(std::move(entry));
1494 }
1495 
1496 #endif  // TENSORFLOW_USE_LIBXSMM
1497 
1498 // Here is an overview of the SparseMatMul code. Note that we assume that the
1499 // left matrix is sparse.
1500 //
1501 // The matrix "left" is divided into a grid with blocksize of (M, KL). Each
1502 // block is encoded as a SparseSlice. These grid elements are stored as
1503 // std::vector<std::vector<SparseSlice>>. Each element of the outer vector
1504 // represents M rows of the left matrix. Lets call these elements l_i and lets
1505 // call each element of the inner vector L_mk.
1506 //
1507 // The matrix "right" is divided into a grid with block size KR * NR.  Lets
1508 // denote the blocks on the right as R_kn. Note that we ensure that KL divides
1509 // KR so that for each element R_kn, we don't need to multiply it with any
1510 // partial L_mk blocks.
1511 //
1512 // We then multiply each right side block R_kn with the full "left" matrix and
1513 // update the output. These iterations are run sequentially since R_kn are
1514 // packed into the same underlying temporary buffer.
1515 //
1516 // In each iteration we do the following:
1517 // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
1518 //    (=128) columns and then concatenating these slices into a buffer. This is
1519 //    done so that each slice r_j of R_kn is stored contiguously in memory. Note
1520 //    that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
1521 //    buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
1522 // 2. For each (l_i, r_j), we compute the inner product using the GEPP function
1523 //    and update the output block o_ij. These calls are further blocked to
1524 //    reduce the working set size. In each iteration we take IB elements from
1525 //    {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
1526 template <typename TL, typename TR>
Compute(typename SparseMatMul<TL,TR>::TensorInfoCache *,const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1527 inline void SparseMatMul<TL, TR>::Compute(
1528     typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
1529     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1530     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1531     bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1532     bool transpose_output, MatrixMap* output) {
1533   const int num_threads = thread_pool->num_threads;
1534   int KR, NR, KL, JB, IB;
1535   ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
1536                     &JB, &IB);
1537   // Slice the left matrix
1538   std::vector<std::vector<SparseSlice<TL>*>> left_slices;
1539   std::unique_ptr<BlockingCounter> sparse_slice_counter =
1540       CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
1541                          transpose_left, M, K, KL, &left_slices, thread_pool);
1542   const int num_left_slices = left_slices.size();
1543 
1544   const int right_dim0 = right.dimension(0);
1545   const int right_dim1 = right.dimension(1);
1546   // Allocate buffer for storing slices of right matrix.
1547   // Note buffer needs enough space to hold at most a KR * NR matrix since that
1548   // is the block size per iteration.
1549   const int buffer_num_rows =
1550       std::min(KR, right_dim0) * ((std::min(NR, right_dim1) + N - 1) / N);
1551   MatrixR buffer(buffer_num_rows, N);
1552   std::vector<ConstMatrixMapR*> right_slices;
1553 
1554   std::vector<SparseSlice<TL>*> block_left_slices;
1555   std::vector<std::function<void(void)>> tasks;
1556   // Number of blocks based on block sizes of KR * NR.
1557   const int num_k_blocks = (right_dim0 + KR - 1) / KR;
1558   const int num_n_blocks = (right_dim1 + NR - 1) / NR;
1559   std::unique_ptr<BlockingCounter> dense_slice_counter;
1560 
1561   for (int nb = 0; nb < num_n_blocks; ++nb) {
1562     const int right_num_cols =
1563         std::min(NR, static_cast<int>(right_dim1 - NR * nb));
1564     for (int kb = 0; kb < num_k_blocks; ++kb) {
1565       const int right_num_rows =
1566           std::min(KR, static_cast<int>(right_dim0 - KR * kb));
1567       dense_slice_counter = CreateDenseSlices(
1568           right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
1569           &buffer, &right_slices);
1570       const int num_right_slices = right_slices.size();
1571       tasks.reserve(num_left_slices * num_right_slices);
1572       for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
1573         for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
1574           for (int j_inner = j_outer;
1575                j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
1576             const int num_cols = std::min(N, right_num_cols - N * j_inner);
1577             for (int i_inner = i_outer;
1578                  i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
1579               block_left_slices.clear();
1580               int begin = kb * KR / KL;
1581               int end = std::min<int>((kb + 1) * KR / KL,
1582                                       (right.dimension(0) + KL - 1) / KL);
1583               DCHECK_LT(begin, end);
1584               block_left_slices.insert(block_left_slices.begin(),
1585                                        left_slices[i_inner].begin() + begin,
1586                                        left_slices[i_inner].begin() + end);
1587               tasks.push_back(std::bind(
1588                   &ComputeOutputBlock, block_left_slices,
1589                   std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
1590                   N * j_inner + nb * NR, kb == 0, transpose_output, output));
1591             }
1592           }
1593         }
1594       }
1595       if (sparse_slice_counter) {
1596         sparse_slice_counter->Wait();
1597         sparse_slice_counter.reset(nullptr);
1598       }
1599       if (dense_slice_counter) {
1600         dense_slice_counter->Wait();
1601         dense_slice_counter.reset(nullptr);
1602       }
1603       BlockingCounter bc(tasks.size());
1604       for (const auto& t : tasks) {
1605         thread_pool->workers->Schedule([&bc, &t]() {
1606           t();
1607           bc.DecrementCount();
1608         });
1609       }
1610       bc.Wait();
1611       tasks.clear();
1612       for (auto& temp : right_slices) {
1613         delete temp;
1614       }
1615       right_slices.clear();
1616     }
1617   }
1618   for (auto& left_slice : left_slices) {
1619     for (auto& temp : left_slice) {
1620       delete temp;
1621     }
1622     left_slice.clear();
1623   }
1624 }
1625 
1626 #define REGISTER_SPARSE_MATMUL(TA, TB)                   \
1627   REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
1628                               .Device(DEVICE_CPU)        \
1629                               .TypeConstraint<TA>("Ta")  \
1630                               .TypeConstraint<TB>("Tb"), \
1631                           SparseMatMulOp<TA, TB, SparseMatMul>);
1632 #ifdef TENSORFLOW_USE_LIBXSMM
1633 #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB)           \
1634   REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
1635                               .Device(DEVICE_CPU)        \
1636                               .TypeConstraint<TA>("Ta")  \
1637                               .TypeConstraint<TB>("Tb"), \
1638                           SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
1639 #endif
1640 
1641 REGISTER_SPARSE_MATMUL(float, bfloat16);
1642 REGISTER_SPARSE_MATMUL(bfloat16, float);
1643 
1644 #ifdef TENSORFLOW_USE_LIBXSMM
1645 REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
1646 REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
1647 #else
1648 REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
1649 REGISTER_SPARSE_MATMUL(float, float);
1650 #endif
1651 
1652 #undef REGISTER_SPARSE_MATMUL
1653 
1654 }  // end namespace tensorflow
1655