1 /*
2 * Copyright (c) 2016 The WebM project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <arm_neon.h>
12 #include <assert.h>
13
14 #include "./vpx_dsp_rtcd.h"
15 #include "vpx/vpx_integer.h"
16 #include "vpx_dsp/arm/transpose_neon.h"
17
18 extern const int16_t vpx_rv[];
19
average_k_out(const uint8x8_t a2,const uint8x8_t a1,const uint8x8_t v0,const uint8x8_t b1,const uint8x8_t b2)20 static uint8x8_t average_k_out(const uint8x8_t a2, const uint8x8_t a1,
21 const uint8x8_t v0, const uint8x8_t b1,
22 const uint8x8_t b2) {
23 const uint8x8_t k1 = vrhadd_u8(a2, a1);
24 const uint8x8_t k2 = vrhadd_u8(b2, b1);
25 const uint8x8_t k3 = vrhadd_u8(k1, k2);
26 return vrhadd_u8(k3, v0);
27 }
28
generate_mask(const uint8x8_t a2,const uint8x8_t a1,const uint8x8_t v0,const uint8x8_t b1,const uint8x8_t b2,const uint8x8_t filter)29 static uint8x8_t generate_mask(const uint8x8_t a2, const uint8x8_t a1,
30 const uint8x8_t v0, const uint8x8_t b1,
31 const uint8x8_t b2, const uint8x8_t filter) {
32 const uint8x8_t a2_v0 = vabd_u8(a2, v0);
33 const uint8x8_t a1_v0 = vabd_u8(a1, v0);
34 const uint8x8_t b1_v0 = vabd_u8(b1, v0);
35 const uint8x8_t b2_v0 = vabd_u8(b2, v0);
36
37 uint8x8_t max = vmax_u8(a2_v0, a1_v0);
38 max = vmax_u8(b1_v0, max);
39 max = vmax_u8(b2_v0, max);
40 return vclt_u8(max, filter);
41 }
42
generate_output(const uint8x8_t a2,const uint8x8_t a1,const uint8x8_t v0,const uint8x8_t b1,const uint8x8_t b2,const uint8x8_t filter)43 static uint8x8_t generate_output(const uint8x8_t a2, const uint8x8_t a1,
44 const uint8x8_t v0, const uint8x8_t b1,
45 const uint8x8_t b2, const uint8x8_t filter) {
46 const uint8x8_t k_out = average_k_out(a2, a1, v0, b1, b2);
47 const uint8x8_t mask = generate_mask(a2, a1, v0, b1, b2, filter);
48
49 return vbsl_u8(mask, k_out, v0);
50 }
51
52 // Same functions but for uint8x16_t.
average_k_outq(const uint8x16_t a2,const uint8x16_t a1,const uint8x16_t v0,const uint8x16_t b1,const uint8x16_t b2)53 static uint8x16_t average_k_outq(const uint8x16_t a2, const uint8x16_t a1,
54 const uint8x16_t v0, const uint8x16_t b1,
55 const uint8x16_t b2) {
56 const uint8x16_t k1 = vrhaddq_u8(a2, a1);
57 const uint8x16_t k2 = vrhaddq_u8(b2, b1);
58 const uint8x16_t k3 = vrhaddq_u8(k1, k2);
59 return vrhaddq_u8(k3, v0);
60 }
61
generate_maskq(const uint8x16_t a2,const uint8x16_t a1,const uint8x16_t v0,const uint8x16_t b1,const uint8x16_t b2,const uint8x16_t filter)62 static uint8x16_t generate_maskq(const uint8x16_t a2, const uint8x16_t a1,
63 const uint8x16_t v0, const uint8x16_t b1,
64 const uint8x16_t b2, const uint8x16_t filter) {
65 const uint8x16_t a2_v0 = vabdq_u8(a2, v0);
66 const uint8x16_t a1_v0 = vabdq_u8(a1, v0);
67 const uint8x16_t b1_v0 = vabdq_u8(b1, v0);
68 const uint8x16_t b2_v0 = vabdq_u8(b2, v0);
69
70 uint8x16_t max = vmaxq_u8(a2_v0, a1_v0);
71 max = vmaxq_u8(b1_v0, max);
72 max = vmaxq_u8(b2_v0, max);
73 return vcltq_u8(max, filter);
74 }
75
generate_outputq(const uint8x16_t a2,const uint8x16_t a1,const uint8x16_t v0,const uint8x16_t b1,const uint8x16_t b2,const uint8x16_t filter)76 static uint8x16_t generate_outputq(const uint8x16_t a2, const uint8x16_t a1,
77 const uint8x16_t v0, const uint8x16_t b1,
78 const uint8x16_t b2,
79 const uint8x16_t filter) {
80 const uint8x16_t k_out = average_k_outq(a2, a1, v0, b1, b2);
81 const uint8x16_t mask = generate_maskq(a2, a1, v0, b1, b2, filter);
82
83 return vbslq_u8(mask, k_out, v0);
84 }
85
vpx_post_proc_down_and_across_mb_row_neon(uint8_t * src_ptr,uint8_t * dst_ptr,int src_stride,int dst_stride,int cols,uint8_t * f,int size)86 void vpx_post_proc_down_and_across_mb_row_neon(uint8_t *src_ptr,
87 uint8_t *dst_ptr, int src_stride,
88 int dst_stride, int cols,
89 uint8_t *f, int size) {
90 uint8_t *src, *dst;
91 int row;
92 int col;
93
94 // While columns of length 16 can be processed, load them.
95 for (col = 0; col < cols - 8; col += 16) {
96 uint8x16_t a0, a1, a2, a3, a4, a5, a6, a7;
97 src = src_ptr - 2 * src_stride;
98 dst = dst_ptr;
99
100 a0 = vld1q_u8(src);
101 src += src_stride;
102 a1 = vld1q_u8(src);
103 src += src_stride;
104 a2 = vld1q_u8(src);
105 src += src_stride;
106 a3 = vld1q_u8(src);
107 src += src_stride;
108
109 for (row = 0; row < size; row += 4) {
110 uint8x16_t v_out_0, v_out_1, v_out_2, v_out_3;
111 const uint8x16_t filterq = vld1q_u8(f + col);
112
113 a4 = vld1q_u8(src);
114 src += src_stride;
115 a5 = vld1q_u8(src);
116 src += src_stride;
117 a6 = vld1q_u8(src);
118 src += src_stride;
119 a7 = vld1q_u8(src);
120 src += src_stride;
121
122 v_out_0 = generate_outputq(a0, a1, a2, a3, a4, filterq);
123 v_out_1 = generate_outputq(a1, a2, a3, a4, a5, filterq);
124 v_out_2 = generate_outputq(a2, a3, a4, a5, a6, filterq);
125 v_out_3 = generate_outputq(a3, a4, a5, a6, a7, filterq);
126
127 vst1q_u8(dst, v_out_0);
128 dst += dst_stride;
129 vst1q_u8(dst, v_out_1);
130 dst += dst_stride;
131 vst1q_u8(dst, v_out_2);
132 dst += dst_stride;
133 vst1q_u8(dst, v_out_3);
134 dst += dst_stride;
135
136 // Rotate over to the next slot.
137 a0 = a4;
138 a1 = a5;
139 a2 = a6;
140 a3 = a7;
141 }
142
143 src_ptr += 16;
144 dst_ptr += 16;
145 }
146
147 // Clean up any left over column of length 8.
148 if (col != cols) {
149 uint8x8_t a0, a1, a2, a3, a4, a5, a6, a7;
150 src = src_ptr - 2 * src_stride;
151 dst = dst_ptr;
152
153 a0 = vld1_u8(src);
154 src += src_stride;
155 a1 = vld1_u8(src);
156 src += src_stride;
157 a2 = vld1_u8(src);
158 src += src_stride;
159 a3 = vld1_u8(src);
160 src += src_stride;
161
162 for (row = 0; row < size; row += 4) {
163 uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3;
164 const uint8x8_t filter = vld1_u8(f + col);
165
166 a4 = vld1_u8(src);
167 src += src_stride;
168 a5 = vld1_u8(src);
169 src += src_stride;
170 a6 = vld1_u8(src);
171 src += src_stride;
172 a7 = vld1_u8(src);
173 src += src_stride;
174
175 v_out_0 = generate_output(a0, a1, a2, a3, a4, filter);
176 v_out_1 = generate_output(a1, a2, a3, a4, a5, filter);
177 v_out_2 = generate_output(a2, a3, a4, a5, a6, filter);
178 v_out_3 = generate_output(a3, a4, a5, a6, a7, filter);
179
180 vst1_u8(dst, v_out_0);
181 dst += dst_stride;
182 vst1_u8(dst, v_out_1);
183 dst += dst_stride;
184 vst1_u8(dst, v_out_2);
185 dst += dst_stride;
186 vst1_u8(dst, v_out_3);
187 dst += dst_stride;
188
189 // Rotate over to the next slot.
190 a0 = a4;
191 a1 = a5;
192 a2 = a6;
193 a3 = a7;
194 }
195
196 // Not strictly necessary but makes resetting dst_ptr easier.
197 dst_ptr += 8;
198 }
199
200 dst_ptr -= cols;
201
202 for (row = 0; row < size; row += 8) {
203 uint8x8_t a0, a1, a2, a3;
204 uint8x8_t b0, b1, b2, b3, b4, b5, b6, b7;
205
206 src = dst_ptr;
207 dst = dst_ptr;
208
209 // Load 8 values, transpose 4 of them, and discard 2 because they will be
210 // reloaded later.
211 load_and_transpose_u8_4x8(src, dst_stride, &a0, &a1, &a2, &a3);
212 a3 = a1;
213 a2 = a1 = a0; // Extend left border.
214
215 src += 2;
216
217 for (col = 0; col < cols; col += 8) {
218 uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3, v_out_4, v_out_5, v_out_6,
219 v_out_7;
220 // Although the filter is meant to be applied vertically and is instead
221 // being applied horizontally here it's OK because it's set in blocks of 8
222 // (or 16).
223 const uint8x8_t filter = vld1_u8(f + col);
224
225 load_and_transpose_u8_8x8(src, dst_stride, &b0, &b1, &b2, &b3, &b4, &b5,
226 &b6, &b7);
227
228 if (col + 8 == cols) {
229 // Last row. Extend border (b5).
230 b6 = b7 = b5;
231 }
232
233 v_out_0 = generate_output(a0, a1, a2, a3, b0, filter);
234 v_out_1 = generate_output(a1, a2, a3, b0, b1, filter);
235 v_out_2 = generate_output(a2, a3, b0, b1, b2, filter);
236 v_out_3 = generate_output(a3, b0, b1, b2, b3, filter);
237 v_out_4 = generate_output(b0, b1, b2, b3, b4, filter);
238 v_out_5 = generate_output(b1, b2, b3, b4, b5, filter);
239 v_out_6 = generate_output(b2, b3, b4, b5, b6, filter);
240 v_out_7 = generate_output(b3, b4, b5, b6, b7, filter);
241
242 transpose_and_store_u8_8x8(dst, dst_stride, v_out_0, v_out_1, v_out_2,
243 v_out_3, v_out_4, v_out_5, v_out_6, v_out_7);
244
245 a0 = b4;
246 a1 = b5;
247 a2 = b6;
248 a3 = b7;
249
250 src += 8;
251 dst += 8;
252 }
253
254 dst_ptr += 8 * dst_stride;
255 }
256 }
257
258 // sum += x;
259 // sumsq += x * y;
accumulate_sum_sumsq(const int16x4_t x,const int32x4_t xy,int16x4_t * const sum,int32x4_t * const sumsq)260 static void accumulate_sum_sumsq(const int16x4_t x, const int32x4_t xy,
261 int16x4_t *const sum, int32x4_t *const sumsq) {
262 const int16x4_t zero = vdup_n_s16(0);
263 const int32x4_t zeroq = vdupq_n_s32(0);
264
265 // Add in the first set because vext doesn't work with '0'.
266 *sum = vadd_s16(*sum, x);
267 *sumsq = vaddq_s32(*sumsq, xy);
268
269 // Shift x and xy to the right and sum. vext requires an immediate.
270 *sum = vadd_s16(*sum, vext_s16(zero, x, 1));
271 *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 1));
272
273 *sum = vadd_s16(*sum, vext_s16(zero, x, 2));
274 *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 2));
275
276 *sum = vadd_s16(*sum, vext_s16(zero, x, 3));
277 *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 3));
278 }
279
280 // Generate mask based on (sumsq * 15 - sum * sum < flimit)
calculate_mask(const int16x4_t sum,const int32x4_t sumsq,const int32x4_t f,const int32x4_t fifteen)281 static uint16x4_t calculate_mask(const int16x4_t sum, const int32x4_t sumsq,
282 const int32x4_t f, const int32x4_t fifteen) {
283 const int32x4_t a = vmulq_s32(sumsq, fifteen);
284 const int32x4_t b = vmlsl_s16(a, sum, sum);
285 const uint32x4_t mask32 = vcltq_s32(b, f);
286 return vmovn_u32(mask32);
287 }
288
combine_mask(const int16x4_t sum_low,const int16x4_t sum_high,const int32x4_t sumsq_low,const int32x4_t sumsq_high,const int32x4_t f)289 static uint8x8_t combine_mask(const int16x4_t sum_low, const int16x4_t sum_high,
290 const int32x4_t sumsq_low,
291 const int32x4_t sumsq_high, const int32x4_t f) {
292 const int32x4_t fifteen = vdupq_n_s32(15);
293 const uint16x4_t mask16_low = calculate_mask(sum_low, sumsq_low, f, fifteen);
294 const uint16x4_t mask16_high =
295 calculate_mask(sum_high, sumsq_high, f, fifteen);
296 return vmovn_u16(vcombine_u16(mask16_low, mask16_high));
297 }
298
299 // Apply filter of (8 + sum + s[c]) >> 4.
filter_pixels(const int16x8_t sum,const uint8x8_t s)300 static uint8x8_t filter_pixels(const int16x8_t sum, const uint8x8_t s) {
301 const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s));
302 const int16x8_t sum_s = vaddq_s16(sum, s16);
303
304 return vqrshrun_n_s16(sum_s, 4);
305 }
306
vpx_mbpost_proc_across_ip_neon(uint8_t * src,int pitch,int rows,int cols,int flimit)307 void vpx_mbpost_proc_across_ip_neon(uint8_t *src, int pitch, int rows, int cols,
308 int flimit) {
309 int row, col;
310 const int32x4_t f = vdupq_n_s32(flimit);
311
312 assert(cols % 8 == 0);
313
314 for (row = 0; row < rows; ++row) {
315 // Sum the first 8 elements, which are extended from s[0].
316 // sumsq gets primed with +16.
317 int sumsq = src[0] * src[0] * 9 + 16;
318 int sum = src[0] * 9;
319
320 uint8x8_t left_context, s, right_context;
321 int16x4_t sum_low, sum_high;
322 int32x4_t sumsq_low, sumsq_high;
323
324 // Sum (+square) the next 6 elements.
325 // Skip [0] because it's included above.
326 for (col = 1; col <= 6; ++col) {
327 sumsq += src[col] * src[col];
328 sum += src[col];
329 }
330
331 // Prime the sums. Later the loop uses the _high values to prime the new
332 // vectors.
333 sumsq_high = vdupq_n_s32(sumsq);
334 sum_high = vdup_n_s16(sum);
335
336 // Manually extend the left border.
337 left_context = vdup_n_u8(src[0]);
338
339 for (col = 0; col < cols; col += 8) {
340 uint8x8_t mask, output;
341 int16x8_t x, y;
342 int32x4_t xy_low, xy_high;
343
344 s = vld1_u8(src + col);
345
346 if (col + 8 == cols) {
347 // Last row. Extend border.
348 right_context = vdup_n_u8(src[col + 7]);
349 } else {
350 right_context = vld1_u8(src + col + 7);
351 }
352
353 x = vreinterpretq_s16_u16(vsubl_u8(right_context, left_context));
354 y = vreinterpretq_s16_u16(vaddl_u8(right_context, left_context));
355 xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
356 xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
357
358 // Catch up to the last sum'd value.
359 sum_low = vdup_lane_s16(sum_high, 3);
360 sumsq_low = vdupq_lane_s32(vget_high_s32(sumsq_high), 1);
361
362 accumulate_sum_sumsq(vget_low_s16(x), xy_low, &sum_low, &sumsq_low);
363
364 // Need to do this sequentially because we need the max value from
365 // sum_low.
366 sum_high = vdup_lane_s16(sum_low, 3);
367 sumsq_high = vdupq_lane_s32(vget_high_s32(sumsq_low), 1);
368
369 accumulate_sum_sumsq(vget_high_s16(x), xy_high, &sum_high, &sumsq_high);
370
371 mask = combine_mask(sum_low, sum_high, sumsq_low, sumsq_high, f);
372
373 output = filter_pixels(vcombine_s16(sum_low, sum_high), s);
374 output = vbsl_u8(mask, output, s);
375
376 vst1_u8(src + col, output);
377
378 left_context = s;
379 }
380
381 src += pitch;
382 }
383 }
384
385 // Apply filter of (vpx_rv + sum + s[c]) >> 4.
filter_pixels_rv(const int16x8_t sum,const uint8x8_t s,const int16x8_t rv)386 static uint8x8_t filter_pixels_rv(const int16x8_t sum, const uint8x8_t s,
387 const int16x8_t rv) {
388 const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s));
389 const int16x8_t sum_s = vaddq_s16(sum, s16);
390 const int16x8_t rounded = vaddq_s16(sum_s, rv);
391
392 return vqshrun_n_s16(rounded, 4);
393 }
394
vpx_mbpost_proc_down_neon(uint8_t * dst,int pitch,int rows,int cols,int flimit)395 void vpx_mbpost_proc_down_neon(uint8_t *dst, int pitch, int rows, int cols,
396 int flimit) {
397 int row, col, i;
398 const int32x4_t f = vdupq_n_s32(flimit);
399 uint8x8_t below_context = vdup_n_u8(0);
400
401 // 8 columns are processed at a time.
402 // If rows is less than 8 the bottom border extension fails.
403 assert(cols % 8 == 0);
404 assert(rows >= 8);
405
406 // Load and keep the first 8 values in memory. Process a vertical stripe that
407 // is 8 wide.
408 for (col = 0; col < cols; col += 8) {
409 uint8x8_t s, above_context[8];
410 int16x8_t sum, sum_tmp;
411 int32x4_t sumsq_low, sumsq_high;
412
413 // Load and extend the top border.
414 s = vld1_u8(dst);
415 for (i = 0; i < 8; i++) {
416 above_context[i] = s;
417 }
418
419 sum_tmp = vreinterpretq_s16_u16(vmovl_u8(s));
420
421 // sum * 9
422 sum = vmulq_n_s16(sum_tmp, 9);
423
424 // (sum * 9) * sum == sum * sum * 9
425 sumsq_low = vmull_s16(vget_low_s16(sum), vget_low_s16(sum_tmp));
426 sumsq_high = vmull_s16(vget_high_s16(sum), vget_high_s16(sum_tmp));
427
428 // Load and discard the next 6 values to prime sum and sumsq.
429 for (i = 1; i <= 6; ++i) {
430 const uint8x8_t a = vld1_u8(dst + i * pitch);
431 const int16x8_t b = vreinterpretq_s16_u16(vmovl_u8(a));
432 sum = vaddq_s16(sum, b);
433
434 sumsq_low = vmlal_s16(sumsq_low, vget_low_s16(b), vget_low_s16(b));
435 sumsq_high = vmlal_s16(sumsq_high, vget_high_s16(b), vget_high_s16(b));
436 }
437
438 for (row = 0; row < rows; ++row) {
439 uint8x8_t mask, output;
440 int16x8_t x, y;
441 int32x4_t xy_low, xy_high;
442
443 s = vld1_u8(dst + row * pitch);
444
445 // Extend the bottom border.
446 if (row + 7 < rows) {
447 below_context = vld1_u8(dst + (row + 7) * pitch);
448 }
449
450 x = vreinterpretq_s16_u16(vsubl_u8(below_context, above_context[0]));
451 y = vreinterpretq_s16_u16(vaddl_u8(below_context, above_context[0]));
452 xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
453 xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
454
455 sum = vaddq_s16(sum, x);
456
457 sumsq_low = vaddq_s32(sumsq_low, xy_low);
458 sumsq_high = vaddq_s32(sumsq_high, xy_high);
459
460 mask = combine_mask(vget_low_s16(sum), vget_high_s16(sum), sumsq_low,
461 sumsq_high, f);
462
463 output = filter_pixels_rv(sum, s, vld1q_s16(vpx_rv + (row & 127)));
464 output = vbsl_u8(mask, output, s);
465
466 vst1_u8(dst + row * pitch, output);
467
468 above_context[0] = above_context[1];
469 above_context[1] = above_context[2];
470 above_context[2] = above_context[3];
471 above_context[3] = above_context[4];
472 above_context[4] = above_context[5];
473 above_context[5] = above_context[6];
474 above_context[6] = above_context[7];
475 above_context[7] = s;
476 }
477
478 dst += 8;
479 }
480 }
481