1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
16 #define TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
17 
18 #include "tensorflow/lite/toco/model.h"
19 #include "tensorflow/lite/toco/tflite/operator.h"
20 #include "tensorflow/lite/util.h"
21 
22 namespace toco {
23 
24 namespace tflite {
25 
26 // The parameters for exporting a TFLite model.
27 struct ExportParams {
28   bool allow_custom_ops = false;
29   bool enable_select_tf_ops = false;
30   bool quantize_weights = false;
31 };
32 
33 // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
34 // result in the given string.
35 tensorflow::Status Export(const Model& model, string* output_file_contents,
36                           const ExportParams& params);
37 
38 // Export API with custom TFLite operator mapping.
39 tensorflow::Status Export(
40     const Model& model, string* output_file_contents,
41     const ExportParams& params,
42     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
43 
44 // This is for backward-compatibility.
45 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,string * output_file_contents)46 inline void Export(const Model& model, bool allow_custom_ops,
47                    bool quantize_weights, string* output_file_contents) {
48   ExportParams params;
49   params.allow_custom_ops = allow_custom_ops;
50   params.quantize_weights = quantize_weights;
51   auto status = Export(model, output_file_contents, params);
52   if (!status.ok()) LOG(QFATAL) << status.error_message();
53 }
54 
55 // This is for backward-compatibility.
56 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,string * output_file_contents,const std::map<OperatorType,std::unique_ptr<BaseOperator>> & ops_by_type)57 inline void Export(
58     const Model& model, bool allow_custom_ops, bool quantize_weights,
59     string* output_file_contents,
60     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
61   ExportParams params;
62   params.allow_custom_ops = allow_custom_ops;
63   params.quantize_weights = quantize_weights;
64   auto status = Export(model, output_file_contents, params, ops_by_type);
65   if (!status.ok()) LOG(QFATAL) << status.error_message();
66 }
67 
68 // This is for backward-compatibility.
69 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,string * output_file_contents)70 inline void Export(const Model& model, string* output_file_contents) {
71   ExportParams params;
72   params.allow_custom_ops = true;
73   auto status = Export(model, output_file_contents, params);
74   if (!status.ok()) LOG(QFATAL) << status.error_message();
75 }
76 
77 namespace details {
78 
79 // A map from tensor name to its final position in the TF Lite buffer.
80 using TensorsMap = std::unordered_map<string, int>;
81 
82 // A key to identify an operator.
83 // Only when `type` is `kUnsupported`, `custom_code` is filled to
84 // identify which operation is used.
85 class OperatorKey {
86  public:
OperatorKey()87   OperatorKey() {}
88 
89   // Construct OperatorKey by Toco op.
90   OperatorKey(
91       const ::toco::OperatorSignature& op_signature,
92       const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
93       bool enable_select_tf_ops);
94 
95   // Construct OperatorKey by type, custom code and version.
96   // Note that this construct doesn't set the additional information including
97   // `is_custom_op`, `is_flex_op`, `is_unsupported_flex_op`.
OperatorKey(::tflite::BuiltinOperator type,const std::string & custom_code,int version)98   OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
99               int version)
100       : type_(type), custom_code_(custom_code), version_(version) {}
101 
102   // Only `type`, `custom_code` and `version` is used to compute hash and
103   // identity.
type()104   ::tflite::BuiltinOperator type() const { return type_; }
custom_code()105   const std::string& custom_code() const { return custom_code_; }
version()106   int version() const { return version_; }
107 
108   // The attributes below are not used to compute hash and identity.
109   //
110   // Return true if the op is a custom op. Note it will return false for Flex
111   // ops.
is_custom_op()112   bool is_custom_op() const { return is_custom_op_; }
113   // Return true if the op is a Flex op.
is_flex_op()114   bool is_flex_op() const { return is_flex_op_; }
115   // Return true if the op is a Flex op but it's knwon that the op is not
116   // supported by Flex runtime.
is_unsupported_flex_op()117   bool is_unsupported_flex_op() const { return is_unsupported_flex_op_; }
118   // Return the original TensorFlow op name for a Flex op.
flex_tensorflow_op()119   const std::string& flex_tensorflow_op() const { return flex_tensorflow_op_; }
120 
121   bool operator<(const OperatorKey& other) const {
122     if (type_ < other.type_)
123       return true;
124     else if (type_ > other.type_)
125       return false;
126     else if (custom_code_ < other.custom_code_)
127       return true;
128     else if (custom_code_ > other.custom_code_)
129       return false;
130     else
131       return version_ < other.version_;
132   }
133 
134   bool operator==(const OperatorKey& other) const {
135     return type_ == other.type_ && custom_code_ == other.custom_code_ &&
136            version_ == other.version_;
137   }
138 
139   struct Hash {
operatorHash140     size_t operator()(const OperatorKey& key) const {
141       return ::tflite::CombineHashes(
142           {std::hash<size_t>()(static_cast<size_t>(key.type())),
143            std::hash<std::string>()(key.custom_code()),
144            std::hash<int>()(key.version())});
145     }
146   };
147 
148  private:
149   ::tflite::BuiltinOperator type_ = ::tflite::BuiltinOperator_CUSTOM;
150   std::string custom_code_;
151   int version_ = 1;
152 
153   bool is_custom_op_ = false;
154   bool is_flex_op_ = false;
155   bool is_unsupported_flex_op_ = false;
156   // The original TensorFlow op name for the flex op. Filled only when
157   // `is_flex_op` is true.
158   std::string flex_tensorflow_op_;
159 };
160 
161 // A map from OperatorKey to its final position in the TF Lite buffer.
162 using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
163 
164 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
165 void LoadOperatorsMap(
166     const Model& model, OperatorsMap* operators_map,
167     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
168     bool enable_select_tf_ops);
169 
170 }  // namespace details
171 }  // namespace tflite
172 }  // namespace toco
173 
174 #endif  // TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
175