Lines Matching full:rows

41 def GenerateCommonTempsCountersAndConsts(emitter, rows):  argument
56 if rows is not 0:
58 'std::int32_t*', 'zipped_lhs_%d_offsets' % rows,
59 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows)
68 def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows): argument
70 GenerateCommonTempsCountersAndConsts(emitter, rows)
85 def GenerateFullTempsCountersAndConsts(emitter, result_type, rows): argument
87 GenerateCommonTempsCountersAndConsts(emitter, rows)
97 def ZipName(rows, leftovers, aligned): argument
98 return zip_Nx8_neon.BuildName(rows, leftovers, aligned)
118 def MulName(result_type, lhs_add, rhs_add, rows, cols): argument
119 return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols)
131 rows, cols, leftovers): argument
134 ZipName(rows, leftovers, aligned),
142 MulName(result_type, lhs_add, rhs_add, rows, 3),
151 MulName(result_type, lhs_add, rhs_add, rows, cols),
155 def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers): argument
156 """Emits code for all lhs strips & leftover rows. Quantize after mul code."""
170 if rows is not 0:
171 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows,
174 BuildMultiQuantizeName(aligned, rows),
176 'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride',
180 def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers): argument
189 if rows is not 0:
191 rows, cols, leftovers)
194 def BuildName(output_type, aligned, rows, cols, leftover): argument
195 name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover)
230 def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers): argument
233 BuildName(output_type, aligned, rows, cols, leftovers),
236 emitter.EmitAssert('m %% 3 == %d' % rows)
241 GenerateQuantized8BitTempsCountersAndConsts(emitter, rows)
243 GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers)
245 GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows)
247 GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers)
249 GenerateFullTempsCountersAndConsts(emitter, 'float*', rows)
251 GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers)
258 def BuildMultiQuantizeName(aligned, rows): argument
259 name = 'multi_qnt_%dx8' % rows
265 def GenerateMultiQuantize(emitter, aligned, rows): argument
267 name = BuildMultiQuantizeName(aligned, rows)
281 qnt_Nx8_neon.BuildName(rows, leftovers, aligned),
328 """First level of main switch, choose optimized version on rows leftover."""
418 for rows in range(1, 4):
419 GenerateMultiQuantize(emitter, aligned, rows)
424 for rows in range(0, 3):
427 GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)