1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
12 
13 // evaluator for thread pool device
14 #ifdef EIGEN_USE_THREADS
15 
16 namespace Eigen {
17 
18 #ifdef EIGEN_USE_SIMPLE_THREAD_POOL
19 namespace internal {
20 
21 template<typename LhsScalar, typename LhsMapper, typename Index>
22 struct packLhsArg {
23   LhsScalar* blockA;
24   const LhsMapper& lhs;
25   const Index m_start;
26   const Index k_start;
27   const Index mc;
28   const Index kc;
29 };
30 
31 template<typename LhsScalar, typename RhsScalar, typename RhsMapper, typename OutputMapper, typename Index>
32 struct packRhsAndKernelArg {
33   const MaxSizeVector<LhsScalar*>* blockAs;
34   RhsScalar* blockB;
35   const RhsMapper& rhs;
36   OutputMapper& output;
37   const Index m;
38   const Index k;
39   const Index n;
40   const Index mc;
41   const Index kc;
42   const Index nc;
43   const Index num_threads;
44   const Index num_blockAs;
45   const Index max_m;
46   const Index k_block_idx;
47   const Index m_block_idx;
48   const Index n_block_idx;
49   const Index m_blocks;
50   const Index n_blocks;
51   MaxSizeVector<Notification*>* kernel_notifications;
52   const MaxSizeVector<Notification*>* lhs_notifications;
53   const bool need_to_pack;
54 };
55 
56 }  // end namespace internal
57 #endif  // EIGEN_USE_SIMPLE_THREAD_POOL
58 
59 template<typename Indices, typename LeftArgType, typename RightArgType>
60 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
61     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
62 
63   typedef ThreadPoolDevice Device;
64 
65   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
66   typedef TensorContractionEvaluatorBase<Self> Base;
67 
68   typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
69   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
70   typedef typename XprType::Index Index;
71   typedef typename XprType::CoeffReturnType CoeffReturnType;
72   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
73 
74   enum {
75     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
76   };
77 
78   // Most of the code is assuming that both input tensors are ColMajor. If the
79   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
80   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
81   // will pretend B is LHS and A is RHS.
82   typedef typename internal::conditional<
83     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
84   typedef typename internal::conditional<
85     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
86 
87   static const int LDims =
88       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
89   static const int RDims =
90       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
91   static const int ContractDims = internal::array_size<Indices>::value;
92 
93   typedef array<Index, LDims> left_dim_mapper_t;
94   typedef array<Index, RDims> right_dim_mapper_t;
95 
96   typedef array<Index, ContractDims> contract_t;
97   typedef array<Index, LDims - ContractDims> left_nocontract_t;
98   typedef array<Index, RDims - ContractDims> right_nocontract_t;
99 
100   static const int NumDims = LDims + RDims - 2 * ContractDims;
101 
102   typedef DSizes<Index, NumDims> Dimensions;
103 
104   // typedefs needed in evalTo
105   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
106   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
107   typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
108 
109   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
110   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
111 
112   TensorEvaluator(const XprType& op, const Device& device) :
113       Base(op, device) {}
114 
115 #ifndef EIGEN_USE_SIMPLE_THREAD_POOL
116   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
117             bool rhs_inner_dim_reordered, int Alignment>
118   void evalProduct(Scalar* buffer) const {
119     typedef
120         typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
121             LhsScalar;
122     typedef
123         typename internal::remove_const<typename EvalRightArgType::Scalar>::type
124             RhsScalar;
125     typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
126     typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
127     typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
128     typedef internal::TensorContractionInputMapper<
129         LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
130         contract_t, internal::packet_traits<LhsScalar>::size,
131         lhs_inner_dim_contiguous, false, Unaligned>
132         LhsMapper;
133     typedef internal::TensorContractionInputMapper<
134         RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
135         contract_t, internal::packet_traits<RhsScalar>::size,
136         rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
137         RhsMapper;
138     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
139     typedef internal::gemm_pack_lhs<LhsScalar, Index,
140                                     typename LhsMapper::SubMapper, Traits::mr,
141                                     Traits::LhsProgress, ColMajor>
142         LhsPacker;
143     typedef internal::gemm_pack_rhs<
144         RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
145         RhsPacker;
146     typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
147                                   Traits::mr, Traits::nr, false, false>
148         GebpKernel;
149 
150     const Index m = this->m_i_size;
151     const Index n = this->m_j_size;
152     const Index k = this->m_k_size;
153     if (m == 0 || n == 0 || k == 0) return;
154 
155     // Compute a set of algorithm parameters:
156     // - kernel block sizes (bm, bn, bk)
157     // - task grain sizes (number of kernels executed per task: gm, gn)
158     // - number of threads
159     // - sharding by row/column
160     // - parallel packing or first lhs then rhs
161     // and some derived parameters:
162     // - number of tasks (nm, nn, nk)
163     // - number of kernels (nm0, nn0)
164     // Unfortunately, all these parameters are tightly interdependent.
165     // So in some cases we first compute approximate values, then compute other
166     // values based on these approximations and then refine the approximations.
167 
168     // There are lots of heuristics here. There is some reasoning behind them,
169     // but ultimately they are just tuned on contraction benchmarks for
170     // different input configurations, thread counts and instruction sets.
171     // So feel free to question any of them.
172 
173     // Compute whether we want to shard by row or by column.
174     // This is a first approximation, it will be refined later. Since we don't
175     // know number of threads yet we use 2, because what's we are most
176     // interested in at this point is whether it makes sense to use
177     // parallelization at all or not.
178     bool shard_by_col = shardByCol(m, n, 2);
179 
180     // First approximation of kernel blocking sizes.
181     // Again, we don't know number of threads yet, so we use 2.
182     Index bm, bn, bk;
183     if (shard_by_col) {
184       internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
185                                           internal::ShardByCol>
186           blocking(k, m, n, 2);
187       bm = blocking.mc();
188       bn = blocking.nc();
189       bk = blocking.kc();
190     } else {
191       internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
192                                           internal::ShardByRow>
193           blocking(k, m, n, 2);
194       bm = blocking.mc();
195       bn = blocking.nc();
196       bk = blocking.kc();
197     }
198 
199     // Compute optimal number of threads.
200     // Note: we use bk instead of k here because we are interested in amount of
201     // _parallelizable_ computations, and computations are not parallelizable
202     // across k dimension.
203     const TensorOpCost cost =
204         contractionCost(m, n, bm, bn, bk, shard_by_col, false);
205     int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
206         static_cast<double>(n) * m, cost, this->m_device.numThreads());
207 
208     // TODO(dvyukov): this is a stop-gap to prevent regressions while the cost
209     // model is not tuned. Remove this when the cost model is tuned.
210     if (n == 1) num_threads = 1;
211 
212     if (num_threads == 1) {
213       // The single-threaded algorithm should be faster in this case.
214       if (n == 1)
215         this->template evalGemv<lhs_inner_dim_contiguous,
216                                 rhs_inner_dim_contiguous,
217                                 rhs_inner_dim_reordered, Alignment>(buffer);
218       else
219         this->template evalGemm<lhs_inner_dim_contiguous,
220                                 rhs_inner_dim_contiguous,
221                                 rhs_inner_dim_reordered, Alignment>(buffer);
222       return;
223     }
224 
225     // Now that we know number of threads, recalculate sharding and blocking.
226     shard_by_col = shardByCol(m, n, num_threads);
227     if (shard_by_col) {
228       internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
229                                           internal::ShardByCol>
230           blocking(k, m, n, num_threads);
231       bm = blocking.mc();
232       bn = blocking.nc();
233       bk = blocking.kc();
234     } else {
235       internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
236                                           internal::ShardByRow>
237           blocking(k, m, n, num_threads);
238       bm = blocking.mc();
239       bn = blocking.nc();
240       bk = blocking.kc();
241     }
242 
243     // Number of kernels for each dimension.
244     Index nm0 = divup(m, bm);
245     Index nn0 = divup(n, bn);
246     Index nk = divup(k, bk);
247 
248     // Calculate task grain size (number of kernels executed per task).
249     // This task size coarsening serves two purposes:
250     // 1. It reduces per-task overheads including synchronization overheads.
251     // 2. It allows to use caches better (reuse the same packed rhs in several
252     // consecutive kernels).
253     Index gm = 1;
254     Index gn = 1;
255     // If we are sharding by column, then we prefer to reduce rows first.
256     if (shard_by_col) {
257       gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
258       gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
259     } else {
260       gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
261       gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
262     }
263     // Number of tasks in each dimension.
264     Index nm = divup(nm0, gm);
265     Index nn = divup(nn0, gn);
266 
267     // Last by not least, decide whether we want to issue both lhs and rhs
268     // packing in parallel; or issue lhs packing first, and then issue rhs
269     // packing when lhs packing completes (for !shard_by_col lhs and rhs are
270     // swapped). Parallel packing allows more parallelism (for both packing and
271     // kernels), while sequential packing provides better locality (once
272     // a thread finishes rhs packing it proceed to kernels with that rhs).
273     // First, we are interested in parallel packing if there are few tasks.
274     bool parallel_pack = num_threads >= nm * nn;
275     // Also do parallel packing if all data fits into L2$.
276     if (m * bk * Index(sizeof(LhsScalar)) + n * bk * Index(sizeof(RhsScalar)) <=
277         l2CacheSize() * num_threads)
278       parallel_pack = true;
279     // But don't do it if we will use each rhs only once. Locality seems to be
280     // more important in this case.
281     if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
282 
283     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides,
284                   this->m_i_strides, this->m_left_contracting_strides,
285                   this->m_k_strides);
286 
287     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides,
288                   this->m_j_strides, this->m_right_contracting_strides,
289                   this->m_k_strides);
290 
291     Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
292             OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n,
293                           k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
294                           shard_by_col, parallel_pack)
295         .run();
296   }
297 
298   // Context coordinates a single parallel gemm operation.
299   template <typename LhsPacker, typename RhsPacker, typename GebpKernel,
300             typename LhsMapper, typename RhsMapper, typename OutputMapper>
301   class Context {
302    public:
303     Context(const Device& device, int num_threads, LhsMapper& lhs,
304             RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
305             Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
306             Index gn, Index nm0, Index nn0, bool shard_by_col,
307             bool parallel_pack)
308         : device_(device),
309           lhs_(lhs),
310           rhs_(rhs),
311           buffer_(buffer),
312           output_(buffer, tm),
313           num_threads_(num_threads),
314           shard_by_col_(shard_by_col),
315           parallel_pack_(parallel_pack),
316           m_(tm),
317           n_(tn),
318           k_(tk),
319           bm_(bm),
320           bn_(bn),
321           bk_(bk),
322           nm_(nm),
323           nn_(nn),
324           nk_(nk),
325           gm_(gm),
326           gn_(gn),
327           nm0_(nm0),
328           nn0_(nn0)
329   {
330       for (Index x = 0; x < P; x++) {
331         // Normal number of notifications for k slice switch is
332         // nm_ + nn_ + nm_ * nn_. However, first P - 1 slices will receive only
333         // nm_ + nn_ notifications, because they will not receive notifications
334         // from preceeding kernels.
335         state_switch_[x] =
336             x == 0
337                 ? 1
338                 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
339                       (x == P - 1 ? nm_ * nn_ : 0);
340         state_packing_ready_[x] =
341             parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
342         state_kernel_[x] = new std::atomic<uint8_t>*[nm_];
343         for (Index m = 0; m < nm_; m++) {
344           state_kernel_[x][m] = new std::atomic<uint8_t>[nn_];
345           // Kernels generally receive 3 notifications (previous kernel + 2
346           // packing), but the first slice won't get notifications from previous
347           // kernels.
348           for (Index n = 0; n < nn_; n++)
349             state_kernel_[x][m][n].store(
350                 (x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
351                 std::memory_order_relaxed);
352         }
353       }
354 
355       // Allocate memory for packed rhs/lhs matrices.
356       size_t align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1);
357       size_t lhs_size =
358           divup<size_t>(bm_ * bk_ * sizeof(LhsScalar), align) * align;
359       size_t rhs_size =
360           divup<size_t>(bn_ * bk_ * sizeof(RhsScalar), align) * align;
361       packed_mem_ = static_cast<char*>(internal::aligned_malloc(
362           (nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1)));
363       char* mem = static_cast<char*>(packed_mem_);
364       for (Index x = 0; x < numext::mini<Index>(nk_, P - 1); x++) {
365         packed_lhs_[x].resize(nm0_);
366         for (Index m = 0; m < nm0_; m++) {
367           packed_lhs_[x][m] = reinterpret_cast<LhsScalar*>(mem);
368           mem += lhs_size;
369         }
370         packed_rhs_[x].resize(nn0_);
371         for (Index n = 0; n < nn0_; n++) {
372           packed_rhs_[x][n] = reinterpret_cast<RhsScalar*>(mem);
373           mem += rhs_size;
374         }
375       }
376     }
377 
378     ~Context() {
379       for (Index x = 0; x < P; x++) {
380         for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
381         delete[] state_kernel_[x];
382       }
383       internal::aligned_free(packed_mem_);
384     }
385 
386     void run() {
387       // Kick off packing of the first slice.
388       signal_switch(0, 1);
389       // Wait for overall completion.
390       // TODO(dvyukov): this wait can lead to deadlock.
391       // If nthreads contractions are concurrently submitted from worker
392       // threads, this wait will block all worker threads and the system will
393       // deadlock.
394       done_.Wait();
395     }
396 
397    private:
398     Notification done_;
399     const Device& device_;
400     LhsMapper& lhs_;
401     RhsMapper& rhs_;
402     Scalar* const buffer_;
403     OutputMapper output_;
404     const int num_threads_;
405     const bool shard_by_col_;
406     const bool parallel_pack_;
407     // Matrix sizes.
408     const Index m_;
409     const Index n_;
410     const Index k_;
411     // Block sizes.
412     const Index bm_;
413     const Index bn_;
414     const Index bk_;
415     // Number of tasks.
416     const Index nm_;
417     const Index nn_;
418     const Index nk_;
419     // Task grain sizes (number of kernels executed per task).
420     const Index gm_;
421     const Index gn_;
422     // Number of blocks (this is different from ni_/nn_ because of task size
423     // coarsening).
424     const Index nm0_;
425     const Index nn0_;
426 
427     // Parallelization strategy.
428     //
429     // Blocks related to the same k block can run in parallel because they write
430     // to different output blocks. So we parallelize within k slices, this
431     // gives us parallelism level of m x n. Before we can start any kernels
432     // related to k-th slice, we need to issue m lhs packing tasks and n rhs
433     // packing tasks.
434     //
435     // However, there is a bottleneck when we are finishing kernels for k-th
436     // slice (at the very end there is only 1 runnable kernel). To mitigate this
437     // bottleneck we allow kernels from k-th and k+1-th slices to run in
438     // parallel. Note that (m, n, k) and (m, n, k+1) kernels write to the same
439     // output block, so they must not run in parallel.
440     //
441     // This gives us the following dependency graph.
442     // On each k slice we have m x n kernel tasks, m lhs paking tasks and n rhs
443     // packing tasks.
444     // Kernel (m, n, k) can start when:
445     //  - kernel (m, n, k-1) has finished
446     //  - lhs packing (m, k) has finished
447     //  - rhs packing (n, k) has finished
448     // Lhs/rhs packing can start when:
449     //  - all k-1 packing has finished (artificially imposed to limit amount of
450     //  parallel packing)
451     //
452     // On top of that we limit runnable tasks to two consecutive k slices.
453     // This is done to limit amount of memory we need for packed lhs/rhs
454     // (for each k slice we need m*bk + n*bk memory in packed_lhs_/packed_rhs_).
455     //
456     // state_switch_ tracks when we are ready to switch to the next k slice.
457     // state_kernel_[m][n] tracks when we are ready to kick off kernel (m, n).
458     // These variable are rolling over 3 consecutive k slices: first two we are
459     // actively executing + one to track completion of kernels in the second
460     // slice.
461     static const Index P = 3;
462     void* packed_mem_;
463     std::vector<LhsScalar*> packed_lhs_[P - 1];
464     std::vector<RhsScalar*> packed_rhs_[P - 1];
465     std::atomic<uint8_t>** state_kernel_[P];
466     // state_switch_ is frequently modified by worker threads, while other
467     // fields are read-only after constructor. Let's move it to a separate cache
468     // line to reduce cache-coherency traffic.
469     char pad_[128];
470     std::atomic<Index> state_packing_ready_[P];
471     std::atomic<Index> state_switch_[P];
472 
473     void pack_lhs(Index m, Index k) {
474       const Index mend = m * gm_ + gm(m);
475       for (Index m1 = m * gm_; m1 < mend; m1++)
476         LhsPacker()(packed_lhs_[k % (P - 1)][m1],
477                     lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
478 
479       if (!parallel_pack_ && shard_by_col_) {
480         signal_packing(k);
481       } else {
482         signal_switch(k + 1);
483         for (Index n = nn_ - 1; n >= 0; n--) signal_kernel(m, n, k, n == 0);
484       }
485     }
486 
487     void pack_rhs(Index n, Index k) {
488       const Index nend = n * gn_ + gn(n);
489       for (Index n1 = n * gn_; n1 < nend; n1++) {
490         if (k == 0) {
491           // Zero the output memory in parallel.
492           // On 10000x2x10000 mm zeroing can easily take half of time.
493           // Zero (bn x m) row. Safe to do here because all kernels that will
494           // write to this memory depend on completion of this task.
495           // Note: don't call device_.memset() here. device_.memset() blocks on
496           // thread pool worker thread, which can lead to underutilization and
497           // deadlocks.
498           memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
499         }
500         RhsPacker()(packed_rhs_[k % (P - 1)][n1],
501                     rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
502       }
503 
504       if (parallel_pack_ || shard_by_col_) {
505         signal_switch(k + 1);
506         for (Index m = nm_ - 1; m >= 0; m--) signal_kernel(m, n, k, m == 0);
507       } else {
508         signal_packing(k);
509       }
510     }
511 
512     void kernel(Index m, Index n, Index k) {
513       // Note: order of iteration matters here. Iteration over m is innermost
514       // because we want to reuse the same packed rhs in consequetive tasks
515       // (rhs fits into L2$ while lhs only into L3$).
516       const Index nend = n * gn_ + gn(n);
517       const Index mend = m * gm_ + gm(m);
518       if (shard_by_col_) {
519         for (Index n1 = n * gn_; n1 < nend; n1++) {
520           for (Index m1 = m * gm_; m1 < mend; m1++)
521             GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
522                          packed_lhs_[k % (P - 1)][m1],
523                          packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
524                          Scalar(1), -1, -1, 0, 0);
525         }
526       } else {
527         for (Index m1 = m * gm_; m1 < mend; m1++)
528           for (Index n1 = n * gn_; n1 < nend; n1++) {
529             GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
530                          packed_lhs_[k % (P - 1)][m1],
531                          packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
532                          Scalar(1), -1, -1, 0, 0);
533           }
534       }
535       signal_kernel(m, n, k + 1, false);
536       signal_switch(k + 2);
537     }
538 
539     void signal_packing(Index k) {
540       eigen_assert(!parallel_pack_);
541       Index s = state_packing_ready_[k % P].fetch_sub(1);
542       eigen_assert(s > 0);
543       if (s != 1) return;
544       state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
545       enqueue_packing(k, shard_by_col_);
546     }
547 
548     void signal_kernel(Index m, Index n, Index k, bool sync) {
549       std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
550       Index s = state->load();
551       eigen_assert(s > 0);
552       if (s != 1 && state->fetch_sub(1) != 1) return;
553       state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
554       if (sync)
555         kernel(m, n, k);
556       else
557         device_.enqueueNoNotification([=]() { kernel(m, n, k); });
558     }
559 
560     void signal_switch(Index k, Index v = 1) {
561       Index s = state_switch_[k % P].fetch_sub(v);
562       eigen_assert(s >= v);
563       if (s != v) return;
564 
565       // Ready to switch to the next k slice.
566       // Reset counter for the next iteration.
567       state_switch_[k % P] =
568           (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
569           nm_ * nn_;
570       if (k < nk_) {
571         // Issue lhs/rhs packing. Their completion will in turn kick off
572         // kernels.
573         if (parallel_pack_) {
574           enqueue_packing(k, !shard_by_col_);
575           enqueue_packing(k, shard_by_col_);
576         } else if (shard_by_col_) {
577           enqueue_packing(k, false);
578         } else {
579           enqueue_packing(k, true);
580         }
581 
582         // Termination handling.
583         // Because kernel completion signals k + 2 switch, we need to finish nk
584         // + 2 slices without issuing any tasks on nk + 1 slice. So here we
585         // pretend that all nk + 1 packing tasks just finish instantly; so that
586         // nk + 2 switch only waits for completion of nk kernels.
587       } else if (k == nk_) {
588         signal_switch(k + 1,
589                       parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
590       } else {
591         done_.Notify();
592       }
593     }
594 
595     // Enqueue all rhs/lhs packing for k-th slice.
596     void enqueue_packing(Index k, bool rhs) {
597       enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
598     }
599 
600     void enqueue_packing_helper(Index start, Index end, Index k, bool rhs) {
601       if (end - start == 1) {
602         if (rhs)
603           pack_rhs(start, k);
604         else
605           pack_lhs(start, k);
606       } else {
607         Index mid = (start + end) / 2;
608         device_.enqueueNoNotification(
609             [=]() { enqueue_packing_helper(mid, end, k, rhs); });
610         device_.enqueueNoNotification(
611             [=]() { enqueue_packing_helper(start, mid, k, rhs); });
612       }
613     }
614 
615     // Block sizes with accounting for potentially incomplete last block.
616     Index bm(Index m) const { return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
617     Index bn(Index n) const { return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
618     Index bk(Index k) const { return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
619     // Task grain sizes accounting for potentially incomplete last task.
620     Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
621     Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
622 
623     Context(const Context&) = delete;
624     void operator=(const Context&) = delete;
625   };
626 
627   // Decide whether we want to shard m x n contraction by columns or by rows.
628   static bool shardByCol(Index m, Index n, Index num_threads) {
629     // Note: we are comparing both n and m against Traits::nr, it is not
630     // a mistake. We are trying to figure out how both n and m will fit into
631     // the main sharding dimension.
632 
633     // Sharding by column is the default
634     // ... unless there is enough data for vectorization over rows
635     if (m / num_threads >= Traits::nr &&
636         // and not enough data for vectorization over columns
637         (n / num_threads < Traits::nr ||
638          // ... or barely enough data for vectorization over columns,
639          // but it is not evenly dividable across threads
640          (n / num_threads < 4 * Traits::nr &&
641           (n % (num_threads * Traits::nr)) != 0 &&
642           // ... and it is evenly dividable across threads for rows
643           ((m % (num_threads * Traits::nr)) == 0 ||
644            // .. or it is not evenly dividable for both dimensions but
645            // there is much more data over rows so that corner effects are
646            // mitigated.
647            (m / n >= 6)))))
648       return false;
649     // Wait, or if matrices are just substantially prolonged over the other
650     // dimension.
651     if (n / num_threads < 16 * Traits::nr && m > n * 32) return false;
652     return true;
653   }
654 
655   Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
656                  int num_threads, bool shard_by_col) const {
657     Index gm = 1;
658     Index gm1 = 1;
659     Index nm0 = divup(m, bm);
660     Index nm1 = nm0;
661     for (;;) {
662       // Find the next candidate for m grain size. It needs to result in
663       // different number of blocks. E.g. if we have 10 kernels, we want to try
664       // 5 and 10, but not 6, 7, 8 and 9.
665       while (gm1 <= nm0 && nm1 == divup(nm0, gm1)) gm1++;
666       if (gm1 > nm0) break;
667       // Check the candidate.
668       int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
669                            shard_by_col);
670       if (res < 0) break;
671       nm1 = divup(nm0, gm1);
672       if (res == 0) continue;
673       // Commit new grain size.
674       gm = gm1;
675     }
676     return gm;
677   }
678 
679   Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
680                  int num_threads, bool shard_by_col) const {
681     Index gn = 1;
682     Index gn1 = 1;
683     Index nn0 = divup(n, bn);
684     Index nn1 = nn0;
685     for (;;) {
686       while (gn1 <= nn0 && nn1 == divup(nn0, gn1)) gn1++;
687       if (gn1 > nn0) break;
688       int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
689                            shard_by_col);
690       if (res < 0) break;
691       nn1 = divup(nn0, gn1);
692       if (res == 0) continue;
693       gn = gn1;
694     }
695     return gn;
696   }
697 
698   // checkGrain checks whether grain (gm, gn) is suitable and is better than
699   // (oldgm, oldgn).
700   int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
701                  Index gn, Index oldgm, Index oldgn, int num_threads,
702                  bool shard_by_col) const {
703     const TensorOpCost cost =
704         contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col, true);
705     double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(
706         static_cast<double>(bm) * gm * bn * gn, cost);
707     // If the task is too small, then we agree on it regardless of anything
708     // else. Otherwise synchronization overheads will dominate.
709     if (taskSize < 1) return 1;
710     // If it is too large, then we reject it and all larger tasks.
711     if (taskSize > 2) return -1;
712     // Now we are in presumably good task size range.
713     // The main deciding factor here is parallelism. Consider that we have 12
714     // kernels and 4 threads. Grains of 2, 3 and 4 all yield good task sizes.
715     // But 2/4 yield 6/3 tasks, which gives us parallelism of 0.75 (at most 3/4
716     // of cores will be busy). While grain size 3 gives us 4 tasks, which gives
717     // us parallelism of 1 (we can load all cores).
718     Index nm0 = divup(m, bm);
719     Index nn0 = divup(n, bn);
720     Index new_tasks = divup(nm0, gm) * divup(nn0, gn);
721     double new_parallelism = static_cast<double>(new_tasks) /
722                              (divup<int>(new_tasks, num_threads) * num_threads);
723     Index old_tasks = divup(nm0, oldgm) * divup(nn0, oldgn);
724     double old_parallelism = static_cast<double>(old_tasks) /
725                              (divup<int>(old_tasks, num_threads) * num_threads);
726     if (new_parallelism > old_parallelism || new_parallelism == 1) return 1;
727     return 0;
728   }
729 
730 #else  // EIGEN_USE_SIMPLE_THREAD_POOL
731 
732   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
733   void evalProduct(Scalar* buffer) const {
734     if (this->m_j_size == 1) {
735       this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
736       return;
737     }
738 
739     evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
740   }
741 
742   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
743   void evalGemm(Scalar* buffer) const {
744     // columns in left side, rows in right side
745     const Index k = this->m_k_size;
746 
747     // rows in left side
748     const Index m = this->m_i_size;
749 
750     // columns in right side
751     const Index n = this->m_j_size;
752 
753     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
754     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
755 
756 
757     const int lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
758     const int rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
759 
760     typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
761                                                    LeftEvaluator, left_nocontract_t,
762                                                    contract_t, lhs_packet_size,
763                                                    lhs_inner_dim_contiguous,
764                                                    false, Unaligned> LhsMapper;
765 
766     typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
767                                                    RightEvaluator, right_nocontract_t,
768                                                    contract_t, rhs_packet_size,
769                                                    rhs_inner_dim_contiguous,
770                                                    rhs_inner_dim_reordered, Unaligned> RhsMapper;
771 
772     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
773 
774     // TODO: packing could be faster sometimes if we supported row major tensor mappers
775     typedef internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, Traits::mr,
776                                     Traits::LhsProgress, ColMajor> LhsPacker;
777     typedef internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> RhsPacker;
778 
779     // TODO: replace false, false with conjugate values?
780     typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
781                                   Traits::mr, Traits::nr, false, false> GebpKernel;
782 
783     typedef internal::packLhsArg<LhsScalar, LhsMapper, Index> packLArg;
784     typedef internal::packRhsAndKernelArg<LhsScalar, RhsScalar, RhsMapper, OutputMapper, Index> packRKArg;
785 
786     // initialize data mappers
787     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
788                   this->m_left_contracting_strides, this->m_k_strides);
789 
790     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
791                   this->m_right_contracting_strides, this->m_k_strides);
792 
793     OutputMapper output(buffer, m);
794 
795     // compute block sizes (which depend on number of threads)
796     const Index num_threads = this->m_device.numThreads();
797     internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, num_threads);
798     Index mc = blocking.mc();
799     Index nc = blocking.nc();
800     Index kc = blocking.kc();
801     eigen_assert(mc <= m);
802     eigen_assert(nc <= n);
803     eigen_assert(kc <= k);
804 
805 #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
806     const Index k_blocks = CEIL_DIV(k, kc);
807     const Index n_blocks = CEIL_DIV(n, nc);
808     const Index m_blocks = CEIL_DIV(m, mc);
809     const Index sizeA = mc * kc;
810     const Index sizeB = kc * nc;
811 
812     /*    cout << "m: " << m << " n: " << n << " k: " << k << endl;
813     cout << "mc: " << mc << " nc: " << nc << " kc: " << kc << endl;
814     cout << "m_blocks: " << m_blocks << " n_blocks: " << n_blocks << " k_blocks: " << k_blocks << endl;
815     cout << "num threads: " << num_threads << endl;
816     */
817 
818     // note: m_device.allocate should return 16 byte aligned pointers, but if blockA and blockB
819     //       aren't 16 byte aligned segfaults will happen due to SIMD instructions
820     // note: You can get away with allocating just a single blockA and offsets and meet the
821     //       the alignment requirements with the assumption that
822     //       (Traits::mr * sizeof(ResScalar)) % 16 == 0
823     const Index numBlockAs = numext::mini(num_threads, m_blocks);
824     MaxSizeVector<LhsScalar *> blockAs(num_threads);
825     for (int i = 0; i < num_threads; i++) {
826       blockAs.push_back(static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar))));
827     }
828 
829     // To circumvent alignment issues, I'm just going to separately allocate the memory for each thread
830     // TODO: is this too much memory to allocate? This simplifies coding a lot, but is wasteful.
831     //       Other options: (1) reuse memory when a thread finishes. con: tricky
832     //                      (2) allocate block B memory in each thread. con: overhead
833     MaxSizeVector<RhsScalar *> blockBs(n_blocks);
834     for (int i = 0; i < n_blocks; i++) {
835       blockBs.push_back(static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar))));
836     }
837 
838     // lhs_notifications starts with all null Notifications
839     MaxSizeVector<Notification*> lhs_notifications(num_threads, nullptr);
840 
841     // this should really be numBlockAs * n_blocks;
842     const Index num_kernel_notifications = num_threads * n_blocks;
843     MaxSizeVector<Notification*> kernel_notifications(num_kernel_notifications,
844                                                     nullptr);
845 
846     for (Index k_block_idx = 0; k_block_idx < k_blocks; k_block_idx++) {
847       const Index k_start = k_block_idx * kc;
848       // make sure we don't overshoot right edge of left matrix
849       const Index actual_kc = numext::mini(k_start + kc, k) - k_start;
850 
851       for (Index m_block_idx = 0; m_block_idx < m_blocks; m_block_idx += numBlockAs) {
852         const Index num_blocks = numext::mini(m_blocks-m_block_idx, numBlockAs);
853 
854         for (Index mt_block_idx = m_block_idx; mt_block_idx < m_block_idx+num_blocks; mt_block_idx++) {
855           const Index m_start = mt_block_idx * mc;
856           const Index actual_mc = numext::mini(m_start + mc, m) - m_start;
857           eigen_assert(actual_mc > 0);
858 
859           Index blockAId = (k_block_idx * m_blocks + mt_block_idx) % num_threads;
860 
861           for (int i = 0; i < n_blocks; ++i) {
862             Index notification_id = (blockAId * n_blocks + i);
863             // Wait for any current kernels using this slot to complete
864             // before using it.
865             if (kernel_notifications[notification_id]) {
866               wait_until_ready(kernel_notifications[notification_id]);
867               delete kernel_notifications[notification_id];
868             }
869             kernel_notifications[notification_id] = new Notification();
870           }
871           const packLArg arg = {
872             blockAs[blockAId], // blockA
873             lhs,        // lhs
874             m_start,    // m
875             k_start,    // k
876             actual_mc,  // mc
877             actual_kc,  // kc
878           };
879 
880           // Delete any existing notification since we may be
881           // replacing it.  The algorithm should ensure that there are
882           // no existing waiters on this notification.
883           delete lhs_notifications[blockAId];
884           lhs_notifications[blockAId] =
885           this->m_device.enqueue(&Self::packLhs<packLArg, LhsPacker>, arg);
886         }
887 
888         // now start kernels.
889         const Index m_base_start = m_block_idx * mc;
890         const bool need_to_pack = m_block_idx == 0;
891 
892         for (Index n_block_idx = 0; n_block_idx < n_blocks; n_block_idx++) {
893           const Index n_start = n_block_idx * nc;
894           const Index actual_nc = numext::mini(n_start + nc, n) - n_start;
895 
896           // first make sure the previous kernels are all done before overwriting rhs. Also wait if
897           // we're going to start new k. In both cases need_to_pack is true.
898           if (need_to_pack) {
899             for (Index i = num_blocks; i < num_threads; ++i) {
900               Index blockAId = (k_block_idx * m_blocks + i + m_block_idx) % num_threads;
901               Index future_id = (blockAId * n_blocks + n_block_idx);
902               wait_until_ready(kernel_notifications[future_id]);
903             }
904           }
905 
906           packRKArg arg = {
907             &blockAs, // blockA
908             blockBs[n_block_idx], // blockB
909             rhs,          // rhs
910             output,       // output
911             m_base_start, // m
912             k_start,      // k
913             n_start,      // n
914             mc,           // mc
915             actual_kc,    // kc
916             actual_nc,    // nc
917             num_threads,
918             numBlockAs,
919             m,
920             k_block_idx,
921             m_block_idx,
922             n_block_idx, // n_block_idx
923             m_blocks, // m_blocks
924             n_blocks, // n_blocks
925             &kernel_notifications, // kernel notifications
926             &lhs_notifications,    // lhs notifications
927             need_to_pack, // need_to_pack
928           };
929 
930           // We asynchronously kick off this function, which ends up
931           // notifying the appropriate kernel_notifications objects,
932           // which this thread waits on before exiting.
933           this->m_device.enqueueNoNotification(&Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>, arg);
934         }
935       }
936     }
937 
938     // Make sure all the kernels are done.
939     for (size_t i = 0; i < kernel_notifications.size(); ++i) {
940       wait_until_ready(kernel_notifications[i]);
941       delete kernel_notifications[i];
942     }
943 
944     // No need to wait for lhs notifications since they should have
945     // already been waited on.  Just clean them up.
946     for (size_t i = 0; i < lhs_notifications.size(); ++i) {
947       delete lhs_notifications[i];
948     }
949 
950     // deallocate all of the memory for both A and B's
951     for (size_t i = 0; i < blockAs.size(); i++) {
952       this->m_device.deallocate(blockAs[i]);
953     }
954     for (size_t i = 0; i < blockBs.size(); i++) {
955       this->m_device.deallocate(blockBs[i]);
956     }
957 
958 #undef CEIL_DIV
959   }
960 
961   /*
962    * Packs a LHS block of size (mt, kc) starting at lhs(m, k). Before packing
963    * the LHS block, check that all of the kernels that worked on the same
964    * mt_block_idx in the previous m_block are done.
965    */
966   template <typename packLArg, typename LhsPacker>
967   static void packLhs(const packLArg arg) {
968     // perform actual packing
969     LhsPacker pack_lhs;
970     pack_lhs(arg.blockA, arg.lhs.getSubMapper(arg.m_start, arg.k_start), arg.kc, arg.mc);
971   }
972 
973   /*
974    * Packs a RHS block of size (kc, nc) starting at (k, n) after checking that
975    * all kernels in the previous block are done.
976    * Then for each LHS future, we wait on the future and then call GEBP
977    * on the area packed by the future (which starts at
978    * blockA + future_idx * mt * kc) on the LHS and with the full packed
979    * RHS block.
980    * The output of this GEBP is written to output(m + i * mt, n).
981    */
982   template <typename packRKArg, typename RhsPacker, typename GebpKernel>
983   static void packRhsAndKernel(packRKArg arg) {
984     if (arg.need_to_pack) {
985       RhsPacker pack_rhs;
986       pack_rhs(arg.blockB, arg.rhs.getSubMapper(arg.k, arg.n), arg.kc, arg.nc);
987     }
988 
989     GebpKernel gebp;
990     for (Index mt_block_idx = 0; mt_block_idx < arg.num_blockAs; mt_block_idx++) {
991       const Index m_base_start = arg.m + arg.mc*mt_block_idx;
992       if (m_base_start < arg.max_m) {
993         Index blockAId = (arg.k_block_idx * arg.m_blocks + mt_block_idx + arg.m_block_idx) % arg.num_threads;
994         wait_until_ready((*arg.lhs_notifications)[blockAId]);
995         const Index actual_mc = numext::mini(m_base_start + arg.mc, arg.max_m) - m_base_start;
996         gebp(arg.output.getSubMapper(m_base_start, arg.n),
997              (*arg.blockAs)[blockAId], arg.blockB,
998              actual_mc, arg.kc, arg.nc, Scalar(1), -1, -1, 0, 0);
999 
1000         // Notify that the kernel is done.
1001         const Index set_idx = blockAId * arg.n_blocks + arg.n_block_idx;
1002         (*arg.kernel_notifications)[set_idx]->Notify();
1003       }
1004     }
1005   }
1006 #endif  // EIGEN_USE_SIMPLE_THREAD_POOL
1007 
1008   TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
1009                                bool shard_by_col, bool prepacked) const {
1010     const int packed_size = std::min<int>(PacketType<LhsScalar, Device>::size,
1011                                           PacketType<RhsScalar, Device>::size);
1012     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
1013     const double kd = static_cast<double>(bk);
1014     // Peak VFMA bandwidth is 0.5. However if we have not enough data for
1015     // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined
1016     // experimentally.
1017     double computeBandwidth = bk == 1 ? 4.0 :
1018           (shard_by_col ? bn : bm) < Traits::nr ||
1019           (shard_by_col ? bm : bn) < Traits::mr ? 2.0 : 0.5;
1020 #ifndef EIGEN_VECTORIZE_FMA
1021     // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors.
1022     // However for MULPS/ADDPS we have dependent sequence of 2 such instructions,
1023     // so overall bandwidth is 1.0.
1024     if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1025 #endif
1026     // Computations.
1027     TensorOpCost cost = TensorOpCost(0, 0, kd * computeBandwidth, true, packed_size);
1028     // Output stores.
1029     cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
1030     if (prepacked) {
1031       // Packing and kernels are executed in different tasks. When we calculate
1032       // task grain size we look only at kernel cost assuming that kernel
1033       // is more expensive than packing.
1034       return cost;
1035     }
1036     // Lhs/rhs loads + computations.
1037     TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * (kd / n);
1038     TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * (kd / m);
1039     // Lhs packing memory cost does not contribute considerably to overall
1040     // execution time because lhs is prefetched early and accessed sequentially.
1041     if (shard_by_col)
1042       lhsCost.dropMemoryCost();
1043     else
1044       rhsCost.dropMemoryCost();
1045     return cost + lhsCost + rhsCost;
1046   }
1047 };
1048 
1049 } // end namespace Eigen
1050 
1051 #endif  // EIGEN_USE_THREADS
1052 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
1053