1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.internal.graphics.palette;
18 
19 import static java.lang.System.arraycopy;
20 
21 import android.annotation.NonNull;
22 import android.annotation.Nullable;
23 import android.graphics.Color;
24 
25 import java.util.ArrayList;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Set;
29 
30 
31 /**
32  * Wu's quantization algorithm is a box-cut quantizer that minimizes variance. It takes longer to
33  * run than, say, median color cut, but provides the highest quality results currently known.
34  *
35  * Prefer `QuantizerCelebi`: coupled with Kmeans, this provides the best-known results for image
36  * quantization.
37  *
38  * Seemingly all Wu implementations are based off of one C code snippet that cites a book from 1992
39  * Graphics Gems vol. II, pp. 126-133. As a result, it is very hard to understand the mechanics of
40  * the algorithm, beyond the commentary provided in the C code. Comments on the methods of this
41  * class are avoided in favor of finding another implementation and reading the commentary there,
42  * avoiding perpetuating the same incomplete and somewhat confusing commentary here.
43  */
44 public final class WuQuantizer implements Quantizer {
45     // A histogram of all the input colors is constructed. It has the shape of a
46     // cube. The cube would be too large if it contained all 16 million colors:
47     // historical best practice is to use 5 bits  of the 8 in each channel,
48     // reducing the histogram to a volume of ~32,000.
49     private static final int BITS = 5;
50     private static final int MAX_INDEX = 32;
51     private static final int SIDE_LENGTH = 33;
52     private static final int TOTAL_SIZE = 35937;
53 
54     private int[] mWeights;
55     private int[] mMomentsR;
56     private int[] mMomentsG;
57     private int[] mMomentsB;
58     private double[] mMoments;
59     private Box[] mCubes;
60     private Palette mPalette;
61     private int[] mColors;
62     private Map<Integer, Integer> mInputPixelToCount;
63 
64     @Override
getQuantizedColors()65     public List<Palette.Swatch> getQuantizedColors() {
66         return mPalette.getSwatches();
67     }
68 
69     @Override
quantize(@onNull int[] pixels, int colorCount)70     public void quantize(@NonNull int[] pixels, int colorCount) {
71         assert (pixels.length > 0);
72 
73         QuantizerMap quantizerMap = new QuantizerMap();
74         quantizerMap.quantize(pixels, colorCount);
75         mInputPixelToCount = quantizerMap.getColorToCount();
76         // Extraction should not be run on using a color count higher than the number of colors
77         // in the pixels. The algorithm doesn't expect that to be the case, unexpected results and
78         // exceptions may occur.
79         Set<Integer> uniqueColors = mInputPixelToCount.keySet();
80         if (uniqueColors.size() <= colorCount) {
81             mColors = new int[mInputPixelToCount.keySet().size()];
82             int index = 0;
83             for (int color : uniqueColors) {
84                 mColors[index++] = color;
85             }
86         } else {
87             constructHistogram(mInputPixelToCount);
88             createMoments();
89             CreateBoxesResult createBoxesResult = createBoxes(colorCount);
90             mColors = createResult(createBoxesResult.mResultCount);
91         }
92 
93         List<Palette.Swatch> swatches = new ArrayList<>();
94         for (int color : mColors) {
95             swatches.add(new Palette.Swatch(color, 0));
96         }
97         mPalette = Palette.from(swatches);
98     }
99 
100     @Nullable
getColors()101     public int[] getColors() {
102         return mColors;
103     }
104 
105     /** Keys are color ints, values are the number of pixels in the image matching that color int */
106     @Nullable
inputPixelToCount()107     public Map<Integer, Integer> inputPixelToCount() {
108         return mInputPixelToCount;
109     }
110 
getIndex(int r, int g, int b)111     private static int getIndex(int r, int g, int b) {
112         return (r << 10) + (r << 6) + (g << 5) + r + g + b;
113     }
114 
constructHistogram(Map<Integer, Integer> pixels)115     private void constructHistogram(Map<Integer, Integer> pixels) {
116         mWeights = new int[TOTAL_SIZE];
117         mMomentsR = new int[TOTAL_SIZE];
118         mMomentsG = new int[TOTAL_SIZE];
119         mMomentsB = new int[TOTAL_SIZE];
120         mMoments = new double[TOTAL_SIZE];
121 
122         for (Map.Entry<Integer, Integer> pair : pixels.entrySet()) {
123             int pixel = pair.getKey();
124             int count = pair.getValue();
125             int red = Color.red(pixel);
126             int green = Color.green(pixel);
127             int blue = Color.blue(pixel);
128             int bitsToRemove = 8 - BITS;
129             int iR = (red >> bitsToRemove) + 1;
130             int iG = (green >> bitsToRemove) + 1;
131             int iB = (blue >> bitsToRemove) + 1;
132             int index = getIndex(iR, iG, iB);
133             mWeights[index] += count;
134             mMomentsR[index] += (red * count);
135             mMomentsG[index] += (green * count);
136             mMomentsB[index] += (blue * count);
137             mMoments[index] += (count * ((red * red) + (green * green) + (blue * blue)));
138         }
139     }
140 
createMoments()141     private void createMoments() {
142         for (int r = 1; r < SIDE_LENGTH; ++r) {
143             int[] area = new int[SIDE_LENGTH];
144             int[] areaR = new int[SIDE_LENGTH];
145             int[] areaG = new int[SIDE_LENGTH];
146             int[] areaB = new int[SIDE_LENGTH];
147             double[] area2 = new double[SIDE_LENGTH];
148 
149             for (int g = 1; g < SIDE_LENGTH; ++g) {
150                 int line = 0;
151                 int lineR = 0;
152                 int lineG = 0;
153                 int lineB = 0;
154 
155                 double line2 = 0.0;
156                 for (int b = 1; b < SIDE_LENGTH; ++b) {
157                     int index = getIndex(r, g, b);
158                     line += mWeights[index];
159                     lineR += mMomentsR[index];
160                     lineG += mMomentsG[index];
161                     lineB += mMomentsB[index];
162                     line2 += mMoments[index];
163 
164                     area[b] += line;
165                     areaR[b] += lineR;
166                     areaG[b] += lineG;
167                     areaB[b] += lineB;
168                     area2[b] += line2;
169 
170                     int previousIndex = getIndex(r - 1, g, b);
171                     mWeights[index] = mWeights[previousIndex] + area[b];
172                     mMomentsR[index] = mMomentsR[previousIndex] + areaR[b];
173                     mMomentsG[index] = mMomentsG[previousIndex] + areaG[b];
174                     mMomentsB[index] = mMomentsB[previousIndex] + areaB[b];
175                     mMoments[index] = mMoments[previousIndex] + area2[b];
176                 }
177             }
178         }
179     }
180 
createBoxes(int maxColorCount)181     private CreateBoxesResult createBoxes(int maxColorCount) {
182         mCubes = new Box[maxColorCount];
183         for (int i = 0; i < maxColorCount; i++) {
184             mCubes[i] = new Box();
185         }
186         double[] volumeVariance = new double[maxColorCount];
187         Box firstBox = mCubes[0];
188         firstBox.r1 = MAX_INDEX;
189         firstBox.g1 = MAX_INDEX;
190         firstBox.b1 = MAX_INDEX;
191 
192         int generatedColorCount = 0;
193         int next = 0;
194 
195         for (int i = 1; i < maxColorCount; i++) {
196             if (cut(mCubes[next], mCubes[i])) {
197                 volumeVariance[next] = (mCubes[next].vol > 1) ? variance(mCubes[next]) : 0.0;
198                 volumeVariance[i] = (mCubes[i].vol > 1) ? variance(mCubes[i]) : 0.0;
199             } else {
200                 volumeVariance[next] = 0.0;
201                 i--;
202             }
203 
204             next = 0;
205 
206             double temp = volumeVariance[0];
207             for (int k = 1; k <= i; k++) {
208                 if (volumeVariance[k] > temp) {
209                     temp = volumeVariance[k];
210                     next = k;
211                 }
212             }
213             generatedColorCount = i + 1;
214             if (temp <= 0.0) {
215                 break;
216             }
217         }
218 
219         return new CreateBoxesResult(maxColorCount, generatedColorCount);
220     }
221 
createResult(int colorCount)222     private int[] createResult(int colorCount) {
223         int[] colors = new int[colorCount];
224         int nextAvailableIndex = 0;
225         for (int i = 0; i < colorCount; ++i) {
226             Box cube = mCubes[i];
227             int weight = volume(cube, mWeights);
228             if (weight > 0) {
229                 int r = (volume(cube, mMomentsR) / weight);
230                 int g = (volume(cube, mMomentsG) / weight);
231                 int b = (volume(cube, mMomentsB) / weight);
232                 int color = Color.rgb(r, g, b);
233                 colors[nextAvailableIndex++] = color;
234             }
235         }
236         int[] resultArray = new int[nextAvailableIndex];
237         arraycopy(colors, 0, resultArray, 0, nextAvailableIndex);
238         return resultArray;
239     }
240 
variance(Box cube)241     private double variance(Box cube) {
242         int dr = volume(cube, mMomentsR);
243         int dg = volume(cube, mMomentsG);
244         int db = volume(cube, mMomentsB);
245         double xx =
246                 mMoments[getIndex(cube.r1, cube.g1, cube.b1)]
247                         - mMoments[getIndex(cube.r1, cube.g1, cube.b0)]
248                         - mMoments[getIndex(cube.r1, cube.g0, cube.b1)]
249                         + mMoments[getIndex(cube.r1, cube.g0, cube.b0)]
250                         - mMoments[getIndex(cube.r0, cube.g1, cube.b1)]
251                         + mMoments[getIndex(cube.r0, cube.g1, cube.b0)]
252                         + mMoments[getIndex(cube.r0, cube.g0, cube.b1)]
253                         - mMoments[getIndex(cube.r0, cube.g0, cube.b0)];
254 
255         int hypotenuse = (dr * dr + dg * dg + db * db);
256         int volume2 = volume(cube, mWeights);
257         double variance2 = xx - ((double) hypotenuse / (double) volume2);
258         return variance2;
259     }
260 
cut(Box one, Box two)261     private boolean cut(Box one, Box two) {
262         int wholeR = volume(one, mMomentsR);
263         int wholeG = volume(one, mMomentsG);
264         int wholeB = volume(one, mMomentsB);
265         int wholeW = volume(one, mWeights);
266 
267         MaximizeResult maxRResult =
268                 maximize(one, Direction.RED, one.r0 + 1, one.r1, wholeR, wholeG, wholeB, wholeW);
269         MaximizeResult maxGResult =
270                 maximize(one, Direction.GREEN, one.g0 + 1, one.g1, wholeR, wholeG, wholeB, wholeW);
271         MaximizeResult maxBResult =
272                 maximize(one, Direction.BLUE, one.b0 + 1, one.b1, wholeR, wholeG, wholeB, wholeW);
273         Direction cutDirection;
274         double maxR = maxRResult.mMaximum;
275         double maxG = maxGResult.mMaximum;
276         double maxB = maxBResult.mMaximum;
277         if (maxR >= maxG && maxR >= maxB) {
278             if (maxRResult.mCutLocation < 0) {
279                 return false;
280             }
281             cutDirection = Direction.RED;
282         } else if (maxG >= maxR && maxG >= maxB) {
283             cutDirection = Direction.GREEN;
284         } else {
285             cutDirection = Direction.BLUE;
286         }
287 
288         two.r1 = one.r1;
289         two.g1 = one.g1;
290         two.b1 = one.b1;
291 
292         switch (cutDirection) {
293             case RED:
294                 one.r1 = maxRResult.mCutLocation;
295                 two.r0 = one.r1;
296                 two.g0 = one.g0;
297                 two.b0 = one.b0;
298                 break;
299             case GREEN:
300                 one.g1 = maxGResult.mCutLocation;
301                 two.r0 = one.r0;
302                 two.g0 = one.g1;
303                 two.b0 = one.b0;
304                 break;
305             case BLUE:
306                 one.b1 = maxBResult.mCutLocation;
307                 two.r0 = one.r0;
308                 two.g0 = one.g0;
309                 two.b0 = one.b1;
310                 break;
311             default:
312                 throw new IllegalArgumentException("unexpected direction " + cutDirection);
313         }
314 
315         one.vol = (one.r1 - one.r0) * (one.g1 - one.g0) * (one.b1 - one.b0);
316         two.vol = (two.r1 - two.r0) * (two.g1 - two.g0) * (two.b1 - two.b0);
317 
318         return true;
319     }
320 
maximize( Box cube, Direction direction, int first, int last, int wholeR, int wholeG, int wholeB, int wholeW)321     private MaximizeResult maximize(
322             Box cube,
323             Direction direction,
324             int first,
325             int last,
326             int wholeR,
327             int wholeG,
328             int wholeB,
329             int wholeW) {
330         int baseR = bottom(cube, direction, mMomentsR);
331         int baseG = bottom(cube, direction, mMomentsG);
332         int baseB = bottom(cube, direction, mMomentsB);
333         int baseW = bottom(cube, direction, mWeights);
334 
335         double max = 0.0;
336         int cut = -1;
337         for (int i = first; i < last; i++) {
338             int halfR = baseR + top(cube, direction, i, mMomentsR);
339             int halfG = baseG + top(cube, direction, i, mMomentsG);
340             int halfB = baseB + top(cube, direction, i, mMomentsB);
341             int halfW = baseW + top(cube, direction, i, mWeights);
342 
343             if (halfW == 0) {
344                 continue;
345             }
346             double tempNumerator = halfR * halfR + halfG * halfG + halfB * halfB;
347             double tempDenominator = halfW;
348             double temp = tempNumerator / tempDenominator;
349 
350             halfR = wholeR - halfR;
351             halfG = wholeG - halfG;
352             halfB = wholeB - halfB;
353             halfW = wholeW - halfW;
354             if (halfW == 0) {
355                 continue;
356             }
357 
358             tempNumerator = halfR * halfR + halfG * halfG + halfB * halfB;
359             tempDenominator = halfW;
360             temp += (tempNumerator / tempDenominator);
361             if (temp > max) {
362                 max = temp;
363                 cut = i;
364             }
365         }
366         return new MaximizeResult(cut, max);
367     }
368 
volume(Box cube, int[] moment)369     private static int volume(Box cube, int[] moment) {
370         return (moment[getIndex(cube.r1, cube.g1, cube.b1)]
371                 - moment[getIndex(cube.r1, cube.g1, cube.b0)]
372                 - moment[getIndex(cube.r1, cube.g0, cube.b1)]
373                 + moment[getIndex(cube.r1, cube.g0, cube.b0)]
374                 - moment[getIndex(cube.r0, cube.g1, cube.b1)]
375                 + moment[getIndex(cube.r0, cube.g1, cube.b0)]
376                 + moment[getIndex(cube.r0, cube.g0, cube.b1)]
377                 - moment[getIndex(cube.r0, cube.g0, cube.b0)]);
378     }
379 
bottom(Box cube, Direction direction, int[] moment)380     private static int bottom(Box cube, Direction direction, int[] moment) {
381         switch (direction) {
382             case RED:
383                 return -moment[getIndex(cube.r0, cube.g1, cube.b1)]
384                         + moment[getIndex(cube.r0, cube.g1, cube.b0)]
385                         + moment[getIndex(cube.r0, cube.g0, cube.b1)]
386                         - moment[getIndex(cube.r0, cube.g0, cube.b0)];
387             case GREEN:
388                 return -moment[getIndex(cube.r1, cube.g0, cube.b1)]
389                         + moment[getIndex(cube.r1, cube.g0, cube.b0)]
390                         + moment[getIndex(cube.r0, cube.g0, cube.b1)]
391                         - moment[getIndex(cube.r0, cube.g0, cube.b0)];
392             case BLUE:
393                 return -moment[getIndex(cube.r1, cube.g1, cube.b0)]
394                         + moment[getIndex(cube.r1, cube.g0, cube.b0)]
395                         + moment[getIndex(cube.r0, cube.g1, cube.b0)]
396                         - moment[getIndex(cube.r0, cube.g0, cube.b0)];
397             default:
398                 throw new IllegalArgumentException("unexpected direction " + direction);
399         }
400     }
401 
top(Box cube, Direction direction, int position, int[] moment)402     private static int top(Box cube, Direction direction, int position, int[] moment) {
403         switch (direction) {
404             case RED:
405                 return (moment[getIndex(position, cube.g1, cube.b1)]
406                         - moment[getIndex(position, cube.g1, cube.b0)]
407                         - moment[getIndex(position, cube.g0, cube.b1)]
408                         + moment[getIndex(position, cube.g0, cube.b0)]);
409             case GREEN:
410                 return (moment[getIndex(cube.r1, position, cube.b1)]
411                         - moment[getIndex(cube.r1, position, cube.b0)]
412                         - moment[getIndex(cube.r0, position, cube.b1)]
413                         + moment[getIndex(cube.r0, position, cube.b0)]);
414             case BLUE:
415                 return (moment[getIndex(cube.r1, cube.g1, position)]
416                         - moment[getIndex(cube.r1, cube.g0, position)]
417                         - moment[getIndex(cube.r0, cube.g1, position)]
418                         + moment[getIndex(cube.r0, cube.g0, position)]);
419             default:
420                 throw new IllegalArgumentException("unexpected direction " + direction);
421         }
422     }
423 
424     private enum Direction {
425         RED,
426         GREEN,
427         BLUE
428     }
429 
430     private static class MaximizeResult {
431         // < 0 if cut impossible
432         final int mCutLocation;
433         final double mMaximum;
434 
MaximizeResult(int cut, double max)435         MaximizeResult(int cut, double max) {
436             mCutLocation = cut;
437             mMaximum = max;
438         }
439     }
440 
441     private static class CreateBoxesResult {
442         final int mRequestedCount;
443         final int mResultCount;
444 
CreateBoxesResult(int requestedCount, int resultCount)445         CreateBoxesResult(int requestedCount, int resultCount) {
446             mRequestedCount = requestedCount;
447             mResultCount = resultCount;
448         }
449     }
450 
451     private static class Box {
452         public int r0 = 0;
453         public int r1 = 0;
454         public int g0 = 0;
455         public int g1 = 0;
456         public int b0 = 0;
457         public int b1 = 0;
458         public int vol = 0;
459     }
460 }
461 
462 
463