1"""Generates the specialized gemm functions."""
2
3import mul_Nx8_Mx8_neon
4import qnt_Nx8_neon
5import zip_Nx8_neon
6
7_QUANTIZED_8BIT = 'quantized_8bit'
8_FULL_32BIT = 'full_32bit'
9_FULL_FLOAT = 'full_float'
10
11
12class Error(Exception):
13  """Module level error."""
14
15
16class ConfigurationError(Error):
17  """Runtime configuration error."""
18
19
20def GenerateCommonTempsCountersAndConsts(emitter, rows):
21  emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3')
22  emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3')
23  emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
24  emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3')
25  emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
26                      '(padded_k + 16) * 3')
27  emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size',
28                      '(padded_k + 16) * n')
29  emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs')
30  emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
31  emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
32  emitter.EmitDeclare(
33      'std::int32_t*', 'zipped_lhs_3_offsets',
34      'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3)')
35  if rows is not 0:
36    emitter.EmitDeclare(
37        'std::int32_t*', 'zipped_lhs_%d_offsets' % rows,
38        'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows)
39  emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs',
40                      'scratch + zipped_chunk_size')
41  emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs')
42  emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride',
43                      'result_stride * 3')
44  emitter.EmitNewline()
45
46
47def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows):
48  """Generates all the boilerplate variables for the q8 gemm function."""
49  GenerateCommonTempsCountersAndConsts(emitter, rows)
50  emitter.EmitDeclare('const std::int32_t', 'const_offset',
51                      'lhs_offset * rhs_offset * k + result_offset')
52  emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
53                      '(1 << (shift - 1))')
54  emitter.EmitDeclare('std::int32_t*', 'temp_result',
55                      'reinterpret_cast<std::int32_t*>('
56                      'scratch + zipped_chunk_size + zipped_rhs_size)')
57  emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result')
58  emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
59  emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
60                      '((n * 4 + 7) / 8) * 8')
61  emitter.EmitNewline()
62
63
64def GenerateFullTempsCountersAndConsts(emitter, result_type, rows):
65  """Generates all the boilerplate variables for the int32 and float gemms."""
66  GenerateCommonTempsCountersAndConsts(emitter, rows)
67  emitter.EmitDeclare('const std::int32_t', 'const_offset',
68                      'lhs_offset * rhs_offset * k')
69  emitter.EmitDeclare(result_type, 'result_chunk', 'result')
70  emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
71  emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
72                      'result_stride * 4')
73  emitter.EmitNewline()
74
75
76def ZipName(rows, leftovers, aligned):
77  return zip_Nx8_neon.BuildName(rows, leftovers, aligned)
78
79
80def GenerateZipRhs(emitter, aligned, cols, leftovers):
81  """Emits the code responsible for zipping the rhs matrix."""
82  emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
83  emitter.EmitCall(
84      ZipName(3, leftovers, aligned),
85      ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0])
86  emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
87  emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
88  emitter.EmitCloseBracket()
89
90  if cols is not 0:
91    emitter.EmitCall(
92        ZipName(cols, leftovers, aligned),
93        ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0])
94  emitter.EmitNewline()
95
96
97def MulName(result_type, lhs_add, rhs_add, rows, cols):
98  return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols)
99
100
101def GetMulParams(result_type):
102  params = ['zipped_lhs', 'zipped_rhs_chunk', 'padded_k', 'mul_result_chunk',
103            'mul_result_chunk_stride_bytes']
104  if result_type is 'float':
105    params.append('result_scale')
106  return params
107
108
109def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned,
110                    rows, cols, leftovers):
111  """Emits code responsible for multiplication of one horizontal lhs strip."""
112  emitter.EmitCall(
113      ZipName(rows, leftovers, aligned),
114      ['lhs_chunk', 'k', 'k', 'zipped_lhs', 'rhs_offset', 'const_offset'])
115  emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs')
116  emitter.EmitAssign('mul_result_chunk', result)
117
118  emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)')
119
120  emitter.EmitCall(
121      MulName(result_type, lhs_add, rhs_add, rows, 3),
122      GetMulParams(result_type))
123  emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
124  emitter.EmitAssignIncrement('mul_result_chunk', 3)
125
126  emitter.EmitCloseBracket()
127
128  if cols is not 0:
129    emitter.EmitCall(
130        MulName(result_type, lhs_add, rhs_add, rows, cols),
131        GetMulParams(result_type))
132
133
134def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers):
135  """Emits code for all lhs strips & leftover rows. Quantize after mul code."""
136  emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
137  GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3,
138                  cols, leftovers)
139  emitter.EmitCall(
140      qnt_Nx8_neon.BuildMultiQuantizeName(aligned, 3),
141      ['temp_result', 'n', 'mul_result_chunk_stride_bytes',
142       'zipped_lhs_3_offsets', 'result_chunk', 'result_stride',
143       'multiplicative_offset', 'rounding_offset', '-shift'])
144  emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
145  emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
146  emitter.EmitCloseBracket()
147  emitter.EmitNewline()
148
149  if rows is not 0:
150    GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows,
151                    cols, leftovers)
152    emitter.EmitCall(
153        qnt_Nx8_neon.BuildMultiQuantizeName(aligned, rows),
154        ['temp_result', 'n', 'mul_result_chunk_stride_bytes',
155         'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride',
156         'multiplicative_offset', 'rounding_offset', '-shift'])
157
158
159def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers):
160  emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
161  GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3,
162                  cols, leftovers)
163  emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
164  emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
165  emitter.EmitCloseBracket()
166  emitter.EmitNewline()
167
168  if rows is not 0:
169    GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned,
170                    rows, cols, leftovers)
171
172
173def BuildName(output_type, aligned, rows, cols, leftover):
174  name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover)
175  if aligned:
176    name += '_aligned'
177  return name
178
179
180def GetCommonGemmParameters():
181  return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'],
182          ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'm'],
183          ['std::int32_t', 'n'], ['std::int32_t', 'k'],
184          ['std::int32_t', 'lhs_offset'], ['std::int32_t', 'rhs_offset']]
185
186
187def GetGemmParameters(output_type, extra_params=None):
188  """Prepares a (type, parameter) array for the gemm functions."""
189  if extra_params is None:
190    extra_params = []
191  params = GetCommonGemmParameters()
192  if output_type is _QUANTIZED_8BIT:
193    params += [['std::int32_t', 'result_offset'],
194               ['std::int32_t', 'multiplicative_offset'],
195               ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']]
196  elif output_type is _FULL_32BIT:
197    params += [['std::int32_t*', 'result']]
198  elif output_type is _FULL_FLOAT:
199    params += [['float', 'result_scale'], ['float*', 'result']]
200  else:
201    raise ConfigurationError('Unsupported output type: %s' % output_type)
202  return params + extra_params
203
204
205def GetStridedGemmParameters(output_type):
206  return GetGemmParameters(output_type, [['std::int32_t', 'result_stride']])
207
208
209def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers):
210  """Build one gemm function for given row, col, and depth leftovers."""
211  emitter.EmitFunctionBeginA(
212      BuildName(output_type, aligned, rows, cols, leftovers),
213      GetStridedGemmParameters(output_type), 'void')
214
215  emitter.EmitAssert('m %% 3 == %d' % rows)
216  emitter.EmitAssert('n %% 3 == %d' % cols)
217  emitter.EmitAssert('k %% 8 == %d' % leftovers)
218
219  if output_type is _QUANTIZED_8BIT:
220    GenerateQuantized8BitTempsCountersAndConsts(emitter, rows)
221    GenerateZipRhs(emitter, aligned, cols, leftovers)
222    GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers)
223  elif output_type is _FULL_32BIT:
224    GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows)
225    GenerateZipRhs(emitter, aligned, cols, leftovers)
226    GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers)
227  elif output_type is _FULL_FLOAT:
228    GenerateFullTempsCountersAndConsts(emitter, 'float*', rows)
229    GenerateZipRhs(emitter, aligned, cols, leftovers)
230    GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers)
231  else:
232    raise ConfigurationError('Unknown output type: %s' % output_type)
233
234  emitter.EmitFunctionEnd()
235
236
237def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers):
238  emitter.EmitCall(
239      emitter.Scope('internal',
240                    BuildName(output_type, aligned, m_mod, n_mod, leftovers)),
241      [p for (unused_t, p) in GetStridedGemmParameters(output_type)])
242
243
244def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod):
245  """Third level of main switch, choose optimized version on depth leftover."""
246  emitter.EmitSwitch('k % 8')
247
248  for leftovers in range(0, 8):
249    emitter.EmitCase(leftovers)
250    emitter.PushIndent()
251    GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers)
252    emitter.EmitBreak()
253    emitter.PopIndent()
254
255  emitter.EmitSwitchEnd()
256
257
258def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod):
259  """Second level of main switch, choose optimized version on cols leftover."""
260  emitter.EmitSwitch('n % 3')
261
262  for n_mod in range(0, 3):
263    emitter.EmitCase(n_mod)
264    emitter.PushIndent()
265    GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod)
266    emitter.EmitBreak()
267    emitter.PopIndent()
268
269  emitter.EmitSwitchEnd()
270
271
272def GenerateGemmSwitch1(emitter, output_type, aligned):
273  """First level of main switch, choose optimized version on rows leftover."""
274  emitter.EmitSwitch('m % 3')
275
276  for m_mod in range(0, 3):
277    emitter.EmitCase(m_mod)
278    emitter.PushIndent()
279    GenerateGemmSwitch2(emitter, output_type, aligned, m_mod)
280    emitter.EmitBreak()
281    emitter.PopIndent()
282
283  emitter.EmitSwitchEnd()
284
285
286def BuildMainGemmName(output_type):
287  if output_type is _QUANTIZED_8BIT:
288    return 'gemm_q8'
289  elif output_type is _FULL_32BIT:
290    return 'gemm_i32'
291  elif output_type is _FULL_FLOAT:
292    return 'gemm_f'
293  else:
294    raise ConfigurationError('Unsupported output type: %s' % output_type)
295
296
297def BuildStridedMainGemmName(output_type):
298  return BuildMainGemmName(output_type) + '_strided'
299
300
301def GenerateMainGemmFunction(emitter, output_type):
302  """Emit high level gemm function that switches between optimized versions."""
303  emitter.EmitFunctionBeginA(
304      BuildStridedMainGemmName(output_type),
305      GetStridedGemmParameters(output_type), 'void')
306
307  emitter.EmitDeclare('const bool', 'lhs_aligned',
308                      '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)')
309  emitter.EmitDeclare('const bool', 'rhs_aligned',
310                      '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)')
311  emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
312
313  if output_type is _QUANTIZED_8BIT:
314    emitter.EmitDeclare('const bool', 'result_aligned',
315                        '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)')
316    emitter.EmitDeclare('const bool', 'result_stride_aligned',
317                        '((result_stride % 8) == 0)')
318    emitter.EmitDeclare('const bool', 'aligned',
319                        'lhs_aligned && rhs_aligned && result_aligned '
320                        '&& k_aligned && result_stride_aligned')
321  else:
322    emitter.EmitDeclare('const bool', 'aligned',
323                        'lhs_aligned && rhs_aligned && k_aligned')
324
325  emitter.EmitIf('aligned')
326  GenerateGemmSwitch1(emitter, output_type, True)
327  emitter.EmitElse()
328  GenerateGemmSwitch1(emitter, output_type, False)
329  emitter.EmitEndif()
330  emitter.EmitFunctionEnd()
331
332
333def GenerateWrapperGemmFunction(emitter, output_type):
334  emitter.EmitFunctionBeginA(
335      BuildMainGemmName(output_type), GetGemmParameters(output_type), 'void')
336  emitter.EmitCall(
337      BuildStridedMainGemmName(output_type),
338      [p for (unused_t, p) in GetGemmParameters(output_type)] + ['n'])
339  emitter.EmitFunctionEnd()
340
341
342def GenerateInternalFunctions(emitter):
343  """Generate all the functions hidden in the internal namespace."""
344  for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
345    for aligned in [True, False]:
346      for rows in range(0, 3):
347        for cols in range(0, 3):
348          for leftover in range(0, 8):
349            GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)
350            emitter.EmitNewline()
351
352
353def GeneratePublicFunctions(emitter):
354  for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
355    GenerateMainGemmFunction(emitter, output_type)
356    emitter.EmitNewline()
357
358    GenerateWrapperGemmFunction(emitter, output_type)
359    emitter.EmitNewline()
360