1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/flatbuffers/reflection.h"
18 
19 namespace libtextclassifier3 {
20 
GetFieldOrNull(const reflection::Object * type,const StringPiece field_name)21 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
22                                         const StringPiece field_name) {
23   TC3_CHECK(type != nullptr && type->fields() != nullptr);
24   return type->fields()->LookupByKey(field_name.data());
25 }
26 
GetFieldOrNull(const reflection::Object * type,const int field_offset)27 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
28                                         const int field_offset) {
29   if (type->fields() == nullptr) {
30     return nullptr;
31   }
32   for (const reflection::Field* field : *type->fields()) {
33     if (field->offset() == field_offset) {
34       return field;
35     }
36   }
37   return nullptr;
38 }
39 
GetFieldOrNull(const reflection::Object * type,const StringPiece field_name,const int field_offset)40 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
41                                         const StringPiece field_name,
42                                         const int field_offset) {
43   // Lookup by name might be faster as the fields are sorted by name in the
44   // schema data, so try that first.
45   if (!field_name.empty()) {
46     return GetFieldOrNull(type, field_name.data());
47   }
48   return GetFieldOrNull(type, field_offset);
49 }
50 
GetFieldOrNull(const reflection::Object * type,const FlatbufferField * field)51 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
52                                         const FlatbufferField* field) {
53   TC3_CHECK(type != nullptr && field != nullptr);
54   if (field->field_name() == nullptr) {
55     return GetFieldOrNull(type, field->field_offset());
56   }
57   return GetFieldOrNull(
58       type,
59       StringPiece(field->field_name()->data(), field->field_name()->size()),
60       field->field_offset());
61 }
62 
GetFieldOrNull(const reflection::Object * type,const FlatbufferFieldT * field)63 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
64                                         const FlatbufferFieldT* field) {
65   TC3_CHECK(type != nullptr && field != nullptr);
66   return GetFieldOrNull(type, field->field_name, field->field_offset);
67 }
68 
TypeForName(const reflection::Schema * schema,const StringPiece type_name)69 const reflection::Object* TypeForName(const reflection::Schema* schema,
70                                       const StringPiece type_name) {
71   for (const reflection::Object* object : *schema->objects()) {
72     if (type_name.Equals(object->name()->str())) {
73       return object;
74     }
75   }
76   return nullptr;
77 }
78 
TypeIdForObject(const reflection::Schema * schema,const reflection::Object * type)79 Optional<int> TypeIdForObject(const reflection::Schema* schema,
80                               const reflection::Object* type) {
81   for (int i = 0; i < schema->objects()->size(); i++) {
82     if (schema->objects()->Get(i) == type) {
83       return Optional<int>(i);
84     }
85   }
86   return Optional<int>();
87 }
88 
TypeIdForName(const reflection::Schema * schema,const StringPiece type_name)89 Optional<int> TypeIdForName(const reflection::Schema* schema,
90                             const StringPiece type_name) {
91   for (int i = 0; i < schema->objects()->size(); i++) {
92     if (type_name.Equals(schema->objects()->Get(i)->name()->str())) {
93       return Optional<int>(i);
94     }
95   }
96   return Optional<int>();
97 }
98 
SwapFieldNamesForOffsetsInPath(const reflection::Schema * schema,FlatbufferFieldPathT * path)99 bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
100                                     FlatbufferFieldPathT* path) {
101   if (schema == nullptr || !schema->root_table()) {
102     TC3_LOG(ERROR) << "Empty schema provided.";
103     return false;
104   }
105 
106   reflection::Object const* type = schema->root_table();
107   for (int i = 0; i < path->field.size(); i++) {
108     const reflection::Field* field = GetFieldOrNull(type, path->field[i].get());
109     if (field == nullptr) {
110       TC3_LOG(ERROR) << "Could not find field: " << path->field[i]->field_name;
111       return false;
112     }
113     path->field[i]->field_name.clear();
114     path->field[i]->field_offset = field->offset();
115 
116     // Descend.
117     if (i < path->field.size() - 1) {
118       if (field->type()->base_type() != reflection::Obj) {
119         TC3_LOG(ERROR) << "Field: " << field->name()->str()
120                        << " is not of type `Object`.";
121         return false;
122       }
123       type = schema->objects()->Get(field->type()->index());
124     }
125   }
126   return true;
127 }
128 
ParseEnumValue(const reflection::Schema * schema,const reflection::Type * type,StringPiece value)129 Variant ParseEnumValue(const reflection::Schema* schema,
130                        const reflection::Type* type, StringPiece value) {
131   TC3_DCHECK(IsEnum(type));
132   TC3_CHECK_NE(schema->enums(), nullptr);
133   const auto* enum_values = schema->enums()->Get(type->index())->values();
134   if (enum_values == nullptr) {
135     TC3_LOG(ERROR) << "Enum has no specified values.";
136     return Variant();
137   }
138   for (const reflection::EnumVal* enum_value : *enum_values) {
139     if (value.Equals(StringPiece(enum_value->name()->c_str(),
140                                  enum_value->name()->size()))) {
141       const int64 value = enum_value->value();
142       switch (type->base_type()) {
143         case reflection::BaseType::Byte:
144           return Variant(static_cast<int8>(value));
145         case reflection::BaseType::UByte:
146           return Variant(static_cast<uint8>(value));
147         case reflection::BaseType::Short:
148           return Variant(static_cast<int16>(value));
149         case reflection::BaseType::UShort:
150           return Variant(static_cast<uint16>(value));
151         case reflection::BaseType::Int:
152           return Variant(static_cast<int32>(value));
153         case reflection::BaseType::UInt:
154           return Variant(static_cast<uint32>(value));
155         case reflection::BaseType::Long:
156           return Variant(value);
157         case reflection::BaseType::ULong:
158           return Variant(static_cast<uint64>(value));
159         default:
160           break;
161       }
162     }
163   }
164   return Variant();
165 }
166 
167 }  // namespace libtextclassifier3
168