1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/benchmark/benchmark_model.h"
17 
18 #include <iostream>
19 #include <sstream>
20 
21 #include "tensorflow/lite/profiling/memory_info.h"
22 #include "tensorflow/lite/profiling/time.h"
23 #include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
24 #include "tensorflow/lite/tools/logging.h"
25 
26 namespace tflite {
27 namespace benchmark {
28 using tensorflow::Stat;
29 
DefaultParams()30 BenchmarkParams BenchmarkModel::DefaultParams() {
31   BenchmarkParams params;
32   params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(50));
33   params.AddParam("min_secs", BenchmarkParam::Create<float>(1.0f));
34   params.AddParam("max_secs", BenchmarkParam::Create<float>(150.0f));
35   params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
36   params.AddParam("run_frequency", BenchmarkParam::Create<float>(-1.0f));
37   params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
38   params.AddParam("use_caching", BenchmarkParam::Create<bool>(false));
39   params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
40   params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
41   params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
42   params.AddParam("warmup_min_secs", BenchmarkParam::Create<float>(0.5f));
43   params.AddParam("verbose", BenchmarkParam::Create<bool>(false));
44   return params;
45 }
46 
BenchmarkModel()47 BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {}
48 
OnBenchmarkEnd(const BenchmarkResults & results)49 void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults& results) {
50   auto inference_us = results.inference_time_us();
51   auto init_us = results.startup_latency_us();
52   auto warmup_us = results.warmup_time_us();
53   auto init_mem_usage = results.init_mem_usage();
54   auto overall_mem_usage = results.overall_mem_usage();
55   TFLITE_LOG(INFO) << "Inference timings in us: "
56                    << "Init: " << init_us << ", "
57                    << "First inference: " << warmup_us.first() << ", "
58                    << "Warmup (avg): " << warmup_us.avg() << ", "
59                    << "Inference (avg): " << inference_us.avg();
60 
61   if (!init_mem_usage.IsSupported()) return;
62   TFLITE_LOG(INFO)
63       << "Note: as the benchmark tool itself affects memory footprint, the "
64          "following is only APPROXIMATE to the actual memory footprint of the "
65          "model at runtime. Take the information at your discretion.";
66   TFLITE_LOG(INFO) << "Peak memory footprint (MB): init="
67                    << init_mem_usage.max_rss_kb / 1024.0
68                    << " overall=" << overall_mem_usage.max_rss_kb / 1024.0;
69 }
70 
GetFlags()71 std::vector<Flag> BenchmarkModel::GetFlags() {
72   return {
73       CreateFlag<int32_t>(
74           "num_runs", &params_,
75           "expected number of runs, see also min_secs, max_secs"),
76       CreateFlag<float>(
77           "min_secs", &params_,
78           "minimum number of seconds to rerun for, potentially making the "
79           "actual number of runs to be greater than num_runs"),
80       CreateFlag<float>(
81           "max_secs", &params_,
82           "maximum number of seconds to rerun for, potentially making the "
83           "actual number of runs to be less than num_runs. Note if --max-secs "
84           "is exceeded in the middle of a run, the benchmark will continue to "
85           "the end of the run but will not start the next run."),
86       CreateFlag<float>("run_delay", &params_, "delay between runs in seconds"),
87       CreateFlag<float>(
88           "run_frequency", &params_,
89           "Execute at a fixed frequency, instead of a fixed delay."
90           "Note if the targeted rate per second cannot be reached, the "
91           "benchmark would start the next run immediately, trying its best to "
92           "catch up. If set, this will override run_delay."),
93       CreateFlag<int32_t>("num_threads", &params_, "number of threads"),
94       CreateFlag<bool>(
95           "use_caching", &params_,
96           "Enable caching of prepacked weights matrices in matrix "
97           "multiplication routines. Currently implies the use of the Ruy "
98           "library."),
99       CreateFlag<std::string>("benchmark_name", &params_, "benchmark name"),
100       CreateFlag<std::string>("output_prefix", &params_,
101                               "benchmark output prefix"),
102       CreateFlag<int32_t>(
103           "warmup_runs", &params_,
104           "minimum number of runs performed on initialization, to "
105           "allow performance characteristics to settle, see also "
106           "warmup_min_secs"),
107       CreateFlag<float>(
108           "warmup_min_secs", &params_,
109           "minimum number of seconds to rerun for, potentially making the "
110           "actual number of warm-up runs to be greater than warmup_runs"),
111       CreateFlag<bool>("verbose", &params_,
112                        "Whether to log parameters whose values are not set. "
113                        "By default, only log those parameters that are set by "
114                        "parsing their values from the commandline flags."),
115   };
116 }
117 
LogParams()118 void BenchmarkModel::LogParams() {
119   const bool verbose = params_.Get<bool>("verbose");
120   TFLITE_LOG(INFO) << "Log parameter values verbosely: [" << verbose << "]";
121 
122   LOG_BENCHMARK_PARAM(int32_t, "num_runs", "Min num runs", verbose);
123   LOG_BENCHMARK_PARAM(float, "min_secs", "Min runs duration (seconds)",
124                       verbose);
125   LOG_BENCHMARK_PARAM(float, "max_secs", "Max runs duration (seconds)",
126                       verbose);
127   LOG_BENCHMARK_PARAM(float, "run_delay", "Inter-run delay (seconds)", verbose);
128   LOG_BENCHMARK_PARAM(float, "run_frequency",
129                       "Number of prorated runs per second", verbose);
130   LOG_BENCHMARK_PARAM(int32_t, "num_threads", "Num threads", verbose);
131   LOG_BENCHMARK_PARAM(bool, "use_caching", "Use caching", verbose);
132   LOG_BENCHMARK_PARAM(std::string, "benchmark_name", "Benchmark name", verbose);
133   LOG_BENCHMARK_PARAM(std::string, "output_prefix", "Output prefix", verbose);
134   LOG_BENCHMARK_PARAM(int32_t, "warmup_runs", "Min warmup runs", verbose);
135   LOG_BENCHMARK_PARAM(float, "warmup_min_secs",
136                       "Min warmup runs duration (seconds)", verbose);
137 }
138 
PrepareInputData()139 TfLiteStatus BenchmarkModel::PrepareInputData() { return kTfLiteOk; }
140 
ResetInputsAndOutputs()141 TfLiteStatus BenchmarkModel::ResetInputsAndOutputs() { return kTfLiteOk; }
142 
Run(int min_num_times,float min_secs,float max_secs,RunType run_type,TfLiteStatus * invoke_status)143 Stat<int64_t> BenchmarkModel::Run(int min_num_times, float min_secs,
144                                   float max_secs, RunType run_type,
145                                   TfLiteStatus* invoke_status) {
146   Stat<int64_t> run_stats;
147   TFLITE_LOG(INFO) << "Running benchmark for at least " << min_num_times
148                    << " iterations and at least " << min_secs << " seconds but"
149                    << " terminate if exceeding " << max_secs << " seconds.";
150   int64_t now_us = profiling::time::NowMicros();
151   int64_t min_finish_us = now_us + static_cast<int64_t>(min_secs * 1.e6f);
152   int64_t max_finish_us = now_us + static_cast<int64_t>(max_secs * 1.e6f);
153 
154   *invoke_status = kTfLiteOk;
155   float inter_run_sleep_time = params_.Get<float>("run_delay");
156   auto run_frequency = params_.Get<float>("run_frequency");
157   double manual_inter_run_gap = 1.0 / run_frequency;
158   // float doesn't have sufficient precision for storing this number
159   double next_run_finish_time = now_us * 1e-6 + manual_inter_run_gap;
160   for (int run = 0; (run < min_num_times || now_us < min_finish_us) &&
161                     now_us <= max_finish_us;
162        run++) {
163     ResetInputsAndOutputs();
164     listeners_.OnSingleRunStart(run_type);
165     int64_t start_us = profiling::time::NowMicros();
166     TfLiteStatus status = RunImpl();
167     int64_t end_us = profiling::time::NowMicros();
168     listeners_.OnSingleRunEnd();
169 
170     run_stats.UpdateStat(end_us - start_us);
171     if (run_frequency > 0) {
172       inter_run_sleep_time =
173           next_run_finish_time - profiling::time::NowMicros() * 1e-6;
174       next_run_finish_time += manual_inter_run_gap;
175     }
176     // Note when "inter_run_sleep_time" is negative or 0.0,
177     // the function will return immediately.
178     util::SleepForSeconds(inter_run_sleep_time);
179     now_us = profiling::time::NowMicros();
180 
181     if (status != kTfLiteOk) {
182       *invoke_status = status;
183     }
184   }
185 
186   std::stringstream stream;
187   run_stats.OutputToStream(&stream);
188   TFLITE_LOG(INFO) << stream.str() << std::endl;
189 
190   return run_stats;
191 }
192 
ValidateParams()193 TfLiteStatus BenchmarkModel::ValidateParams() { return kTfLiteOk; }
194 
Run(int argc,char ** argv)195 TfLiteStatus BenchmarkModel::Run(int argc, char** argv) {
196   TF_LITE_ENSURE_STATUS(ParseFlags(argc, argv));
197   return Run();
198 }
199 
Run()200 TfLiteStatus BenchmarkModel::Run() {
201   TF_LITE_ENSURE_STATUS(ValidateParams());
202 
203   LogParams();
204 
205   const double model_size_mb = MayGetModelFileSize() / 1e6;
206   const auto start_mem_usage = profiling::memory::GetMemoryUsage();
207   int64_t initialization_start_us = profiling::time::NowMicros();
208   TF_LITE_ENSURE_STATUS(Init());
209   const auto init_end_mem_usage = profiling::memory::GetMemoryUsage();
210   int64_t initialization_end_us = profiling::time::NowMicros();
211   int64_t startup_latency_us = initialization_end_us - initialization_start_us;
212   const auto init_mem_usage = init_end_mem_usage - start_mem_usage;
213 
214   if (model_size_mb > 0) {
215     TFLITE_LOG(INFO) << "The input model file size (MB): " << model_size_mb;
216   }
217   TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3
218                    << "ms.";
219 
220   TF_LITE_ENSURE_STATUS(PrepareInputData());
221 
222   TfLiteStatus status = kTfLiteOk;
223   uint64_t input_bytes = ComputeInputBytes();
224   listeners_.OnBenchmarkStart(params_);
225   Stat<int64_t> warmup_time_us =
226       Run(params_.Get<int32_t>("warmup_runs"),
227           params_.Get<float>("warmup_min_secs"), params_.Get<float>("max_secs"),
228           WARMUP, &status);
229   if (status != kTfLiteOk) {
230     return status;
231   }
232 
233   Stat<int64_t> inference_time_us =
234       Run(params_.Get<int32_t>("num_runs"), params_.Get<float>("min_secs"),
235           params_.Get<float>("max_secs"), REGULAR, &status);
236   const auto overall_mem_usage =
237       profiling::memory::GetMemoryUsage() - start_mem_usage;
238 
239   listeners_.OnBenchmarkEnd({model_size_mb, startup_latency_us, input_bytes,
240                              warmup_time_us, inference_time_us, init_mem_usage,
241                              overall_mem_usage});
242   return status;
243 }
244 
ParseFlags(int * argc,char ** argv)245 TfLiteStatus BenchmarkModel::ParseFlags(int* argc, char** argv) {
246   auto flag_list = GetFlags();
247   const bool parse_result =
248       Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
249   if (!parse_result) {
250     std::string usage = Flags::Usage(argv[0], flag_list);
251     TFLITE_LOG(ERROR) << usage;
252     return kTfLiteError;
253   }
254 
255   std::string unconsumed_args =
256       Flags::ArgsToString(*argc, const_cast<const char**>(argv));
257   if (!unconsumed_args.empty()) {
258     TFLITE_LOG(WARN) << "Unconsumed cmdline flags: " << unconsumed_args;
259   }
260 
261   return kTfLiteOk;
262 }
263 
264 }  // namespace benchmark
265 }  // namespace tflite
266