1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
17 #define TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/tools/command_line_flags.h"
24 #include "tensorflow/lite/tools/logging.h"
25 #include "tensorflow/lite/tools/tool_params.h"
26 
27 namespace tflite {
28 namespace tools {
29 
30 // Same w/ Interpreter::TfLiteDelegatePtr to avoid pulling
31 // tensorflow/lite/interpreter.h dependency
32 using TfLiteDelegatePtr =
33     std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
34 
35 class DelegateProvider {
36  public:
~DelegateProvider()37   virtual ~DelegateProvider() {}
38 
39   // Create a list of command-line parsable flags based on tool params inside
40   // 'params' whose value will be set to the corresponding runtime flag value.
41   virtual std::vector<Flag> CreateFlags(ToolParams* params) const = 0;
42 
43   // Log tool params. If 'verbose' is set to false, the param is going to be
44   // only logged if its value has been set, say via being parsed from
45   // commandline flags.
46   virtual void LogParams(const ToolParams& params, bool verbose) const = 0;
47 
48   // Create a TfLiteDelegate based on tool params.
49   virtual TfLiteDelegatePtr CreateTfLiteDelegate(
50       const ToolParams& params) const = 0;
51 
52   virtual std::string GetName() const = 0;
53 
DefaultParams()54   const ToolParams& DefaultParams() const { return default_params_; }
55 
56  protected:
57   template <typename T>
CreateFlag(const char * name,ToolParams * params,const std::string & usage)58   Flag CreateFlag(const char* name, ToolParams* params,
59                   const std::string& usage) const {
60     return Flag(
61         name, [params, name](const T& val) { params->Set<T>(name, val); },
62         default_params_.Get<T>(name), usage, Flag::kOptional);
63   }
64   ToolParams default_params_;
65 };
66 
67 using DelegateProviderPtr = std::unique_ptr<DelegateProvider>;
68 using DelegateProviderList = std::vector<DelegateProviderPtr>;
69 
70 class DelegateProviderRegistrar {
71  public:
72   template <typename T>
73   struct Register {
RegisterRegister74     Register() {
75       auto* const instance = DelegateProviderRegistrar::GetSingleton();
76       instance->providers_.emplace_back(DelegateProviderPtr(new T()));
77     }
78   };
79 
GetProviders()80   static const DelegateProviderList& GetProviders() {
81     return GetSingleton()->providers_;
82   }
83 
84  private:
DelegateProviderRegistrar()85   DelegateProviderRegistrar() {}
86   DelegateProviderRegistrar(const DelegateProviderRegistrar&) = delete;
87   DelegateProviderRegistrar& operator=(const DelegateProviderRegistrar&) =
88       delete;
89 
GetSingleton()90   static DelegateProviderRegistrar* GetSingleton() {
91     static auto* instance = new DelegateProviderRegistrar();
92     return instance;
93   }
94   DelegateProviderList providers_;
95 };
96 
97 #define REGISTER_DELEGATE_PROVIDER_VNAME(T) gDelegateProvider_##T##_
98 #define REGISTER_DELEGATE_PROVIDER(T)                          \
99   static tflite::tools::DelegateProviderRegistrar::Register<T> \
100       REGISTER_DELEGATE_PROVIDER_VNAME(T);
101 
102 // A global helper function to get all registered delegate providers.
GetRegisteredDelegateProviders()103 inline const DelegateProviderList& GetRegisteredDelegateProviders() {
104   return DelegateProviderRegistrar::GetProviders();
105 }
106 }  // namespace tools
107 }  // namespace tflite
108 
109 #endif  // TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
110