1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
6 // Copyright (C) 2014 Eric Martin <eric@ericmart.in>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
14 
15 #if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
16 
17 namespace Eigen {
18 
19 template<typename Scalar, typename Index, typename LhsMapper,
20          typename RhsMapper, typename OutputMapper, bool needs_edge_check>
21 __device__ EIGEN_STRONG_INLINE void
EigenContractionKernelInternal(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,Scalar * lhs_shmem,Scalar * rhs_shmem,const Index m_size,const Index n_size,const Index k_size)22 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
23                                const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
24                        const Index m_size, const Index n_size, const Index k_size) {
25 
26   const Index m_block_idx = blockIdx.x;
27   const Index n_block_idx = blockIdx.y;
28 
29   const Index base_m = 64 * m_block_idx;
30   const Index base_n = 64 * n_block_idx;
31 
32   // declare and initialize 64 registers for output 8x8 block
33 
34   // prefetch registers
35   Scalar lhs_pf0;
36   Scalar lhs_pf1;
37   Scalar lhs_pf2;
38   Scalar lhs_pf3;
39   Scalar lhs_pf4;
40   Scalar lhs_pf5;
41   Scalar lhs_pf6;
42   Scalar lhs_pf7;
43 
44   Scalar rhs_pf0;
45   Scalar rhs_pf1;
46   Scalar rhs_pf2;
47   Scalar rhs_pf3;
48   Scalar rhs_pf4;
49   Scalar rhs_pf5;
50   Scalar rhs_pf6;
51   Scalar rhs_pf7;
52 
53   // shared memory is formatted
54   // (contract idx in block, nocontract idx in block, block idx)
55   // where block idx is column major. This transposition limits the number of
56   // bank conflicts when reading the LHS. The core idea is that since the contracting
57   // index is shared by both sides, then the contracting index should be in threadIdx.x.
58 
59   // On the LHS, we pad each row inside of each block with an extra element. This makes
60   // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
61   // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
62 
63   // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
64   // conflicts on writes and also none on reads.
65 
66   // storage indices
67   const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
68   const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
69 
70   const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
71   const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
72   const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
73   const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
74   const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
75   const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
76   const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
77   const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
78 
79   const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
80   const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
81   const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
82   const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
83   const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
84   const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
85   const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
86   const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
87 
88   // in the loading code, the following variables are important:
89   // threadIdx.x: the vertical position in an 8x8 block
90   // threadIdx.y: the vertical index of the 8x8 block in the grid
91   // threadIdx.z: the horizontal position in an 8x8 block
92   // k: the horizontal index of the 8x8 block in the grid
93   //
94   // The k parameter is implicit (it was the loop counter for a loop that went
95   // from 0 to <8, but now that loop is unrolled in the below code.
96 
97   const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
98   const Index lhs_vert = base_m + load_idx_vert;
99 
100 #define prefetchIntoRegisters(base_k)                           \
101   {                                                             \
102     lhs_pf0 = conv(0);                                          \
103     lhs_pf1 = conv(0);                                          \
104     lhs_pf2 = conv(0);                                          \
105     lhs_pf3 = conv(0);                                          \
106     lhs_pf4 = conv(0);                                          \
107     lhs_pf5 = conv(0);                                          \
108     lhs_pf6 = conv(0);                                          \
109     lhs_pf7 = conv(0);                                          \
110                                                                 \
111     rhs_pf0 = conv(0);                                          \
112     rhs_pf1 = conv(0);                                          \
113     rhs_pf2 = conv(0);                                          \
114     rhs_pf3 = conv(0);                                          \
115     rhs_pf4 = conv(0);                                          \
116     rhs_pf5 = conv(0);                                          \
117     rhs_pf6 = conv(0);                                          \
118     rhs_pf7 = conv(0);                                          \
119                                                                 \
120     if (!needs_edge_check || lhs_vert < m_size) {               \
121       const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8;   \
122       const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8;   \
123       const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8;   \
124       const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8;   \
125       const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8;   \
126       const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8;   \
127       const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8;   \
128       const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8;   \
129                                                                 \
130       if (!needs_edge_check || lhs_horiz_7 < k_size) {          \
131         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
132         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
133         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
134         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
135         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
136         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
137         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
138         lhs_pf7 = lhs(lhs_vert, lhs_horiz_7);                   \
139       } else if (lhs_horiz_6 < k_size) {                        \
140         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
141         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
142         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
143         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
144         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
145         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
146         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
147       } else if (lhs_horiz_5 < k_size) {                        \
148         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
149         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
150         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
151         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
152         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
153         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
154       } else if (lhs_horiz_4 < k_size) {                        \
155         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
156         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
157         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
158         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
159         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
160       } else if (lhs_horiz_3 < k_size) {                        \
161         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
162         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
163         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
164         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
165       } else if (lhs_horiz_2 < k_size) {                        \
166         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
167         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
168         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
169       } else if (lhs_horiz_1 < k_size) {                        \
170         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
171         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
172       } else if (lhs_horiz_0 < k_size) {                        \
173         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
174       }                                                         \
175     }                                                           \
176                                                                 \
177     const Index rhs_vert = base_k + load_idx_vert;              \
178     if (!needs_edge_check || rhs_vert < k_size) {               \
179       const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8;   \
180       const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8;   \
181       const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8;   \
182       const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8;   \
183       const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8;   \
184       const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8;   \
185       const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8;   \
186       const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8;   \
187                                                                 \
188       if (rhs_horiz_7 < n_size) {                               \
189         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
190         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
191         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
192         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
193         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
194         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
195         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
196         rhs_pf7 = rhs(rhs_vert, rhs_horiz_7);                   \
197       } else if (rhs_horiz_6 < n_size) {                        \
198         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
199         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
200         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
201         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
202         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
203         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
204         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
205       } else if (rhs_horiz_5 < n_size) {                        \
206         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
207         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
208         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
209         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
210         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
211         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
212       } else if (rhs_horiz_4 < n_size) {                        \
213         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
214         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
215         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
216         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
217         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
218       } else if (rhs_horiz_3 < n_size) {                        \
219         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
220         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
221         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
222         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
223       } else if (rhs_horiz_2 < n_size) {                        \
224         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
225         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
226         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
227       } else if (rhs_horiz_1 < n_size) {                        \
228         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
229         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
230       } else if (rhs_horiz_0 < n_size) {                        \
231         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
232       }                                                         \
233     }                                                           \
234   }                                                             \
235 
236 #define writeRegToShmem(_)                      \
237   lhs_shmem[lhs_store_idx_0] = lhs_pf0;         \
238   rhs_shmem[rhs_store_idx_0] = rhs_pf0;         \
239                                                 \
240   lhs_shmem[lhs_store_idx_1] = lhs_pf1;         \
241   rhs_shmem[rhs_store_idx_1] = rhs_pf1;         \
242                                                 \
243   lhs_shmem[lhs_store_idx_2] = lhs_pf2;         \
244   rhs_shmem[rhs_store_idx_2] = rhs_pf2;         \
245                                                 \
246   lhs_shmem[lhs_store_idx_3] = lhs_pf3;         \
247   rhs_shmem[rhs_store_idx_3] = rhs_pf3;         \
248                                                 \
249   lhs_shmem[lhs_store_idx_4] = lhs_pf4;         \
250   rhs_shmem[rhs_store_idx_4] = rhs_pf4;         \
251                                                 \
252   lhs_shmem[lhs_store_idx_5] = lhs_pf5;         \
253   rhs_shmem[rhs_store_idx_5] = rhs_pf5;         \
254                                                 \
255   lhs_shmem[lhs_store_idx_6] = lhs_pf6;         \
256   rhs_shmem[rhs_store_idx_6] = rhs_pf6;         \
257                                                 \
258   lhs_shmem[lhs_store_idx_7] = lhs_pf7;         \
259   rhs_shmem[rhs_store_idx_7] = rhs_pf7;         \
260 
261   // declare and initialize result array
262 #define res(i, j) _res_##i##j
263 #define initResultRow(i)                        \
264   Scalar res(i, 0) = conv(0);                   \
265   Scalar res(i, 1) = conv(0);                   \
266   Scalar res(i, 2) = conv(0);                   \
267   Scalar res(i, 3) = conv(0);                   \
268   Scalar res(i, 4) = conv(0);                   \
269   Scalar res(i, 5) = conv(0);                   \
270   Scalar res(i, 6) = conv(0);                   \
271   Scalar res(i, 7) = conv(0);                   \
272 
273   internal::scalar_cast_op<int, Scalar> conv;
274   initResultRow(0);
275   initResultRow(1);
276   initResultRow(2);
277   initResultRow(3);
278   initResultRow(4);
279   initResultRow(5);
280   initResultRow(6);
281   initResultRow(7);
282 #undef initResultRow
283 
284   for (Index base_k = 0; base_k < k_size; base_k += 64) {
285     // wait for previous iteration to finish with shmem. Despite common sense,
286     // the code is a bit faster with this here then at bottom of loop
287     __syncthreads();
288 
289     prefetchIntoRegisters(base_k);
290     writeRegToShmem();
291 
292     #undef prefetchIntoRegisters
293     #undef writeRegToShmem
294 
295     // wait for shared mem packing to be done before starting computation
296     __syncthreads();
297 
298     // compute 8x8 matrix product by outer product. This involves packing one column
299     // of LHS and one row of RHS into registers (takes 16 registers).
300 
301 #define lcol(i) _lcol##i
302     Scalar lcol(0);
303     Scalar lcol(1);
304     Scalar lcol(2);
305     Scalar lcol(3);
306     Scalar lcol(4);
307     Scalar lcol(5);
308     Scalar lcol(6);
309     Scalar lcol(7);
310 
311 #define rrow(j) _rrow##j
312     Scalar rrow(0);
313     Scalar rrow(1);
314     Scalar rrow(2);
315     Scalar rrow(3);
316     Scalar rrow(4);
317     Scalar rrow(5);
318     Scalar rrow(6);
319     Scalar rrow(7);
320 
321     // Now x corresponds to k, y to m, and z to n
322     const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
323     const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
324 
325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
327 
328 #define loadData(i, j)                          \
329     lcol(0) = lhs_element(0, j);               \
330     rrow(0) = rhs_element(i, 0);               \
331     lcol(1) = lhs_element(1, j);               \
332     rrow(1) = rhs_element(i, 1);               \
333     lcol(2) = lhs_element(2, j);               \
334     rrow(2) = rhs_element(i, 2);               \
335     lcol(3) = lhs_element(3, j);               \
336     rrow(3) = rhs_element(i, 3);               \
337     lcol(4) = lhs_element(4, j);               \
338     rrow(4) = rhs_element(i, 4);               \
339     lcol(5) = lhs_element(5, j);               \
340     rrow(5) = rhs_element(i, 5);               \
341     lcol(6) = lhs_element(6, j);               \
342     rrow(6) = rhs_element(i, 6);               \
343     lcol(7) = lhs_element(7, j);               \
344     rrow(7) = rhs_element(i, 7);               \
345 
346 #define computeCol(j)                           \
347     res(0, j) += lcol(0) * rrow(j);             \
348     res(1, j) += lcol(1) * rrow(j);             \
349     res(2, j) += lcol(2) * rrow(j);             \
350     res(3, j) += lcol(3) * rrow(j);             \
351     res(4, j) += lcol(4) * rrow(j);             \
352     res(5, j) += lcol(5) * rrow(j);             \
353     res(6, j) += lcol(6) * rrow(j);             \
354     res(7, j) += lcol(7) * rrow(j);             \
355 
356 #define computePass(i)                          \
357     loadData(i, i);                             \
358                                                 \
359     computeCol(0);                              \
360     computeCol(1);                              \
361     computeCol(2);                              \
362     computeCol(3);                              \
363     computeCol(4);                              \
364     computeCol(5);                              \
365     computeCol(6);                              \
366     computeCol(7);                              \
367 
368     computePass(0);
369     computePass(1);
370     computePass(2);
371     computePass(3);
372     computePass(4);
373     computePass(5);
374     computePass(6);
375     computePass(7);
376 
377 #undef lcol
378 #undef rrow
379 #undef lhs_element
380 #undef rhs_element
381 #undef loadData
382 #undef computeCol
383 #undef computePass
384   } // end loop over k
385 
386   // we've now iterated over all of the large (ie width 64) k blocks and
387   // accumulated results in registers. At this point thread (x, y, z) contains
388   // the sum across all big k blocks of the product of little k block of index (x, y)
389   // with block of index (y, z). To compute the final output, we need to reduce
390   // the 8 threads over y by summation.
391 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
392 
393 #define reduceRow(i, mask)                      \
394   shuffleInc(i, 0, mask);                       \
395   shuffleInc(i, 1, mask);                       \
396   shuffleInc(i, 2, mask);                       \
397   shuffleInc(i, 3, mask);                       \
398   shuffleInc(i, 4, mask);                       \
399   shuffleInc(i, 5, mask);                       \
400   shuffleInc(i, 6, mask);                       \
401   shuffleInc(i, 7, mask);                       \
402 
403 #define reduceMatrix(mask)                      \
404   reduceRow(0, mask);                           \
405   reduceRow(1, mask);                           \
406   reduceRow(2, mask);                           \
407   reduceRow(3, mask);                           \
408   reduceRow(4, mask);                           \
409   reduceRow(5, mask);                           \
410   reduceRow(6, mask);                           \
411   reduceRow(7, mask);                           \
412 
413   // actually perform the reduction, now each thread of index (_, y, z)
414   // contains the correct values in its registers that belong in the output
415   // block
416   reduceMatrix(1);
417   reduceMatrix(2);
418   reduceMatrix(4);
419 
420 #undef shuffleInc
421 #undef reduceRow
422 #undef reduceMatrix
423 
424   // now we need to copy the 64 values into main memory. We can't split work
425   // among threads because all variables are in registers. There's 2 ways
426   // to do this:
427   // (1) have 1 thread do 64 writes from registers into global memory
428   // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
429   //     each do 8 writes into global memory. We can just overwrite the shared
430   //     memory from the problem we just solved.
431   // (2) is slightly faster than (1) due to less branching and more ILP
432 
433   // TODO: won't yield much gain, but could just use currently unused shared mem
434   //       and then we won't have to sync
435   // wait for shared mem to be out of use
436   __syncthreads();
437 
438 #define writeResultShmem(i, j)                                          \
439   lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
440 
441 #define writeRow(i)                             \
442   writeResultShmem(i, 0);                       \
443   writeResultShmem(i, 1);                       \
444   writeResultShmem(i, 2);                       \
445   writeResultShmem(i, 3);                       \
446   writeResultShmem(i, 4);                       \
447   writeResultShmem(i, 5);                       \
448   writeResultShmem(i, 6);                       \
449   writeResultShmem(i, 7);                       \
450 
451   if (threadIdx.x == 0) {
452     writeRow(0);
453     writeRow(1);
454     writeRow(2);
455     writeRow(3);
456     writeRow(4);
457     writeRow(5);
458     writeRow(6);
459     writeRow(7);
460   }
461 #undef writeResultShmem
462 #undef writeRow
463 
464   const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
465   const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
466 
467   if (threadIdx.x < max_i_write) {
468     if (max_j_write == 8) {
469       // TODO: can i trade bank conflicts for coalesced writes?
470       Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
471       Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
472       Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
473       Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
474       Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
475       Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
476       Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
477       Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
478 
479       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
480       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
481       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
482       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
483       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
484       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
485       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
486       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
487     } else {
488 #pragma unroll 7
489       for (int j = 0; j < max_j_write; j++) {
490         Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
491         output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
492       }
493     }
494   }
495 #undef res
496 }
497 
498 
499 template<typename Scalar, typename Index, typename LhsMapper,
500          typename RhsMapper, typename OutputMapper>
501 __global__ void
502 __launch_bounds__(512)
EigenContractionKernel(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)503 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
504                        const OutputMapper output,
505                        const Index m_size, const Index n_size, const Index k_size) {
506   __shared__ Scalar lhs_shmem[72 * 64];
507   __shared__ Scalar rhs_shmem[72 * 64];
508 
509   const Index m_block_idx = blockIdx.x;
510   const Index n_block_idx = blockIdx.y;
511 
512   const Index base_m = 64 * m_block_idx;
513   const Index base_n = 64 * n_block_idx;
514 
515   if (base_m + 63 < m_size && base_n + 63 < n_size) {
516     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
517   } else {
518     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
519   }
520 }
521 
522 
523 template<typename Index, typename LhsMapper,
524          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
525          bool CHECK_RHS_BOUNDARY>
526 __device__ EIGEN_STRONG_INLINE void
EigenFloatContractionKernelInternal16x16(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,float2 lhs_shmem2[][16],float2 rhs_shmem2[][8],const Index m_size,const Index n_size,const Index k_size,const Index base_m,const Index base_n)527 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
528                        const OutputMapper output, float2 lhs_shmem2[][16],
529                        float2 rhs_shmem2[][8], const Index m_size,
530                        const Index n_size, const Index k_size,
531                        const Index base_m, const Index base_n) {
532   typedef float Scalar;
533 
534   // prefetch registers
535   float4 lhs_pf0, rhs_pf0;
536 
537   float4 results[4];
538   for (int i=0; i < 4; i++) {
539     results[i].x = results[i].y = results[i].z = results[i].w = 0;
540   }
541 
542 
543 #define prefetch_lhs(reg, row, col)                   \
544     if (!CHECK_LHS_BOUNDARY) {                        \
545       if (col < k_size) {                             \
546         reg =lhs.loadPacket<Unaligned>(row, col);     \
547       }                                               \
548     } else {                                          \
549       if (col < k_size) {                             \
550         if (row + 3 < m_size) {                       \
551           reg =lhs.loadPacket<Unaligned>(row, col);   \
552         } else if (row + 2 < m_size) {                \
553           reg.x =lhs(row + 0, col);                   \
554           reg.y =lhs(row + 1, col);                   \
555           reg.z =lhs(row + 2, col);                   \
556         } else if (row + 1 < m_size) {                \
557           reg.x =lhs(row + 0, col);                   \
558           reg.y =lhs(row + 1, col);                   \
559         } else if (row  < m_size) {                   \
560           reg.x =lhs(row + 0, col);                   \
561         }                                             \
562       }                                               \
563     }                                                 \
564 
565 
566   Index lhs_vert = base_m+threadIdx.x*4;
567 
568   for (Index k = 0; k < k_size; k += 16) {
569     lhs_pf0 = internal::pset1<float4>(0);
570     rhs_pf0 = internal::pset1<float4>(0);
571 
572     Index lhs_horiz = threadIdx.y+k;
573     prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
574 
575     Index rhs_vert = k+(threadIdx.x%4)*4;
576     Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
577 
578     if (!CHECK_RHS_BOUNDARY) {
579       if ((rhs_vert + 3) < k_size) {
580         // just CHECK_RHS_BOUNDARY
581         rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
582       } else if (rhs_vert + 2 < k_size) {
583         // just CHECK_RHS_BOUNDARY
584         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
585         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
586         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
587       } else if (rhs_vert + 1 < k_size) {
588         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
589         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
590       } else if (rhs_vert  < k_size) {
591         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
592       }
593     } else {
594       if (rhs_horiz0 < n_size) {
595         if ((rhs_vert + 3) < k_size) {
596           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
597         } else if ((rhs_vert + 2) < k_size) {
598           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
599           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
600           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
601         } else if ((rhs_vert + 1) < k_size) {
602           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
603           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
604         } else if (rhs_vert  < k_size) {
605           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
606         }
607       }
608     }
609     float x1, x2 ;
610     // the following can be a bitwise operation..... some day.
611     if((threadIdx.x%8) < 4) {
612       x1 = rhs_pf0.y;
613       x2 = rhs_pf0.w;
614     } else {
615       x1 = rhs_pf0.x;
616       x2 = rhs_pf0.z;
617     }
618     x1 = __shfl_xor(x1, 4);
619     x2 = __shfl_xor(x2, 4);
620     if((threadIdx.x%8) < 4) {
621       rhs_pf0.y = x1;
622       rhs_pf0.w = x2;
623     } else {
624       rhs_pf0.x = x1;
625       rhs_pf0.z = x2;
626     }
627 
628     // We have 64 features.
629     // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
630     // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
631     // ...
632     // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
633     // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
634     // ...
635     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
636     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
637 
638     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
639     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
640     // ...
641     // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
642     // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63)
643     // ...
644 
645     lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
646     lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
647 
648 
649 #define add_vals(fl1, fl2, fr1, fr2)\
650     results[0].x += fl1.x * fr1.x;\
651     results[0].y += fl1.y * fr1.x;\
652     results[0].z += fl2.x * fr1.x;\
653     results[0].w += fl2.y * fr1.x;\
654 \
655     results[1].x += fl1.x * fr1.y;\
656     results[1].y += fl1.y * fr1.y;\
657     results[1].z += fl2.x * fr1.y;\
658     results[1].w += fl2.y * fr1.y;\
659 \
660     results[2].x += fl1.x * fr2.x;\
661     results[2].y += fl1.y * fr2.x;\
662     results[2].z += fl2.x * fr2.x;\
663     results[2].w += fl2.y * fr2.x;\
664 \
665     results[3].x += fl1.x * fr2.y;\
666     results[3].y += fl1.y * fr2.y;\
667     results[3].z += fl2.x * fr2.y;\
668     results[3].w += fl2.y * fr2.y;\
669 
670     __syncthreads();
671 
672     // Do the multiplies.
673     #pragma unroll
674     for (int koff = 0; koff < 16; koff ++) {
675       // 32 x threads.
676       float2 fl1 = lhs_shmem2[koff][threadIdx.x];
677       float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
678 
679       int start_feature = threadIdx.y * 4;
680       float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
681       float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
682 
683       add_vals(fl1, fl2, fr1, fr2)
684     }
685     __syncthreads();
686   }
687 
688 #undef prefetch_lhs
689 #undef add_vals
690 
691   Index horiz_base = threadIdx.y*4+base_n;
692   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
693     for (int i = 0; i < 4; i++) {
694       output(lhs_vert, horiz_base + i) = results[i].x;
695       output(lhs_vert + 1, horiz_base + i) = results[i].y;
696       output(lhs_vert + 2, horiz_base + i) = results[i].z;
697       output(lhs_vert + 3, horiz_base + i) = results[i].w;
698     }
699   } else if (!CHECK_RHS_BOUNDARY) {
700     // CHECK LHS
701     if (lhs_vert + 3 < m_size) {
702       for (int i = 0; i < 4; i++) {
703         output(lhs_vert, horiz_base + i) = results[i].x;
704         output(lhs_vert + 1, horiz_base + i) = results[i].y;
705         output(lhs_vert + 2, horiz_base + i) = results[i].z;
706         output(lhs_vert + 3, horiz_base + i) = results[i].w;
707       }
708     } else if (lhs_vert + 2 < m_size) {
709       for (int i = 0; i < 4; i++) {
710         output(lhs_vert, horiz_base + i) = results[i].x;
711         output(lhs_vert + 1, horiz_base + i) = results[i].y;
712         output(lhs_vert + 2, horiz_base + i) = results[i].z;
713       }
714     } else if (lhs_vert + 1 < m_size) {
715       for (int i = 0; i < 4; i++) {
716         output(lhs_vert, horiz_base + i) = results[i].x;
717         output(lhs_vert + 1, horiz_base + i) = results[i].y;
718       }
719     } else if (lhs_vert  < m_size) {
720       for (int i = 0; i < 4; i++) {
721         output(lhs_vert, horiz_base + i) = results[i].x;
722       }
723     }
724   } else if (!CHECK_LHS_BOUNDARY) {
725     // CHECK RHS
726     /*
727     int ncols_rem = fminf(n_size- horiz_base, 4);
728     for (int i = 0; i < ncols_rem; i++) {
729       output(lhs_vert, horiz_base + i) = results[i].x;
730       output(lhs_vert + 1, horiz_base + i) = results[i].y;
731       output(lhs_vert + 2, horiz_base + i) = results[i].z;
732       output(lhs_vert + 3, horiz_base + i) = results[i].w;
733     }*/
734     for (int i = 0; i < 4; i++) {
735       if (horiz_base+i < n_size) {
736         output(lhs_vert, horiz_base + i) = results[i].x;
737         output(lhs_vert + 1, horiz_base + i) = results[i].y;
738         output(lhs_vert + 2, horiz_base + i) = results[i].z;
739         output(lhs_vert + 3, horiz_base + i) = results[i].w;
740        }
741     }
742   } else {
743     // CHECK both boundaries.
744     for (int i = 0; i < 4; i++) {
745       if (horiz_base+i < n_size) {
746         if (lhs_vert < m_size)
747           output(lhs_vert, horiz_base + i) = results[i].x;
748         if (lhs_vert + 1 < m_size)
749           output(lhs_vert + 1, horiz_base + i) = results[i].y;
750         if (lhs_vert + 2 < m_size)
751           output(lhs_vert + 2, horiz_base + i) = results[i].z;
752         if (lhs_vert + 3 < m_size)
753           output(lhs_vert + 3, horiz_base + i) = results[i].w;
754       }
755     }
756   }
757 }
758 
759 
760 template<typename Index, typename LhsMapper,
761          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
762          bool CHECK_RHS_BOUNDARY>
763 __device__ EIGEN_STRONG_INLINE void
EigenFloatContractionKernelInternal(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,float2 lhs_shmem2[][32],float2 rhs_shmem2[][8],const Index m_size,const Index n_size,const Index k_size,const Index base_m,const Index base_n)764 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
765                        const OutputMapper output, float2 lhs_shmem2[][32],
766                        float2 rhs_shmem2[][8], const Index m_size,
767                        const Index n_size, const Index k_size,
768                        const Index base_m, const Index base_n) {
769   typedef float Scalar;
770 
771   // prefetch registers
772   float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
773   float4 rhs_pf0, rhs_pf1;
774 
775   float4 results[8];
776   for (int i=0; i < 8; i++) {
777     results[i].x = results[i].y = results[i].z = results[i].w = 0;
778   }
779 
780 
781   Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
782   for (Index k = 0; k < k_size; k += 32) {
783     lhs_pf0 = internal::pset1<float4>(0);
784     lhs_pf1 = internal::pset1<float4>(0);
785     lhs_pf2 = internal::pset1<float4>(0);
786     lhs_pf3 = internal::pset1<float4>(0);
787 
788     rhs_pf0 = internal::pset1<float4>(0);
789     rhs_pf1 = internal::pset1<float4>(0);
790 
791      if (!CHECK_LHS_BOUNDARY) {
792       if ((threadIdx.y/4+k+24) < k_size) {
793         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
794         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
795         lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
796         lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
797       } else if ((threadIdx.y/4+k+16) < k_size) {
798         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
799         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
800         lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
801       } else if ((threadIdx.y/4+k+8) < k_size) {
802         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
803         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
804       } else if ((threadIdx.y/4+k) < k_size) {
805         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
806       }
807     } else {
808       // just CHECK_LHS_BOUNDARY
809       if (lhs_vert + 3 < m_size) {
810         if ((threadIdx.y/4+k+24) < k_size) {
811           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
812           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
813           lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
814           lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
815         } else if ((threadIdx.y/4+k+16) < k_size) {
816           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
817           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
818           lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
819         } else if ((threadIdx.y/4+k+8) < k_size) {
820           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
821           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
822         } else if ((threadIdx.y/4+k) < k_size) {
823           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
824         }
825       } else if (lhs_vert + 2 < m_size) {
826         if ((threadIdx.y/4+k+24) < k_size) {
827           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
828           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
829           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
830           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
831           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
832           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
833           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
834           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
835           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
836           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
837           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
838           lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
839         } else if ((threadIdx.y/4+k+16) < k_size) {
840           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
841           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
842           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
843           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
844           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
845           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
846           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
847           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
848           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
849         } else if ((threadIdx.y/4+k+8) < k_size) {
850           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
851           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
852           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
853           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
854           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
855           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
856         } else if ((threadIdx.y/4+k) < k_size) {
857           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
858           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
859           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
860         }
861       } else if (lhs_vert + 1 < m_size) {
862         if ((threadIdx.y/4+k+24) < k_size) {
863           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
864           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
865           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
866           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
867           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
868           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
869           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
870           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
871         } else if ((threadIdx.y/4+k+16) < k_size) {
872           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
873           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
874           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
875           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
876           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
877           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
878         } else if ((threadIdx.y/4+k+8) < k_size) {
879           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
880           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
881           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
882           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
883         } else if ((threadIdx.y/4+k) < k_size) {
884           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
885           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
886         }
887       } else if (lhs_vert < m_size) {
888         if ((threadIdx.y/4+k+24) < k_size) {
889           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
890           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
891           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
892           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
893         } else if ((threadIdx.y/4+k+16) < k_size) {
894           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
895           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
896           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
897         } else if ((threadIdx.y/4+k+8) < k_size) {
898           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
899           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
900         } else if ((threadIdx.y/4+k) < k_size) {
901           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
902         }
903       }
904     }
905     __syncthreads();
906     Index rhs_vert = k+threadIdx.x*4;
907     Index rhs_horiz0 = threadIdx.y*2+base_n;
908     Index rhs_horiz1 = threadIdx.y*2+1+base_n;
909     if (!CHECK_RHS_BOUNDARY) {
910       if ((rhs_vert + 3) < k_size) {
911         // just CHECK_RHS_BOUNDARY
912         rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
913         rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
914       } else if (rhs_vert + 2 < k_size) {
915         // just CHECK_RHS_BOUNDARY
916         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
917         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
918         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
919         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
920         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
921         rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
922       } else if (rhs_vert + 1 < k_size) {
923         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
924         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
925         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
926         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
927       } else if (rhs_vert  < k_size) {
928         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
929         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
930       }
931     } else {
932       if (rhs_horiz1 < n_size) {
933         if ((rhs_vert + 3) < k_size) {
934           // just CHECK_RHS_BOUNDARY
935           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
936           rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
937         } else if (rhs_vert + 2 < k_size) {
938           // just CHECK_RHS_BOUNDARY
939           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
940           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
941           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
942           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
943           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
944           rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
945         } else if (k+threadIdx.x*4 + 1 < k_size) {
946           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
947           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
948           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
949           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
950         } else if (k+threadIdx.x*4  < k_size) {
951           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
952           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
953         }
954       } else if (rhs_horiz0 < n_size) {
955         if ((rhs_vert + 3) < k_size) {
956           // just CHECK_RHS_BOUNDARY
957           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
958         } else if ((rhs_vert + 2) < k_size) {
959           // just CHECK_RHS_BOUNDARY
960           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
961           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
962           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
963         } else if ((rhs_vert + 1) < k_size) {
964           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
965           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
966         } else if (rhs_vert  < k_size) {
967           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
968         }
969       }
970     }
971     __syncthreads();
972     // Loaded. Do computation
973     // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
974     // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
975     // ..
976     // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
977     rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
978     // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
979     // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
980     // ..
981     rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
982     // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
983     // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
984     rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
985     // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
986     // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
987     rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
988 
989     // LHS.
990     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
991     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
992     // ...
993     // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
994     // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
995 
996 
997 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
998       results[0].x += a_feat1.x * f1.x;\
999       results[1].x += a_feat1.x * f1.y;\
1000       results[2].x += a_feat1.x * f2.x;\
1001       results[3].x += a_feat1.x * f2.y;\
1002       results[4].x += a_feat1.x * f3.x;\
1003       results[5].x += a_feat1.x * f3.y;\
1004       results[6].x += a_feat1.x * f4.x;\
1005       results[7].x += a_feat1.x * f4.y;\
1006 \
1007       results[0].y += a_feat1.y * f1.x;\
1008       results[1].y += a_feat1.y * f1.y;\
1009       results[2].y += a_feat1.y * f2.x;\
1010       results[3].y += a_feat1.y * f2.y;\
1011       results[4].y += a_feat1.y * f3.x;\
1012       results[5].y += a_feat1.y * f3.y;\
1013       results[6].y += a_feat1.y * f4.x;\
1014       results[7].y += a_feat1.y * f4.y;\
1015 \
1016       results[0].z += a_feat2.x * f1.x;\
1017       results[1].z += a_feat2.x * f1.y;\
1018       results[2].z += a_feat2.x * f2.x;\
1019       results[3].z += a_feat2.x * f2.y;\
1020       results[4].z += a_feat2.x * f3.x;\
1021       results[5].z += a_feat2.x * f3.y;\
1022       results[6].z += a_feat2.x * f4.x;\
1023       results[7].z += a_feat2.x * f4.y;\
1024 \
1025       results[0].w += a_feat2.y * f1.x;\
1026       results[1].w += a_feat2.y * f1.y;\
1027       results[2].w += a_feat2.y * f2.x;\
1028       results[3].w += a_feat2.y * f2.y;\
1029       results[4].w += a_feat2.y * f3.x;\
1030       results[5].w += a_feat2.y * f3.y;\
1031       results[6].w += a_feat2.y * f4.x;\
1032       results[7].w += a_feat2.y * f4.y;\
1033 
1034     lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1035     lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1036     lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1037     lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1038 
1039     lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1040     lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1041     lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1042     lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1043 
1044     __syncthreads();
1045 
1046     // Do the multiplies.
1047     #pragma unroll
1048     for (int koff = 0; koff < 32; koff ++) {
1049       float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1050       float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1051 
1052       // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
1053       int start_feature = (threadIdx.y / 4) * 8;
1054 
1055       float2 br1 = rhs_shmem2[start_feature/2 +     (koff % 4) * 32][koff/4];
1056       float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
1057       float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
1058       float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
1059 
1060       add_vals(a3, a4, br1, br2, br3, br4)
1061     }
1062     __syncthreads();
1063   } // end loop over k
1064 
1065 
1066   __syncthreads();
1067   Index horiz_base = (threadIdx.y/4)*8+base_n;
1068   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1069     for (int i = 0; i < 8; i++) {
1070       output(lhs_vert, horiz_base + i) = results[i].x;
1071       output(lhs_vert + 1, horiz_base + i) = results[i].y;
1072       output(lhs_vert + 2, horiz_base + i) = results[i].z;
1073       output(lhs_vert + 3, horiz_base + i) = results[i].w;
1074     }
1075   } else if (!CHECK_RHS_BOUNDARY) {
1076     if (lhs_vert + 3 < m_size) {
1077       for (int i = 0; i < 8; i++) {
1078         output(lhs_vert, horiz_base + i) = results[i].x;
1079         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1080         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1081         output(lhs_vert + 3, horiz_base + i) = results[i].w;
1082       }
1083     } else if (lhs_vert + 2 < m_size) {
1084       for (int i = 0; i < 8; i++) {
1085         output(lhs_vert, horiz_base + i) = results[i].x;
1086         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1087         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1088       }
1089     } else if (lhs_vert + 1 < m_size) {
1090       for (int i = 0; i < 8; i++) {
1091         output(lhs_vert, horiz_base + i) = results[i].x;
1092         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1093       }
1094     } else if (lhs_vert  < m_size) {
1095       for (int i = 0; i < 8; i++) {
1096         output(lhs_vert, horiz_base + i) = results[i].x;
1097       }
1098     }
1099   } else if (!CHECK_LHS_BOUNDARY) {
1100     // CHECK BOUNDARY_B
1101     for (int i = 0; i < 8; i++) {
1102       if (horiz_base + i < n_size) {
1103         output(lhs_vert, horiz_base + i) = results[i].x;
1104         output(lhs_vert + 1, horiz_base + i) = results[i].y;
1105         output(lhs_vert + 2, horiz_base + i) = results[i].z;
1106         output(lhs_vert + 3, horiz_base + i) = results[i].w;
1107       }
1108     }
1109   } else {
1110     // CHECK both boundaries.
1111     for (int i = 0; i < 8; i++) {
1112       if (horiz_base + i < n_size) {
1113         if (lhs_vert < m_size)
1114           output(lhs_vert, horiz_base + i) = results[i].x;
1115         if (lhs_vert + 1 < m_size)
1116           output(lhs_vert + 1, horiz_base + i) = results[i].y;
1117         if (lhs_vert + 2 < m_size)
1118           output(lhs_vert + 2, horiz_base + i) = results[i].z;
1119         if (lhs_vert + 3 < m_size)
1120           output(lhs_vert + 3, horiz_base + i) = results[i].w;
1121       }
1122     }
1123   }
1124 }
1125 
1126 
1127 template<typename Index, typename LhsMapper,
1128          typename RhsMapper, typename OutputMapper>
1129 __global__ void
1130 __launch_bounds__(256)
EigenFloatContractionKernel(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)1131 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
1132                        const OutputMapper output,
1133                        const Index m_size, const Index n_size, const Index k_size) {
1134   __shared__ float2 lhs_shmem[64*32];
1135   __shared__ float2 rhs_shmem[128*8];
1136 
1137   typedef float2 LHS_MEM[64][32];
1138   typedef float2 RHS_MEM[128][8];
1139 
1140   typedef float2 LHS_MEM16x16[32][16];
1141   typedef float2 RHS_MEM16x16[64][8];
1142 
1143   const Index m_block_idx = blockIdx.x;
1144   const Index n_block_idx = blockIdx.y;
1145 
1146   const Index base_m = 128 * m_block_idx;
1147   const Index base_n = 64 * n_block_idx;
1148 
1149   bool check_rhs = (base_n + 63) >= n_size;
1150   bool check_lhs128 = (base_m + 127) >= m_size;
1151 
1152   if (!check_rhs) {
1153     if (!check_lhs128) {
1154       // >= 128 rows left
1155       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1156                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1157     } else {
1158       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1159                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1160     }
1161   } else {
1162     if (!check_lhs128) {
1163       // >= 128 rows left
1164       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1165                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1166     } else {
1167       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1168                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1169     }
1170   }
1171 }
1172 
1173 template<typename Index, typename LhsMapper,
1174          typename RhsMapper, typename OutputMapper>
1175 __global__ void
1176 __launch_bounds__(256)
EigenFloatContractionKernel16x16(const LhsMapper lhs,const RhsMapper rhs,const OutputMapper output,const Index m_size,const Index n_size,const Index k_size)1177 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
1178                        const OutputMapper output,
1179                        const Index m_size, const Index n_size, const Index k_size) {
1180   __shared__ float2 lhs_shmem[32][16];
1181   __shared__ float2 rhs_shmem[64][8];
1182 
1183   const Index m_block_idx = blockIdx.x;
1184   const Index n_block_idx = blockIdx.y;
1185 
1186   const Index base_m = 64 * m_block_idx;
1187   const Index base_n = 64 * n_block_idx;
1188 
1189   if (base_m + 63 < m_size) {
1190     if (base_n + 63 < n_size) {
1191       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1192     } else {
1193       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1194     }
1195   } else {
1196     if (base_n + 63 < n_size) {
1197       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1198     } else {
1199       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1200     }
1201   }
1202 }
1203 
1204 
1205 template<typename Indices, typename LeftArgType, typename RightArgType>
1206 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> :
1207     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > {
1208 
1209   typedef GpuDevice Device;
1210 
1211   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
1212   typedef TensorContractionEvaluatorBase<Self> Base;
1213 
1214   typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
1215   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1216   typedef typename XprType::Index Index;
1217   typedef typename XprType::CoeffReturnType CoeffReturnType;
1218   typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
1219 
1220   enum {
1221     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1222   };
1223 
1224   // Most of the code is assuming that both input tensors are ColMajor. If the
1225   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
1226   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
1227   // will pretend B is LHS and A is RHS.
1228   typedef typename internal::conditional<
1229     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
1230   typedef typename internal::conditional<
1231     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
1232 
1233   static const int LDims =
1234       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1235   static const int RDims =
1236       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1237   static const int ContractDims = internal::array_size<Indices>::value;
1238 
1239   typedef array<Index, LDims> left_dim_mapper_t;
1240   typedef array<Index, RDims> right_dim_mapper_t;
1241 
1242   typedef array<Index, ContractDims> contract_t;
1243   typedef array<Index, LDims - ContractDims> left_nocontract_t;
1244   typedef array<Index, RDims - ContractDims> right_nocontract_t;
1245 
1246   static const int NumDims = LDims + RDims - 2 * ContractDims;
1247 
1248   typedef DSizes<Index, NumDims> Dimensions;
1249 
1250   // typedefs needed in evalTo
1251   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
1252   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
1253 
1254   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1255   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1256 
1257   typedef typename LeftEvaluator::Dimensions LeftDimensions;
1258   typedef typename RightEvaluator::Dimensions RightDimensions;
1259 
1260   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
1261       Base(op, device) {}
1262 
1263   // We need to redefine this method to make nvcc happy
1264   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
1265     this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1266     this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1267     if (data) {
1268       evalTo(data);
1269       return false;
1270     } else {
1271       this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
1272       evalTo(this->m_result);
1273       return true;
1274     }
1275   }
1276 
1277   void evalTo(Scalar* buffer) const {
1278     if (this->m_lhs_inner_dim_contiguous) {
1279       if (this->m_rhs_inner_dim_contiguous) {
1280         if (this->m_rhs_inner_dim_reordered) {
1281           evalTyped<true, true, true, Unaligned>(buffer);
1282         }
1283         else {
1284           evalTyped<true, true, false, Unaligned>(buffer);
1285         }
1286       }
1287       else {
1288        if (this->m_rhs_inner_dim_reordered) {
1289           evalTyped<true, false, true, Unaligned>(buffer);
1290         }
1291         else {
1292           evalTyped<true, false, false, Unaligned>(buffer);
1293         }
1294       }
1295     }
1296     else {
1297       if (this->m_rhs_inner_dim_contiguous) {
1298         if (this->m_rhs_inner_dim_reordered) {
1299           evalTyped<false, true, true, Unaligned>(buffer);
1300         }
1301         else {
1302           evalTyped<false, true, false, Unaligned>(buffer);
1303         }
1304       }
1305       else {
1306        if (this->m_rhs_inner_dim_reordered) {
1307           evalTyped<false, false, true, Unaligned>(buffer);
1308         }
1309         else {
1310           evalTyped<false, false, false, Unaligned>(buffer);
1311         }
1312       }
1313     }
1314   }
1315 
1316   template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
1317     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1318     const Index m_blocks = (m + 63) / 64;
1319     const Index n_blocks = (n + 63) / 64;
1320     const dim3 num_blocks(m_blocks, n_blocks, 1);
1321     const dim3 block_size(8, 8, 8);
1322     LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1323     }
1324   };
1325 
1326   template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1327     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
1328       if (m < 768 || n < 768) {
1329         const Index m_blocks = (m + 63) / 64;
1330         const Index n_blocks = (n + 63) / 64;
1331         const dim3 num_blocks(m_blocks, n_blocks, 1);
1332         const dim3 block_size(16, 16, 1);
1333         LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1334       } else {
1335         const Index m_blocks = (m + 127) / 128;
1336         const Index n_blocks = (n + 63) / 64;
1337         const dim3 num_blocks(m_blocks, n_blocks, 1);
1338         const dim3 block_size(8, 32, 1);
1339         LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
1340       }
1341     }
1342   };
1343 
1344   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1345   void evalTyped(Scalar* buffer) const {
1346     // columns in left side, rows in right side
1347     const Index k = this->m_k_size;
1348     EIGEN_UNUSED_VARIABLE(k)
1349 
1350     // rows in left side
1351     const Index m = this->m_i_size;
1352 
1353     // columns in right side
1354     const Index n = this->m_j_size;
1355 
1356     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
1357     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
1358 
1359     typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
1360                                                    LeftEvaluator, left_nocontract_t,
1361                                                    contract_t, 4,
1362                                                    lhs_inner_dim_contiguous,
1363                                                    false, Unaligned> LhsMapper;
1364 
1365     typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
1366                                                    RightEvaluator, right_nocontract_t,
1367                                                    contract_t, 4,
1368                                                    rhs_inner_dim_contiguous,
1369                                                    rhs_inner_dim_reordered, Unaligned> RhsMapper;
1370 
1371     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1372 
1373 
1374     // initialize data mappers
1375     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1376                   this->m_left_contracting_strides, this->m_k_strides);
1377 
1378     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1379                   this->m_right_contracting_strides, this->m_k_strides);
1380 
1381     OutputMapper output(buffer, m);
1382 
1383     setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte);
1384     LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output,  m, n, k, this->m_device);
1385   }
1386 };
1387 
1388 } // end namespace Eigen
1389 
1390 #endif // EIGEN_USE_GPU and __CUDACC__
1391 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
1392