1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/sparse_matmul_op.h"
21 
22 #include <map>
23 #include <memory>
24 #include <vector>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/bfloat16.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/lib/core/blocking_counter.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/gtl/stl_util.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/thread_annotations.h"
40 #include "tensorflow/core/platform/types.h"
41 #ifdef TENSORFLOW_USE_LIBXSMM
42 #include "include/libxsmm_intrinsics_x86.h"
43 #include "include/libxsmm_malloc.h"
44 #include "include/libxsmm_spmdm.h"
45 #endif
46 
47 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
48 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
49 #endif
50 
51 #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
52 
53 namespace tensorflow {
54 namespace {
55 
56 template <typename T>
57 using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
58 
59 template <typename T>
60 using BasicMatrixMap =
61     Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
62 
63 using Matrix = BasicMatrix<float>;
64 using MatrixMap = BasicMatrixMap<float>;
65 using CPUDevice = Eigen::ThreadPoolDevice;
66 using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
67 
68 // Two commonly used static dsizes. We use Eigen::type2index to allow as much
69 // compile time optimization as possible.
70 #ifdef EIGEN_HAS_INDEX_LIST
71 inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
dsizes_00()72 dsizes_00() {
73   return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
74 }
75 inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
dsizes_10()76 dsizes_10() {
77   return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
78 }
79 #else
dsizes_00()80 inline DSizes dsizes_00() { return DSizes(0, 0); }
dsizes_10()81 inline DSizes dsizes_10() { return DSizes(1, 0); }
82 #endif
83 
84 // Blocksizes
85 // TODO(agarwal): compute these sizes based on cache sizes.
86 const int K = 64;
87 const int M = 64;
88 const int N = 128;
89 
90 // This stores a sparse representation of a slice of a matrix with size
91 // (num_rows, num_cols). The slice is represented as a series of blocks of size
92 // (num_rows, b), where b = block_size for all but the last block, which may
93 // have fewer columns.
94 //
95 // num_rows and block_size are assumed to be <= 256. This allows storing
96 // different indices as uint8.
97 //
98 // For each block, we store all the non zero entries in data/data3 vector and
99 // the corresponding coordinates of the element in index/index3 vectors. index3
100 // vector stores index of 3 elements in the same row so that these elements can
101 // share the same row coordinate. Each entry in Index3 corresponds to 3 entries
102 // in data3.
103 //
104 // Note that all the data/indices of all the blocks are stored in the same
105 // vectors respectively. To identify block boundaries, we store the block
106 // offsets using index3_offset/index_offset. If there are n blocks in the slice,
107 // index3_offset and index_offset have n entries. The indices for the ith block
108 // are the values in the following range:
109 // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
110 // index_offset.
111 template <typename T>
112 struct SparseSlice {
113   using ConstMatrixMap = BasicMatrixMap<const T>;
114 
115  public:
116   // Indices of three elements on the same row.
117   struct Index3 {
118     uint8 m;  // row
119     // columns
120     uint8 k1;
121     uint8 k2;
122     uint8 k3;
123   };
124 
125   // Index of one element.
126   struct Index {
127     uint8 m;
128     uint8 k;
129   };
130 
SparseSlicetensorflow::__anon00ddf21f0111::SparseSlice131   SparseSlice(int nrows, int ncols, int bsize)
132       : num_rows(nrows), num_cols(ncols), block_size(bsize) {
133     DCHECK_LE(nrows, 256);
134     DCHECK_LE(block_size, 256);
135   }
136 
137   // Initializes the slice with data starting at mat(0, col_offset) and with
138   // size (num_rows, num_cols).
139   // If Transpose is true, implicitly transposes mat.
140   template <bool Transpose = false>
141   void Initialize(const ConstMatrixMap& mat, int col_offset);
142 
143   void Clear();
144 
145   // See comments above.
146   std::vector<int> index3_offset;
147   std::vector<Index3> index3;
148   std::vector<T> data3;
149 
150   // See comments above. Similar to "index3" except that each element in "index"
151   // corresponds to one element in data.
152   std::vector<int> index_offset;
153   std::vector<Index> index;
154   std::vector<T> data;
155 
156   // Number of rows and columns for the slice.
157   const int num_rows;
158   const int num_cols;
159 
160   // Block size used to initialize from a matrix.
161   const int block_size;
162 };
163 
164 template <typename T>
165 bool IsZero(T v);
166 
167 template <>
IsZero(bfloat16 v)168 ALWAYS_INLINE bool IsZero(bfloat16 v) {
169   return v.IsZero();
170 }
171 
172 template <>
IsZero(float v)173 ALWAYS_INLINE bool IsZero(float v) {
174   return v == 0.0f;
175 }
176 
177 template <typename T>
178 template <bool Transpose>
Initialize(const typename SparseSlice<T>::ConstMatrixMap & mat,int col_offset)179 void SparseSlice<T>::Initialize(
180     const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
181   const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
182   const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
183   DCHECK_LE(num_rows, mat_rows);
184   DCHECK_LE(num_cols + col_offset, mat_cols);
185 
186   int num_blocks = (num_cols + block_size - 1) / block_size;
187   int mat_size = num_rows * num_cols;
188 
189   index3_offset.reserve(num_blocks);
190   data3.reserve(mat_size);
191   index3.reserve(mat_size / 3);
192 
193   index_offset.reserve(num_blocks);
194   data.reserve(num_blocks * num_rows * 2);
195   index.reserve(num_blocks * num_rows * 2);
196 
197   Index3 idx3;
198   const int stride = Transpose ? mat.dimension(1) : 1;
199 
200   for (int i = 0; i < num_blocks; ++i) {
201     int num_block_cols = std::min(block_size, num_cols - block_size * i);
202     for (int row = 0; row < num_rows; ++row) {
203       idx3.m = static_cast<uint8>(row);
204       // Safety note: The following code has a race, since it checks whether
205       // *curr is nonzero and then reads it again on use.  However, the result
206       // of the race is only that some of the "nonzeros" in the resulting sparse
207       // representation may actually be zero, which is harmless.
208       const auto* start =
209           Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
210       const auto* curr = start;
211       const auto* end = start + stride * num_block_cols;
212       uint8 k = 0;
213 #define NEXT_ELEM \
214   curr += stride; \
215   ++k;
216 #define EAT_ZEROS                          \
217   while (curr < end && IsZero<T>(*curr)) { \
218     NEXT_ELEM;                             \
219   }
220       while (true) {
221         EAT_ZEROS
222         if (curr >= end) break;
223         idx3.k1 = k;
224         const T value1 = *curr;
225         NEXT_ELEM;
226 
227         EAT_ZEROS
228         if (curr >= end) {
229           data.push_back(value1);
230           index.push_back({idx3.m, idx3.k1});
231           break;
232         }
233         idx3.k2 = k;
234         const T value2 = *curr;
235         NEXT_ELEM;
236 
237         EAT_ZEROS
238         if (curr >= end) {
239           data.push_back(value2);
240           index.push_back({idx3.m, idx3.k2});
241           data.push_back(value1);
242           index.push_back({idx3.m, idx3.k1});
243           break;
244         }
245         idx3.k3 = k;
246         data3.push_back(value1);
247         data3.push_back(value2);
248         data3.push_back(*curr);
249         NEXT_ELEM;
250         index3.push_back(idx3);
251 #undef NEXT_ELEM
252 #undef EAT_ZEROS
253       }
254     }
255     col_offset += block_size;
256     index3_offset.push_back(index3.size());
257     index_offset.push_back(index.size());
258   }
259   DCHECK_EQ(index3_offset.size(), num_blocks);
260   DCHECK_EQ(index_offset.size(), num_blocks);
261   DCHECK_EQ(3 * index3.size(), data3.size());
262   DCHECK_EQ(index.size(), data.size());
263 }
264 
265 template <typename T>
Clear()266 void SparseSlice<T>::Clear() {
267   index3_offset.clear();
268   index3.clear();
269   data3.clear();
270   index_offset.clear();
271   index.clear();
272   data.clear();
273 }
274 
275 using Packet = Eigen::internal::packet_traits<float>::type;
276 const int kNumOperands = (sizeof(Packet) / sizeof(float));
277 #define LOAD(x) Eigen::internal::pload<Packet>(x);
278 #define EXPAND_BFLOAT_L(x, y) \
279   const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
280 #define EXPAND_BFLOAT_U(x, y) \
281   const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
282 #define STORE(x, y) Eigen::internal::pstore<float>(x, y);
283 #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
284 
ConvertBfloat16ToFloat(const bfloat16 * src)285 ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
286   float out = 0;
287   auto tmp = reinterpret_cast<bfloat16*>(&out);
288 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
289   tmp[0] = *src;
290 #else
291   tmp[1] = *src;
292 #endif
293   return out;
294 }
295 
ConvertFourBfloat16ToFloat(const bfloat16 * src)296 ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
297   return Eigen::internal::pload4bf16<Packet>(
298       reinterpret_cast<const float*>(src));
299 }
300 
ConvertTwoBfloat16ToFloat(const bfloat16 * src)301 ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
302   return Eigen::internal::pload2bf16<Packet>(
303       reinterpret_cast<const float*>(src));
304 }
305 
ScalarMulAdd(const float a,const float ** inp,float ** out)306 ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
307   **out += a * **inp;
308   ++*inp;
309   ++*out;
310 }
311 
ScalarMulAdd(const float a,const bfloat16 ** inp,float ** out)312 ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
313                                 float** out) {
314   float inp_f = ConvertBfloat16ToFloat(*inp);
315   **out += a * inp_f;
316   ++*inp;
317   ++*out;
318 }
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)319 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
320                                     const float a3, const bfloat16** inp1,
321                                     const bfloat16** inp2,
322                                     const bfloat16** inp3, float** out) {
323   float inp1_f = ConvertBfloat16ToFloat(*inp1);
324   float inp2_f = ConvertBfloat16ToFloat(*inp2);
325   float inp3_f = ConvertBfloat16ToFloat(*inp3);
326   **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
327   ++*out;
328   ++*inp1;
329   ++*inp2;
330   ++*inp3;
331 }
332 
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)333 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
334                                     const float a3, const float** inp1,
335                                     const float** inp2, const float** inp3,
336                                     float** out) {
337   **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
338   ++*out;
339   ++*inp1;
340   ++*inp2;
341   ++*inp3;
342 }
343 
LoadSingleScalar(const bfloat16 ** data,Packet * l)344 ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
345   auto tmp = ConvertBfloat16ToFloat(*data);
346   *l = Eigen::internal::pset1<Packet>(tmp);
347   ++*data;
348 }
349 
LoadTwoScalars(const bfloat16 ** data,Packet * l1,Packet * l2)350 ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
351                                   Packet* l2) {
352   if (kNumOperands >= 2) {
353     auto tmp = ConvertTwoBfloat16ToFloat(*data);
354     *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
355     *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
356     *data += 2;
357   } else {
358     LoadSingleScalar(data, l1);
359     LoadSingleScalar(data, l2);
360   }
361 }
362 
LoadFourScalars(const bfloat16 ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)363 ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
364                                    Packet* l2, Packet* l3, Packet* l4) {
365   if (kNumOperands >= 4) {
366     auto tmp = ConvertFourBfloat16ToFloat(*data);
367     *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
368     *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
369     *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
370     *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
371     *data += 4;
372   } else {
373     LoadTwoScalars(data, l1, l2);
374     LoadTwoScalars(data, l3, l4);
375   }
376 }
377 
LoadSingleScalar(const float ** data,Packet * l)378 ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
379   *l = Eigen::internal::pload1<Packet>(*data);
380   ++(*data);
381 }
382 
LoadTwoScalars(const float ** data,Packet * l1,Packet * l2)383 ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
384   LoadSingleScalar(data, l1);
385   LoadSingleScalar(data, l2);
386 }
387 
LoadFourScalars(const float ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)388 ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
389                                    Packet* l3, Packet* l4) {
390   LoadTwoScalars(data, l1, l2);
391   LoadTwoScalars(data, l3, l4);
392 }
393 
394 template <typename T>
LoadThreeScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3)395 ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
396                                     Packet* l3) {
397   LoadTwoScalars(data, l1, l2);
398   LoadSingleScalar(data, l3);
399 }
400 
401 template <typename T>
LoadSixScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4,Packet * l5,Packet * l6)402 ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
403                                   Packet* l3, Packet* l4, Packet* l5,
404                                   Packet* l6) {
405   LoadFourScalars(data, l1, l2, l3, l4);
406   LoadTwoScalars(data, l5, l6);
407 }
408 
409 // Vectorized version of ScalarMulAdd.
MulAdd(const Packet a,const bfloat16 ** binp,float ** out)410 ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
411   auto inp = reinterpret_cast<const float*>(*binp);
412   const auto b = LOAD(inp);
413   EXPAND_BFLOAT_L(b, b_0);
414   EXPAND_BFLOAT_U(b, b_1);
415   *binp += 2 * kNumOperands;
416   auto c1 = LOAD(*out);
417   auto c2 = LOAD(*out + kNumOperands);
418   FMA(a, b_0, c1, c1);
419   FMA(a, b_1, c2, c2);
420   STORE(*out, c1);
421   STORE(*out + kNumOperands, c2);
422   *out += 2 * kNumOperands;
423 }
424 
425 // Vectorized version of ScalarMulAdd3Way.
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)426 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
427                               const bfloat16** binp1, const bfloat16** binp2,
428                               const bfloat16** binp3, float** out) {
429   auto inp1 = reinterpret_cast<const float*>(*binp1);
430   auto inp2 = reinterpret_cast<const float*>(*binp2);
431   auto inp3 = reinterpret_cast<const float*>(*binp3);
432   auto c1 = LOAD(*out);
433   auto c2 = LOAD(*out + kNumOperands);
434   const auto b1 = LOAD(inp1);
435   EXPAND_BFLOAT_L(b1, b1_0);
436   EXPAND_BFLOAT_U(b1, b1_1);
437   *binp1 += 2 * kNumOperands;
438   const auto b2 = LOAD(inp2);
439   EXPAND_BFLOAT_L(b2, b2_0);
440   EXPAND_BFLOAT_U(b2, b2_1);
441   *binp2 += 2 * kNumOperands;
442   const auto b3 = LOAD(inp3);
443   EXPAND_BFLOAT_L(b3, b3_0);
444   EXPAND_BFLOAT_U(b3, b3_1);
445   *binp3 += 2 * kNumOperands;
446   FMA(a1, b1_0, c1, c1);
447   FMA(a1, b1_1, c2, c2);
448   FMA(a2, b2_0, c1, c1);
449   FMA(a2, b2_1, c2, c2);
450   FMA(a3, b3_0, c1, c1);
451   FMA(a3, b3_1, c2, c2);
452   STORE(*out, c1);
453   STORE(*out + kNumOperands, c2);
454   *out += 2 * kNumOperands;
455 }
456 
457 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)458 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
459                                  const Packet a3, const bfloat16** binp1,
460                                  const bfloat16** binp2, const bfloat16** binp3,
461                                  float** out) {
462   auto inp1 = reinterpret_cast<const float*>(*binp1);
463   auto inp2 = reinterpret_cast<const float*>(*binp2);
464   auto inp3 = reinterpret_cast<const float*>(*binp3);
465   auto c1 = LOAD(*out);
466   auto c2 = LOAD(*out + kNumOperands);
467   const auto b1 = LOAD(inp1);
468   const auto b2 = LOAD(inp2);
469   const auto b3 = LOAD(inp3);
470 
471   EXPAND_BFLOAT_L(b1, b1_0);
472   EXPAND_BFLOAT_U(b1, b1_1);
473   EXPAND_BFLOAT_L(b2, b2_0);
474   EXPAND_BFLOAT_U(b2, b2_1);
475   EXPAND_BFLOAT_L(b3, b3_0);
476   EXPAND_BFLOAT_U(b3, b3_1);
477   auto c3 = LOAD(*out + 2 * kNumOperands);
478   auto c4 = LOAD(*out + 3 * kNumOperands);
479   const auto b4 = LOAD(inp1 + kNumOperands);
480   const auto b5 = LOAD(inp2 + kNumOperands);
481   const auto b6 = LOAD(inp3 + kNumOperands);
482 
483   EXPAND_BFLOAT_L(b4, b4_0);
484   EXPAND_BFLOAT_U(b4, b4_1);
485   EXPAND_BFLOAT_L(b5, b5_0);
486   EXPAND_BFLOAT_U(b5, b5_1);
487   EXPAND_BFLOAT_L(b6, b6_0);
488   EXPAND_BFLOAT_U(b6, b6_1);
489 
490   FMA(a1, b1_0, c1, c1);
491   FMA(a1, b1_1, c2, c2);
492   FMA(a1, b4_0, c3, c3);
493   FMA(a1, b4_1, c4, c4);
494   FMA(a2, b2_0, c1, c1);
495   FMA(a2, b2_1, c2, c2);
496   FMA(a2, b5_0, c3, c3);
497   FMA(a2, b5_1, c4, c4);
498   FMA(a3, b3_0, c1, c1);
499   FMA(a3, b3_1, c2, c2);
500   FMA(a3, b6_0, c3, c3);
501   FMA(a3, b6_1, c4, c4);
502   STORE(*out, c1);
503   STORE(*out + kNumOperands, c2);
504   STORE(*out + 2 * kNumOperands, c3);
505   STORE(*out + 3 * kNumOperands, c4);
506   *out += 4 * kNumOperands;
507   *binp1 += 4 * kNumOperands;
508   *binp2 += 4 * kNumOperands;
509   *binp3 += 4 * kNumOperands;
510 }
511 
512 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)513 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
514                                  const Packet a3, const bfloat16** inp1,
515                                  const bfloat16** inp2, const bfloat16** inp3,
516                                  float** out) {
517   for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
518     TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
519     TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
520   }
521 }
522 
523 // Vectorized version of ScalarMulAdd
MulAdd(const Packet a,const float ** inp,float ** out)524 ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
525   const auto b = LOAD(*inp);
526   *inp += kNumOperands;
527   auto c = LOAD(*out);
528   FMA(a, b, c, c);
529   STORE(*out, c);
530   *out += kNumOperands;
531 }
532 
533 // Vectorized version of ScalarMulAdd3Way
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)534 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
535                               const float** inp1, const float** inp2,
536                               const float** inp3, float** out) {
537   auto c = LOAD(*out);
538   const auto b1 = LOAD(*inp1);
539   *inp1 += kNumOperands;
540   const auto b2 = LOAD(*inp2);
541   *inp2 += kNumOperands;
542   const auto b3 = LOAD(*inp3);
543   *inp3 += kNumOperands;
544   FMA(a1, b1, c, c);
545   FMA(a2, b2, c, c);
546   FMA(a3, b3, c, c);
547   STORE(*out, c);
548   *out += kNumOperands;
549 }
550 
551 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)552 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
553                                  const Packet a3, const float** inp1,
554                                  const float** inp2, const float** inp3,
555                                  float** out) {
556   auto c1 = LOAD(*out);
557   const auto b1 = LOAD(*inp1);
558   const auto b2 = LOAD(*inp2);
559   const auto b3 = LOAD(*inp3);
560 
561   auto c2 = LOAD(*out + kNumOperands);
562   const auto b4 = LOAD(*inp1 + kNumOperands);
563   const auto b5 = LOAD(*inp2 + kNumOperands);
564   const auto b6 = LOAD(*inp3 + kNumOperands);
565 
566   FMA(a1, b1, c1, c1);
567   FMA(a1, b4, c2, c2);
568   FMA(a2, b2, c1, c1);
569   FMA(a2, b5, c2, c2);
570   FMA(a3, b3, c1, c1);
571   FMA(a3, b6, c2, c2);
572   STORE(*out, c1);
573   STORE(*out + kNumOperands, c2);
574   *out += 2 * kNumOperands;
575   *inp1 += 2 * kNumOperands;
576   *inp2 += 2 * kNumOperands;
577   *inp3 += 2 * kNumOperands;
578 }
579 
580 // Unroll MulAdd3Way for four iterations
FourMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)581 ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
582                                   const Packet a3, const float** inp1,
583                                   const float** inp2, const float** inp3,
584                                   float** out) {
585   TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
586   TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
587 }
588 
589 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)590 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
591                                  const Packet a3, const float** inp1,
592                                  const float** inp2, const float** inp3,
593                                  float** out) {
594   if (kNumOperands == 8) {
595     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
596     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
597     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
598     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
599   } else {
600     DCHECK_LE(4 * kNumOperands, 128);
601     for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
602       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
603       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
604       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
605       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
606     }
607   }
608 }
609 // Computes product of "left_slices" with "num_cols" columns of "right", and
610 // stores the output in *"output".
611 // Note that left_slices is a list of SparseSlices, which are conceptually
612 // assumed to be concatenated along the column dimension. Also each SparseSlice
613 // is encoded as a list of blocks with upto N columns. See SparseSlice for more
614 // details.
615 template <typename TL, typename TR, int Cols>
GEPP(const std::vector<SparseSlice<TL> * > & left_slices,const Eigen::TensorMap<Eigen::Tensor<const TR,2,Eigen::RowMajor>,Eigen::Aligned> & right,const int num_cols,Matrix * output)616 inline void GEPP(
617     const std::vector<SparseSlice<TL>*>& left_slices,
618     const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
619                            Eigen::Aligned>& right,
620     const int num_cols, Matrix* output) {
621   const int cols = (Cols == -1) ? num_cols : Cols;
622   DCHECK_EQ(num_cols, cols);
623   const int right_num_cols = right.dimension(1);
624   const int output_num_cols = output->dimension(1);
625   static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
626   const int cols_mod = cols % kNumOperandsR;
627   int k_offset = 0;
628   // Pre-compute pointers for output matrix.
629   float* out_ptrs[M];
630   float* const out_start = &(*output)(0, 0);
631   for (int j = 0; j < M; ++j) {
632     out_ptrs[j] = out_start + output_num_cols * j;
633   }
634   for (const auto* left_slice : left_slices) {
635     const auto& left = *left_slice;
636     const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
637     const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
638     const int num_blocks = left.index3_offset.size();
639     int begin3 = 0;
640     int begin = 0;
641     for (int i = 0; i < num_blocks; ++i) {
642       // Pre-compute pointers for right matrix
643       const TR* right_ptrs[K];
644       const auto* const right_start = &right(k_offset, 0);
645       DCHECK_LT(k_offset, right.dimension(0));
646       for (int j = 0; j < K; ++j) {
647         right_ptrs[j] = right_start + right_num_cols * j;
648       }
649 
650       const int end3 = left.index3_offset[i];
651       int j = begin3;
652       // Loop unrolled for 2 iterations.
653       for (; j + 1 < end3; j += 2) {
654         Packet l1, l2, l3, nl1, nl2, nl3;
655         LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
656         const auto& index = left.index3[j];
657         const auto& nindex = left.index3[j + 1];
658         float* out = out_ptrs[index.m];
659         float* nout = out_ptrs[nindex.m];
660         const auto* r1 = right_ptrs[index.k1];
661         const auto* r2 = right_ptrs[index.k2];
662         const auto* r3 = right_ptrs[index.k3];
663 
664         const auto* nr1 = right_ptrs[nindex.k1];
665         const auto* nr2 = right_ptrs[nindex.k2];
666         const auto* nr3 = right_ptrs[nindex.k3];
667         if (cols == 128) {
668           MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
669           MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
670         } else {
671           for (int n = 0; n < cols / kNumOperandsR; ++n) {
672             MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
673             MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
674           }
675 
676           const float sl1 = Eigen::internal::pfirst<Packet>(l1);
677           const float sl2 = Eigen::internal::pfirst<Packet>(l2);
678           const float sl3 = Eigen::internal::pfirst<Packet>(l3);
679           const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
680           const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
681           const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
682           for (int k = 0; k < cols_mod; ++k) {
683             ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
684             ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
685           }
686         }
687       }
688       if (j < end3) {
689         Packet l1, l2, l3;
690         LoadThreeScalars(&data3, &l1, &l2, &l3);
691 
692         const auto& index = left.index3[j];
693         float* out = out_ptrs[index.m];
694         const auto* r1 = right_ptrs[index.k1];
695         const auto* r2 = right_ptrs[index.k2];
696         const auto* r3 = right_ptrs[index.k3];
697         if (cols == 128) {
698           MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
699         } else {
700           for (int n = 0; n < cols / kNumOperandsR; ++n) {
701             MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
702           }
703           const float sl1 = Eigen::internal::pfirst<Packet>(l1);
704           const float sl2 = Eigen::internal::pfirst<Packet>(l2);
705           const float sl3 = Eigen::internal::pfirst<Packet>(l3);
706           for (int k = 0; k < cols_mod; ++k) {
707             ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
708           }
709         }
710       }
711       begin3 = end3;
712       int end = left.index_offset[i];
713       // Loop unrolled for 4 iterations.
714       j = begin;
715       for (; j + 3 < end; j += 4) {
716         Packet l, nl, n2l, n3l;
717         LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
718 
719         const auto& index = left.index[j];
720         const auto& nindex = left.index[j + 1];
721         const auto& n2index = left.index[j + 2];
722         const auto& n3index = left.index[j + 3];
723         const auto* r = right_ptrs[index.k];
724         const auto* nr = right_ptrs[nindex.k];
725         const auto* n2r = right_ptrs[n2index.k];
726         const auto* n3r = right_ptrs[n3index.k];
727         float* out = out_ptrs[index.m];
728         float* nout = out_ptrs[nindex.m];
729         float* n2out = out_ptrs[n2index.m];
730         float* n3out = out_ptrs[n3index.m];
731 
732         for (int n = 0; n < cols / kNumOperandsR; ++n) {
733           MulAdd(l, &r, &out);
734           MulAdd(nl, &nr, &nout);
735           MulAdd(n2l, &n2r, &n2out);
736           MulAdd(n3l, &n3r, &n3out);
737         }
738 
739         const float sl1 = Eigen::internal::pfirst<Packet>(l);
740         const float sl2 = Eigen::internal::pfirst<Packet>(nl);
741         const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
742         const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
743         for (int k = 0; k < cols_mod; ++k) {
744           ScalarMulAdd(sl1, &r, &out);
745           ScalarMulAdd(sl2, &nr, &nout);
746           ScalarMulAdd(sl3, &n2r, &n2out);
747           ScalarMulAdd(sl4, &n3r, &n3out);
748         }
749       }
750       while (j < end) {
751         Packet l;
752         LoadSingleScalar(&data, &l);
753         const auto& index = left.index[j];
754         const auto* r = right_ptrs[index.k];
755         float* out = out_ptrs[index.m];
756         for (int n = 0; n < cols / kNumOperandsR; ++n) {
757           MulAdd(l, &r, &out);
758         }
759         const float sl = Eigen::internal::pfirst<Packet>(l);
760         for (int k = 0; k < cols_mod; ++k) {
761           ScalarMulAdd(sl, &r, &out);
762         }
763         j++;
764       }
765       k_offset += left.block_size;
766       begin = end;
767     }
768   }
769 }
770 
771 #undef LOAD
772 #undef EXPAND_BFLOAT_L
773 #undef EXPAND_BFLOAT_U
774 #undef STORE
775 #undef FMA
776 
777 }  // namespace
778 
779 template <typename TL, typename TR>
780 class SparseMatMul {
781   using MatrixL = BasicMatrix<TL>;
782   using MatrixR = BasicMatrix<TR>;
783   using ConstMatrixMapL = BasicMatrixMap<const TL>;
784   using ConstMatrixMapR = BasicMatrixMap<const TR>;
785   using MatrixMapR = BasicMatrixMap<TR>;
786 
787  public:
788   // Not used; added to match interface of LibxsmmSparseMatMul
789   struct TensorInfoCache {};
790 
791   // Perform matrix multiplication of "left" and "right", and store the result
792   // in *"output".
793  public:
794   static inline void Compute(TensorInfoCache* cache,
795                              const ConstMatrixMapL& left,
796                              const ConstMatrixMapR& right, bool transpose_left,
797                              const DeviceBase::CpuWorkerThreads* thread_pool,
798                              bool transpose_output, MatrixMap* output);
799 
800  private:
801   // Computes multiplication of left and num_cols columns of right, and stores
802   // the output block in *"output" at offsets "output_row_offset" and
803   // "output_col_offset". If assign is true, assigns the value to that block,
804   // else adds the values to the existing values.
805   static inline void ComputeOutputBlock(
806       const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
807       int num_cols, int output_row_offset, int output_col_offset, bool assign,
808       bool transpose_output, MatrixMap* output);
809 
810   // Encodes "mat" using a sparse representation and stores that in
811   // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
812   // "slice_num_cols", each grid element is converted into a SparseSlice and
813   // stored in mat_slices. "slice_block_size" is used to perform further column
814   // blocking of each slice.
815   static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
816       const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
817       int slice_block_size, int slice_num_cols,
818       std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
819       const DeviceBase::CpuWorkerThreads* thread_pool);
820 
821   // This function chops "mat" along column dimension into pieces with at most N
822   // columns, and concatenates the pieces one after the other in "buffer". It
823   // returns the list of the pieces in "slices". It returns a BlockingCounter
824   // which should be used to wait for the shuffle operations to complete.
825   static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
826       const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
827       int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
828       MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
829 
830   // Helper function for CreateDenseSlices to move the data around. It returns a
831   // BlockingCounter which should be used to wait for the shuffle operations to
832   // complete.
833   static inline BlockingCounter* ShuffleMatrix(
834       const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
835       int slice_col_start, int slice_num_cols, const int N,
836       const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
837 
838   // Helper function for CreateDenseSlices to create slices.
839   static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
840                                  const int num_slices,
841                                  std::vector<ConstMatrixMapR*>* slices);
842 
843   // Heuristics to compute various block sizes.
844   // KR, NR: block sizes for "right". We run blocking iterations that operate on
845   // matrices with at most this size.
846   // KL: grid size along the column dimension used while encoding left.
847   // IB, JB: number of left and right slices to multiply together. This is used
848   // for ordering different ComputeBlockOutput operations inside each blocking
849   // iteration so as to potentially reduce the working set size.
850   static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
851                                        const ConstMatrixMapR& right,
852                                        bool transpose_left, int num_threads,
853                                        int* KR, int* NR, int* KL, int* JB,
854                                        int* IB);
855 
856   TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
857 };
858 
859 #ifdef TENSORFLOW_USE_LIBXSMM
860 template <typename TL, typename TR>
861 class LibxsmmSparseMatMul {
862   using MatrixL = BasicMatrix<TL>;
863   using MatrixR = BasicMatrix<TR>;
864   using ConstMatrixMapL = BasicMatrixMap<const TL>;
865   using ConstMatrixMapR = BasicMatrixMap<const TR>;
866   using MatrixMapR = BasicMatrixMap<TR>;
867 
868  public:
869   // This structure contains a set of libxsmm kernels for sizes that have been
870   // encountered previously by this operator so that libxsmm does not need to
871   // reallocate its scratchpad memory each time (which hurts performance
872   // substantially).
873   struct TensorInfoCache {
874     struct TensorInfoCacheEntry {
875       // Parameters for kernel
876       int M;
877       int K;
878       int N;
879       int max_threads;
880       // libxsmm handle and matrix data
881       libxsmm_spmdm_handle handle;
882       libxsmm_CSR_sparseslice* output_csr;
883       // Chain to non-libxsmm implementation's cache in case that ever becomes
884       // useful (it is an empty struct right now)
885       typename SparseMatMul<TL, TR>::TensorInfoCache
886           non_libxsmm_cache;  // Currently not used
887     };
888     // protects entries; invariant: entries is a valid std::multimap
889     tensorflow::mutex lock;
890     // Because there could be multiple matrix multiplies with the same sizes
891     // going on at the same time, we need to allow multiple cache entries for a
892     // given set of parameters. Taking and returning entries is used to make
893     // sure the same cache entry is not used from two threads at a time.
894     std::multimap<std::tuple<int, int, int, int>,
895                   std::unique_ptr<TensorInfoCacheEntry>>
896         entries GUARDED_BY(lock);
897 
TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache898     TensorInfoCache() : lock(), entries() {}
899     // Look up and remove first entry with these parameters, creating one if
900     // there isn't one
take_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache901     std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
902                                                            int max_threads)
903         LOCKS_EXCLUDED(lock) {
904       tensorflow::mutex_lock ml(lock);
905       auto key = std::make_tuple(M, K, N, max_threads);
906       auto it = entries.find(key);
907       if (it != entries.end()) {
908         auto val = std::move(it->second);
909         entries.erase(it);
910         return val;
911       } else {
912         std::unique_ptr<TensorInfoCacheEntry> e{
913             new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
914         // setup scoped allocator, which uses cpu_allocator() for this scope
915         const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
916         libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
917         return e;
918       }
919     }
920     // Add a cache entry with certain parameters
return_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache921     void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
922         LOCKS_EXCLUDED(lock) {
923       tensorflow::mutex_lock ml(lock);
924       auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
925       entries.insert(std::make_pair(key, std::move(e)));
926     }
~TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache927     ~TensorInfoCache() {
928       tensorflow::mutex_lock ml(lock);
929       for (auto& p : entries) {
930         libxsmm_spmdm_destroy(&p.second->handle);
931       }
932       entries.clear();
933     }
934 
935    private:
936     TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
937   };
938 
939   // Perform matrix multiplication of "left" and "right", and store the result
940   // in *"output".
941  public:
942   static inline void Compute(TensorInfoCache* cache,
943                              const ConstMatrixMapL& left,
944                              const ConstMatrixMapR& right, bool transpose_left,
945                              const DeviceBase::CpuWorkerThreads* thread_pool,
946                              bool transpose_output, MatrixMap* output);
947 
948  private:
949   TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
950 };
951 #endif
952 
953 template <typename TL, typename TR,
954           template <typename TL2, typename TR2> class DoMatMul>
955 class SparseMatMulOp : public OpKernel {
956   using MatrixR = BasicMatrix<TR>;
957   using ConstMatrixMapR = BasicMatrixMap<const TR>;
958 
959  public:
SparseMatMulOp(OpKernelConstruction * ctx)960   explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
961     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
962     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
963     OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
964     OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
965   }
966 
Compute(OpKernelContext * ctx)967   void Compute(OpKernelContext* ctx) override {
968     const Tensor& a = ctx->input(0);
969     const Tensor& b = ctx->input(1);
970     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
971                 errors::InvalidArgument("a is not a matrix"));
972     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
973                 errors::InvalidArgument("b is not a matrix"));
974 
975     const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
976     const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
977     const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
978     const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
979 
980     OP_REQUIRES(ctx, k == k2,
981                 errors::InvalidArgument(
982                     "Matrix size incompatible: a: ", a.shape().DebugString(),
983                     ", b: ", b.shape().DebugString()));
984     Tensor* output = nullptr;
985     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
986 
987     if (k == 0) {
988       // If the inner dimension k in the matrix multiplication is zero, we fill
989       // the output with zeros.
990       functor::SetZeroFunctor<CPUDevice, float> f;
991       f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
992       return;
993     }
994 
995     auto out = output->matrix<float>();
996 
997     std::unique_ptr<Tensor> a_float;
998     std::unique_ptr<Tensor> b_float;
999     if (!a_is_sparse_ && !b_is_sparse_) {
1000       auto left = &a;
1001       auto right = &b;
1002       // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
1003       if (std::is_same<TL, bfloat16>::value) {
1004         a_float.reset(new Tensor(DT_FLOAT, a.shape()));
1005         BFloat16ToFloat(a.flat<bfloat16>().data(),
1006                         a_float->flat<float>().data(), a.NumElements());
1007         left = a_float.get();
1008       }
1009       if (std::is_same<TR, bfloat16>::value) {
1010         b_float.reset(new Tensor(DT_FLOAT, b.shape()));
1011         BFloat16ToFloat(b.flat<bfloat16>().data(),
1012                         b_float->flat<float>().data(), b.NumElements());
1013         right = b_float.get();
1014       }
1015       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
1016       dim_pair[0].first = transpose_a_ ? 0 : 1;
1017       dim_pair[0].second = transpose_b_ ? 1 : 0;
1018 
1019       out.device(ctx->template eigen_device<CPUDevice>()) =
1020           left->matrix<float>().contract(right->matrix<float>(), dim_pair);
1021       return;
1022     }
1023 
1024     auto left = &a;
1025     auto right = &b;
1026     bool transpose_output = false;
1027     bool transpose_a = transpose_a_;
1028     bool transpose_b = transpose_b_;
1029     if (!a_is_sparse_) {
1030       // Swap the order of multiplications using the identity:
1031       // A * B = (B' *  A')'.
1032       std::swap(left, right);
1033       std::swap(transpose_a, transpose_b);
1034       transpose_a = !transpose_a;
1035       transpose_b = !transpose_b;
1036       transpose_output = !transpose_output;
1037     }
1038 
1039     std::unique_ptr<Tensor> right_tr;
1040     if (transpose_b) {
1041       // TODO(agarwal): avoid transposing the matrix here and directly handle
1042       // transpose in CreateDenseSlices.
1043       right_tr.reset(
1044           new Tensor(right->dtype(),
1045                      TensorShape({right->dim_size(1), right->dim_size(0)})));
1046 
1047       const auto perm = dsizes_10();
1048       if (transpose_output) {
1049         right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
1050             right->matrix<TL>().shuffle(perm);
1051       } else {
1052         right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
1053             right->matrix<TR>().shuffle(perm);
1054       }
1055       right = right_tr.get();
1056     }
1057 
1058     if (transpose_output) {
1059       DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
1060                                 right->matrix<TL>(), transpose_a,
1061                                 ctx->device()->tensorflow_cpu_worker_threads(),
1062                                 transpose_output, &out);
1063     } else {
1064       DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
1065                                 right->matrix<TR>(), transpose_a,
1066                                 ctx->device()->tensorflow_cpu_worker_threads(),
1067                                 transpose_output, &out);
1068     }
1069   }
1070 
1071  private:
1072   bool transpose_a_;
1073   bool transpose_b_;
1074   bool a_is_sparse_;
1075   bool b_is_sparse_;
1076 
1077   // Cache for non-transposed-output multiply
1078   typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
1079   // Cache for transposed-output multiply
1080   typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
1081 
1082   TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
1083 };
1084 
1085 template <typename TL, typename TR>
ComputeOutputBlock(const std::vector<SparseSlice<TL> * > & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,int num_cols,int output_row_offset,int output_col_offset,bool assign,bool transpose_output,MatrixMap * output)1086 inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
1087     const std::vector<SparseSlice<TL>*>& left,
1088     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
1089     int output_row_offset, int output_col_offset, bool assign,
1090     bool transpose_output, MatrixMap* output) {
1091   const auto perm = dsizes_10();
1092   int num_rows = left[0]->num_rows;
1093   const int rhs_num_cols = right.dimension(1);
1094   DCHECK_LE(num_cols, rhs_num_cols);
1095   Matrix out(num_rows, rhs_num_cols);
1096   out.setZero();
1097   if (num_cols == N) {
1098     GEPP<TL, TR, N>(left, right, num_cols, &out);
1099   } else {
1100     GEPP<TL, TR, -1>(left, right, num_cols, &out);
1101   }
1102   if (!assign) {
1103     const DSizes begin(output_row_offset, output_col_offset);
1104     const DSizes sizes(num_rows, num_cols);
1105     if (transpose_output) {
1106       if (num_cols == rhs_num_cols) {
1107         output->shuffle(perm).slice(begin, sizes) += out;
1108       } else {
1109         const auto zero = dsizes_00();
1110         output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
1111       }
1112     } else {
1113       if (num_cols == rhs_num_cols) {
1114         output->slice(begin, sizes) += out;
1115       } else {
1116         const auto zero = dsizes_00();
1117         output->slice(begin, sizes) += out.slice(zero, sizes);
1118       }
1119     }
1120   } else {
1121     std::unique_ptr<Matrix> out_tr;
1122     if (transpose_output) {
1123       out_tr.reset(new Matrix(rhs_num_cols, num_rows));
1124       *out_tr = out.shuffle(perm);
1125       std::swap(output_row_offset, output_col_offset);
1126       std::swap(num_rows, num_cols);
1127     }
1128     const Matrix& final_out = transpose_output ? *out_tr : out;
1129     for (int i = 0; i < num_rows; ++i) {
1130       memcpy(&(*output)(output_row_offset + i, output_col_offset),
1131              &final_out(i, 0), num_cols * sizeof(float));
1132     }
1133   }
1134 }
1135 
1136 template <typename TL, typename TR>
1137 inline std::unique_ptr<BlockingCounter>
CreateSparseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & mat,bool transpose,int slice_num_rows,int slice_block_size,int slice_num_cols,std::vector<std::vector<SparseSlice<TL> * >> * mat_slices,const DeviceBase::CpuWorkerThreads * thread_pool)1138 SparseMatMul<TL, TR>::CreateSparseSlices(
1139     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
1140     int slice_num_rows, int slice_block_size, int slice_num_cols,
1141     std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
1142     const DeviceBase::CpuWorkerThreads* thread_pool) {
1143   const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
1144   const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
1145   const int num_slices_dim0 =
1146       std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
1147   const int num_slices_dim1 =
1148       std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
1149   mat_slices->resize(num_slices_dim0);
1150   BlockingCounter* counter =
1151       new BlockingCounter(num_slices_dim0 * num_slices_dim1);
1152   auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
1153                                    SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
1154                                    int col_offset) {
1155     if (transpose) {
1156       sparse_slice->template Initialize<true>(*slice, col_offset);
1157     } else {
1158       sparse_slice->template Initialize<false>(*slice, col_offset);
1159     }
1160     delete slice;
1161     counter->DecrementCount();
1162   };
1163   for (int i = 0; i < num_slices_dim0; ++i) {
1164     (*mat_slices)[i].resize(num_slices_dim1);
1165     int num_rows =
1166         std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
1167     for (int j = 0; j < num_slices_dim1; ++j) {
1168       int num_cols =
1169           std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
1170       SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
1171       if (transpose) {
1172         slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1173             &mat(0, i * slice_num_rows), mat.dimensions());
1174       } else {
1175         DSizes d(num_rows, mat_num_cols);
1176         slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1177             &mat(i * slice_num_rows, 0), d);
1178       }
1179       auto* sparse_slice =
1180           new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
1181       (*mat_slices)[i][j] = sparse_slice;
1182       thread_pool->workers->Schedule(
1183           [=]() { work(sparse_slice, slice, slice_num_cols * j); });
1184     }
1185   }
1186   return std::unique_ptr<BlockingCounter>(counter);
1187 }
1188 #define LOAD(x) Eigen::internal::ploadu<Packet>((x));
1189 #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
1190 #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
1191 
1192 template <int NUM_ELEM = -1>
CopyAndMayBeInterleaveBfloat16(void * bdst,const void * bsrc,int num_elements)1193 ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
1194                                                   int num_elements) {
1195   DCHECK_GE(kNumOperands, 8);
1196   static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
1197   const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
1198   DCHECK_EQ(num, num_elements);
1199   const float* src = reinterpret_cast<const float*>(bsrc);
1200   float* dst = reinterpret_cast<float*>(bdst);
1201   for (int index = 0; index + kStep <= num; index += kStep) {
1202     auto in = LOAD(src);
1203     auto tmp = INTERLEAVE(in);
1204     STORE(dst, tmp);
1205     src += kNumOperands;
1206     dst += kNumOperands;
1207   }
1208   if (num % kStep != 0) {
1209     memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
1210   }
1211 }
1212 
1213 template <typename T>
CopyAndMayBeInterleave(void * dst,const void * src,int num_elements)1214 ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
1215                                           int num_elements) {
1216   if (std::is_same<T, float>::value || kNumOperands < 8) {
1217     memcpy(dst, src, num_elements * sizeof(T));
1218   } else if (std::is_same<T, bfloat16>::value) {
1219     if (num_elements == N) {
1220       CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
1221     } else {
1222       CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
1223     }
1224   } else {
1225     LOG(FATAL) << "Unsupported type";
1226   }
1227 }
1228 
1229 #undef LOAD
1230 #undef Interleave
1231 #undef Store
1232 
1233 template <typename TL, typename TR>
ShuffleMatrix(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int slice_row_start,int slice_num_rows,int slice_col_start,int slice_num_cols,const int N,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer)1234 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
1235     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
1236     int slice_row_start, int slice_num_rows, int slice_col_start,
1237     int slice_num_cols, const int N,
1238     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
1239   DCHECK_EQ(N % 2, 0);
1240   DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
1241   int num_threads = std::min(thread_pool->num_threads, 16);
1242   BlockingCounter* counter = new BlockingCounter(num_threads);
1243   DCHECK_EQ(N, buffer->dimension(1));
1244   auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
1245                        slice_num_cols, N, buffer, counter](int s, int e) {
1246     const int row_start = s % slice_num_rows + slice_row_start;
1247     const int col_start = s / slice_num_rows * N + slice_col_start;
1248     auto* out_start = &(*buffer)(s, 0);
1249     const auto* input_start = &mat(row_start, col_start);
1250     const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
1251                                  slice_col_start + slice_num_cols - 1);
1252     const int mat_num_cols = mat.dimension(1);
1253     const int row_slice_size = slice_num_rows * mat_num_cols;
1254 
1255     const int aligned_end = slice_num_cols / N * slice_num_rows;
1256     const int e1 = std::min(e, aligned_end);
1257     while (s < e1) {
1258       CopyAndMayBeInterleave<TR>(out_start, input_start, N);
1259       out_start += N;
1260       input_start += mat_num_cols;
1261       if (input_start > input_end) {
1262         input_start = input_start - row_slice_size + N;
1263       }
1264       ++s;
1265     }
1266     int s1 = std::max(s, aligned_end);
1267     const int copy_num_cols = slice_num_cols % N;
1268     while (s1 < e) {
1269       CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
1270       out_start += N;
1271       input_start += mat_num_cols;
1272       ++s1;
1273     }
1274     if (counter) counter->DecrementCount();
1275   };
1276 
1277   int start = 0;
1278   int end = 0;
1279   int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
1280   DCHECK_LE(num_out_rows, buffer->dimension(0));
1281   for (int i = std::max(1, num_threads); i > 0; --i) {
1282     end = start + num_out_rows / i;
1283     thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
1284     num_out_rows -= (end - start);
1285     start = end;
1286   }
1287   return counter;
1288 }
1289 
1290 template <typename TL, typename TR>
SliceMatrix(const MatrixR & mat,const int num_rows,const int num_slices,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1291 inline void SparseMatMul<TL, TR>::SliceMatrix(
1292     const MatrixR& mat, const int num_rows, const int num_slices,
1293     std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1294   slices->resize(num_slices);
1295   DSizes d(num_rows, mat.dimension(1));
1296   DCHECK_LE(num_rows * num_slices, mat.dimension(0));
1297   for (int i = 0; i < num_slices; ++i) {
1298     (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
1299   }
1300 }
1301 
1302 template <typename TL, typename TR>
CreateDenseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int row_start,int num_rows,int col_start,int num_cols,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1303 inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
1304     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
1305     int num_rows, int col_start, int num_cols,
1306     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
1307     std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1308   std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
1309       mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
1310   const int num_slices = (num_cols + N - 1) / N;
1311   SliceMatrix(*buffer, num_rows, num_slices, slices);
1312   return shuffle_counter;
1313 }
1314 
1315 template <typename TL, typename TR>
ComputeBlockSizes(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,int num_threads,int * KR,int * NR,int * KL,int * JB,int * IB)1316 inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
1317     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1318     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1319     bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
1320     int* IB) {
1321   // Heuristics for calculating block sizes
1322   // Assume two hyperthreads per core.
1323   const int est_num_cores = std::max(1, (num_threads + 1) / 2);
1324   // Use block of rhs with at most 128K floats per core.
1325   const int mem = est_num_cores * 128 * 1024;
1326   *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
1327   *NR = right.dimension(1);
1328   if (*KR * *NR > mem) {
1329     // 4096 may be enough to amortize the cost of writes.
1330     *KR = std::min<int>(*KR, 4096);
1331   }
1332   // Use sizes that are multiples of K and 256.
1333   *KR = std::max(1, *KR / K) * K;
1334   *NR = std::max(1, *NR / 256) * 256;
1335   if (*KR * *NR > mem) {
1336     *NR = mem / *KR;
1337   }
1338   *NR = std::max(1, *NR / 256) * 256;
1339 
1340   const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1341   const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1342   for (*KL = 1024; *KL > K; *KL /= 2) {
1343     if (*KR % *KL == 0 &&
1344         std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
1345       break;
1346     }
1347   }
1348   DCHECK_EQ(*KL % K, 0);
1349   DCHECK_GE(*KR, *KL);
1350   if (*KR < right.dimension(0)) {
1351     CHECK_EQ(*KR % *KL, 0);
1352   }
1353 
1354   *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
1355   *IB = 8 * *JB;
1356   DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
1357 }
1358 
1359 #ifdef TENSORFLOW_USE_LIBXSMM
1360 
1361 template <typename F>
do_on_all_threads(const DeviceBase::CpuWorkerThreads * thread_pool,const F & f)1362 void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
1363                        const F& f) {
1364   int num_threads = thread_pool->num_threads;
1365   if (num_threads == 0) {
1366     LOG(FATAL) << "Have 0 threads in thread pool";
1367   } else if (num_threads == 1) {
1368     f(0);
1369   } else {
1370     BlockingCounter counter(num_threads - 1);
1371     for (int i = 1; i < num_threads; ++i) {
1372       thread_pool->workers->Schedule([&, i]() {
1373         f(i);
1374         counter.DecrementCount();
1375       });
1376     }
1377     f(0);
1378     counter.Wait();
1379   }
1380 }
1381 
1382 template <typename T>
1383 struct empty_type_wrapper {};
1384 
1385 // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
1386 // allow overloading
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,const float * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1387 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1388     empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1389     const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
1390     int tid, int nthreads) {
1391   return libxsmm_spmdm_createSparseSlice_fp32_thread(
1392       handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
1393 }
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,const bfloat16 * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1394 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1395     empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1396     char transA, const bfloat16* A,
1397     libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
1398     int nthreads) {
1399   return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
1400       handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a,
1401       block_id, tid, nthreads);
1402 }
1403 
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,char transB,const bfloat16 * alpha,libxsmm_CSR_sparseslice * A_sparse,const bfloat16 * B,char transC,const bfloat16 * beta,float * C,int block_id,int tid,int nthreads)1404 void wrapper_libxsmm_spmdm_compute_generic_thread(
1405     empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1406     char transA, char transB, const bfloat16* alpha,
1407     libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
1408     const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
1409   return libxsmm_spmdm_compute_bfloat16_thread(
1410       handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
1411       reinterpret_cast<const uint16*>(B), transC,
1412       reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
1413 }
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,char transB,const float * alpha,libxsmm_CSR_sparseslice * A_sparse,const float * B,char transC,const float * beta,float * C,int block_id,int tid,int nthreads)1414 void wrapper_libxsmm_spmdm_compute_generic_thread(
1415     empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1416     char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
1417     const float* B, char transC, const float* beta, float* C, int block_id,
1418     int tid, int nthreads) {
1419   return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
1420                                            A_sparse, B, transC, beta, C,
1421                                            block_id, tid, nthreads);
1422 }
1423 
1424 template <typename TL, typename TR>
Compute(typename LibxsmmSparseMatMul<TL,TR>::TensorInfoCache * cache,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1425 inline void LibxsmmSparseMatMul<TL, TR>::Compute(
1426     typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
1427     const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
1428     const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
1429     bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1430     bool transpose_output, MatrixMap* output) {
1431   if (false) {
1432     // Not handled by libxsmm currently
1433     SparseMatMul<TL, TR>::Compute(
1434         nullptr /* Assumes no cached data for fallback */, left, right,
1435         transpose_left, thread_pool, transpose_output, output);
1436     return;
1437   }
1438   const int num_threads = thread_pool->num_threads;
1439   const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1440   const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1441   const int right_dim0 = right.dimension(0);
1442   const int right_dim1 = right.dimension(1);
1443   CHECK_EQ(left_dim1, right_dim0);
1444   CHECK_EQ(left_dim0,
1445            (transpose_output ? output->dimension(1) : output->dimension(0)));
1446   CHECK_EQ(right_dim1,
1447            (transpose_output ? output->dimension(0) : output->dimension(1)));
1448   if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
1449     // Causes problems in libxsmm
1450     SparseMatMul<TL, TR>::Compute(
1451         nullptr /* Assumes no cached data for fallback */, left, right,
1452         transpose_left, thread_pool, transpose_output, output);
1453     return;
1454   }
1455   auto left_data = left.data();
1456   auto right_data = right.data();
1457   auto output_data = output->data();
1458   // Initialize libxsmm for this matrix; make sure another thread doesn't use
1459   // this handle
1460   auto entry =
1461       cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
1462   // Convert the left matrix to compressed sparse row (CSR) format
1463   ptrdiff_t total_num_creation_blocks =
1464       libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
1465   std::atomic<int> cur_create_block_number;
1466   cur_create_block_number.store(0);
1467   do_on_all_threads(thread_pool, [&](int i) {
1468     while (true) {
1469       int work_item = cur_create_block_number.fetch_add(1);
1470       if (work_item >= total_num_creation_blocks) break;
1471       wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1472           empty_type_wrapper<TL>{}, &entry->handle,
1473           (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
1474           i, num_threads);
1475     }
1476   });
1477   // Do matrix-matrix multiplication
1478   ptrdiff_t total_num_mult_blocks =
1479       libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
1480   std::atomic<int> cur_mult_block_number;
1481   cur_mult_block_number.store(0);
1482   do_on_all_threads(thread_pool, [&](int i) {
1483     while (true) {
1484       int work_item = cur_mult_block_number.fetch_add(1);
1485       if (work_item >= total_num_mult_blocks) break;
1486       const TL alpha(1.0);  // Stored in a variable so we can get a pointer
1487       const TL beta(0.0);   // Stored in a variable so we can get a pointer
1488       wrapper_libxsmm_spmdm_compute_generic_thread(
1489           empty_type_wrapper<TL>{}, &entry->handle,
1490           (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
1491           right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
1492           work_item, i, num_threads);
1493     }
1494   });
1495   // Put handle + CSR storage back into cache
1496   cache->return_cache_entry(std::move(entry));
1497 }
1498 
1499 #endif  // TENSORFLOW_USE_LIBXSMM
1500 
1501 // Here is an overview of the SparseMatMul code. Note that we assume that the
1502 // left matrix is sparse.
1503 //
1504 // The matrix "left" is divided into a grid with blocksize of (M, KL). Each
1505 // block is encoded as a SparseSlice. These grid elements are stored as
1506 // std::vector<std::vector<SparseSlice>>. Each element of the outer vector
1507 // represents M rows of the left matrix. Lets call these elements l_i and lets
1508 // call each element of the inner vector L_mk.
1509 //
1510 // The matrix "right" is divided into a grid with block size KR * NR.  Lets
1511 // denote the blocks on the right as R_kn. Note that we ensure that KL divides
1512 // KR so that for each element R_kn, we don't need to multiply it with any
1513 // partial L_mk blocks.
1514 //
1515 // We then multiply each right side block R_kn with the full "left" matrix and
1516 // update the output. These iterations are run sequentially since R_kn are
1517 // packed into the same underlying temporary buffer.
1518 //
1519 // In each iteration we do the following:
1520 // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
1521 //    (=128) columns and then concatenating these slices into a buffer. This is
1522 //    done so that each slice r_j of R_kn is stored contiguously in memory. Note
1523 //    that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
1524 //    buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
1525 // 2. For each (l_i, r_j), we compute the inner product using the GEPP function
1526 //    and update the output block o_ij. These calls are further blocked to
1527 //    reduce the working set size. In each iteration we take IB elements from
1528 //    {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
1529 template <typename TL, typename TR>
Compute(typename SparseMatMul<TL,TR>::TensorInfoCache *,const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1530 inline void SparseMatMul<TL, TR>::Compute(
1531     typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
1532     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1533     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1534     bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1535     bool transpose_output, MatrixMap* output) {
1536   const int num_threads = thread_pool->num_threads;
1537   int KR, NR, KL, JB, IB;
1538   ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
1539                     &JB, &IB);
1540   // Slice the left matrix
1541   std::vector<std::vector<SparseSlice<TL>*>> left_slices;
1542   std::unique_ptr<BlockingCounter> sparse_slice_counter =
1543       CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
1544                          transpose_left, M, K, KL, &left_slices, thread_pool);
1545   const int num_left_slices = left_slices.size();
1546 
1547   const int right_dim0 = right.dimension(0);
1548   const int right_dim1 = right.dimension(1);
1549   // Allocate buffer for storing slices of right matrix.
1550   // Note buffer needs enough space to hold at most a KR * NR matrix since that
1551   // is the block size per iteration.
1552   const int buffer_num_rows =
1553       std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N;
1554   MatrixR buffer(buffer_num_rows, N);
1555   std::vector<ConstMatrixMapR*> right_slices;
1556 
1557   std::vector<SparseSlice<TL>*> block_left_slices;
1558   std::vector<std::function<void(void)>> tasks;
1559   // Number of blocks based on block sizes of KR * NR.
1560   const int num_k_blocks = (right_dim0 + KR - 1) / KR;
1561   const int num_n_blocks = (right_dim1 + NR - 1) / NR;
1562   std::unique_ptr<BlockingCounter> dense_slice_counter;
1563 
1564   for (int nb = 0; nb < num_n_blocks; ++nb) {
1565     const int right_num_cols =
1566         std::min(NR, static_cast<int>(right_dim1 - NR * nb));
1567     for (int kb = 0; kb < num_k_blocks; ++kb) {
1568       const int right_num_rows =
1569           std::min(KR, static_cast<int>(right_dim0 - KR * kb));
1570       dense_slice_counter = CreateDenseSlices(
1571           right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
1572           &buffer, &right_slices);
1573       const int num_right_slices = right_slices.size();
1574       tasks.reserve(num_left_slices * num_right_slices);
1575       for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
1576         for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
1577           for (int j_inner = j_outer;
1578                j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
1579             const int num_cols = std::min(N, right_num_cols - N * j_inner);
1580             for (int i_inner = i_outer;
1581                  i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
1582               block_left_slices.clear();
1583               int begin = kb * KR / KL;
1584               int end = std::min<int>((kb + 1) * KR / KL,
1585                                       (right.dimension(0) + KL - 1) / KL);
1586               DCHECK_LT(begin, end);
1587               block_left_slices.insert(block_left_slices.begin(),
1588                                        left_slices[i_inner].begin() + begin,
1589                                        left_slices[i_inner].begin() + end);
1590               tasks.push_back(std::bind(
1591                   &ComputeOutputBlock, block_left_slices,
1592                   std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
1593                   N * j_inner + nb * NR, kb == 0, transpose_output, output));
1594             }
1595           }
1596         }
1597       }
1598       if (sparse_slice_counter) {
1599         sparse_slice_counter->Wait();
1600         sparse_slice_counter.reset(nullptr);
1601       }
1602       if (dense_slice_counter) {
1603         dense_slice_counter->Wait();
1604         dense_slice_counter.reset(nullptr);
1605       }
1606       BlockingCounter bc(tasks.size());
1607       for (const auto& t : tasks) {
1608         thread_pool->workers->Schedule([&bc, &t]() {
1609           t();
1610           bc.DecrementCount();
1611         });
1612       }
1613       bc.Wait();
1614       tasks.clear();
1615       gtl::STLDeleteElements(&right_slices);
1616       right_slices.clear();
1617     }
1618   }
1619   for (auto& left_slice : left_slices) {
1620     gtl::STLDeleteElements(&left_slice);
1621   }
1622 }
1623 
1624 #define REGISTER_SPARSE_MATMUL(TA, TB)                   \
1625   REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
1626                               .Device(DEVICE_CPU)        \
1627                               .TypeConstraint<TA>("Ta")  \
1628                               .TypeConstraint<TB>("Tb"), \
1629                           SparseMatMulOp<TA, TB, SparseMatMul>);
1630 #ifdef TENSORFLOW_USE_LIBXSMM
1631 #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB)           \
1632   REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
1633                               .Device(DEVICE_CPU)        \
1634                               .TypeConstraint<TA>("Ta")  \
1635                               .TypeConstraint<TB>("Tb"), \
1636                           SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
1637 #endif
1638 
1639 REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
1640 
1641 REGISTER_SPARSE_MATMUL(float, bfloat16);
1642 
1643 REGISTER_SPARSE_MATMUL(bfloat16, float);
1644 
1645 #ifdef TENSORFLOW_USE_LIBXSMM
1646 REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
1647 #else
1648 REGISTER_SPARSE_MATMUL(float, float);
1649 #endif
1650 
1651 #undef REGISTER_SPARSE_MATMUL
1652 
1653 }  // end namespace tensorflow
1654