1// This file is part of OpenCV project.
2// It is subject to the license terms in the LICENSE file found in the top-level directory
3// of this distribution and at http://opencv.org/license.html.
4
5// Copyright (C) 2014, Advanced Micro Devices, Inc., all rights reserved.
6// Third party copyrights are property of their respective owners.
7
8#ifdef cl_amd_printf
9#pragma OPENCL_EXTENSION cl_amd_printf:enable
10#endif
11
12#ifdef DOUBLE_SUPPORT
13#ifdef cl_amd_fp64
14#pragma OPENCL EXTENSION cl_amd_fp64:enable
15#elif defined cl_khr_fp64
16#pragma OPENCL EXTENSION cl_khr_fp64:enable
17#endif
18#endif
19
20
21#ifdef OP_CALC_WEIGHTS
22
23__kernel void calcAlmostDist2Weight(__global wlut_t * almostDist2Weight, int almostMaxDist,
24                                    FT almostDist2ActualDistMultiplier, int fixedPointMult,
25                                    w_t den, FT WEIGHT_THRESHOLD)
26{
27    int almostDist = get_global_id(0);
28
29    if (almostDist < almostMaxDist)
30    {
31        FT dist = almostDist * almostDist2ActualDistMultiplier;
32#ifdef ABS
33        w_t w = exp((w_t)(-dist*dist) * den);
34#else
35        w_t w = exp((w_t)(-dist) * den);
36#endif
37        wlut_t weight = convert_wlut_t(fixedPointMult * (isnan(w) ? (w_t)1.0 : w));
38        almostDist2Weight[almostDist] =
39            weight < (wlut_t)(WEIGHT_THRESHOLD * fixedPointMult) ? (wlut_t)0 : weight;
40    }
41}
42
43#elif defined OP_CALC_FASTNLMEANS
44
45#define noconvert
46
47#define SEARCH_SIZE_SQ (SEARCH_SIZE * SEARCH_SIZE)
48
49inline int calcDist(pixel_t a, pixel_t b)
50{
51#ifdef ABS
52    int_t retval = convert_int_t(abs_diff(a, b));
53#else
54    int_t diff = convert_int_t(a) - convert_int_t(b);
55    int_t retval = diff * diff;
56#endif
57
58#if cn == 1
59    return retval;
60#elif cn == 2
61    return retval.x + retval.y;
62#elif cn == 3
63    return retval.x + retval.y + retval.z;
64#elif cn == 4
65    return retval.x + retval.y + retval.z + retval.w;
66#else
67#error "cn should be either 1, 2, 3 or 4"
68#endif
69}
70
71#ifdef ABS
72inline int calcDistUpDown(pixel_t down_value, pixel_t down_value_t, pixel_t up_value, pixel_t up_value_t)
73{
74    return calcDist(down_value, down_value_t) - calcDist(up_value, up_value_t);
75}
76#else
77inline int calcDistUpDown(pixel_t down_value, pixel_t down_value_t, pixel_t up_value, pixel_t up_value_t)
78{
79    int_t A = convert_int_t(down_value) - convert_int_t(down_value_t);
80    int_t B = convert_int_t(up_value) - convert_int_t(up_value_t);
81    int_t retval = (A - B) * (A + B);
82
83#if cn == 1
84    return retval;
85#elif cn == 2
86    return retval.x + retval.y;
87#elif cn == 3
88    return retval.x + retval.y + retval.z;
89#elif cn == 4
90    return retval.x + retval.y + retval.z + retval.w;
91#else
92#error "cn should be either 1, 2, 3 or 4"
93#endif
94}
95#endif
96
97#define COND if (x == 0 && y == 0)
98
99inline void calcFirstElementInRow(__global const uchar * src, int src_step, int src_offset,
100                                  __local int * dists, int y, int x, int id,
101                                  __global int * col_dists, __global int * up_col_dists)
102{
103    y -= TEMPLATE_SIZE2;
104    int sx = x - SEARCH_SIZE2, sy = y - SEARCH_SIZE2;
105    int col_dists_current_private[TEMPLATE_SIZE];
106
107    for (int i = id; i < SEARCH_SIZE_SQ; i += CTA_SIZE)
108    {
109        int dist = 0, value;
110
111        __global const pixel_t * src_template = (__global const pixel_t *)(src +
112            mad24(sy + i / SEARCH_SIZE, src_step, mad24(psz, sx + i % SEARCH_SIZE, src_offset)));
113        __global const pixel_t * src_current = (__global const pixel_t *)(src + mad24(y, src_step, mad24(psz, x, src_offset)));
114        __global int * col_dists_current = col_dists + i * TEMPLATE_SIZE;
115
116        #pragma unroll
117        for (int j = 0; j < TEMPLATE_SIZE; ++j)
118            col_dists_current_private[j] = 0;
119
120        for (int ty = 0; ty < TEMPLATE_SIZE; ++ty)
121        {
122            #pragma unroll
123            for (int tx = -TEMPLATE_SIZE2; tx <= TEMPLATE_SIZE2; ++tx)
124            {
125                value = calcDist(src_template[tx], src_current[tx]);
126
127                col_dists_current_private[tx + TEMPLATE_SIZE2] += value;
128                dist += value;
129            }
130
131            src_current = (__global const pixel_t *)((__global const uchar *)src_current + src_step);
132            src_template = (__global const pixel_t *)((__global const uchar *)src_template + src_step);
133        }
134
135        #pragma unroll
136        for (int j = 0; j < TEMPLATE_SIZE; ++j)
137            col_dists_current[j] = col_dists_current_private[j];
138
139        dists[i] = dist;
140        up_col_dists[0 + i] = col_dists[TEMPLATE_SIZE - 1];
141    }
142}
143
144inline void calcElementInFirstRow(__global const uchar * src, int src_step, int src_offset,
145                                  __local int * dists, int y, int x0, int x, int id, int first,
146                                  __global int * col_dists, __global int * up_col_dists)
147{
148    x += TEMPLATE_SIZE2;
149    y -= TEMPLATE_SIZE2;
150    int sx = x - SEARCH_SIZE2, sy = y - SEARCH_SIZE2;
151
152    for (int i = id; i < SEARCH_SIZE_SQ; i += CTA_SIZE)
153    {
154        __global const pixel_t * src_current = (__global const pixel_t *)(src + mad24(y, src_step, mad24(psz, x, src_offset)));
155        __global const pixel_t * src_template = (__global const pixel_t *)(src +
156            mad24(sy + i / SEARCH_SIZE, src_step, mad24(psz, sx + i % SEARCH_SIZE, src_offset)));
157        __global int * col_dists_current = col_dists + TEMPLATE_SIZE * i;
158
159        int col_dist = 0;
160
161        #pragma unroll
162        for (int ty = 0; ty < TEMPLATE_SIZE; ++ty)
163        {
164            col_dist += calcDist(src_current[0], src_template[0]);
165
166            src_current = (__global const pixel_t *)((__global const uchar *)src_current + src_step);
167            src_template = (__global const pixel_t *)((__global const uchar *)src_template + src_step);
168        }
169
170        dists[i] += col_dist - col_dists_current[first];
171        col_dists_current[first] = col_dist;
172        up_col_dists[mad24(x0, SEARCH_SIZE_SQ, i)] = col_dist;
173    }
174}
175
176inline void calcElement(__global const uchar * src, int src_step, int src_offset,
177                        __local int * dists, int y, int x0, int x, int id, int first,
178                        __global int * col_dists, __global int * up_col_dists)
179{
180    int sx = x + TEMPLATE_SIZE2;
181    int sy_up = y - TEMPLATE_SIZE2 - 1;
182    int sy_down = y + TEMPLATE_SIZE2;
183
184    pixel_t up_value = *(__global const pixel_t *)(src + mad24(sy_up, src_step, mad24(psz, sx, src_offset)));
185    pixel_t down_value = *(__global const pixel_t *)(src + mad24(sy_down, src_step, mad24(psz, sx, src_offset)));
186
187    sx -= SEARCH_SIZE2;
188    sy_up -= SEARCH_SIZE2;
189    sy_down -= SEARCH_SIZE2;
190
191    for (int i = id; i < SEARCH_SIZE_SQ; i += CTA_SIZE)
192    {
193        int wx = i % SEARCH_SIZE, wy = i / SEARCH_SIZE;
194
195        pixel_t up_value_t = *(__global const pixel_t *)(src + mad24(sy_up + wy, src_step, mad24(psz, sx + wx, src_offset)));
196        pixel_t down_value_t = *(__global const pixel_t *)(src + mad24(sy_down + wy, src_step, mad24(psz, sx + wx, src_offset)));
197
198        __global int * col_dists_current = col_dists + mad24(i, TEMPLATE_SIZE, first);
199        __global int * up_col_dists_current = up_col_dists + mad24(x0, SEARCH_SIZE_SQ, i);
200
201        int col_dist = up_col_dists_current[0] + calcDistUpDown(down_value, down_value_t, up_value, up_value_t);
202
203        dists[i] += col_dist - col_dists_current[0];
204        col_dists_current[0] = col_dist;
205        up_col_dists_current[0] = col_dist;
206    }
207}
208
209inline void convolveWindow(__global const uchar * src, int src_step, int src_offset,
210                           __local int * dists, __global const wlut_t * almostDist2Weight,
211                           __global uchar * dst, int dst_step, int dst_offset,
212                           int y, int x, int id, __local weight_t * weights_local,
213                           __local sum_t * weighted_sum_local, int almostTemplateWindowSizeSqBinShift)
214{
215    int sx = x - SEARCH_SIZE2, sy = y - SEARCH_SIZE2;
216    weight_t weights = (weight_t)0;
217    sum_t weighted_sum = (sum_t)0;
218
219    for (int i = id; i < SEARCH_SIZE_SQ; i += CTA_SIZE)
220    {
221        int src_index = mad24(sy + i / SEARCH_SIZE, src_step, mad24(i % SEARCH_SIZE + sx, psz, src_offset));
222        sum_t src_value = convert_sum_t(*(__global const pixel_t *)(src + src_index));
223
224        int almostAvgDist = dists[i] >> almostTemplateWindowSizeSqBinShift;
225        weight_t weight = convert_weight_t(almostDist2Weight[almostAvgDist]);
226
227        weights += weight;
228        weighted_sum += (sum_t)weight * src_value;
229    }
230
231    weights_local[id] = weights;
232    weighted_sum_local[id] = weighted_sum;
233    barrier(CLK_LOCAL_MEM_FENCE);
234
235    for (int lsize = CTA_SIZE >> 1; lsize > 2; lsize >>= 1)
236    {
237        if (id < lsize)
238        {
239           int id2 = lsize + id;
240           weights_local[id] += weights_local[id2];
241           weighted_sum_local[id] += weighted_sum_local[id2];
242        }
243        barrier(CLK_LOCAL_MEM_FENCE);
244    }
245
246    if (id == 0)
247    {
248        int dst_index = mad24(y, dst_step, mad24(psz, x, dst_offset));
249        sum_t weighted_sum_local_0 = weighted_sum_local[0] + weighted_sum_local[1] +
250            weighted_sum_local[2] + weighted_sum_local[3];
251        weight_t weights_local_0 = weights_local[0] + weights_local[1] + weights_local[2] + weights_local[3];
252
253        *(__global pixel_t *)(dst + dst_index) = convert_pixel_t(weighted_sum_local_0 / (sum_t)weights_local_0);
254    }
255}
256
257__kernel void fastNlMeansDenoising(__global const uchar * src, int src_step, int src_offset,
258                                   __global uchar * dst, int dst_step, int dst_offset, int dst_rows, int dst_cols,
259                                   __global const wlut_t * almostDist2Weight, __global uchar * buffer,
260                                   int almostTemplateWindowSizeSqBinShift)
261{
262    int block_x = get_group_id(0), nblocks_x = get_num_groups(0);
263    int block_y = get_group_id(1);
264    int id = get_local_id(0), first;
265
266    __local int dists[SEARCH_SIZE_SQ];
267    __local weight_t weights[CTA_SIZE];
268    __local sum_t weighted_sum[CTA_SIZE];
269
270    int x0 = block_x * BLOCK_COLS, x1 = min(x0 + BLOCK_COLS, dst_cols);
271    int y0 = block_y * BLOCK_ROWS, y1 = min(y0 + BLOCK_ROWS, dst_rows);
272
273    // for each group we need SEARCH_SIZE_SQ * TEMPLATE_SIZE integer buffer for storing part column sum for current element
274    // and SEARCH_SIZE_SQ * BLOCK_COLS integer buffer for storing last column sum for each element of search window of up row
275    int block_data_start = SEARCH_SIZE_SQ * (mad24(block_y, dst_cols, x0) + mad24(block_y, nblocks_x, block_x) * TEMPLATE_SIZE);
276    __global int * col_dists = (__global int *)(buffer + block_data_start * sizeof(int));
277    __global int * up_col_dists = col_dists + SEARCH_SIZE_SQ * TEMPLATE_SIZE;
278
279    for (int y = y0; y < y1; ++y)
280        for (int x = x0; x < x1; ++x)
281        {
282            if (x == x0)
283            {
284                calcFirstElementInRow(src, src_step, src_offset, dists, y, x, id, col_dists, up_col_dists);
285                first = 0;
286            }
287            else
288            {
289                if (y == y0)
290                    calcElementInFirstRow(src, src_step, src_offset, dists, y, x - x0, x, id, first, col_dists, up_col_dists);
291                else
292                    calcElement(src, src_step, src_offset, dists, y, x - x0, x, id, first, col_dists, up_col_dists);
293
294                first = (first + 1) % TEMPLATE_SIZE;
295            }
296
297            convolveWindow(src, src_step, src_offset, dists, almostDist2Weight, dst, dst_step, dst_offset,
298                y, x, id, weights, weighted_sum, almostTemplateWindowSizeSqBinShift);
299        }
300}
301
302#endif
303