1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include <array> 20 #include <cstddef> 21 #include <cstdint> 22 #include <memory> 23 #include <optional> 24 #include <span> 25 26 #include <android-base/mapped_file.h> 27 #include <input/RingBuffer.h> 28 #include <utils/Timers.h> 29 30 #include <tensorflow/lite/core/api/error_reporter.h> 31 #include <tensorflow/lite/interpreter.h> 32 #include <tensorflow/lite/model.h> 33 #include <tensorflow/lite/signature_runner.h> 34 35 namespace android { 36 37 struct TfLiteMotionPredictorSample { 38 // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample. 39 struct Point { 40 float x; 41 float y; 42 } position; 43 // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION. 44 float pressure; 45 float tilt; 46 float orientation; 47 }; 48 49 inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs, 50 const TfLiteMotionPredictorSample::Point& rhs) { 51 return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y}; 52 } 53 54 class TfLiteMotionPredictorModel; 55 56 // Buffer storage for a TfLiteMotionPredictorModel. 57 class TfLiteMotionPredictorBuffers { 58 public: 59 // Creates buffer storage for a model with the given input length. 60 TfLiteMotionPredictorBuffers(size_t inputLength); 61 62 // Adds a motion sample to the buffers. 63 void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample); 64 65 // Returns true if the buffers are complete enough to generate a prediction. isReady()66 bool isReady() const { 67 // Predictions can't be applied unless there are at least two points to determine 68 // the direction to apply them in. 69 return mAxisFrom && mAxisTo; 70 } 71 72 // Resets all buffers to their initial state. 73 void reset(); 74 75 // Copies the buffers to those of a model for prediction. 76 void copyTo(TfLiteMotionPredictorModel& model) const; 77 78 // Returns the current axis of the buffer's samples. Only valid if isReady(). axisFrom()79 TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; } axisTo()80 TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; } 81 82 // Returns the timestamp of the last sample. lastTimestamp()83 int64_t lastTimestamp() const { return mTimestamp; } 84 85 private: 86 int64_t mTimestamp = 0; 87 88 RingBuffer<float> mInputR; 89 RingBuffer<float> mInputPhi; 90 RingBuffer<float> mInputPressure; 91 RingBuffer<float> mInputTilt; 92 RingBuffer<float> mInputOrientation; 93 94 // The samples defining the current polar axis. 95 std::optional<TfLiteMotionPredictorSample> mAxisFrom; 96 std::optional<TfLiteMotionPredictorSample> mAxisTo; 97 }; 98 99 // A TFLite model for generating motion predictions. 100 class TfLiteMotionPredictorModel { 101 public: 102 struct Config { 103 // The time between predictions. 104 nsecs_t predictionInterval = 0; 105 // The noise floor for predictions. 106 // Distances (r) less than this should be discarded as noise. 107 float distanceNoiseFloor = 0; 108 109 // Low and high jerk thresholds (with normalized dt = 1) for predictions. 110 // High jerk means more predictions will be pruned, vice versa for low. 111 float lowJerk = 0; 112 float highJerk = 0; 113 }; 114 115 // Creates a model from an encoded Flatbuffer model. 116 static std::unique_ptr<TfLiteMotionPredictorModel> create(); 117 118 ~TfLiteMotionPredictorModel(); 119 120 // Returns the length of the model's input buffers. 121 size_t inputLength() const; 122 123 // Returns the length of the model's output buffers. 124 size_t outputLength() const; 125 config()126 const Config& config() const { return mConfig; } 127 128 // Executes the model. 129 // Returns true if the model successfully executed and the output tensors can be read. 130 bool invoke(); 131 132 // Returns mutable buffers to the input tensors of inputLength() elements. 133 std::span<float> inputR(); 134 std::span<float> inputPhi(); 135 std::span<float> inputPressure(); 136 std::span<float> inputOrientation(); 137 std::span<float> inputTilt(); 138 139 // Returns immutable buffers to the output tensors of identical length. Only valid after a 140 // successful call to invoke(). 141 std::span<const float> outputR() const; 142 std::span<const float> outputPhi() const; 143 std::span<const float> outputPressure() const; 144 145 private: 146 explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model, 147 Config config); 148 149 void allocateTensors(); 150 void attachInputTensors(); 151 void attachOutputTensors(); 152 153 TfLiteTensor* mInputR = nullptr; 154 TfLiteTensor* mInputPhi = nullptr; 155 TfLiteTensor* mInputPressure = nullptr; 156 TfLiteTensor* mInputTilt = nullptr; 157 TfLiteTensor* mInputOrientation = nullptr; 158 159 const TfLiteTensor* mOutputR = nullptr; 160 const TfLiteTensor* mOutputPhi = nullptr; 161 const TfLiteTensor* mOutputPressure = nullptr; 162 163 std::unique_ptr<android::base::MappedFile> mFlatBuffer; 164 std::unique_ptr<tflite::ErrorReporter> mErrorReporter; 165 std::unique_ptr<tflite::FlatBufferModel> mModel; 166 std::unique_ptr<tflite::Interpreter> mInterpreter; 167 tflite::SignatureRunner* mRunner = nullptr; 168 169 const Config mConfig = {}; 170 }; 171 172 } // namespace android 173