1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/types.h"
16 #include "tensorflow/lite/string_util.h"
17 
18 namespace toco {
19 
20 namespace tflite {
21 
22 namespace {
23 
CopyStringToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)24 DataBuffer::FlatBufferOffset CopyStringToBuffer(
25     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
26   const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data;
27   ::tflite::DynamicBuffer dyn_buffer;
28   for (const string& str : src_data) {
29     dyn_buffer.AddString(str.c_str(), str.length());
30   }
31   char* tensor_buffer;
32   int bytes = dyn_buffer.WriteToBuffer(&tensor_buffer);
33   std::vector<uint8_t> dst_data(bytes);
34   memcpy(dst_data.data(), tensor_buffer, bytes);
35   free(tensor_buffer);
36   return builder->CreateVector(dst_data.data(), bytes);
37 }
38 
39 // vector<bool> may be implemented using a bit-set, so we can't just
40 // reinterpret_cast, accessing its data as vector<bool> and let flatbuffer
41 // CreateVector handle it.
42 // Background: https://isocpp.org/blog/2012/11/on-vectorbool
CopyBoolToBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)43 DataBuffer::FlatBufferOffset CopyBoolToBuffer(
44     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
45   const auto& src_data = array.GetBuffer<ArrayDataType::kBool>().data;
46   return builder->CreateVector(src_data);
47 }
48 
49 template <ArrayDataType T>
CopyBuffer(const Array & array,flatbuffers::FlatBufferBuilder * builder)50 DataBuffer::FlatBufferOffset CopyBuffer(
51     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
52   using NativeT = ::toco::DataType<T>;
53   const auto& src_data = array.GetBuffer<T>().data;
54   const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
55   auto size = src_data.size() * sizeof(NativeT);
56   return builder->CreateVector(dst_data, size);
57 }
58 
CopyStringFromBuffer(const::tflite::Buffer & buffer,Array * array)59 void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) {
60   auto* src_data = reinterpret_cast<const char*>(buffer.data()->data());
61   std::vector<string>* dst_data =
62       &array->GetMutableBuffer<ArrayDataType::kString>().data;
63   int32_t num_strings = ::tflite::GetStringCount(src_data);
64   for (int i = 0; i < num_strings; i++) {
65     ::tflite::StringRef str_ref = ::tflite::GetString(src_data, i);
66     string this_str(str_ref.str, str_ref.len);
67     dst_data->push_back(this_str);
68   }
69 }
70 
71 template <ArrayDataType T>
CopyBuffer(const::tflite::Buffer & buffer,Array * array)72 void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
73   using NativeT = ::toco::DataType<T>;
74   auto* src_buffer = buffer.data();
75   const NativeT* src_data =
76       reinterpret_cast<const NativeT*>(src_buffer->data());
77   int num_items = src_buffer->size() / sizeof(NativeT);
78 
79   std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
80   for (int i = 0; i < num_items; ++i) {
81     dst_data->push_back(*src_data);
82     ++src_data;
83   }
84 }
85 }  // namespace
86 
Serialize(ArrayDataType array_data_type)87 ::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
88   switch (array_data_type) {
89     case ArrayDataType::kFloat:
90       return ::tflite::TensorType_FLOAT32;
91     case ArrayDataType::kInt16:
92       return ::tflite::TensorType_INT16;
93     case ArrayDataType::kInt32:
94       return ::tflite::TensorType_INT32;
95     case ArrayDataType::kInt64:
96       return ::tflite::TensorType_INT64;
97     case ArrayDataType::kUint8:
98       return ::tflite::TensorType_UINT8;
99     case ArrayDataType::kString:
100       return ::tflite::TensorType_STRING;
101     case ArrayDataType::kBool:
102       return ::tflite::TensorType_BOOL;
103     case ArrayDataType::kComplex64:
104       return ::tflite::TensorType_COMPLEX64;
105     default:
106       // FLOAT32 is filled for unknown data types.
107       // TODO(ycling): Implement type inference in TF Lite interpreter.
108       return ::tflite::TensorType_FLOAT32;
109   }
110 }
111 
Deserialize(int tensor_type)112 ArrayDataType DataType::Deserialize(int tensor_type) {
113   switch (::tflite::TensorType(tensor_type)) {
114     case ::tflite::TensorType_FLOAT32:
115       return ArrayDataType::kFloat;
116     case ::tflite::TensorType_INT16:
117       return ArrayDataType::kInt16;
118     case ::tflite::TensorType_INT32:
119       return ArrayDataType::kInt32;
120     case ::tflite::TensorType_INT64:
121       return ArrayDataType::kInt64;
122     case ::tflite::TensorType_STRING:
123       return ArrayDataType::kString;
124     case ::tflite::TensorType_UINT8:
125       return ArrayDataType::kUint8;
126     case ::tflite::TensorType_BOOL:
127       return ArrayDataType::kBool;
128     case ::tflite::TensorType_COMPLEX64:
129       return ArrayDataType::kComplex64;
130     default:
131       LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
132   }
133 }
134 
Serialize(const Array & array,flatbuffers::FlatBufferBuilder * builder)135 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
136     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
137   if (!array.buffer) return 0;  // an empty buffer, usually an output.
138 
139   switch (array.data_type) {
140     case ArrayDataType::kFloat:
141       return CopyBuffer<ArrayDataType::kFloat>(array, builder);
142     case ArrayDataType::kInt16:
143       return CopyBuffer<ArrayDataType::kInt16>(array, builder);
144     case ArrayDataType::kInt32:
145       return CopyBuffer<ArrayDataType::kInt32>(array, builder);
146     case ArrayDataType::kInt64:
147       return CopyBuffer<ArrayDataType::kInt64>(array, builder);
148     case ArrayDataType::kString:
149       return CopyStringToBuffer(array, builder);
150     case ArrayDataType::kUint8:
151       return CopyBuffer<ArrayDataType::kUint8>(array, builder);
152     case ArrayDataType::kBool:
153       return CopyBoolToBuffer(array, builder);
154     case ArrayDataType::kComplex64:
155       return CopyBuffer<ArrayDataType::kComplex64>(array, builder);
156     default:
157       LOG(FATAL) << "Unhandled array data type.";
158   }
159 }
160 
Deserialize(const::tflite::Tensor & tensor,const::tflite::Buffer & buffer,Array * array)161 void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
162                              const ::tflite::Buffer& buffer, Array* array) {
163   if (tensor.buffer() == 0) return;      // an empty buffer, usually an output.
164   if (buffer.data() == nullptr) return;  // a non-defined buffer.
165 
166   switch (tensor.type()) {
167     case ::tflite::TensorType_FLOAT32:
168       return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
169     case ::tflite::TensorType_INT16:
170       return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
171     case ::tflite::TensorType_INT32:
172       return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
173     case ::tflite::TensorType_INT64:
174       return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
175     case ::tflite::TensorType_STRING:
176       return CopyStringFromBuffer(buffer, array);
177     case ::tflite::TensorType_UINT8:
178       return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
179     case ::tflite::TensorType_BOOL:
180       return CopyBuffer<ArrayDataType::kBool>(buffer, array);
181     case ::tflite::TensorType_COMPLEX64:
182       return CopyBuffer<ArrayDataType::kComplex64>(buffer, array);
183     default:
184       LOG(FATAL) << "Unhandled tensor type.";
185   }
186 }
187 
Serialize(PaddingType padding_type)188 ::tflite::Padding Padding::Serialize(PaddingType padding_type) {
189   switch (padding_type) {
190     case PaddingType::kSame:
191       return ::tflite::Padding_SAME;
192     case PaddingType::kValid:
193       return ::tflite::Padding_VALID;
194     default:
195       LOG(FATAL) << "Unhandled padding type.";
196   }
197 }
198 
Deserialize(int padding)199 PaddingType Padding::Deserialize(int padding) {
200   switch (::tflite::Padding(padding)) {
201     case ::tflite::Padding_SAME:
202       return PaddingType::kSame;
203     case ::tflite::Padding_VALID:
204       return PaddingType::kValid;
205     default:
206       LOG(FATAL) << "Unhandled padding.";
207   }
208 }
209 
Serialize(FusedActivationFunctionType faf_type)210 ::tflite::ActivationFunctionType ActivationFunction::Serialize(
211     FusedActivationFunctionType faf_type) {
212   switch (faf_type) {
213     case FusedActivationFunctionType::kNone:
214       return ::tflite::ActivationFunctionType_NONE;
215     case FusedActivationFunctionType::kRelu:
216       return ::tflite::ActivationFunctionType_RELU;
217     case FusedActivationFunctionType::kRelu6:
218       return ::tflite::ActivationFunctionType_RELU6;
219     case FusedActivationFunctionType::kRelu1:
220       return ::tflite::ActivationFunctionType_RELU_N1_TO_1;
221     default:
222       LOG(FATAL) << "Unhandled fused activation function type.";
223   }
224 }
225 
Deserialize(int activation_function)226 FusedActivationFunctionType ActivationFunction::Deserialize(
227     int activation_function) {
228   switch (::tflite::ActivationFunctionType(activation_function)) {
229     case ::tflite::ActivationFunctionType_NONE:
230       return FusedActivationFunctionType::kNone;
231     case ::tflite::ActivationFunctionType_RELU:
232       return FusedActivationFunctionType::kRelu;
233     case ::tflite::ActivationFunctionType_RELU6:
234       return FusedActivationFunctionType::kRelu6;
235     case ::tflite::ActivationFunctionType_RELU_N1_TO_1:
236       return FusedActivationFunctionType::kRelu1;
237     default:
238       LOG(FATAL) << "Unhandled fused activation function type.";
239   }
240 }
241 
242 }  // namespace tflite
243 
244 }  // namespace toco
245