• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1  /*
2   * Copyright (C) 2018 The Android Open Source Project
3   *
4   * Licensed under the Apache License, Version 2.0 (the "License");
5   * you may not use this file except in compliance with the License.
6   * You may obtain a copy of the License at
7   *
8   *      http://www.apache.org/licenses/LICENSE-2.0
9   *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  #ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
18  #define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
19  
20  #include <functional>
21  #include <vector>
22  
23  #include "utils/flatbuffers.h"
24  #include "utils/strings/stringpiece.h"
25  #include "utils/variant.h"
26  #include "flatbuffers/reflection_generated.h"
27  
28  #ifdef __cplusplus
29  extern "C" {
30  #endif
31  #include "lauxlib.h"
32  #include "lua.h"
33  #include "lualib.h"
34  #ifdef __cplusplus
35  }
36  #endif
37  
38  namespace libtextclassifier3 {
39  
40  static constexpr const char *kLengthKey = "__len";
41  static constexpr const char *kPairsKey = "__pairs";
42  static constexpr const char *kIndexKey = "__index";
43  
44  // Casts to the lua user data type.
45  template <typename T>
AsUserData(const T * value)46  void *AsUserData(const T *value) {
47    return static_cast<void *>(const_cast<T *>(value));
48  }
49  template <typename T>
AsUserData(const T value)50  void *AsUserData(const T value) {
51    return reinterpret_cast<void *>(value);
52  }
53  
54  // Retrieves up-values.
55  template <typename T>
FromUpValue(const int index,lua_State * state)56  T FromUpValue(const int index, lua_State *state) {
57    return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index)));
58  }
59  
60  class LuaEnvironment {
61   public:
62    // Wrapper for handling an iterator.
63    class Iterator {
64     public:
~Iterator()65      virtual ~Iterator() {}
66      static int NextCallback(lua_State *state);
67      static int LengthCallback(lua_State *state);
68      static int ItemCallback(lua_State *state);
69      static int IteritemsCallback(lua_State *state);
70  
71      // Called when the next element of an iterator is fetched.
72      virtual int Next(lua_State *state) const = 0;
73  
74      // Called when the length of the iterator is queried.
75      virtual int Length(lua_State *state) const = 0;
76  
77      // Called when an item is queried.
78      virtual int Item(lua_State *state) const = 0;
79  
80      // Called when a new iterator is started.
81      virtual int Iteritems(lua_State *state) const = 0;
82  
83     protected:
84      static constexpr int kIteratorArgId = 1;
85    };
86  
87    template <typename T>
88    class ItemIterator : public Iterator {
89     public:
NewIterator(StringPiece name,const T * items,lua_State * state)90      void NewIterator(StringPiece name, const T *items, lua_State *state) const {
91        lua_newtable(state);
92        luaL_newmetatable(state, name.data());
93        lua_pushlightuserdata(state, AsUserData(this));
94        lua_pushlightuserdata(state, AsUserData(items));
95        lua_pushcclosure(state, &Iterator::ItemCallback, 2);
96        lua_setfield(state, -2, kIndexKey);
97        lua_pushlightuserdata(state, AsUserData(this));
98        lua_pushlightuserdata(state, AsUserData(items));
99        lua_pushcclosure(state, &Iterator::LengthCallback, 2);
100        lua_setfield(state, -2, kLengthKey);
101        lua_pushlightuserdata(state, AsUserData(this));
102        lua_pushlightuserdata(state, AsUserData(items));
103        lua_pushcclosure(state, &Iterator::IteritemsCallback, 2);
104        lua_setfield(state, -2, kPairsKey);
105        lua_setmetatable(state, -2);
106      }
107  
Iteritems(lua_State * state)108      int Iteritems(lua_State *state) const override {
109        lua_pushlightuserdata(state, AsUserData(this));
110        lua_pushlightuserdata(
111            state, lua_touserdata(state, lua_upvalueindex(kItemsArgId)));
112        lua_pushnumber(state, 0);
113        lua_pushcclosure(state, &Iterator::NextCallback, 3);
114        return /*num results=*/1;
115      }
116  
Length(lua_State * state)117      int Length(lua_State *state) const override {
118        lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size());
119        return /*num results=*/1;
120      }
121  
Next(lua_State * state)122      int Next(lua_State *state) const override {
123        return Next(FromUpValue<T *>(kItemsArgId, state),
124                    lua_tointeger(state, lua_upvalueindex(kIterValueArgId)),
125                    state);
126      }
127  
Next(const T * items,const int64 pos,lua_State * state)128      int Next(const T *items, const int64 pos, lua_State *state) const {
129        if (pos >= items->size()) {
130          return 0;
131        }
132  
133        // Update iterator value.
134        lua_pushnumber(state, pos + 1);
135        lua_replace(state, lua_upvalueindex(3));
136  
137        // Push key.
138        lua_pushinteger(state, pos + 1);
139  
140        // Push item.
141        return 1 + Item(items, pos, state);
142      }
143  
Item(lua_State * state)144      int Item(lua_State *state) const override {
145        const T *items = FromUpValue<T *>(kItemsArgId, state);
146        switch (lua_type(state, -1)) {
147          case LUA_TNUMBER: {
148            // Lua is one based, so adjust the index here.
149            const int64 index =
150                static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1;
151            if (index < 0 || index >= items->size()) {
152              TC3_LOG(ERROR) << "Invalid index: " << index;
153              lua_error(state);
154              return 0;
155            }
156            return Item(items, index, state);
157          }
158          case LUA_TSTRING: {
159            size_t key_length = 0;
160            const char *key = lua_tolstring(state, /*idx=*/-1, &key_length);
161            return Item(items, StringPiece(key, key_length), state);
162          }
163          default:
164            TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1);
165            lua_error(state);
166            return 0;
167        }
168      }
169  
170      virtual int Item(const T *items, const int64 pos,
171                       lua_State *state) const = 0;
172  
Item(const T * items,StringPiece key,lua_State * state)173      virtual int Item(const T *items, StringPiece key, lua_State *state) const {
174        TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString();
175        lua_error(state);
176        return 0;
177      }
178  
179     protected:
180      static constexpr int kItemsArgId = 2;
181      static constexpr int kIterValueArgId = 3;
182    };
183  
184    virtual ~LuaEnvironment();
185    LuaEnvironment();
186  
187    // Compile a lua snippet into binary bytecode.
188    // NOTE: The compiled bytecode might not be compatible across Lua versions
189    // and platforms.
190    bool Compile(StringPiece snippet, std::string *bytecode);
191  
192    typedef int (*CallbackHandler)(lua_State *);
193  
194    // Loads default libraries.
195    void LoadDefaultLibraries();
196  
197    // Provides a callback to Lua.
198    template <typename T, int (T::*handler)()>
Bind()199    void Bind() {
200      lua_pushlightuserdata(state_, static_cast<void *>(this));
201      lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
202    }
203  
204    // Setup a named table that callsback whenever a member is accessed.
205    // This allows to lazily provide required information to the script.
206    template <typename T, int (T::*handler)()>
BindTable(const char * name)207    void BindTable(const char *name) {
208      lua_newtable(state_);
209      luaL_newmetatable(state_, name);
210      lua_pushlightuserdata(state_, static_cast<void *>(this));
211      lua_pushcclosure(state_, &Dispatch<T, handler>, 1);
212      lua_setfield(state_, -2, kIndexKey);
213      lua_setmetatable(state_, -2);
214    }
215  
216    void PushValue(const Variant &value);
217  
218    // Reads a string from the stack.
219    StringPiece ReadString(const int index) const;
220  
221    // Pushes a string to the stack.
222    void PushString(const StringPiece str);
223  
224    // Pushes a flatbuffer to the stack.
225    void PushFlatbuffer(const reflection::Schema *schema,
226                        const flatbuffers::Table *table);
227  
228    // Reads a flatbuffer from the stack.
229    int ReadFlatbuffer(ReflectiveFlatbuffer *buffer);
230  
231    // Runs a closure in protected mode.
232    // `func`: closure to run in protected mode.
233    // `num_lua_args`: number of arguments from the lua stack to process.
234    // `num_results`: number of result values pushed on the stack.
235    int RunProtected(const std::function<int()> &func, const int num_args = 0,
236                     const int num_results = 0);
237  
state()238    lua_State *state() const { return state_; }
239  
240   protected:
241    lua_State *state_;
242  
243   private:
244    // Auxiliary methods to expose (reflective) flatbuffer based data to Lua.
245    static void PushFlatbuffer(const char *name, const reflection::Schema *schema,
246                               const reflection::Object *type,
247                               const flatbuffers::Table *table, lua_State *state);
248    static int GetFieldCallback(lua_State *state);
249    static int GetField(const reflection::Schema *schema,
250                        const reflection::Object *type,
251                        const flatbuffers::Table *table, lua_State *state);
252  
253    template <typename T, int (T::*handler)()>
Dispatch(lua_State * state)254    static int Dispatch(lua_State *state) {
255      T *env = FromUpValue<T *>(1, state);
256      return ((*env).*handler)();
257    }
258  };
259  
260  bool Compile(StringPiece snippet, std::string *bytecode);
261  
262  }  // namespace libtextclassifier3
263  
264  #endif  // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_
265