1"""Qnt primitive used by the GEMM function.
2
3"""
4
5import neon_emitter
6
7
8class Error(Exception):
9  """Module level error."""
10
11
12class ConfigurationError(Error):
13  """Unsupported configuration."""
14
15
16class QntLane(object):
17
18  def __init__(self, source, output, offset, load_1, load_2):
19    self.source = source
20    self.output = output
21    self.offset = offset
22    self.load_1 = load_1
23    self.load_2 = load_2
24
25
26def BuildName(lanes, leftovers, aligned):
27  name = 'qnt_%dx8' % lanes
28  if leftovers:
29    name += '_%d' % leftovers
30  if aligned:
31    name += '_aligned'
32  return name
33
34
35def LoadAndDuplicateOffsets(emitter, registers, lanes, offsets):
36  if lanes == 1 or lanes == 2 or lanes == 3:
37    offset_registers = []
38    for unused_i in range(0, lanes):
39      register = registers.QuadRegister()
40      emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)),
41                                  emitter.AllLanes(registers.High(register))],
42                         emitter.DereferenceIncrement(offsets, 32))
43      offset_registers.append(register)
44    return offset_registers
45  else:
46    raise ConfigurationError('Unsupported number of lanes: %d' % lanes)
47
48
49def GenerateQntLanes(emitter, registers, qnt_lanes, source, stride, destination,
50                     destination_stride, offsets):
51  """Prepare lanes for reading unquantized multiplication results."""
52  offset_registers = LoadAndDuplicateOffsets(emitter, registers, qnt_lanes,
53                                             offsets)
54
55  lanes = []
56  last_input_register = source
57  last_output_register = destination
58  for i in range(0, qnt_lanes):
59    if not i:
60      lanes.append(QntLane(source,
61                           destination,
62                           offset_registers[i],
63                           registers.QuadRegister(),  # load 1
64                           registers.QuadRegister()))  # load 2
65    else:
66      input_register = registers.GeneralRegister()
67      output_register = registers.GeneralRegister()
68      lanes.append(QntLane(input_register,
69                           output_register,
70                           offset_registers[i],
71                           registers.QuadRegister(),  # load 1
72                           registers.QuadRegister()))  # load 2
73      emitter.EmitAdd(input_register, last_input_register, stride)
74      emitter.EmitAdd(output_register, last_output_register, destination_stride)
75      last_input_register = input_register
76      last_output_register = output_register
77  return lanes
78
79
80def DuplicateRegister(emitter, registers, value):
81  register = registers.QuadRegister()
82  emitter.EmitVDup('32', register, value)
83  return register
84
85
86def GenerateQuantize(emitter, registers, lanes, lane_temps,
87                     multiplicative_offset, rounding_offset, shift):
88  """Inner loop for quantization: add offsets, multiply, round, shift."""
89  for lane in lanes:
90    emitter.EmitVAdd('i32', lane[0], lane[0], lane[1])
91
92  for lane in lanes:
93    emitter.EmitVMul('i32', lane[0], lane[0], multiplicative_offset)
94
95  for lane in lanes:
96    emitter.EmitVAdd('i32', lane[0], lane[0], rounding_offset)
97
98  for lane in lanes:
99    emitter.EmitVShl('s32', lane[0], lane[0], shift)
100
101  for lane in lanes:
102    emitter.EmitVQmovn('s32', lane[2], lane[0])
103
104  for lane_temp in lane_temps:
105    emitter.EmitVQmovun('s16', registers.Low(lane_temp), lane_temp)
106
107
108def GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset,
109                              rounding_offset, shift, alignment):
110  """Load unquantized data from lanes, quantize, store final result."""
111  lane_temps = []
112  for lane in lanes:
113    lane_temps.append(registers.QuadRegister())
114
115  for lane in lanes:
116    emitter.EmitVLoadA(
117        '1.32', [registers.Low(lane.load_1), registers.High(lane.load_1),
118                 registers.Low(lane.load_2), registers.High(lane.load_2)],
119        emitter.DereferenceIncrement(lane.source, 64))
120
121  for lane in lanes:
122    emitter.EmitPld(lane.source)
123
124  quantize_setup = []
125  for (lane_temp, lane) in zip(lane_temps, lanes):
126    quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)])
127    quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)])
128
129  GenerateQuantize(emitter, registers, quantize_setup, lane_temps,
130                   multiplicative_offset, rounding_offset, shift)
131
132  for (lane_temp, lane) in zip(lane_temps, lanes):
133    emitter.EmitVStore('1.8', registers.Low(lane_temp),
134                       emitter.DereferenceIncrement(lane.output, alignment))
135
136  for lane_temp in lane_temps:
137    registers.FreeRegister(lane_temp)
138
139
140def GenerateLoadLeftovers(emitter, registers, leftovers, lanes):
141  """Handle non multiply of 8 leftover loading."""
142  if leftovers == 1:
143    for lane in lanes:
144      emitter.EmitVLoad('1.32', emitter.Lane(
145          registers.Low(lane.load_1), 0),
146                        emitter.Dereference(lane.source, None))
147  elif leftovers == 2:
148    for lane in lanes:
149      emitter.EmitVLoad('1.32', registers.Low(lane.load_1),
150                        emitter.Dereference(lane.source, 64))
151  elif leftovers == 3:
152    for lane in lanes:
153      emitter.EmitVLoad('1.32', registers.Low(lane.load_1),
154                        emitter.DereferenceIncrement(lane.source, 64))
155    for lane in lanes:
156      emitter.EmitVLoad('1.32', emitter.Lane(
157          registers.High(lane.load_1), 0),
158                        emitter.Dereference(lane.source, None))
159  elif leftovers == 4:
160    for lane in lanes:
161      emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1),
162                                  registers.High(lane.load_1)],
163                         emitter.Dereference(lane.source, 64))
164  elif leftovers == 5:
165    for lane in lanes:
166      emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1),
167                                  registers.High(lane.load_1)],
168                         emitter.DereferenceIncrement(lane.source, 64))
169    for lane in lanes:
170      emitter.EmitVLoad('1.32', emitter.Lane(
171          registers.Low(lane.load_2), 0),
172                        emitter.Dereference(lane.source, None))
173  elif leftovers == 6:
174    for lane in lanes:
175      emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1),
176                                  registers.High(lane.load_1),
177                                  registers.Low(lane.load_2)],
178                         emitter.Dereference(lane.source, 64))
179  elif leftovers == 7:
180    for lane in lanes:
181      emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1),
182                                  registers.High(lane.load_1),
183                                  registers.Low(lane.load_2)],
184                         emitter.DereferenceIncrement(lane.source, 64))
185    for lane in lanes:
186      emitter.EmitVLoad('1.32', emitter.Lane(
187          registers.High(lane.load_2), 0),
188                        emitter.Dereference(lane.source, None))
189  else:
190    raise ConfigurationError('Unsuported leftover count: %d' % leftovers)
191
192
193def GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes):
194  """Handle non multiply of 8 leftover storing."""
195  setup = []
196  for (temp, lane) in zip(lane_temps, lanes):
197    setup.append([registers.Low(temp), lane.output])
198
199  if leftovers == 1:
200    for lane in setup:
201      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 0),
202                         emitter.Dereference(lane[1], None))
203  elif leftovers == 2:
204    for lane in setup:
205      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0),
206                         emitter.Dereference(lane[1], None))
207  elif leftovers == 3:
208    for lane in setup:
209      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0),
210                         emitter.DereferenceIncrement(lane[1], None))
211    for lane in setup:
212      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 2),
213                         emitter.Dereference(lane[1], None))
214  elif leftovers == 4:
215    for lane in setup:
216      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
217                         emitter.Dereference(lane[1], None))
218  elif leftovers == 5:
219    for lane in setup:
220      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
221                         emitter.DereferenceIncrement(lane[1], None))
222    for lane in setup:
223      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 4),
224                         emitter.Dereference(lane[1], None))
225  elif leftovers == 6:
226    for lane in setup:
227      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
228                         emitter.DereferenceIncrement(lane[1], None))
229    for lane in setup:
230      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2),
231                         emitter.Dereference(lane[1], None))
232  elif leftovers == 7:
233    for lane in setup:
234      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
235                         emitter.DereferenceIncrement(lane[1], None))
236    for lane in setup:
237      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2),
238                         emitter.DereferenceIncrement(lane[1], None))
239    for lane in setup:
240      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 6),
241                         emitter.DereferenceIncrement(lane[1], None))
242  else:
243    raise ConfigurationError('Unsupported leftovers count: %d' % leftovers)
244
245
246def GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes,
247                                      multiplicative_offset, rounding_offset,
248                                      shift):
249  """Handle leftovers if row size not a multiply of 8."""
250  lane_temps = []
251  for lane in lanes:
252    lane_temps.append(registers.QuadRegister())
253
254  GenerateLoadLeftovers(emitter, registers, leftovers, lanes)
255
256  quantize_setup = []
257  for (lane_temp, lane) in zip(lane_temps, lanes):
258    quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)])
259    if leftovers > 4:
260      quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)
261                            ])
262
263  GenerateQuantize(emitter, registers, quantize_setup, lane_temps,
264                   multiplicative_offset, rounding_offset, shift)
265
266  GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes)
267
268
269def GenerateQntNx8(emitter, qnt_lanes, leftovers, aligned):
270  """Emits optimized quantization code for given lanes and row size."""
271  if leftovers < 0 or leftovers > 7:
272    raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
273  if qnt_lanes < 1 or qnt_lanes > 3:
274    raise ConfigurationError('Qnt_lanes should should be 1, 2 or 3.')
275
276  name = BuildName(qnt_lanes, leftovers, aligned)
277
278  emitter.EmitFunctionBeginA(
279      name,
280      [['const std::int32_t*', 'source'], ['std::int32_t', 'count'],
281       ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'],
282       ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'],
283       ['std::int32_t', 'multiplicative_offset'],
284       ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void')
285  emitter.EmitAssert('count %% 8 == %d' % leftovers)
286  emitter.EmitAssert('count >= 8')
287  emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
288  if aligned:
289    emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
290    if qnt_lanes > 1:
291      emitter.EmitAssert('destination_stride % 8 == 0')
292  emitter.EmitAsmBegin()
293
294  registers = neon_emitter.NeonRegisters()
295
296  count = registers.MapParameter('count')
297
298  multiplicative_offset = DuplicateRegister(
299      emitter, registers, registers.MapParameter('multiplicative_offset'))
300  rounding_offset = DuplicateRegister(emitter, registers,
301                                      registers.MapParameter('rounding_offset'))
302  shift = DuplicateRegister(emitter, registers, registers.MapParameter('shift'))
303
304  lanes = GenerateQntLanes(
305      emitter, registers, qnt_lanes, registers.MapParameter('source'),
306      registers.MapParameter('stride'), registers.MapParameter('destination'),
307      registers.MapParameter('destination_stride'),
308      registers.MapParameter('offsets'))
309
310  if leftovers:
311    emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers))
312    emitter.EmitBeqFront(2)
313
314  emitter.EmitNewline()
315  emitter.EmitNumericalLabel(1)
316  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
317
318  GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset,
319                            rounding_offset, shift, 64 if aligned else None)
320
321  emitter.EmitNewline()
322  emitter.EmitBneBack(1)
323
324  if leftovers:
325    emitter.EmitNumericalLabel(2)
326    GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes,
327                                      multiplicative_offset, rounding_offset,
328                                      shift)
329
330  emitter.EmitAsmEnd(registers.MappedParameters(), [],
331                     registers.Clobbers() + ['cc', 'memory'])
332  emitter.EmitFunctionEnd()
333
334
335def BuildMultiQuantizeName(aligned, rows):
336  name = 'multi_qnt_%dx8' % rows
337  if aligned:
338    name = '%s_aligned' % name
339  return name
340
341
342def GenerateMultiQuantize(emitter, aligned, rows):
343  """Emit main quantization code that switches between optimized versions."""
344  name = BuildMultiQuantizeName(aligned, rows)
345  emitter.EmitFunctionBeginA(
346      name,
347      [['const std::int32_t*', 'source'], ['std::int32_t', 'count'],
348       ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'],
349       ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'],
350       ['std::int32_t', 'multiplicative_offset'],
351       ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void')
352  emitter.EmitSwitch('count % 8')
353
354  for leftovers in range(0, 8):
355    emitter.EmitCase(leftovers)
356    emitter.PushIndent()
357    emitter.EmitCall(
358        BuildName(rows, leftovers, aligned),
359        ['source', 'count', 'stride', 'offsets', 'destination',
360         'destination_stride', 'multiplicative_offset', 'rounding_offset',
361         'shift'])
362    emitter.EmitBreak()
363    emitter.PopIndent()
364
365  emitter.EmitSwitchEnd()
366  emitter.EmitFunctionEnd()
367
368
369def GenerateFunctions(neon, cc):
370  for aligned in [True, False]:
371    for lanes in range(1, 4):
372      for leftovers in range(0, 8):
373        GenerateQntNx8(neon, lanes, leftovers, aligned)
374        neon.EmitNewline()
375
376  for aligned in [True, False]:
377    for rows in range(1, 4):
378      GenerateMultiQuantize(cc, aligned, rows)
379      cc.EmitNewline()
380