1// Compute shader to perform ASTC decoding. 2// 3// Usage: 4// #include "AstcDecompressor.glsl" 5// 6// main() { 7// uvec4 astcBlock = ... // read an ASTC block 8// astcDecoderInitialize(astcBlock, blockSize); 9// uvec2 posInBlock = uvec2(0, 0); // which texel we want to decode in the block 10// uvec4 texel = astcDecodeTexel(pos); 11// } 12// 13//////////////////////////////////////////////////////////////////////////////////////////////////// 14// 15// Please refer for the ASTC spec for all the details: 16// https://www.khronos.org/registry/OpenGL/extensions/KHR/KHR_texture_compression_astc_hdr.txt 17// 18// 19// Quick reminder of an ASTC block layout 20// -------------------------------------- 21// 22// Each ASTC block is 128 bits. From top to bottom (from bit 127 to bit 0), we have: 23// 1. weight data (24 - 96 bits). Starts at bit 127 and grows down. (So it needs to be reversed) 24// 2. extra CEM data. 0 bits if only 1 partition OR if CEM selector value in bits [24:23] is 00, 25// otherwise 2, 5 or 8 bits for 2, 3, or 4 partitions respectively. 26// 3. color component selector (CCS) - 2 bits if dual plane is active, otherwise 0 bits. 27// 4. Color endpoint data - variable length 28// 5. CEM 4 bits if single partition, else 6 bits. 29// 6. partition seed 10 bits (13-22) - only if more than 1 partition 30// 7. partition count 2 bits (11-12) 31// 8. block mode 11 bits ( 0-10) 32// 33// Optimization ideas 34// ------------------ 35// 36// 1. Use a uniform buffer instead of static arrays to load the tables in AstcLookupTables.glsl 37// 2. Investigate using SSBO or sampled image instead of storage image for the input. 38// 3. Make decodeTrit() / decodeQuint() return a pair of values, since we always need at least 2 39// 4. Look into which queue we use to run the shader, some GPUs may have a separate compute queue. 40// 5. Use a `shared` variable to share the block data and common block config, once we change the 41// local group size to match the block size. 42// 43// Missing features 44// ---------------- 45// 46// 1. Make sure we cover all the cases where we should return the error color? See section C.2.24 47// Illegal Encodings for the full list. 48// 2. Add support for 3D slices. 49// 3. HDR support? Probably not worth it. 50 51#include "AstcLookupTables.glsl" 52 53const uvec4 kErrorColor = uvec4(255, 0, 255, 255); // beautiful magenta, as per the spec 54 55// Global variables //////////////////////////////////////////////////////////////////////////////// 56 57uvec4 astcBlock; // Full ASTC block data 58uvec2 blockSize; // Size of the ASTC block 59bool decodeError; // True if there's an error in the block. 60bool voidExtent; // True if void-extent block (all pixels are the same color) 61bool dualPlane; // True for dual plane blocks (a block with 2 sets of weights) 62uvec2 weightGridSize; // Width and height of the weight grid. Always <= blockSize. 63uint numWeights; // Number of weights 64uvec3 weightEncoding; // Number of trits (x), quints (y) and bits (z) to encode the weights. 65uint weightDataSize; // Size of the weight data in bits 66uint numPartitions; // Number of partitions (1-4) 67uint partitionSeed; // Determines which partition pattern we use (10 bits) 68 69//////////////////////////////////////////////////////////////////////////////////////////////////// 70 71// Returns the number of bits needed to encode `numVals` values using a given encoding. 72// encoding: number of trits (x), quints (y) and bits (z) used for the encoding. 73uint getEncodingSize(uint numVals, uvec3 encoding) { 74 // See section C.2.22. 75 uvec2 tqBits = (numVals * encoding.xy * uvec2(8, 7) + uvec2(4, 2)) / uvec2(5, 3); 76 return numVals * encoding.z + (tqBits.x + tqBits.y); 77} 78 79// This function sets all the global variables above 80void astcDecoderInitialize(uvec4 blockData, uvec2 blockSize_) { 81 astcBlock = blockData; 82 blockSize = blockSize_; 83 decodeError = false; 84 85 voidExtent = (astcBlock[3] & 0x1FF) == 0x1FC; 86 if (voidExtent) return; 87 88 const uint bits01 = bitfieldExtract(astcBlock[3], 0, 2); 89 const uint bits23 = bitfieldExtract(astcBlock[3], 2, 2); 90 const uint bit4 = bitfieldExtract(astcBlock[3], 4, 1); 91 const uint bits56 = bitfieldExtract(astcBlock[3], 5, 2); 92 const uint bits78 = bitfieldExtract(astcBlock[3], 7, 2); 93 94 uint r; 95 uint h = bitfieldExtract(astcBlock[3], 9, 1); 96 dualPlane = bool(bitfieldExtract(astcBlock[3], 10, 1)); 97 98 // Refer to "Table C.2.8 - 2D Block Mode Layout" 99 if (bits01 == 0) { 100 r = bits23 << 1 | bit4; 101 switch (bits78) { 102 case 0: 103 weightGridSize = uvec2(12, bits56 + 2); 104 break; 105 case 1: 106 weightGridSize = uvec2(bits56 + 2, 12); 107 break; 108 case 2: 109 weightGridSize = uvec2(bits56 + 6, bitfieldExtract(astcBlock[3], 9, 2) + 6); 110 dualPlane = false; 111 h = 0; 112 break; 113 case 3: 114 if (bits56 == 0) { 115 weightGridSize = uvec2(6, 10); 116 } else if (bits56 == 1) { 117 weightGridSize = uvec2(10, 6); 118 } else { 119 decodeError = true; 120 return; 121 } 122 } 123 } else { 124 r = bits01 << 1 | bit4; 125 switch (bits23) { 126 case 0: 127 weightGridSize = uvec2(bits78 + 4, bits56 + 2); 128 break; 129 case 1: 130 weightGridSize = uvec2(bits78 + 8, bits56 + 2); 131 break; 132 case 2: 133 weightGridSize = uvec2(bits56 + 2, bits78 + 8); 134 break; 135 case 3: 136 if (bits78 >> 1 == 0) { 137 weightGridSize = uvec2(bits56 + 2, (bits78 & 1) + 6); 138 } else { 139 weightGridSize = uvec2((bits78 & 1) + 2, bits56 + 2); 140 } 141 } 142 } 143 144 if (any(greaterThan(weightGridSize, blockSize))) { 145 decodeError = true; 146 return; 147 } 148 149 // weigths 150 weightEncoding = kWeightEncodings[h << 3 | r]; 151 numWeights = (weightGridSize.x * weightGridSize.y) << int(dualPlane); 152 weightDataSize = getEncodingSize(numWeights, weightEncoding); 153 if (weightDataSize < 24 || weightDataSize > 96 || numWeights > 64) { 154 decodeError = true; 155 return; 156 } 157 158 numPartitions = bitfieldExtract(astcBlock[3], 11, 2) + 1; 159 if (numPartitions > 1) { 160 partitionSeed = bitfieldExtract(astcBlock[3], 13, 10); 161 } 162 163 if (dualPlane && numPartitions == 4) { 164 decodeError = true; 165 return; 166 } 167} 168 169// Extracts a range of bits from a uvec4, treating it as a single 128-bit field. 170// offset: index of the first bit to extract (0-127). 171// numBits: number of bits to extract. (0-32). If numBits is 0, this returns 0. 172// Result is undefined if offset >= 128 or offset + numBits > 128 173uint extractBits(uvec4 data, uint offset, uint numBits) { 174 if (numBits == 0) return 0; 175 176 const uint i = 3 - offset / 32; 177 const uint j = 3 - (offset + numBits - 1) / 32; 178 const uint start = offset & 31; 179 if (i == j) { 180 // All the bits to extract are located on the same component of the vector 181 return bitfieldExtract(data[i], int(start), int(numBits)); 182 } else { 183 uint numLowBits = 32 - start; 184 uint lowBits = bitfieldExtract(data[i], int(start), int(numLowBits)); 185 uint highBits = bitfieldExtract(data[j], 0, int(numBits - numLowBits)); 186 return (highBits << numLowBits) | lowBits; 187 } 188} 189 190// Returns the CEM, a number between 0 and 15 that determines how the endpoints are encoded. 191// Also sets a couple of output parameters: 192// - startOfExtraCem: bit position of the start of the extra CEM 193// - totalEndpoints: number of endpoints in the block, for all partitions. 194// - baseEndpointIndex: index of the first endpoint for this partition 195// Refer to "Section C.2.11 Color Endpoint Mode" for decoding details 196uint decodeCEM(uint partitionIndex, out uint startOfExtraCem, out uint totalEndpoints, 197 out uint baseEndpointIndex) { 198 if (numPartitions == 1) { 199 startOfExtraCem = 128 - weightDataSize; 200 const uint cem = bitfieldExtract(astcBlock[3], 13, 4); 201 totalEndpoints = 2 * (cem >> 2) + 2; 202 baseEndpointIndex = 0; 203 return cem; 204 } else { 205 const uint cemSelector = bitfieldExtract(astcBlock[3], 23, 2); 206 const uint baseCem = bitfieldExtract(astcBlock[3], 25, 4); 207 208 if (cemSelector == 0) { 209 // We're in the multi-partition, single CEM case 210 startOfExtraCem = 128 - weightDataSize; 211 const uint endpointsPerPartition = 2 * (baseCem >> 2) + 2; 212 totalEndpoints = endpointsPerPartition * numPartitions; 213 baseEndpointIndex = endpointsPerPartition * partitionIndex; 214 return baseCem; 215 } else { 216 // Refer to "Figure C.4" for the details of the encoding here. 217 218 // Size in bits of the extra CEM data, which is located right after the weight data. 219 const uint sizeOfExtraCem = 3 * numPartitions - 4; 220 startOfExtraCem = 128 - weightDataSize - sizeOfExtraCem; 221 222 // Extract the extra CEM data 223 const uint extraCem = extractBits(astcBlock, startOfExtraCem, sizeOfExtraCem); 224 const uint fullCem = extraCem << 4 | baseCem; 225 226 const uint mValue = 227 bitfieldExtract(fullCem, int(2 * partitionIndex + numPartitions), 2); 228 const uint cValues = bitfieldExtract(fullCem, 0, int(numPartitions)); 229 230 // TODO(gregschlom): investigate whether a couple of small lookup tables would be more 231 // efficient here. 232 totalEndpoints = 2 * (cemSelector * numPartitions + bitCount(cValues)); 233 baseEndpointIndex = 2 * (cemSelector * partitionIndex + 234 bitCount(bitfieldExtract(cValues, 0, int(partitionIndex)))); 235 uint baseClass = cemSelector - 1 + bitfieldExtract(cValues, int(partitionIndex), 1); 236 return baseClass << 2 | mValue; 237 } 238 } 239} 240 241// Decodes a single trit within a block of 5. 242// offset: bit offset where the block of trits starts, within the 128 bits of data 243// numBits: how many bits are used to encode the LSB (0-6) 244// i: index of the trit within the block (0-4) 245// See section "C.2.12 Integer Sequence Encoding" 246uint decodeTrit(uvec4 data, uint offset, uint numBits, uint i) { 247 const int inumBits = int(numBits); 248 249 // In the largest encoding possible (1 trit + 6 bits), the block is 38 bits long (5 * 6 + 8). 250 // Since this wouldn't fit in 32 bits, we extract the low bits for the trit index 0 separately, 251 // this way we only need at most 4 * 6 + 8 = 32 bits, which fits perfectly. 252 const uint block = extractBits(data, offset + numBits, 4 * numBits + 8); 253 254 // Extract the 8 bits that encode the pack of 5 trits 255 // TODO(gregschlom): Optimization idea: if numbits == 0, then packedTrits = block. Worth doing? 256 const uint packedTrits = bitfieldExtract(block, 0, 2) | 257 bitfieldExtract(block, 1 * inumBits + 2, 2) << 2 | 258 bitfieldExtract(block, 2 * inumBits + 4, 1) << 4 | 259 bitfieldExtract(block, 3 * inumBits + 5, 2) << 5 | 260 bitfieldExtract(block, 4 * inumBits + 7, 1) << 7; 261 262 // Extract the LSB 263 uint lowBits; 264 if (i == 0) { 265 lowBits = extractBits(data, offset, numBits); 266 } else { 267 const int j = int(i) - 1; 268 const ivec4 deltas = {2, 4, 5, 7}; 269 lowBits = bitfieldExtract(block, j * inumBits + deltas[j], inumBits); 270 } 271 272 const uint decoded = kTritEncodings[packedTrits]; 273 return bitfieldExtract(decoded, 2 * int(i), 2) << numBits | lowBits; 274} 275 276// Decodes a single quint within a block of 3. 277// offset: bit offset where the block of quint starts, within the 128 bits of data 278// numBits: how many bits are used to encode the LSB (0-5) 279// i: index of the quint within the block (0-2) 280// See section "C.2.12 Integer Sequence Encoding" 281uint decodeQuint(uvec4 data, uint offset, uint numBits, uint i) { 282 const int inumBits = int(numBits); 283 284 // Note that we don't have the same size issue as trits (see above), since the largest encoding 285 // here is 1 quint and 5 bits, which is 3 * 5 + 7 = 22 bits long 286 const uint block = extractBits(data, offset, 3 * numBits + 7); 287 288 // Extract the 7 bits that encode the pack of 3 quints 289 const uint packedQuints = bitfieldExtract(block, inumBits, 3) | 290 bitfieldExtract(block, 2 * inumBits + 3, 2) << 3 | 291 bitfieldExtract(block, 3 * inumBits + 5, 2) << 5; 292 293 // Extract the LSB 294 const ivec3 deltas = {0, 3, 5}; 295 const uint lowBits = bitfieldExtract(block, int(i) * inumBits + deltas[i], inumBits); 296 297 const uint decoded = kQuintEncodings[packedQuints]; 298 return bitfieldExtract(decoded, 3 * int(i), 3) << numBits | lowBits; 299} 300 301uint decode1Weight(uvec4 weightData, uvec3 encoding, uint numWeights, uint index) { 302 if (index >= numWeights) return 0; 303 304 uint numBits = encoding.z; 305 306 if (encoding.x == 1) { 307 // 1 trit 308 uint offset = (index / 5) * (5 * numBits + 8); 309 uint w = decodeTrit(weightData, offset, numBits, index % 5); 310 return kUnquantTritWeightMap[3 * ((1 << numBits) - 1) + w]; 311 } else if (encoding.y == 1) { 312 // 1 quint 313 uint offset = (index / 3) * (3 * numBits + 7); 314 uint w = decodeQuint(weightData, offset, numBits, index % 3); 315 return kUnquantQuintWeightMap[5 * ((1 << numBits) - 1) + w]; 316 } else { 317 // only bits, no trits or quints. We can have between 1 and 6 bits. 318 uint offset = index * numBits; 319 uint w = extractBits(weightData, offset, numBits); 320 321 // The first number in the table is the multiplication factor: 63 / (2^numBits - 1) 322 // The second number is a shift factor to adjust when the previous result isn't an integer. 323 const uvec2 kUnquantTable[] = {{63, 8}, {21, 8}, {9, 8}, {4, 2}, {2, 4}, {1, 8}}; 324 const uvec2 unquant = kUnquantTable[numBits - 1]; 325 w = w * unquant.x | w >> unquant.y; 326 if (w > 32) w += 1; 327 return w; 328 } 329} 330 331uint interpolateWeights(uvec4 weightData, uvec3 encoding, uint numWeights, uint index, 332 uint gridWidth, uint stride, uint offset, uvec2 fractionalPart) { 333 uvec4 weightIndices = stride * (uvec4(index) + uvec4(0, 1, gridWidth, gridWidth + 1)) + offset; 334 335 // TODO(gregschlom): Optimization idea: instead of always decoding 4 weights, we could decode 336 // just what we need depending on whether fractionalPart.x and fractionalPart.y are 0 337 uvec4 weights = uvec4(decode1Weight(weightData, encoding, numWeights, weightIndices[0]), 338 decode1Weight(weightData, encoding, numWeights, weightIndices[1]), 339 decode1Weight(weightData, encoding, numWeights, weightIndices[2]), 340 decode1Weight(weightData, encoding, numWeights, weightIndices[3])); 341 342 uint w11 = (fractionalPart.x * fractionalPart.y + 8) >> 4; 343 uvec4 factors = uvec4(16 - fractionalPart.x - fractionalPart.y + w11, // w00 344 fractionalPart.x - w11, // w01 345 fractionalPart.y - w11, // w10 346 w11); // w11 347 348 return uint(dot(weights, factors) + 8) >> 4; // this is what the spec calls "effective weight" 349} 350 351uvec2 decodeWeights(uvec4 weightData, const uvec2 posInBlock) { 352 // Refer to "C.2.18 Weight Infill to interpolate between 4 grid points" 353 354 // TODO(gregschlom): The spec says: "since the block dimensions are constrained, these are 355 // easily looked up in a table." - Is it worth doing? 356 uvec2 scaleFactor = (1024 + blockSize / 2) / (blockSize - 1); 357 358 uvec2 homogeneousCoords = posInBlock * scaleFactor; 359 uvec2 gridCoords = (homogeneousCoords * (weightGridSize - 1) + 32) >> 6; 360 uvec2 integralPart = gridCoords >> 4; 361 uvec2 fractionalPart = gridCoords & 0xf; 362 363 uint gridWidth = weightGridSize.x; 364 uint v0 = integralPart.y * gridWidth + integralPart.x; 365 366 uvec2 weights = uvec2(0); 367 weights.x = interpolateWeights(weightData, weightEncoding, numWeights, v0, gridWidth, 368 1 << int(dualPlane), 0, fractionalPart); 369 if (dualPlane) { 370 weights.y = interpolateWeights(weightData, weightEncoding, numWeights, v0, gridWidth, 2, 1, 371 fractionalPart); 372 } 373 return weights; 374} 375 376uint hash52(uint p) { 377 p ^= p >> 15; 378 p -= p << 17; 379 p += p << 7; 380 p += p << 4; 381 p ^= p >> 5; 382 p += p << 16; 383 p ^= p >> 7; 384 p ^= p >> 3; 385 p ^= p << 6; 386 p ^= p >> 17; 387 return p; 388} 389 390uint selectPartition(uint seed, uvec2 pos, uint numPartitions) { 391 if (numPartitions == 1) { 392 return 0; 393 } 394 if (blockSize.x * blockSize.y < 31) { 395 pos <<= 1; 396 } 397 seed = 1024 * numPartitions + (seed - 1024); 398 uint rnum = hash52(seed); 399 // TODO(gregschlom): micro-optimization: could repetedly halve the bits to extract them in 6 400 // calls to bitfieldExtract instead of 8. 401 uvec4 seedA = uvec4(bitfieldExtract(rnum, 0, 4), bitfieldExtract(rnum, 4, 4), 402 bitfieldExtract(rnum, 8, 4), bitfieldExtract(rnum, 12, 4)); 403 uvec4 seedB = uvec4(bitfieldExtract(rnum, 16, 4), bitfieldExtract(rnum, 20, 4), 404 bitfieldExtract(rnum, 24, 4), bitfieldExtract(rnum, 28, 4)); 405 406 seedA = seedA * seedA; 407 seedB = seedB * seedB; 408 409 uvec2 shifts1 = uvec2((seed & 2) != 0 ? 4 : 5, numPartitions == 3 ? 6 : 5); 410 uvec4 shifts2 = (seed & 1) != 0 ? shifts1.xyxy : shifts1.yxyx; 411 412 seedA >>= shifts2; 413 seedB >>= shifts2; 414 415 // Note: this could be implemented as matrix multiplication, but we'd have to use floats and I'm 416 // not sure if the values are always small enough to stay accurate. 417 uvec4 result = 418 uvec4(dot(seedA.xy, pos), dot(seedA.zw, pos), dot(seedB.xy, pos), dot(seedB.zw, pos)) + 419 (uvec4(rnum) >> uvec4(14, 10, 6, 2)); 420 421 result &= uvec4(0x3F); 422 423 if (numPartitions == 2) { 424 result.zw = uvec2(0); 425 } else if (numPartitions == 3) { 426 result.w = 0; 427 } 428 429 // Return the index of the largest component in `result` 430 if (all(greaterThanEqual(uvec3(result.x), result.yzw))) { 431 return 0; 432 } else if (all(greaterThanEqual(uvec2(result.y), result.zw))) { 433 return 1; 434 } else if (result.z >= result.w) { 435 return 2; 436 } else { 437 return 3; 438 } 439} 440 441uvec3 getEndpointEncoding(uint availableEndpointBits, uint numEndpoints, out uint actualSize) { 442 // This implements the algorithm described in section "C.2.22 Data Size Determination" 443 // TODO(gregschlom): This could be implemented with a lookup table instead. Or we could use a 444 // binary search but not sure if worth it due to the extra cost of branching. 445 for (uint i = 0; i < kColorEncodings.length(); ++i) { 446 uvec3 encoding = kColorEncodings[i]; 447 actualSize = getEncodingSize(numEndpoints, encoding); 448 if (actualSize <= availableEndpointBits) { 449 return encoding; 450 } 451 } 452 return uvec3(0); // this should never happen 453} 454 455ivec4 blueContract(ivec4 v) { return ivec4((v.r + v.b) >> 1, (v.g + v.b) >> 1, v.ba); } 456 457int sum(ivec3 v) { return v.x + v.y + v.z; } 458 459void bitTransferSigned(inout ivec4 a, inout ivec4 b) { 460 b >>= 1; 461 b |= a & 0x80; 462 a >>= 1; 463 a &= 0x3f; 464 // This is equivalent to: "if ((a & 0x20) != 0) a -= 0x40;" in the spec. It treats "a" as a 465 // 6-bit signed integer, converting it from (0, 63) to (-32, 31) 466 a = bitfieldExtract(a, 0, 6); 467} 468 469// Decodes the endpoints and writes them to ep0 and ep1. 470// vA: even-numbered values in the spec (ie: v0, v2, v4 and v6) 471// vB: odd-numbered values in the spec (ie: v1, v3, v5 and v7) 472// mode: the CEM (color endpoint mode) 473// Note: HDR modes are not supported. 474void decodeEndpoints(ivec4 vA, ivec4 vB, uint mode, out uvec4 ep0, out uvec4 ep1) { 475 switch (mode) { 476 case 0: // LDR luminance only, direct 477 ep0 = uvec4(vA.xxx, 255); 478 ep1 = uvec4(vB.xxx, 255); 479 return; 480 481 case 1: { // LDR luminance only, base + offset 482 const int l0 = (vA.x >> 2) | (vB.x & 0xC0); 483 const int l1 = min(l0 + (vB.x & 0x3F), 255); 484 ep0 = uvec4(uvec3(l0), 255); 485 ep1 = uvec4(uvec3(l1), 255); 486 return; 487 } 488 489 case 4: // LDR luminance + alpha, direct 490 ep0 = vA.xxxy; 491 ep1 = vB.xxxy; 492 return; 493 494 case 5: // LDR luminance + alpha, base + offset 495 bitTransferSigned(vB, vA); 496 ep0 = clamp(vA.xxxy, 0, 255); 497 ep1 = clamp(vA.xxxy + vB.xxxy, 0, 255); 498 return; 499 500 case 6: // LDR RGB, base + scale 501 ep1 = uvec4(vA.x, vB.x, vA.y, 255); 502 ep0 = uvec4((ep1.rgb * vB.y) >> 8, 255); 503 return; 504 505 case 10: // LDR RGB, base + scale, plus alphas 506 ep1 = uvec4(vA.x, vB.x, vA.y, vB.z); 507 ep0 = uvec4((ep1.rgb * vB.y) >> 8, vA.z); 508 return; 509 510 case 8: // LDR RGB, direct 511 vA.a = 255; 512 vB.a = 255; 513 case 12: // LDR RGBA, direct 514 if (sum(vB.rgb) >= sum(vA.rgb)) { 515 ep0 = vA; 516 ep1 = vB; 517 } else { 518 ep0 = blueContract(vB); 519 ep1 = blueContract(vA); 520 } 521 return; 522 523 case 9: // LDR RGB, base + offset 524 // We will end up with vA.a = 255 and vB.a = 0 after calling bitTransferSigned(vB, vA) 525 vA.a = 255; 526 vB.a = -128; 527 case 13: // LDR RGBA, base + offset 528 bitTransferSigned(vB, vA); 529 if (sum(vB.rgb) >= 0) { 530 ep0 = clamp(vA, 0, 255); 531 ep1 = clamp(vA + vB, 0, 255); 532 } else { 533 ep0 = clamp(blueContract(vA + vB), 0, 255); 534 ep1 = clamp(blueContract(vA), 0, 255); 535 } 536 return; 537 538 default: 539 // Unimplemented color encoding. (HDR) 540 ep0 = uvec4(0); 541 ep1 = uvec4(0); 542 } 543} 544 545uint decode1Endpoint(uvec4 data, uint startOffset, uint index, uvec3 encoding) { 546 uint numBits = encoding.z; 547 548 if (encoding.x == 1) { 549 // 1 trit 550 uint offset = (index / 5) * (5 * numBits + 8) + startOffset; 551 uint ep = decodeTrit(data, offset, numBits, index % 5); 552 return kUnquantTritColorMap[3 * ((1 << numBits) - 1) + ep]; 553 } else if (encoding.y == 1) { 554 // 1 quint 555 uint offset = (index / 3) * (3 * numBits + 7) + startOffset; 556 uint ep = decodeQuint(data, offset, numBits, index % 3); 557 return kUnquantQuintColorMap[5 * ((1 << numBits) - 1) + ep]; 558 } else { 559 // only bits, no trits or quints. We can have between 1 and 8 bits. 560 uint offset = index * numBits + startOffset; 561 uint w = extractBits(data, offset, numBits); 562 // The first number in the table is the multiplication factor. 255 / (2^numBits - 1) 563 // The second number is a shift factor to adjust when the previous result isn't an integer. 564 const uvec2 kUnquantTable[] = {{255, 8}, {85, 8}, {36, 1}, {17, 8}, 565 {8, 2}, {4, 4}, {2, 6}, {1, 8}}; 566 const uvec2 unquant = kUnquantTable[numBits - 1]; 567 return w * unquant.x | w >> unquant.y; 568 } 569} 570 571// Creates a 128-bit mask with the lower n bits set to 1 572uvec4 buildBitmask(uint bits) { 573 ivec4 numBits = int(bits) - ivec4(96, 64, 32, 0); 574 uvec4 mask = (uvec4(1) << clamp(numBits, ivec4(0), ivec4(31))) - 1; 575 return mix(mask, uvec4(0xffffffffu), greaterThanEqual(uvec4(bits), uvec4(128, 96, 64, 32))); 576} 577 578// Main function to decode the texel at a given position in the block 579uvec4 astcDecodeTexel(const uvec2 posInBlock) { 580 if (decodeError) { 581 return kErrorColor; 582 } 583 584 if (voidExtent) { 585 return uvec4(bitfieldExtract(astcBlock[1], 8, 8), bitfieldExtract(astcBlock[1], 24, 8), 586 bitfieldExtract(astcBlock[0], 8, 8), bitfieldExtract(astcBlock[0], 24, 8)); 587 } 588 589 const uvec4 weightData = bitfieldReverse(astcBlock.wzyx) & buildBitmask(weightDataSize); 590 const uvec2 weights = decodeWeights(weightData, posInBlock); 591 592 const uint partitionIndex = selectPartition(partitionSeed, posInBlock, numPartitions); 593 594 uint startOfExtraCem = 0; 595 uint totalEndpoints = 0; 596 uint baseEndpointIndex = 0; 597 uint cem = decodeCEM(partitionIndex, startOfExtraCem, totalEndpoints, baseEndpointIndex); 598 599 // Per spec, we must return the error color if we require more than 18 color endpoints 600 if (totalEndpoints > 18) { 601 return kErrorColor; 602 } 603 604 const uint endpointsStart = (numPartitions == 1) ? 17 : 29; 605 const uint endpointsEnd = -2 * int(dualPlane) + startOfExtraCem; 606 const uint availableEndpointBits = endpointsEnd - endpointsStart; 607 // TODO(gregschlom): Do we need this: if (availableEndpointBits >= 128) return kErrorColor; 608 609 uint actualEndpointBits; 610 const uvec3 endpointEncoding = 611 getEndpointEncoding(availableEndpointBits, totalEndpoints, actualEndpointBits); 612 // TODO(gregschlom): Do we need this: if (endpointEncoding == uvec3(0)) return kErrorColor; 613 614 // Number of endpoints pairs in this partition. (Between 1 and 4) 615 // This is the n field from "Table C.2.17 - Color Endpoint Modes" divided by two 616 const uint numEndpointPairs = (cem >> 2) + 1; 617 618 ivec4 vA = ivec4(0); // holds what the spec calls v0, v2, v4 and v6 619 ivec4 vB = ivec4(0); // holds what the spec calls v1, v3, v5 and v7 620 621 uvec4 epData = astcBlock & buildBitmask(endpointsStart + actualEndpointBits); 622 623 for (uint i = 0; i < numEndpointPairs; ++i) { 624 const uint epIdx = 2 * i + baseEndpointIndex; 625 vA[i] = int(decode1Endpoint(epData, endpointsStart, epIdx, endpointEncoding)); 626 vB[i] = int(decode1Endpoint(epData, endpointsStart, epIdx + 1, endpointEncoding)); 627 } 628 629 uvec4 ep0, ep1; 630 decodeEndpoints(vA, vB, cem, ep0, ep1); 631 632 uvec4 weightsPerChannel = uvec4(weights[0]); 633 if (dualPlane) { 634 uint ccs = extractBits(astcBlock, endpointsEnd, 2); 635 weightsPerChannel[ccs] = weights[1]; 636 } 637 638 return (ep0 * (64 - weightsPerChannel) + ep1 * weightsPerChannel + 32) >> 6; 639 640 // TODO(gregschlom): Check section "C.2.19 Weight Application" - we're supposed to do something 641 // else here, depending on whether we're using sRGB or not. Currently we have a difference of up 642 // to 1 when compared against the reference decoder. Probably not worth trying to fix it though. 643} 644