1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define ATRACE_TAG (ATRACE_TAG_THERMAL | ATRACE_TAG_HAL)
17 
18 #include "virtualtemp_estimator.h"
19 
20 #include <android-base/logging.h>
21 #include <android-base/stringprintf.h>
22 #include <dlfcn.h>
23 #include <json/reader.h>
24 #include <utils/Trace.h>
25 
26 #include <cmath>
27 #include <sstream>
28 #include <vector>
29 
30 namespace thermal {
31 namespace vtestimator {
32 namespace {
getFloatFromValue(const Json::Value & value)33 float getFloatFromValue(const Json::Value &value) {
34     if (value.isString()) {
35         return std::atof(value.asString().c_str());
36     } else {
37         return value.asFloat();
38     }
39 }
40 
getInputRangeInfoFromJsonValues(const Json::Value & values,InputRangeInfo * input_range_info)41 bool getInputRangeInfoFromJsonValues(const Json::Value &values, InputRangeInfo *input_range_info) {
42     if (values.size() != 2) {
43         LOG(ERROR) << "Data Range Values size: " << values.size() << "is invalid.";
44         return false;
45     }
46 
47     float min_val = getFloatFromValue(values[0]);
48     float max_val = getFloatFromValue(values[1]);
49 
50     if (std::isnan(min_val) || std::isnan(max_val)) {
51         LOG(ERROR) << "Illegal data range: thresholds not defined properly " << min_val << " : "
52                    << max_val;
53         return false;
54     }
55 
56     if (min_val > max_val) {
57         LOG(ERROR) << "Illegal data range: data_min_threshold(" << min_val
58                    << ") > data_max_threshold(" << max_val << ")";
59         return false;
60     }
61     input_range_info->min_threshold = min_val;
62     input_range_info->max_threshold = max_val;
63     LOG(INFO) << "Data Range Info: " << input_range_info->min_threshold
64               << " <= val <= " << input_range_info->max_threshold;
65     return true;
66 }
67 
CalculateOffset(const std::vector<float> & offset_thresholds,const std::vector<float> & offset_values,const float value)68 float CalculateOffset(const std::vector<float> &offset_thresholds,
69                       const std::vector<float> &offset_values, const float value) {
70     for (int i = offset_thresholds.size(); i > 0; --i) {
71         if (offset_thresholds[i - 1] < value) {
72             return offset_values[i - 1];
73         }
74     }
75 
76     return 0;
77 }
78 }  // namespace
79 
DumpTraces()80 VtEstimatorStatus VirtualTempEstimator::DumpTraces() {
81     if (type != kUseMLModel) {
82         return kVtEstimatorUnSupported;
83     }
84 
85     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
86         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during DumpTraces\n";
87         return kVtEstimatorInitFailed;
88     }
89 
90     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
91 
92     if (!common_instance_->is_initialized) {
93         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
94         return kVtEstimatorInitFailed;
95     }
96 
97     // get model input/output buffers
98     float *model_input = tflite_instance_->input_buffer;
99     float *model_output = tflite_instance_->output_buffer;
100     auto input_buffer_size = tflite_instance_->input_buffer_size;
101     auto output_buffer_size = tflite_instance_->output_buffer_size;
102 
103     // In Case of use_prev_samples, inputs are available in order in scratch buffer
104     if (common_instance_->use_prev_samples) {
105         model_input = tflite_instance_->scratch_buffer;
106     }
107 
108     // Add traces for model input/output buffers
109     std::string sensor_name = common_instance_->sensor_name;
110     for (size_t i = 0; i < input_buffer_size; ++i) {
111         ATRACE_INT((sensor_name + "_input_" + std::to_string(i)).c_str(),
112                    static_cast<int>(model_input[i]));
113     }
114 
115     for (size_t i = 0; i < output_buffer_size; ++i) {
116         ATRACE_INT((sensor_name + "_output_" + std::to_string(i)).c_str(),
117                    static_cast<int>(model_output[i]));
118     }
119 
120     // log input data and output data buffers
121     std::string input_data_str = "model_input_buffer: [";
122     for (size_t i = 0; i < input_buffer_size; ++i) {
123         input_data_str += ::android::base::StringPrintf("%0.2f ", model_input[i]);
124     }
125     input_data_str += "]";
126     LOG(INFO) << input_data_str;
127 
128     std::string output_data_str = "model_output_buffer: [";
129     for (size_t i = 0; i < output_buffer_size; ++i) {
130         output_data_str += ::android::base::StringPrintf("%0.2f ", model_output[i]);
131     }
132     output_data_str += "]";
133     LOG(INFO) << output_data_str;
134 
135     return kVtEstimatorOk;
136 }
137 
LoadTFLiteWrapper()138 void VirtualTempEstimator::LoadTFLiteWrapper() {
139     if (!tflite_instance_) {
140         LOG(ERROR) << "tflite_instance_ is nullptr during LoadTFLiteWrapper";
141         return;
142     }
143 
144     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
145 
146     void *mLibHandle = dlopen("/vendor/lib64/libthermal_tflite_wrapper.so", 0);
147     if (mLibHandle == nullptr) {
148         LOG(ERROR) << "Could not load libthermal_tflite_wrapper library with error: " << dlerror();
149         return;
150     }
151 
152     tflite_instance_->tflite_methods.create =
153             reinterpret_cast<tflitewrapper_create>(dlsym(mLibHandle, "ThermalTfliteCreate"));
154     if (!tflite_instance_->tflite_methods.create) {
155         LOG(ERROR) << "Could not link and cast tflitewrapper_create with error: " << dlerror();
156     }
157 
158     tflite_instance_->tflite_methods.init =
159             reinterpret_cast<tflitewrapper_init>(dlsym(mLibHandle, "ThermalTfliteInit"));
160     if (!tflite_instance_->tflite_methods.init) {
161         LOG(ERROR) << "Could not link and cast tflitewrapper_init with error: " << dlerror();
162     }
163 
164     tflite_instance_->tflite_methods.invoke =
165             reinterpret_cast<tflitewrapper_invoke>(dlsym(mLibHandle, "ThermalTfliteInvoke"));
166     if (!tflite_instance_->tflite_methods.invoke) {
167         LOG(ERROR) << "Could not link and cast tflitewrapper_invoke with error: " << dlerror();
168     }
169 
170     tflite_instance_->tflite_methods.destroy =
171             reinterpret_cast<tflitewrapper_destroy>(dlsym(mLibHandle, "ThermalTfliteDestroy"));
172     if (!tflite_instance_->tflite_methods.destroy) {
173         LOG(ERROR) << "Could not link and cast tflitewrapper_destroy with error: " << dlerror();
174     }
175 
176     tflite_instance_->tflite_methods.get_input_config_size =
177             reinterpret_cast<tflitewrapper_get_input_config_size>(
178                     dlsym(mLibHandle, "ThermalTfliteGetInputConfigSize"));
179     if (!tflite_instance_->tflite_methods.get_input_config_size) {
180         LOG(ERROR) << "Could not link and cast tflitewrapper_get_input_config_size with error: "
181                    << dlerror();
182     }
183 
184     tflite_instance_->tflite_methods.get_input_config =
185             reinterpret_cast<tflitewrapper_get_input_config>(
186                     dlsym(mLibHandle, "ThermalTfliteGetInputConfig"));
187     if (!tflite_instance_->tflite_methods.get_input_config) {
188         LOG(ERROR) << "Could not link and cast tflitewrapper_get_input_config with error: "
189                    << dlerror();
190     }
191 }
192 
VirtualTempEstimator(std::string_view sensor_name,VtEstimationType estimationType,size_t num_linked_sensors)193 VirtualTempEstimator::VirtualTempEstimator(std::string_view sensor_name,
194                                            VtEstimationType estimationType,
195                                            size_t num_linked_sensors) {
196     type = estimationType;
197 
198     common_instance_ = std::make_unique<VtEstimatorCommonData>(sensor_name, num_linked_sensors);
199     if (estimationType == kUseMLModel) {
200         tflite_instance_ = std::make_unique<VtEstimatorTFLiteData>();
201         LoadTFLiteWrapper();
202     } else if (estimationType == kUseLinearModel) {
203         linear_model_instance_ = std::make_unique<VtEstimatorLinearModelData>();
204     } else {
205         LOG(ERROR) << "Unsupported estimationType [" << estimationType << "]";
206     }
207 }
208 
~VirtualTempEstimator()209 VirtualTempEstimator::~VirtualTempEstimator() {
210     LOG(INFO) << "VirtualTempEstimator destructor";
211 }
212 
LinearModelInitialize(LinearModelInitData data)213 VtEstimatorStatus VirtualTempEstimator::LinearModelInitialize(LinearModelInitData data) {
214     if (linear_model_instance_ == nullptr || common_instance_ == nullptr) {
215         LOG(ERROR) << "linear_model_instance_ or common_instance_ is nullptr during Initialize";
216         return kVtEstimatorInitFailed;
217     }
218 
219     size_t num_linked_sensors = common_instance_->num_linked_sensors;
220     std::unique_lock<std::mutex> lock(linear_model_instance_->mutex);
221 
222     if ((num_linked_sensors == 0) || (data.coefficients.size() == 0) ||
223         (data.prev_samples_order == 0)) {
224         LOG(ERROR) << "Invalid num_linked_sensors [" << num_linked_sensors
225                    << "] or coefficients.size() [" << data.coefficients.size()
226                    << "] or prev_samples_order [" << data.prev_samples_order << "]";
227         return kVtEstimatorInitFailed;
228     }
229 
230     if (data.coefficients.size() != (num_linked_sensors * data.prev_samples_order)) {
231         LOG(ERROR) << "In valid args coefficients.size()[" << data.coefficients.size()
232                    << "] num_linked_sensors [" << num_linked_sensors << "] prev_samples_order["
233                    << data.prev_samples_order << "]";
234         return kVtEstimatorInvalidArgs;
235     }
236 
237     common_instance_->use_prev_samples = data.use_prev_samples;
238     common_instance_->prev_samples_order = data.prev_samples_order;
239 
240     linear_model_instance_->input_samples.reserve(common_instance_->prev_samples_order);
241     linear_model_instance_->coefficients.reserve(common_instance_->prev_samples_order);
242 
243     // Store coefficients
244     for (size_t i = 0; i < data.prev_samples_order; ++i) {
245         std::vector<float> single_order_coefficients;
246         for (size_t j = 0; j < num_linked_sensors; ++j) {
247             single_order_coefficients.emplace_back(data.coefficients[i * num_linked_sensors + j]);
248         }
249         linear_model_instance_->coefficients.emplace_back(single_order_coefficients);
250     }
251 
252     common_instance_->offset_thresholds = data.offset_thresholds;
253     common_instance_->offset_values = data.offset_values;
254     common_instance_->is_initialized = true;
255 
256     return kVtEstimatorOk;
257 }
258 
TFliteInitialize(MLModelInitData data)259 VtEstimatorStatus VirtualTempEstimator::TFliteInitialize(MLModelInitData data) {
260     if (!tflite_instance_ || !common_instance_) {
261         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during Initialize\n";
262         return kVtEstimatorInitFailed;
263     }
264 
265     std::string model_path = data.model_path;
266     size_t num_linked_sensors = common_instance_->num_linked_sensors;
267     bool use_prev_samples = data.use_prev_samples;
268     size_t prev_samples_order = data.prev_samples_order;
269     size_t num_hot_spots = data.num_hot_spots;
270     size_t output_label_count = data.output_label_count;
271 
272     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
273 
274     if (model_path.empty()) {
275         LOG(ERROR) << "Invalid model_path:" << model_path;
276         return kVtEstimatorInvalidArgs;
277     }
278 
279     if (num_linked_sensors == 0 || prev_samples_order < 1 ||
280         (!use_prev_samples && prev_samples_order > 1)) {
281         LOG(ERROR) << "Invalid tflite_instance_ config: "
282                    << "number of linked sensor: " << num_linked_sensors
283                    << " use previous: " << use_prev_samples
284                    << " previous sample order: " << prev_samples_order;
285         return kVtEstimatorInitFailed;
286     }
287 
288     common_instance_->use_prev_samples = data.use_prev_samples;
289     common_instance_->prev_samples_order = prev_samples_order;
290     tflite_instance_->support_under_sampling = data.support_under_sampling;
291     tflite_instance_->enable_input_validation = data.enable_input_validation;
292     tflite_instance_->input_buffer_size = num_linked_sensors * prev_samples_order;
293     tflite_instance_->input_buffer = new float[tflite_instance_->input_buffer_size];
294     if (common_instance_->use_prev_samples) {
295         tflite_instance_->scratch_buffer = new float[tflite_instance_->input_buffer_size];
296     }
297 
298     if (output_label_count < 1 || num_hot_spots < 1) {
299         LOG(ERROR) << "Invalid tflite_instance_ config:"
300                    << "number of hot spots: " << num_hot_spots
301                    << " predicted sample order: " << output_label_count;
302         return kVtEstimatorInitFailed;
303     }
304 
305     tflite_instance_->output_label_count = output_label_count;
306     tflite_instance_->num_hot_spots = num_hot_spots;
307     tflite_instance_->output_buffer_size = output_label_count * num_hot_spots;
308     tflite_instance_->output_buffer = new float[tflite_instance_->output_buffer_size];
309 
310     if (!tflite_instance_->tflite_methods.create || !tflite_instance_->tflite_methods.init ||
311         !tflite_instance_->tflite_methods.invoke || !tflite_instance_->tflite_methods.destroy ||
312         !tflite_instance_->tflite_methods.get_input_config_size ||
313         !tflite_instance_->tflite_methods.get_input_config) {
314         LOG(ERROR) << "Invalid tflite methods";
315         return kVtEstimatorInitFailed;
316     }
317 
318     tflite_instance_->tflite_wrapper =
319             tflite_instance_->tflite_methods.create(kNumInputTensors, kNumOutputTensors);
320     if (!tflite_instance_->tflite_wrapper) {
321         LOG(ERROR) << "Failed to create tflite wrapper";
322         return kVtEstimatorInitFailed;
323     }
324 
325     int ret = tflite_instance_->tflite_methods.init(tflite_instance_->tflite_wrapper,
326                                                     model_path.c_str());
327     if (ret) {
328         LOG(ERROR) << "Failed to Init tflite_wrapper for " << model_path << " (ret: )" << ret
329                    << ")";
330         return kVtEstimatorInitFailed;
331     }
332 
333     Json::Value input_config;
334     if (!GetInputConfig(&input_config)) {
335         LOG(ERROR) << "Get Input Config failed for " << model_path;
336         return kVtEstimatorInitFailed;
337     }
338 
339     if (!ParseInputConfig(input_config)) {
340         LOG(ERROR) << "Parse Input Config failed for " << model_path;
341         return kVtEstimatorInitFailed;
342     }
343 
344     if (tflite_instance_->enable_input_validation && !tflite_instance_->input_range.size()) {
345         LOG(ERROR) << "Input ranges missing when input data validation is enabled for "
346                    << common_instance_->sensor_name;
347         return kVtEstimatorInitFailed;
348     }
349 
350     common_instance_->offset_thresholds = data.offset_thresholds;
351     common_instance_->offset_values = data.offset_values;
352     tflite_instance_->model_path = model_path;
353 
354     common_instance_->is_initialized = true;
355     LOG(INFO) << "Successfully initialized VirtualTempEstimator for " << model_path;
356     return kVtEstimatorOk;
357 }
358 
LinearModelEstimate(const std::vector<float> & thermistors,std::vector<float> * output)359 VtEstimatorStatus VirtualTempEstimator::LinearModelEstimate(const std::vector<float> &thermistors,
360                                                             std::vector<float> *output) {
361     if (linear_model_instance_ == nullptr || common_instance_ == nullptr) {
362         LOG(ERROR) << "linear_model_instance_ or common_instance_ is nullptr during Initialize";
363         return kVtEstimatorInitFailed;
364     }
365 
366     size_t prev_samples_order = common_instance_->prev_samples_order;
367     size_t num_linked_sensors = common_instance_->num_linked_sensors;
368 
369     std::unique_lock<std::mutex> lock(linear_model_instance_->mutex);
370 
371     if ((thermistors.size() != num_linked_sensors) || (output == nullptr)) {
372         LOG(ERROR) << "Invalid args Thermistors size[" << thermistors.size()
373                    << "] num_linked_sensors[" << num_linked_sensors << "] output[" << output << "]";
374         return kVtEstimatorInvalidArgs;
375     }
376 
377     if (common_instance_->is_initialized == false) {
378         LOG(ERROR) << "VirtualTempEstimator not initialized to estimate";
379         return kVtEstimatorInitFailed;
380     }
381 
382     // For the first iteration copy current inputs to all previous inputs
383     // This would allow the estimator to have previous samples from the first iteration itself
384     // and provide a valid predicted value
385     if (common_instance_->cur_sample_count == 0) {
386         for (size_t i = 0; i < prev_samples_order; ++i) {
387             linear_model_instance_->input_samples[i] = thermistors;
388         }
389     }
390 
391     size_t cur_sample_index = common_instance_->cur_sample_count % prev_samples_order;
392     linear_model_instance_->input_samples[cur_sample_index] = thermistors;
393 
394     // Calculate Weighted Average Value
395     int input_level = cur_sample_index;
396     float estimated_value = 0;
397     for (size_t i = 0; i < prev_samples_order; ++i) {
398         for (size_t j = 0; j < num_linked_sensors; ++j) {
399             estimated_value += linear_model_instance_->coefficients[i][j] *
400                                linear_model_instance_->input_samples[input_level][j];
401         }
402         input_level--;  // go to previous samples
403         input_level = (input_level >= 0) ? input_level : (prev_samples_order - 1);
404     }
405 
406     // Update sample count
407     common_instance_->cur_sample_count++;
408 
409     // add offset to estimated value if applicable
410     estimated_value += CalculateOffset(common_instance_->offset_thresholds,
411                                        common_instance_->offset_values, estimated_value);
412 
413     std::vector<float> data = {estimated_value};
414     *output = data;
415     return kVtEstimatorOk;
416 }
417 
TFliteEstimate(const std::vector<float> & thermistors,std::vector<float> * output)418 VtEstimatorStatus VirtualTempEstimator::TFliteEstimate(const std::vector<float> &thermistors,
419                                                        std::vector<float> *output) {
420     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
421         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during Estimate\n";
422         return kVtEstimatorInitFailed;
423     }
424 
425     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
426 
427     if (!common_instance_->is_initialized) {
428         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
429         return kVtEstimatorInitFailed;
430     }
431 
432     size_t num_linked_sensors = common_instance_->num_linked_sensors;
433     if ((thermistors.size() != num_linked_sensors) || (output == nullptr)) {
434         LOG(ERROR) << "Invalid args for " << tflite_instance_->model_path
435                    << " thermistors.size(): " << thermistors.size()
436                    << " num_linked_sensors: " << num_linked_sensors << " output: " << output;
437         return kVtEstimatorInvalidArgs;
438     }
439 
440     // log input data
441     std::string input_data_str = "model_input: [";
442     for (size_t i = 0; i < num_linked_sensors; ++i) {
443         input_data_str += ::android::base::StringPrintf("%0.2f ", thermistors[i]);
444     }
445     input_data_str += "]";
446     LOG(INFO) << input_data_str;
447 
448     // check time gap between samples and ignore stale previous samples
449     if (std::chrono::duration_cast<std::chrono::milliseconds>(boot_clock::now() -
450                                                               tflite_instance_->prev_sample_time) >=
451         tflite_instance_->max_sample_interval) {
452         LOG(INFO) << "Ignoring stale previous samples for " << common_instance_->sensor_name;
453         common_instance_->cur_sample_count = 0;
454     }
455 
456     // copy input data into input tensors
457     size_t prev_samples_order = common_instance_->prev_samples_order;
458     size_t cur_sample_index = common_instance_->cur_sample_count % prev_samples_order;
459     size_t sample_start_index = cur_sample_index * num_linked_sensors;
460     for (size_t i = 0; i < num_linked_sensors; ++i) {
461         if (tflite_instance_->enable_input_validation) {
462             if (thermistors[i] < tflite_instance_->input_range[i].min_threshold ||
463                 thermistors[i] > tflite_instance_->input_range[i].max_threshold) {
464                 LOG(INFO) << "thermistors[" << i << "] value: " << thermistors[i]
465                           << " not in range: " << tflite_instance_->input_range[i].min_threshold
466                           << " <= val <= " << tflite_instance_->input_range[i].max_threshold;
467                 common_instance_->cur_sample_count = 0;
468                 return kVtEstimatorLowConfidence;
469             }
470         }
471         tflite_instance_->input_buffer[sample_start_index + i] = thermistors[i];
472         if (cur_sample_index == 0 && tflite_instance_->support_under_sampling) {
473             // fill previous samples if support under sampling
474             for (size_t j = 1; j < prev_samples_order; ++j) {
475                 size_t copy_start_index = j * num_linked_sensors;
476                 tflite_instance_->input_buffer[copy_start_index + i] = thermistors[i];
477             }
478         }
479     }
480 
481     // Update sample count
482     common_instance_->cur_sample_count++;
483     tflite_instance_->prev_sample_time = boot_clock::now();
484     if ((common_instance_->cur_sample_count < prev_samples_order) &&
485         !(tflite_instance_->support_under_sampling)) {
486         return kVtEstimatorUnderSampling;
487     }
488 
489     // prepare model input
490     float *model_input;
491     size_t input_buffer_size = tflite_instance_->input_buffer_size;
492     size_t output_buffer_size = tflite_instance_->output_buffer_size;
493     if (!common_instance_->use_prev_samples) {
494         model_input = tflite_instance_->input_buffer;
495     } else {
496         sample_start_index = ((cur_sample_index + 1) * num_linked_sensors) % input_buffer_size;
497         for (size_t i = 0; i < input_buffer_size; ++i) {
498             size_t input_index = (sample_start_index + i) % input_buffer_size;
499             tflite_instance_->scratch_buffer[i] = tflite_instance_->input_buffer[input_index];
500         }
501         model_input = tflite_instance_->scratch_buffer;
502     }
503 
504     int ret = tflite_instance_->tflite_methods.invoke(
505             tflite_instance_->tflite_wrapper, model_input, input_buffer_size,
506             tflite_instance_->output_buffer, output_buffer_size);
507     if (ret) {
508         LOG(ERROR) << "Failed to Invoke for " << tflite_instance_->model_path << " (ret: " << ret
509                    << ")";
510         return kVtEstimatorInvokeFailed;
511     }
512     tflite_instance_->last_update_time = boot_clock::now();
513 
514     // prepare output
515     std::vector<float> data;
516     std::ostringstream model_out_log, predict_log;
517     data.reserve(output_buffer_size);
518     for (size_t i = 0; i < output_buffer_size; ++i) {
519         // add offset to predicted value
520         float predicted_value = tflite_instance_->output_buffer[i];
521         model_out_log << predicted_value << " ";
522         predicted_value += CalculateOffset(common_instance_->offset_thresholds,
523                                            common_instance_->offset_values, predicted_value);
524         predict_log << predicted_value << " ";
525         data.emplace_back(predicted_value);
526     }
527     LOG(INFO) << "model_output: [" << model_out_log.str() << "]";
528     LOG(INFO) << "predicted_value: [" << predict_log.str() << "]";
529     *output = data;
530 
531     return kVtEstimatorOk;
532 }
533 
Estimate(const std::vector<float> & thermistors,std::vector<float> * output)534 VtEstimatorStatus VirtualTempEstimator::Estimate(const std::vector<float> &thermistors,
535                                                  std::vector<float> *output) {
536     if (type == kUseMLModel) {
537         return TFliteEstimate(thermistors, output);
538     } else if (type == kUseLinearModel) {
539         return LinearModelEstimate(thermistors, output);
540     }
541 
542     LOG(ERROR) << "Unsupported estimationType [" << type << "]";
543     return kVtEstimatorUnSupported;
544 }
545 
TFliteGetMaxPredictWindowMs(size_t * predict_window_ms)546 VtEstimatorStatus VirtualTempEstimator::TFliteGetMaxPredictWindowMs(size_t *predict_window_ms) {
547     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
548         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
549         return kVtEstimatorInitFailed;
550     }
551 
552     if (!common_instance_->is_initialized) {
553         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
554         return kVtEstimatorInitFailed;
555     }
556 
557     size_t window = tflite_instance_->predict_window_ms;
558     if (window == 0) {
559         return kVtEstimatorUnSupported;
560     }
561     *predict_window_ms = window;
562     return kVtEstimatorOk;
563 }
564 
GetMaxPredictWindowMs(size_t * predict_window_ms)565 VtEstimatorStatus VirtualTempEstimator::GetMaxPredictWindowMs(size_t *predict_window_ms) {
566     if (type == kUseMLModel) {
567         return TFliteGetMaxPredictWindowMs(predict_window_ms);
568     }
569 
570     LOG(ERROR) << "Unsupported estimationType [" << type << "]";
571     return kVtEstimatorUnSupported;
572 }
573 
TFlitePredictAfterTimeMs(const size_t time_ms,float * output)574 VtEstimatorStatus VirtualTempEstimator::TFlitePredictAfterTimeMs(const size_t time_ms,
575                                                                  float *output) {
576     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
577         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
578         return kVtEstimatorInitFailed;
579     }
580 
581     if (!common_instance_->is_initialized) {
582         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
583         return kVtEstimatorInitFailed;
584     }
585 
586     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
587 
588     size_t window = tflite_instance_->predict_window_ms;
589     auto sample_interval = tflite_instance_->sample_interval;
590     auto last_update_time = tflite_instance_->last_update_time;
591     auto request_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(boot_clock::now() -
592                                                                                  last_update_time);
593     // check for under sampling
594     if ((common_instance_->cur_sample_count < common_instance_->prev_samples_order) &&
595         !(tflite_instance_->support_under_sampling)) {
596         LOG(INFO) << tflite_instance_->model_path
597                   << " cannot provide prediction while under sampling";
598         return kVtEstimatorUnderSampling;
599     }
600 
601     // calculate requested time since last update
602     request_time_ms = request_time_ms + std::chrono::milliseconds{time_ms};
603     if (sample_interval.count() == 0 || window == 0 ||
604         window < static_cast<size_t>(request_time_ms.count())) {
605         LOG(INFO) << tflite_instance_->model_path << " cannot predict temperature after ("
606                   << time_ms << " + " << request_time_ms.count() - time_ms
607                   << ") ms since last update with sample interval [" << sample_interval.count()
608                   << "] ms and predict window [" << window << "] ms";
609         return kVtEstimatorUnSupported;
610     }
611 
612     size_t request_step = request_time_ms / sample_interval;
613     size_t output_label_count = tflite_instance_->output_label_count;
614     float *output_buffer = tflite_instance_->output_buffer;
615     float prediction;
616     if (request_step == output_label_count - 1) {
617         // request prediction is on the right boundary of the window
618         prediction = output_buffer[output_label_count - 1];
619     } else {
620         float left = output_buffer[request_step], right = output_buffer[request_step + 1];
621         prediction = left;
622         if (left != right) {
623             prediction += (request_time_ms - sample_interval * request_step) * (right - left) /
624                           sample_interval;
625         }
626     }
627 
628     *output = prediction;
629 
630     return kVtEstimatorOk;
631 }
632 
PredictAfterTimeMs(const size_t time_ms,float * output)633 VtEstimatorStatus VirtualTempEstimator::PredictAfterTimeMs(const size_t time_ms, float *output) {
634     if (type == kUseMLModel) {
635         return TFlitePredictAfterTimeMs(time_ms, output);
636     }
637 
638     LOG(ERROR) << "PredictAfterTimeMs not supported for type [" << type << "]";
639     return kVtEstimatorUnSupported;
640 }
641 
TFliteGetAllPredictions(std::vector<float> * output)642 VtEstimatorStatus VirtualTempEstimator::TFliteGetAllPredictions(std::vector<float> *output) {
643     if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
644         LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
645         return kVtEstimatorInitFailed;
646     }
647 
648     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
649 
650     if (!common_instance_->is_initialized) {
651         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
652         return kVtEstimatorInitFailed;
653     }
654 
655     if (output == nullptr) {
656         LOG(ERROR) << "output is nullptr";
657         return kVtEstimatorInvalidArgs;
658     }
659 
660     std::vector<float> tflite_output;
661     size_t output_buffer_size = tflite_instance_->output_buffer_size;
662     tflite_output.reserve(output_buffer_size);
663     for (size_t i = 0; i < output_buffer_size; ++i) {
664         tflite_output.emplace_back(tflite_instance_->output_buffer[i]);
665     }
666     *output = tflite_output;
667 
668     return kVtEstimatorOk;
669 }
670 
GetAllPredictions(std::vector<float> * output)671 VtEstimatorStatus VirtualTempEstimator::GetAllPredictions(std::vector<float> *output) {
672     if (type == kUseMLModel) {
673         return TFliteGetAllPredictions(output);
674     }
675 
676     LOG(INFO) << "GetAllPredicts not supported by estimationType [" << type << "]";
677     return kVtEstimatorUnSupported;
678 }
679 
TFLiteDumpStatus(std::string_view sensor_name,std::ostringstream * dump_buf)680 VtEstimatorStatus VirtualTempEstimator::TFLiteDumpStatus(std::string_view sensor_name,
681                                                          std::ostringstream *dump_buf) {
682     if (dump_buf == nullptr) {
683         LOG(ERROR) << "dump_buf is nullptr for " << sensor_name;
684         return kVtEstimatorInvalidArgs;
685     }
686 
687     if (!common_instance_->is_initialized) {
688         LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
689         return kVtEstimatorInitFailed;
690     }
691 
692     std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
693 
694     *dump_buf << " Sensor Name: " << sensor_name << std::endl;
695     *dump_buf << "  Current Values: ";
696     size_t output_buffer_size = tflite_instance_->output_buffer_size;
697     for (size_t i = 0; i < output_buffer_size; ++i) {
698         // add offset to predicted value
699         float predicted_value = tflite_instance_->output_buffer[i];
700         predicted_value += CalculateOffset(common_instance_->offset_thresholds,
701                                            common_instance_->offset_values, predicted_value);
702         *dump_buf << predicted_value << ", ";
703     }
704     *dump_buf << std::endl;
705 
706     *dump_buf << "  Model Path: \"" << tflite_instance_->model_path << "\"" << std::endl;
707 
708     return kVtEstimatorOk;
709 }
710 
DumpStatus(std::string_view sensor_name,std::ostringstream * dump_buff)711 VtEstimatorStatus VirtualTempEstimator::DumpStatus(std::string_view sensor_name,
712                                                    std::ostringstream *dump_buff) {
713     if (type == kUseMLModel) {
714         return TFLiteDumpStatus(sensor_name, dump_buff);
715     }
716 
717     LOG(INFO) << "DumpStatus not supported by estimationType [" << type << "]";
718     return kVtEstimatorUnSupported;
719 }
720 
Initialize(const VtEstimationInitData & data)721 VtEstimatorStatus VirtualTempEstimator::Initialize(const VtEstimationInitData &data) {
722     LOG(INFO) << "Initialize VirtualTempEstimator for " << type;
723 
724     if (type == kUseMLModel) {
725         return TFliteInitialize(data.ml_model_init_data);
726     } else if (type == kUseLinearModel) {
727         return LinearModelInitialize(data.linear_model_init_data);
728     }
729 
730     LOG(ERROR) << "Unsupported estimationType [" << type << "]";
731     return kVtEstimatorUnSupported;
732 }
733 
ParseInputConfig(const Json::Value & input_config)734 bool VirtualTempEstimator::ParseInputConfig(const Json::Value &input_config) {
735     if (!input_config["ModelConfig"].empty()) {
736         if (!input_config["ModelConfig"]["sample_interval_ms"].empty()) {
737             // read input sample interval
738             int sample_interval_ms = input_config["ModelConfig"]["sample_interval_ms"].asInt();
739             if (sample_interval_ms <= 0) {
740                 LOG(ERROR) << "Invalid sample_interval_ms: " << sample_interval_ms;
741                 return false;
742             }
743 
744             tflite_instance_->sample_interval = std::chrono::milliseconds{sample_interval_ms};
745             LOG(INFO) << "Parsed tflite model input sample_interval: " << sample_interval_ms
746                       << " for " << common_instance_->sensor_name;
747 
748             // determine predict window
749             tflite_instance_->predict_window_ms =
750                     sample_interval_ms * (tflite_instance_->output_label_count - 1);
751             LOG(INFO) << "Max prediction window size: " << tflite_instance_->predict_window_ms
752                       << " ms for " << common_instance_->sensor_name;
753         }
754 
755         if (!input_config["ModelConfig"]["max_sample_interval_ms"].empty()) {
756             // read input max sample interval
757             int max_sample_interval_ms =
758                     input_config["ModelConfig"]["max_sample_interval_ms"].asInt();
759             if (max_sample_interval_ms <= 0) {
760                 LOG(ERROR) << "Invalid max_sample_interval_ms " << max_sample_interval_ms;
761                 return false;
762             }
763 
764             tflite_instance_->max_sample_interval =
765                     std::chrono::milliseconds{max_sample_interval_ms};
766             LOG(INFO) << "Parsed tflite model max_sample_interval: " << max_sample_interval_ms
767                       << " for " << common_instance_->sensor_name;
768         }
769     }
770 
771     if (!input_config["InputData"].empty()) {
772         Json::Value input_data = input_config["InputData"];
773         if (input_data.size() != common_instance_->num_linked_sensors) {
774             LOG(ERROR) << "Input ranges size: " << input_data.size()
775                        << " does not match num_linked_sensors: "
776                        << common_instance_->num_linked_sensors;
777             return false;
778         }
779 
780         LOG(INFO) << "Start to parse tflite model input config for "
781                   << common_instance_->num_linked_sensors;
782         tflite_instance_->input_range.assign(input_data.size(), InputRangeInfo());
783         for (Json::Value::ArrayIndex i = 0; i < input_data.size(); ++i) {
784             const std::string &name = input_data[i]["Name"].asString();
785             LOG(INFO) << "Sensor[" << i << "] Name: " << name;
786             if (!getInputRangeInfoFromJsonValues(input_data[i]["Range"],
787                                                  &tflite_instance_->input_range[i])) {
788                 LOG(ERROR) << "Failed to parse tflite model temp range for sensor: [" << name
789                            << "]";
790                 return false;
791             }
792         }
793     }
794 
795     return true;
796 }
797 
GetInputConfig(Json::Value * config)798 bool VirtualTempEstimator::GetInputConfig(Json::Value *config) {
799     int config_size = 0;
800     int ret = tflite_instance_->tflite_methods.get_input_config_size(
801             tflite_instance_->tflite_wrapper, &config_size);
802     if (ret || config_size <= 0) {
803         LOG(ERROR) << "Failed to get tflite input config size (ret: " << ret
804                    << ") with size: " << config_size;
805         return false;
806     }
807 
808     LOG(INFO) << "Model input config_size: " << config_size << " for "
809               << common_instance_->sensor_name;
810 
811     char *config_str = new char[config_size];
812     ret = tflite_instance_->tflite_methods.get_input_config(tflite_instance_->tflite_wrapper,
813                                                             config_str, config_size);
814     if (ret) {
815         LOG(ERROR) << "Failed to get tflite input config (ret: " << ret << ")";
816         delete[] config_str;
817         return false;
818     }
819 
820     Json::CharReaderBuilder builder;
821     std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
822     std::string errorMessage;
823 
824     bool success = true;
825     if (!reader->parse(config_str, config_str + config_size, config, &errorMessage)) {
826         LOG(ERROR) << "Failed to parse tflite JSON input config: " << errorMessage;
827         success = false;
828     }
829     delete[] config_str;
830     return success;
831 }
832 
833 }  // namespace vtestimator
834 }  // namespace thermal
835