1 /*
2  *  Copyright (c) 2016 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 #include <assert.h>
13 
14 #include "./vpx_dsp_rtcd.h"
15 #include "vpx/vpx_integer.h"
16 #include "vpx_dsp/arm/transpose_neon.h"
17 
18 extern const int16_t vpx_rv[];
19 
average_k_out(const uint8x8_t a2,const uint8x8_t a1,const uint8x8_t v0,const uint8x8_t b1,const uint8x8_t b2)20 static uint8x8_t average_k_out(const uint8x8_t a2, const uint8x8_t a1,
21                                const uint8x8_t v0, const uint8x8_t b1,
22                                const uint8x8_t b2) {
23   const uint8x8_t k1 = vrhadd_u8(a2, a1);
24   const uint8x8_t k2 = vrhadd_u8(b2, b1);
25   const uint8x8_t k3 = vrhadd_u8(k1, k2);
26   return vrhadd_u8(k3, v0);
27 }
28 
generate_mask(const uint8x8_t a2,const uint8x8_t a1,const uint8x8_t v0,const uint8x8_t b1,const uint8x8_t b2,const uint8x8_t filter)29 static uint8x8_t generate_mask(const uint8x8_t a2, const uint8x8_t a1,
30                                const uint8x8_t v0, const uint8x8_t b1,
31                                const uint8x8_t b2, const uint8x8_t filter) {
32   const uint8x8_t a2_v0 = vabd_u8(a2, v0);
33   const uint8x8_t a1_v0 = vabd_u8(a1, v0);
34   const uint8x8_t b1_v0 = vabd_u8(b1, v0);
35   const uint8x8_t b2_v0 = vabd_u8(b2, v0);
36 
37   uint8x8_t max = vmax_u8(a2_v0, a1_v0);
38   max = vmax_u8(b1_v0, max);
39   max = vmax_u8(b2_v0, max);
40   return vclt_u8(max, filter);
41 }
42 
generate_output(const uint8x8_t a2,const uint8x8_t a1,const uint8x8_t v0,const uint8x8_t b1,const uint8x8_t b2,const uint8x8_t filter)43 static uint8x8_t generate_output(const uint8x8_t a2, const uint8x8_t a1,
44                                  const uint8x8_t v0, const uint8x8_t b1,
45                                  const uint8x8_t b2, const uint8x8_t filter) {
46   const uint8x8_t k_out = average_k_out(a2, a1, v0, b1, b2);
47   const uint8x8_t mask = generate_mask(a2, a1, v0, b1, b2, filter);
48 
49   return vbsl_u8(mask, k_out, v0);
50 }
51 
52 // Same functions but for uint8x16_t.
average_k_outq(const uint8x16_t a2,const uint8x16_t a1,const uint8x16_t v0,const uint8x16_t b1,const uint8x16_t b2)53 static uint8x16_t average_k_outq(const uint8x16_t a2, const uint8x16_t a1,
54                                  const uint8x16_t v0, const uint8x16_t b1,
55                                  const uint8x16_t b2) {
56   const uint8x16_t k1 = vrhaddq_u8(a2, a1);
57   const uint8x16_t k2 = vrhaddq_u8(b2, b1);
58   const uint8x16_t k3 = vrhaddq_u8(k1, k2);
59   return vrhaddq_u8(k3, v0);
60 }
61 
generate_maskq(const uint8x16_t a2,const uint8x16_t a1,const uint8x16_t v0,const uint8x16_t b1,const uint8x16_t b2,const uint8x16_t filter)62 static uint8x16_t generate_maskq(const uint8x16_t a2, const uint8x16_t a1,
63                                  const uint8x16_t v0, const uint8x16_t b1,
64                                  const uint8x16_t b2, const uint8x16_t filter) {
65   const uint8x16_t a2_v0 = vabdq_u8(a2, v0);
66   const uint8x16_t a1_v0 = vabdq_u8(a1, v0);
67   const uint8x16_t b1_v0 = vabdq_u8(b1, v0);
68   const uint8x16_t b2_v0 = vabdq_u8(b2, v0);
69 
70   uint8x16_t max = vmaxq_u8(a2_v0, a1_v0);
71   max = vmaxq_u8(b1_v0, max);
72   max = vmaxq_u8(b2_v0, max);
73   return vcltq_u8(max, filter);
74 }
75 
generate_outputq(const uint8x16_t a2,const uint8x16_t a1,const uint8x16_t v0,const uint8x16_t b1,const uint8x16_t b2,const uint8x16_t filter)76 static uint8x16_t generate_outputq(const uint8x16_t a2, const uint8x16_t a1,
77                                    const uint8x16_t v0, const uint8x16_t b1,
78                                    const uint8x16_t b2,
79                                    const uint8x16_t filter) {
80   const uint8x16_t k_out = average_k_outq(a2, a1, v0, b1, b2);
81   const uint8x16_t mask = generate_maskq(a2, a1, v0, b1, b2, filter);
82 
83   return vbslq_u8(mask, k_out, v0);
84 }
85 
vpx_post_proc_down_and_across_mb_row_neon(uint8_t * src_ptr,uint8_t * dst_ptr,int src_stride,int dst_stride,int cols,uint8_t * f,int size)86 void vpx_post_proc_down_and_across_mb_row_neon(uint8_t *src_ptr,
87                                                uint8_t *dst_ptr, int src_stride,
88                                                int dst_stride, int cols,
89                                                uint8_t *f, int size) {
90   uint8_t *src, *dst;
91   int row;
92   int col;
93 
94   // Process a stripe of macroblocks. The stripe will be a multiple of 16 (for
95   // Y) or 8 (for U/V) wide (cols) and the height (size) will be 16 (for Y) or 8
96   // (for U/V).
97   assert((size == 8 || size == 16) && cols % 8 == 0);
98 
99   // While columns of length 16 can be processed, load them.
100   for (col = 0; col < cols - 8; col += 16) {
101     uint8x16_t a0, a1, a2, a3, a4, a5, a6, a7;
102     src = src_ptr - 2 * src_stride;
103     dst = dst_ptr;
104 
105     a0 = vld1q_u8(src);
106     src += src_stride;
107     a1 = vld1q_u8(src);
108     src += src_stride;
109     a2 = vld1q_u8(src);
110     src += src_stride;
111     a3 = vld1q_u8(src);
112     src += src_stride;
113 
114     for (row = 0; row < size; row += 4) {
115       uint8x16_t v_out_0, v_out_1, v_out_2, v_out_3;
116       const uint8x16_t filterq = vld1q_u8(f + col);
117 
118       a4 = vld1q_u8(src);
119       src += src_stride;
120       a5 = vld1q_u8(src);
121       src += src_stride;
122       a6 = vld1q_u8(src);
123       src += src_stride;
124       a7 = vld1q_u8(src);
125       src += src_stride;
126 
127       v_out_0 = generate_outputq(a0, a1, a2, a3, a4, filterq);
128       v_out_1 = generate_outputq(a1, a2, a3, a4, a5, filterq);
129       v_out_2 = generate_outputq(a2, a3, a4, a5, a6, filterq);
130       v_out_3 = generate_outputq(a3, a4, a5, a6, a7, filterq);
131 
132       vst1q_u8(dst, v_out_0);
133       dst += dst_stride;
134       vst1q_u8(dst, v_out_1);
135       dst += dst_stride;
136       vst1q_u8(dst, v_out_2);
137       dst += dst_stride;
138       vst1q_u8(dst, v_out_3);
139       dst += dst_stride;
140 
141       // Rotate over to the next slot.
142       a0 = a4;
143       a1 = a5;
144       a2 = a6;
145       a3 = a7;
146     }
147 
148     src_ptr += 16;
149     dst_ptr += 16;
150   }
151 
152   // Clean up any left over column of length 8.
153   if (col != cols) {
154     uint8x8_t a0, a1, a2, a3, a4, a5, a6, a7;
155     src = src_ptr - 2 * src_stride;
156     dst = dst_ptr;
157 
158     a0 = vld1_u8(src);
159     src += src_stride;
160     a1 = vld1_u8(src);
161     src += src_stride;
162     a2 = vld1_u8(src);
163     src += src_stride;
164     a3 = vld1_u8(src);
165     src += src_stride;
166 
167     for (row = 0; row < size; row += 4) {
168       uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3;
169       const uint8x8_t filter = vld1_u8(f + col);
170 
171       a4 = vld1_u8(src);
172       src += src_stride;
173       a5 = vld1_u8(src);
174       src += src_stride;
175       a6 = vld1_u8(src);
176       src += src_stride;
177       a7 = vld1_u8(src);
178       src += src_stride;
179 
180       v_out_0 = generate_output(a0, a1, a2, a3, a4, filter);
181       v_out_1 = generate_output(a1, a2, a3, a4, a5, filter);
182       v_out_2 = generate_output(a2, a3, a4, a5, a6, filter);
183       v_out_3 = generate_output(a3, a4, a5, a6, a7, filter);
184 
185       vst1_u8(dst, v_out_0);
186       dst += dst_stride;
187       vst1_u8(dst, v_out_1);
188       dst += dst_stride;
189       vst1_u8(dst, v_out_2);
190       dst += dst_stride;
191       vst1_u8(dst, v_out_3);
192       dst += dst_stride;
193 
194       // Rotate over to the next slot.
195       a0 = a4;
196       a1 = a5;
197       a2 = a6;
198       a3 = a7;
199     }
200 
201     // Not strictly necessary but makes resetting dst_ptr easier.
202     dst_ptr += 8;
203   }
204 
205   dst_ptr -= cols;
206 
207   for (row = 0; row < size; row += 8) {
208     uint8x8_t a0, a1, a2, a3;
209     uint8x8_t b0, b1, b2, b3, b4, b5, b6, b7;
210 
211     src = dst_ptr;
212     dst = dst_ptr;
213 
214     // Load 8 values, transpose 4 of them, and discard 2 because they will be
215     // reloaded later.
216     load_and_transpose_u8_4x8(src, dst_stride, &a0, &a1, &a2, &a3);
217     a3 = a1;
218     a2 = a1 = a0;  // Extend left border.
219 
220     src += 2;
221 
222     for (col = 0; col < cols; col += 8) {
223       uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3, v_out_4, v_out_5, v_out_6,
224           v_out_7;
225       // Although the filter is meant to be applied vertically and is instead
226       // being applied horizontally here it's OK because it's set in blocks of 8
227       // (or 16).
228       const uint8x8_t filter = vld1_u8(f + col);
229 
230       load_and_transpose_u8_8x8(src, dst_stride, &b0, &b1, &b2, &b3, &b4, &b5,
231                                 &b6, &b7);
232 
233       if (col + 8 == cols) {
234         // Last row. Extend border (b5).
235         b6 = b7 = b5;
236       }
237 
238       v_out_0 = generate_output(a0, a1, a2, a3, b0, filter);
239       v_out_1 = generate_output(a1, a2, a3, b0, b1, filter);
240       v_out_2 = generate_output(a2, a3, b0, b1, b2, filter);
241       v_out_3 = generate_output(a3, b0, b1, b2, b3, filter);
242       v_out_4 = generate_output(b0, b1, b2, b3, b4, filter);
243       v_out_5 = generate_output(b1, b2, b3, b4, b5, filter);
244       v_out_6 = generate_output(b2, b3, b4, b5, b6, filter);
245       v_out_7 = generate_output(b3, b4, b5, b6, b7, filter);
246 
247       transpose_and_store_u8_8x8(dst, dst_stride, v_out_0, v_out_1, v_out_2,
248                                  v_out_3, v_out_4, v_out_5, v_out_6, v_out_7);
249 
250       a0 = b4;
251       a1 = b5;
252       a2 = b6;
253       a3 = b7;
254 
255       src += 8;
256       dst += 8;
257     }
258 
259     dst_ptr += 8 * dst_stride;
260   }
261 }
262 
263 // sum += x;
264 // sumsq += x * y;
accumulate_sum_sumsq(const int16x4_t x,const int32x4_t xy,int16x4_t * const sum,int32x4_t * const sumsq)265 static void accumulate_sum_sumsq(const int16x4_t x, const int32x4_t xy,
266                                  int16x4_t *const sum, int32x4_t *const sumsq) {
267   const int16x4_t zero = vdup_n_s16(0);
268   const int32x4_t zeroq = vdupq_n_s32(0);
269 
270   // Add in the first set because vext doesn't work with '0'.
271   *sum = vadd_s16(*sum, x);
272   *sumsq = vaddq_s32(*sumsq, xy);
273 
274   // Shift x and xy to the right and sum. vext requires an immediate.
275   *sum = vadd_s16(*sum, vext_s16(zero, x, 1));
276   *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 1));
277 
278   *sum = vadd_s16(*sum, vext_s16(zero, x, 2));
279   *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 2));
280 
281   *sum = vadd_s16(*sum, vext_s16(zero, x, 3));
282   *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 3));
283 }
284 
285 // Generate mask based on (sumsq * 15 - sum * sum < flimit)
calculate_mask(const int16x4_t sum,const int32x4_t sumsq,const int32x4_t f,const int32x4_t fifteen)286 static uint16x4_t calculate_mask(const int16x4_t sum, const int32x4_t sumsq,
287                                  const int32x4_t f, const int32x4_t fifteen) {
288   const int32x4_t a = vmulq_s32(sumsq, fifteen);
289   const int32x4_t b = vmlsl_s16(a, sum, sum);
290   const uint32x4_t mask32 = vcltq_s32(b, f);
291   return vmovn_u32(mask32);
292 }
293 
combine_mask(const int16x4_t sum_low,const int16x4_t sum_high,const int32x4_t sumsq_low,const int32x4_t sumsq_high,const int32x4_t f)294 static uint8x8_t combine_mask(const int16x4_t sum_low, const int16x4_t sum_high,
295                               const int32x4_t sumsq_low,
296                               const int32x4_t sumsq_high, const int32x4_t f) {
297   const int32x4_t fifteen = vdupq_n_s32(15);
298   const uint16x4_t mask16_low = calculate_mask(sum_low, sumsq_low, f, fifteen);
299   const uint16x4_t mask16_high =
300       calculate_mask(sum_high, sumsq_high, f, fifteen);
301   return vmovn_u16(vcombine_u16(mask16_low, mask16_high));
302 }
303 
304 // Apply filter of (8 + sum + s[c]) >> 4.
filter_pixels(const int16x8_t sum,const uint8x8_t s)305 static uint8x8_t filter_pixels(const int16x8_t sum, const uint8x8_t s) {
306   const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s));
307   const int16x8_t sum_s = vaddq_s16(sum, s16);
308 
309   return vqrshrun_n_s16(sum_s, 4);
310 }
311 
vpx_mbpost_proc_across_ip_neon(uint8_t * src,int pitch,int rows,int cols,int flimit)312 void vpx_mbpost_proc_across_ip_neon(uint8_t *src, int pitch, int rows, int cols,
313                                     int flimit) {
314   int row, col;
315   const int32x4_t f = vdupq_n_s32(flimit);
316 
317   assert(cols % 8 == 0);
318 
319   for (row = 0; row < rows; ++row) {
320     // Sum the first 8 elements, which are extended from s[0].
321     // sumsq gets primed with +16.
322     int sumsq = src[0] * src[0] * 9 + 16;
323     int sum = src[0] * 9;
324 
325     uint8x8_t left_context, s, right_context;
326     int16x4_t sum_low, sum_high;
327     int32x4_t sumsq_low, sumsq_high;
328 
329     // Sum (+square) the next 6 elements.
330     // Skip [0] because it's included above.
331     for (col = 1; col <= 6; ++col) {
332       sumsq += src[col] * src[col];
333       sum += src[col];
334     }
335 
336     // Prime the sums. Later the loop uses the _high values to prime the new
337     // vectors.
338     sumsq_high = vdupq_n_s32(sumsq);
339     sum_high = vdup_n_s16(sum);
340 
341     // Manually extend the left border.
342     left_context = vdup_n_u8(src[0]);
343 
344     for (col = 0; col < cols; col += 8) {
345       uint8x8_t mask, output;
346       int16x8_t x, y;
347       int32x4_t xy_low, xy_high;
348 
349       s = vld1_u8(src + col);
350 
351       if (col + 8 == cols) {
352         // Last row. Extend border.
353         right_context = vdup_n_u8(src[col + 7]);
354       } else {
355         right_context = vld1_u8(src + col + 7);
356       }
357 
358       x = vreinterpretq_s16_u16(vsubl_u8(right_context, left_context));
359       y = vreinterpretq_s16_u16(vaddl_u8(right_context, left_context));
360       xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
361       xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
362 
363       // Catch up to the last sum'd value.
364       sum_low = vdup_lane_s16(sum_high, 3);
365       sumsq_low = vdupq_lane_s32(vget_high_s32(sumsq_high), 1);
366 
367       accumulate_sum_sumsq(vget_low_s16(x), xy_low, &sum_low, &sumsq_low);
368 
369       // Need to do this sequentially because we need the max value from
370       // sum_low.
371       sum_high = vdup_lane_s16(sum_low, 3);
372       sumsq_high = vdupq_lane_s32(vget_high_s32(sumsq_low), 1);
373 
374       accumulate_sum_sumsq(vget_high_s16(x), xy_high, &sum_high, &sumsq_high);
375 
376       mask = combine_mask(sum_low, sum_high, sumsq_low, sumsq_high, f);
377 
378       output = filter_pixels(vcombine_s16(sum_low, sum_high), s);
379       output = vbsl_u8(mask, output, s);
380 
381       vst1_u8(src + col, output);
382 
383       left_context = s;
384     }
385 
386     src += pitch;
387   }
388 }
389 
390 // Apply filter of (vpx_rv + sum + s[c]) >> 4.
filter_pixels_rv(const int16x8_t sum,const uint8x8_t s,const int16x8_t rv)391 static uint8x8_t filter_pixels_rv(const int16x8_t sum, const uint8x8_t s,
392                                   const int16x8_t rv) {
393   const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s));
394   const int16x8_t sum_s = vaddq_s16(sum, s16);
395   const int16x8_t rounded = vaddq_s16(sum_s, rv);
396 
397   return vqshrun_n_s16(rounded, 4);
398 }
399 
vpx_mbpost_proc_down_neon(uint8_t * dst,int pitch,int rows,int cols,int flimit)400 void vpx_mbpost_proc_down_neon(uint8_t *dst, int pitch, int rows, int cols,
401                                int flimit) {
402   int row, col, i;
403   const int32x4_t f = vdupq_n_s32(flimit);
404   uint8x8_t below_context = vdup_n_u8(0);
405 
406   // 8 columns are processed at a time.
407   // If rows is less than 8 the bottom border extension fails.
408   assert(cols % 8 == 0);
409   assert(rows >= 8);
410 
411   // Load and keep the first 8 values in memory. Process a vertical stripe that
412   // is 8 wide.
413   for (col = 0; col < cols; col += 8) {
414     uint8x8_t s, above_context[8];
415     int16x8_t sum, sum_tmp;
416     int32x4_t sumsq_low, sumsq_high;
417 
418     // Load and extend the top border.
419     s = vld1_u8(dst);
420     for (i = 0; i < 8; i++) {
421       above_context[i] = s;
422     }
423 
424     sum_tmp = vreinterpretq_s16_u16(vmovl_u8(s));
425 
426     // sum * 9
427     sum = vmulq_n_s16(sum_tmp, 9);
428 
429     // (sum * 9) * sum == sum * sum * 9
430     sumsq_low = vmull_s16(vget_low_s16(sum), vget_low_s16(sum_tmp));
431     sumsq_high = vmull_s16(vget_high_s16(sum), vget_high_s16(sum_tmp));
432 
433     // Load and discard the next 6 values to prime sum and sumsq.
434     for (i = 1; i <= 6; ++i) {
435       const uint8x8_t a = vld1_u8(dst + i * pitch);
436       const int16x8_t b = vreinterpretq_s16_u16(vmovl_u8(a));
437       sum = vaddq_s16(sum, b);
438 
439       sumsq_low = vmlal_s16(sumsq_low, vget_low_s16(b), vget_low_s16(b));
440       sumsq_high = vmlal_s16(sumsq_high, vget_high_s16(b), vget_high_s16(b));
441     }
442 
443     for (row = 0; row < rows; ++row) {
444       uint8x8_t mask, output;
445       int16x8_t x, y;
446       int32x4_t xy_low, xy_high;
447 
448       s = vld1_u8(dst + row * pitch);
449 
450       // Extend the bottom border.
451       if (row + 7 < rows) {
452         below_context = vld1_u8(dst + (row + 7) * pitch);
453       }
454 
455       x = vreinterpretq_s16_u16(vsubl_u8(below_context, above_context[0]));
456       y = vreinterpretq_s16_u16(vaddl_u8(below_context, above_context[0]));
457       xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
458       xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
459 
460       sum = vaddq_s16(sum, x);
461 
462       sumsq_low = vaddq_s32(sumsq_low, xy_low);
463       sumsq_high = vaddq_s32(sumsq_high, xy_high);
464 
465       mask = combine_mask(vget_low_s16(sum), vget_high_s16(sum), sumsq_low,
466                           sumsq_high, f);
467 
468       output = filter_pixels_rv(sum, s, vld1q_s16(vpx_rv + (row & 127)));
469       output = vbsl_u8(mask, output, s);
470 
471       vst1_u8(dst + row * pitch, output);
472 
473       above_context[0] = above_context[1];
474       above_context[1] = above_context[2];
475       above_context[2] = above_context[3];
476       above_context[3] = above_context[4];
477       above_context[4] = above_context[5];
478       above_context[5] = above_context[6];
479       above_context[6] = above_context[7];
480       above_context[7] = s;
481     }
482 
483     dst += 8;
484   }
485 }
486