Lines Matching refs:emitter

21 def GenerateCommonTempsCountersAndConsts(emitter):  argument
23 emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 8')
24 emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
25 emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 4')
26 emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
28 emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
29 emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
30 emitter.EmitDeclare('std::int32_t*', 'zipped_lhs_offsets',
32 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_1',
34 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_2',
36 emitter.EmitNewline()
39 def GenerateQuantized8BitTempsCountersAndConsts(emitter): argument
41 GenerateCommonTempsCountersAndConsts(emitter)
42 emitter.EmitDeclare('const std::int32_t', 'const_offset',
44 emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
46 emitter.EmitDeclare('std::int32_t*', 'temp_result',
49 emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
50 emitter.EmitNewline()
53 def GenerateFullTempsCountersAndConsts(emitter, result_type): argument
55 GenerateCommonTempsCountersAndConsts(emitter)
56 emitter.EmitDeclare('const std::int32_t', 'const_offset',
58 emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
59 emitter.EmitNewline()
62 def GenerateZipVector(emitter, aligned, leftovers): argument
63 emitter.EmitCall(
83 def GenerateMulCols(emitter, result_type, lhs_add, rhs_add, aligned, cols, argument
86 emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
87 emitter.EmitCall(
90 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
92 emitter.EmitCall(
95 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
97 emitter.EmitCall(
101 emitter.EmitAssignIncrement('mul_result_chunk', 8)
102 emitter.EmitCloseBracket()
105 emitter.EmitCall(
108 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
110 emitter.EmitCall(
114 emitter.EmitCall(
118 emitter.EmitCall(
122 emitter.EmitCall(
127 def GenerateQuantized8BitMul(emitter, aligned, cols, leftovers): argument
129 GenerateMulCols(emitter, 'int32', False, True, aligned, cols, leftovers)
130 emitter.EmitCall(
136 def GenerateFullMul(emitter, result_type, aligned, cols, leftovers): argument
137 GenerateMulCols(emitter, result_type, True, True, aligned, cols, leftovers)
170 def GenerateGemv(emitter, output_type, aligned, cols, leftovers): argument
172 emitter.EmitFunctionBeginA(
176 emitter.EmitAssert('n %% 8 == %d' % cols)
177 emitter.EmitAssert('k %% 8 == %d' % leftovers)
180 GenerateQuantized8BitTempsCountersAndConsts(emitter)
181 GenerateZipVector(emitter, aligned, leftovers)
182 GenerateQuantized8BitMul(emitter, aligned, cols, leftovers)
184 GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*')
185 GenerateZipVector(emitter, aligned, leftovers)
186 GenerateFullMul(emitter, 'int32', aligned, cols, leftovers)
188 GenerateFullTempsCountersAndConsts(emitter, 'float*')
189 GenerateZipVector(emitter, aligned, leftovers)
190 GenerateFullMul(emitter, 'float', aligned, cols, leftovers)
194 emitter.EmitFunctionEnd()
197 def GenerateGemvCall(emitter, output_type, aligned, m_mod, leftovers): argument
198 emitter.EmitCall(
199 emitter.Scope('internal',
204 def GenerateGemvSwitch2(emitter, output_type, aligned, n_mod): argument
206 emitter.EmitSwitch('k % 8')
209 emitter.EmitCase(leftovers)
210 emitter.PushIndent()
211 GenerateGemvCall(emitter, output_type, aligned, n_mod, leftovers)
212 emitter.EmitBreak()
213 emitter.PopIndent()
215 emitter.EmitSwitchEnd()
218 def GenerateGemvSwitch1(emitter, output_type, aligned): argument
220 emitter.EmitSwitch('n % 8')
223 emitter.EmitCase(n_mod)
224 emitter.PushIndent()
225 GenerateGemvSwitch2(emitter, output_type, aligned, n_mod)
226 emitter.EmitBreak()
227 emitter.PopIndent()
229 emitter.EmitSwitchEnd()
243 def GenerateMainGemvFunction(emitter, output_type): argument
245 emitter.EmitFunctionBeginA(
248 emitter.EmitDeclare('const bool', 'lhs_aligned',
250 emitter.EmitDeclare('const bool', 'rhs_aligned',
252 emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
255 emitter.EmitDeclare('const bool', 'result_aligned',
257 emitter.EmitDeclare('const bool', 'aligned',
261 emitter.EmitDeclare('const bool', 'aligned',
264 emitter.EmitIf('aligned')
265 GenerateGemvSwitch1(emitter, output_type, True)
266 emitter.EmitElse()
267 GenerateGemvSwitch1(emitter, output_type, False)
268 emitter.EmitEndif()
269 emitter.EmitFunctionEnd()
272 def GenerateInternalFunctions(emitter): argument
278 GenerateGemv(emitter, output_type, aligned, cols, leftover)
279 emitter.EmitNewline()
282 def GeneratePublicFunctions(emitter): argument
284 GenerateMainGemvFunction(emitter, output_type)
285 emitter.EmitNewline()