1// Compute shader to perform ASTC decoding.
2//
3// Usage:
4// #include "AstcDecompressor.glsl"
5//
6// main() {
7//   uvec4 astcBlock = ... // read an ASTC block
8//   astcDecoderInitialize(astcBlock, blockSize);
9//   uvec2 posInBlock = uvec2(0, 0);  // which texel we want to decode in the block
10//   uvec4 texel = astcDecodeTexel(pos);
11// }
12//
13////////////////////////////////////////////////////////////////////////////////////////////////////
14//
15// Please refer for the ASTC spec for all the details:
16// https://www.khronos.org/registry/OpenGL/extensions/KHR/KHR_texture_compression_astc_hdr.txt
17//
18//
19// Quick reminder of an ASTC block layout
20// --------------------------------------
21//
22// Each ASTC block is 128 bits. From top to bottom (from bit 127 to bit 0), we have:
23//    1. weight data (24 - 96 bits). Starts at bit 127 and grows down. (So it needs to be reversed)
24//    2. extra CEM data. 0 bits if only 1 partition OR if CEM selector value in bits [24:23] is 00,
25//       otherwise 2, 5 or 8 bits for 2, 3, or 4 partitions respectively.
26//    3. color component selector (CCS) - 2 bits if dual plane is active, otherwise 0 bits.
27//    4. Color endpoint data - variable length
28//    5. CEM                 4 bits if single partition, else 6 bits.
29//    6. partition seed     10 bits (13-22) - only if more than 1 partition
30//    7. partition count     2 bits (11-12)
31//    8. block mode         11 bits ( 0-10)
32//
33// Optimization ideas
34// ------------------
35//
36//   1. Use a uniform buffer instead of static arrays to load the tables in AstcLookupTables.glsl
37//   2. Investigate using SSBO or sampled image instead of storage image for the input.
38//   3. Make decodeTrit() / decodeQuint() return a pair of values, since we always need at least 2
39//   4. Look into which queue we use to run the shader, some GPUs may have a separate compute queue.
40//   5. Use a `shared` variable to share the block data and common block config, once we change the
41//      local group size to match the block size.
42//
43// Missing features
44// ----------------
45//
46//   1. Make sure we cover all the cases where we should return the error color? See section C.2.24
47//      Illegal Encodings for the full list.
48//   2. Add support for 3D slices.
49//   3. HDR support? Probably not worth it.
50
51#include "AstcLookupTables.glsl"
52
53const uvec4 kErrorColor = uvec4(255, 0, 255, 255);  // beautiful magenta, as per the spec
54
55// Global variables ////////////////////////////////////////////////////////////////////////////////
56
57uvec4 astcBlock;       // Full ASTC block data
58uvec2 blockSize;       // Size of the ASTC block
59bool decodeError;      // True if there's an error in the block.
60bool voidExtent;       // True if void-extent block (all pixels are the same color)
61bool dualPlane;        // True for dual plane blocks (a block with 2 sets of weights)
62uvec2 weightGridSize;  // Width and height of the weight grid. Always <= blockSize.
63uint numWeights;       // Number of weights
64uvec3 weightEncoding;  // Number of trits (x), quints (y) and bits (z) to encode the weights.
65uint weightDataSize;   // Size of the weight data in bits
66uint numPartitions;    // Number of partitions (1-4)
67uint partitionSeed;    // Determines which partition pattern we use (10 bits)
68
69////////////////////////////////////////////////////////////////////////////////////////////////////
70
71// Returns the number of bits needed to encode `numVals` values using a given encoding.
72// encoding: number of trits (x), quints (y) and bits (z) used for the encoding.
73uint getEncodingSize(uint numVals, uvec3 encoding) {
74    // See section C.2.22.
75    uvec2 tqBits = (numVals * encoding.xy * uvec2(8, 7) + uvec2(4, 2)) / uvec2(5, 3);
76    return numVals * encoding.z + (tqBits.x + tqBits.y);
77}
78
79// This function sets all the global variables above
80void astcDecoderInitialize(uvec4 blockData, uvec2 blockSize_) {
81    astcBlock = blockData;
82    blockSize = blockSize_;
83    decodeError = false;
84
85    voidExtent = (astcBlock[3] & 0x1FF) == 0x1FC;
86    if (voidExtent) return;
87
88    const uint bits01 = bitfieldExtract(astcBlock[3], 0, 2);
89    const uint bits23 = bitfieldExtract(astcBlock[3], 2, 2);
90    const uint bit4 = bitfieldExtract(astcBlock[3], 4, 1);
91    const uint bits56 = bitfieldExtract(astcBlock[3], 5, 2);
92    const uint bits78 = bitfieldExtract(astcBlock[3], 7, 2);
93
94    uint r;
95    uint h = bitfieldExtract(astcBlock[3], 9, 1);
96    dualPlane = bool(bitfieldExtract(astcBlock[3], 10, 1));
97
98    // Refer to "Table C.2.8 - 2D Block Mode Layout"
99    if (bits01 == 0) {
100        r = bits23 << 1 | bit4;
101        switch (bits78) {
102            case 0:
103                weightGridSize = uvec2(12, bits56 + 2);
104                break;
105            case 1:
106                weightGridSize = uvec2(bits56 + 2, 12);
107                break;
108            case 2:
109                weightGridSize = uvec2(bits56 + 6, bitfieldExtract(astcBlock[3], 9, 2) + 6);
110                dualPlane = false;
111                h = 0;
112                break;
113            case 3:
114                if (bits56 == 0) {
115                    weightGridSize = uvec2(6, 10);
116                } else if (bits56 == 1) {
117                    weightGridSize = uvec2(10, 6);
118                } else {
119                    decodeError = true;
120                    return;
121                }
122        }
123    } else {
124        r = bits01 << 1 | bit4;
125        switch (bits23) {
126            case 0:
127                weightGridSize = uvec2(bits78 + 4, bits56 + 2);
128                break;
129            case 1:
130                weightGridSize = uvec2(bits78 + 8, bits56 + 2);
131                break;
132            case 2:
133                weightGridSize = uvec2(bits56 + 2, bits78 + 8);
134                break;
135            case 3:
136                if (bits78 >> 1 == 0) {
137                    weightGridSize = uvec2(bits56 + 2, (bits78 & 1) + 6);
138                } else {
139                    weightGridSize = uvec2((bits78 & 1) + 2, bits56 + 2);
140                }
141        }
142    }
143
144    if (any(greaterThan(weightGridSize, blockSize))) {
145        decodeError = true;
146        return;
147    }
148
149    // weigths
150    weightEncoding = kWeightEncodings[h << 3 | r];
151    numWeights = (weightGridSize.x * weightGridSize.y) << int(dualPlane);
152    weightDataSize = getEncodingSize(numWeights, weightEncoding);
153    if (weightDataSize < 24 || weightDataSize > 96 || numWeights > 64) {
154        decodeError = true;
155        return;
156    }
157
158    numPartitions = bitfieldExtract(astcBlock[3], 11, 2) + 1;
159    if (numPartitions > 1) {
160        partitionSeed = bitfieldExtract(astcBlock[3], 13, 10);
161    }
162
163    if (dualPlane && numPartitions == 4) {
164        decodeError = true;
165        return;
166    }
167}
168
169// Extracts a range of bits from a uvec4, treating it as a single 128-bit field.
170// offset: index of the first bit to extract (0-127).
171// numBits: number of bits to extract. (0-32). If numBits is 0, this returns 0.
172// Result is undefined if offset >= 128 or offset + numBits > 128
173uint extractBits(uvec4 data, uint offset, uint numBits) {
174    if (numBits == 0) return 0;
175
176    const uint i = 3 - offset / 32;
177    const uint j = 3 - (offset + numBits - 1) / 32;
178    const uint start = offset & 31;
179    if (i == j) {
180        // All the bits to extract are located on the same component of the vector
181        return bitfieldExtract(data[i], int(start), int(numBits));
182    } else {
183        uint numLowBits = 32 - start;
184        uint lowBits = bitfieldExtract(data[i], int(start), int(numLowBits));
185        uint highBits = bitfieldExtract(data[j], 0, int(numBits - numLowBits));
186        return (highBits << numLowBits) | lowBits;
187    }
188}
189
190// Returns the CEM, a number between 0 and 15 that determines how the endpoints are encoded.
191// Also sets a couple of output parameters:
192// - startOfExtraCem: bit position of the start of the extra CEM
193// - totalEndpoints: number of endpoints in the block, for all partitions.
194// - baseEndpointIndex: index of the first endpoint for this partition
195// Refer to "Section C.2.11  Color Endpoint Mode" for decoding details
196uint decodeCEM(uint partitionIndex, out uint startOfExtraCem, out uint totalEndpoints,
197               out uint baseEndpointIndex) {
198    if (numPartitions == 1) {
199        startOfExtraCem = 128 - weightDataSize;
200        const uint cem = bitfieldExtract(astcBlock[3], 13, 4);
201        totalEndpoints = 2 * (cem >> 2) + 2;
202        baseEndpointIndex = 0;
203        return cem;
204    } else {
205        const uint cemSelector = bitfieldExtract(astcBlock[3], 23, 2);
206        const uint baseCem = bitfieldExtract(astcBlock[3], 25, 4);
207
208        if (cemSelector == 0) {
209            // We're in the multi-partition, single CEM case
210            startOfExtraCem = 128 - weightDataSize;
211            const uint endpointsPerPartition = 2 * (baseCem >> 2) + 2;
212            totalEndpoints = endpointsPerPartition * numPartitions;
213            baseEndpointIndex = endpointsPerPartition * partitionIndex;
214            return baseCem;
215        } else {
216            // Refer to "Figure C.4" for the details of the encoding here.
217
218            // Size in bits of the extra CEM data, which is located right after the weight data.
219            const uint sizeOfExtraCem = 3 * numPartitions - 4;
220            startOfExtraCem = 128 - weightDataSize - sizeOfExtraCem;
221
222            // Extract the extra CEM data
223            const uint extraCem = extractBits(astcBlock, startOfExtraCem, sizeOfExtraCem);
224            const uint fullCem = extraCem << 4 | baseCem;
225
226            const uint mValue =
227                bitfieldExtract(fullCem, int(2 * partitionIndex + numPartitions), 2);
228            const uint cValues = bitfieldExtract(fullCem, 0, int(numPartitions));
229
230            // TODO(gregschlom): investigate whether a couple of small lookup tables would be more
231            // efficient here.
232            totalEndpoints = 2 * (cemSelector * numPartitions + bitCount(cValues));
233            baseEndpointIndex = 2 * (cemSelector * partitionIndex +
234                                     bitCount(bitfieldExtract(cValues, 0, int(partitionIndex))));
235            uint baseClass = cemSelector - 1 + bitfieldExtract(cValues, int(partitionIndex), 1);
236            return baseClass << 2 | mValue;
237        }
238    }
239}
240
241// Decodes a single trit within a block of 5.
242// offset: bit offset where the block of trits starts, within the 128 bits of data
243// numBits: how many bits are used to encode the LSB (0-6)
244// i: index of the trit within the block (0-4)
245// See section "C.2.12  Integer Sequence Encoding"
246uint decodeTrit(uvec4 data, uint offset, uint numBits, uint i) {
247    const int inumBits = int(numBits);
248
249    // In the largest encoding possible (1 trit + 6 bits), the block is 38 bits long (5 * 6 + 8).
250    // Since this wouldn't fit in 32 bits, we extract the low bits for the trit index 0 separately,
251    // this way we only need at most 4 * 6 + 8 = 32 bits, which fits perfectly.
252    const uint block = extractBits(data, offset + numBits, 4 * numBits + 8);
253
254    // Extract the 8 bits that encode the pack of 5 trits
255    // TODO(gregschlom): Optimization idea: if numbits == 0, then packedTrits = block. Worth doing?
256    const uint packedTrits = bitfieldExtract(block, 0, 2) |
257                             bitfieldExtract(block, 1 * inumBits + 2, 2) << 2 |
258                             bitfieldExtract(block, 2 * inumBits + 4, 1) << 4 |
259                             bitfieldExtract(block, 3 * inumBits + 5, 2) << 5 |
260                             bitfieldExtract(block, 4 * inumBits + 7, 1) << 7;
261
262    // Extract the LSB
263    uint lowBits;
264    if (i == 0) {
265        lowBits = extractBits(data, offset, numBits);
266    } else {
267        const int j = int(i) - 1;
268        const ivec4 deltas = {2, 4, 5, 7};
269        lowBits = bitfieldExtract(block, j * inumBits + deltas[j], inumBits);
270    }
271
272    const uint decoded = kTritEncodings[packedTrits];
273    return bitfieldExtract(decoded, 2 * int(i), 2) << numBits | lowBits;
274}
275
276// Decodes a single quint within a block of 3.
277// offset: bit offset where the block of quint starts, within the 128 bits of data
278// numBits: how many bits are used to encode the LSB (0-5)
279// i: index of the quint within the block (0-2)
280// See section "C.2.12  Integer Sequence Encoding"
281uint decodeQuint(uvec4 data, uint offset, uint numBits, uint i) {
282    const int inumBits = int(numBits);
283
284    // Note that we don't have the same size issue as trits (see above), since the largest encoding
285    // here is 1 quint and 5 bits, which is 3 * 5 + 7 = 22 bits long
286    const uint block = extractBits(data, offset, 3 * numBits + 7);
287
288    // Extract the 7 bits that encode the pack of 3 quints
289    const uint packedQuints = bitfieldExtract(block, inumBits, 3) |
290                              bitfieldExtract(block, 2 * inumBits + 3, 2) << 3 |
291                              bitfieldExtract(block, 3 * inumBits + 5, 2) << 5;
292
293    // Extract the LSB
294    const ivec3 deltas = {0, 3, 5};
295    const uint lowBits = bitfieldExtract(block, int(i) * inumBits + deltas[i], inumBits);
296
297    const uint decoded = kQuintEncodings[packedQuints];
298    return bitfieldExtract(decoded, 3 * int(i), 3) << numBits | lowBits;
299}
300
301uint decode1Weight(uvec4 weightData, uvec3 encoding, uint numWeights, uint index) {
302    if (index >= numWeights) return 0;
303
304    uint numBits = encoding.z;
305
306    if (encoding.x == 1) {
307        // 1 trit
308        uint offset = (index / 5) * (5 * numBits + 8);
309        uint w = decodeTrit(weightData, offset, numBits, index % 5);
310        return kUnquantTritWeightMap[3 * ((1 << numBits) - 1) + w];
311    } else if (encoding.y == 1) {
312        // 1 quint
313        uint offset = (index / 3) * (3 * numBits + 7);
314        uint w = decodeQuint(weightData, offset, numBits, index % 3);
315        return kUnquantQuintWeightMap[5 * ((1 << numBits) - 1) + w];
316    } else {
317        // only bits, no trits or quints. We can have between 1 and 6 bits.
318        uint offset = index * numBits;
319        uint w = extractBits(weightData, offset, numBits);
320
321        // The first number in the table is the multiplication factor: 63 / (2^numBits - 1)
322        // The second number is a shift factor to adjust when the previous result isn't an integer.
323        const uvec2 kUnquantTable[] = {{63, 8}, {21, 8}, {9, 8}, {4, 2}, {2, 4}, {1, 8}};
324        const uvec2 unquant = kUnquantTable[numBits - 1];
325        w = w * unquant.x | w >> unquant.y;
326        if (w > 32) w += 1;
327        return w;
328    }
329}
330
331uint interpolateWeights(uvec4 weightData, uvec3 encoding, uint numWeights, uint index,
332                        uint gridWidth, uint stride, uint offset, uvec2 fractionalPart) {
333    uvec4 weightIndices = stride * (uvec4(index) + uvec4(0, 1, gridWidth, gridWidth + 1)) + offset;
334
335    // TODO(gregschlom): Optimization idea: instead of always decoding 4 weights, we could decode
336    // just what we need depending on whether fractionalPart.x and fractionalPart.y are 0
337    uvec4 weights = uvec4(decode1Weight(weightData, encoding, numWeights, weightIndices[0]),
338                          decode1Weight(weightData, encoding, numWeights, weightIndices[1]),
339                          decode1Weight(weightData, encoding, numWeights, weightIndices[2]),
340                          decode1Weight(weightData, encoding, numWeights, weightIndices[3]));
341
342    uint w11 = (fractionalPart.x * fractionalPart.y + 8) >> 4;
343    uvec4 factors = uvec4(16 - fractionalPart.x - fractionalPart.y + w11,  // w00
344                          fractionalPart.x - w11,                          // w01
345                          fractionalPart.y - w11,                          // w10
346                          w11);                                            // w11
347
348    return uint(dot(weights, factors) + 8) >> 4;  // this is what the spec calls "effective weight"
349}
350
351uvec2 decodeWeights(uvec4 weightData, const uvec2 posInBlock) {
352    // Refer to "C.2.18  Weight Infill to interpolate between 4 grid points"
353
354    // TODO(gregschlom): The spec says: "since the block dimensions are constrained, these are
355    // easily looked up in a table." - Is it worth doing?
356    uvec2 scaleFactor = (1024 + blockSize / 2) / (blockSize - 1);
357
358    uvec2 homogeneousCoords = posInBlock * scaleFactor;
359    uvec2 gridCoords = (homogeneousCoords * (weightGridSize - 1) + 32) >> 6;
360    uvec2 integralPart = gridCoords >> 4;
361    uvec2 fractionalPart = gridCoords & 0xf;
362
363    uint gridWidth = weightGridSize.x;
364    uint v0 = integralPart.y * gridWidth + integralPart.x;
365
366    uvec2 weights = uvec2(0);
367    weights.x = interpolateWeights(weightData, weightEncoding, numWeights, v0, gridWidth,
368                                   1 << int(dualPlane), 0, fractionalPart);
369    if (dualPlane) {
370        weights.y = interpolateWeights(weightData, weightEncoding, numWeights, v0, gridWidth, 2, 1,
371                                       fractionalPart);
372    }
373    return weights;
374}
375
376uint hash52(uint p) {
377    p ^= p >> 15;
378    p -= p << 17;
379    p += p << 7;
380    p += p << 4;
381    p ^= p >> 5;
382    p += p << 16;
383    p ^= p >> 7;
384    p ^= p >> 3;
385    p ^= p << 6;
386    p ^= p >> 17;
387    return p;
388}
389
390uint selectPartition(uint seed, uvec2 pos, uint numPartitions) {
391    if (numPartitions == 1) {
392        return 0;
393    }
394    if (blockSize.x * blockSize.y < 31) {
395        pos <<= 1;
396    }
397    seed = 1024 * numPartitions + (seed - 1024);
398    uint rnum = hash52(seed);
399    // TODO(gregschlom): micro-optimization: could repetedly halve the bits to extract them in 6
400    // calls to bitfieldExtract instead of 8.
401    uvec4 seedA = uvec4(bitfieldExtract(rnum, 0, 4), bitfieldExtract(rnum, 4, 4),
402                        bitfieldExtract(rnum, 8, 4), bitfieldExtract(rnum, 12, 4));
403    uvec4 seedB = uvec4(bitfieldExtract(rnum, 16, 4), bitfieldExtract(rnum, 20, 4),
404                        bitfieldExtract(rnum, 24, 4), bitfieldExtract(rnum, 28, 4));
405
406    seedA = seedA * seedA;
407    seedB = seedB * seedB;
408
409    uvec2 shifts1 = uvec2((seed & 2) != 0 ? 4 : 5, numPartitions == 3 ? 6 : 5);
410    uvec4 shifts2 = (seed & 1) != 0 ? shifts1.xyxy : shifts1.yxyx;
411
412    seedA >>= shifts2;
413    seedB >>= shifts2;
414
415    // Note: this could be implemented as matrix multiplication, but we'd have to use floats and I'm
416    // not sure if the values are always small enough to stay accurate.
417    uvec4 result =
418        uvec4(dot(seedA.xy, pos), dot(seedA.zw, pos), dot(seedB.xy, pos), dot(seedB.zw, pos)) +
419        (uvec4(rnum) >> uvec4(14, 10, 6, 2));
420
421    result &= uvec4(0x3F);
422
423    if (numPartitions == 2) {
424        result.zw = uvec2(0);
425    } else if (numPartitions == 3) {
426        result.w = 0;
427    }
428
429    // Return the index of the largest component in `result`
430    if (all(greaterThanEqual(uvec3(result.x), result.yzw))) {
431        return 0;
432    } else if (all(greaterThanEqual(uvec2(result.y), result.zw))) {
433        return 1;
434    } else if (result.z >= result.w) {
435        return 2;
436    } else {
437        return 3;
438    }
439}
440
441uvec3 getEndpointEncoding(uint availableEndpointBits, uint numEndpoints, out uint actualSize) {
442    // This implements the algorithm described in section "C.2.22  Data Size Determination"
443    // TODO(gregschlom): This could be implemented with a lookup table instead. Or we could use a
444    // binary search but not sure if worth it due to the extra cost of branching.
445    for (uint i = 0; i < kColorEncodings.length(); ++i) {
446        uvec3 encoding = kColorEncodings[i];
447        actualSize = getEncodingSize(numEndpoints, encoding);
448        if (actualSize <= availableEndpointBits) {
449            return encoding;
450        }
451    }
452    return uvec3(0);  // this should never happen
453}
454
455ivec4 blueContract(ivec4 v) { return ivec4((v.r + v.b) >> 1, (v.g + v.b) >> 1, v.ba); }
456
457int sum(ivec3 v) { return v.x + v.y + v.z; }
458
459void bitTransferSigned(inout ivec4 a, inout ivec4 b) {
460    b >>= 1;
461    b |= a & 0x80;
462    a >>= 1;
463    a &= 0x3f;
464    // This is equivalent to: "if ((a & 0x20) != 0) a -= 0x40;" in the spec. It treats "a" as a
465    // 6-bit signed integer, converting it from (0, 63) to (-32, 31)
466    a = bitfieldExtract(a, 0, 6);
467}
468
469// Decodes the endpoints and writes them to ep0 and ep1.
470// vA: even-numbered values in the spec (ie: v0, v2, v4 and v6)
471// vB: odd-numbered values in the spec (ie: v1, v3, v5 and v7)
472// mode: the CEM (color endpoint mode)
473// Note: HDR modes are not supported.
474void decodeEndpoints(ivec4 vA, ivec4 vB, uint mode, out uvec4 ep0, out uvec4 ep1) {
475    switch (mode) {
476        case 0:  // LDR luminance only, direct
477            ep0 = uvec4(vA.xxx, 255);
478            ep1 = uvec4(vB.xxx, 255);
479            return;
480
481        case 1: {  // LDR luminance only, base + offset
482            const int l0 = (vA.x >> 2) | (vB.x & 0xC0);
483            const int l1 = min(l0 + (vB.x & 0x3F), 255);
484            ep0 = uvec4(uvec3(l0), 255);
485            ep1 = uvec4(uvec3(l1), 255);
486            return;
487        }
488
489        case 4:  // LDR luminance + alpha, direct
490            ep0 = vA.xxxy;
491            ep1 = vB.xxxy;
492            return;
493
494        case 5:  // LDR luminance + alpha, base + offset
495            bitTransferSigned(vB, vA);
496            ep0 = clamp(vA.xxxy, 0, 255);
497            ep1 = clamp(vA.xxxy + vB.xxxy, 0, 255);
498            return;
499
500        case 6:  // LDR RGB, base + scale
501            ep1 = uvec4(vA.x, vB.x, vA.y, 255);
502            ep0 = uvec4((ep1.rgb * vB.y) >> 8, 255);
503            return;
504
505        case 10:  //  LDR RGB, base + scale, plus alphas
506            ep1 = uvec4(vA.x, vB.x, vA.y, vB.z);
507            ep0 = uvec4((ep1.rgb * vB.y) >> 8, vA.z);
508            return;
509
510        case 8:  // LDR RGB, direct
511            vA.a = 255;
512            vB.a = 255;
513        case 12:  // LDR RGBA, direct
514            if (sum(vB.rgb) >= sum(vA.rgb)) {
515                ep0 = vA;
516                ep1 = vB;
517            } else {
518                ep0 = blueContract(vB);
519                ep1 = blueContract(vA);
520            }
521            return;
522
523        case 9:  // LDR RGB, base + offset
524            // We will end up with vA.a = 255 and vB.a = 0 after calling bitTransferSigned(vB, vA)
525            vA.a = 255;
526            vB.a = -128;
527        case 13:  // LDR RGBA, base + offset
528            bitTransferSigned(vB, vA);
529            if (sum(vB.rgb) >= 0) {
530                ep0 = clamp(vA, 0, 255);
531                ep1 = clamp(vA + vB, 0, 255);
532            } else {
533                ep0 = clamp(blueContract(vA + vB), 0, 255);
534                ep1 = clamp(blueContract(vA), 0, 255);
535            }
536            return;
537
538        default:
539            // Unimplemented color encoding. (HDR)
540            ep0 = uvec4(0);
541            ep1 = uvec4(0);
542    }
543}
544
545uint decode1Endpoint(uvec4 data, uint startOffset, uint index, uvec3 encoding) {
546    uint numBits = encoding.z;
547
548    if (encoding.x == 1) {
549        // 1 trit
550        uint offset = (index / 5) * (5 * numBits + 8) + startOffset;
551        uint ep = decodeTrit(data, offset, numBits, index % 5);
552        return kUnquantTritColorMap[3 * ((1 << numBits) - 1) + ep];
553    } else if (encoding.y == 1) {
554        // 1 quint
555        uint offset = (index / 3) * (3 * numBits + 7) + startOffset;
556        uint ep = decodeQuint(data, offset, numBits, index % 3);
557        return kUnquantQuintColorMap[5 * ((1 << numBits) - 1) + ep];
558    } else {
559        // only bits, no trits or quints. We can have between 1 and 8 bits.
560        uint offset = index * numBits + startOffset;
561        uint w = extractBits(data, offset, numBits);
562        // The first number in the table is the multiplication factor. 255 / (2^numBits - 1)
563        // The second number is a shift factor to adjust when the previous result isn't an integer.
564        const uvec2 kUnquantTable[] = {{255, 8}, {85, 8}, {36, 1}, {17, 8},
565                                       {8, 2},   {4, 4},  {2, 6},  {1, 8}};
566        const uvec2 unquant = kUnquantTable[numBits - 1];
567        return w * unquant.x | w >> unquant.y;
568    }
569}
570
571// Creates a 128-bit mask with the lower n bits set to 1
572uvec4 buildBitmask(uint bits) {
573    ivec4 numBits = int(bits) - ivec4(96, 64, 32, 0);
574    uvec4 mask = (uvec4(1) << clamp(numBits, ivec4(0), ivec4(31))) - 1;
575    return mix(mask, uvec4(0xffffffffu), greaterThanEqual(uvec4(bits), uvec4(128, 96, 64, 32)));
576}
577
578// Main function to decode the texel at a given position in the block
579uvec4 astcDecodeTexel(const uvec2 posInBlock) {
580    if (decodeError) {
581        return kErrorColor;
582    }
583
584    if (voidExtent) {
585        return uvec4(bitfieldExtract(astcBlock[1], 8, 8), bitfieldExtract(astcBlock[1], 24, 8),
586                     bitfieldExtract(astcBlock[0], 8, 8), bitfieldExtract(astcBlock[0], 24, 8));
587    }
588
589    const uvec4 weightData = bitfieldReverse(astcBlock.wzyx) & buildBitmask(weightDataSize);
590    const uvec2 weights = decodeWeights(weightData, posInBlock);
591
592    const uint partitionIndex = selectPartition(partitionSeed, posInBlock, numPartitions);
593
594    uint startOfExtraCem = 0;
595    uint totalEndpoints = 0;
596    uint baseEndpointIndex = 0;
597    uint cem = decodeCEM(partitionIndex, startOfExtraCem, totalEndpoints, baseEndpointIndex);
598
599    // Per spec, we must return the error color if we require more than 18 color endpoints
600    if (totalEndpoints > 18) {
601        return kErrorColor;
602    }
603
604    const uint endpointsStart = (numPartitions == 1) ? 17 : 29;
605    const uint endpointsEnd = -2 * int(dualPlane) + startOfExtraCem;
606    const uint availableEndpointBits = endpointsEnd - endpointsStart;
607    // TODO(gregschlom): Do we need this: if (availableEndpointBits >= 128) return kErrorColor;
608
609    uint actualEndpointBits;
610    const uvec3 endpointEncoding =
611        getEndpointEncoding(availableEndpointBits, totalEndpoints, actualEndpointBits);
612    // TODO(gregschlom): Do we need this: if (endpointEncoding == uvec3(0)) return kErrorColor;
613
614    // Number of endpoints pairs in this partition. (Between 1 and 4)
615    // This is the n field from "Table C.2.17 - Color Endpoint Modes" divided by two
616    const uint numEndpointPairs = (cem >> 2) + 1;
617
618    ivec4 vA = ivec4(0);  // holds what the spec calls v0, v2, v4 and v6
619    ivec4 vB = ivec4(0);  // holds what the spec calls v1, v3, v5 and v7
620
621    uvec4 epData = astcBlock & buildBitmask(endpointsStart + actualEndpointBits);
622
623    for (uint i = 0; i < numEndpointPairs; ++i) {
624        const uint epIdx = 2 * i + baseEndpointIndex;
625        vA[i] = int(decode1Endpoint(epData, endpointsStart, epIdx, endpointEncoding));
626        vB[i] = int(decode1Endpoint(epData, endpointsStart, epIdx + 1, endpointEncoding));
627    }
628
629    uvec4 ep0, ep1;
630    decodeEndpoints(vA, vB, cem, ep0, ep1);
631
632    uvec4 weightsPerChannel = uvec4(weights[0]);
633    if (dualPlane) {
634        uint ccs = extractBits(astcBlock, endpointsEnd, 2);
635        weightsPerChannel[ccs] = weights[1];
636    }
637
638    return (ep0 * (64 - weightsPerChannel) + ep1 * weightsPerChannel + 32) >> 6;
639
640    // TODO(gregschlom): Check section "C.2.19  Weight Application" - we're supposed to do something
641    // else here, depending on whether we're using sRGB or not. Currently we have a difference of up
642    // to 1 when compared against the reference decoder. Probably not worth trying to fix it though.
643}
644