1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/sparse_matmul_op.h"
21
22 #include <map>
23 #include <memory>
24 #include <vector>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/bfloat16.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/lib/core/blocking_counter.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/thread_annotations.h"
39 #include "tensorflow/core/platform/types.h"
40 #ifdef TENSORFLOW_USE_LIBXSMM
41 #include "include/libxsmm_intrinsics_x86.h"
42 #include "include/libxsmm_malloc.h"
43 #include "include/libxsmm_spmdm.h"
44 #endif
45
46 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
47 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
48 #endif
49
50 #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
51
52 namespace tensorflow {
53 namespace {
54
55 template <typename T>
56 using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
57
58 template <typename T>
59 using BasicMatrixMap =
60 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
61
62 using Matrix = BasicMatrix<float>;
63 using MatrixMap = BasicMatrixMap<float>;
64 using CPUDevice = Eigen::ThreadPoolDevice;
65 using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
66
67 // Two commonly used static dsizes. We use Eigen::type2index to allow as much
68 // compile time optimization as possible.
69 #ifdef EIGEN_HAS_INDEX_LIST
70 inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
dsizes_00()71 dsizes_00() {
72 return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
73 }
74 inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
dsizes_10()75 dsizes_10() {
76 return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
77 }
78 #else
dsizes_00()79 inline DSizes dsizes_00() { return DSizes(0, 0); }
dsizes_10()80 inline DSizes dsizes_10() { return DSizes(1, 0); }
81 #endif
82
83 // Blocksizes
84 // TODO(agarwal): compute these sizes based on cache sizes.
85 const int K = 64;
86 const int M = 64;
87 const int N = 128;
88
89 // This stores a sparse representation of a slice of a matrix with size
90 // (num_rows, num_cols). The slice is represented as a series of blocks of size
91 // (num_rows, b), where b = block_size for all but the last block, which may
92 // have fewer columns.
93 //
94 // num_rows and block_size are assumed to be <= 256. This allows storing
95 // different indices as uint8.
96 //
97 // For each block, we store all the non zero entries in data/data3 vector and
98 // the corresponding coordinates of the element in index/index3 vectors. index3
99 // vector stores index of 3 elements in the same row so that these elements can
100 // share the same row coordinate. Each entry in Index3 corresponds to 3 entries
101 // in data3.
102 //
103 // Note that all the data/indices of all the blocks are stored in the same
104 // vectors respectively. To identify block boundaries, we store the block
105 // offsets using index3_offset/index_offset. If there are n blocks in the slice,
106 // index3_offset and index_offset have n entries. The indices for the ith block
107 // are the values in the following range:
108 // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
109 // index_offset.
110 template <typename T>
111 struct SparseSlice {
112 using ConstMatrixMap = BasicMatrixMap<const T>;
113
114 public:
115 // Indices of three elements on the same row.
116 struct Index3 {
117 uint8 m; // row
118 // columns
119 uint8 k1;
120 uint8 k2;
121 uint8 k3;
122 };
123
124 // Index of one element.
125 struct Index {
126 uint8 m;
127 uint8 k;
128 };
129
SparseSlicetensorflow::__anon00ddf21f0111::SparseSlice130 SparseSlice(int nrows, int ncols, int bsize)
131 : num_rows(nrows), num_cols(ncols), block_size(bsize) {
132 DCHECK_LE(nrows, 256);
133 DCHECK_LE(block_size, 256);
134 }
135
136 // Initializes the slice with data starting at mat(0, col_offset) and with
137 // size (num_rows, num_cols).
138 // If Transpose is true, implicitly transposes mat.
139 template <bool Transpose = false>
140 void Initialize(const ConstMatrixMap& mat, int col_offset);
141
142 void Clear();
143
144 // See comments above.
145 std::vector<int> index3_offset;
146 std::vector<Index3> index3;
147 std::vector<T> data3;
148
149 // See comments above. Similar to "index3" except that each element in "index"
150 // corresponds to one element in data.
151 std::vector<int> index_offset;
152 std::vector<Index> index;
153 std::vector<T> data;
154
155 // Number of rows and columns for the slice.
156 const int num_rows;
157 const int num_cols;
158
159 // Block size used to initialize from a matrix.
160 const int block_size;
161 };
162
163 template <typename T>
164 bool IsZero(T v);
165
166 template <>
IsZero(bfloat16 v)167 ALWAYS_INLINE bool IsZero(bfloat16 v) {
168 return !static_cast<bool>(v);
169 }
170
171 template <>
IsZero(float v)172 ALWAYS_INLINE bool IsZero(float v) {
173 return v == 0.0f;
174 }
175
176 template <typename T>
177 template <bool Transpose>
Initialize(const typename SparseSlice<T>::ConstMatrixMap & mat,int col_offset)178 void SparseSlice<T>::Initialize(
179 const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
180 const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
181 const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
182 DCHECK_LE(num_rows, mat_rows);
183 DCHECK_LE(num_cols + col_offset, mat_cols);
184
185 int num_blocks = (num_cols + block_size - 1) / block_size;
186 int mat_size = num_rows * num_cols;
187
188 index3_offset.reserve(num_blocks);
189 data3.reserve(mat_size);
190 index3.reserve(mat_size / 3);
191
192 index_offset.reserve(num_blocks);
193 data.reserve(num_blocks * num_rows * 2);
194 index.reserve(num_blocks * num_rows * 2);
195
196 Index3 idx3;
197 const int stride = Transpose ? mat.dimension(1) : 1;
198
199 for (int i = 0; i < num_blocks; ++i) {
200 int num_block_cols = std::min(block_size, num_cols - block_size * i);
201 for (int row = 0; row < num_rows; ++row) {
202 idx3.m = static_cast<uint8>(row);
203 // Safety note: The following code has a race, since it checks whether
204 // *curr is nonzero and then reads it again on use. However, the result
205 // of the race is only that some of the "nonzeros" in the resulting sparse
206 // representation may actually be zero, which is harmless.
207 const auto* start =
208 Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
209 const auto* curr = start;
210 const auto* end = start + stride * num_block_cols;
211 uint8 k = 0;
212 #define NEXT_ELEM \
213 curr += stride; \
214 ++k;
215 #define EAT_ZEROS \
216 while (curr < end && IsZero<T>(*curr)) { \
217 NEXT_ELEM; \
218 }
219 while (true) {
220 EAT_ZEROS
221 if (curr >= end) break;
222 idx3.k1 = k;
223 const T value1 = *curr;
224 NEXT_ELEM;
225
226 EAT_ZEROS
227 if (curr >= end) {
228 data.push_back(value1);
229 index.push_back({idx3.m, idx3.k1});
230 break;
231 }
232 idx3.k2 = k;
233 const T value2 = *curr;
234 NEXT_ELEM;
235
236 EAT_ZEROS
237 if (curr >= end) {
238 data.push_back(value2);
239 index.push_back({idx3.m, idx3.k2});
240 data.push_back(value1);
241 index.push_back({idx3.m, idx3.k1});
242 break;
243 }
244 idx3.k3 = k;
245 data3.push_back(value1);
246 data3.push_back(value2);
247 data3.push_back(*curr);
248 NEXT_ELEM;
249 index3.push_back(idx3);
250 #undef NEXT_ELEM
251 #undef EAT_ZEROS
252 }
253 }
254 col_offset += block_size;
255 index3_offset.push_back(index3.size());
256 index_offset.push_back(index.size());
257 }
258 DCHECK_EQ(index3_offset.size(), num_blocks);
259 DCHECK_EQ(index_offset.size(), num_blocks);
260 DCHECK_EQ(3 * index3.size(), data3.size());
261 DCHECK_EQ(index.size(), data.size());
262 }
263
264 template <typename T>
Clear()265 void SparseSlice<T>::Clear() {
266 index3_offset.clear();
267 index3.clear();
268 data3.clear();
269 index_offset.clear();
270 index.clear();
271 data.clear();
272 }
273
274 using Packet = Eigen::internal::packet_traits<float>::type;
275 const int kNumOperands = (sizeof(Packet) / sizeof(float));
276 #define LOAD(x) Eigen::internal::pload<Packet>(x);
277 #define EXPAND_BFLOAT_L(x, y) \
278 const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
279 #define EXPAND_BFLOAT_U(x, y) \
280 const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
281 #define STORE(x, y) Eigen::internal::pstore<float>(x, y);
282 #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
283
ConvertBfloat16ToFloat(const bfloat16 * src)284 ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
285 float out = 0;
286 auto tmp = reinterpret_cast<bfloat16*>(&out);
287 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
288 tmp[0] = *src;
289 #else
290 tmp[1] = *src;
291 #endif
292 return out;
293 }
294
ConvertFourBfloat16ToFloat(const bfloat16 * src)295 ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
296 return Eigen::internal::pload4bf16<Packet>(
297 reinterpret_cast<const float*>(src));
298 }
299
ConvertTwoBfloat16ToFloat(const bfloat16 * src)300 ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
301 return Eigen::internal::pload2bf16<Packet>(
302 reinterpret_cast<const float*>(src));
303 }
304
ScalarMulAdd(const float a,const float ** inp,float ** out)305 ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
306 **out += a * **inp;
307 ++*inp;
308 ++*out;
309 }
310
ScalarMulAdd(const float a,const bfloat16 ** inp,float ** out)311 ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
312 float** out) {
313 float inp_f = ConvertBfloat16ToFloat(*inp);
314 **out += a * inp_f;
315 ++*inp;
316 ++*out;
317 }
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)318 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
319 const float a3, const bfloat16** inp1,
320 const bfloat16** inp2,
321 const bfloat16** inp3, float** out) {
322 float inp1_f = ConvertBfloat16ToFloat(*inp1);
323 float inp2_f = ConvertBfloat16ToFloat(*inp2);
324 float inp3_f = ConvertBfloat16ToFloat(*inp3);
325 **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
326 ++*out;
327 ++*inp1;
328 ++*inp2;
329 ++*inp3;
330 }
331
ScalarMulAdd3Way(const float a1,const float a2,const float a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)332 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
333 const float a3, const float** inp1,
334 const float** inp2, const float** inp3,
335 float** out) {
336 **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
337 ++*out;
338 ++*inp1;
339 ++*inp2;
340 ++*inp3;
341 }
342
LoadSingleScalar(const bfloat16 ** data,Packet * l)343 ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
344 auto tmp = ConvertBfloat16ToFloat(*data);
345 *l = Eigen::internal::pset1<Packet>(tmp);
346 ++*data;
347 }
348
LoadTwoScalars(const bfloat16 ** data,Packet * l1,Packet * l2)349 ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
350 Packet* l2) {
351 if (kNumOperands >= 2) {
352 auto tmp = ConvertTwoBfloat16ToFloat(*data);
353 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
354 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
355 *data += 2;
356 } else {
357 LoadSingleScalar(data, l1);
358 LoadSingleScalar(data, l2);
359 }
360 }
361
LoadFourScalars(const bfloat16 ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)362 ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
363 Packet* l2, Packet* l3, Packet* l4) {
364 if (kNumOperands >= 4) {
365 auto tmp = ConvertFourBfloat16ToFloat(*data);
366 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
367 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
368 *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
369 *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
370 *data += 4;
371 } else {
372 LoadTwoScalars(data, l1, l2);
373 LoadTwoScalars(data, l3, l4);
374 }
375 }
376
LoadSingleScalar(const float ** data,Packet * l)377 ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
378 *l = Eigen::internal::pload1<Packet>(*data);
379 ++(*data);
380 }
381
LoadTwoScalars(const float ** data,Packet * l1,Packet * l2)382 ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
383 LoadSingleScalar(data, l1);
384 LoadSingleScalar(data, l2);
385 }
386
LoadFourScalars(const float ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4)387 ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
388 Packet* l3, Packet* l4) {
389 LoadTwoScalars(data, l1, l2);
390 LoadTwoScalars(data, l3, l4);
391 }
392
393 template <typename T>
LoadThreeScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3)394 ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
395 Packet* l3) {
396 LoadTwoScalars(data, l1, l2);
397 LoadSingleScalar(data, l3);
398 }
399
400 template <typename T>
LoadSixScalars(const T ** data,Packet * l1,Packet * l2,Packet * l3,Packet * l4,Packet * l5,Packet * l6)401 ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
402 Packet* l3, Packet* l4, Packet* l5,
403 Packet* l6) {
404 LoadFourScalars(data, l1, l2, l3, l4);
405 LoadTwoScalars(data, l5, l6);
406 }
407
408 // Vectorized version of ScalarMulAdd.
MulAdd(const Packet a,const bfloat16 ** binp,float ** out)409 ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
410 auto inp = reinterpret_cast<const float*>(*binp);
411 const auto b = LOAD(inp);
412 EXPAND_BFLOAT_L(b, b_0);
413 EXPAND_BFLOAT_U(b, b_1);
414 *binp += 2 * kNumOperands;
415 auto c1 = LOAD(*out);
416 auto c2 = LOAD(*out + kNumOperands);
417 FMA(a, b_0, c1, c1);
418 FMA(a, b_1, c2, c2);
419 STORE(*out, c1);
420 STORE(*out + kNumOperands, c2);
421 *out += 2 * kNumOperands;
422 }
423
424 // Vectorized version of ScalarMulAdd3Way.
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)425 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
426 const bfloat16** binp1, const bfloat16** binp2,
427 const bfloat16** binp3, float** out) {
428 auto inp1 = reinterpret_cast<const float*>(*binp1);
429 auto inp2 = reinterpret_cast<const float*>(*binp2);
430 auto inp3 = reinterpret_cast<const float*>(*binp3);
431 auto c1 = LOAD(*out);
432 auto c2 = LOAD(*out + kNumOperands);
433 const auto b1 = LOAD(inp1);
434 EXPAND_BFLOAT_L(b1, b1_0);
435 EXPAND_BFLOAT_U(b1, b1_1);
436 *binp1 += 2 * kNumOperands;
437 const auto b2 = LOAD(inp2);
438 EXPAND_BFLOAT_L(b2, b2_0);
439 EXPAND_BFLOAT_U(b2, b2_1);
440 *binp2 += 2 * kNumOperands;
441 const auto b3 = LOAD(inp3);
442 EXPAND_BFLOAT_L(b3, b3_0);
443 EXPAND_BFLOAT_U(b3, b3_1);
444 *binp3 += 2 * kNumOperands;
445 FMA(a1, b1_0, c1, c1);
446 FMA(a1, b1_1, c2, c2);
447 FMA(a2, b2_0, c1, c1);
448 FMA(a2, b2_1, c2, c2);
449 FMA(a3, b3_0, c1, c1);
450 FMA(a3, b3_1, c2, c2);
451 STORE(*out, c1);
452 STORE(*out + kNumOperands, c2);
453 *out += 2 * kNumOperands;
454 }
455
456 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** binp1,const bfloat16 ** binp2,const bfloat16 ** binp3,float ** out)457 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
458 const Packet a3, const bfloat16** binp1,
459 const bfloat16** binp2, const bfloat16** binp3,
460 float** out) {
461 auto inp1 = reinterpret_cast<const float*>(*binp1);
462 auto inp2 = reinterpret_cast<const float*>(*binp2);
463 auto inp3 = reinterpret_cast<const float*>(*binp3);
464 auto c1 = LOAD(*out);
465 auto c2 = LOAD(*out + kNumOperands);
466 const auto b1 = LOAD(inp1);
467 const auto b2 = LOAD(inp2);
468 const auto b3 = LOAD(inp3);
469
470 EXPAND_BFLOAT_L(b1, b1_0);
471 EXPAND_BFLOAT_U(b1, b1_1);
472 EXPAND_BFLOAT_L(b2, b2_0);
473 EXPAND_BFLOAT_U(b2, b2_1);
474 EXPAND_BFLOAT_L(b3, b3_0);
475 EXPAND_BFLOAT_U(b3, b3_1);
476 auto c3 = LOAD(*out + 2 * kNumOperands);
477 auto c4 = LOAD(*out + 3 * kNumOperands);
478 const auto b4 = LOAD(inp1 + kNumOperands);
479 const auto b5 = LOAD(inp2 + kNumOperands);
480 const auto b6 = LOAD(inp3 + kNumOperands);
481
482 EXPAND_BFLOAT_L(b4, b4_0);
483 EXPAND_BFLOAT_U(b4, b4_1);
484 EXPAND_BFLOAT_L(b5, b5_0);
485 EXPAND_BFLOAT_U(b5, b5_1);
486 EXPAND_BFLOAT_L(b6, b6_0);
487 EXPAND_BFLOAT_U(b6, b6_1);
488
489 FMA(a1, b1_0, c1, c1);
490 FMA(a1, b1_1, c2, c2);
491 FMA(a1, b4_0, c3, c3);
492 FMA(a1, b4_1, c4, c4);
493 FMA(a2, b2_0, c1, c1);
494 FMA(a2, b2_1, c2, c2);
495 FMA(a2, b5_0, c3, c3);
496 FMA(a2, b5_1, c4, c4);
497 FMA(a3, b3_0, c1, c1);
498 FMA(a3, b3_1, c2, c2);
499 FMA(a3, b6_0, c3, c3);
500 FMA(a3, b6_1, c4, c4);
501 STORE(*out, c1);
502 STORE(*out + kNumOperands, c2);
503 STORE(*out + 2 * kNumOperands, c3);
504 STORE(*out + 3 * kNumOperands, c4);
505 *out += 4 * kNumOperands;
506 *binp1 += 4 * kNumOperands;
507 *binp2 += 4 * kNumOperands;
508 *binp3 += 4 * kNumOperands;
509 }
510
511 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const bfloat16 ** inp1,const bfloat16 ** inp2,const bfloat16 ** inp3,float ** out)512 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
513 const Packet a3, const bfloat16** inp1,
514 const bfloat16** inp2, const bfloat16** inp3,
515 float** out) {
516 for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
517 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
518 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
519 }
520 }
521
522 // Vectorized version of ScalarMulAdd
MulAdd(const Packet a,const float ** inp,float ** out)523 ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
524 const auto b = LOAD(*inp);
525 *inp += kNumOperands;
526 auto c = LOAD(*out);
527 FMA(a, b, c, c);
528 STORE(*out, c);
529 *out += kNumOperands;
530 }
531
532 // Vectorized version of ScalarMulAdd3Way
MulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)533 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
534 const float** inp1, const float** inp2,
535 const float** inp3, float** out) {
536 auto c = LOAD(*out);
537 const auto b1 = LOAD(*inp1);
538 *inp1 += kNumOperands;
539 const auto b2 = LOAD(*inp2);
540 *inp2 += kNumOperands;
541 const auto b3 = LOAD(*inp3);
542 *inp3 += kNumOperands;
543 FMA(a1, b1, c, c);
544 FMA(a2, b2, c, c);
545 FMA(a3, b3, c, c);
546 STORE(*out, c);
547 *out += kNumOperands;
548 }
549
550 // Unroll MulAdd3Way for two iterations
TwoMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)551 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
552 const Packet a3, const float** inp1,
553 const float** inp2, const float** inp3,
554 float** out) {
555 auto c1 = LOAD(*out);
556 const auto b1 = LOAD(*inp1);
557 const auto b2 = LOAD(*inp2);
558 const auto b3 = LOAD(*inp3);
559
560 auto c2 = LOAD(*out + kNumOperands);
561 const auto b4 = LOAD(*inp1 + kNumOperands);
562 const auto b5 = LOAD(*inp2 + kNumOperands);
563 const auto b6 = LOAD(*inp3 + kNumOperands);
564
565 FMA(a1, b1, c1, c1);
566 FMA(a1, b4, c2, c2);
567 FMA(a2, b2, c1, c1);
568 FMA(a2, b5, c2, c2);
569 FMA(a3, b3, c1, c1);
570 FMA(a3, b6, c2, c2);
571 STORE(*out, c1);
572 STORE(*out + kNumOperands, c2);
573 *out += 2 * kNumOperands;
574 *inp1 += 2 * kNumOperands;
575 *inp2 += 2 * kNumOperands;
576 *inp3 += 2 * kNumOperands;
577 }
578
579 // Unroll MulAdd3Way for four iterations
FourMulAdd3Way(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)580 ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
581 const Packet a3, const float** inp1,
582 const float** inp2, const float** inp3,
583 float** out) {
584 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
585 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
586 }
587
588 // Apply MulAdd3Way on 128 operands.
MulAdd3Way128(const Packet a1,const Packet a2,const Packet a3,const float ** inp1,const float ** inp2,const float ** inp3,float ** out)589 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
590 const Packet a3, const float** inp1,
591 const float** inp2, const float** inp3,
592 float** out) {
593 if (kNumOperands == 8) {
594 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
595 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
596 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
597 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
598 } else {
599 DCHECK_LE(4 * kNumOperands, 128);
600 for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
601 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
602 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
603 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
604 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
605 }
606 }
607 }
608 // Computes product of "left_slices" with "num_cols" columns of "right", and
609 // stores the output in *"output".
610 // Note that left_slices is a list of SparseSlices, which are conceptually
611 // assumed to be concatenated along the column dimension. Also each SparseSlice
612 // is encoded as a list of blocks with upto N columns. See SparseSlice for more
613 // details.
614 template <typename TL, typename TR, int Cols>
GEPP(const std::vector<SparseSlice<TL> * > & left_slices,const Eigen::TensorMap<Eigen::Tensor<const TR,2,Eigen::RowMajor>,Eigen::Aligned> & right,const int num_cols,Matrix * output)615 inline void GEPP(
616 const std::vector<SparseSlice<TL>*>& left_slices,
617 const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
618 Eigen::Aligned>& right,
619 const int num_cols, Matrix* output) {
620 const int cols = (Cols == -1) ? num_cols : Cols;
621 DCHECK_EQ(num_cols, cols);
622 const int right_num_cols = right.dimension(1);
623 const int output_num_cols = output->dimension(1);
624 static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
625 const int cols_mod = cols % kNumOperandsR;
626 int k_offset = 0;
627 // Pre-compute pointers for output matrix.
628 float* out_ptrs[M];
629 float* const out_start = &(*output)(0, 0);
630 for (int j = 0; j < M; ++j) {
631 out_ptrs[j] = out_start + output_num_cols * j;
632 }
633 for (const auto* left_slice : left_slices) {
634 const auto& left = *left_slice;
635 const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
636 const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
637 const int num_blocks = left.index3_offset.size();
638 int begin3 = 0;
639 int begin = 0;
640 for (int i = 0; i < num_blocks; ++i) {
641 // Pre-compute pointers for right matrix
642 const TR* right_ptrs[K];
643 const auto* const right_start = &right(k_offset, 0);
644 DCHECK_LT(k_offset, right.dimension(0));
645 for (int j = 0; j < K; ++j) {
646 right_ptrs[j] = right_start + right_num_cols * j;
647 }
648
649 const int end3 = left.index3_offset[i];
650 int j = begin3;
651 // Loop unrolled for 2 iterations.
652 for (; j + 1 < end3; j += 2) {
653 Packet l1, l2, l3, nl1, nl2, nl3;
654 LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
655 const auto& index = left.index3[j];
656 const auto& nindex = left.index3[j + 1];
657 float* out = out_ptrs[index.m];
658 float* nout = out_ptrs[nindex.m];
659 const auto* r1 = right_ptrs[index.k1];
660 const auto* r2 = right_ptrs[index.k2];
661 const auto* r3 = right_ptrs[index.k3];
662
663 const auto* nr1 = right_ptrs[nindex.k1];
664 const auto* nr2 = right_ptrs[nindex.k2];
665 const auto* nr3 = right_ptrs[nindex.k3];
666 if (cols == 128) {
667 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
668 MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
669 } else {
670 for (int n = 0; n < cols / kNumOperandsR; ++n) {
671 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
672 MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
673 }
674
675 const float sl1 = Eigen::internal::pfirst<Packet>(l1);
676 const float sl2 = Eigen::internal::pfirst<Packet>(l2);
677 const float sl3 = Eigen::internal::pfirst<Packet>(l3);
678 const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
679 const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
680 const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
681 for (int k = 0; k < cols_mod; ++k) {
682 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
683 ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
684 }
685 }
686 }
687 if (j < end3) {
688 Packet l1, l2, l3;
689 LoadThreeScalars(&data3, &l1, &l2, &l3);
690
691 const auto& index = left.index3[j];
692 float* out = out_ptrs[index.m];
693 const auto* r1 = right_ptrs[index.k1];
694 const auto* r2 = right_ptrs[index.k2];
695 const auto* r3 = right_ptrs[index.k3];
696 if (cols == 128) {
697 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
698 } else {
699 for (int n = 0; n < cols / kNumOperandsR; ++n) {
700 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
701 }
702 const float sl1 = Eigen::internal::pfirst<Packet>(l1);
703 const float sl2 = Eigen::internal::pfirst<Packet>(l2);
704 const float sl3 = Eigen::internal::pfirst<Packet>(l3);
705 for (int k = 0; k < cols_mod; ++k) {
706 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
707 }
708 }
709 }
710 begin3 = end3;
711 int end = left.index_offset[i];
712 // Loop unrolled for 4 iterations.
713 j = begin;
714 for (; j + 3 < end; j += 4) {
715 Packet l, nl, n2l, n3l;
716 LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
717
718 const auto& index = left.index[j];
719 const auto& nindex = left.index[j + 1];
720 const auto& n2index = left.index[j + 2];
721 const auto& n3index = left.index[j + 3];
722 const auto* r = right_ptrs[index.k];
723 const auto* nr = right_ptrs[nindex.k];
724 const auto* n2r = right_ptrs[n2index.k];
725 const auto* n3r = right_ptrs[n3index.k];
726 float* out = out_ptrs[index.m];
727 float* nout = out_ptrs[nindex.m];
728 float* n2out = out_ptrs[n2index.m];
729 float* n3out = out_ptrs[n3index.m];
730
731 for (int n = 0; n < cols / kNumOperandsR; ++n) {
732 MulAdd(l, &r, &out);
733 MulAdd(nl, &nr, &nout);
734 MulAdd(n2l, &n2r, &n2out);
735 MulAdd(n3l, &n3r, &n3out);
736 }
737
738 const float sl1 = Eigen::internal::pfirst<Packet>(l);
739 const float sl2 = Eigen::internal::pfirst<Packet>(nl);
740 const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
741 const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
742 for (int k = 0; k < cols_mod; ++k) {
743 ScalarMulAdd(sl1, &r, &out);
744 ScalarMulAdd(sl2, &nr, &nout);
745 ScalarMulAdd(sl3, &n2r, &n2out);
746 ScalarMulAdd(sl4, &n3r, &n3out);
747 }
748 }
749 while (j < end) {
750 Packet l;
751 LoadSingleScalar(&data, &l);
752 const auto& index = left.index[j];
753 const auto* r = right_ptrs[index.k];
754 float* out = out_ptrs[index.m];
755 for (int n = 0; n < cols / kNumOperandsR; ++n) {
756 MulAdd(l, &r, &out);
757 }
758 const float sl = Eigen::internal::pfirst<Packet>(l);
759 for (int k = 0; k < cols_mod; ++k) {
760 ScalarMulAdd(sl, &r, &out);
761 }
762 j++;
763 }
764 k_offset += left.block_size;
765 begin = end;
766 }
767 }
768 }
769
770 #undef LOAD
771 #undef EXPAND_BFLOAT_L
772 #undef EXPAND_BFLOAT_U
773 #undef STORE
774 #undef FMA
775
776 } // namespace
777
778 template <typename TL, typename TR>
779 class SparseMatMul {
780 using MatrixL = BasicMatrix<TL>;
781 using MatrixR = BasicMatrix<TR>;
782 using ConstMatrixMapL = BasicMatrixMap<const TL>;
783 using ConstMatrixMapR = BasicMatrixMap<const TR>;
784 using MatrixMapR = BasicMatrixMap<TR>;
785
786 public:
787 // Not used; added to match interface of LibxsmmSparseMatMul
788 struct TensorInfoCache {};
789
790 // Perform matrix multiplication of "left" and "right", and store the result
791 // in *"output".
792 public:
793 static inline void Compute(TensorInfoCache* cache,
794 const ConstMatrixMapL& left,
795 const ConstMatrixMapR& right, bool transpose_left,
796 const DeviceBase::CpuWorkerThreads* thread_pool,
797 bool transpose_output, MatrixMap* output);
798
799 private:
800 // Computes multiplication of left and num_cols columns of right, and stores
801 // the output block in *"output" at offsets "output_row_offset" and
802 // "output_col_offset". If assign is true, assigns the value to that block,
803 // else adds the values to the existing values.
804 static inline void ComputeOutputBlock(
805 const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
806 int num_cols, int output_row_offset, int output_col_offset, bool assign,
807 bool transpose_output, MatrixMap* output);
808
809 // Encodes "mat" using a sparse representation and stores that in
810 // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
811 // "slice_num_cols", each grid element is converted into a SparseSlice and
812 // stored in mat_slices. "slice_block_size" is used to perform further column
813 // blocking of each slice.
814 static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
815 const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
816 int slice_block_size, int slice_num_cols,
817 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
818 const DeviceBase::CpuWorkerThreads* thread_pool);
819
820 // This function chops "mat" along column dimension into pieces with at most N
821 // columns, and concatenates the pieces one after the other in "buffer". It
822 // returns the list of the pieces in "slices". It returns a BlockingCounter
823 // which should be used to wait for the shuffle operations to complete.
824 static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
825 const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
826 int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
827 MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
828
829 // Helper function for CreateDenseSlices to move the data around. It returns a
830 // BlockingCounter which should be used to wait for the shuffle operations to
831 // complete.
832 static inline BlockingCounter* ShuffleMatrix(
833 const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
834 int slice_col_start, int slice_num_cols, const int N,
835 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
836
837 // Helper function for CreateDenseSlices to create slices.
838 static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
839 const int num_slices,
840 std::vector<ConstMatrixMapR*>* slices);
841
842 // Heuristics to compute various block sizes.
843 // KR, NR: block sizes for "right". We run blocking iterations that operate on
844 // matrices with at most this size.
845 // KL: grid size along the column dimension used while encoding left.
846 // IB, JB: number of left and right slices to multiply together. This is used
847 // for ordering different ComputeBlockOutput operations inside each blocking
848 // iteration so as to potentially reduce the working set size.
849 static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
850 const ConstMatrixMapR& right,
851 bool transpose_left, int num_threads,
852 int* KR, int* NR, int* KL, int* JB,
853 int* IB);
854
855 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
856 };
857
858 #ifdef TENSORFLOW_USE_LIBXSMM
859 template <typename TL, typename TR>
860 class LibxsmmSparseMatMul {
861 using MatrixL = BasicMatrix<TL>;
862 using MatrixR = BasicMatrix<TR>;
863 using ConstMatrixMapL = BasicMatrixMap<const TL>;
864 using ConstMatrixMapR = BasicMatrixMap<const TR>;
865 using MatrixMapR = BasicMatrixMap<TR>;
866
867 public:
868 // This structure contains a set of libxsmm kernels for sizes that have been
869 // encountered previously by this operator so that libxsmm does not need to
870 // reallocate its scratchpad memory each time (which hurts performance
871 // substantially).
872 struct TensorInfoCache {
873 struct TensorInfoCacheEntry {
874 // Parameters for kernel
875 int M;
876 int K;
877 int N;
878 int max_threads;
879 // libxsmm handle and matrix data
880 libxsmm_spmdm_handle handle;
881 libxsmm_CSR_sparseslice* output_csr;
882 // Chain to non-libxsmm implementation's cache in case that ever becomes
883 // useful (it is an empty struct right now)
884 typename SparseMatMul<TL, TR>::TensorInfoCache
885 non_libxsmm_cache; // Currently not used
886 };
887 // protects entries; invariant: entries is a valid std::multimap
888 tensorflow::mutex lock;
889 // Because there could be multiple matrix multiplies with the same sizes
890 // going on at the same time, we need to allow multiple cache entries for a
891 // given set of parameters. Taking and returning entries is used to make
892 // sure the same cache entry is not used from two threads at a time.
893 std::multimap<std::tuple<int, int, int, int>,
894 std::unique_ptr<TensorInfoCacheEntry>>
895 entries TF_GUARDED_BY(lock);
896
TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache897 TensorInfoCache() : lock(), entries() {}
898 // Look up and remove first entry with these parameters, creating one if
899 // there isn't one
take_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache900 std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
901 int max_threads)
902 TF_LOCKS_EXCLUDED(lock) {
903 tensorflow::mutex_lock ml(lock);
904 auto key = std::make_tuple(M, K, N, max_threads);
905 auto it = entries.find(key);
906 if (it != entries.end()) {
907 auto val = std::move(it->second);
908 entries.erase(it);
909 return val;
910 } else {
911 std::unique_ptr<TensorInfoCacheEntry> e{
912 new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
913 // setup scoped allocator, which uses cpu_allocator() for this scope
914 const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
915 libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
916 return e;
917 }
918 }
919 // Add a cache entry with certain parameters
return_cache_entrytensorflow::LibxsmmSparseMatMul::TensorInfoCache920 void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
921 TF_LOCKS_EXCLUDED(lock) {
922 tensorflow::mutex_lock ml(lock);
923 auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
924 entries.insert(std::make_pair(key, std::move(e)));
925 }
~TensorInfoCachetensorflow::LibxsmmSparseMatMul::TensorInfoCache926 ~TensorInfoCache() {
927 tensorflow::mutex_lock ml(lock);
928 for (auto& p : entries) {
929 libxsmm_spmdm_destroy(&p.second->handle);
930 }
931 entries.clear();
932 }
933
934 private:
935 TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
936 };
937
938 // Perform matrix multiplication of "left" and "right", and store the result
939 // in *"output".
940 public:
941 static inline void Compute(TensorInfoCache* cache,
942 const ConstMatrixMapL& left,
943 const ConstMatrixMapR& right, bool transpose_left,
944 const DeviceBase::CpuWorkerThreads* thread_pool,
945 bool transpose_output, MatrixMap* output);
946
947 private:
948 TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
949 };
950 #endif
951
952 template <typename TL, typename TR,
953 template <typename TL2, typename TR2> class DoMatMul>
954 class SparseMatMulOp : public OpKernel {
955 using MatrixR = BasicMatrix<TR>;
956 using ConstMatrixMapR = BasicMatrixMap<const TR>;
957
958 public:
SparseMatMulOp(OpKernelConstruction * ctx)959 explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
960 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
961 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
962 OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
963 OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
964 }
965
Compute(OpKernelContext * ctx)966 void Compute(OpKernelContext* ctx) override {
967 const Tensor& a = ctx->input(0);
968 const Tensor& b = ctx->input(1);
969 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
970 errors::InvalidArgument("a is not a matrix"));
971 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
972 errors::InvalidArgument("b is not a matrix"));
973
974 const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
975 const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
976 const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
977 const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
978
979 OP_REQUIRES(ctx, k == k2,
980 errors::InvalidArgument(
981 "Matrix size incompatible: a: ", a.shape().DebugString(),
982 ", b: ", b.shape().DebugString()));
983 Tensor* output = nullptr;
984 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
985
986 if (k == 0) {
987 // If the inner dimension k in the matrix multiplication is zero, we fill
988 // the output with zeros.
989 functor::SetZeroFunctor<CPUDevice, float> f;
990 f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
991 return;
992 }
993
994 auto out = output->matrix<float>();
995
996 std::unique_ptr<Tensor> a_float;
997 std::unique_ptr<Tensor> b_float;
998 if (!a_is_sparse_ && !b_is_sparse_) {
999 auto left = &a;
1000 auto right = &b;
1001 // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
1002 if (std::is_same<TL, bfloat16>::value) {
1003 a_float.reset(new Tensor(DT_FLOAT, a.shape()));
1004 BFloat16ToFloat(a.flat<bfloat16>().data(),
1005 a_float->flat<float>().data(), a.NumElements());
1006 left = a_float.get();
1007 }
1008 if (std::is_same<TR, bfloat16>::value) {
1009 b_float.reset(new Tensor(DT_FLOAT, b.shape()));
1010 BFloat16ToFloat(b.flat<bfloat16>().data(),
1011 b_float->flat<float>().data(), b.NumElements());
1012 right = b_float.get();
1013 }
1014 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
1015 dim_pair[0].first = transpose_a_ ? 0 : 1;
1016 dim_pair[0].second = transpose_b_ ? 1 : 0;
1017
1018 out.device(ctx->template eigen_device<CPUDevice>()) =
1019 left->matrix<float>().contract(right->matrix<float>(), dim_pair);
1020 return;
1021 }
1022
1023 auto left = &a;
1024 auto right = &b;
1025 bool transpose_output = false;
1026 bool transpose_a = transpose_a_;
1027 bool transpose_b = transpose_b_;
1028 if (!a_is_sparse_) {
1029 // Swap the order of multiplications using the identity:
1030 // A * B = (B' * A')'.
1031 std::swap(left, right);
1032 std::swap(transpose_a, transpose_b);
1033 transpose_a = !transpose_a;
1034 transpose_b = !transpose_b;
1035 transpose_output = !transpose_output;
1036 }
1037
1038 std::unique_ptr<Tensor> right_tr;
1039 if (transpose_b) {
1040 // TODO(agarwal): avoid transposing the matrix here and directly handle
1041 // transpose in CreateDenseSlices.
1042 right_tr.reset(
1043 new Tensor(right->dtype(),
1044 TensorShape({right->dim_size(1), right->dim_size(0)})));
1045
1046 const auto perm = dsizes_10();
1047 if (transpose_output) {
1048 right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
1049 right->matrix<TL>().shuffle(perm);
1050 } else {
1051 right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
1052 right->matrix<TR>().shuffle(perm);
1053 }
1054 right = right_tr.get();
1055 }
1056
1057 if (transpose_output) {
1058 DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
1059 right->matrix<TL>(), transpose_a,
1060 ctx->device()->tensorflow_cpu_worker_threads(),
1061 transpose_output, &out);
1062 } else {
1063 DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
1064 right->matrix<TR>(), transpose_a,
1065 ctx->device()->tensorflow_cpu_worker_threads(),
1066 transpose_output, &out);
1067 }
1068 }
1069
1070 private:
1071 bool transpose_a_;
1072 bool transpose_b_;
1073 bool a_is_sparse_;
1074 bool b_is_sparse_;
1075
1076 // Cache for non-transposed-output multiply
1077 typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
1078 // Cache for transposed-output multiply
1079 typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
1080
1081 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
1082 };
1083
1084 template <typename TL, typename TR>
ComputeOutputBlock(const std::vector<SparseSlice<TL> * > & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,int num_cols,int output_row_offset,int output_col_offset,bool assign,bool transpose_output,MatrixMap * output)1085 inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
1086 const std::vector<SparseSlice<TL>*>& left,
1087 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
1088 int output_row_offset, int output_col_offset, bool assign,
1089 bool transpose_output, MatrixMap* output) {
1090 const auto perm = dsizes_10();
1091 int num_rows = left[0]->num_rows;
1092 const int rhs_num_cols = right.dimension(1);
1093 DCHECK_LE(num_cols, rhs_num_cols);
1094 Matrix out(num_rows, rhs_num_cols);
1095 out.setZero();
1096 if (num_cols == N) {
1097 GEPP<TL, TR, N>(left, right, num_cols, &out);
1098 } else {
1099 GEPP<TL, TR, -1>(left, right, num_cols, &out);
1100 }
1101 if (!assign) {
1102 const DSizes begin(output_row_offset, output_col_offset);
1103 const DSizes sizes(num_rows, num_cols);
1104 if (transpose_output) {
1105 if (num_cols == rhs_num_cols) {
1106 output->shuffle(perm).slice(begin, sizes) += out;
1107 } else {
1108 const auto zero = dsizes_00();
1109 output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
1110 }
1111 } else {
1112 if (num_cols == rhs_num_cols) {
1113 output->slice(begin, sizes) += out;
1114 } else {
1115 const auto zero = dsizes_00();
1116 output->slice(begin, sizes) += out.slice(zero, sizes);
1117 }
1118 }
1119 } else {
1120 std::unique_ptr<Matrix> out_tr;
1121 if (transpose_output) {
1122 out_tr.reset(new Matrix(rhs_num_cols, num_rows));
1123 *out_tr = out.shuffle(perm);
1124 std::swap(output_row_offset, output_col_offset);
1125 std::swap(num_rows, num_cols);
1126 }
1127 const Matrix& final_out = transpose_output ? *out_tr : out;
1128 for (int i = 0; i < num_rows; ++i) {
1129 memcpy(&(*output)(output_row_offset + i, output_col_offset),
1130 &final_out(i, 0), num_cols * sizeof(float));
1131 }
1132 }
1133 }
1134
1135 template <typename TL, typename TR>
1136 inline std::unique_ptr<BlockingCounter>
CreateSparseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & mat,bool transpose,int slice_num_rows,int slice_block_size,int slice_num_cols,std::vector<std::vector<SparseSlice<TL> * >> * mat_slices,const DeviceBase::CpuWorkerThreads * thread_pool)1137 SparseMatMul<TL, TR>::CreateSparseSlices(
1138 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
1139 int slice_num_rows, int slice_block_size, int slice_num_cols,
1140 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
1141 const DeviceBase::CpuWorkerThreads* thread_pool) {
1142 const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
1143 const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
1144 const int num_slices_dim0 =
1145 std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
1146 const int num_slices_dim1 =
1147 std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
1148 mat_slices->resize(num_slices_dim0);
1149 BlockingCounter* counter =
1150 new BlockingCounter(num_slices_dim0 * num_slices_dim1);
1151 auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
1152 SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
1153 int col_offset) {
1154 if (transpose) {
1155 sparse_slice->template Initialize<true>(*slice, col_offset);
1156 } else {
1157 sparse_slice->template Initialize<false>(*slice, col_offset);
1158 }
1159 delete slice;
1160 counter->DecrementCount();
1161 };
1162 for (int i = 0; i < num_slices_dim0; ++i) {
1163 (*mat_slices)[i].resize(num_slices_dim1);
1164 int num_rows =
1165 std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
1166 for (int j = 0; j < num_slices_dim1; ++j) {
1167 int num_cols =
1168 std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
1169 SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
1170 if (transpose) {
1171 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1172 &mat(0, i * slice_num_rows), mat.dimensions());
1173 } else {
1174 DSizes d(num_rows, mat_num_cols);
1175 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
1176 &mat(i * slice_num_rows, 0), d);
1177 }
1178 auto* sparse_slice =
1179 new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
1180 (*mat_slices)[i][j] = sparse_slice;
1181 thread_pool->workers->Schedule(
1182 [=]() { work(sparse_slice, slice, slice_num_cols * j); });
1183 }
1184 }
1185 return std::unique_ptr<BlockingCounter>(counter);
1186 }
1187 #define LOAD(x) Eigen::internal::ploadu<Packet>((x));
1188 #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
1189 #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
1190
1191 template <int NUM_ELEM = -1>
CopyAndMayBeInterleaveBfloat16(void * bdst,const void * bsrc,int num_elements)1192 ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
1193 int num_elements) {
1194 DCHECK_GE(kNumOperands, 8);
1195 static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
1196 const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
1197 DCHECK_EQ(num, num_elements);
1198 const float* src = reinterpret_cast<const float*>(bsrc);
1199 float* dst = reinterpret_cast<float*>(bdst);
1200 for (int index = 0; index + kStep <= num; index += kStep) {
1201 auto in = LOAD(src);
1202 auto tmp = INTERLEAVE(in);
1203 STORE(dst, tmp);
1204 src += kNumOperands;
1205 dst += kNumOperands;
1206 }
1207 if (num % kStep != 0) {
1208 memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
1209 }
1210 }
1211
1212 template <typename T>
CopyAndMayBeInterleave(void * dst,const void * src,int num_elements)1213 ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
1214 int num_elements) {
1215 if (std::is_same<T, float>::value || kNumOperands < 8) {
1216 memcpy(dst, src, num_elements * sizeof(T));
1217 } else if (std::is_same<T, bfloat16>::value) {
1218 if (num_elements == N) {
1219 CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
1220 } else {
1221 CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
1222 }
1223 } else {
1224 LOG(FATAL) << "Unsupported type";
1225 }
1226 }
1227
1228 #undef LOAD
1229 #undef Interleave
1230 #undef Store
1231
1232 template <typename TL, typename TR>
ShuffleMatrix(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int slice_row_start,int slice_num_rows,int slice_col_start,int slice_num_cols,const int N,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer)1233 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
1234 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
1235 int slice_row_start, int slice_num_rows, int slice_col_start,
1236 int slice_num_cols, const int N,
1237 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
1238 DCHECK_EQ(N % 2, 0);
1239 DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
1240 // Note(nikhilsarda): This heuristic is optimal in benchmarks as of
1241 // Jan 21, 2020.
1242 int num_threads = std::min(thread_pool->num_threads, 8);
1243 BlockingCounter* counter = new BlockingCounter(num_threads);
1244 DCHECK_EQ(N, buffer->dimension(1));
1245 auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
1246 slice_num_cols, N, buffer, counter](int s, int e) {
1247 const int row_start = s % slice_num_rows + slice_row_start;
1248 const int col_start = s / slice_num_rows * N + slice_col_start;
1249 auto* out_start = &(*buffer)(s, 0);
1250 const auto* input_start = &mat(row_start, col_start);
1251 const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
1252 slice_col_start + slice_num_cols - 1);
1253 const int mat_num_cols = mat.dimension(1);
1254 const int row_slice_size = slice_num_rows * mat_num_cols;
1255
1256 const int aligned_end = slice_num_cols / N * slice_num_rows;
1257 const int e1 = std::min(e, aligned_end);
1258 while (s < e1) {
1259 CopyAndMayBeInterleave<TR>(out_start, input_start, N);
1260 out_start += N;
1261 input_start += mat_num_cols;
1262 if (input_start > input_end) {
1263 input_start = input_start - row_slice_size + N;
1264 }
1265 ++s;
1266 }
1267 int s1 = std::max(s, aligned_end);
1268 const int copy_num_cols = slice_num_cols % N;
1269 while (s1 < e) {
1270 CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
1271 out_start += N;
1272 input_start += mat_num_cols;
1273 ++s1;
1274 }
1275 if (counter) counter->DecrementCount();
1276 };
1277
1278 int start = 0;
1279 int end = 0;
1280 int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
1281 DCHECK_LE(num_out_rows, buffer->dimension(0));
1282 for (int i = std::max(1, num_threads); i > 0; --i) {
1283 end = start + num_out_rows / i;
1284 thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
1285 num_out_rows -= (end - start);
1286 start = end;
1287 }
1288 return counter;
1289 }
1290
1291 template <typename TL, typename TR>
SliceMatrix(const MatrixR & mat,const int num_rows,const int num_slices,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1292 inline void SparseMatMul<TL, TR>::SliceMatrix(
1293 const MatrixR& mat, const int num_rows, const int num_slices,
1294 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1295 slices->resize(num_slices);
1296 DSizes d(num_rows, mat.dimension(1));
1297 DCHECK_LE(num_rows * num_slices, mat.dimension(0));
1298 for (int i = 0; i < num_slices; ++i) {
1299 (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
1300 }
1301 }
1302
1303 template <typename TL, typename TR>
CreateDenseSlices(const typename SparseMatMul<TL,TR>::ConstMatrixMapR & mat,int row_start,int num_rows,int col_start,int num_cols,const DeviceBase::CpuWorkerThreads * thread_pool,MatrixR * buffer,std::vector<typename SparseMatMul<TL,TR>::ConstMatrixMapR * > * slices)1304 inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
1305 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
1306 int num_rows, int col_start, int num_cols,
1307 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
1308 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
1309 std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
1310 mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
1311 const int num_slices = (num_cols + N - 1) / N;
1312 SliceMatrix(*buffer, num_rows, num_slices, slices);
1313 return shuffle_counter;
1314 }
1315
1316 template <typename TL, typename TR>
ComputeBlockSizes(const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,int num_threads,int * KR,int * NR,int * KL,int * JB,int * IB)1317 inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
1318 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1319 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1320 bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
1321 int* IB) {
1322 // Heuristics for calculating block sizes
1323 // Assume two hyperthreads per core.
1324 const int est_num_cores = std::max(1, (num_threads + 1) / 2);
1325 // Use block of rhs with at most 128K floats per core.
1326 const int mem = est_num_cores * 128 * 1024;
1327 *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
1328 *NR = right.dimension(1);
1329 if (*KR * *NR > mem) {
1330 // 4096 may be enough to amortize the cost of writes.
1331 *KR = std::min<int>(*KR, 4096);
1332 }
1333 // Use sizes that are multiples of K and 256.
1334 *KR = std::max(1, *KR / K) * K;
1335 *NR = std::max(1, *NR / 256) * 256;
1336 if (*KR * *NR > mem) {
1337 *NR = mem / *KR;
1338 }
1339 *NR = std::max(1, *NR / 256) * 256;
1340
1341 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1342 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1343 for (*KL = 1024; *KL > K; *KL /= 2) {
1344 if (*KR % *KL == 0 &&
1345 std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
1346 break;
1347 }
1348 }
1349 DCHECK_EQ(*KL % K, 0);
1350 DCHECK_GE(*KR, *KL);
1351 if (*KR < right.dimension(0)) {
1352 CHECK_EQ(*KR % *KL, 0);
1353 }
1354
1355 *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
1356 *IB = 8 * *JB;
1357 DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
1358 }
1359
1360 #ifdef TENSORFLOW_USE_LIBXSMM
1361
1362 template <typename F>
do_on_all_threads(const DeviceBase::CpuWorkerThreads * thread_pool,const F & f)1363 void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
1364 const F& f) {
1365 int num_threads = thread_pool->num_threads;
1366 if (num_threads == 0) {
1367 LOG(FATAL) << "Have 0 threads in thread pool";
1368 } else if (num_threads == 1) {
1369 f(0);
1370 } else {
1371 BlockingCounter counter(num_threads - 1);
1372 for (int i = 1; i < num_threads; ++i) {
1373 thread_pool->workers->Schedule([&, i]() {
1374 f(i);
1375 counter.DecrementCount();
1376 });
1377 }
1378 f(0);
1379 counter.Wait();
1380 }
1381 }
1382
1383 template <typename T>
1384 struct empty_type_wrapper {};
1385
1386 // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
1387 // allow overloading
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,const float * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1388 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1389 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1390 const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
1391 int tid, int nthreads) {
1392 return libxsmm_spmdm_createSparseSlice_fp32_thread(
1393 handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
1394 }
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,const bfloat16 * A,libxsmm_CSR_sparseslice * libxsmm_output_csr_a,int block_id,int tid,int nthreads)1395 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1396 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1397 char transA, const bfloat16* A,
1398 libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
1399 int nthreads) {
1400 return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
1401 handle, transA, reinterpret_cast<const libxsmm_bfloat16*>(A),
1402 libxsmm_output_csr_a, block_id, tid, nthreads);
1403 }
1404
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<bfloat16>,const libxsmm_spmdm_handle * handle,char transA,char transB,const bfloat16 * alpha,libxsmm_CSR_sparseslice * A_sparse,const bfloat16 * B,char transC,const bfloat16 * beta,float * C,int block_id,int tid,int nthreads)1405 void wrapper_libxsmm_spmdm_compute_generic_thread(
1406 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
1407 char transA, char transB, const bfloat16* alpha,
1408 libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
1409 const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
1410 return libxsmm_spmdm_compute_bfloat16_thread(
1411 handle, transA, transB, reinterpret_cast<const libxsmm_bfloat16*>(alpha),
1412 A_sparse, reinterpret_cast<const libxsmm_bfloat16*>(B), transC,
1413 reinterpret_cast<const libxsmm_bfloat16*>(beta), C, block_id, tid,
1414 nthreads);
1415 }
wrapper_libxsmm_spmdm_compute_generic_thread(empty_type_wrapper<float>,const libxsmm_spmdm_handle * handle,char transA,char transB,const float * alpha,libxsmm_CSR_sparseslice * A_sparse,const float * B,char transC,const float * beta,float * C,int block_id,int tid,int nthreads)1416 void wrapper_libxsmm_spmdm_compute_generic_thread(
1417 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
1418 char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
1419 const float* B, char transC, const float* beta, float* C, int block_id,
1420 int tid, int nthreads) {
1421 return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
1422 A_sparse, B, transC, beta, C,
1423 block_id, tid, nthreads);
1424 }
1425
1426 template <typename TL, typename TR>
Compute(typename LibxsmmSparseMatMul<TL,TR>::TensorInfoCache * cache,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename LibxsmmSparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1427 inline void LibxsmmSparseMatMul<TL, TR>::Compute(
1428 typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
1429 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
1430 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
1431 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1432 bool transpose_output, MatrixMap* output) {
1433 const int num_threads = thread_pool->num_threads;
1434 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
1435 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
1436 const int right_dim0 = right.dimension(0);
1437 const int right_dim1 = right.dimension(1);
1438 CHECK_EQ(left_dim1, right_dim0);
1439 CHECK_EQ(left_dim0,
1440 (transpose_output ? output->dimension(1) : output->dimension(0)));
1441 CHECK_EQ(right_dim1,
1442 (transpose_output ? output->dimension(0) : output->dimension(1)));
1443 #if 0 // this issue seems to be resolved
1444 if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
1445 // Causes problems in libxsmm
1446 SparseMatMul<TL, TR>::Compute(
1447 nullptr /* Assumes no cached data for fallback */, left, right,
1448 transpose_left, thread_pool, transpose_output, output);
1449 return;
1450 }
1451 #endif
1452 auto left_data = left.data();
1453 auto right_data = right.data();
1454 auto output_data = output->data();
1455 // Initialize libxsmm for this matrix; make sure another thread doesn't use
1456 // this handle
1457 auto entry =
1458 cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
1459 // Convert the left matrix to compressed sparse row (CSR) format
1460 ptrdiff_t total_num_creation_blocks =
1461 libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
1462 std::atomic<int> cur_create_block_number;
1463 cur_create_block_number.store(0);
1464 do_on_all_threads(thread_pool, [&](int i) {
1465 while (true) {
1466 int work_item = cur_create_block_number.fetch_add(1);
1467 if (work_item >= total_num_creation_blocks) break;
1468 wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
1469 empty_type_wrapper<TL>{}, &entry->handle,
1470 (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
1471 i, num_threads);
1472 }
1473 });
1474 // Do matrix-matrix multiplication
1475 ptrdiff_t total_num_mult_blocks =
1476 libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
1477 std::atomic<int> cur_mult_block_number;
1478 cur_mult_block_number.store(0);
1479 do_on_all_threads(thread_pool, [&](int i) {
1480 while (true) {
1481 int work_item = cur_mult_block_number.fetch_add(1);
1482 if (work_item >= total_num_mult_blocks) break;
1483 const TL alpha(1.0); // Stored in a variable so we can get a pointer
1484 const TL beta(0.0); // Stored in a variable so we can get a pointer
1485 wrapper_libxsmm_spmdm_compute_generic_thread(
1486 empty_type_wrapper<TL>{}, &entry->handle,
1487 (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
1488 right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
1489 work_item, i, num_threads);
1490 }
1491 });
1492 // Put handle + CSR storage back into cache
1493 cache->return_cache_entry(std::move(entry));
1494 }
1495
1496 #endif // TENSORFLOW_USE_LIBXSMM
1497
1498 // Here is an overview of the SparseMatMul code. Note that we assume that the
1499 // left matrix is sparse.
1500 //
1501 // The matrix "left" is divided into a grid with blocksize of (M, KL). Each
1502 // block is encoded as a SparseSlice. These grid elements are stored as
1503 // std::vector<std::vector<SparseSlice>>. Each element of the outer vector
1504 // represents M rows of the left matrix. Lets call these elements l_i and lets
1505 // call each element of the inner vector L_mk.
1506 //
1507 // The matrix "right" is divided into a grid with block size KR * NR. Lets
1508 // denote the blocks on the right as R_kn. Note that we ensure that KL divides
1509 // KR so that for each element R_kn, we don't need to multiply it with any
1510 // partial L_mk blocks.
1511 //
1512 // We then multiply each right side block R_kn with the full "left" matrix and
1513 // update the output. These iterations are run sequentially since R_kn are
1514 // packed into the same underlying temporary buffer.
1515 //
1516 // In each iteration we do the following:
1517 // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
1518 // (=128) columns and then concatenating these slices into a buffer. This is
1519 // done so that each slice r_j of R_kn is stored contiguously in memory. Note
1520 // that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
1521 // buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
1522 // 2. For each (l_i, r_j), we compute the inner product using the GEPP function
1523 // and update the output block o_ij. These calls are further blocked to
1524 // reduce the working set size. In each iteration we take IB elements from
1525 // {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
1526 template <typename TL, typename TR>
Compute(typename SparseMatMul<TL,TR>::TensorInfoCache *,const typename SparseMatMul<TL,TR>::ConstMatrixMapL & left,const typename SparseMatMul<TL,TR>::ConstMatrixMapR & right,bool transpose_left,const DeviceBase::CpuWorkerThreads * thread_pool,bool transpose_output,MatrixMap * output)1527 inline void SparseMatMul<TL, TR>::Compute(
1528 typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
1529 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
1530 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
1531 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
1532 bool transpose_output, MatrixMap* output) {
1533 const int num_threads = thread_pool->num_threads;
1534 int KR, NR, KL, JB, IB;
1535 ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
1536 &JB, &IB);
1537 // Slice the left matrix
1538 std::vector<std::vector<SparseSlice<TL>*>> left_slices;
1539 std::unique_ptr<BlockingCounter> sparse_slice_counter =
1540 CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
1541 transpose_left, M, K, KL, &left_slices, thread_pool);
1542 const int num_left_slices = left_slices.size();
1543
1544 const int right_dim0 = right.dimension(0);
1545 const int right_dim1 = right.dimension(1);
1546 // Allocate buffer for storing slices of right matrix.
1547 // Note buffer needs enough space to hold at most a KR * NR matrix since that
1548 // is the block size per iteration.
1549 const int buffer_num_rows =
1550 std::min(KR, right_dim0) * ((std::min(NR, right_dim1) + N - 1) / N);
1551 MatrixR buffer(buffer_num_rows, N);
1552 std::vector<ConstMatrixMapR*> right_slices;
1553
1554 std::vector<SparseSlice<TL>*> block_left_slices;
1555 std::vector<std::function<void(void)>> tasks;
1556 // Number of blocks based on block sizes of KR * NR.
1557 const int num_k_blocks = (right_dim0 + KR - 1) / KR;
1558 const int num_n_blocks = (right_dim1 + NR - 1) / NR;
1559 std::unique_ptr<BlockingCounter> dense_slice_counter;
1560
1561 for (int nb = 0; nb < num_n_blocks; ++nb) {
1562 const int right_num_cols =
1563 std::min(NR, static_cast<int>(right_dim1 - NR * nb));
1564 for (int kb = 0; kb < num_k_blocks; ++kb) {
1565 const int right_num_rows =
1566 std::min(KR, static_cast<int>(right_dim0 - KR * kb));
1567 dense_slice_counter = CreateDenseSlices(
1568 right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
1569 &buffer, &right_slices);
1570 const int num_right_slices = right_slices.size();
1571 tasks.reserve(num_left_slices * num_right_slices);
1572 for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
1573 for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
1574 for (int j_inner = j_outer;
1575 j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
1576 const int num_cols = std::min(N, right_num_cols - N * j_inner);
1577 for (int i_inner = i_outer;
1578 i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
1579 block_left_slices.clear();
1580 int begin = kb * KR / KL;
1581 int end = std::min<int>((kb + 1) * KR / KL,
1582 (right.dimension(0) + KL - 1) / KL);
1583 DCHECK_LT(begin, end);
1584 block_left_slices.insert(block_left_slices.begin(),
1585 left_slices[i_inner].begin() + begin,
1586 left_slices[i_inner].begin() + end);
1587 tasks.push_back(std::bind(
1588 &ComputeOutputBlock, block_left_slices,
1589 std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
1590 N * j_inner + nb * NR, kb == 0, transpose_output, output));
1591 }
1592 }
1593 }
1594 }
1595 if (sparse_slice_counter) {
1596 sparse_slice_counter->Wait();
1597 sparse_slice_counter.reset(nullptr);
1598 }
1599 if (dense_slice_counter) {
1600 dense_slice_counter->Wait();
1601 dense_slice_counter.reset(nullptr);
1602 }
1603 BlockingCounter bc(tasks.size());
1604 for (const auto& t : tasks) {
1605 thread_pool->workers->Schedule([&bc, &t]() {
1606 t();
1607 bc.DecrementCount();
1608 });
1609 }
1610 bc.Wait();
1611 tasks.clear();
1612 for (auto& temp : right_slices) {
1613 delete temp;
1614 }
1615 right_slices.clear();
1616 }
1617 }
1618 for (auto& left_slice : left_slices) {
1619 for (auto& temp : left_slice) {
1620 delete temp;
1621 }
1622 left_slice.clear();
1623 }
1624 }
1625
1626 #define REGISTER_SPARSE_MATMUL(TA, TB) \
1627 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
1628 .Device(DEVICE_CPU) \
1629 .TypeConstraint<TA>("Ta") \
1630 .TypeConstraint<TB>("Tb"), \
1631 SparseMatMulOp<TA, TB, SparseMatMul>);
1632 #ifdef TENSORFLOW_USE_LIBXSMM
1633 #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB) \
1634 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
1635 .Device(DEVICE_CPU) \
1636 .TypeConstraint<TA>("Ta") \
1637 .TypeConstraint<TB>("Tb"), \
1638 SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
1639 #endif
1640
1641 REGISTER_SPARSE_MATMUL(float, bfloat16);
1642 REGISTER_SPARSE_MATMUL(bfloat16, float);
1643
1644 #ifdef TENSORFLOW_USE_LIBXSMM
1645 REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16);
1646 REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
1647 #else
1648 REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
1649 REGISTER_SPARSE_MATMUL(float, float);
1650 #endif
1651
1652 #undef REGISTER_SPARSE_MATMUL
1653
1654 } // end namespace tensorflow
1655