1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/types.h"
16 #include "tensorflow/lite/string_util.h"
17
18 namespace toco {
19
20 namespace tflite {
21
22 namespace {
23
CopyStringToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)24 DataBuffer::FlatBufferOffset CopyStringToBuffer(
25 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
26 const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data;
27 ::tflite::DynamicBuffer dyn_buffer;
28 for (const string& str : src_data) {
29 dyn_buffer.AddString(str.c_str(), str.length());
30 }
31 char* tensor_buffer;
32 int bytes = dyn_buffer.WriteToBuffer(&tensor_buffer);
33 std::vector<uint8_t> dst_data(bytes);
34 memcpy(dst_data.data(), tensor_buffer, bytes);
35 free(tensor_buffer);
36 return builder->CreateVector(dst_data.data(), bytes);
37 }
38
39 // vector<bool> may be implemented using a bit-set, so we can't just
40 // reinterpret_cast, accessing its data as vector<bool> and let flatbuffer
41 // CreateVector handle it.
42 // Background: https://isocpp.org/blog/2012/11/on-vectorbool
CopyBoolToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)43 DataBuffer::FlatBufferOffset CopyBoolToBuffer(
44 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
45 const auto& src_data = array.GetBuffer<ArrayDataType::kBool>().data;
46 return builder->CreateVector(src_data);
47 }
48
49 template <ArrayDataType T>
CopyBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)50 DataBuffer::FlatBufferOffset CopyBuffer(
51 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
52 using NativeT = ::toco::DataType<T>;
53 const auto& src_data = array.GetBuffer<T>().data;
54 const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
55 auto size = src_data.size() * sizeof(NativeT);
56 return builder->CreateVector(dst_data, size);
57 }
58
CopyStringFromBuffer(const::tflite::Buffer & buffer,Array * array)59 void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) {
60 auto* src_data = reinterpret_cast<const char*>(buffer.data()->data());
61 std::vector<string>* dst_data =
62 &array->GetMutableBuffer<ArrayDataType::kString>().data;
63 int32_t num_strings = ::tflite::GetStringCount(src_data);
64 for (int i = 0; i < num_strings; i++) {
65 ::tflite::StringRef str_ref = ::tflite::GetString(src_data, i);
66 string this_str(str_ref.str, str_ref.len);
67 dst_data->push_back(this_str);
68 }
69 }
70
71 template <ArrayDataType T>
CopyBuffer(const::tflite::Buffer & buffer,Array * array)72 void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
73 using NativeT = ::toco::DataType<T>;
74 auto* src_buffer = buffer.data();
75 const NativeT* src_data =
76 reinterpret_cast<const NativeT*>(src_buffer->data());
77 int num_items = src_buffer->size() / sizeof(NativeT);
78
79 std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
80 for (int i = 0; i < num_items; ++i) {
81 dst_data->push_back(*src_data);
82 ++src_data;
83 }
84 }
85 } // namespace
86
Serialize(ArrayDataType array_data_type)87 ::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
88 switch (array_data_type) {
89 case ArrayDataType::kFloat:
90 return ::tflite::TensorType_FLOAT32;
91 case ArrayDataType::kInt16:
92 return ::tflite::TensorType_INT16;
93 case ArrayDataType::kInt32:
94 return ::tflite::TensorType_INT32;
95 case ArrayDataType::kInt64:
96 return ::tflite::TensorType_INT64;
97 case ArrayDataType::kUint8:
98 return ::tflite::TensorType_UINT8;
99 case ArrayDataType::kString:
100 return ::tflite::TensorType_STRING;
101 case ArrayDataType::kBool:
102 return ::tflite::TensorType_BOOL;
103 case ArrayDataType::kComplex64:
104 return ::tflite::TensorType_COMPLEX64;
105 default:
106 // FLOAT32 is filled for unknown data types.
107 // TODO(ycling): Implement type inference in TF Lite interpreter.
108 return ::tflite::TensorType_FLOAT32;
109 }
110 }
111
Deserialize(int tensor_type)112 ArrayDataType DataType::Deserialize(int tensor_type) {
113 switch (::tflite::TensorType(tensor_type)) {
114 case ::tflite::TensorType_FLOAT32:
115 return ArrayDataType::kFloat;
116 case ::tflite::TensorType_INT16:
117 return ArrayDataType::kInt16;
118 case ::tflite::TensorType_INT32:
119 return ArrayDataType::kInt32;
120 case ::tflite::TensorType_INT64:
121 return ArrayDataType::kInt64;
122 case ::tflite::TensorType_STRING:
123 return ArrayDataType::kString;
124 case ::tflite::TensorType_UINT8:
125 return ArrayDataType::kUint8;
126 case ::tflite::TensorType_BOOL:
127 return ArrayDataType::kBool;
128 case ::tflite::TensorType_COMPLEX64:
129 return ArrayDataType::kComplex64;
130 default:
131 LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
132 }
133 }
134
Serialize(const Array & array,flatbuffers::FlatBufferBuilder * builder)135 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
136 const Array& array, flatbuffers::FlatBufferBuilder* builder) {
137 if (!array.buffer) return 0; // an empty buffer, usually an output.
138
139 switch (array.data_type) {
140 case ArrayDataType::kFloat:
141 return CopyBuffer<ArrayDataType::kFloat>(array, builder);
142 case ArrayDataType::kInt16:
143 return CopyBuffer<ArrayDataType::kInt16>(array, builder);
144 case ArrayDataType::kInt32:
145 return CopyBuffer<ArrayDataType::kInt32>(array, builder);
146 case ArrayDataType::kInt64:
147 return CopyBuffer<ArrayDataType::kInt64>(array, builder);
148 case ArrayDataType::kString:
149 return CopyStringToBuffer(array, builder);
150 case ArrayDataType::kUint8:
151 return CopyBuffer<ArrayDataType::kUint8>(array, builder);
152 case ArrayDataType::kBool:
153 return CopyBoolToBuffer(array, builder);
154 case ArrayDataType::kComplex64:
155 return CopyBuffer<ArrayDataType::kComplex64>(array, builder);
156 default:
157 LOG(FATAL) << "Unhandled array data type.";
158 }
159 }
160
Deserialize(const::tflite::Tensor & tensor,const::tflite::Buffer & buffer,Array * array)161 void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
162 const ::tflite::Buffer& buffer, Array* array) {
163 if (tensor.buffer() == 0) return; // an empty buffer, usually an output.
164 if (buffer.data() == nullptr) return; // a non-defined buffer.
165
166 switch (tensor.type()) {
167 case ::tflite::TensorType_FLOAT32:
168 return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
169 case ::tflite::TensorType_INT16:
170 return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
171 case ::tflite::TensorType_INT32:
172 return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
173 case ::tflite::TensorType_INT64:
174 return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
175 case ::tflite::TensorType_STRING:
176 return CopyStringFromBuffer(buffer, array);
177 case ::tflite::TensorType_UINT8:
178 return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
179 case ::tflite::TensorType_BOOL:
180 return CopyBuffer<ArrayDataType::kBool>(buffer, array);
181 case ::tflite::TensorType_COMPLEX64:
182 return CopyBuffer<ArrayDataType::kComplex64>(buffer, array);
183 default:
184 LOG(FATAL) << "Unhandled tensor type.";
185 }
186 }
187
Serialize(PaddingType padding_type)188 ::tflite::Padding Padding::Serialize(PaddingType padding_type) {
189 switch (padding_type) {
190 case PaddingType::kSame:
191 return ::tflite::Padding_SAME;
192 case PaddingType::kValid:
193 return ::tflite::Padding_VALID;
194 default:
195 LOG(FATAL) << "Unhandled padding type.";
196 }
197 }
198
Deserialize(int padding)199 PaddingType Padding::Deserialize(int padding) {
200 switch (::tflite::Padding(padding)) {
201 case ::tflite::Padding_SAME:
202 return PaddingType::kSame;
203 case ::tflite::Padding_VALID:
204 return PaddingType::kValid;
205 default:
206 LOG(FATAL) << "Unhandled padding.";
207 }
208 }
209
Serialize(FusedActivationFunctionType faf_type)210 ::tflite::ActivationFunctionType ActivationFunction::Serialize(
211 FusedActivationFunctionType faf_type) {
212 switch (faf_type) {
213 case FusedActivationFunctionType::kNone:
214 return ::tflite::ActivationFunctionType_NONE;
215 case FusedActivationFunctionType::kRelu:
216 return ::tflite::ActivationFunctionType_RELU;
217 case FusedActivationFunctionType::kRelu6:
218 return ::tflite::ActivationFunctionType_RELU6;
219 case FusedActivationFunctionType::kRelu1:
220 return ::tflite::ActivationFunctionType_RELU_N1_TO_1;
221 default:
222 LOG(FATAL) << "Unhandled fused activation function type.";
223 }
224 }
225
Deserialize(int activation_function)226 FusedActivationFunctionType ActivationFunction::Deserialize(
227 int activation_function) {
228 switch (::tflite::ActivationFunctionType(activation_function)) {
229 case ::tflite::ActivationFunctionType_NONE:
230 return FusedActivationFunctionType::kNone;
231 case ::tflite::ActivationFunctionType_RELU:
232 return FusedActivationFunctionType::kRelu;
233 case ::tflite::ActivationFunctionType_RELU6:
234 return FusedActivationFunctionType::kRelu6;
235 case ::tflite::ActivationFunctionType_RELU_N1_TO_1:
236 return FusedActivationFunctionType::kRelu1;
237 default:
238 LOG(FATAL) << "Unhandled fused activation function type.";
239 }
240 }
241
242 } // namespace tflite
243
244 } // namespace toco
245