1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <json/value.h>
19 
20 #include <sstream>
21 #include <vector>
22 
23 #include "virtualtemp_estimator_data.h"
24 
25 namespace thermal {
26 namespace vtestimator {
27 
28 enum VtEstimatorStatus {
29     kVtEstimatorOk = 0,
30     kVtEstimatorInvalidArgs = 1,
31     kVtEstimatorInitFailed = 2,
32     kVtEstimatorInvokeFailed = 3,
33     kVtEstimatorUnSupported = 4,
34     kVtEstimatorLowConfidence = 5,
35     kVtEstimatorUnderSampling = 6,
36 };
37 
38 enum VtEstimationType { kUseMLModel = 0, kUseLinearModel = 1, kInvalidEstimationType = 2 };
39 
40 struct MLModelInitData {
41     std::string model_path;
42     bool use_prev_samples;
43     size_t prev_samples_order;
44     size_t output_label_count;
45     size_t num_hot_spots;
46     bool enable_input_validation;
47     std::vector<float> offset_thresholds;
48     std::vector<float> offset_values;
49     bool support_under_sampling;
50 };
51 
52 struct LinearModelInitData {
53     bool use_prev_samples;
54     size_t prev_samples_order;
55     std::vector<float> coefficients;
56     std::vector<float> offset_thresholds;
57     std::vector<float> offset_values;
58 };
59 
60 union VtEstimationInitData {
VtEstimationInitData(VtEstimationType type)61     VtEstimationInitData(VtEstimationType type) {
62         if (type == kUseMLModel) {
63             ml_model_init_data.model_path = "";
64             ml_model_init_data.use_prev_samples = false;
65             ml_model_init_data.prev_samples_order = 1;
66             ml_model_init_data.output_label_count = 1;
67             ml_model_init_data.num_hot_spots = 1;
68             ml_model_init_data.enable_input_validation = false;
69             ml_model_init_data.support_under_sampling = false;
70         } else if (type == kUseLinearModel) {
71             linear_model_init_data.use_prev_samples = false;
72             linear_model_init_data.prev_samples_order = 1;
73         }
74     }
~VtEstimationInitData()75     ~VtEstimationInitData() {}
76 
77     MLModelInitData ml_model_init_data;
78     LinearModelInitData linear_model_init_data;
79 };
80 
81 // Class to estimate virtual temperature
82 class VirtualTempEstimator {
83   public:
84     // Implicit copy-move headers.
85     VirtualTempEstimator(const VirtualTempEstimator &) = delete;
86     VirtualTempEstimator(VirtualTempEstimator &&) = default;
87     VirtualTempEstimator &operator=(const VirtualTempEstimator &) = delete;
88     VirtualTempEstimator &operator=(VirtualTempEstimator &&) = default;
89 
90     VirtualTempEstimator(std::string_view sensor_name, VtEstimationType type,
91                          size_t num_linked_sensors);
92     ~VirtualTempEstimator();
93 
94     // Initializes the estimator based on init_data
95     VtEstimatorStatus Initialize(const VtEstimationInitData &init_data);
96 
97     // Performs the prediction and returns estimated value in output
98     VtEstimatorStatus Estimate(const std::vector<float> &thermistors, std::vector<float> *output);
99 
100     // Dump estimator status
101     VtEstimatorStatus DumpStatus(std::string_view sensor_name, std::ostringstream *dump_buf);
102     // Get predict window width in milliseconds
103     VtEstimatorStatus GetMaxPredictWindowMs(size_t *predict_window_ms);
104     // Predict temperature after desired milliseconds
105     VtEstimatorStatus PredictAfterTimeMs(const size_t time_ms, float *output);
106     // Get entire output buffer of the estimator
107     VtEstimatorStatus GetAllPredictions(std::vector<float> *output);
108 
109     // Adds traces to help debug
110     VtEstimatorStatus DumpTraces();
111 
112   private:
113     void LoadTFLiteWrapper();
114     VtEstimationType type;
115     std::unique_ptr<VtEstimatorCommonData> common_instance_;
116     std::unique_ptr<VtEstimatorTFLiteData> tflite_instance_;
117     std::unique_ptr<VtEstimatorLinearModelData> linear_model_instance_;
118 
119     VtEstimatorStatus LinearModelInitialize(LinearModelInitData data);
120     VtEstimatorStatus TFliteInitialize(MLModelInitData data);
121 
122     VtEstimatorStatus LinearModelEstimate(const std::vector<float> &thermistors,
123                                           std::vector<float> *output);
124     VtEstimatorStatus TFliteEstimate(const std::vector<float> &thermistors,
125                                      std::vector<float> *output);
126     VtEstimatorStatus TFliteGetMaxPredictWindowMs(size_t *predict_window_ms);
127     VtEstimatorStatus TFlitePredictAfterTimeMs(const size_t time_ms, float *output);
128     VtEstimatorStatus TFliteGetAllPredictions(std::vector<float> *output);
129 
130     VtEstimatorStatus TFLiteDumpStatus(std::string_view sensor_name, std::ostringstream *dump_buf);
131     bool GetInputConfig(Json::Value *config);
132     bool ParseInputConfig(const Json::Value &config);
133 };
134 
135 }  // namespace vtestimator
136 }  // namespace thermal
137