1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package platform.test.motion.truth
18 
19 import com.google.common.truth.Fact
20 import com.google.common.truth.Fact.fact
21 import com.google.common.truth.Fact.simpleFact
22 import com.google.common.truth.FailureMetadata
23 import com.google.common.truth.Subject
24 import com.google.common.truth.Subject.Factory
25 import com.google.common.truth.Truth
26 import platform.test.motion.golden.TimeSeries
27 
28 /** Subject on [TimeSeries] to produce meaningful failure diffs. */
29 class TimeSeriesSubject
30 private constructor(failureMetadata: FailureMetadata, private val actual: TimeSeries?) :
31     Subject(failureMetadata, actual) {
32 
33     override fun isEqualTo(expected: Any?) {
34         if (actual is TimeSeries && expected is TimeSeries) {
35             val facts = compareTimeSeries(expected, actual)
36             if (facts.isNotEmpty()) {
37                 failWithoutActual(facts[0], *(facts.drop(1)).toTypedArray())
38             }
39         } else {
40             super.isEqualTo(expected)
41         }
42     }
43 
44     private fun compareTimeSeries(expected: TimeSeries, actual: TimeSeries) =
45         buildList<Fact> {
46             val actualToExpectedDataPointIndices: List<Pair<Int, Int>>
47             if (actual.frameIds != expected.frameIds) {
48                 add(simpleFact("TimeSeries.frames does not match"))
49                 add(fact("|  expected", expected.frameIds.map { it.label }))
50                 add(fact("|  but got", actual.frameIds.map { it.label }))
51 
52                 val actualFrameIds = actual.frameIds.toSet()
53                 val expectedFrameIds = expected.frameIds.toSet()
54                 val framesToCompare = actualFrameIds.intersect(expectedFrameIds)
55 
56                 if (framesToCompare != actualFrameIds) {
57                     val unexpected = actualFrameIds - framesToCompare
58                     add(fact("|  unexpected (${ unexpected.size})", unexpected.map { it.label }))
59                 }
60 
61                 if (framesToCompare != expectedFrameIds) {
62                     val missing = expectedFrameIds - framesToCompare
63                     add(fact("|  missing (${ missing.size})", missing.map { it.label }))
64                 }
65                 actualToExpectedDataPointIndices =
66                     framesToCompare.map {
67                         actual.frameIds.indexOf(it) to expected.frameIds.indexOf(it)
68                     }
69             } else {
70                 actualToExpectedDataPointIndices = List(actual.frameIds.size) { it to it }
71             }
72 
73             val featuresToCompare: Set<String>
74             if (actual.features.keys != expected.features.keys) {
75                 featuresToCompare = actual.features.keys.intersect(expected.features.keys)
76                 add(simpleFact("TimeSeries.features does not match"))
77 
78                 if (featuresToCompare != actual.features.keys) {
79                     val unexpected = actual.features.keys - featuresToCompare
80                     add(fact("|  unexpected (${ unexpected.size})", unexpected))
81                 }
82 
83                 if (featuresToCompare != expected.features.keys) {
84                     val missing = expected.features.keys - featuresToCompare
85                     add(fact("|  missing (${ missing.size})", missing))
86                 }
87             } else {
88                 featuresToCompare = actual.features.keys
89             }
90 
91             featuresToCompare.forEach { featureKey ->
92                 val actualFeature = checkNotNull(actual.features[featureKey])
93                 val expectedFeature = checkNotNull(expected.features[featureKey])
94 
95                 val mismatchingDataPointIndices =
96                     actualToExpectedDataPointIndices.filter { (actualIndex, expectedIndex) ->
97                         actualFeature.dataPoints[actualIndex] !=
98                             expectedFeature.dataPoints[expectedIndex]
99                     }
100 
101                 if (mismatchingDataPointIndices.isNotEmpty()) {
102                     add(simpleFact("TimeSeries.features[$featureKey].dataPoints do not match"))
103 
104                     mismatchingDataPointIndices.forEach { (actualIndex, expectedIndex) ->
105                         add(simpleFact("|  @${actual.frameIds[actualIndex].label}"))
106                         add(fact("|    expected", expectedFeature.dataPoints[expectedIndex]))
107                         add(fact("|    but was", actualFeature.dataPoints[actualIndex]))
108                     }
109                 }
110             }
111         }
112 
113     companion object {
114         /** Returns a factory to be used with [Truth.assertAbout]. */
115         fun timeSeries(): Factory<TimeSeriesSubject, TimeSeries> {
116             return Factory { failureMetadata: FailureMetadata, subject: TimeSeries? ->
117                 TimeSeriesSubject(failureMetadata, subject)
118             }
119         }
120 
121         /** Shortcut for `Truth.assertAbout(timeSeries()).that(timeSeries)`. */
122         fun assertThat(timeSeries: TimeSeries): TimeSeriesSubject =
123             Truth.assertAbout(timeSeries()).that(timeSeries)
124     }
125 }
126