1"""Multiply primitive optimized for the gemv operation."""
2
3import neon_emitter
4
5
6class Error(Exception):
7  """Module level error."""
8
9
10class ConfigurationError(Error):
11  """Unsupported configuration."""
12
13
14def GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
15                                  count, lhs, rhs_1, rhs_2):
16  """Emit inner loop for 1 row x M cols multiplication."""
17  emitter.EmitComment('General 1xM lanes loop.')
18  emitter.EmitNumericalLabel(1)
19  emitter.EmitNewline()
20  emitter.EmitComment('Subtract counter.')
21  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
22  emitter.EmitNewline()
23
24  right_load = [registers.DoubleRegister() for unused_i in range(4)]
25  left_load = registers.DoubleRegister()
26
27  emitter.EmitVLoad('1.8', left_load, emitter.DereferenceIncrement(lhs, 64))
28  emitter.EmitVLoadA('1.8', right_load, emitter.DereferenceIncrement(rhs_1, 64))
29
30  emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
31  emitter.EmitPldOffset(rhs_1, emitter.ImmediateConstant(128))
32
33  multiply_results = [registers.QuadRegister() for unused_i in range(4)]
34
35  for i in range(4):
36    emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)
37
38  emitter.EmitVLoadA('1.8', right_load[:lanes_count],
39                     emitter.DereferenceIncrement(rhs_2, 64))
40  emitter.EmitPldOffset(rhs_2, emitter.ImmediateConstant(lanes_count * 32))
41
42  for i in range(4):
43    emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])
44
45  for i in range(lanes_count):
46    emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)
47
48  for i in range(lanes_count):
49    emitter.EmitVPadal('u16', aggregators[i + 4], multiply_results[i])
50
51  emitter.EmitNewline()
52  emitter.EmitComment('Loop break.')
53  emitter.EmitBneBack(1)
54  emitter.EmitNewline()
55
56  registers.FreeRegister(left_load)
57  registers.FreeRegisters(right_load)
58  registers.FreeRegisters(multiply_results)
59
60
61def ReadLeft(emitter, registers, lhs):
62  register = registers.QuadRegister()
63  emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)),
64                              emitter.AllLanes(registers.High(register))],
65                     emitter.Dereference(lhs, None))
66  return register
67
68
69def ReadRight(emitter, registers, rhs, count):
70  if count == 1 or count == 2:
71    register = registers.DoubleRegister()
72  elif count == 3 or count == 4:
73    register = registers.QuadRegister()
74  else:
75    raise ConfigurationError('Unsupported elements no: %d' % count)
76  emitter.EmitVLoad('1.32', register, emitter.Dereference(rhs, 64))
77  return register
78
79
80def DuplicateGeneralRegister(emitter, registers, general_register,
81                             min_register):
82  duplicated = registers.QuadRegister(min_register)
83  emitter.EmitVDup('32', duplicated, general_register)
84  return duplicated
85
86
87def GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
88                                  result_type, lhs_add, rhs_add, lhs, rhs_1,
89                                  rhs_2, results):
90  """Generates assembly responsible for reducing the 4 way aggregators."""
91  if lhs_add:
92    left_offset = ReadLeft(emitter, registers, lhs)
93  else:
94    left_offset = None
95
96  if rhs_add:
97    right_offset_1 = ReadRight(emitter, registers, rhs_1, 4)
98    right_offset_2 = ReadRight(emitter, registers, rhs_2, lanes_count)
99  else:
100    right_offset_1 = None
101    right_offset_2 = None
102
103  if result_type is 'float':
104    result_scale = DuplicateGeneralRegister(
105        emitter, registers, registers.MapParameter('result_scale'), 4)
106  else:
107    result_scale = None
108
109  emitter.EmitNewline()
110  emitter.EmitComment('Horizontal reduce aggregators.')
111  for aggregator in aggregators:
112    emitter.EmitVPadd('u32', registers.Low(aggregator),
113                      registers.Low(aggregator), registers.High(aggregator))
114
115  temp = aggregators[0]
116  emitter.EmitVPadd('u32', registers.Low(temp), registers.Low(aggregators[0]),
117                    registers.Low(aggregators[1]))
118  emitter.EmitVPadd('u32', registers.High(temp), registers.Low(aggregators[2]),
119                    registers.Low(aggregators[3]))
120
121  if lanes_count == 1:
122    temp_2 = registers.Low(aggregators[1])
123    emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
124                      registers.Low(aggregators[4]))
125  elif lanes_count == 2:
126    temp_2 = registers.Low(aggregators[1])
127    emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
128                      registers.Low(aggregators[5]))
129  elif lanes_count == 3:
130    temp_2 = aggregators[1]
131    emitter.EmitVPadd('u32', registers.Low(temp_2),
132                      registers.Low(aggregators[4]),
133                      registers.Low(aggregators[5]))
134    emitter.EmitVPadd('u32', registers.High(temp_2),
135                      registers.Low(aggregators[6]),
136                      registers.Low(aggregators[6]))
137  elif lanes_count == 4:
138    temp_2 = aggregators[1]
139    emitter.EmitVPadd('u32', registers.Low(temp_2),
140                      registers.Low(aggregators[4]),
141                      registers.Low(aggregators[5]))
142    emitter.EmitVPadd('u32', registers.High(temp_2),
143                      registers.Low(aggregators[6]),
144                      registers.Low(aggregators[7]))
145  else:
146    temp_2 = None
147
148  if lhs_add:
149    emitter.EmitNewline()
150    emitter.EmitComment('Add lhs offsets to aggregated rows.')
151    emitter.EmitVAdd('s32', temp, temp, left_offset)
152    if lanes_count == 1 or lanes_count == 2:
153      emitter.EmitVAdd('s32', temp_2, temp_2, registers.Low(left_offset))
154    elif lanes_count == 3 or lanes_count == 4:
155      emitter.EmitVAdd('s32', temp_2, temp_2, left_offset)
156
157  if rhs_add:
158    emitter.EmitNewline()
159    emitter.EmitComment('Add rhs offset to aggregated rows.')
160    emitter.EmitVAdd('s32', temp, temp, right_offset_1)
161    emitter.EmitVAdd('s32', temp_2, temp_2, right_offset_2)
162
163  if result_type is 'float':
164    emitter.EmitNewline()
165    emitter.EmitComment('Convert to float and scale.')
166    emitter.EmitVCvt('f32', 's32', temp, temp)
167    emitter.EmitVCvt('f32', 's32', temp_2, temp_2)
168    emitter.EmitVMul('f32', temp, temp, result_scale)
169    if lanes_count == 1 or lanes_count == 2:
170      emitter.EmitVMul('f32', temp_2, temp_2, registers.Low(result_scale))
171    elif lanes_count == 3 or lanes_count == 4:
172      emitter.EmitVMul('f32', temp_2, temp_2, result_scale)
173
174  emitter.EmitNewline()
175  emitter.EmitComment('Store results.')
176  if lanes_count == 1:
177    emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp)],
178                        emitter.DereferenceIncrement(results, None))
179    emitter.EmitVStore('1.32', emitter.Lane(temp_2, 0),
180                       emitter.Dereference(results, None))
181  elif lanes_count == 2:
182    emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
183                                 temp_2], emitter.Dereference(results, None))
184  elif lanes_count == 3:
185    emitter.EmitVStoreA(
186        '1.32',
187        [registers.Low(temp), registers.High(temp), registers.Low(temp_2)],
188        emitter.DereferenceIncrement(results, None))
189    emitter.EmitVStore('1.32', emitter.Lane(
190        registers.High(temp_2), 0), emitter.Dereference(results, None))
191  elif lanes_count == 4:
192    emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
193                                 registers.Low(temp_2), registers.High(temp_2)],
194                        emitter.Dereference(results, None))
195
196
197def BuildName(result_type, lhs_add, rhs_add, lanes):
198  name = 'mul_1x8_%dx8_%s' % (lanes, result_type)
199  if lhs_add:
200    name += '_lhsadd'
201  if rhs_add:
202    name += '_rhsadd'
203  return name
204
205
206def CppResultType(result_type):
207  if result_type is 'int32':
208    return 'std::int32_t*'
209  elif result_type is 'float':
210    return 'float*'
211  else:
212    raise ConfigurationError('Unsupported result type: %s' % result_type)
213
214
215def GetParameters(result_type):
216  params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs_1'],
217            ['const std::uint8_t*', 'rhs_2'], ['std::int32_t', 'count'],
218            [CppResultType(result_type), 'result']]
219  if result_type is 'float':
220    params.append(['float', 'result_scale'])
221  return params
222
223
224def GenerateAndClearAggregators(emitter, registers, aggregator_count):
225  """Prepare aggregators and emit aggregator clear code."""
226  emitter.EmitNewline()
227  emitter.EmitComment('Clear aggregators.')
228  aggregators = []
229  for i in range(aggregator_count):
230    aggregator = registers.QuadRegister()
231    aggregators.append(aggregator)
232    if i < 3:
233      emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
234    else:
235      emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
236  emitter.EmitNewline()
237  return aggregators
238
239
240def GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes_count):
241  """Generates the 1xN multiplication primitive."""
242  if lanes_count < 1 or lanes_count > 4:
243    raise ConfigurationError('Lanes should be: 1, 2, 3 or 4.')
244
245  emitter.EmitFunctionBeginA(
246      BuildName(result_type, lhs_add, rhs_add, lanes_count + 4),
247      GetParameters(result_type), 'inline void')
248
249  emitter.EmitAssert('count % 8 == 0')
250  emitter.EmitAssert('count >= 8')
251  emitter.EmitAsmBegin()
252
253  registers = neon_emitter.NeonRegisters()
254
255  count = registers.MapParameter('count')
256
257  lhs = registers.MapParameter('lhs')
258  rhs_1 = registers.MapParameter('rhs_1')
259  rhs_2 = registers.MapParameter('rhs_2')
260
261  emitter.EmitPld(lhs)
262  emitter.EmitPld(rhs_1)
263  emitter.EmitPld(rhs_2)
264
265  aggregators = GenerateAndClearAggregators(emitter, registers, lanes_count + 4)
266
267  GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
268                                count, lhs, rhs_1, rhs_2)
269  GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
270                                result_type, lhs_add, rhs_add, lhs, rhs_1,
271                                rhs_2, registers.MapParameter('result'))
272
273  emitter.EmitAsmEnd(registers.MappedParameters(), [],
274                     registers.Clobbers() + ['cc', 'memory'])
275  emitter.EmitFunctionEnd()
276
277
278def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
279  for lanes in range(1, 5):
280    GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes)
281    emitter.EmitNewline()
282
283
284if __name__ == '__main__':
285  GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True)
286