1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_ 19 20 #include <functional> 21 #include <vector> 22 23 #include "utils/flatbuffers.h" 24 #include "utils/strings/stringpiece.h" 25 #include "utils/variant.h" 26 #include "flatbuffers/reflection_generated.h" 27 28 #ifdef __cplusplus 29 extern "C" { 30 #endif 31 #include "lauxlib.h" 32 #include "lua.h" 33 #include "lualib.h" 34 #ifdef __cplusplus 35 } 36 #endif 37 38 namespace libtextclassifier3 { 39 40 static constexpr const char *kLengthKey = "__len"; 41 static constexpr const char *kPairsKey = "__pairs"; 42 static constexpr const char *kIndexKey = "__index"; 43 44 // Casts to the lua user data type. 45 template <typename T> AsUserData(const T * value)46 void *AsUserData(const T *value) { 47 return static_cast<void *>(const_cast<T *>(value)); 48 } 49 template <typename T> AsUserData(const T value)50 void *AsUserData(const T value) { 51 return reinterpret_cast<void *>(value); 52 } 53 54 // Retrieves up-values. 55 template <typename T> FromUpValue(const int index,lua_State * state)56 T FromUpValue(const int index, lua_State *state) { 57 return static_cast<T>(lua_touserdata(state, lua_upvalueindex(index))); 58 } 59 60 class LuaEnvironment { 61 public: 62 // Wrapper for handling an iterator. 63 class Iterator { 64 public: ~Iterator()65 virtual ~Iterator() {} 66 static int NextCallback(lua_State *state); 67 static int LengthCallback(lua_State *state); 68 static int ItemCallback(lua_State *state); 69 static int IteritemsCallback(lua_State *state); 70 71 // Called when the next element of an iterator is fetched. 72 virtual int Next(lua_State *state) const = 0; 73 74 // Called when the length of the iterator is queried. 75 virtual int Length(lua_State *state) const = 0; 76 77 // Called when an item is queried. 78 virtual int Item(lua_State *state) const = 0; 79 80 // Called when a new iterator is started. 81 virtual int Iteritems(lua_State *state) const = 0; 82 83 protected: 84 static constexpr int kIteratorArgId = 1; 85 }; 86 87 template <typename T> 88 class ItemIterator : public Iterator { 89 public: NewIterator(StringPiece name,const T * items,lua_State * state)90 void NewIterator(StringPiece name, const T *items, lua_State *state) const { 91 lua_newtable(state); 92 luaL_newmetatable(state, name.data()); 93 lua_pushlightuserdata(state, AsUserData(this)); 94 lua_pushlightuserdata(state, AsUserData(items)); 95 lua_pushcclosure(state, &Iterator::ItemCallback, 2); 96 lua_setfield(state, -2, kIndexKey); 97 lua_pushlightuserdata(state, AsUserData(this)); 98 lua_pushlightuserdata(state, AsUserData(items)); 99 lua_pushcclosure(state, &Iterator::LengthCallback, 2); 100 lua_setfield(state, -2, kLengthKey); 101 lua_pushlightuserdata(state, AsUserData(this)); 102 lua_pushlightuserdata(state, AsUserData(items)); 103 lua_pushcclosure(state, &Iterator::IteritemsCallback, 2); 104 lua_setfield(state, -2, kPairsKey); 105 lua_setmetatable(state, -2); 106 } 107 Iteritems(lua_State * state)108 int Iteritems(lua_State *state) const override { 109 lua_pushlightuserdata(state, AsUserData(this)); 110 lua_pushlightuserdata( 111 state, lua_touserdata(state, lua_upvalueindex(kItemsArgId))); 112 lua_pushnumber(state, 0); 113 lua_pushcclosure(state, &Iterator::NextCallback, 3); 114 return /*num results=*/1; 115 } 116 Length(lua_State * state)117 int Length(lua_State *state) const override { 118 lua_pushinteger(state, FromUpValue<T *>(kItemsArgId, state)->size()); 119 return /*num results=*/1; 120 } 121 Next(lua_State * state)122 int Next(lua_State *state) const override { 123 return Next(FromUpValue<T *>(kItemsArgId, state), 124 lua_tointeger(state, lua_upvalueindex(kIterValueArgId)), 125 state); 126 } 127 Next(const T * items,const int64 pos,lua_State * state)128 int Next(const T *items, const int64 pos, lua_State *state) const { 129 if (pos >= items->size()) { 130 return 0; 131 } 132 133 // Update iterator value. 134 lua_pushnumber(state, pos + 1); 135 lua_replace(state, lua_upvalueindex(3)); 136 137 // Push key. 138 lua_pushinteger(state, pos + 1); 139 140 // Push item. 141 return 1 + Item(items, pos, state); 142 } 143 Item(lua_State * state)144 int Item(lua_State *state) const override { 145 const T *items = FromUpValue<T *>(kItemsArgId, state); 146 switch (lua_type(state, -1)) { 147 case LUA_TNUMBER: { 148 // Lua is one based, so adjust the index here. 149 const int64 index = 150 static_cast<int64>(lua_tonumber(state, /*idx=*/-1)) - 1; 151 if (index < 0 || index >= items->size()) { 152 TC3_LOG(ERROR) << "Invalid index: " << index; 153 lua_error(state); 154 return 0; 155 } 156 return Item(items, index, state); 157 } 158 case LUA_TSTRING: { 159 size_t key_length = 0; 160 const char *key = lua_tolstring(state, /*idx=*/-1, &key_length); 161 return Item(items, StringPiece(key, key_length), state); 162 } 163 default: 164 TC3_LOG(ERROR) << "Unexpected access type: " << lua_type(state, -1); 165 lua_error(state); 166 return 0; 167 } 168 } 169 170 virtual int Item(const T *items, const int64 pos, 171 lua_State *state) const = 0; 172 Item(const T * items,StringPiece key,lua_State * state)173 virtual int Item(const T *items, StringPiece key, lua_State *state) const { 174 TC3_LOG(ERROR) << "Unexpected key access: " << key.ToString(); 175 lua_error(state); 176 return 0; 177 } 178 179 protected: 180 static constexpr int kItemsArgId = 2; 181 static constexpr int kIterValueArgId = 3; 182 }; 183 184 virtual ~LuaEnvironment(); 185 LuaEnvironment(); 186 187 // Compile a lua snippet into binary bytecode. 188 // NOTE: The compiled bytecode might not be compatible across Lua versions 189 // and platforms. 190 bool Compile(StringPiece snippet, std::string *bytecode); 191 192 typedef int (*CallbackHandler)(lua_State *); 193 194 // Loads default libraries. 195 void LoadDefaultLibraries(); 196 197 // Provides a callback to Lua. 198 template <typename T, int (T::*handler)()> Bind()199 void Bind() { 200 lua_pushlightuserdata(state_, static_cast<void *>(this)); 201 lua_pushcclosure(state_, &Dispatch<T, handler>, 1); 202 } 203 204 // Setup a named table that callsback whenever a member is accessed. 205 // This allows to lazily provide required information to the script. 206 template <typename T, int (T::*handler)()> BindTable(const char * name)207 void BindTable(const char *name) { 208 lua_newtable(state_); 209 luaL_newmetatable(state_, name); 210 lua_pushlightuserdata(state_, static_cast<void *>(this)); 211 lua_pushcclosure(state_, &Dispatch<T, handler>, 1); 212 lua_setfield(state_, -2, kIndexKey); 213 lua_setmetatable(state_, -2); 214 } 215 216 void PushValue(const Variant &value); 217 218 // Reads a string from the stack. 219 StringPiece ReadString(const int index) const; 220 221 // Pushes a string to the stack. 222 void PushString(const StringPiece str); 223 224 // Pushes a flatbuffer to the stack. 225 void PushFlatbuffer(const reflection::Schema *schema, 226 const flatbuffers::Table *table); 227 228 // Reads a flatbuffer from the stack. 229 int ReadFlatbuffer(ReflectiveFlatbuffer *buffer); 230 231 // Runs a closure in protected mode. 232 // `func`: closure to run in protected mode. 233 // `num_lua_args`: number of arguments from the lua stack to process. 234 // `num_results`: number of result values pushed on the stack. 235 int RunProtected(const std::function<int()> &func, const int num_args = 0, 236 const int num_results = 0); 237 state()238 lua_State *state() const { return state_; } 239 240 protected: 241 lua_State *state_; 242 243 private: 244 // Auxiliary methods to expose (reflective) flatbuffer based data to Lua. 245 static void PushFlatbuffer(const char *name, const reflection::Schema *schema, 246 const reflection::Object *type, 247 const flatbuffers::Table *table, lua_State *state); 248 static int GetFieldCallback(lua_State *state); 249 static int GetField(const reflection::Schema *schema, 250 const reflection::Object *type, 251 const flatbuffers::Table *table, lua_State *state); 252 253 template <typename T, int (T::*handler)()> Dispatch(lua_State * state)254 static int Dispatch(lua_State *state) { 255 T *env = FromUpValue<T *>(1, state); 256 return ((*env).*handler)(); 257 } 258 }; 259 260 bool Compile(StringPiece snippet, std::string *bytecode); 261 262 } // namespace libtextclassifier3 263 264 #endif // LIBTEXTCLASSIFIER_UTILS_LUA_UTILS_H_ 265