1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/regex-match.h"
18
19 #include <memory>
20
21 #include "annotator/types.h"
22 #include "utils/lua-utils.h"
23
24 #ifdef __cplusplus
25 extern "C" {
26 #endif
27 #include "lauxlib.h"
28 #include "lualib.h"
29 #ifdef __cplusplus
30 }
31 #endif
32
33 namespace libtextclassifier3 {
34 namespace {
35
36 // Provide a lua environment for running regex match post verification.
37 // It sets up and exposes the match data as well as the context.
38 class LuaVerifier : private LuaEnvironment {
39 public:
40 static std::unique_ptr<LuaVerifier> Create(
41 const std::string& context, const std::string& verifier_code,
42 const UniLib::RegexMatcher* matcher);
43
44 bool Verify(bool* result);
45
46 private:
LuaVerifier(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)47 explicit LuaVerifier(const std::string& context,
48 const std::string& verifier_code,
49 const UniLib::RegexMatcher* matcher)
50 : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
51 bool Initialize();
52
53 // Provides details of a capturing group to lua.
54 int GetCapturingGroup();
55
56 const std::string& context_;
57 const std::string& verifier_code_;
58 const UniLib::RegexMatcher* matcher_;
59 };
60
Initialize()61 bool LuaVerifier::Initialize() {
62 // Run protected to not lua panic in case of setup failure.
63 return RunProtected([this] {
64 LoadDefaultLibraries();
65
66 // Expose context of the match as `context` global variable.
67 PushString(context_);
68 lua_setglobal(state_, "context");
69
70 // Expose match array as `match` global variable.
71 // Each entry `match[i]` exposes the ith capturing group as:
72 // * `begin`: span start
73 // * `end`: span end
74 // * `text`: the text
75 BindTable<LuaVerifier, &LuaVerifier::GetCapturingGroup>("match");
76 lua_setglobal(state_, "match");
77 return LUA_OK;
78 }) == LUA_OK;
79 }
80
Create(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)81 std::unique_ptr<LuaVerifier> LuaVerifier::Create(
82 const std::string& context, const std::string& verifier_code,
83 const UniLib::RegexMatcher* matcher) {
84 auto verifier = std::unique_ptr<LuaVerifier>(
85 new LuaVerifier(context, verifier_code, matcher));
86 if (!verifier->Initialize()) {
87 TC3_LOG(ERROR) << "Could not initialize lua environment.";
88 return nullptr;
89 }
90 return verifier;
91 }
92
GetCapturingGroup()93 int LuaVerifier::GetCapturingGroup() {
94 if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
95 TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
96 << lua_type(state_, /*idx=*/-1);
97 lua_error(state_);
98 return 0;
99 }
100 const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
101 int status = UniLib::RegexMatcher::kNoError;
102 const CodepointSpan span = {matcher_->Start(group_id, &status),
103 matcher_->End(group_id, &status)};
104 std::string text = matcher_->Group(group_id, &status).ToUTF8String();
105 if (status != UniLib::RegexMatcher::kNoError) {
106 TC3_LOG(ERROR) << "Could not extract span from capturing group.";
107 lua_error(state_);
108 return 0;
109 }
110 lua_newtable(state_);
111 lua_pushinteger(state_, span.first);
112 lua_setfield(state_, /*idx=*/-2, "begin");
113 lua_pushinteger(state_, span.second);
114 lua_setfield(state_, /*idx=*/-2, "end");
115 PushString(text);
116 lua_setfield(state_, /*idx=*/-2, "text");
117 return 1;
118 }
119
Verify(bool * result)120 bool LuaVerifier::Verify(bool* result) {
121 if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
122 /*name=*/nullptr) != LUA_OK) {
123 TC3_LOG(ERROR) << "Could not load verifier snippet.";
124 return false;
125 }
126
127 if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
128 TC3_LOG(ERROR) << "Could not run verifier snippet.";
129 return false;
130 }
131
132 if (RunProtected(
133 [this, result] {
134 if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
135 TC3_LOG(ERROR) << "Unexpected verification result type: "
136 << lua_type(state_, /*idx=*/-1);
137 lua_error(state_);
138 return LUA_ERRRUN;
139 }
140 *result = lua_toboolean(state_, /*idx=*/-1);
141 return LUA_OK;
142 },
143 /*num_args=*/1) != LUA_OK) {
144 TC3_LOG(ERROR) << "Could not read lua result.";
145 return false;
146 }
147 return true;
148 }
149
150 } // namespace
151
SetFieldFromCapturingGroup(const int group_id,const FlatbufferFieldPath * field_path,const UniLib::RegexMatcher * matcher,ReflectiveFlatbuffer * flatbuffer)152 bool SetFieldFromCapturingGroup(const int group_id,
153 const FlatbufferFieldPath* field_path,
154 const UniLib::RegexMatcher* matcher,
155 ReflectiveFlatbuffer* flatbuffer) {
156 int status = UniLib::RegexMatcher::kNoError;
157 std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
158 if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
159 return false;
160 }
161 return flatbuffer->ParseAndSet(field_path, group_text);
162 }
163
VerifyMatch(const std::string & context,const UniLib::RegexMatcher * matcher,const std::string & lua_verifier_code)164 bool VerifyMatch(const std::string& context,
165 const UniLib::RegexMatcher* matcher,
166 const std::string& lua_verifier_code) {
167 bool status = false;
168 auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
169 if (verifier == nullptr) {
170 TC3_LOG(ERROR) << "Could not create verifier.";
171 return false;
172 }
173 if (!verifier->Verify(&status)) {
174 TC3_LOG(ERROR) << "Could not create verifier.";
175 return false;
176 }
177 return status;
178 }
179
180 } // namespace libtextclassifier3
181