1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/lua-ranker.h"
18 
19 #include "utils/base/logging.h"
20 #include "utils/lua-utils.h"
21 
22 #ifdef __cplusplus
23 extern "C" {
24 #endif
25 #include "lauxlib.h"
26 #include "lualib.h"
27 #ifdef __cplusplus
28 }
29 #endif
30 
31 namespace libtextclassifier3 {
32 
33 std::unique_ptr<ActionsSuggestionsLuaRanker>
Create(const Conversation & conversation,const std::string & ranker_code,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema,ActionsSuggestionsResponse * response)34 ActionsSuggestionsLuaRanker::Create(
35     const Conversation& conversation, const std::string& ranker_code,
36     const reflection::Schema* entity_data_schema,
37     const reflection::Schema* annotations_entity_data_schema,
38     ActionsSuggestionsResponse* response) {
39   auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
40       new ActionsSuggestionsLuaRanker(
41           conversation, ranker_code, entity_data_schema,
42           annotations_entity_data_schema, response));
43   if (!ranker->Initialize()) {
44     TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
45     return nullptr;
46   }
47   return ranker;
48 }
49 
Initialize()50 bool ActionsSuggestionsLuaRanker::Initialize() {
51   return RunProtected([this] {
52            LoadDefaultLibraries();
53 
54            // Expose generated actions.
55            PushActions(&response_->actions, actions_entity_data_schema_,
56                        annotations_entity_data_schema_);
57            lua_setglobal(state_, "actions");
58 
59            // Expose conversation message stream.
60            PushConversation(&conversation_.messages,
61                             annotations_entity_data_schema_);
62            lua_setglobal(state_, "messages");
63            return LUA_OK;
64          }) == LUA_OK;
65 }
66 
ReadActionsRanking()67 int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
68   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
69     TC3_LOG(ERROR) << "Expected actions table, got: "
70                    << lua_type(state_, /*idx=*/-1);
71     lua_pop(state_, 1);
72     lua_error(state_);
73     return LUA_ERRRUN;
74   }
75   std::vector<ActionSuggestion> ranked_actions;
76   lua_pushnil(state_);
77   while (Next(/*index=*/-2)) {
78     const int action_id = Read<int>(/*index=*/-1) - 1;
79     lua_pop(state_, 1);
80     if (action_id < 0 || action_id >= response_->actions.size()) {
81       TC3_LOG(ERROR) << "Invalid action index: " << action_id;
82       lua_error(state_);
83       return LUA_ERRRUN;
84     }
85     ranked_actions.push_back(response_->actions[action_id]);
86   }
87   lua_pop(state_, 1);
88   response_->actions = ranked_actions;
89   return LUA_OK;
90 }
91 
RankActions()92 bool ActionsSuggestionsLuaRanker::RankActions() {
93   if (response_->actions.empty()) {
94     // Nothing to do.
95     return true;
96   }
97 
98   if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
99                       /*name=*/nullptr) != LUA_OK) {
100     TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
101     return false;
102   }
103 
104   if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
105     TC3_LOG(ERROR) << "Could not run ranking snippet.";
106     return false;
107   }
108 
109   if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
110       LUA_OK) {
111     TC3_LOG(ERROR) << "Could not read lua result.";
112     return false;
113   }
114   return true;
115 }
116 
117 }  // namespace libtextclassifier3
118