1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
19 
20 #include "actions/types.h"
21 #include "annotator/types.h"
22 #include "utils/flatbuffers.h"
23 #include "utils/lua-utils.h"
24 
25 #ifdef __cplusplus
26 extern "C" {
27 #endif
28 #include "lauxlib.h"
29 #include "lua.h"
30 #include "lualib.h"
31 #ifdef __cplusplus
32 }
33 #endif
34 
35 // Action specific shared lua utilities.
36 namespace libtextclassifier3 {
37 
38 // Provides an annotation to lua.
39 void PushAnnotation(const ClassificationResult& classification,
40                     const reflection::Schema* entity_data_schema,
41                     LuaEnvironment* env);
42 void PushAnnotation(const ClassificationResult& classification,
43                     StringPiece text,
44                     const reflection::Schema* entity_data_schema,
45                     LuaEnvironment* env);
46 void PushAnnotation(const ActionSuggestionAnnotation& annotation,
47                     const reflection::Schema* entity_data_schema,
48                     LuaEnvironment* env);
49 
50 // A lua iterator to enumerate annotation.
51 template <typename Annotation>
52 class AnnotationIterator
53     : public LuaEnvironment::ItemIterator<std::vector<Annotation>> {
54  public:
AnnotationIterator(const reflection::Schema * entity_data_schema,LuaEnvironment * env)55   AnnotationIterator(const reflection::Schema* entity_data_schema,
56                      LuaEnvironment* env)
57       : env_(env), entity_data_schema_(entity_data_schema) {}
Item(const std::vector<Annotation> * annotations,const int64 pos,lua_State * state)58   int Item(const std::vector<Annotation>* annotations, const int64 pos,
59            lua_State* state) const override {
60     PushAnnotation((*annotations)[pos], entity_data_schema_, env_);
61     return 1;
62   }
63   int Item(const std::vector<Annotation>* annotations, StringPiece key,
64            lua_State* state) const override;
65 
66  private:
67   LuaEnvironment* env_;
68   const reflection::Schema* entity_data_schema_;
69 };
70 
71 template <>
72 int AnnotationIterator<ClassificationResult>::Item(
73     const std::vector<ClassificationResult>* annotations, StringPiece key,
74     lua_State* state) const;
75 
76 template <>
77 int AnnotationIterator<ActionSuggestionAnnotation>::Item(
78     const std::vector<ActionSuggestionAnnotation>* annotations, StringPiece key,
79     lua_State* state) const;
80 
81 void PushAnnotatedSpan(
82     const AnnotatedSpan& annotated_span,
83     const AnnotationIterator<ClassificationResult>& annotation_iterator,
84     LuaEnvironment* env);
85 
86 MessageTextSpan ReadSpan(LuaEnvironment* env);
87 ActionSuggestionAnnotation ReadAnnotation(
88     const reflection::Schema* entity_data_schema, LuaEnvironment* env);
89 int ReadAnnotations(const reflection::Schema* entity_data_schema,
90                     LuaEnvironment* env,
91                     std::vector<ActionSuggestionAnnotation>* annotations);
92 ClassificationResult ReadClassificationResult(
93     const reflection::Schema* entity_data_schema, LuaEnvironment* env);
94 
95 // A lua iterator to enumerate annotated spans.
96 class AnnotatedSpanIterator
97     : public LuaEnvironment::ItemIterator<std::vector<AnnotatedSpan>> {
98  public:
AnnotatedSpanIterator(const AnnotationIterator<ClassificationResult> & annotation_iterator,LuaEnvironment * env)99   AnnotatedSpanIterator(
100       const AnnotationIterator<ClassificationResult>& annotation_iterator,
101       LuaEnvironment* env)
102       : env_(env), annotation_iterator_(annotation_iterator) {}
AnnotatedSpanIterator(const reflection::Schema * entity_data_schema,LuaEnvironment * env)103   AnnotatedSpanIterator(const reflection::Schema* entity_data_schema,
104                         LuaEnvironment* env)
105       : env_(env), annotation_iterator_(entity_data_schema, env) {}
106 
Item(const std::vector<AnnotatedSpan> * spans,const int64 pos,lua_State * state)107   int Item(const std::vector<AnnotatedSpan>* spans, const int64 pos,
108            lua_State* state) const override {
109     PushAnnotatedSpan((*spans)[pos], annotation_iterator_, env_);
110     return /*num results=*/1;
111   }
112 
113  private:
114   LuaEnvironment* env_;
115   AnnotationIterator<ClassificationResult> annotation_iterator_;
116 };
117 
118 // Provides an action to lua.
119 void PushAction(
120     const ActionSuggestion& action,
121     const reflection::Schema* entity_data_schema,
122     const AnnotationIterator<ActionSuggestionAnnotation>& annotation_iterator,
123     LuaEnvironment* env);
124 
125 ActionSuggestion ReadAction(
126     const reflection::Schema* actions_entity_data_schema,
127     const reflection::Schema* annotations_entity_data_schema,
128     LuaEnvironment* env);
129 int ReadActions(const reflection::Schema* actions_entity_data_schema,
130                 const reflection::Schema* annotations_entity_data_schema,
131                 LuaEnvironment* env, std::vector<ActionSuggestion>* actions);
132 
133 // A lua iterator to enumerate actions suggestions.
134 class ActionsIterator
135     : public LuaEnvironment::ItemIterator<std::vector<ActionSuggestion>> {
136  public:
ActionsIterator(const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema,LuaEnvironment * env)137   ActionsIterator(const reflection::Schema* entity_data_schema,
138                   const reflection::Schema* annotations_entity_data_schema,
139                   LuaEnvironment* env)
140       : env_(env),
141         entity_data_schema_(entity_data_schema),
142         annotation_iterator_(annotations_entity_data_schema, env) {}
Item(const std::vector<ActionSuggestion> * actions,const int64 pos,lua_State * state)143   int Item(const std::vector<ActionSuggestion>* actions, const int64 pos,
144            lua_State* state) const override {
145     PushAction((*actions)[pos], entity_data_schema_, annotation_iterator_,
146                env_);
147     return /*num results=*/1;
148   }
149 
150  private:
151   LuaEnvironment* env_;
152   const reflection::Schema* entity_data_schema_;
153   AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
154 };
155 
156 // Conversation message lua iterator.
157 class ConversationIterator
158     : public LuaEnvironment::ItemIterator<std::vector<ConversationMessage>> {
159  public:
ConversationIterator(const AnnotationIterator<ClassificationResult> & annotation_iterator,LuaEnvironment * env)160   ConversationIterator(
161       const AnnotationIterator<ClassificationResult>& annotation_iterator,
162       LuaEnvironment* env)
163       : env_(env),
164         annotated_span_iterator_(
165             AnnotatedSpanIterator(annotation_iterator, env)) {}
ConversationIterator(const reflection::Schema * entity_data_schema,LuaEnvironment * env)166   ConversationIterator(const reflection::Schema* entity_data_schema,
167                        LuaEnvironment* env)
168       : env_(env),
169         annotated_span_iterator_(
170             AnnotatedSpanIterator(entity_data_schema, env)) {}
171 
172   int Item(const std::vector<ConversationMessage>* messages, const int64 pos,
173            lua_State* state) const override;
174 
175  private:
176   LuaEnvironment* env_;
177   AnnotatedSpanIterator annotated_span_iterator_;
178 };
179 
180 }  // namespace libtextclassifier3
181 
182 #endif  // LIBTEXTCLASSIFIER_ACTIONS_LUA_UTILS_H_
183