1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/regex-match.h"
18 
19 #include <memory>
20 
21 #include "annotator/types.h"
22 #include "utils/lua-utils.h"
23 
24 #ifdef __cplusplus
25 extern "C" {
26 #endif
27 #include "lauxlib.h"
28 #include "lualib.h"
29 #ifdef __cplusplus
30 }
31 #endif
32 
33 namespace libtextclassifier3 {
34 namespace {
35 
36 // Provide a lua environment for running regex match post verification.
37 // It sets up and exposes the match data as well as the context.
38 class LuaVerifier : private LuaEnvironment {
39  public:
40   static std::unique_ptr<LuaVerifier> Create(
41       const std::string& context, const std::string& verifier_code,
42       const UniLib::RegexMatcher* matcher);
43 
44   bool Verify(bool* result);
45 
46  private:
LuaVerifier(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)47   explicit LuaVerifier(const std::string& context,
48                        const std::string& verifier_code,
49                        const UniLib::RegexMatcher* matcher)
50       : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
51   bool Initialize();
52 
53   // Provides details of a capturing group to lua.
54   int GetCapturingGroup();
55 
56   const std::string& context_;
57   const std::string& verifier_code_;
58   const UniLib::RegexMatcher* matcher_;
59 };
60 
Initialize()61 bool LuaVerifier::Initialize() {
62   // Run protected to not lua panic in case of setup failure.
63   return RunProtected([this] {
64            LoadDefaultLibraries();
65 
66            // Expose context of the match as `context` global variable.
67            PushString(context_);
68            lua_setglobal(state_, "context");
69 
70            // Expose match array as `match` global variable.
71            // Each entry `match[i]` exposes the ith capturing group as:
72            //   * `begin`: span start
73            //   * `end`: span end
74            //   * `text`: the text
75            BindTable<LuaVerifier, &LuaVerifier::GetCapturingGroup>("match");
76            lua_setglobal(state_, "match");
77            return LUA_OK;
78          }) == LUA_OK;
79 }
80 
Create(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)81 std::unique_ptr<LuaVerifier> LuaVerifier::Create(
82     const std::string& context, const std::string& verifier_code,
83     const UniLib::RegexMatcher* matcher) {
84   auto verifier = std::unique_ptr<LuaVerifier>(
85       new LuaVerifier(context, verifier_code, matcher));
86   if (!verifier->Initialize()) {
87     TC3_LOG(ERROR) << "Could not initialize lua environment.";
88     return nullptr;
89   }
90   return verifier;
91 }
92 
GetCapturingGroup()93 int LuaVerifier::GetCapturingGroup() {
94   if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
95     TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
96                    << lua_type(state_, /*idx=*/-1);
97     lua_error(state_);
98     return 0;
99   }
100   const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
101   int status = UniLib::RegexMatcher::kNoError;
102   const CodepointSpan span = {matcher_->Start(group_id, &status),
103                               matcher_->End(group_id, &status)};
104   std::string text = matcher_->Group(group_id, &status).ToUTF8String();
105   if (status != UniLib::RegexMatcher::kNoError) {
106     TC3_LOG(ERROR) << "Could not extract span from capturing group.";
107     lua_error(state_);
108     return 0;
109   }
110   lua_newtable(state_);
111   lua_pushinteger(state_, span.first);
112   lua_setfield(state_, /*idx=*/-2, "begin");
113   lua_pushinteger(state_, span.second);
114   lua_setfield(state_, /*idx=*/-2, "end");
115   PushString(text);
116   lua_setfield(state_, /*idx=*/-2, "text");
117   return 1;
118 }
119 
Verify(bool * result)120 bool LuaVerifier::Verify(bool* result) {
121   if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
122                       /*name=*/nullptr) != LUA_OK) {
123     TC3_LOG(ERROR) << "Could not load verifier snippet.";
124     return false;
125   }
126 
127   if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
128     TC3_LOG(ERROR) << "Could not run verifier snippet.";
129     return false;
130   }
131 
132   if (RunProtected(
133           [this, result] {
134             if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
135               TC3_LOG(ERROR) << "Unexpected verification result type: "
136                              << lua_type(state_, /*idx=*/-1);
137               lua_error(state_);
138               return LUA_ERRRUN;
139             }
140             *result = lua_toboolean(state_, /*idx=*/-1);
141             return LUA_OK;
142           },
143           /*num_args=*/1) != LUA_OK) {
144     TC3_LOG(ERROR) << "Could not read lua result.";
145     return false;
146   }
147   return true;
148 }
149 
150 }  // namespace
151 
SetFieldFromCapturingGroup(const int group_id,const FlatbufferFieldPath * field_path,const UniLib::RegexMatcher * matcher,ReflectiveFlatbuffer * flatbuffer)152 bool SetFieldFromCapturingGroup(const int group_id,
153                                 const FlatbufferFieldPath* field_path,
154                                 const UniLib::RegexMatcher* matcher,
155                                 ReflectiveFlatbuffer* flatbuffer) {
156   int status = UniLib::RegexMatcher::kNoError;
157   std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
158   if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
159     return false;
160   }
161   return flatbuffer->ParseAndSet(field_path, group_text);
162 }
163 
VerifyMatch(const std::string & context,const UniLib::RegexMatcher * matcher,const std::string & lua_verifier_code)164 bool VerifyMatch(const std::string& context,
165                  const UniLib::RegexMatcher* matcher,
166                  const std::string& lua_verifier_code) {
167   bool status = false;
168   auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
169   if (verifier == nullptr) {
170     TC3_LOG(ERROR) << "Could not create verifier.";
171     return false;
172   }
173   if (!verifier->Verify(&status)) {
174     TC3_LOG(ERROR) << "Could not create verifier.";
175     return false;
176   }
177   return status;
178 }
179 
180 }  // namespace libtextclassifier3
181