1/* 2 * function: kernel_3d_denoise 3 * 3D Noise Reduction 4 * gain: The parameter determines the filtering strength for the reference block 5 * threshold: Noise variances of observed image 6 * restoredPrev: The previous restored image, image2d_t as read only 7 * output: restored image, image2d_t as write only 8 * input: observed image, image2d_t as read only 9 * inputPrev1: reference image, image2d_t as read only 10 * inputPrev2: reference image, image2d_t as read only 11 */ 12 13#ifndef REFERENCE_FRAME_COUNT 14#define REFERENCE_FRAME_COUNT 2 15#endif 16 17#ifndef ENABLE_IIR_FILERING 18#define ENABLE_IIR_FILERING 1 19#endif 20 21#define ENABLE_GRADIENT 1 22 23#ifndef WORKGROUP_WIDTH 24#define WORKGROUP_WIDTH 2 25#endif 26 27#ifndef WORKGROUP_HEIGHT 28#define WORKGROUP_HEIGHT 32 29#endif 30 31#define REF_BLOCK_X_OFFSET 1 32#define REF_BLOCK_Y_OFFSET 4 33 34#define REF_BLOCK_WIDTH (WORKGROUP_WIDTH + 2 * REF_BLOCK_X_OFFSET) 35#define REF_BLOCK_HEIGHT (WORKGROUP_HEIGHT + 2 * REF_BLOCK_Y_OFFSET) 36 37inline int2 subgroup_pos(const int sg_id, const int sg_lid) 38{ 39 int2 pos; 40 pos.x = mad24(2, sg_id % 2, sg_lid % 2); 41 pos.y = mad24(4, sg_id / 2, sg_lid / 2); 42 43 return pos; 44} 45 46inline void average_slice(float8 ref, 47 float8 observe, 48 float8* restore, 49 float2* sum_weight, 50 float gain, 51 float threshold, 52 uint sg_id, 53 uint sg_lid) 54{ 55 float8 grad = 0.0f; 56 float8 gradient = 0.0f; 57 float8 dist = 0.0f; 58 float8 distance = 0.0f; 59 float weight = 0.0f; 60 61#if ENABLE_GRADIENT 62 // calculate & cumulate gradient 63 if (sg_lid % 2 == 0) { 64 grad = intel_sub_group_shuffle(ref, 4); 65 } else { 66 grad = intel_sub_group_shuffle(ref, 5); 67 } 68 gradient = (float8)(grad.s1, grad.s1, grad.s1, grad.s1, grad.s5, grad.s5, grad.s5, grad.s5); 69 70 // normalize gradient "1/(4*255.0f) = 0.00098039f" 71 grad = fabs(gradient - ref) * 0.00098039f; 72 //grad = mad(-2, gradient, (ref + grad)) * 0.0004902f; 73 74 grad.s0 = (grad.s0 + grad.s1 + grad.s2 + grad.s3); 75 grad.s4 = (grad.s4 + grad.s5 + grad.s6 + grad.s7); 76#endif 77 // calculate & normalize distance "1/255.0f = 0.00392157f" 78 dist = (observe - ref) * 0.00392157f; 79 dist = dist * dist; 80 81 float8 dist_shuffle[8]; 82 dist_shuffle[0] = (intel_sub_group_shuffle(dist, 0)); 83 dist_shuffle[1] = (intel_sub_group_shuffle(dist, 1)); 84 dist_shuffle[2] = (intel_sub_group_shuffle(dist, 2)); 85 dist_shuffle[3] = (intel_sub_group_shuffle(dist, 3)); 86 dist_shuffle[4] = (intel_sub_group_shuffle(dist, 4)); 87 dist_shuffle[5] = (intel_sub_group_shuffle(dist, 5)); 88 dist_shuffle[6] = (intel_sub_group_shuffle(dist, 6)); 89 dist_shuffle[7] = (intel_sub_group_shuffle(dist, 7)); 90 91 if (sg_lid % 2 == 0) { 92 distance = dist_shuffle[0]; 93 distance += dist_shuffle[2]; 94 distance += dist_shuffle[4]; 95 distance += dist_shuffle[6]; 96 } 97 else { 98 distance = dist_shuffle[1]; 99 distance += dist_shuffle[3]; 100 distance += dist_shuffle[5]; 101 distance += dist_shuffle[7]; 102 } 103 104 // cumulate distance 105 dist.s0 = (distance.s0 + distance.s1 + distance.s2 + distance.s3); 106 dist.s4 = (distance.s4 + distance.s5 + distance.s6 + distance.s7); 107 gain = (grad.s0 < threshold) ? gain : 2.0f * gain; 108 weight = native_exp(-gain * dist.s0); 109 (*restore).lo = mad(weight, ref.lo, (*restore).lo); 110 (*sum_weight).lo = (*sum_weight).lo + weight; 111 112 gain = (grad.s4 < threshold) ? gain : 2.0f * gain; 113 weight = native_exp(-gain * dist.s4); 114 (*restore).hi = mad(weight, ref.hi, (*restore).hi); 115 (*sum_weight).hi = (*sum_weight).hi + weight; 116} 117 118inline void weighted_average (__read_only image2d_t input, 119 __local uchar8* ref_cache, 120 bool load_observe, 121 float8* observe, 122 float8* restore, 123 float2* sum_weight, 124 float gain, 125 float threshold, 126 uint sg_id, 127 uint sg_lid) 128{ 129 sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST; 130 131 int local_id_x = get_local_id(0); 132 int local_id_y = get_local_id(1); 133 const int group_id_x = get_group_id(0); 134 const int group_id_y = get_group_id(1); 135 136 int start_x = mad24(group_id_x, WORKGROUP_WIDTH, -REF_BLOCK_X_OFFSET); 137 int start_y = mad24(group_id_y, WORKGROUP_HEIGHT, -REF_BLOCK_Y_OFFSET); 138 139 int i = local_id_x + local_id_y * WORKGROUP_WIDTH; 140 for ( int j = i; j < (REF_BLOCK_HEIGHT * REF_BLOCK_WIDTH); 141 j += (WORKGROUP_HEIGHT * WORKGROUP_WIDTH) ) { 142 int corrd_x = start_x + (j % REF_BLOCK_WIDTH); 143 int corrd_y = start_y + (j / REF_BLOCK_WIDTH); 144 145 ref_cache[j] = as_uchar8( convert_ushort4(read_imageui(input, 146 sampler, 147 (int2)(corrd_x, corrd_y)))); 148 } 149 barrier(CLK_LOCAL_MEM_FENCE); 150 151#if WORKGROUP_WIDTH == 4 152 int2 pos = subgroup_pos(sg_id, sg_lid); 153 local_id_x = pos.x; 154 local_id_y = pos.y; 155#endif 156 157 if (load_observe) { 158 (*observe) = convert_float8( 159 ref_cache[mad24(local_id_y + REF_BLOCK_Y_OFFSET, 160 REF_BLOCK_WIDTH, 161 local_id_x + REF_BLOCK_X_OFFSET)]); 162 (*restore) = (*observe); 163 (*sum_weight) = 1.0f; 164 } 165 166 float8 ref[2] = {0.0f, 0.0f}; 167 __local uchar4* p_ref = (__local uchar4*)(ref_cache); 168 169 // top-left 170 ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24(local_id_y, 171 2 * REF_BLOCK_WIDTH, 172 mad24(2, local_id_x, 1)))); 173 average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 174 175 // top-right 176 ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24(local_id_y, 177 2 * REF_BLOCK_WIDTH, 178 mad24(2, local_id_x, 3)))); 179 average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 180 181 // top-mid 182 average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 183 184 // mid-left 185 ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 4), 186 2 * REF_BLOCK_WIDTH, 187 mad24(2, local_id_x, 1)))); 188 average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 189 190 // mid-right 191 ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 4), 192 2 * REF_BLOCK_WIDTH, 193 mad24(2, local_id_x, 3)))); 194 average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 195 196 // mid-mid 197 if (!load_observe) { 198 average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 199 } 200 201 // bottom-left 202 ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 8), 203 2 * REF_BLOCK_WIDTH, 204 mad24(2, local_id_x, 1)))); 205 average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 206 207 // bottom-right 208 ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 8), 209 2 * REF_BLOCK_WIDTH, 210 mad24(2, local_id_x, 3)))); 211 average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 212 213 // bottom-mid 214 average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid); 215} 216 217__kernel void kernel_3d_denoise ( float gain, 218 float threshold, 219 __read_only image2d_t restoredPrev, 220 __write_only image2d_t output, 221 __read_only image2d_t input, 222 __read_only image2d_t inputPrev1, 223 __read_only image2d_t inputPrev2) 224{ 225 float8 restore = 0.0f; 226 float8 observe = 0.0f; 227 float2 sum_weight = 0.0f; 228 229 const int sg_id = get_sub_group_id(); 230 const int sg_lid = (get_local_id(1) * WORKGROUP_WIDTH + get_local_id(0)) % 8; 231 232 __local uchar8 ref_cache[REF_BLOCK_HEIGHT * REF_BLOCK_WIDTH]; 233 234 weighted_average (input, ref_cache, true, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid); 235 236#if ENABLE_IIR_FILERING 237 weighted_average (restoredPrev, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid); 238#else 239#if REFERENCE_FRAME_COUNT > 1 240 weighted_average (inputPrev1, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid); 241#endif 242 243#if REFERENCE_FRAME_COUNT > 2 244 weighted_average (inputPrev2, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid); 245#endif 246#endif 247 248 restore.lo = restore.lo / sum_weight.lo; 249 restore.hi = restore.hi / sum_weight.hi; 250 251 int local_id_x = get_local_id(0); 252 int local_id_y = get_local_id(1); 253 const int group_id_x = get_group_id(0); 254 const int group_id_y = get_group_id(1); 255 256#if WORKGROUP_WIDTH == 4 257 int2 pos = subgroup_pos(sg_id, sg_lid); 258 local_id_x = pos.x; 259 local_id_y = pos.y; 260#endif 261 262 int coor_x = mad24(group_id_x, WORKGROUP_WIDTH, local_id_x); 263 int coor_y = mad24(group_id_y, WORKGROUP_HEIGHT, local_id_y); 264 265 write_imageui(output, 266 (int2)(coor_x, coor_y), 267 convert_uint4(as_ushort4(convert_uchar8(restore)))); 268} 269 270