1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
17 #define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 namespace tflite {
25 namespace tools {
26 
27 template <typename T>
28 class TypedToolParam;
29 
30 class ToolParam {
31  protected:
32   enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
33   template <typename T>
34   static ParamType GetValueType();
35 
36  public:
37   template <typename T>
Create(const T & default_value)38   static std::unique_ptr<ToolParam> Create(const T& default_value) {
39     return std::unique_ptr<ToolParam>(new TypedToolParam<T>(default_value));
40   }
41 
42   template <typename T>
AsTyped()43   TypedToolParam<T>* AsTyped() {
44     AssertHasSameType(GetValueType<T>(), type_);
45     return static_cast<TypedToolParam<T>*>(this);
46   }
47 
48   template <typename T>
AsConstTyped()49   const TypedToolParam<T>* AsConstTyped() const {
50     AssertHasSameType(GetValueType<T>(), type_);
51     return static_cast<const TypedToolParam<T>*>(this);
52   }
53 
~ToolParam()54   virtual ~ToolParam() {}
ToolParam(ParamType type)55   explicit ToolParam(ParamType type) : has_value_set_(false), type_(type) {}
56 
HasValueSet()57   bool HasValueSet() const { return has_value_set_; }
58 
Set(const ToolParam &)59   virtual void Set(const ToolParam&) {}
60 
61   virtual std::unique_ptr<ToolParam> Clone() const = 0;
62 
63  protected:
64   bool has_value_set_;
65 
66  private:
67   static void AssertHasSameType(ParamType a, ParamType b);
68 
69   const ParamType type_;
70 };
71 
72 template <typename T>
73 class TypedToolParam : public ToolParam {
74  public:
TypedToolParam(const T & value)75   explicit TypedToolParam(const T& value)
76       : ToolParam(GetValueType<T>()), value_(value) {}
77 
Set(const T & value)78   void Set(const T& value) {
79     value_ = value;
80     has_value_set_ = true;
81   }
82 
Get()83   T Get() const { return value_; }
84 
Set(const ToolParam & other)85   void Set(const ToolParam& other) override {
86     Set(other.AsConstTyped<T>()->Get());
87   }
88 
Clone()89   std::unique_ptr<ToolParam> Clone() const override {
90     return std::unique_ptr<ToolParam>(new TypedToolParam<T>(value_));
91   }
92 
93  private:
94   T value_;
95 };
96 
97 // A map-like container for holding values of different types.
98 class ToolParams {
99  public:
AddParam(const std::string & name,std::unique_ptr<ToolParam> value)100   void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) {
101     params_[name] = std::move(value);
102   }
103 
HasParam(const std::string & name)104   bool HasParam(const std::string& name) const {
105     return params_.find(name) != params_.end();
106   }
107 
Empty()108   bool Empty() const { return params_.empty(); }
109 
GetParam(const std::string & name)110   const ToolParam* GetParam(const std::string& name) const {
111     const auto& entry = params_.find(name);
112     if (entry == params_.end()) return nullptr;
113     return entry->second.get();
114   }
115 
116   template <typename T>
Set(const std::string & name,const T & value)117   void Set(const std::string& name, const T& value) {
118     AssertParamExists(name);
119     params_.at(name)->AsTyped<T>()->Set(value);
120   }
121 
122   template <typename T>
HasValueSet(const std::string & name)123   bool HasValueSet(const std::string& name) const {
124     AssertParamExists(name);
125     return params_.at(name)->AsConstTyped<T>()->HasValueSet();
126   }
127 
128   template <typename T>
Get(const std::string & name)129   T Get(const std::string& name) const {
130     AssertParamExists(name);
131     return params_.at(name)->AsConstTyped<T>()->Get();
132   }
133 
134   // Set the value of all same parameters from 'other'.
135   void Set(const ToolParams& other);
136 
137   // Merge the value of all parameters from 'other'. 'overwrite' indicates
138   // whether the value of the same paratmeter is overwritten or not.
139   void Merge(const ToolParams& other, bool overwrite = false);
140 
141  private:
142   void AssertParamExists(const std::string& name) const;
143   std::unordered_map<std::string, std::unique_ptr<ToolParam>> params_;
144 };
145 
146 #define LOG_TOOL_PARAM(params, type, name, description, verbose)      \
147   do {                                                                \
148     TFLITE_MAY_LOG(INFO, (verbose) || params.HasValueSet<type>(name)) \
149         << description << ": [" << params.Get<type>(name) << "]";     \
150   } while (0)
151 
152 }  // namespace tools
153 }  // namespace tflite
154 #endif  // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
155