1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/lua-ranker.h"
18
19 #include "utils/base/logging.h"
20 #include "utils/lua-utils.h"
21
22 #ifdef __cplusplus
23 extern "C" {
24 #endif
25 #include "lauxlib.h"
26 #include "lualib.h"
27 #ifdef __cplusplus
28 }
29 #endif
30
31 namespace libtextclassifier3 {
32
33 std::unique_ptr<ActionsSuggestionsLuaRanker>
Create(const Conversation & conversation,const std::string & ranker_code,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema,ActionsSuggestionsResponse * response)34 ActionsSuggestionsLuaRanker::Create(
35 const Conversation& conversation, const std::string& ranker_code,
36 const reflection::Schema* entity_data_schema,
37 const reflection::Schema* annotations_entity_data_schema,
38 ActionsSuggestionsResponse* response) {
39 auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
40 new ActionsSuggestionsLuaRanker(
41 conversation, ranker_code, entity_data_schema,
42 annotations_entity_data_schema, response));
43 if (!ranker->Initialize()) {
44 TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
45 return nullptr;
46 }
47 return ranker;
48 }
49
Initialize()50 bool ActionsSuggestionsLuaRanker::Initialize() {
51 return RunProtected([this] {
52 LoadDefaultLibraries();
53
54 // Expose generated actions.
55 PushActions(&response_->actions, actions_entity_data_schema_,
56 annotations_entity_data_schema_);
57 lua_setglobal(state_, "actions");
58
59 // Expose conversation message stream.
60 PushConversation(&conversation_.messages,
61 annotations_entity_data_schema_);
62 lua_setglobal(state_, "messages");
63 return LUA_OK;
64 }) == LUA_OK;
65 }
66
ReadActionsRanking()67 int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
68 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
69 TC3_LOG(ERROR) << "Expected actions table, got: "
70 << lua_type(state_, /*idx=*/-1);
71 lua_pop(state_, 1);
72 lua_error(state_);
73 return LUA_ERRRUN;
74 }
75 std::vector<ActionSuggestion> ranked_actions;
76 lua_pushnil(state_);
77 while (Next(/*index=*/-2)) {
78 const int action_id = Read<int>(/*index=*/-1) - 1;
79 lua_pop(state_, 1);
80 if (action_id < 0 || action_id >= response_->actions.size()) {
81 TC3_LOG(ERROR) << "Invalid action index: " << action_id;
82 lua_error(state_);
83 return LUA_ERRRUN;
84 }
85 ranked_actions.push_back(response_->actions[action_id]);
86 }
87 lua_pop(state_, 1);
88 response_->actions = ranked_actions;
89 return LUA_OK;
90 }
91
RankActions()92 bool ActionsSuggestionsLuaRanker::RankActions() {
93 if (response_->actions.empty()) {
94 // Nothing to do.
95 return true;
96 }
97
98 if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
99 /*name=*/nullptr) != LUA_OK) {
100 TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
101 return false;
102 }
103
104 if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
105 TC3_LOG(ERROR) << "Could not run ranking snippet.";
106 return false;
107 }
108
109 if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
110 LUA_OK) {
111 TC3_LOG(ERROR) << "Could not read lua result.";
112 return false;
113 }
114 return true;
115 }
116
117 } // namespace libtextclassifier3
118