Lines Matching full:rows

20 def GenerateCommonTempsCountersAndConsts(emitter, rows):  argument
35 if rows is not 0:
37 'std::int32_t*', 'zipped_lhs_%d_offsets' % rows,
38 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows)
47 def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows): argument
49 GenerateCommonTempsCountersAndConsts(emitter, rows)
64 def GenerateFullTempsCountersAndConsts(emitter, result_type, rows): argument
66 GenerateCommonTempsCountersAndConsts(emitter, rows)
76 def ZipName(rows, leftovers, aligned): argument
77 return zip_Nx8_neon.BuildName(rows, leftovers, aligned)
97 def MulName(result_type, lhs_add, rhs_add, rows, cols): argument
98 return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols)
110 rows, cols, leftovers): argument
113 ZipName(rows, leftovers, aligned),
121 MulName(result_type, lhs_add, rhs_add, rows, 3),
130 MulName(result_type, lhs_add, rhs_add, rows, cols),
134 def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers): argument
135 """Emits code for all lhs strips & leftover rows. Quantize after mul code."""
149 if rows is not 0:
150 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows,
153 qnt_Nx8_neon.BuildMultiQuantizeName(aligned, rows),
155 'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride',
159 def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers): argument
168 if rows is not 0:
170 rows, cols, leftovers)
173 def BuildName(output_type, aligned, rows, cols, leftover): argument
174 name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover)
209 def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers): argument
212 BuildName(output_type, aligned, rows, cols, leftovers),
215 emitter.EmitAssert('m %% 3 == %d' % rows)
220 GenerateQuantized8BitTempsCountersAndConsts(emitter, rows)
222 GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers)
224 GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows)
226 GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers)
228 GenerateFullTempsCountersAndConsts(emitter, 'float*', rows)
230 GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers)
273 """First level of main switch, choose optimized version on rows leftover."""
346 for rows in range(0, 3):
349 GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)