1 /*
2  * Copyright 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package platform.test.screenshot.matchers
18 
19 import android.graphics.Color
20 import android.graphics.Rect
21 import androidx.annotation.FloatRange
22 import kotlin.collections.List
23 import kotlin.math.pow
24 import platform.test.screenshot.proto.ScreenshotResultProto
25 
26 /**
27  * Image comparison using Structural Similarity Index, developed by Wang, Bovik, Sheikh, and
28  * Simoncelli. Details can be read in their paper:
29  * https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf
30  */
31 class MSSIMMatcher(@FloatRange(from = 0.0, to = 1.0) private val threshold: Double = 0.98) :
32     BitmapMatcher() {
33 
34     companion object {
35         // These values were taken from the publication
36         private const val CONSTANT_L = 254.0
37         private const val CONSTANT_K1 = 0.00001
38         private const val CONSTANT_K2 = 0.00003
39         private val CONSTANT_C1 = (CONSTANT_L * CONSTANT_K1).pow(2.0)
40         private val CONSTANT_C2 = (CONSTANT_L * CONSTANT_K2).pow(2.0)
41         private const val WINDOW_SIZE = 10
42     }
43 
compareBitmapsnull44     override fun compareBitmaps(
45         expected: IntArray,
46         given: IntArray,
47         width: Int,
48         height: Int,
49         regions: List<Rect>
50     ): MatchResult {
51         val filter = getFilter(width, height, regions)
52         val calSSIMResult = calculateSSIM(expected, given, width, height, filter)
53 
54         val stats =
55             ScreenshotResultProto.DiffResult.ComparisonStatistics.newBuilder()
56                 .setNumberPixelsCompared(calSSIMResult.numPixelsCompared)
57                 .setNumberPixelsSimilar(calSSIMResult.numPixelsSimilar)
58                 .setNumberPixelsIgnored(calSSIMResult.numPixelsIgnored)
59                 .setNumberPixelsDifferent(
60                     calSSIMResult.numPixelsCompared - calSSIMResult.numPixelsSimilar
61                 )
62                 .build()
63 
64         if (
65             calSSIMResult.numPixelsSimilar >= threshold * calSSIMResult.numPixelsCompared.toDouble()
66         ) {
67             return MatchResult(matches = true, diff = null, comparisonStatistics = stats)
68         }
69 
70         // Create diff
71         val result = PixelPerfectMatcher().compareBitmaps(expected, given, width, height, regions)
72         return MatchResult(matches = false, diff = result.diff, comparisonStatistics = stats)
73     }
74 
calculateSSIMnull75     internal fun calculateSSIM(
76         ideal: IntArray,
77         given: IntArray,
78         width: Int,
79         height: Int,
80         filter: BooleanArray
81     ): SSIMResult {
82         return calculateSSIM(ideal, given, 0, width, width, height, filter)
83     }
84 
calculateSSIMnull85     private fun calculateSSIM(
86         ideal: IntArray,
87         given: IntArray,
88         offset: Int,
89         stride: Int,
90         width: Int,
91         height: Int,
92         filter: BooleanArray
93     ): SSIMResult {
94         var SSIMTotal = 0.0
95         var totalNumPixelsCompared = 0.0
96         var currentWindowY = 0
97         var ignored = 0
98 
99         while (currentWindowY < height) {
100             val windowHeight = computeWindowSize(currentWindowY, height)
101             var currentWindowX = 0
102             while (currentWindowX < width) {
103                 val windowWidth = computeWindowSize(currentWindowX, width)
104                 val start: Int = indexFromXAndY(currentWindowX, currentWindowY, stride, offset)
105                 if (
106                     shouldIgnoreWindow(ideal, start, stride, windowWidth, windowHeight, filter) &&
107                         shouldIgnoreWindow(given, start, stride, windowWidth, windowHeight, filter)
108                 ) {
109                     currentWindowX += WINDOW_SIZE
110                     ignored += windowWidth * windowHeight
111                     continue
112                 }
113                 val means = getMeans(ideal, given, filter, start, stride, windowWidth, windowHeight)
114                 val meanX = means[0]
115                 val meanY = means[1]
116                 val variances =
117                     getVariances(
118                         ideal,
119                         given,
120                         filter,
121                         meanX,
122                         meanY,
123                         start,
124                         stride,
125                         windowWidth,
126                         windowHeight
127                     )
128                 val varX = variances[0]
129                 val varY = variances[1]
130                 val stdBoth = variances[2]
131                 val SSIM = SSIM(meanX, meanY, varX, varY, stdBoth)
132                 val numPixelsCompared =
133                     numPixelsToCompareInWindow(start, stride, windowWidth, windowHeight, filter)
134                 SSIMTotal += SSIM * numPixelsCompared
135                 totalNumPixelsCompared += numPixelsCompared.toDouble()
136                 currentWindowX += WINDOW_SIZE
137             }
138             currentWindowY += WINDOW_SIZE
139         }
140 
141         val averageSSIM = SSIMTotal / totalNumPixelsCompared
142         return SSIMResult(
143             SSIM = averageSSIM,
144             numPixelsSimilar = (averageSSIM * totalNumPixelsCompared + 0.5).toInt(),
145             numPixelsIgnored = ignored,
146             numPixelsCompared = (totalNumPixelsCompared + 0.5).toInt()
147         )
148     }
149 
150     /**
151      * Compute the size of the window. The window defaults to WINDOW_SIZE, but must be contained
152      * within dimension.
153      */
computeWindowSizenull154     private fun computeWindowSize(coordinateStart: Int, dimension: Int): Int {
155         return if (coordinateStart + WINDOW_SIZE <= dimension) {
156             WINDOW_SIZE
157         } else {
158             dimension - coordinateStart
159         }
160     }
161 
162     /**
163      * Checks whether a pixel should be ignored. A pixel should be ignored if the corresponding
164      * filter entry is false.
165      */
shouldIgnorePixelnull166     private fun shouldIgnorePixel(
167         x: Int,
168         y: Int,
169         start: Int,
170         stride: Int,
171         filter: BooleanArray
172     ): Boolean {
173         return !filter[indexFromXAndY(x, y, stride, start)]
174     }
175 
176     /**
177      * Checks whether a whole window should be ignored. A window should be ignored if all pixels are
178      * either white or should be ignored.
179      */
shouldIgnoreWindownull180     private fun shouldIgnoreWindow(
181         colors: IntArray,
182         start: Int,
183         stride: Int,
184         windowWidth: Int,
185         windowHeight: Int,
186         filter: BooleanArray
187     ): Boolean {
188         for (y in 0 until windowHeight) {
189             for (x in 0 until windowWidth) {
190                 if (shouldIgnorePixel(x, y, start, stride, filter)) {
191                     continue
192                 }
193                 if (colors[indexFromXAndY(x, y, stride, start)] != Color.WHITE) {
194                     return false
195                 }
196             }
197         }
198         return true
199     }
200 
numPixelsToCompareInWindownull201     private fun numPixelsToCompareInWindow(
202         start: Int,
203         stride: Int,
204         windowWidth: Int,
205         windowHeight: Int,
206         filter: BooleanArray
207     ): Int {
208         var numPixelsToCompare = 0
209         for (y in 0 until windowHeight) {
210             for (x in 0 until windowWidth) {
211                 if (!shouldIgnorePixel(x, y, start, stride, filter)) {
212                     numPixelsToCompare++
213                 }
214             }
215         }
216         return numPixelsToCompare
217     }
218 
219     /**
220      * This calculates the position in an array that would represent a bitmap given the parameters.
221      */
indexFromXAndYnull222     private fun indexFromXAndY(x: Int, y: Int, stride: Int, offset: Int): Int {
223         return x + y * stride + offset
224     }
225 
SSIMnull226     private fun SSIM(muX: Double, muY: Double, sigX: Double, sigY: Double, sigXY: Double): Double {
227         var SSIM = (2 * muX * muY + CONSTANT_C1) * (2 * sigXY + CONSTANT_C2)
228         val denom = ((muX * muX + muY * muY + CONSTANT_C1) * (sigX + sigY + CONSTANT_C2))
229         SSIM /= denom
230         return SSIM
231     }
232 
233     /**
234      * This method will find the mean of a window in both sets of pixels. The return is an array
235      * where the first double is the mean of the first set and the second double is the mean of the
236      * second set.
237      */
getMeansnull238     private fun getMeans(
239         pixels0: IntArray,
240         pixels1: IntArray,
241         filter: BooleanArray,
242         start: Int,
243         stride: Int,
244         windowWidth: Int,
245         windowHeight: Int
246     ): DoubleArray {
247         var avg0 = 0.0
248         var avg1 = 0.0
249         var numPixelsCounted = 0.0
250         for (y in 0 until windowHeight) {
251             for (x in 0 until windowWidth) {
252                 if (shouldIgnorePixel(x, y, start, stride, filter)) {
253                     continue
254                 }
255                 val index: Int = indexFromXAndY(x, y, stride, start)
256                 avg0 += getIntensity(pixels0[index])
257                 avg1 += getIntensity(pixels1[index])
258                 numPixelsCounted += 1.0
259             }
260         }
261         avg0 /= numPixelsCounted
262         avg1 /= numPixelsCounted
263         return doubleArrayOf(avg0, avg1)
264     }
265 
266     /**
267      * Finds the variance of the two sets of pixels, as well as the covariance of the windows. The
268      * return value is an array of doubles, the first is the variance of the first set of pixels,
269      * the second is the variance of the second set of pixels, and the third is the covariance.
270      */
getVariancesnull271     private fun getVariances(
272         pixels0: IntArray,
273         pixels1: IntArray,
274         filter: BooleanArray,
275         mean0: Double,
276         mean1: Double,
277         start: Int,
278         stride: Int,
279         windowWidth: Int,
280         windowHeight: Int
281     ): DoubleArray {
282         var var0 = 0.0
283         var var1 = 0.0
284         var varBoth = 0.0
285         var numPixelsCounted = 0
286         for (y in 0 until windowHeight) {
287             for (x in 0 until windowWidth) {
288                 if (shouldIgnorePixel(x, y, start, stride, filter)) {
289                     continue
290                 }
291                 val index: Int = indexFromXAndY(x, y, stride, start)
292                 val v0 = getIntensity(pixels0[index]) - mean0
293                 val v1 = getIntensity(pixels1[index]) - mean1
294                 var0 += v0 * v0
295                 var1 += v1 * v1
296                 varBoth += v0 * v1
297                 numPixelsCounted += 1
298             }
299         }
300         if (numPixelsCounted <= 1) {
301             var0 = 0.0
302             var1 = 0.0
303             varBoth = 0.0
304         } else {
305             var0 /= (numPixelsCounted - 1).toDouble()
306             var1 /= (numPixelsCounted - 1).toDouble()
307             varBoth /= (numPixelsCounted - 1).toDouble()
308         }
309         return doubleArrayOf(var0, var1, varBoth)
310     }
311 
312     /**
313      * Gets the intensity of a given pixel in RGB using luminosity formula
314      *
315      * l = 0.21R' + 0.72G' + 0.07B'
316      *
317      * The prime symbols dictate a gamma correction of 1.
318      */
getIntensitynull319     private fun getIntensity(pixel: Int): Double {
320         val gamma = 1.0
321         var l = 0.0
322         l += 0.21f * (Color.red(pixel) / 255f.toDouble()).pow(gamma)
323         l += 0.72f * (Color.green(pixel) / 255f.toDouble()).pow(gamma)
324         l += 0.07f * (Color.blue(pixel) / 255f.toDouble()).pow(gamma)
325         return l
326     }
327 }
328 
329 /**
330  * Result of the calculation of SSIM.
331  *
332  * @param numPixelsSimilar The number of similar pixels.
333  * @param numPixelsIgnored The number of ignored pixels.
334  * @param numPixelsCompared The number of compared pixels.
335  */
336 class SSIMResult(
337     val SSIM: Double,
338     val numPixelsSimilar: Int,
339     val numPixelsIgnored: Int,
340     val numPixelsCompared: Int
341 )
342