1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
18 #define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
19 
20 #include <functional>
21 #include <vector>
22 
23 #include "utils/flatbuffers.h"
24 #include "utils/strings/stringpiece.h"
25 #include "utils/variant.h"
26 #include "flatbuffers/reflection_generated.h"
27 
28 #ifdef __cplusplus
29 extern "C" {
30 #endif
31 #include "lauxlib.h"
32 #include "lua.h"
33 #include "lualib.h"
34 #ifdef __cplusplus
35 }
36 #endif
37 
38 namespace libtextclassifier3 {
39 
40 static constexpr const char *kLengthKey = "__len";
41 static constexpr const char *kPairsKey = "__pairs";
42 static constexpr const char *kIndexKey = "__index";
43 
44 // Casts to the lua user data type.
45 template <typename T>
AsUserData(const T * value)46 void *AsUserData(const T *value) {
47   return static_cast<void *>(const_cast<T *>(value));
48 }
49 template <typename T>
AsUserData(const T value)50 void *AsUserData(const T value) {
51   return reinterpret_cast<void *>(value);
52 }
53 
54 // Retrieves up-values.
55 template <typename T>
FromUpValue(const int index,lua_State * state)56 T FromUpValue(const int index, lua_State *state) {
57   return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
58 }
59 
60 class LuaEnvironment {
61  public:
62   // Wrapper for handling an iterator.
63   class Iterator {
64    public:
~Iterator()65     virtual ~Iterator() {}
66     static int NextCallback(lua_State *state);
67     static int LengthCallback(lua_State *state);
68     static int ItemCallback(lua_State *state);
69     static int IteritemsCallback(lua_State *state);
70 
71     // Called when the next element of an iterator is fetched.
72     virtual int Next(lua_State *state) const = 0;
73 
74     // Called when the length of the iterator is queried.
75     virtual int Length(lua_State *state) const = 0;
76 
77     // Called when an item is queried.
78     virtual int Item(lua_State *state) const = 0;
79 
80     // Called when a new iterator is started.
81     virtual int Iteritems(lua_State *state) const = 0;
82 
83    protected:
84     static constexpr int kIteratorArgId = 1;
85   };
86 
87   template <typename T>
88   class ItemIterator : public Iterator {
89    public:
NewIterator(StringPiece name,const T * items,lua_State * state)90     void NewIterator(StringPiece name, const T *items, lua_State *state) const {
91       lua_newtable(state);
92       luaL_newmetatable(state, name.data());
93       lua_pushlightuserdata(state, AsUserData(this));
94       lua_pushlightuserdata(state, AsUserData(items));
95       lua_pushcclosure(state, &Iterator::ItemCallback, 2);
96       lua_setfield(state, -2, kIndexKey);
97       lua_pushlightuserdata(state, AsUserData(this));
98       lua_pushlightuserdata(state, AsUserData(items));
99       lua_pushcclosure(state, &Iterator::LengthCallback, 2);
100       lua_setfield(state, -2, kLengthKey);
101       lua_pushlightuserdata(state, AsUserData(this));
102       lua_pushlightuserdata(state, AsUserData(items));
103       lua_pushcclosure(state, &Iterator::IteritemsCallback, 2);
104       lua_setfield(state, -2, kPairsKey);
105       lua_setmetatable(state, -2);
106     }
107 
Iteritems(lua_State * state)108     int Iteritems(lua_State *state) const override {
109       lua_pushlightuserdata(state, AsUserData(this));
110       lua_pushlightuserdata(
111           state, lua_touserdata(state, lua_upvalueindex(kItemsArgId)));
112       lua_pushnumber(state, 0);
113       lua_pushcclosure(state, &Iterator::NextCallback, 3);
114       return /*num results=*/1;
115     }
116 
Length(lua_State * state)117     int Length(lua_State *state) const override {
118       lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size());
119       return /*num results=*/1;
120     }
121 
Next(lua_State * state)122     int Next(lua_State *state) const override {
123       return Next(FromUpValue<T *>(kItemsArgId, state),
124                   lua_tointeger(state, lua_upvalueindex(kIterValueArgId)),
125                   state);
126     }
127 
Next(const T * items,const int64 pos,lua_State * state)128     int Next(const T *items, const int64 pos, lua_State *state) const {
129       if (pos >= items->size()) {
130         return 0;
131       }
132 
133       // Update iterator value.
134       lua_pushnumber(state, pos + 1);
135       lua_replace(state, lua_upvalueindex(3));
136 
137       // Push key.
138       lua_pushinteger(state, pos + 1);
139 
140       // Push item.
141       return 1 + Item(items, pos, state);
142     }
143 
Item(lua_State * state)144     int Item(lua_State *state) const override {
145       const T *items = FromUpValue<T *>(kItemsArgId, state);
146       switch (lua_type(state, -1)) {
147         case LUA_TNUMBER: {
148           // Lua is one based, so adjust the index here.
149           const int64 index =
150               static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1;
151           if (index < 0 || index >= items->size()) {
152             TC3_LOG(ERROR) << "Invalid index: " << index;
153             lua_error(state);
154             return 0;
155           }
156           return Item(items, index, state);
157         }
158         case LUA_TSTRING: {
159           size_t key_length = 0;
160           const char *key = lua_tolstring(state, /*idx=*/-1, &key_length);
161           return Item(items, StringPiece(key, key_length), state);
162         }
163         default:
164           TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1);
165           lua_error(state);
166           return 0;
167       }
168     }
169 
170     virtual int Item(const T *items, const int64 pos,
171                      lua_State *state) const = 0;
172 
Item(const T * items,StringPiece key,lua_State * state)173     virtual int Item(const T *items, StringPiece key, lua_State *state) const {
174       TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString();
175       lua_error(state);
176       return 0;
177     }
178 
179    protected:
180     static constexpr int kItemsArgId = 2;
181     static constexpr int kIterValueArgId = 3;
182   };
183 
184   virtual ~LuaEnvironment();
185   LuaEnvironment();
186 
187   // Compile a lua snippet into binary bytecode.
188   // NOTE: The compiled bytecode might not be compatible across Lua versions
189   // and platforms.
190   bool Compile(StringPiece snippet, std::string *bytecode);
191 
192   typedef int (*CallbackHandler)(lua_State *);
193 
194   // Loads default libraries.
195   void LoadDefaultLibraries();
196 
197   // Provides a callback to Lua.
198   template <typename T, int (T::*handler)()>
Bind()199   void Bind() {
200     lua_pushlightuserdata(state_, static_cast<void *>(this));
201     lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
202   }
203 
204   // Setup a named table that callsback whenever a member is accessed.
205   // This allows to lazily provide required information to the script.
206   template <typename T, int (T::*handler)()>
BindTable(const char * name)207   void BindTable(const char *name) {
208     lua_newtable(state_);
209     luaL_newmetatable(state_, name);
210     lua_pushlightuserdata(state_, static_cast<void *>(this));
211     lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
212     lua_setfield(state_, -2, kIndexKey);
213     lua_setmetatable(state_, -2);
214   }
215 
216   void PushValue(const Variant &value);
217 
218   // Reads a string from the stack.
219   StringPiece ReadString(const int index) const;
220 
221   // Pushes a string to the stack.
222   void PushString(const StringPiece str);
223 
224   // Pushes a flatbuffer to the stack.
225   void PushFlatbuffer(const reflection::Schema *schema,
226                       const flatbuffers::Table *table);
227 
228   // Reads a flatbuffer from the stack.
229   int ReadFlatbuffer(ReflectiveFlatbuffer *buffer);
230 
231   // Runs a closure in protected mode.
232   // `func`: closure to run in protected mode.
233   // `num_lua_args`: number of arguments from the lua stack to process.
234   // `num_results`: number of result values pushed on the stack.
235   int RunProtected(const std::function<int()> &func, const int num_args = 0,
236                    const int num_results = 0);
237 
state()238   lua_State *state() const { return state_; }
239 
240  protected:
241   lua_State *state_;
242 
243  private:
244   // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
245   static void PushFlatbuffer(const char *name, const reflection::Schema *schema,
246                              const reflection::Object *type,
247                              const flatbuffers::Table *table, lua_State *state);
248   static int GetFieldCallback(lua_State *state);
249   static int GetField(const reflection::Schema *schema,
250                       const reflection::Object *type,
251                       const flatbuffers::Table *table, lua_State *state);
252 
253   template <typename T, int (T::*handler)()>
Dispatch(lua_State * state)254   static int Dispatch(lua_State *state) {
255     T *env = FromUpValue<T *>(1, state);
256     return ((*env).*handler)();
257   }
258 };
259 
260 bool Compile(StringPiece snippet, std::string *bytecode);
261 
262 }  // namespace libtextclassifier3
263 
264 #endif  // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
265