1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <android-base/chrono_utils.h>
17 
18 #include <cstddef>
19 #include <mutex>
20 #include <string>
21 
22 #pragma once
23 
24 namespace thermal {
25 namespace vtestimator {
26 
27 using android::base::boot_clock;
28 
29 // Current version only supports single input/output tensors
30 constexpr int kNumInputTensors = 1;
31 constexpr int kNumOutputTensors = 1;
32 
33 typedef void *(*tflitewrapper_create)(int num_input_tensors, int num_output_tensors);
34 typedef bool (*tflitewrapper_init)(void *handle, const char *model_path);
35 typedef bool (*tflitewrapper_invoke)(void *handle, float *input_samples, int num_input_samples,
36                                      float *output_samples, int num_output_samples);
37 typedef void (*tflitewrapper_destroy)(void *handle);
38 typedef bool (*tflitewrapper_get_input_config_size)(void *handle, int *config_size);
39 typedef bool (*tflitewrapper_get_input_config)(void *handle, char *config_buffer,
40                                                int config_buffer_size);
41 
42 struct TFLiteWrapperMethods {
43     tflitewrapper_create create;
44     tflitewrapper_init init;
45     tflitewrapper_invoke invoke;
46     tflitewrapper_destroy destroy;
47     tflitewrapper_get_input_config_size get_input_config_size;
48     tflitewrapper_get_input_config get_input_config;
49     mutable std::mutex mutex;
50 };
51 
52 struct InputRangeInfo {
53     float max_threshold = std::numeric_limits<float>::max();
54     float min_threshold = std::numeric_limits<float>::min();
55 };
56 
57 struct VtEstimatorCommonData {
VtEstimatorCommonDataVtEstimatorCommonData58     VtEstimatorCommonData(std::string_view name, size_t num_input_sensors) {
59         sensor_name = name;
60         num_linked_sensors = num_input_sensors;
61         prev_samples_order = 1;
62         is_initialized = false;
63         use_prev_samples = false;
64         cur_sample_count = 0;
65     }
66     std::string sensor_name;
67 
68     std::vector<float> offset_thresholds;
69     std::vector<float> offset_values;
70 
71     size_t num_linked_sensors;
72     size_t prev_samples_order;
73     size_t cur_sample_count;
74     bool use_prev_samples;
75     bool is_initialized;
76 };
77 
78 struct VtEstimatorTFLiteData {
VtEstimatorTFLiteDataVtEstimatorTFLiteData79     VtEstimatorTFLiteData() {
80         scratch_buffer = nullptr;
81         input_buffer = nullptr;
82         input_buffer_size = 0;
83         output_label_count = 1;
84         num_hot_spots = 1;
85         output_buffer = nullptr;
86         output_buffer_size = 1;
87         support_under_sampling = false;
88         sample_interval = std::chrono::milliseconds{0};
89         max_sample_interval = std::chrono::milliseconds{std::numeric_limits<int>::max()};
90         predict_window_ms = 0;
91         last_update_time = boot_clock::time_point::min();
92         prev_sample_time = boot_clock::time_point::min();
93         enable_input_validation = false;
94 
95         tflite_wrapper = nullptr;
96         tflite_methods.create = nullptr;
97         tflite_methods.init = nullptr;
98         tflite_methods.get_input_config_size = nullptr;
99         tflite_methods.get_input_config = nullptr;
100         tflite_methods.invoke = nullptr;
101         tflite_methods.destroy = nullptr;
102     }
103 
104     void *tflite_wrapper;
105     float *scratch_buffer;
106     float *input_buffer;
107     size_t input_buffer_size;
108     size_t num_hot_spots;
109     size_t output_label_count;
110     float *output_buffer;
111     size_t output_buffer_size;
112     std::string model_path;
113     TFLiteWrapperMethods tflite_methods;
114     std::vector<InputRangeInfo> input_range;
115     bool support_under_sampling;
116     std::chrono::milliseconds sample_interval{};
117     std::chrono::milliseconds max_sample_interval{};
118     size_t predict_window_ms;
119     boot_clock::time_point last_update_time;
120     boot_clock::time_point prev_sample_time;
121     bool enable_input_validation;
122 
~VtEstimatorTFLiteDataVtEstimatorTFLiteData123     ~VtEstimatorTFLiteData() {
124         if (tflite_wrapper && tflite_methods.destroy) {
125             tflite_methods.destroy(tflite_wrapper);
126         }
127 
128         if (scratch_buffer) {
129             delete scratch_buffer;
130         }
131 
132         if (input_buffer) {
133             delete input_buffer;
134         }
135 
136         if (output_buffer) {
137             delete output_buffer;
138         }
139     }
140 };
141 
142 struct VtEstimatorLinearModelData {
VtEstimatorLinearModelDataVtEstimatorLinearModelData143     VtEstimatorLinearModelData() {}
144 
~VtEstimatorLinearModelDataVtEstimatorLinearModelData145     ~VtEstimatorLinearModelData() {}
146 
147     std::vector<std::vector<float>> input_samples;
148     std::vector<std::vector<float>> coefficients;
149     mutable std::mutex mutex;
150 };
151 
152 }  // namespace vtestimator
153 }  // namespace thermal
154