1"""Zip primitive used by the GEMM function.
2
3Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to
4multiply of 8 length with zeros. Calculates row sums and appends those at the
5end.
6"""
7
8import neon_emitter
9
10
11class Error(Exception):
12  """Module level error."""
13
14
15class ConfigurationError(Error):
16  """Unsupported configuration."""
17
18
19class ZipLane(object):
20
21  def __init__(self, input_address, load, aggregator):
22    self.input_address = input_address
23    self.load = load
24    self.aggregator = aggregator
25
26
27def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride):
28  """Prepares read lanes for the zip operation.
29
30  Args:
31    emitter: ARM/NEON emitter.
32    registers: ARM/NEON registers state.
33    zip_lanes: number of lanes to prepare.
34    input_address: register that contains the input address for the first lane.
35    stride: memory stride for lane inputs.
36
37  Returns:
38    Array of ZipLane objects.
39  """
40  lanes = []
41  last_address_register = input_address
42  for i in range(0, zip_lanes):
43    if not i:
44      lanes.append(ZipLane(input_address, registers.DoubleRegister(),
45                           registers.QuadRegister(2)))
46    else:
47      address_register = registers.GeneralRegister()
48      lanes.append(ZipLane(address_register, registers.DoubleRegister(),
49                           registers.QuadRegister(2)))
50      emitter.EmitAdd(address_register, last_address_register, stride)
51      last_address_register = address_register
52  return lanes
53
54
55def BuildName(zip_lanes, leftovers, aligned):
56  name = 'zip_%dx8' % zip_lanes
57  if leftovers:
58    name += '_%d' % leftovers
59  if aligned:
60    name += '_aligned'
61  return name
62
63
64def GenerateClearAggregators(emitter, lanes):
65  for lane in lanes:
66    emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0))
67
68
69def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment):
70  """Emit inner loop code for reading N lanes and interweaving them."""
71  emitter.EmitNewline()
72  emitter.EmitComment('Load Aggregate Store.')
73
74  for lane in lanes:
75    emitter.EmitVLoad(
76        '1.8', lane.load,
77        emitter.DereferenceIncrement(lane.input_address, alignment))
78
79  store_registers = []
80  for lane in lanes:
81    emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
82    store_registers.append(lane.load)
83
84  emitter.EmitVStoreA('1.8', store_registers,
85                      emitter.DereferenceIncrement(output_address, 64))
86
87
88def GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes,
89                                       output_address):
90  """Handle leftovers when count is not a multiply of 8."""
91  emitter.EmitNewline()
92  emitter.EmitComment('Leftover Load Aggregate Store.')
93
94  # Clear load registers.
95  for lane in lanes:
96    emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0))
97
98  if leftovers == 1:
99    # Load 8 bits.
100    for lane in lanes:
101      emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0),
102                        emitter.Dereference(lane.input_address, None))
103  elif leftovers == 2:
104    # Load 16 bits.
105    for lane in lanes:
106      emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0),
107                        emitter.Dereference(lane.input_address, None))
108  elif leftovers == 3:
109    # Load 16 bits.
110    for lane in lanes:
111      emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0),
112                        emitter.DereferenceIncrement(lane.input_address, None))
113    # Load 8 bits.
114    for lane in lanes:
115      emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2),
116                        emitter.Dereference(lane.input_address, None))
117  elif leftovers == 4:
118    # Load 32 bits.
119    for lane in lanes:
120      emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
121                        emitter.Dereference(lane.input_address, None))
122  elif leftovers == 5:
123    # Load 32 bits..
124    for lane in lanes:
125      emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
126                        emitter.DereferenceIncrement(lane.input_address, None))
127    # Load 8 bits.
128    for lane in lanes:
129      emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4),
130                        emitter.Dereference(lane.input_address, None))
131  elif leftovers == 6:
132    # Load 32 bits..
133    for lane in lanes:
134      emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
135                        emitter.DereferenceIncrement(lane.input_address, None))
136    # Load 16 bits.
137    for lane in lanes:
138      emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2),
139                        emitter.Dereference(lane.input_address, None))
140  elif leftovers == 7:
141    # Load 32 bits..
142    for lane in lanes:
143      emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
144                        emitter.DereferenceIncrement(lane.input_address, None))
145    # Load 16 bits.
146    for lane in lanes:
147      emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2),
148                        emitter.DereferenceIncrement(lane.input_address, None))
149    # Load 8 bits.
150    for lane in lanes:
151      emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6),
152                        emitter.Dereference(lane.input_address, None))
153  else:
154    raise ConfigurationError('Unsupported leftover num: %d' % leftovers)
155
156  # Aggregate.
157  store_registers = []
158  for lane in lanes:
159    emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
160    store_registers.append(lane.load)
161
162  # Store.
163  emitter.EmitVStoreA('1.8', store_registers,
164                      emitter.DereferenceIncrement(output_address, 64))
165
166
167def GenerateAggregatorReduction(emitter, registers, lanes, output_address,
168                                multiplicative_offset, additive_offset):
169  """Reduce 4 lane sum aggregators to 1 value and store the sums."""
170  emitter.EmitNewline()
171  emitter.EmitComment('Aggregator Reduction.')
172
173  multiplier = registers.DoubleRegister()
174  emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset)
175  offset = registers.QuadRegister()
176  emitter.EmitVDup('32', offset, additive_offset)
177
178  lane_temps = []
179  for lane in lanes:
180    emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator)
181
182  for lane in lanes:
183    lane_temp = registers.DoubleRegister()
184    lane_temps.append(lane_temp)
185    emitter.EmitVPadd('u32', lane_temp, registers.Low(lane.aggregator),
186                      registers.High(lane.aggregator))
187
188  temp = registers.QuadRegister()
189  low = registers.Low(temp)
190  high = registers.High(temp)
191
192  if len(lanes) == 1:
193    emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0])
194  elif len(lanes) == 2:
195    emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
196  elif len(lanes) == 3:
197    emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
198    emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2])
199  elif len(lanes) == 4:
200    emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
201    emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3])
202  else:
203    raise ConfigurationError('Unexpected number of aggregators to reduce: %d' %
204                             len(lanes))
205
206  emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0))
207  emitter.EmitVAdd('i32', temp, temp, offset)
208
209  if len(lanes) == 1:
210    emitter.EmitVStore('1.32', emitter.Lane(low, 0),
211                       emitter.Dereference(output_address, None))
212  elif len(lanes) == 2:
213    emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64))
214  elif len(lanes) == 3:
215    emitter.EmitVStore('1.32', low,
216                       emitter.DereferenceIncrement(output_address, 64))
217    emitter.EmitVStore('1.32', emitter.Lane(high, 0),
218                       emitter.Dereference(output_address, None))
219  elif len(lanes) == 4:
220    emitter.EmitVStoreA('1.32', [low, high],
221                        emitter.DereferenceIncrement(output_address, 64))
222
223
224def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned):
225  """Emit the zip function for a given number of rows and row size leftovers."""
226  if leftovers < 0 or leftovers > 7:
227    raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
228  if zip_lanes < 1 or zip_lanes > 4:
229    raise ConfigurationError('Zip_lanes should should be 1, 2, 3 or 4.')
230
231  name = BuildName(zip_lanes, leftovers, aligned)
232
233  emitter.EmitFunctionBeginA(
234      name, [['const std::uint8_t*', 'source'], ['std::int32_t', 'count'],
235             ['std::int32_t', 'stride'], ['std::uint8_t*', 'destination'],
236             ['std::int32_t', 'multiplicative_offset'],
237             ['std::int32_t', 'additive_offset']], 'void')
238  emitter.EmitAssert('count %% 8 == %d' % leftovers)
239  emitter.EmitAssert('count <= 2048')
240  emitter.EmitAssert('count >= 8')
241  emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
242  if aligned:
243    emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
244    if zip_lanes > 1:
245      emitter.EmitAssert('stride % 8 == 0')
246  emitter.EmitAsmBegin()
247
248  registers = neon_emitter.NeonRegisters()
249
250  count = registers.MapParameter('count')
251  output_address = registers.MapParameter('destination')
252
253  lanes = GenerateZipLanes(emitter, registers, zip_lanes,
254                           registers.MapParameter('source'),
255                           registers.MapParameter('stride'))
256
257  if leftovers:
258    emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers))
259
260  GenerateClearAggregators(emitter, lanes)
261
262  emitter.EmitNewline()
263  emitter.EmitNumericalLabel(1)
264  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
265
266  GenerateLoadAggregateStore(emitter, lanes, output_address, 64 if aligned else
267                             None)
268
269  emitter.EmitNewline()
270  emitter.EmitBneBack(1)
271
272  if leftovers:
273    GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes,
274                                       output_address)
275
276  GenerateAggregatorReduction(emitter, registers, lanes, output_address,
277                              registers.MapParameter('multiplicative_offset'),
278                              registers.MapParameter('additive_offset'))
279
280  emitter.EmitAsmEnd(registers.MappedParameters(), [],
281                     registers.Clobbers() + ['cc', 'memory'])
282  emitter.EmitFunctionEnd()
283
284
285def GenerateFunctions(emitter):
286  for aligned in [True, False]:
287    for lanes in range(1, 5):
288      for leftovers in range(0, 8):
289        GenerateZipNx8(emitter, lanes, leftovers, aligned)
290        emitter.EmitNewline()
291