1/*
2 * function: kernel_wavelet_coeff_variance
3 *     Calculate wavelet coefficients variance
4 * input:  Wavelet coefficients as read only
5 * output: Wavelet coefficients variance
6 */
7
8#ifndef WAVELET_DENOISE_Y
9#define WAVELET_DENOISE_Y 1
10#endif
11
12#ifndef WAVELET_DENOISE_UV
13#define WAVELET_DENOISE_UV 0
14#endif
15
16#define WG_CELL_X_SIZE 8
17#define WG_CELL_Y_SIZE 8
18
19#define SLM_CELL_X_OFFSET 1
20#define SLM_CELL_Y_OFFSET 2
21
22// 10x12
23#define SLM_CELL_X_SIZE (WG_CELL_X_SIZE + SLM_CELL_X_OFFSET * 2)
24#define SLM_CELL_Y_SIZE (WG_CELL_Y_SIZE + SLM_CELL_Y_OFFSET * 2)
25
26__kernel void kernel_wavelet_coeff_variance (__read_only image2d_t input, __write_only image2d_t output, int layer)
27{
28    sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;
29
30    int g_id_x = get_global_id (0);
31    int g_id_y = get_global_id (1);
32
33    int group_id_x = get_group_id(0);
34    int group_id_y = get_group_id(1);
35
36    int local_id_x = get_local_id(0);
37    int local_id_y = get_local_id(1);
38
39    int g_size_x = get_global_size (0);
40    int g_size_y = get_global_size (1);
41
42    int l_size_x = get_local_size (0);
43    int l_size_y = get_local_size (1);
44
45    int local_index = local_id_y * WG_CELL_X_SIZE + local_id_x;
46
47    float offset = 0.5f;
48    float4 line_sum[5] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
49    float4 line_var = 0.0f;
50
51    __local float4 local_src_data[SLM_CELL_X_SIZE * SLM_CELL_Y_SIZE];
52
53    int i = local_id_x + local_id_y * WG_CELL_X_SIZE;
54    int start_x = mad24(group_id_x, WG_CELL_X_SIZE, -SLM_CELL_X_OFFSET);
55    int start_y = mad24(group_id_y, WG_CELL_Y_SIZE, -SLM_CELL_Y_OFFSET);
56
57    for (int j = i;  j < SLM_CELL_X_SIZE * SLM_CELL_Y_SIZE; j += WG_CELL_X_SIZE * WG_CELL_Y_SIZE)
58    {
59        int x = start_x + (j % SLM_CELL_X_SIZE);
60        int y = start_y + (j / SLM_CELL_X_SIZE);
61        local_src_data[j] = read_imagef (input, sampler, (int2)(x, y)) - offset;
62    }
63    barrier(CLK_LOCAL_MEM_FENCE);
64
65    float16 line0 = *((__local float16 *)(local_src_data + local_id_y * SLM_CELL_X_SIZE + local_id_x));
66    float16 line1 = *((__local float16 *)(local_src_data + (local_id_y + 1) * SLM_CELL_X_SIZE + local_id_x));
67    float16 line2 = *((__local float16 *)(local_src_data + (local_id_y + 2) * SLM_CELL_X_SIZE + local_id_x));
68    float16 line3 = *((__local float16 *)(local_src_data + (local_id_y + 3) * SLM_CELL_X_SIZE + local_id_x));
69    float16 line4 = *((__local float16 *)(local_src_data + (local_id_y + 4) * SLM_CELL_X_SIZE + local_id_x));
70
71#if WAVELET_DENOISE_Y
72    line_sum[0] = mad(line0.s0123, line0.s0123, line_sum[0]);
73    line_sum[0] = mad(line0.s1234, line0.s1234, line_sum[0]);
74    line_sum[0] = mad(line0.s2345, line0.s2345, line_sum[0]);
75    line_sum[0] = mad(line0.s3456, line0.s3456, line_sum[0]);
76    line_sum[0] = mad(line0.s4567, line0.s4567, line_sum[0]);
77    line_sum[0] = mad(line0.s5678, line0.s5678, line_sum[0]);
78    line_sum[0] = mad(line0.s6789, line0.s6789, line_sum[0]);
79    line_sum[0] = mad(line0.s789a, line0.s789a, line_sum[0]);
80    line_sum[0] = mad(line0.s89ab, line0.s89ab, line_sum[0]);
81
82    line_sum[1] = mad(line1.s0123, line1.s0123, line_sum[1]);
83    line_sum[1] = mad(line1.s1234, line1.s1234, line_sum[1]);
84    line_sum[1] = mad(line1.s2345, line1.s2345, line_sum[1]);
85    line_sum[1] = mad(line1.s3456, line1.s3456, line_sum[1]);
86    line_sum[1] = mad(line1.s4567, line1.s4567, line_sum[1]);
87    line_sum[1] = mad(line1.s5678, line1.s5678, line_sum[1]);
88    line_sum[1] = mad(line1.s6789, line1.s6789, line_sum[1]);
89    line_sum[1] = mad(line1.s789a, line1.s789a, line_sum[1]);
90    line_sum[1] = mad(line1.s89ab, line1.s89ab, line_sum[1]);
91
92    line_sum[2] = mad(line2.s0123, line2.s0123, line_sum[2]);
93    line_sum[2] = mad(line2.s1234, line2.s1234, line_sum[2]);
94    line_sum[2] = mad(line2.s2345, line2.s2345, line_sum[2]);
95    line_sum[2] = mad(line2.s3456, line2.s3456, line_sum[2]);
96    line_sum[2] = mad(line2.s4567, line2.s4567, line_sum[2]);
97    line_sum[2] = mad(line2.s5678, line2.s5678, line_sum[2]);
98    line_sum[2] = mad(line2.s6789, line2.s6789, line_sum[2]);
99    line_sum[2] = mad(line2.s789a, line2.s789a, line_sum[2]);
100    line_sum[2] = mad(line2.s89ab, line2.s89ab, line_sum[2]);
101
102    line_sum[3] = mad(line3.s0123, line3.s0123, line_sum[3]);
103    line_sum[3] = mad(line3.s1234, line3.s1234, line_sum[3]);
104    line_sum[3] = mad(line3.s2345, line3.s2345, line_sum[3]);
105    line_sum[3] = mad(line3.s3456, line3.s3456, line_sum[3]);
106    line_sum[3] = mad(line3.s4567, line3.s4567, line_sum[3]);
107    line_sum[3] = mad(line3.s5678, line3.s5678, line_sum[3]);
108    line_sum[3] = mad(line3.s6789, line3.s6789, line_sum[3]);
109    line_sum[3] = mad(line3.s789a, line3.s789a, line_sum[3]);
110    line_sum[3] = mad(line3.s89ab, line3.s89ab, line_sum[3]);
111
112    line_sum[4] = mad(line4.s0123, line4.s0123, line_sum[4]);
113    line_sum[4] = mad(line4.s1234, line4.s1234, line_sum[4]);
114    line_sum[4] = mad(line4.s2345, line4.s2345, line_sum[4]);
115    line_sum[4] = mad(line4.s3456, line4.s3456, line_sum[4]);
116    line_sum[4] = mad(line4.s4567, line4.s4567, line_sum[4]);
117    line_sum[4] = mad(line4.s5678, line4.s5678, line_sum[4]);
118    line_sum[4] = mad(line4.s6789, line4.s6789, line_sum[4]);
119    line_sum[4] = mad(line4.s789a, line4.s789a, line_sum[4]);
120    line_sum[4] = mad(line4.s89ab, line4.s89ab, line_sum[4]);
121
122    line_var = (line_sum[0] + line_sum[1] + line_sum[2] + line_sum[3] + line_sum[4]) / 45;
123#endif
124
125#if WAVELET_DENOISE_UV
126    line_sum[0] = mad(line0.s0123, line0.s0123, line_sum[0]);
127    line_sum[0] = mad(line0.s2345, line0.s2345, line_sum[0]);
128    line_sum[0] = mad(line0.s4567, line0.s4567, line_sum[0]);
129    line_sum[0] = mad(line0.s6789, line0.s6789, line_sum[0]);
130    line_sum[0] = mad(line0.s89ab, line0.s89ab, line_sum[0]);
131    line_sum[0] = mad(line0.sabcd, line0.sabcd, line_sum[0]);
132    line_sum[0] = mad(line0.scdef, line0.scdef, line_sum[0]);
133
134    line_sum[1] = mad(line1.s0123, line1.s0123, line_sum[1]);
135    line_sum[1] = mad(line1.s2345, line1.s2345, line_sum[1]);
136    line_sum[1] = mad(line1.s4567, line1.s4567, line_sum[1]);
137    line_sum[1] = mad(line1.s6789, line1.s6789, line_sum[1]);
138    line_sum[1] = mad(line1.s89ab, line1.s89ab, line_sum[1]);
139    line_sum[1] = mad(line1.sabcd, line1.sabcd, line_sum[1]);
140    line_sum[1] = mad(line1.scdef, line1.scdef, line_sum[1]);
141
142    line_sum[2] = mad(line2.s0123, line2.s0123, line_sum[2]);
143    line_sum[2] = mad(line2.s2345, line2.s2345, line_sum[2]);
144    line_sum[2] = mad(line2.s4567, line2.s4567, line_sum[2]);
145    line_sum[2] = mad(line2.s6789, line2.s6789, line_sum[2]);
146    line_sum[2] = mad(line2.s89ab, line2.s89ab, line_sum[2]);
147    line_sum[2] = mad(line2.sabcd, line2.sabcd, line_sum[2]);
148    line_sum[2] = mad(line2.scdef, line2.scdef, line_sum[2]);
149
150    line_sum[3] = mad(line3.s0123, line3.s0123, line_sum[3]);
151    line_sum[3] = mad(line3.s2345, line3.s2345, line_sum[3]);
152    line_sum[3] = mad(line3.s4567, line3.s4567, line_sum[3]);
153    line_sum[3] = mad(line3.s6789, line3.s6789, line_sum[3]);
154    line_sum[3] = mad(line3.s89ab, line3.s89ab, line_sum[3]);
155    line_sum[3] = mad(line3.sabcd, line3.sabcd, line_sum[3]);
156    line_sum[3] = mad(line3.scdef, line3.scdef, line_sum[3]);
157
158    line_sum[4] = mad(line4.s0123, line4.s0123, line_sum[4]);
159    line_sum[4] = mad(line4.s2345, line4.s2345, line_sum[4]);
160    line_sum[4] = mad(line4.s4567, line4.s4567, line_sum[4]);
161    line_sum[4] = mad(line4.s6789, line4.s6789, line_sum[4]);
162    line_sum[4] = mad(line4.s89ab, line4.s89ab, line_sum[4]);
163    line_sum[4] = mad(line4.sabcd, line4.sabcd, line_sum[4]);
164    line_sum[4] = mad(line4.scdef, line4.scdef, line_sum[4]);
165
166    line_var = ((line_sum[0] + line_sum[1] + line_sum[2] + line_sum[3] + line_sum[4]) / 35);
167#endif
168
169    write_imagef(output, (int2)(g_id_x, g_id_y), line_var);
170}
171
172/*
173 * function: kernel_wavelet_coeff_thresholding
174 *     wavelet coefficient thresholding kernel
175 * hl/lh/hh:  wavelet coefficients
176 * layer:        wavelet decomposition layer
177 * decomLevels:  wavelet decomposition levels
178 */
179
180__kernel void kernel_wavelet_coeff_thresholding (float noise_var1, float noise_var2,
181        __read_only image2d_t in_hl, __read_only image2d_t var_hl, __write_only image2d_t out_hl,
182        __read_only image2d_t in_lh, __read_only image2d_t var_lh, __write_only image2d_t out_lh,
183        __read_only image2d_t in_hh, __read_only image2d_t var_hh, __write_only image2d_t out_hh,
184        int layer, int decomLevels,
185        float hardThresh, float softThresh, float ag_weight)
186{
187    int x = get_global_id (0);
188    int y = get_global_id (1);
189    sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;
190
191    float4 input_hl;
192    float4 input_lh;
193    float4 input_hh;
194
195    float4 output_hl;
196    float4 output_lh;
197    float4 output_hh;
198
199    float4 coeff_var_hl;
200    float4 coeff_var_lh;
201    float4 coeff_var_hh;
202
203    float4 stddev_hl;
204    float4 stddev_lh;
205    float4 stddev_hh;
206
207    float4 thresh_hl;
208    float4 thresh_lh;
209    float4 thresh_hh;
210
211    float4 noise_var = (float4) (noise_var1, noise_var2, noise_var1, noise_var2);
212
213    input_hl = read_imagef(in_hl, sampler, (int2)(x, y)) - 0.5f;
214    input_lh = read_imagef(in_lh, sampler, (int2)(x, y)) - 0.5f;
215    input_hh = read_imagef(in_hh, sampler, (int2)(x, y)) - 0.5f;
216
217    coeff_var_hl = 65025 * (1 << 2 * layer) * read_imagef(var_hl, sampler, (int2)(x, y));
218    coeff_var_lh = 65025 * (1 << 2 * layer) * read_imagef(var_lh, sampler, (int2)(x, y));
219    coeff_var_hh = 65025 * (1 << 2 * layer) * read_imagef(var_hh, sampler, (int2)(x, y));
220
221    stddev_hl = coeff_var_hl - noise_var;
222    stddev_hl = (stddev_hl > 0) ? sqrt(stddev_hl) : 0.000001f;
223
224    stddev_lh = coeff_var_lh - noise_var;
225    stddev_lh = (stddev_lh > 0) ? sqrt(stddev_lh) : 0.000001f;
226
227    stddev_hh = coeff_var_hh - noise_var;
228    stddev_hh = (stddev_hh > 0) ? sqrt(stddev_hh) : 0.000001f;
229
230    thresh_hl = (ag_weight * noise_var / stddev_hl) / (255 * (1 << layer));
231    thresh_lh = (ag_weight * noise_var / stddev_lh) / (255 * (1 << layer));
232    thresh_hh = (ag_weight * noise_var / stddev_hh) / (255 * (1 << layer));
233
234    // Soft thresholding
235    output_hl = (fabs(input_hl) < thresh_hl) ? 0 : ((input_hl > 0) ? fabs(input_hl) - thresh_hl : thresh_hl - fabs(input_hl));
236    output_lh = (fabs(input_lh) < thresh_lh) ? 0 : ((input_lh > 0) ? fabs(input_lh) - thresh_lh : thresh_lh - fabs(input_lh));
237    output_hh = (fabs(input_hh) < thresh_hh) ? 0 : ((input_hh > 0) ? fabs(input_hh) - thresh_hh : thresh_hh - fabs(input_hh));
238
239    write_imagef(out_hl, (int2)(x, y), output_hl + 0.5f);
240    write_imagef(out_lh, (int2)(x, y), output_lh + 0.5f);
241    write_imagef(out_hh, (int2)(x, y), output_hh + 0.5f);
242}
243
244