1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
16 #define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
17 
18 #include <iostream>
19 #include "base.h"
20 
21 namespace gemmlowp {
22 namespace meta {
23 
24 template <typename Executor, typename Params, int kernel_m, int kernel_n,
25           int kernel_k>
26 void Gemm(const Params& params);
27 
28 class GemmExecutorPackRHS {
29  public:
30   template <typename P>
EstimateScratchSize(const P & params,int kernel_m,int kernel_n,int kernel_k)31   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
32                                  int kernel_k) {
33     const int lhs_scratch =
34         StreamUtil<typename P::InType, typename P::LeftStream>::Scratch(
35             params.left_stream, kernel_m, kernel_k);
36     const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n);
37     const int rhs_scratch =
38         rhs_chunks *
39         StreamUtil<typename P::InType, typename P::RightStream>::Scratch(
40             params.right_stream, kernel_n, kernel_k);
41     return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch);
42   }
43 
44   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
45             int k_leftovers>
ExecuteDispatch3D(const P & params)46   static void ExecuteDispatch3D(const P& params) {
47     // Shorthand typedefs for streams and multiply kernels.
48     typedef typename P::InType InType;
49     typedef typename P::OutType OutType;
50 
51     typedef Stream<typename P::InType, m, k, k_leftovers,
52                    typename P::LeftStream>
53         LeftStreamF;
54     typedef Stream<typename P::InType, m_leftovers, k, k_leftovers,
55                    typename P::LeftStream>
56         LeftStreamL;
57 
58     typedef Stream<typename P::InType, n, k, k_leftovers,
59                    typename P::RightStream>
60         RightStreamF;
61     typedef Stream<typename P::InType, n_leftovers, k, k_leftovers,
62                    typename P::RightStream>
63         RightStreamL;
64 
65     typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream>
66         OutputStreamFF;
67     typedef Stream<typename P::OutType, m_leftovers, n, 0,
68                    typename P::OutputStream>
69         OutputStreamLF;
70 
71     typedef MulKernel<typename P::InType, typename P::OutType,
72                       typename P::Kernel, typename P::OutputStream, m, n, k>
73         KernelFF;
74     typedef MulKernel<typename P::InType, typename P::OutType,
75                       typename P::Kernel, typename P::OutputStream, m,
76                       n_leftovers, k>
77         KernelFL;
78     typedef MulKernel<typename P::InType, typename P::OutType,
79                       typename P::Kernel, typename P::OutputStream, m_leftovers,
80                       n, k>
81         KernelLF;
82     typedef MulKernel<typename P::InType, typename P::OutType,
83                       typename P::Kernel, typename P::OutputStream, m_leftovers,
84                       n_leftovers, k>
85         KernelLL;
86 
87 #ifdef DEBUG
88 #ifdef DEBUG_METAGEMM_VERBOSE
89     std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n
90               << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x"
91               << k_leftovers << " -- " << params.m << "x" << params.n << "x"
92               << params.k << std::endl;
93     LeftStreamF::Debug(params.left_stream);
94     LeftStreamL::Debug(params.left_stream);
95 
96     RightStreamF::Debug(params.right_stream);
97     RightStreamL::Debug(params.right_stream);
98 
99     OutputStreamFF::Debug(params.fused_kernel.output_stream);
100     OutputStreamLF::Debug(params.fused_kernel.output_stream);
101 
102     KernelFF::Debug(params.fused_kernel);
103     KernelFL::Debug(params.fused_kernel);
104     KernelLF::Debug(params.fused_kernel);
105     KernelLL::Debug(params.fused_kernel);
106 #endif
107 #endif
108 
109     int lhs_chunks = params.m / m;
110     int rhs_chunks = params.n / n;
111 
112     // Scratch memory for packed LHS & RHS chunks.
113 
114     std::uint8_t* packed_lhs = params.scratch;
115     std::uint8_t* packed_rhs =
116         params.scratch + LeftStreamF::Scratch(params.left_stream);
117 
118     // Pack full RHS first.
119 
120     std::uint8_t* packed_rhs_chunk = packed_rhs;
121     const int packed_rhs_chunk_size =
122         RightStreamF::PackedStride(params.right_stream);
123 
124     {
125       const std::uint8_t* rhs_chunk =
126           reinterpret_cast<const std::uint8_t*>(params.rhs);
127       const int rhs_chunk_size =
128           RightStreamF::UnpackedStride(params.right_stream);
129 
130       for (int i = 0; i < rhs_chunks; ++i) {
131         RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk),
132                            params.right_stream,
133                            reinterpret_cast<InType*>(packed_rhs_chunk));
134 
135         rhs_chunk += rhs_chunk_size;
136         packed_rhs_chunk += packed_rhs_chunk_size;
137       }
138 
139       RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk),
140                          params.right_stream,
141                          reinterpret_cast<InType*>(packed_rhs_chunk));
142     }
143 
144     // Multiply RHS by LHS one LHS chunk at a time.
145 
146     const std::uint8_t* lhs_chunk =
147         reinterpret_cast<const std::uint8_t*>(params.lhs);
148     std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result);
149     std::uint8_t* result_chunk = result_strip;
150 
151     {
152       const int lhs_chunk_size =
153           LeftStreamF::UnpackedStride(params.left_stream);
154       const int result_strip_size =
155           OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream);
156       const int result_chunk_size =
157           OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream);
158 
159       for (int i = 0; i < lhs_chunks; ++i) {
160         LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk),
161                           params.left_stream,
162                           reinterpret_cast<InType*>(packed_lhs));
163 
164         result_chunk = result_strip;
165         packed_rhs_chunk = packed_rhs;
166 
167         for (int j = 0; j < rhs_chunks; ++j) {
168           KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs),
169                              reinterpret_cast<const InType*>(packed_rhs_chunk),
170                              params.fused_kernel,
171                              reinterpret_cast<OutType*>(result_chunk));
172 
173           result_chunk += result_chunk_size;
174           packed_rhs_chunk += packed_rhs_chunk_size;
175         }
176 
177         KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs),
178                            reinterpret_cast<const InType*>(packed_rhs_chunk),
179                            params.fused_kernel,
180                            reinterpret_cast<OutType*>(result_chunk));
181 
182         lhs_chunk += lhs_chunk_size;
183         result_strip += result_strip_size;
184       }
185     }
186 
187     // Leftover LHS chunk.
188     if (m_leftovers > 0) {  // static if
189       const int result_chunk_size =
190           OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream);
191 
192       LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk),
193                         params.left_stream,
194                         reinterpret_cast<InType*>(packed_lhs));
195 
196       result_chunk = result_strip;
197       packed_rhs_chunk = packed_rhs;
198 
199       for (int i = 0; i < rhs_chunks; ++i) {
200         KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs),
201                            reinterpret_cast<const InType*>(packed_rhs_chunk),
202                            params.fused_kernel,
203                            reinterpret_cast<OutType*>(result_chunk));
204 
205         result_chunk += result_chunk_size;
206         packed_rhs_chunk += packed_rhs_chunk_size;
207       }
208 
209       KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs),
210                          reinterpret_cast<const InType*>(packed_rhs_chunk),
211                          params.fused_kernel,
212                          reinterpret_cast<OutType*>(result_chunk));
213     }
214   }
215 };
216 
217 class GemmExecutorPackLHS {
218  public:
219   template <typename P>
EstimateScratchSize(const P & params,int kernel_m,int kernel_n,int kernel_k)220   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
221                                  int kernel_k) {
222     const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m);
223     const int lhs_scratch =
224         lhs_chunks *
225         StreamUtil<typename P::InType, typename P::LeftStream>::Scratch(
226             params.left_stream, kernel_m, kernel_k);
227     const int rhs_scratch =
228         StreamUtil<typename P::InType, typename P::RightStream>::Scratch(
229             params.right_stream, kernel_n, kernel_k);
230     return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch);
231   }
232 
233   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
234             int k_leftovers>
ExecuteDispatch3D(const P & params)235   static void ExecuteDispatch3D(const P& params) {
236     // Shorthand typedefs for streams and multiply kernels.
237     typedef typename P::InType InType;
238     typedef typename P::OutType OutType;
239 
240     typedef Stream<typename P::InType, m, k, k_leftovers,
241                    typename P::LeftStream>
242         LeftStreamF;
243     typedef Stream<typename P::InType, m_leftovers, k, k_leftovers,
244                    typename P::LeftStream>
245         LeftStreamL;
246 
247     typedef Stream<typename P::InType, n, k, k_leftovers,
248                    typename P::RightStream>
249         RightStreamF;
250     typedef Stream<typename P::InType, n_leftovers, k, k_leftovers,
251                    typename P::RightStream>
252         RightStreamL;
253 
254     typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream>
255         OutputStreamFF;
256     typedef Stream<typename P::OutType, m, n_leftovers, 0,
257                    typename P::OutputStream>
258         OutputStreamFL;
259 
260     typedef MulKernel<typename P::InType, typename P::OutType,
261                       typename P::Kernel, typename P::OutputStream, m, n, k>
262         KernelFF;
263     typedef MulKernel<typename P::InType, typename P::OutType,
264                       typename P::Kernel, typename P::OutputStream, m,
265                       n_leftovers, k>
266         KernelFL;
267     typedef MulKernel<typename P::InType, typename P::OutType,
268                       typename P::Kernel, typename P::OutputStream, m_leftovers,
269                       n, k>
270         KernelLF;
271     typedef MulKernel<typename P::InType, typename P::OutType,
272                       typename P::Kernel, typename P::OutputStream, m_leftovers,
273                       n_leftovers, k>
274         KernelLL;
275 #ifdef DEBUG
276 #ifdef DEBUG_METAGEMM_VERBOSE
277     std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n
278               << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x"
279               << k_leftovers << " -- " << params.m << "x" << params.n << "x"
280               << params.k << std::endl;
281     LeftStreamF::Debug(params.left_stream);
282     LeftStreamL::Debug(params.left_stream);
283 
284     RightStreamF::Debug(params.right_stream);
285     RightStreamL::Debug(params.right_stream);
286 
287     OutputStreamFF::Debug(params.fused_kernel.output_stream);
288     OutputStreamFL::Debug(params.fused_kernel.output_stream);
289 
290     KernelFF::Debug(params.fused_kernel);
291     KernelFL::Debug(params.fused_kernel);
292     KernelLF::Debug(params.fused_kernel);
293     KernelLL::Debug(params.fused_kernel);
294 #endif
295 #endif
296 
297     int lhs_chunks = params.m / m;
298     int rhs_chunks = params.n / n;
299 
300     // Scratch memory for packed LHS & RHS chunks.
301     std::uint8_t* packed_rhs = params.scratch;
302     std::uint8_t* packed_lhs =
303         params.scratch + RightStreamF::Scratch(params.right_stream);
304 
305     // Pack full LHS first.
306 
307     std::uint8_t* packed_lhs_chunk = packed_lhs;
308     const int packed_lhs_chunk_size =
309         LeftStreamF::PackedStride(params.left_stream);
310 
311     {
312       const std::uint8_t* lhs_chunk =
313           reinterpret_cast<const std::uint8_t*>(params.lhs);
314       const int lhs_chunk_size =
315           LeftStreamF::UnpackedStride(params.left_stream);
316 
317       for (int i = 0; i < lhs_chunks; ++i) {
318         LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk),
319                           params.left_stream,
320                           reinterpret_cast<InType*>(packed_lhs_chunk));
321 
322         lhs_chunk += lhs_chunk_size;
323         packed_lhs_chunk += packed_lhs_chunk_size;
324       }
325 
326       LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk),
327                         params.left_stream,
328                         reinterpret_cast<InType*>(packed_lhs_chunk));
329     }
330 
331     // Multiply RHS by LHS one RHS chunk at a time.
332 
333     const std::uint8_t* rhs_chunk =
334         reinterpret_cast<const std::uint8_t*>(params.rhs);
335     std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result);
336     std::uint8_t* result_chunk = result_strip;
337 
338     {
339       const int rhs_chunk_size =
340           RightStreamF::UnpackedStride(params.right_stream);
341       const int result_strip_size =
342           OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream);
343       const int result_chunk_size =
344           OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream);
345 
346       for (int i = 0; i < rhs_chunks; ++i) {
347         RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk),
348                            params.right_stream,
349                            reinterpret_cast<InType*>(packed_rhs));
350 
351         result_chunk = result_strip;
352         packed_lhs_chunk = packed_lhs;
353 
354         for (int j = 0; j < lhs_chunks; ++j) {
355           KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
356                              reinterpret_cast<const InType*>(packed_rhs),
357                              params.fused_kernel,
358                              reinterpret_cast<OutType*>(result_chunk));
359 
360           result_chunk += result_chunk_size;
361           packed_lhs_chunk += packed_lhs_chunk_size;
362         }
363 
364         KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
365                            reinterpret_cast<const InType*>(packed_rhs),
366                            params.fused_kernel,
367                            reinterpret_cast<OutType*>(result_chunk));
368 
369         rhs_chunk += rhs_chunk_size;
370         result_strip += result_strip_size;
371       }
372     }
373 
374     // Leftover RHS chunk.
375     if (n_leftovers > 0) {  // static if
376       const int result_chunk_size =
377           OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream);
378 
379       RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk),
380                          params.right_stream,
381                          reinterpret_cast<InType*>(packed_rhs));
382 
383       result_chunk = result_strip;
384       packed_lhs_chunk = packed_lhs;
385 
386       for (int i = 0; i < lhs_chunks; ++i) {
387         KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
388                            reinterpret_cast<const InType*>(packed_rhs),
389                            params.fused_kernel,
390                            reinterpret_cast<OutType*>(result_chunk));
391 
392         result_chunk += result_chunk_size;
393         packed_lhs_chunk += packed_lhs_chunk_size;
394       }
395 
396       KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
397                          reinterpret_cast<const InType*>(packed_rhs),
398                          params.fused_kernel,
399                          reinterpret_cast<OutType*>(result_chunk));
400     }
401   }
402 };
403 
404 namespace internal {
405 
CalculateCacheFriendlyTasksCount(int cache_size,int constant_memory,int per_chunk_memory,int total_dim,int chunk_dim)406 inline int CalculateCacheFriendlyTasksCount(int cache_size, int constant_memory,
407                                             int per_chunk_memory, int total_dim,
408                                             int chunk_dim) {
409   assert(constant_memory + per_chunk_memory < cache_size);
410   const int available_cache = cache_size - constant_memory;
411   const int available_chunks = available_cache / per_chunk_memory;
412   const int chunks_count = (total_dim + chunk_dim - 1) / chunk_dim;
413   return (chunks_count + available_chunks - 1) / available_chunks;
414 }
415 
416 template <typename Params>
UpdateCacheFriendlyTask(int m_offset,int m,int n_offset,int n,const Params & params,Params * task_params)417 inline void UpdateCacheFriendlyTask(int m_offset, int m, int n_offset, int n,
418                                     const Params& params, Params* task_params) {
419   task_params->m = m;
420   task_params->lhs =
421       StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset(
422           params.left_stream, params.lhs, m_offset, 0);
423 
424   task_params->n = n;
425   task_params->rhs =
426       StreamUtil<typename Params::InType, typename Params::RightStream>::Offset(
427           params.right_stream, params.rhs, n_offset, 0);
428 
429   task_params->result =
430       StreamUtil<typename Params::OutType, typename Params::OutputStream>::
431           Offset(params.fused_kernel.output_stream, params.result, m_offset,
432                  n_offset);
433 }
434 
435 }  // namespace internal
436 
437 template <int cache_size = 256 * 1024>
438 class GemmExecutorPackRHSCacheFriendly {
439  public:
440   template <typename P>
EstimateScratchSize(const P & params,int kernel_m,int kernel_n,int kernel_k)441   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
442                                  int kernel_k) {
443     return cache_size;
444   }
445 
446   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
447             int k_leftovers>
ExecuteDispatch3D(const P & params)448   static void ExecuteDispatch3D(const P& params) {
449     typedef Stream<typename P::InType, m, k, k_leftovers,
450                    typename P::LeftStream>
451         LeftStream;
452 
453     typedef Stream<typename P::InType, n, k, k_leftovers,
454                    typename P::RightStream>
455         RightStream;
456 
457     const int lhs_scratch = LeftStream::Scratch(params.left_stream);
458     const int rhs_scratch = RightStream::Scratch(params.right_stream);
459 
460     const int cache_friendly_tasks_count =
461         internal::CalculateCacheFriendlyTasksCount(cache_size, lhs_scratch,
462                                                    rhs_scratch, params.n, n);
463 
464     if (cache_friendly_tasks_count == 1) {
465       GemmExecutorPackRHS::ExecuteDispatch3D<P, m, n, k, m_leftovers,
466                                              n_leftovers, k_leftovers>(params);
467       return;
468     }
469 
470     const int cache_friendly_dim = params.n / cache_friendly_tasks_count;
471 
472     P task_params = params;
473     for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) {
474       internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim,
475                                         cache_friendly_dim, params,
476                                         &task_params);
477       Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params);
478     }
479     const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim;
480     internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum,
481                                       params, &task_params);
482     Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params);
483   }
484 };
485 
486 template <int cache_size = 256 * 1024>
487 class GemmExecutorPackLHSCacheFriendly {
488  public:
489   template <typename P>
EstimateScratchSize(const P & params,int kernel_m,int kernel_n,int kernel_k)490   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
491                                  int kernel_k) {
492     return cache_size;
493   }
494 
495   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
496             int k_leftovers>
ExecuteDispatch3D(const P & params)497   static void ExecuteDispatch3D(const P& params) {
498     typedef Stream<typename P::InType, m, k, k_leftovers,
499                    typename P::LeftStream>
500         LeftStream;
501 
502     typedef Stream<typename P::InType, n, k, k_leftovers,
503                    typename P::RightStream>
504         RightStream;
505 
506     const int lhs_scratch = LeftStream::Scratch(params.left_stream);
507     const int rhs_scratch = RightStream::Scratch(params.right_stream);
508 
509     const int cache_friendly_tasks_count =
510         internal::CalculateCacheFriendlyTasksCount(cache_size, rhs_scratch,
511                                                    lhs_scratch, params.m, m);
512 
513     if (cache_friendly_tasks_count == 1) {
514       GemmExecutorPackLHS::ExecuteDispatch3D<P, m, n, k, m_leftovers,
515                                              n_leftovers, k_leftovers>(params);
516       return;
517     }
518 
519     const int cache_friendly_dim = params.m / cache_friendly_tasks_count;
520 
521     P task_params = params;
522     for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) {
523       internal::UpdateCacheFriendlyTask(i * cache_friendly_dim,
524                                         cache_friendly_dim, 0, params.n, params,
525                                         &task_params);
526       Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params);
527     }
528     const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim;
529     internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n,
530                                       params, &task_params);
531     Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params);
532   }
533 };
534 
535 namespace internal {
536 
537 // Stage 3.
538 
539 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
540           int fixed_n, int variable_k>
541 struct Dispatch3DStage3 {
ExecuteDispatch3DStage3542   static void Execute(const P& params, int k) {
543 #ifdef DEBUG
544 #ifdef DEBUG_METAGEMM_VERBOSE
545     std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k
546               << " : " << fixed_m << "x" << fixed_n << "x" << variable_k
547               << std::endl
548               << std::flush;
549 #endif
550 #endif
551     if (k == variable_k) {
552       E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
553                                     variable_k>(params);
554     } else {
555       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
556                        variable_k - 1>::Execute(params, k);
557     }
558   }
559 };
560 
561 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
562           int fixed_n>
563 struct Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 0> {
564   static void Execute(const P& params, int k) {
565 #ifdef DEBUG
566 #ifdef DEBUG_METAGEMM_VERBOSE
567     std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k
568               << " : " << fixed_m << "x" << fixed_n << "x" << 0 << std::endl
569               << std::flush;
570 #endif
571 #endif
572     if (k == 0) {
573       E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
574                                     0>(params);
575     } else {
576       std::cerr << "FATAL: dispatch3DStage3 failed: ran out of cases."
577                 << std::endl
578                 << std::flush;
579       std::exit(1);
580     }
581   }
582 };
583 
584 // Stage 2.
585 
586 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
587           int variable_n>
588 struct Dispatch3DStage2 {
589   static void Execute(const P& params, int n, int k) {
590 #ifdef DEBUG
591 #ifdef DEBUG_METAGEMM_VERBOSE
592     std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k
593               << " : " << fixed_m << "x" << variable_n << std::endl
594               << std::flush;
595 #endif
596 #endif
597     if (n == variable_n) {
598       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, variable_n,
599                        dim_k - 1>::Execute(params, k);
600     } else {
601       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m,
602                        variable_n - 1>::Execute(params, n, k);
603     }
604   }
605 };
606 
607 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m>
608 struct Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, 0> {
609   static void Execute(const P& params, int n, int k) {
610 #ifdef DEBUG
611 #ifdef DEBUG_METAGEMM_VERBOSE
612     std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k
613               << " : " << fixed_m << "x" << 0 << std::endl
614               << std::flush;
615 #endif
616 #endif
617     if (n == 0) {
618       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, 0,
619                        dim_k - 1>::Execute(params, k);
620     } else {
621       std::cerr << "FATAL: dispatch3DStage2 failed: ran out of cases."
622                 << std::endl
623                 << std::flush;
624       std::exit(1);
625     }
626   }
627 };
628 
629 // Stage 1.
630 
631 template <typename E, typename P, int dim_m, int dim_n, int dim_k,
632           int variable_m>
633 struct Dispatch3DStage1 {
634   static void Execute(const P& params, int m, int n, int k) {
635 #ifdef DEBUG
636 #ifdef DEBUG_METAGEMM_VERBOSE
637     std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k
638               << " : " << variable_m << std::endl
639               << std::flush;
640 #endif
641 #endif
642     if (m == variable_m) {
643       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, variable_m,
644                        dim_n - 1>::Execute(params, n, k);
645     } else {
646       Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, variable_m - 1>::Execute(
647           params, m, n, k);
648     }
649   }
650 };
651 
652 template <typename E, typename P, int dim_m, int dim_n, int dim_k>
653 struct Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, 0> {
654   static void Execute(const P& params, int m, int n, int k) {
655 #ifdef DEBUG
656 #ifdef DEBUG_METAGEMM_VERBOSE
657     std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k
658               << " : " << 0 << std::endl
659               << std::flush;
660 #endif
661 #endif
662     if (m == 0) {
663       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, 0, dim_n - 1>::Execute(params,
664                                                                          n, k);
665     } else {
666       std::cerr << "FATAL: dispatch3DStage1 failed: ran out of cases."
667                 << std::endl
668                 << std::flush;
669       std::exit(1);
670     }
671   }
672 };
673 
674 }  // namespace internal
675 
676 template <typename Executor, typename Params, int kernel_m, int kernel_n,
677           int kernel_k>
678 inline void Gemm(const Params& params) {
679   internal::Dispatch3DStage1<Executor, Params, kernel_m, kernel_n, kernel_k,
680                              kernel_m - 1>::Execute(params, params.m % kernel_m,
681                                                     params.n % kernel_n,
682                                                     params.k % kernel_k);
683 }
684 
685 }  // namespace meta
686 }  // namespace gemmlowp
687 
688 #endif  // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
689