1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/flatbuffers.h"
18 
19 #include <vector>
20 #include "utils/strings/numbers.h"
21 #include "utils/variant.h"
22 
23 namespace libtextclassifier3 {
24 namespace {
CreateRepeatedField(const reflection::Schema * schema,const reflection::Type * type,std::unique_ptr<ReflectiveFlatbuffer::RepeatedField> * repeated_field)25 bool CreateRepeatedField(
26     const reflection::Schema* schema, const reflection::Type* type,
27     std::unique_ptr<ReflectiveFlatbuffer::RepeatedField>* repeated_field) {
28   switch (type->element()) {
29     case reflection::Bool:
30       repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<bool>);
31       return true;
32     case reflection::Int:
33       repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<int>);
34       return true;
35     case reflection::Long:
36       repeated_field->reset(
37           new ReflectiveFlatbuffer::TypedRepeatedField<int64>);
38       return true;
39     case reflection::Float:
40       repeated_field->reset(
41           new ReflectiveFlatbuffer::TypedRepeatedField<float>);
42       return true;
43     case reflection::Double:
44       repeated_field->reset(
45           new ReflectiveFlatbuffer::TypedRepeatedField<double>);
46       return true;
47     case reflection::String:
48       repeated_field->reset(
49           new ReflectiveFlatbuffer::TypedRepeatedField<std::string>);
50       return true;
51     case reflection::Obj:
52       repeated_field->reset(
53           new ReflectiveFlatbuffer::TypedRepeatedField<ReflectiveFlatbuffer>(
54               schema, type));
55       return true;
56     default:
57       TC3_LOG(ERROR) << "Unsupported type: " << type->element();
58       return false;
59   }
60 }
61 }  // namespace
62 
63 template <>
FlatbufferFileIdentifier()64 const char* FlatbufferFileIdentifier<Model>() {
65   return ModelIdentifier();
66 }
67 
NewRoot() const68 std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot()
69     const {
70   if (!schema_->root_table()) {
71     TC3_LOG(ERROR) << "No root table specified.";
72     return nullptr;
73   }
74   return std::unique_ptr<ReflectiveFlatbuffer>(
75       new ReflectiveFlatbuffer(schema_, schema_->root_table()));
76 }
77 
NewTable(StringPiece table_name) const78 std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable(
79     StringPiece table_name) const {
80   for (const reflection::Object* object : *schema_->objects()) {
81     if (table_name.Equals(object->name()->str())) {
82       return std::unique_ptr<ReflectiveFlatbuffer>(
83           new ReflectiveFlatbuffer(schema_, object));
84     }
85   }
86   return nullptr;
87 }
88 
GetFieldOrNull(const StringPiece field_name) const89 const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
90     const StringPiece field_name) const {
91   return type_->fields()->LookupByKey(field_name.data());
92 }
93 
GetFieldOrNull(const FlatbufferField * field) const94 const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
95     const FlatbufferField* field) const {
96   // Lookup by name might be faster as the fields are sorted by name in the
97   // schema data, so try that first.
98   if (field->field_name() != nullptr) {
99     return GetFieldOrNull(field->field_name()->str());
100   }
101   return GetFieldByOffsetOrNull(field->field_offset());
102 }
103 
GetFieldWithParent(const FlatbufferFieldPath * field_path,ReflectiveFlatbuffer ** parent,reflection::Field const ** field)104 bool ReflectiveFlatbuffer::GetFieldWithParent(
105     const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent,
106     reflection::Field const** field) {
107   const auto* path = field_path->field();
108   if (path == nullptr || path->size() == 0) {
109     return false;
110   }
111 
112   for (int i = 0; i < path->size(); i++) {
113     *parent = (i == 0 ? this : (*parent)->Mutable(*field));
114     if (*parent == nullptr) {
115       return false;
116     }
117     *field = (*parent)->GetFieldOrNull(path->Get(i));
118     if (*field == nullptr) {
119       return false;
120     }
121   }
122 
123   return true;
124 }
125 
GetFieldByOffsetOrNull(const int field_offset) const126 const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull(
127     const int field_offset) const {
128   if (type_->fields() == nullptr) {
129     return nullptr;
130   }
131   for (const reflection::Field* field : *type_->fields()) {
132     if (field->offset() == field_offset) {
133       return field;
134     }
135   }
136   return nullptr;
137 }
138 
IsMatchingType(const reflection::Field * field,const Variant & value) const139 bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field,
140                                           const Variant& value) const {
141   switch (field->type()->base_type()) {
142     case reflection::Bool:
143       return value.HasBool();
144     case reflection::Int:
145       return value.HasInt();
146     case reflection::Long:
147       return value.HasInt64();
148     case reflection::Float:
149       return value.HasFloat();
150     case reflection::Double:
151       return value.HasDouble();
152     case reflection::String:
153       return value.HasString();
154     default:
155       return false;
156   }
157 }
158 
ParseAndSet(const reflection::Field * field,const std::string & value)159 bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
160                                        const std::string& value) {
161   switch (field->type()->base_type()) {
162     case reflection::String:
163       return Set(field, value);
164     case reflection::Int: {
165       int32 int_value;
166       if (!ParseInt32(value.data(), &int_value)) {
167         TC3_LOG(ERROR) << "Could not parse '" << value << "' as int32.";
168         return false;
169       }
170       return Set(field, int_value);
171     }
172     case reflection::Long: {
173       int64 int_value;
174       if (!ParseInt64(value.data(), &int_value)) {
175         TC3_LOG(ERROR) << "Could not parse '" << value << "' as int64.";
176         return false;
177       }
178       return Set(field, int_value);
179     }
180     case reflection::Float: {
181       double double_value;
182       if (!ParseDouble(value.data(), &double_value)) {
183         TC3_LOG(ERROR) << "Could not parse '" << value << "' as float.";
184         return false;
185       }
186       return Set(field, static_cast<float>(double_value));
187     }
188     case reflection::Double: {
189       double double_value;
190       if (!ParseDouble(value.data(), &double_value)) {
191         TC3_LOG(ERROR) << "Could not parse '" << value << "' as double.";
192         return false;
193       }
194       return Set(field, double_value);
195     }
196     default:
197       TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type();
198       return false;
199   }
200 }
201 
ParseAndSet(const FlatbufferFieldPath * path,const std::string & value)202 bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
203                                        const std::string& value) {
204   ReflectiveFlatbuffer* parent;
205   const reflection::Field* field;
206   if (!GetFieldWithParent(path, &parent, &field)) {
207     return false;
208   }
209   return parent->ParseAndSet(field, value);
210 }
211 
Mutable(const StringPiece field_name)212 ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
213     const StringPiece field_name) {
214   if (const reflection::Field* field = GetFieldOrNull(field_name)) {
215     return Mutable(field);
216   }
217   TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
218   return nullptr;
219 }
220 
Mutable(const reflection::Field * field)221 ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
222     const reflection::Field* field) {
223   if (field->type()->base_type() != reflection::Obj) {
224     TC3_LOG(ERROR) << "Field is not of type Object.";
225     return nullptr;
226   }
227   const auto entry = children_.find(field);
228   if (entry != children_.end()) {
229     return entry->second.get();
230   }
231   const auto it = children_.insert(
232       /*hint=*/entry,
233       std::make_pair(
234           field,
235           std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer(
236               schema_, schema_->objects()->Get(field->type()->index())))));
237   return it->second.get();
238 }
239 
Repeated(StringPiece field_name)240 ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
241     StringPiece field_name) {
242   if (const reflection::Field* field = GetFieldOrNull(field_name)) {
243     return Repeated(field);
244   }
245   TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString();
246   return nullptr;
247 }
248 
Repeated(const reflection::Field * field)249 ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated(
250     const reflection::Field* field) {
251   if (field->type()->base_type() != reflection::Vector) {
252     TC3_LOG(ERROR) << "Field is not of type Vector.";
253     return nullptr;
254   }
255 
256   // If the repeated field was already set, return its instance.
257   const auto entry = repeated_fields_.find(field);
258   if (entry != repeated_fields_.end()) {
259     return entry->second.get();
260   }
261 
262   // Otherwise, create a new instance and store it.
263   std::unique_ptr<RepeatedField> repeated_field;
264   if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) {
265     TC3_LOG(ERROR) << "Could not create repeated field.";
266     return nullptr;
267   }
268   const auto it = repeated_fields_.insert(
269       /*hint=*/entry, std::make_pair(field, std::move(repeated_field)));
270   return it->second.get();
271 }
272 
Serialize(flatbuffers::FlatBufferBuilder * builder) const273 flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
274     flatbuffers::FlatBufferBuilder* builder) const {
275   // Build all children before we can start with this table.
276   std::vector<
277       std::pair</* field vtable offset */ int,
278                 /* field data offset in buffer */ flatbuffers::uoffset_t>>
279       offsets;
280   offsets.reserve(children_.size() + repeated_fields_.size());
281   for (const auto& it : children_) {
282     offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
283   }
284 
285   // Create strings.
286   for (const auto& it : fields_) {
287     if (it.second.HasString()) {
288       offsets.push_back({it.first->offset(),
289                          builder->CreateString(it.second.StringValue()).o});
290     }
291   }
292 
293   // Build the repeated fields.
294   for (const auto& it : repeated_fields_) {
295     offsets.push_back({it.first->offset(), it.second->Serialize(builder)});
296   }
297 
298   // Build the table now.
299   const flatbuffers::uoffset_t table_start = builder->StartTable();
300 
301   // Add scalar fields.
302   for (const auto& it : fields_) {
303     switch (it.second.GetType()) {
304       case Variant::TYPE_BOOL_VALUE:
305         builder->AddElement<uint8_t>(
306             it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()),
307             static_cast<uint8_t>(it.first->default_integer()));
308         continue;
309       case Variant::TYPE_INT_VALUE:
310         builder->AddElement<int32>(
311             it.first->offset(), it.second.IntValue(),
312             static_cast<int32>(it.first->default_integer()));
313         continue;
314       case Variant::TYPE_INT64_VALUE:
315         builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(),
316                                    it.first->default_integer());
317         continue;
318       case Variant::TYPE_FLOAT_VALUE:
319         builder->AddElement<float>(
320             it.first->offset(), it.second.FloatValue(),
321             static_cast<float>(it.first->default_real()));
322         continue;
323       case Variant::TYPE_DOUBLE_VALUE:
324         builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(),
325                                     it.first->default_real());
326         continue;
327       default:
328         continue;
329     }
330   }
331 
332   // Add strings, subtables and repeated fields.
333   for (const auto& it : offsets) {
334     builder->AddOffset(it.first, flatbuffers::Offset<void>(it.second));
335   }
336 
337   return builder->EndTable(table_start);
338 }
339 
Serialize() const340 std::string ReflectiveFlatbuffer::Serialize() const {
341   flatbuffers::FlatBufferBuilder builder;
342   builder.Finish(flatbuffers::Offset<void>(Serialize(&builder)));
343   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
344                      builder.GetSize());
345 }
346 
MergeFrom(const flatbuffers::Table * from)347 bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
348   // No fields to set.
349   if (type_->fields() == nullptr) {
350     return true;
351   }
352 
353   for (const reflection::Field* field : *type_->fields()) {
354     // Skip fields that are not explicitly set.
355     if (!from->CheckField(field->offset())) {
356       continue;
357     }
358     const reflection::BaseType type = field->type()->base_type();
359     switch (type) {
360       case reflection::Bool:
361         Set<bool>(field, from->GetField<uint8_t>(field->offset(),
362                                                  field->default_integer()));
363         break;
364       case reflection::Int:
365         Set<int32>(field, from->GetField<int32>(field->offset(),
366                                                 field->default_integer()));
367         break;
368       case reflection::Long:
369         Set<int64>(field, from->GetField<int64>(field->offset(),
370                                                 field->default_integer()));
371         break;
372       case reflection::Float:
373         Set<float>(field, from->GetField<float>(field->offset(),
374                                                 field->default_real()));
375         break;
376       case reflection::Double:
377         Set<double>(field, from->GetField<double>(field->offset(),
378                                                   field->default_real()));
379         break;
380       case reflection::String:
381         Set<std::string>(
382             field, from->GetPointer<const flatbuffers::String*>(field->offset())
383                        ->str());
384         break;
385       case reflection::Obj:
386         if (!Mutable(field)->MergeFrom(
387                 from->GetPointer<const flatbuffers::Table* const>(
388                     field->offset()))) {
389           return false;
390         }
391         break;
392       default:
393         TC3_LOG(ERROR) << "Unsupported type: " << type;
394         return false;
395     }
396   }
397   return true;
398 }
399 
MergeFromSerializedFlatbuffer(StringPiece from)400 bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) {
401   return MergeFrom(flatbuffers::GetAnyRoot(
402       reinterpret_cast<const unsigned char*>(from.data())));
403 }
404 
AsFlatMap(const std::string & key_separator,const std::string & key_prefix,std::map<std::string,Variant> * result) const405 void ReflectiveFlatbuffer::AsFlatMap(
406     const std::string& key_separator, const std::string& key_prefix,
407     std::map<std::string, Variant>* result) const {
408   // Add direct fields.
409   for (auto it : fields_) {
410     (*result)[key_prefix + it.first->name()->str()] = it.second;
411   }
412 
413   // Add nested messages.
414   for (auto& it : children_) {
415     it.second->AsFlatMap(key_separator,
416                          key_prefix + it.first->name()->str() + key_separator,
417                          result);
418   }
419 }
420 
421 }  // namespace libtextclassifier3
422