1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/lua-utils.h"
18 
19 namespace libtextclassifier3 {
20 namespace {
21 static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
22                                            {LUA_TABLIBNAME, luaopen_table},
23                                            {LUA_STRLIBNAME, luaopen_string},
24                                            {LUA_MATHLIBNAME, luaopen_math},
25                                            {nullptr, nullptr}};
26 
27 static constexpr const char kTextKey[] = "text";
28 static constexpr const char kTimeUsecKey[] = "parsed_time_ms_utc";
29 static constexpr const char kGranularityKey[] = "granularity";
30 static constexpr const char kCollectionKey[] = "collection";
31 static constexpr const char kNameKey[] = "name";
32 static constexpr const char kScoreKey[] = "score";
33 static constexpr const char kPriorityScoreKey[] = "priority_score";
34 static constexpr const char kTypeKey[] = "type";
35 static constexpr const char kResponseTextKey[] = "response_text";
36 static constexpr const char kAnnotationKey[] = "annotation";
37 static constexpr const char kSpanKey[] = "span";
38 static constexpr const char kMessageKey[] = "message";
39 static constexpr const char kBeginKey[] = "begin";
40 static constexpr const char kEndKey[] = "end";
41 static constexpr const char kClassificationKey[] = "classification";
42 static constexpr const char kSerializedEntity[] = "serialized_entity";
43 static constexpr const char kEntityKey[] = "entity";
44 
45 // Implementation of a lua_Writer that appends the data to a string.
LuaStringWriter(lua_State * state,const void * data,size_t size,void * result)46 int LuaStringWriter(lua_State* state, const void* data, size_t size,
47                     void* result) {
48   std::string* const result_string = static_cast<std::string*>(result);
49   result_string->insert(result_string->size(), static_cast<const char*>(data),
50                         size);
51   return LUA_OK;
52 }
53 
54 }  // namespace
55 
LuaEnvironment()56 LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }
57 
~LuaEnvironment()58 LuaEnvironment::~LuaEnvironment() {
59   if (state_ != nullptr) {
60     lua_close(state_);
61   }
62 }
63 
PushFlatbuffer(const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Table * table) const64 void LuaEnvironment::PushFlatbuffer(const reflection::Schema* schema,
65                                     const reflection::Object* type,
66                                     const flatbuffers::Table* table) const {
67   PushLazyObject(
68       std::bind(&LuaEnvironment::GetField, this, schema, type, table));
69 }
70 
GetField(const reflection::Schema * schema,const reflection::Object * type,const flatbuffers::Table * table) const71 int LuaEnvironment::GetField(const reflection::Schema* schema,
72                              const reflection::Object* type,
73                              const flatbuffers::Table* table) const {
74   const char* field_name = lua_tostring(state_, /*idx=*/kIndexStackTop);
75   const reflection::Field* field = type->fields()->LookupByKey(field_name);
76   if (field == nullptr) {
77     lua_error(state_);
78     return 0;
79   }
80   // Provide primitive fields directly.
81   const reflection::BaseType field_type = field->type()->base_type();
82   switch (field_type) {
83     case reflection::Bool:
84       Push(table->GetField<bool>(field->offset(), field->default_integer()));
85       break;
86     case reflection::UByte:
87       Push(table->GetField<uint8>(field->offset(), field->default_integer()));
88       break;
89     case reflection::Byte:
90       Push(table->GetField<int8>(field->offset(), field->default_integer()));
91       break;
92     case reflection::Int:
93       Push(table->GetField<int32>(field->offset(), field->default_integer()));
94       break;
95     case reflection::UInt:
96       Push(table->GetField<uint32>(field->offset(), field->default_integer()));
97       break;
98     case reflection::Long:
99       Push(table->GetField<int64>(field->offset(), field->default_integer()));
100       break;
101     case reflection::ULong:
102       Push(table->GetField<uint64>(field->offset(), field->default_integer()));
103       break;
104     case reflection::Float:
105       Push(table->GetField<float>(field->offset(), field->default_real()));
106       break;
107     case reflection::Double:
108       Push(table->GetField<double>(field->offset(), field->default_real()));
109       break;
110     case reflection::String: {
111       Push(table->GetPointer<const flatbuffers::String*>(field->offset()));
112       break;
113     }
114     case reflection::Obj: {
115       const flatbuffers::Table* field_table =
116           table->GetPointer<const flatbuffers::Table*>(field->offset());
117       if (field_table == nullptr) {
118         // Field was not set in entity data.
119         return 0;
120       }
121       const reflection::Object* field_type =
122           schema->objects()->Get(field->type()->index());
123       PushFlatbuffer(schema, field_type, field_table);
124       break;
125     }
126     case reflection::Vector: {
127       const flatbuffers::Vector<flatbuffers::Offset<void>>* field_vector =
128           table->GetPointer<
129               const flatbuffers::Vector<flatbuffers::Offset<void>>*>(
130               field->offset());
131       if (field_vector == nullptr) {
132         // Repeated field was not set in flatbuffer.
133         PushEmptyVector();
134         break;
135       }
136       switch (field->type()->element()) {
137         case reflection::Bool:
138           PushRepeatedField(table->GetPointer<const flatbuffers::Vector<bool>*>(
139               field->offset()));
140           break;
141         case reflection::UByte:
142           PushRepeatedField(
143               table->GetPointer<const flatbuffers::Vector<uint8>*>(
144                   field->offset()));
145           break;
146         case reflection::Byte:
147           PushRepeatedField(table->GetPointer<const flatbuffers::Vector<int8>*>(
148               field->offset()));
149           break;
150         case reflection::Int:
151           PushRepeatedField(
152               table->GetPointer<const flatbuffers::Vector<int32>*>(
153                   field->offset()));
154           break;
155         case reflection::UInt:
156           PushRepeatedField(
157               table->GetPointer<const flatbuffers::Vector<uint32>*>(
158                   field->offset()));
159           break;
160         case reflection::Long:
161           PushRepeatedField(
162               table->GetPointer<const flatbuffers::Vector<int64>*>(
163                   field->offset()));
164           break;
165         case reflection::ULong:
166           PushRepeatedField(
167               table->GetPointer<const flatbuffers::Vector<uint64>*>(
168                   field->offset()));
169           break;
170         case reflection::Float:
171           PushRepeatedField(
172               table->GetPointer<const flatbuffers::Vector<float>*>(
173                   field->offset()));
174           break;
175         case reflection::Double:
176           PushRepeatedField(
177               table->GetPointer<const flatbuffers::Vector<double>*>(
178                   field->offset()));
179           break;
180         case reflection::String:
181           PushRepeatedField(
182               table->GetPointer<const flatbuffers::Vector<
183                   flatbuffers::Offset<flatbuffers::String>>*>(field->offset()));
184           break;
185         case reflection::Obj:
186           PushRepeatedFlatbufferField(
187               schema, schema->objects()->Get(field->type()->index()),
188               table->GetPointer<const flatbuffers::Vector<
189                   flatbuffers::Offset<flatbuffers::Table>>*>(field->offset()));
190           break;
191         default:
192           TC3_LOG(ERROR) << "Unsupported repeated type: "
193                          << field->type()->element();
194           lua_error(state_);
195           return 0;
196       }
197       break;
198     }
199     default:
200       TC3_LOG(ERROR) << "Unsupported type: " << field_type;
201       lua_error(state_);
202       return 0;
203   }
204   return 1;
205 }
206 
ReadFlatbuffer(const int index,MutableFlatbuffer * buffer) const207 int LuaEnvironment::ReadFlatbuffer(const int index,
208                                    MutableFlatbuffer* buffer) const {
209   if (buffer == nullptr) {
210     TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
211     lua_error(state_);
212     return LUA_ERRRUN;
213   }
214   if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
215     TC3_LOG(ERROR) << "Expected table, got: "
216                    << lua_type(state_, /*idx=*/kIndexStackTop);
217     lua_error(state_);
218     return LUA_ERRRUN;
219   }
220 
221   lua_pushnil(state_);
222   while (Next(index - 1)) {
223     const StringPiece key = ReadString(/*index=*/index - 1);
224     const reflection::Field* field = buffer->GetFieldOrNull(key);
225     if (field == nullptr) {
226       TC3_LOG(ERROR) << "Unknown field: " << key;
227       lua_error(state_);
228       return LUA_ERRRUN;
229     }
230     switch (field->type()->base_type()) {
231       case reflection::Obj:
232         ReadFlatbuffer(/*index=*/kIndexStackTop, buffer->Mutable(field));
233         break;
234       case reflection::Bool:
235         buffer->Set(field, Read<bool>(/*index=*/kIndexStackTop));
236         break;
237       case reflection::Byte:
238         buffer->Set(field, Read<int8>(/*index=*/kIndexStackTop));
239         break;
240       case reflection::UByte:
241         buffer->Set(field, Read<uint8>(/*index=*/kIndexStackTop));
242         break;
243       case reflection::Int:
244         buffer->Set(field, Read<int32>(/*index=*/kIndexStackTop));
245         break;
246       case reflection::UInt:
247         buffer->Set(field, Read<uint32>(/*index=*/kIndexStackTop));
248         break;
249       case reflection::Long:
250         buffer->Set(field, Read<int64>(/*index=*/kIndexStackTop));
251         break;
252       case reflection::ULong:
253         buffer->Set(field, Read<uint64>(/*index=*/kIndexStackTop));
254         break;
255       case reflection::Float:
256         buffer->Set(field, Read<float>(/*index=*/kIndexStackTop));
257         break;
258       case reflection::Double:
259         buffer->Set(field, Read<double>(/*index=*/kIndexStackTop));
260         break;
261       case reflection::String: {
262         buffer->Set(field, ReadString(/*index=*/kIndexStackTop));
263         break;
264       }
265       case reflection::Vector: {
266         // Read repeated field.
267         switch (field->type()->element()) {
268           case reflection::Bool:
269             ReadRepeatedField<bool>(/*index=*/kIndexStackTop,
270                                     buffer->Repeated(field));
271             break;
272           case reflection::Byte:
273             ReadRepeatedField<int8>(/*index=*/kIndexStackTop,
274                                     buffer->Repeated(field));
275             break;
276           case reflection::UByte:
277             ReadRepeatedField<uint8>(/*index=*/kIndexStackTop,
278                                      buffer->Repeated(field));
279             break;
280           case reflection::Int:
281             ReadRepeatedField<int32>(/*index=*/kIndexStackTop,
282                                      buffer->Repeated(field));
283             break;
284           case reflection::UInt:
285             ReadRepeatedField<uint32>(/*index=*/kIndexStackTop,
286                                       buffer->Repeated(field));
287             break;
288           case reflection::Long:
289             ReadRepeatedField<int64>(/*index=*/kIndexStackTop,
290                                      buffer->Repeated(field));
291             break;
292           case reflection::ULong:
293             ReadRepeatedField<uint64>(/*index=*/kIndexStackTop,
294                                       buffer->Repeated(field));
295             break;
296           case reflection::Float:
297             ReadRepeatedField<float>(/*index=*/kIndexStackTop,
298                                      buffer->Repeated(field));
299             break;
300           case reflection::Double:
301             ReadRepeatedField<double>(/*index=*/kIndexStackTop,
302                                       buffer->Repeated(field));
303             break;
304           case reflection::String:
305             ReadRepeatedField<std::string>(/*index=*/kIndexStackTop,
306                                            buffer->Repeated(field));
307             break;
308           case reflection::Obj:
309             ReadRepeatedField<MutableFlatbuffer>(/*index=*/kIndexStackTop,
310                                                  buffer->Repeated(field));
311             break;
312           default:
313             TC3_LOG(ERROR) << "Unsupported repeated field type: "
314                            << field->type()->element();
315             lua_error(state_);
316             return LUA_ERRRUN;
317         }
318         break;
319       }
320       default:
321         TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
322         lua_error(state_);
323         return LUA_ERRRUN;
324     }
325     lua_pop(state_, 1);
326   }
327   return LUA_OK;
328 }
329 
LoadDefaultLibraries()330 void LuaEnvironment::LoadDefaultLibraries() {
331   for (const luaL_Reg* lib = defaultlibs; lib->func; lib++) {
332     luaL_requiref(state_, lib->name, lib->func, 1);
333     lua_pop(state_, 1);  // Remove lib.
334   }
335 }
336 
ReadString(const int index) const337 StringPiece LuaEnvironment::ReadString(const int index) const {
338   size_t length = 0;
339   const char* data = lua_tolstring(state_, index, &length);
340   return StringPiece(data, length);
341 }
342 
PushString(const StringPiece str) const343 void LuaEnvironment::PushString(const StringPiece str) const {
344   lua_pushlstring(state_, str.data(), str.size());
345 }
346 
Compile(StringPiece snippet,std::string * bytecode) const347 bool LuaEnvironment::Compile(StringPiece snippet, std::string* bytecode) const {
348   if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
349                       /*name=*/nullptr) != LUA_OK) {
350     TC3_LOG(ERROR) << "Could not compile lua snippet: "
351                    << ReadString(/*index=*/kIndexStackTop);
352     lua_pop(state_, 1);
353     return false;
354   }
355   if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
356     TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
357     lua_pop(state_, 1);
358     return false;
359   }
360   lua_pop(state_, 1);
361   return true;
362 }
363 
PushAnnotation(const ClassificationResult & classification,const reflection::Schema * entity_data_schema) const364 void LuaEnvironment::PushAnnotation(
365     const ClassificationResult& classification,
366     const reflection::Schema* entity_data_schema) const {
367   if (entity_data_schema == nullptr ||
368       classification.serialized_entity_data.empty()) {
369     // Empty table.
370     lua_newtable(state_);
371   } else {
372     PushFlatbuffer(entity_data_schema,
373                    flatbuffers::GetRoot<flatbuffers::Table>(
374                        classification.serialized_entity_data.data()));
375   }
376   Push(classification.datetime_parse_result.time_ms_utc);
377   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTimeUsecKey);
378   Push(classification.datetime_parse_result.granularity);
379   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kGranularityKey);
380   Push(classification.collection);
381   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kCollectionKey);
382   Push(classification.score);
383   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
384   Push(classification.serialized_entity_data);
385   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSerializedEntity);
386 }
387 
PushAnnotation(const ClassificationResult & classification,StringPiece text,const reflection::Schema * entity_data_schema) const388 void LuaEnvironment::PushAnnotation(
389     const ClassificationResult& classification, StringPiece text,
390     const reflection::Schema* entity_data_schema) const {
391   PushAnnotation(classification, entity_data_schema);
392   Push(text);
393   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTextKey);
394 }
395 
PushAnnotation(const ActionSuggestionAnnotation & annotation,const reflection::Schema * entity_data_schema) const396 void LuaEnvironment::PushAnnotation(
397     const ActionSuggestionAnnotation& annotation,
398     const reflection::Schema* entity_data_schema) const {
399   PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema);
400   PushString(annotation.name);
401   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kNameKey);
402   {
403     lua_newtable(state_);
404     Push(annotation.span.message_index);
405     lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kMessageKey);
406     Push(annotation.span.span.first);
407     lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
408     Push(annotation.span.span.second);
409     lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
410   }
411   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
412 }
413 
PushAnnotatedSpan(const AnnotatedSpan & annotated_span,const reflection::Schema * entity_data_schema) const414 void LuaEnvironment::PushAnnotatedSpan(
415     const AnnotatedSpan& annotated_span,
416     const reflection::Schema* entity_data_schema) const {
417   lua_newtable(state_);
418   {
419     lua_newtable(state_);
420     Push(annotated_span.span.first);
421     lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kBeginKey);
422     Push(annotated_span.span.second);
423     lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kEndKey);
424   }
425   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kSpanKey);
426   PushAnnotations(&annotated_span.classification, entity_data_schema);
427   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kClassificationKey);
428 }
429 
PushAnnotatedSpans(const std::vector<AnnotatedSpan> * annotated_spans,const reflection::Schema * entity_data_schema) const430 void LuaEnvironment::PushAnnotatedSpans(
431     const std::vector<AnnotatedSpan>* annotated_spans,
432     const reflection::Schema* entity_data_schema) const {
433   PushIterator(annotated_spans ? annotated_spans->size() : 0,
434                [this, annotated_spans, entity_data_schema](const int64 index) {
435                  PushAnnotatedSpan(annotated_spans->at(index),
436                                    entity_data_schema);
437                  return 1;
438                });
439 }
440 
ReadSpan() const441 MessageTextSpan LuaEnvironment::ReadSpan() const {
442   MessageTextSpan span;
443   lua_pushnil(state_);
444   while (Next(/*index=*/kIndexStackTop - 1)) {
445     const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
446     if (key.Equals(kMessageKey)) {
447       span.message_index = Read<int>(/*index=*/kIndexStackTop);
448     } else if (key.Equals(kBeginKey)) {
449       span.span.first = Read<int>(/*index=*/kIndexStackTop);
450     } else if (key.Equals(kEndKey)) {
451       span.span.second = Read<int>(/*index=*/kIndexStackTop);
452     } else if (key.Equals(kTextKey)) {
453       span.text = Read<std::string>(/*index=*/kIndexStackTop);
454     } else {
455       TC3_LOG(INFO) << "Unknown span field: " << key;
456     }
457     lua_pop(state_, 1);
458   }
459   return span;
460 }
461 
ReadAnnotations(const reflection::Schema * entity_data_schema,std::vector<ActionSuggestionAnnotation> * annotations) const462 int LuaEnvironment::ReadAnnotations(
463     const reflection::Schema* entity_data_schema,
464     std::vector<ActionSuggestionAnnotation>* annotations) const {
465   if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
466     TC3_LOG(ERROR) << "Expected annotations table, got: "
467                    << lua_type(state_, /*idx=*/kIndexStackTop);
468     lua_pop(state_, 1);
469     lua_error(state_);
470     return LUA_ERRRUN;
471   }
472 
473   // Read actions.
474   lua_pushnil(state_);
475   while (Next(/*index=*/kIndexStackTop - 1)) {
476     if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
477       TC3_LOG(ERROR) << "Expected annotation table, got: "
478                      << lua_type(state_, /*idx=*/kIndexStackTop);
479       lua_pop(state_, 1);
480       continue;
481     }
482     annotations->push_back(ReadAnnotation(entity_data_schema));
483     lua_pop(state_, 1);
484   }
485   return LUA_OK;
486 }
487 
ReadAnnotation(const reflection::Schema * entity_data_schema) const488 ActionSuggestionAnnotation LuaEnvironment::ReadAnnotation(
489     const reflection::Schema* entity_data_schema) const {
490   ActionSuggestionAnnotation annotation;
491   lua_pushnil(state_);
492   while (Next(/*index=*/kIndexStackTop - 1)) {
493     const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
494     if (key.Equals(kNameKey)) {
495       annotation.name = Read<std::string>(/*index=*/kIndexStackTop);
496     } else if (key.Equals(kSpanKey)) {
497       annotation.span = ReadSpan();
498     } else if (key.Equals(kEntityKey)) {
499       annotation.entity = ReadClassificationResult(entity_data_schema);
500     } else {
501       TC3_LOG(ERROR) << "Unknown annotation field: " << key;
502     }
503     lua_pop(state_, 1);
504   }
505   return annotation;
506 }
507 
ReadClassificationResult(const reflection::Schema * entity_data_schema) const508 ClassificationResult LuaEnvironment::ReadClassificationResult(
509     const reflection::Schema* entity_data_schema) const {
510   ClassificationResult classification;
511   lua_pushnil(state_);
512   while (Next(/*index=*/kIndexStackTop - 1)) {
513     const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
514     if (key.Equals(kCollectionKey)) {
515       classification.collection = Read<std::string>(/*index=*/kIndexStackTop);
516     } else if (key.Equals(kScoreKey)) {
517       classification.score = Read<float>(/*index=*/kIndexStackTop);
518     } else if (key.Equals(kTimeUsecKey)) {
519       classification.datetime_parse_result.time_ms_utc =
520           Read<int64>(/*index=*/kIndexStackTop);
521     } else if (key.Equals(kGranularityKey)) {
522       classification.datetime_parse_result.granularity =
523           static_cast<DatetimeGranularity>(
524               lua_tonumber(state_, /*idx=*/kIndexStackTop));
525     } else if (key.Equals(kSerializedEntity)) {
526       classification.serialized_entity_data =
527           Read<std::string>(/*index=*/kIndexStackTop);
528     } else if (key.Equals(kEntityKey)) {
529       auto buffer = MutableFlatbufferBuilder(entity_data_schema).NewRoot();
530       ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
531       classification.serialized_entity_data = buffer->Serialize();
532     } else {
533       TC3_LOG(INFO) << "Unknown classification result field: " << key;
534     }
535     lua_pop(state_, 1);
536   }
537   return classification;
538 }
539 
PushAction(const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const540 void LuaEnvironment::PushAction(
541     const ActionSuggestion& action,
542     const reflection::Schema* actions_entity_data_schema,
543     const reflection::Schema* annotations_entity_data_schema) const {
544   if (actions_entity_data_schema == nullptr ||
545       action.serialized_entity_data.empty()) {
546     // Empty table.
547     lua_newtable(state_);
548   } else {
549     PushFlatbuffer(actions_entity_data_schema,
550                    flatbuffers::GetRoot<flatbuffers::Table>(
551                        action.serialized_entity_data.data()));
552   }
553   PushString(action.type);
554   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kTypeKey);
555   PushString(action.response_text);
556   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kResponseTextKey);
557   Push(action.score);
558   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kScoreKey);
559   Push(action.priority_score);
560   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kPriorityScoreKey);
561   PushAnnotations(&action.annotations, annotations_entity_data_schema);
562   lua_setfield(state_, /*idx=*/kIndexStackTop - 1, kAnnotationKey);
563 }
564 
PushActions(const std::vector<ActionSuggestion> * actions,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const565 void LuaEnvironment::PushActions(
566     const std::vector<ActionSuggestion>* actions,
567     const reflection::Schema* actions_entity_data_schema,
568     const reflection::Schema* annotations_entity_data_schema) const {
569   PushIterator(actions ? actions->size() : 0,
570                [this, actions, actions_entity_data_schema,
571                 annotations_entity_data_schema](const int64 index) {
572                  PushAction(actions->at(index), actions_entity_data_schema,
573                             annotations_entity_data_schema);
574                  return 1;
575                });
576 }
577 
ReadAction(const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const578 ActionSuggestion LuaEnvironment::ReadAction(
579     const reflection::Schema* actions_entity_data_schema,
580     const reflection::Schema* annotations_entity_data_schema) const {
581   ActionSuggestion action;
582   lua_pushnil(state_);
583   while (Next(/*index=*/kIndexStackTop - 1)) {
584     const StringPiece key = ReadString(/*index=*/kIndexStackTop - 1);
585     if (key.Equals(kResponseTextKey)) {
586       action.response_text = Read<std::string>(/*index=*/kIndexStackTop);
587     } else if (key.Equals(kTypeKey)) {
588       action.type = Read<std::string>(/*index=*/kIndexStackTop);
589     } else if (key.Equals(kScoreKey)) {
590       action.score = Read<float>(/*index=*/kIndexStackTop);
591     } else if (key.Equals(kPriorityScoreKey)) {
592       action.priority_score = Read<float>(/*index=*/kIndexStackTop);
593     } else if (key.Equals(kAnnotationKey)) {
594       ReadAnnotations(actions_entity_data_schema, &action.annotations);
595     } else if (key.Equals(kEntityKey)) {
596       auto buffer =
597           MutableFlatbufferBuilder(actions_entity_data_schema).NewRoot();
598       ReadFlatbuffer(/*index=*/kIndexStackTop, buffer.get());
599       action.serialized_entity_data = buffer->Serialize();
600     } else {
601       TC3_LOG(INFO) << "Unknown action field: " << key;
602     }
603     lua_pop(state_, 1);
604   }
605   return action;
606 }
607 
ReadActions(const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema,std::vector<ActionSuggestion> * actions) const608 int LuaEnvironment::ReadActions(
609     const reflection::Schema* actions_entity_data_schema,
610     const reflection::Schema* annotations_entity_data_schema,
611     std::vector<ActionSuggestion>* actions) const {
612   // Read actions.
613   lua_pushnil(state_);
614   while (Next(/*index=*/kIndexStackTop - 1)) {
615     if (lua_type(state_, /*idx=*/kIndexStackTop) != LUA_TTABLE) {
616       TC3_LOG(ERROR) << "Expected action table, got: "
617                      << lua_type(state_, /*idx=*/kIndexStackTop);
618       lua_pop(state_, 1);
619       continue;
620     }
621     actions->push_back(
622         ReadAction(actions_entity_data_schema, annotations_entity_data_schema));
623     lua_pop(state_, /*n=*/1);
624   }
625   lua_pop(state_, /*n=*/1);
626 
627   return LUA_OK;
628 }
629 
PushConversation(const std::vector<ConversationMessage> * conversation,const reflection::Schema * annotations_entity_data_schema) const630 void LuaEnvironment::PushConversation(
631     const std::vector<ConversationMessage>* conversation,
632     const reflection::Schema* annotations_entity_data_schema) const {
633   PushIterator(
634       conversation ? conversation->size() : 0,
635       [this, conversation, annotations_entity_data_schema](const int64 index) {
636         const ConversationMessage& message = conversation->at(index);
637         lua_newtable(state_);
638         Push(message.user_id);
639         lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "user_id");
640         Push(message.text);
641         lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "text");
642         Push(message.reference_time_ms_utc);
643         lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "time_ms_utc");
644         Push(message.reference_timezone);
645         lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "timezone");
646         PushAnnotatedSpans(&message.annotations,
647                            annotations_entity_data_schema);
648         lua_setfield(state_, /*idx=*/kIndexStackTop - 1, "annotation");
649         return 1;
650       });
651 }
652 
Compile(StringPiece snippet,std::string * bytecode)653 bool Compile(StringPiece snippet, std::string* bytecode) {
654   return LuaEnvironment().Compile(snippet, bytecode);
655 }
656 
657 }  // namespace libtextclassifier3
658