1 /* 2 * Copyright 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <cstddef> 18 #include <cstdint> 19 #include <functional> 20 #include <limits> 21 #include <vector> 22 23 #include <input/Input.h> // for MotionEvent 24 #include <input/RingBuffer.h> 25 #include <utils/Timers.h> // for nsecs_t 26 27 #include "Eigen/Core" 28 29 namespace android { 30 31 /** 32 * Class to handle computing and reporting metrics for MotionPredictor. 33 * 34 * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the 35 * MotionEvents from the corresponding methods in MotionPredictor. 36 * 37 * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When 38 * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final 39 * AtomFields are computed and reported to the stats library. The number of atoms reported is equal 40 * to the value of `maxNumPredictions` passed to the constructor. Each atom corresponds to one 41 * "prediction time bucket" — the amount of time into the future being predicted. 42 * 43 * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library 44 * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported. 45 */ 46 class MotionPredictorMetricsManager { 47 public: 48 struct AtomFields; 49 50 using ReportAtomFunction = std::function<void(const AtomFields&)>; 51 52 static void defaultReportAtomFunction(const AtomFields& atomFields); 53 54 // Parameters: 55 // • predictionInterval: the time interval between successive prediction target timestamps. 56 // Note: the MetricsManager assumes that the input interval equals the prediction interval. 57 // • maxNumPredictions: the maximum number of distinct target timestamps the prediction model 58 // will generate predictions for. The MetricsManager reports this many atoms per stroke. 59 // • [Optional] reportAtomFunction: the function that will be called to report metrics. If 60 // omitted (or if an empty function is given), the `stats_write(…)` function from the Android 61 // stats library will be used. 62 MotionPredictorMetricsManager( 63 nsecs_t predictionInterval, 64 size_t maxNumPredictions, 65 ReportAtomFunction reportAtomFunction = defaultReportAtomFunction); 66 67 // This method should be called once for each call to MotionPredictor::record, receiving the 68 // forwarded MotionEvent argument. 69 void onRecord(const MotionEvent& inputEvent); 70 71 // This method should be called once for each call to MotionPredictor::predict, receiving the 72 // MotionEvent that will be returned by MotionPredictor::predict. 73 void onPredict(const MotionEvent& predictionEvent); 74 75 // Simple structs to hold relevant touch input information. Public so they can be used in tests. 76 77 struct TouchPoint { 78 Eigen::Vector2f position; // (y, x) in pixels 79 float pressure; 80 }; 81 82 struct GroundTruthPoint : TouchPoint { 83 nsecs_t timestamp; 84 }; 85 86 struct PredictionPoint : TouchPoint { 87 // The timestamp of the last ground truth point when the prediction was made. 88 nsecs_t originTimestamp; 89 90 nsecs_t targetTimestamp; 91 92 // Order by targetTimestamp when sorting. 93 bool operator<(const PredictionPoint& other) const { 94 return this->targetTimestamp < other.targetTimestamp; 95 } 96 }; 97 98 // Metrics aggregated so far for the current stroke. These are not the final fields to be 99 // reported in the atom (see AtomFields below), but rather an intermediate representation of the 100 // data that can be conveniently aggregated and from which the atom fields can be derived later. 101 // 102 // Displacement units are in pixels. 103 // 104 // "Along-trajectory error" is the dot product of the prediction error with the unit vector 105 // pointing towards the ground truth point whose timestamp corresponds to the prediction 106 // target timestamp, originating from the preceding ground truth point. 107 // 108 // "Off-trajectory error" is the component of the prediction error orthogonal to the 109 // "along-trajectory" unit vector described above. 110 // 111 // "High-velocity" errors are errors that are only accumulated when the velocity between the 112 // most recent two input events exceeds a certain threshold. 113 // 114 // "Scale-invariant errors" are the errors produced when the path length of the stroke is 115 // scaled to 1. (In other words, the error distances are normalized by the path length.) 116 struct AggregatedStrokeMetrics { 117 // General errors 118 float alongTrajectoryErrorSum = 0; 119 float alongTrajectorySumSquaredErrors = 0; 120 float offTrajectorySumSquaredErrors = 0; 121 float pressureSumSquaredErrors = 0; 122 size_t generalErrorsCount = 0; 123 124 // High-velocity errors 125 float highVelocityAlongTrajectorySse = 0; 126 float highVelocityOffTrajectorySse = 0; 127 size_t highVelocityErrorsCount = 0; 128 129 // Scale-invariant errors 130 float scaleInvariantAlongTrajectorySse = 0; 131 float scaleInvariantOffTrajectorySse = 0; 132 size_t scaleInvariantErrorsCount = 0; 133 }; 134 135 // In order to explicitly indicate "no relevant data" for a metric, we report this 136 // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is 137 // completely unobtainable. For along-trajectory error mean, which can be negative, the 138 // magnitude makes it unobtainable in practice.) 139 static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min(); 140 141 // Final metric values reported in the atom. 142 struct AtomFields { 143 int deltaTimeBucketMilliseconds = 0; 144 145 // General errors 146 int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL; 147 int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL; 148 int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL; 149 int pressureRmseMilliunits = NO_DATA_SENTINEL; 150 151 // High-velocity errors 152 int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels 153 int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels 154 155 // Scale-invariant errors 156 int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels 157 int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels 158 }; 159 160 private: 161 // The interval between consecutive predictions' target timestamps. We assume that the input 162 // interval also equals this value. 163 const nsecs_t mPredictionInterval; 164 165 // The maximum number of input frames into the future the model can predict. 166 // Used to perform time-bucketing of metrics. 167 const size_t mMaxNumPredictions; 168 169 // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant 170 // error. (Also, the last two points are used to compute the ground truth trajectory.) 171 RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints; 172 173 // Predictions having a targetTimestamp after the most recent ground truth point's timestamp. 174 // Invariant: sorted in ascending order of targetTimestamp. 175 std::vector<PredictionPoint> mRecentPredictions; 176 177 // Containers for the intermediate representation of stroke metrics and the final atom fields. 178 // These are indexed by the number of input frames into the future being predicted minus one, 179 // and always have size mMaxNumPredictions. 180 std::vector<AggregatedStrokeMetrics> mAggregatedMetrics; 181 std::vector<AtomFields> mAtomFields; 182 183 const ReportAtomFunction mReportAtomFunction; 184 185 // Helper methods for the implementation of onRecord and onPredict. 186 187 // Clears stored ground truth and prediction points, as well as all stored metrics for the 188 // current stroke. 189 void clearStrokeData(); 190 191 // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from 192 // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that 193 // fuzzily match with the new ground truth point. 194 void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint); 195 196 // Given a new prediction with targetTimestamp matching the latest ground truth point's 197 // timestamp, computes the corresponding metrics and updates mAggregatedMetrics. 198 void updateAggregatedMetrics(const PredictionPoint& predictionPoint); 199 200 // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics. 201 void computeAtomFields(); 202 203 // Reports the current data in mAtomFields by calling mReportAtomFunction. 204 void reportMetrics(); 205 }; 206 207 } // namespace android 208