1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
16 #define TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
17
18 #include "tensorflow/lite/toco/model.h"
19 #include "tensorflow/lite/toco/tflite/operator.h"
20 #include "tensorflow/lite/util.h"
21
22 namespace toco {
23
24 namespace tflite {
25
26 // The parameters for exporting a TFLite model.
27 struct ExportParams {
28 bool allow_custom_ops = false;
29 bool enable_select_tf_ops = false;
30 bool quantize_weights = false;
31 };
32
33 // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
34 // result in the given string.
35 tensorflow::Status Export(const Model& model, string* output_file_contents,
36 const ExportParams& params);
37
38 // Export API with custom TFLite operator mapping.
39 tensorflow::Status Export(
40 const Model& model, string* output_file_contents,
41 const ExportParams& params,
42 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
43
44 // This is for backward-compatibility.
45 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,string * output_file_contents)46 inline void Export(const Model& model, bool allow_custom_ops,
47 bool quantize_weights, string* output_file_contents) {
48 ExportParams params;
49 params.allow_custom_ops = allow_custom_ops;
50 params.quantize_weights = quantize_weights;
51 auto status = Export(model, output_file_contents, params);
52 if (!status.ok()) LOG(QFATAL) << status.error_message();
53 }
54
55 // This is for backward-compatibility.
56 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,string * output_file_contents,const std::map<OperatorType,std::unique_ptr<BaseOperator>> & ops_by_type)57 inline void Export(
58 const Model& model, bool allow_custom_ops, bool quantize_weights,
59 string* output_file_contents,
60 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
61 ExportParams params;
62 params.allow_custom_ops = allow_custom_ops;
63 params.quantize_weights = quantize_weights;
64 auto status = Export(model, output_file_contents, params, ops_by_type);
65 if (!status.ok()) LOG(QFATAL) << status.error_message();
66 }
67
68 // This is for backward-compatibility.
69 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,string * output_file_contents)70 inline void Export(const Model& model, string* output_file_contents) {
71 ExportParams params;
72 params.allow_custom_ops = true;
73 auto status = Export(model, output_file_contents, params);
74 if (!status.ok()) LOG(QFATAL) << status.error_message();
75 }
76
77 namespace details {
78
79 // A map from tensor name to its final position in the TF Lite buffer.
80 using TensorsMap = std::unordered_map<string, int>;
81
82 // A key to identify an operator.
83 // Only when `type` is `kUnsupported`, `custom_code` is filled to
84 // identify which operation is used.
85 class OperatorKey {
86 public:
OperatorKey()87 OperatorKey() {}
88
89 // Construct OperatorKey by Toco op.
90 OperatorKey(
91 const ::toco::OperatorSignature& op_signature,
92 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
93 bool enable_select_tf_ops);
94
95 // Construct OperatorKey by type, custom code and version.
96 // Note that this construct doesn't set the additional information including
97 // `is_custom_op`, `is_flex_op`, `is_unsupported_flex_op`.
OperatorKey(::tflite::BuiltinOperator type,const std::string & custom_code,int version)98 OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
99 int version)
100 : type_(type), custom_code_(custom_code), version_(version) {}
101
102 // Only `type`, `custom_code` and `version` is used to compute hash and
103 // identity.
type()104 ::tflite::BuiltinOperator type() const { return type_; }
custom_code()105 const std::string& custom_code() const { return custom_code_; }
version()106 int version() const { return version_; }
107
108 // The attributes below are not used to compute hash and identity.
109 //
110 // Return true if the op is a custom op. Note it will return false for Flex
111 // ops.
is_custom_op()112 bool is_custom_op() const { return is_custom_op_; }
113 // Return true if the op is a Flex op.
is_flex_op()114 bool is_flex_op() const { return is_flex_op_; }
115 // Return true if the op is a Flex op but it's knwon that the op is not
116 // supported by Flex runtime.
is_unsupported_flex_op()117 bool is_unsupported_flex_op() const { return is_unsupported_flex_op_; }
118 // Return the original TensorFlow op name for a Flex op.
flex_tensorflow_op()119 const std::string& flex_tensorflow_op() const { return flex_tensorflow_op_; }
120
121 bool operator<(const OperatorKey& other) const {
122 if (type_ < other.type_)
123 return true;
124 else if (type_ > other.type_)
125 return false;
126 else if (custom_code_ < other.custom_code_)
127 return true;
128 else if (custom_code_ > other.custom_code_)
129 return false;
130 else
131 return version_ < other.version_;
132 }
133
134 bool operator==(const OperatorKey& other) const {
135 return type_ == other.type_ && custom_code_ == other.custom_code_ &&
136 version_ == other.version_;
137 }
138
139 struct Hash {
operatorHash140 size_t operator()(const OperatorKey& key) const {
141 return ::tflite::CombineHashes(
142 {std::hash<size_t>()(static_cast<size_t>(key.type())),
143 std::hash<std::string>()(key.custom_code()),
144 std::hash<int>()(key.version())});
145 }
146 };
147
148 private:
149 ::tflite::BuiltinOperator type_ = ::tflite::BuiltinOperator_CUSTOM;
150 std::string custom_code_;
151 int version_ = 1;
152
153 bool is_custom_op_ = false;
154 bool is_flex_op_ = false;
155 bool is_unsupported_flex_op_ = false;
156 // The original TensorFlow op name for the flex op. Filled only when
157 // `is_flex_op` is true.
158 std::string flex_tensorflow_op_;
159 };
160
161 // A map from OperatorKey to its final position in the TF Lite buffer.
162 using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
163
164 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
165 void LoadOperatorsMap(
166 const Model& model, OperatorsMap* operators_map,
167 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
168 bool enable_select_tf_ops);
169
170 } // namespace details
171 } // namespace tflite
172 } // namespace toco
173
174 #endif // TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
175