1// Compute shader to convert ASTC textures to BC3 (ie: BC1 for color + BC4 for alpha).
2//
3// A bit of history
4// ----------------
5//
6// The algorithm used here for BC1 compression has a long history. It was originally published by
7// Simon Brown for the Squish encoder:
8//   https://www.sjbrown.co.uk/posts/dxt-compression-techniques/
9//   https://github.com/svn2github/libsquish/blob/c763145a30512c10450954b7a2b5b3a2f9a94e00/rangefit.cpp#L33
10//
11// It was then rewritten and improved upon by Fabian "ryg" Giesen for the stb_dxt encoder:
12//   https://github.com/GammaUNC/FasTC/blob/0f8cef65cf8f0fc5c58a2d06af3e0c3ad2374678/DXTEncoder/src/stb_dxt.h#L283
13//   https://fgiesen.wordpress.com/2022/11/08/whats-that-magic-computation-in-stb__refineblock/
14//
15// That version then made it to many places, including ANGLE, first as a C++ version:
16//   https://source.corp.google.com/android/external/angle/src/image_util/loadimage_etc.cpp;l=1073;bpv=0;bpt=0;rcl=90f88d3bc0d38ef5ec06ddaaef230db2d6e6fc02
17//
18// and then as a compute shader version upon which this shader is based:
19//   http://cs/android/external/angle/src/libANGLE/renderer/vulkan/shaders/src/EtcToBc.comp;rcl=81e45c881c54a7737f6fce95097f6df2f94cd76f
20//
21//
22// Useful links to understand BC1 compression
23// ------------------------------------------
24//
25//  http://www.ludicon.com/castano/blog/2022/11/bc1-compression-revisited/
26//  https://github.com/castano/icbc
27//  https://developer.download.nvidia.com/compute/cuda/1.1-Beta/x86_website/projects/dxtc/doc/cuda_dxtc.pdf
28//  https://fgiesen.wordpress.com/2022/11/08/whats-that-magic-computation-in-stb__refineblock/
29//  https://www.reedbeta.com/blog/understanding-bcn-texture-compression-formats/
30//  https://bartwronski.com/2020/05/21/dimensionality-reduction-for-image-and-texture-set-compression/
31//  https://core.ac.uk/download/pdf/210601023.pdf
32//  https://github.com/microsoft/Xbox-ATG-Samples/blob/main/XDKSamples/Graphics/FastBlockCompress/Shaders/BlockCompress.hlsli
33//  https://github.com/GammaUNC/FasTC/blob/0f8cef65cf8f0fc5c58a2d06af3e0c3ad2374678/DXTEncoder/src/stb_dxt.h
34//  https://github.com/darksylinc/betsy/blob/master/bin/Data/bc1.glsl
35//  https://github.com/GPUOpen-Tools/compressonator/blob/master/cmp_core/shaders/bc1_cmp.h
36//
37//
38// Optimization ideas
39// ------------------
40//
41// - Do the color refinement step from stb_dxt. This is probably the top priority. Currently, we
42//   only do the PCA step and we use the min and max colors as the endpoints. We should instead see
43//   if picking other endpoints on the PCA line would lead to better results.
44//
45// - Use dithering to improve quality. Betsy and FasTC encoders (links above) have examples.
46//
47// - Add a fast path for when all pixels are the same color (speed improvement)
48//
49// - Use BC1 instead of BC3 if the image doesn't contain semi-transparent pixels. We will need to
50//   add a pre-processing step to determine if there are such pixels. Alternatively, it could be
51//   done fairly efficiently as a post-processing step where we discard the BC4 data if all pixels
52//   are opaque, however in that case it would only work for fully opaque image (ie: we wouldn't be
53//   able to take advantage of BC1's punch-through alpha.
54//
55// To-do list
56// ---------------
57//   - TODO(gregschlom): Check that the GPU has gl_SubgroupSize >= 16 before using this shader,
58//     otherwise it will give wrong results.
59//
60//   - TODO(gregschlom): Check if the results are correct for image sizes that aren't multiples of 4
61
62#version 450 core
63#include "AstcDecompressor.glsl"
64#include "Common.comp"
65
66// TODO(gregschlom): Check how widespread is support for these extensions.
67#extension GL_KHR_shader_subgroup_clustered : enable
68#extension GL_KHR_shader_subgroup_shuffle : enable
69
70// To maximize GPU utilization, we use a local workgroup size of 64 which is a multiple of the
71// subgroup size of both AMD and NVIDIA cards.
72layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
73
74// Using 2DArray textures for compatibility with the old ASTC decoder.
75// TODO(gregschlom): Once we have texture metrics, check if we want to keep supporting array text.
76layout(binding = 0, rgba32ui) readonly uniform WITH_TYPE(uimage) srcImage;
77layout(binding = 1, rgba32ui) writeonly uniform WITH_TYPE(uimage) dstImage;
78
79layout(push_constant) uniform imagInfo {
80    uvec2 blockSize;
81    uint baseLayer;
82    uint smallBlock;  // TODO(gregschlom) Remove this once we remove the old decoder.
83}
84u_pushConstant;
85
86// Decodes an ASTC-encoded pixel at `texelPos` to RGBA
87uvec4 decodeRGBA(uvec2 texelPos, uint layer) {
88    uvec2 blockPos = texelPos / u_pushConstant.blockSize;
89    uvec2 posInBlock = texelPos % u_pushConstant.blockSize;
90
91    astcBlock = imageLoad(srcImage, WITH_TYPE(getPos)(ivec3(blockPos, layer))).wzyx;
92    astcDecoderInitialize(astcBlock, u_pushConstant.blockSize);
93    return astcDecodeTexel(posInBlock);
94}
95
96// Returns the 2-bit index of the BC1 color that's the closest to the input color.
97// color: the color that we want to approximate
98// maxEndpoint / minEndpoint: the BC1 endpoint values we've chosen
99uint getColorIndex(vec3 color, vec3 minEndpoint, vec3 maxEndpoint) {
100    // Project `color` on the line that goes between `minEndpoint` and `maxEndpoint`.
101    //
102    // TODO(gregschlom): this doesn't account for the fact that the color palette is actually
103    // quantisized as RGB565 instead of RGB8. A slower but potentially slightly higher quality
104    // approach would be to compute all 4 RGB565 colors in the palette, then find the closest one.
105    vec3 colorLine = maxEndpoint - minEndpoint;
106    float x = dot(color - minEndpoint, colorLine) / dot(colorLine, colorLine);
107
108    // x is now a float in [0, 1] indicating where `color` lies when projected on the line between
109    // the min and max endpoint. Remap x as an integer between 0 and 3.
110    int index = int(round(clamp(x * 3, 0, 3)));
111
112    // Finally, we need to convert to the somewhat unintuitive BC1 indexing scheme, where:
113    //  0 is maxEndpoint, 1 is minEndpoint, 2 is (1/3)*minEndpoint + (2/3)*maxEndpoint and 3 is
114    // (2/3)*minEndpoint + (1/3)*maxEndpoint. The lookup table for this is [1, 3, 2, 0], which we
115    // bit-pack into 8 bits.
116    //
117    // Alternatively, we could use this formula:
118    // `index = -index & 3; return index ^ uint(index < 2);` but the  lookup table method is faster.
119    return bitfieldExtract(45u, index * 2, 2);
120}
121
122// Same as above, but for alpha values, using BC4's encoding scheme.
123uint getAlphaIndex(uint alpha, uint minAlpha, uint maxAlpha) {
124    float x = float(alpha - minAlpha) / float(maxAlpha - minAlpha);
125    int index = int(round(clamp(x * 7, 0, 7)));
126
127    // Like for getColorIndex, we need to remap the index according to BC4's indexing scheme, where
128    //  0 is maxAlpha, 1 is minAlpha, 2 is (1/7)*minAlpha + (6/7)*maxAlpha, etc...
129    // The lookup table for this is [1, 7, 6, 5, 4, 3, 2, 0], which we bit-pack into 32 bits using
130    // 4 bits for each value.
131    //
132    // Alternatively, we could use this formula:
133    // `index = -index & 7; return index ^ uint(index < 2);` but the lookup table method is faster.
134    return bitfieldExtract(36984433u, index * 4, 3);
135}
136
137// Computes the color endpoints using Principal Component Analysis to find the best fit line
138// through the colors in the 4x4 block.
139void computeEndpoints(uvec3 rgbColor, out uvec3 minEndpoint, out uvec3 maxEndpoint) {
140    // See the comment at the top of this file for more details on this algorithm.
141
142    uvec3 avgColor = subgroupClusteredAdd(rgbColor, 16) + 8 >> 4;  // +8 to round to nearest.
143    uvec3 minColor = subgroupClusteredMin(rgbColor, 16);
144    uvec3 maxColor = subgroupClusteredMax(rgbColor, 16);
145
146    // Special case when all pixels are the same color
147    if (minColor == maxColor) {
148        minEndpoint = minColor;
149        maxEndpoint = minColor;
150        return;
151    }
152
153    // Compute the covariance matrix of the r, g and b channels. This is a 3x3 symmetric matrix.
154    // First compute the 6 unique covariance values:
155    ivec3 dx = ivec3(rgbColor) - ivec3(avgColor);
156    vec3 cov1 = subgroupClusteredAdd(dx.r * dx, 16);        // cov(r,r), cov(r,g), cov(r,b)
157    vec3 cov2 = subgroupClusteredAdd(dx.ggb * dx.gbb, 16);  // cov(g,g), cov(g,b), cov(b,b)
158
159    // Then build the matrix:
160    mat3 covMat = mat3(cov1,                    // rr, rg, rb
161                       vec3(cov1.y, cov2.xy),   // rg, gg, gb
162                       vec3(cov1.z, cov2.yz));  // rb, gb, bb
163
164    // Find the principal axis via power iteration. (https://en.wikipedia.org/wiki/Power_iteration)
165    // 3 to 8 iterations are sufficient for a good approximation.
166    // Note: in theory, we're supposed to normalize the vector on each iteration, however we get
167    // significantly higher quality (and obviously faster performance) when not doing it.
168    // TODO(gregschlom): Investigate why that is the case.
169    vec3 principalAxis = covMat * (covMat * (covMat * (covMat * (maxColor - minColor))));
170
171    // Ensure all components are in the [-1,1] range.
172    // TODO(gregschlom): Investigate if we really need this. It doesn't make a lot of sense.
173    float magn = max(max(abs(principalAxis.r), abs(principalAxis.g)), abs(principalAxis.b));
174    principalAxis = (magn < 4.0)  // If the magnitude is too small, default to luminance
175                        ? vec3(0.299f, 0.587f, 0.114f)  // Coefficients to convert RGB to luminance
176                        : principalAxis / magn;
177
178    // Project the colors on the principal axis and pick the 2 colors at the extreme points as the
179    // endpoints.
180    float distance = dot(rgbColor, principalAxis);
181    float minDistance = subgroupClusteredMin(distance, 16);
182    float maxDistance = subgroupClusteredMax(distance, 16);
183
184    uvec2 indices = uvec2(distance == minDistance ? gl_SubgroupInvocationID : 0,
185                          distance == maxDistance ? gl_SubgroupInvocationID : 0);
186    uvec2 minMaxIndex = subgroupClusteredMax(indices, 16);
187
188    // TODO(gregschlom): we're returning the original pixel colors instead of the projected colors.
189    // Investigate if we could increase quality by returning the projected colors.
190    minEndpoint = subgroupShuffle(rgbColor, minMaxIndex.x);
191    maxEndpoint = subgroupShuffle(rgbColor, minMaxIndex.y);
192}
193
194uvec2 encodeAlpha(uint value, uint texelId) {
195    uint minValue = subgroupClusteredMin(value, 16);
196    uint maxValue = subgroupClusteredMax(value, 16);
197
198    // Determine the alpha index (between 0 and 7)
199    uint index = (minValue != maxValue) ? getAlphaIndex(value, minValue, maxValue) : 0;
200
201    // Pack everything together into 64 bits. The first 3-bit index goes at bit 16, the next
202    // one at bit 19 and so on until the last one which goes at bit 61. The bottom 16 bits will
203    // contain the max and min value.
204    // Note: shifting a uint by more than 31 is UB, which is why we need the ternary operator here.
205    uvec2 mask = uvec2(texelId < 5 ? 0 : (index << 29) >> (-3 * texelId + 45),
206                       texelId > 5 ? 0 : index << (3 * texelId + 16));
207    uvec2 packed = subgroupClusteredOr(mask, 16);
208    return uvec2((maxValue & 0xff) | ((minValue & 0xff) << 8) | packed[1], packed[0]);
209}
210
211uint packColorToRGB565(uvec3 color) {
212    uvec3 quant = uvec3(round(vec3(color) * vec3(31.0, 63.0, 31.0) / vec3(255.0)));
213    return (quant.r << 11) | (quant.g << 5) | quant.b;
214}
215
216void main() {
217    // We can't use gl_LocalInvocationID here because the spec doesn't make any guarantees as to how
218    // it will be mapped to gl_SubgroupInvocationID (See: https://stackoverflow.com/q/72451338/).
219    // And since we use subgroupClusteredXXX commands, we must ensure that any 16 consecutive
220    // subgroup invocation ids [16n, 16n+1..16n+15] map to the same 4x4 block in the input image.
221    // So instead of using gl_LocalInvocationID, we construct it from the subgroup ids.
222    // This is a number in the range [0, 63] since local group size is 64
223    uint localId = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
224
225    uint blockId = localId / 16;  // [0-3]  Id of the 4x4 block we're working on
226    uint texelId = localId % 16;  // [0-15] Id of the texel within the 4x4 block
227
228    // Absolute coordinates in the input image
229    uvec2 texelCoord = 8 * gl_WorkGroupID.xy + uvec2(4 * (blockId & 0x1) + (texelId % 4),
230                                                     2 * (blockId & 0x2) + (texelId / 4));
231    // Layer, for array textures.
232    uint layer = u_pushConstant.baseLayer + gl_WorkGroupID.z;
233
234    uvec4 currentTexel = decodeRGBA(texelCoord, layer);
235
236    // Compute the color endpoints
237    uvec3 minEndpoint, maxEndpoint;
238    computeEndpoints(currentTexel.rgb, minEndpoint, maxEndpoint);
239    uvec2 endpoints = uvec2(packColorToRGB565(minEndpoint), packColorToRGB565(maxEndpoint));
240
241    // Find which of the 4 colors best matches the color of the current texel
242    uint index = 0;
243    if (endpoints.x != endpoints.y) {
244        index = getColorIndex(vec3(currentTexel.rgb), vec3(minEndpoint), vec3(maxEndpoint));
245    }
246    if (endpoints.x > endpoints.y) {
247        index ^= 1;
248        endpoints = endpoints.yx;
249    }
250
251    // Pack everything together.
252    uvec4 result;
253    result.rg = encodeAlpha(currentTexel.a, texelId);
254    result.b = endpoints.y | (endpoints.x << 16);
255    result.a = subgroupClusteredOr(index << (2 * texelId), 16);
256
257    if (texelId == 0) {
258        imageStore(dstImage, WITH_TYPE(getPos)(ivec3(texelCoord / 4, layer)), result);
259    }
260}