1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/lua-utils.h"
18 
19 namespace libtextclassifier3 {
20 namespace {
21 static constexpr const char* kTextKey = "text";
22 static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
23 static constexpr const char* kGranularityKey = "granularity";
24 static constexpr const char* kCollectionKey = "collection";
25 static constexpr const char* kNameKey = "name";
26 static constexpr const char* kScoreKey = "score";
27 static constexpr const char* kPriorityScoreKey = "priority_score";
28 static constexpr const char* kTypeKey = "type";
29 static constexpr const char* kResponseTextKey = "response_text";
30 static constexpr const char* kAnnotationKey = "annotation";
31 static constexpr const char* kSpanKey = "span";
32 static constexpr const char* kMessageKey = "message";
33 static constexpr const char* kBeginKey = "begin";
34 static constexpr const char* kEndKey = "end";
35 static constexpr const char* kClassificationKey = "classification";
36 static constexpr const char* kSerializedEntity = "serialized_entity";
37 static constexpr const char* kEntityKey = "entity";
38 }  // namespace
39 
40 template <>
Item(const std::vector<ClassificationResult> * annotations,StringPiece key,lua_State * state) const41 int AnnotationIterator<ClassificationResult>::Item(
42     const std::vector<ClassificationResult>* annotations, StringPiece key,
43     lua_State* state) const {
44   // Lookup annotation by collection.
45   for (const ClassificationResult& annotation : *annotations) {
46     if (key.Equals(annotation.collection)) {
47       PushAnnotation(annotation, entity_data_schema_, env_);
48       return 1;
49     }
50   }
51   TC3_LOG(ERROR) << "No annotation with collection: " << key.ToString()
52                  << " found.";
53   lua_error(state);
54   return 0;
55 }
56 
57 template <>
Item(const std::vector<ActionSuggestionAnnotation> * annotations,StringPiece key,lua_State * state) const58 int AnnotationIterator<ActionSuggestionAnnotation>::Item(
59     const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
60     lua_State* state) const {
61   // Lookup annotation by name.
62   for (const ActionSuggestionAnnotation& annotation : *annotations) {
63     if (key.Equals(annotation.name)) {
64       PushAnnotation(annotation, entity_data_schema_, env_);
65       return 1;
66     }
67   }
68   TC3_LOG(ERROR) << "No annotation with name: " << key.ToString() << " found.";
69   lua_error(state);
70   return 0;
71 }
72 
PushAnnotation(const ClassificationResult & classification,const reflection::Schema * entity_data_schema,LuaEnvironment * env)73 void PushAnnotation(const ClassificationResult& classification,
74                     const reflection::Schema* entity_data_schema,
75                     LuaEnvironment* env) {
76   if (entity_data_schema == nullptr ||
77       classification.serialized_entity_data.empty()) {
78     // Empty table.
79     lua_newtable(env->state());
80   } else {
81     env->PushFlatbuffer(entity_data_schema,
82                         flatbuffers::GetRoot<flatbuffers::Table>(
83                             classification.serialized_entity_data.data()));
84   }
85   lua_pushinteger(env->state(),
86                   classification.datetime_parse_result.time_ms_utc);
87   lua_setfield(env->state(), /*idx=*/-2, kTimeUsecKey);
88   lua_pushinteger(env->state(),
89                   classification.datetime_parse_result.granularity);
90   lua_setfield(env->state(), /*idx=*/-2, kGranularityKey);
91   env->PushString(classification.collection);
92   lua_setfield(env->state(), /*idx=*/-2, kCollectionKey);
93   lua_pushnumber(env->state(), classification.score);
94   lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
95   env->PushString(classification.serialized_entity_data);
96   lua_setfield(env->state(), /*idx=*/-2, kSerializedEntity);
97 }
98 
PushAnnotation(const ClassificationResult & classification,StringPiece text,const reflection::Schema * entity_data_schema,LuaEnvironment * env)99 void PushAnnotation(const ClassificationResult& classification,
100                     StringPiece text,
101                     const reflection::Schema* entity_data_schema,
102                     LuaEnvironment* env) {
103   PushAnnotation(classification, entity_data_schema, env);
104   env->PushString(text);
105   lua_setfield(env->state(), /*idx=*/-2, kTextKey);
106 }
107 
PushAnnotatedSpan(const AnnotatedSpan & annotated_span,const AnnotationIterator<ClassificationResult> & annotation_iterator,LuaEnvironment * env)108 void PushAnnotatedSpan(
109     const AnnotatedSpan& annotated_span,
110     const AnnotationIterator<ClassificationResult>& annotation_iterator,
111     LuaEnvironment* env) {
112   lua_newtable(env->state());
113   {
114     lua_newtable(env->state());
115     lua_pushinteger(env->state(), annotated_span.span.first);
116     lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
117     lua_pushinteger(env->state(), annotated_span.span.second);
118     lua_setfield(env->state(), /*idx=*/-2, kEndKey);
119   }
120   lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
121   annotation_iterator.NewIterator(kClassificationKey,
122                                   &annotated_span.classification, env->state());
123   lua_setfield(env->state(), /*idx=*/-2, kClassificationKey);
124 }
125 
ReadSpan(LuaEnvironment * env)126 MessageTextSpan ReadSpan(LuaEnvironment* env) {
127   MessageTextSpan span;
128   lua_pushnil(env->state());
129   while (lua_next(env->state(), /*idx=*/-2)) {
130     const StringPiece key = env->ReadString(/*index=*/-2);
131     if (key.Equals(kMessageKey)) {
132       span.message_index =
133           static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
134     } else if (key.Equals(kBeginKey)) {
135       span.span.first =
136           static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
137     } else if (key.Equals(kEndKey)) {
138       span.span.second =
139           static_cast<int>(lua_tonumber(env->state(), /*idx=*/-1));
140     } else if (key.Equals(kTextKey)) {
141       span.text = env->ReadString(/*index=*/-1).ToString();
142     } else {
143       TC3_LOG(INFO) << "Unknown span field: " << key.ToString();
144     }
145     lua_pop(env->state(), 1);
146   }
147   return span;
148 }
149 
ReadAnnotations(const reflection::Schema * entity_data_schema,LuaEnvironment * env,std::vector<ActionSuggestionAnnotation> * annotations)150 int ReadAnnotations(const reflection::Schema* entity_data_schema,
151                     LuaEnvironment* env,
152                     std::vector<ActionSuggestionAnnotation>* annotations) {
153   if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
154     TC3_LOG(ERROR) << "Expected annotations table, got: "
155                    << lua_type(env->state(), /*idx=*/-1);
156     lua_pop(env->state(), 1);
157     lua_error(env->state());
158     return LUA_ERRRUN;
159   }
160 
161   // Read actions.
162   lua_pushnil(env->state());
163   while (lua_next(env->state(), /*idx=*/-2)) {
164     if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
165       TC3_LOG(ERROR) << "Expected annotation table, got: "
166                      << lua_type(env->state(), /*idx=*/-1);
167       lua_pop(env->state(), 1);
168       continue;
169     }
170     annotations->push_back(ReadAnnotation(entity_data_schema, env));
171     lua_pop(env->state(), 1);
172   }
173   return LUA_OK;
174 }
175 
ReadAnnotation(const reflection::Schema * entity_data_schema,LuaEnvironment * env)176 ActionSuggestionAnnotation ReadAnnotation(
177     const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
178   ActionSuggestionAnnotation annotation;
179   lua_pushnil(env->state());
180   while (lua_next(env->state(), /*idx=*/-2)) {
181     const StringPiece key = env->ReadString(/*index=*/-2);
182     if (key.Equals(kNameKey)) {
183       annotation.name = env->ReadString(/*index=*/-1).ToString();
184     } else if (key.Equals(kSpanKey)) {
185       annotation.span = ReadSpan(env);
186     } else if (key.Equals(kEntityKey)) {
187       annotation.entity = ReadClassificationResult(entity_data_schema, env);
188     } else {
189       TC3_LOG(ERROR) << "Unknown annotation field: " << key.ToString();
190     }
191     lua_pop(env->state(), 1);
192   }
193   return annotation;
194 }
195 
ReadClassificationResult(const reflection::Schema * entity_data_schema,LuaEnvironment * env)196 ClassificationResult ReadClassificationResult(
197     const reflection::Schema* entity_data_schema, LuaEnvironment* env) {
198   ClassificationResult classification;
199   lua_pushnil(env->state());
200   while (lua_next(env->state(), /*idx=*/-2)) {
201     const StringPiece key = env->ReadString(/*index=*/-2);
202     if (key.Equals(kCollectionKey)) {
203       classification.collection = env->ReadString(/*index=*/-1).ToString();
204     } else if (key.Equals(kScoreKey)) {
205       classification.score =
206           static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
207     } else if (key.Equals(kTimeUsecKey)) {
208       classification.datetime_parse_result.time_ms_utc =
209           static_cast<int64>(lua_tonumber(env->state(), /*idx=*/-1));
210     } else if (key.Equals(kGranularityKey)) {
211       classification.datetime_parse_result.granularity =
212           static_cast<DatetimeGranularity>(
213               lua_tonumber(env->state(), /*idx=*/-1));
214     } else if (key.Equals(kSerializedEntity)) {
215       classification.serialized_entity_data =
216           env->ReadString(/*index=*/-1).ToString();
217     } else if (key.Equals(kEntityKey)) {
218       auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot();
219       env->ReadFlatbuffer(buffer.get());
220       classification.serialized_entity_data = buffer->Serialize();
221     } else {
222       TC3_LOG(INFO) << "Unknown classification result field: "
223                     << key.ToString();
224     }
225     lua_pop(env->state(), 1);
226   }
227   return classification;
228 }
229 
PushAnnotation(const ActionSuggestionAnnotation & annotation,const reflection::Schema * entity_data_schema,LuaEnvironment * env)230 void PushAnnotation(const ActionSuggestionAnnotation& annotation,
231                     const reflection::Schema* entity_data_schema,
232                     LuaEnvironment* env) {
233   PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema,
234                  env);
235   env->PushString(annotation.name);
236   lua_setfield(env->state(), /*idx=*/-2, kNameKey);
237   {
238     lua_newtable(env->state());
239     lua_pushinteger(env->state(), annotation.span.message_index);
240     lua_setfield(env->state(), /*idx=*/-2, kMessageKey);
241     lua_pushinteger(env->state(), annotation.span.span.first);
242     lua_setfield(env->state(), /*idx=*/-2, kBeginKey);
243     lua_pushinteger(env->state(), annotation.span.span.second);
244     lua_setfield(env->state(), /*idx=*/-2, kEndKey);
245   }
246   lua_setfield(env->state(), /*idx=*/-2, kSpanKey);
247 }
248 
PushAction(const ActionSuggestion & action,const reflection::Schema * entity_data_schema,const AnnotationIterator<ActionSuggestionAnnotation> & annotation_iterator,LuaEnvironment * env)249 void PushAction(
250     const ActionSuggestion& action,
251     const reflection::Schema* entity_data_schema,
252     const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
253     LuaEnvironment* env) {
254   if (entity_data_schema == nullptr || action.serialized_entity_data.empty()) {
255     // Empty table.
256     lua_newtable(env->state());
257   } else {
258     env->PushFlatbuffer(entity_data_schema,
259                         flatbuffers::GetRoot<flatbuffers::Table>(
260                             action.serialized_entity_data.data()));
261   }
262   env->PushString(action.type);
263   lua_setfield(env->state(), /*idx=*/-2, kTypeKey);
264   env->PushString(action.response_text);
265   lua_setfield(env->state(), /*idx=*/-2, kResponseTextKey);
266   lua_pushnumber(env->state(), action.score);
267   lua_setfield(env->state(), /*idx=*/-2, kScoreKey);
268   lua_pushnumber(env->state(), action.priority_score);
269   lua_setfield(env->state(), /*idx=*/-2, kPriorityScoreKey);
270   annotation_iterator.NewIterator(kAnnotationKey, &action.annotations,
271                                   env->state());
272   lua_setfield(env->state(), /*idx=*/-2, kAnnotationKey);
273 }
274 
ReadAction(const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema,LuaEnvironment * env)275 ActionSuggestion ReadAction(
276     const reflection::Schema* actions_entity_data_schema,
277     const reflection::Schema* annotations_entity_data_schema,
278     LuaEnvironment* env) {
279   ActionSuggestion action;
280   lua_pushnil(env->state());
281   while (lua_next(env->state(), /*idx=*/-2)) {
282     const StringPiece key = env->ReadString(/*index=*/-2);
283     if (key.Equals(kResponseTextKey)) {
284       action.response_text = env->ReadString(/*index=*/-1).ToString();
285     } else if (key.Equals(kTypeKey)) {
286       action.type = env->ReadString(/*index=*/-1).ToString();
287     } else if (key.Equals(kScoreKey)) {
288       action.score = static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
289     } else if (key.Equals(kPriorityScoreKey)) {
290       action.priority_score =
291           static_cast<float>(lua_tonumber(env->state(), /*idx=*/-1));
292     } else if (key.Equals(kAnnotationKey)) {
293       ReadAnnotations(actions_entity_data_schema, env, &action.annotations);
294     } else if (key.Equals(kEntityKey)) {
295       auto buffer =
296           ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot();
297       env->ReadFlatbuffer(buffer.get());
298       action.serialized_entity_data = buffer->Serialize();
299     } else {
300       TC3_LOG(INFO) << "Unknown action field: " << key.ToString();
301     }
302     lua_pop(env->state(), 1);
303   }
304   return action;
305 }
306 
ReadActions(const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema,LuaEnvironment * env,std::vector<ActionSuggestion> * actions)307 int ReadActions(const reflection::Schema* actions_entity_data_schema,
308                 const reflection::Schema* annotations_entity_data_schema,
309                 LuaEnvironment* env, std::vector<ActionSuggestion>* actions) {
310   if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
311     TC3_LOG(ERROR) << "Expected actions table, got: "
312                    << lua_type(env->state(), /*idx=*/-1);
313     lua_pop(env->state(), 1);
314     lua_error(env->state());
315     return LUA_ERRRUN;
316   }
317 
318   // Read actions.
319   lua_pushnil(env->state());
320   while (lua_next(env->state(), /*idx=*/-2)) {
321     if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) {
322       TC3_LOG(ERROR) << "Expected action table, got: "
323                      << lua_type(env->state(), /*idx=*/-1);
324       lua_pop(env->state(), 1);
325       continue;
326     }
327     actions->push_back(ReadAction(actions_entity_data_schema,
328                                   annotations_entity_data_schema, env));
329     lua_pop(env->state(), /*n=1*/ 1);
330   }
331   lua_pop(env->state(), /*n=*/1);
332 
333   return LUA_OK;
334 }
335 
Item(const std::vector<ConversationMessage> * messages,const int64 pos,lua_State * state) const336 int ConversationIterator::Item(const std::vector<ConversationMessage>* messages,
337                                const int64 pos, lua_State* state) const {
338   const ConversationMessage& message = (*messages)[pos];
339   lua_newtable(state);
340   lua_pushinteger(state, message.user_id);
341   lua_setfield(state, /*idx=*/-2, "user_id");
342   env_->PushString(message.text);
343   lua_setfield(state, /*idx=*/-2, "text");
344   lua_pushinteger(state, message.reference_time_ms_utc);
345   lua_setfield(state, /*idx=*/-2, "time_ms_utc");
346   env_->PushString(message.reference_timezone);
347   lua_setfield(state, /*idx=*/-2, "timezone");
348   annotated_span_iterator_.NewIterator("annotation", &message.annotations,
349                                        state);
350   lua_setfield(state, /*idx=*/-2, "annotation");
351   return 1;
352 }
353 
354 }  // namespace libtextclassifier3
355