1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "MotionPredictor"
18 
19 #include <input/MotionPredictor.h>
20 
21 #include <algorithm>
22 #include <array>
23 #include <cinttypes>
24 #include <cmath>
25 #include <cstddef>
26 #include <cstdint>
27 #include <limits>
28 #include <optional>
29 #include <string>
30 #include <utility>
31 #include <vector>
32 
33 #include <android-base/logging.h>
34 #include <android-base/strings.h>
35 #include <android/input.h>
36 #include <com_android_input_flags.h>
37 
38 #include <attestation/HmacKeyManager.h>
39 #include <ftl/enum.h>
40 #include <input/TfLiteMotionPredictor.h>
41 
42 namespace input_flags = com::android::input::flags;
43 
44 namespace android {
45 namespace {
46 
47 /**
48  * Log debug messages about predictions.
49  * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG"
50  */
isDebug()51 bool isDebug() {
52     return __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG, ANDROID_LOG_INFO);
53 }
54 
55 // Converts a prediction of some polar (r, phi) to Cartesian (x, y) when applied to an axis.
convertPrediction(const TfLiteMotionPredictorSample::Point & axisFrom,const TfLiteMotionPredictorSample::Point & axisTo,float r,float phi)56 TfLiteMotionPredictorSample::Point convertPrediction(
57         const TfLiteMotionPredictorSample::Point& axisFrom,
58         const TfLiteMotionPredictorSample::Point& axisTo, float r, float phi) {
59     const TfLiteMotionPredictorSample::Point axis = axisTo - axisFrom;
60     const float axis_phi = std::atan2(axis.y, axis.x);
61     const float x_delta = r * std::cos(axis_phi + phi);
62     const float y_delta = r * std::sin(axis_phi + phi);
63     return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
64 }
65 
normalizeRange(float x,float min,float max)66 float normalizeRange(float x, float min, float max) {
67     const float normalized = (x - min) / (max - min);
68     return std::min(1.0f, std::max(0.0f, normalized));
69 }
70 
71 } // namespace
72 
73 // --- JerkTracker ---
74 
JerkTracker(bool normalizedDt)75 JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {}
76 
pushSample(int64_t timestamp,float xPos,float yPos)77 void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
78     mTimestamps.pushBack(timestamp);
79     const int numSamples = mTimestamps.size();
80 
81     std::array<float, 4> newXDerivatives;
82     std::array<float, 4> newYDerivatives;
83 
84     /**
85      * Diagram showing the calculation of higher order derivatives of sample x3
86      * collected at time=t3.
87      * Terms in parentheses are not stored (and not needed for calculations)
88      *  t0 ----- t1  ----- t2 ----- t3
89      * (x0)-----(x1) ----- x2 ----- x3
90      * (x'0) --- x'1 ---  x'2
91      *  x''0  -  x''1
92      *  x'''0
93      *
94      * In this example:
95      * x'2 = (x3 - x2) / (t3 - t2)
96      * x''1 = (x'2 - x'1) / (t2 - t1)
97      * x'''0 = (x''1 - x''0) / (t1 - t0)
98      * Therefore, timestamp history is needed to calculate higher order derivatives,
99      * compared to just the last calculated derivative sample.
100      *
101      * If mNormalizedDt = true, then dt = 1 and the division is moot.
102      */
103     for (int i = 0; i < numSamples; ++i) {
104         if (i == 0) {
105             newXDerivatives[i] = xPos;
106             newYDerivatives[i] = yPos;
107         } else {
108             newXDerivatives[i] = newXDerivatives[i - 1] - mXDerivatives[i - 1];
109             newYDerivatives[i] = newYDerivatives[i - 1] - mYDerivatives[i - 1];
110             if (!mNormalizedDt) {
111                 const float dt = mTimestamps[numSamples - i] - mTimestamps[numSamples - i - 1];
112                 newXDerivatives[i] = newXDerivatives[i] / dt;
113                 newYDerivatives[i] = newYDerivatives[i] / dt;
114             }
115         }
116     }
117 
118     std::swap(newXDerivatives, mXDerivatives);
119     std::swap(newYDerivatives, mYDerivatives);
120 }
121 
reset()122 void JerkTracker::reset() {
123     mTimestamps.clear();
124 }
125 
jerkMagnitude() const126 std::optional<float> JerkTracker::jerkMagnitude() const {
127     if (mTimestamps.size() == mTimestamps.capacity()) {
128         return std::hypot(mXDerivatives[3], mYDerivatives[3]);
129     }
130     return std::nullopt;
131 }
132 
133 // --- MotionPredictor ---
134 
MotionPredictor(nsecs_t predictionTimestampOffsetNanos,std::function<bool ()> checkMotionPredictionEnabled,ReportAtomFunction reportAtomFunction)135 MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
136                                  std::function<bool()> checkMotionPredictionEnabled,
137                                  ReportAtomFunction reportAtomFunction)
138       : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
139         mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
140         mReportAtomFunction(reportAtomFunction) {}
141 
record(const MotionEvent & event)142 android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
143     if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
144         // We still have an active gesture for another device. The provided MotionEvent is not
145         // consistent with the previous gesture.
146         LOG(ERROR) << "Inconsistent event stream: last event is " << *mLastEvent << ", but "
147                    << __func__ << " is called with " << event;
148         return android::base::Error()
149                 << "Inconsistent event stream: still have an active gesture from device "
150                 << mLastEvent->getDeviceId() << ", but received " << event;
151     }
152     if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
153         ALOGE("Prediction not supported for device %d's %s source", event.getDeviceId(),
154               inputEventSourceToString(event.getSource()).c_str());
155         return {};
156     }
157 
158     // Initialise the model now that it's likely to be used.
159     if (!mModel) {
160         mModel = TfLiteMotionPredictorModel::create();
161         LOG_ALWAYS_FATAL_IF(!mModel);
162     }
163 
164     if (!mBuffers) {
165         mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
166     }
167 
168     // Pass input event to the MetricsManager.
169     if (!mMetricsManager) {
170         mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength(),
171                                 mReportAtomFunction);
172     }
173     mMetricsManager->onRecord(event);
174 
175     const int32_t action = event.getActionMasked();
176     if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
177         ALOGD_IF(isDebug(), "End of event stream");
178         mBuffers->reset();
179         mJerkTracker.reset();
180         mLastEvent.reset();
181         return {};
182     } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
183         ALOGD_IF(isDebug(), "Skipping unsupported %s action",
184                  MotionEvent::actionToString(action).c_str());
185         return {};
186     }
187 
188     if (event.getPointerCount() != 1) {
189         ALOGD_IF(isDebug(), "Prediction not supported for multiple pointers");
190         return {};
191     }
192 
193     const ToolType toolType = event.getPointerProperties(0)->toolType;
194     if (toolType != ToolType::STYLUS) {
195         ALOGD_IF(isDebug(), "Prediction not supported for non-stylus tool: %s",
196                  ftl::enum_string(toolType).c_str());
197         return {};
198     }
199 
200     for (size_t i = 0; i <= event.getHistorySize(); ++i) {
201         if (event.isResampled(0, i)) {
202             continue;
203         }
204         const PointerCoords* coords = event.getHistoricalRawPointerCoords(0, i);
205         mBuffers->pushSample(event.getHistoricalEventTime(i),
206                              {
207                                      .position.x = coords->getAxisValue(AMOTION_EVENT_AXIS_X),
208                                      .position.y = coords->getAxisValue(AMOTION_EVENT_AXIS_Y),
209                                      .pressure = event.getHistoricalPressure(0, i),
210                                      .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT,
211                                                                           0, i),
212                                      .orientation = event.getHistoricalOrientation(0, i),
213                              });
214         mJerkTracker.pushSample(event.getHistoricalEventTime(i),
215                                 coords->getAxisValue(AMOTION_EVENT_AXIS_X),
216                                 coords->getAxisValue(AMOTION_EVENT_AXIS_Y));
217     }
218 
219     if (!mLastEvent) {
220         mLastEvent = MotionEvent();
221     }
222     mLastEvent->copyFrom(&event, /*keepHistory=*/false);
223 
224     return {};
225 }
226 
predict(nsecs_t timestamp)227 std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
228     if (mBuffers == nullptr || !mBuffers->isReady()) {
229         return nullptr;
230     }
231 
232     LOG_ALWAYS_FATAL_IF(!mModel);
233     mBuffers->copyTo(*mModel);
234     LOG_ALWAYS_FATAL_IF(!mModel->invoke());
235 
236     // Read out the predictions.
237     const std::span<const float> predictedR = mModel->outputR();
238     const std::span<const float> predictedPhi = mModel->outputPhi();
239     const std::span<const float> predictedPressure = mModel->outputPressure();
240 
241     TfLiteMotionPredictorSample::Point axisFrom = mBuffers->axisFrom().position;
242     TfLiteMotionPredictorSample::Point axisTo = mBuffers->axisTo().position;
243 
244     if (isDebug()) {
245         ALOGD("axisFrom: %f, %f", axisFrom.x, axisFrom.y);
246         ALOGD("axisTo: %f, %f", axisTo.x, axisTo.y);
247         ALOGD("mInputR: %s", base::Join(mModel->inputR(), ", ").c_str());
248         ALOGD("mInputPhi: %s", base::Join(mModel->inputPhi(), ", ").c_str());
249         ALOGD("mInputPressure: %s", base::Join(mModel->inputPressure(), ", ").c_str());
250         ALOGD("mInputTilt: %s", base::Join(mModel->inputTilt(), ", ").c_str());
251         ALOGD("mInputOrientation: %s", base::Join(mModel->inputOrientation(), ", ").c_str());
252         ALOGD("predictedR: %s", base::Join(predictedR, ", ").c_str());
253         ALOGD("predictedPhi: %s", base::Join(predictedPhi, ", ").c_str());
254         ALOGD("predictedPressure: %s", base::Join(predictedPressure, ", ").c_str());
255     }
256 
257     LOG_ALWAYS_FATAL_IF(!mLastEvent);
258     const MotionEvent& event = *mLastEvent;
259     bool hasPredictions = false;
260     std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
261     int64_t predictionTime = mBuffers->lastTimestamp();
262     const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
263 
264     const float jerkMagnitude = mJerkTracker.jerkMagnitude().value_or(0);
265     const float fractionKept =
266             1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk);
267     // float to ensure proper division below.
268     const float predictionTimeWindow = futureTime - predictionTime;
269     const int maxNumPredictions = static_cast<int>(
270             std::ceil(predictionTimeWindow / mModel->config().predictionInterval * fractionKept));
271     ALOGD_IF(isDebug(),
272              "jerk (d^3p/normalizedDt^3): %f, fraction of prediction window pruned: %f, max number "
273              "of predictions: %d",
274              jerkMagnitude, 1 - fractionKept, maxNumPredictions);
275     for (size_t i = 0; i < static_cast<size_t>(predictedR.size()) && predictionTime <= futureTime;
276          ++i) {
277         if (predictedR[i] < mModel->config().distanceNoiseFloor) {
278             // Stop predicting when the predicted output is below the model's noise floor.
279             //
280             // We assume that all subsequent predictions in the batch are unreliable because later
281             // predictions are conditional on earlier predictions, and a state of noise is not a
282             // good basis for prediction.
283             //
284             // The UX trade-off is that this potentially sacrifices some predictions when the input
285             // device starts to speed up, but avoids producing noisy predictions as it slows down.
286             break;
287         }
288         if (input_flags::enable_prediction_pruning_via_jerk_thresholding()) {
289             if (i >= static_cast<size_t>(maxNumPredictions)) {
290                 break;
291             }
292         }
293         // TODO(b/266747654): Stop predictions if confidence is < some
294         // threshold. Currently predictions are pruned via jerk thresholding.
295 
296         const TfLiteMotionPredictorSample::Point predictedPoint =
297                 convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
298 
299         ALOGD_IF(isDebug(), "prediction %zu: %f, %f", i, predictedPoint.x, predictedPoint.y);
300         PointerCoords coords;
301         coords.clear();
302         coords.setAxisValue(AMOTION_EVENT_AXIS_X, predictedPoint.x);
303         coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y);
304         coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
305         // Copy forward tilt and orientation from the last event until they are predicted
306         // (b/291789258).
307         coords.setAxisValue(AMOTION_EVENT_AXIS_TILT,
308                             event.getAxisValue(AMOTION_EVENT_AXIS_TILT, 0));
309         coords.setAxisValue(AMOTION_EVENT_AXIS_ORIENTATION,
310                             event.getRawPointerCoords(0)->getAxisValue(
311                                     AMOTION_EVENT_AXIS_ORIENTATION));
312 
313         predictionTime += mModel->config().predictionInterval;
314         if (i == 0) {
315             hasPredictions = true;
316             prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
317                                    event.getDisplayId(), INVALID_HMAC, AMOTION_EVENT_ACTION_MOVE,
318                                    event.getActionButton(), event.getFlags(), event.getEdgeFlags(),
319                                    event.getMetaState(), event.getButtonState(),
320                                    event.getClassification(), event.getTransform(),
321                                    event.getXPrecision(), event.getYPrecision(),
322                                    event.getRawXCursorPosition(), event.getRawYCursorPosition(),
323                                    event.getRawTransform(), event.getDownTime(), predictionTime,
324                                    event.getPointerCount(), event.getPointerProperties(), &coords);
325         } else {
326             prediction->addSample(predictionTime, &coords);
327         }
328 
329         axisFrom = axisTo;
330         axisTo = predictedPoint;
331     }
332 
333     if (!hasPredictions) {
334         return nullptr;
335     }
336 
337     // Pass predictions to the MetricsManager.
338     LOG_ALWAYS_FATAL_IF(!mMetricsManager);
339     mMetricsManager->onPredict(*prediction);
340 
341     return prediction;
342 }
343 
isPredictionAvailable(int32_t,int32_t source)344 bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source) {
345     // Global flag override
346     if (!mCheckMotionPredictionEnabled()) {
347         ALOGD_IF(isDebug(), "Prediction not available due to flag override");
348         return false;
349     }
350 
351     // Prediction is only supported for stylus sources.
352     if (!isFromSource(source, AINPUT_SOURCE_STYLUS)) {
353         ALOGD_IF(isDebug(), "Prediction not available for non-stylus source: %s",
354                  inputEventSourceToString(source).c_str());
355         return false;
356     }
357     return true;
358 }
359 
360 } // namespace android
361