1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/zlib-utils.h"
18 
19 #include <memory>
20 
21 #include "utils/base/logging.h"
22 #include "utils/intents/zlib-utils.h"
23 #include "utils/resources.h"
24 
25 namespace libtextclassifier3 {
26 
27 // Compress rule fields in the model.
CompressActionsModel(ActionsModelT * model)28 bool CompressActionsModel(ActionsModelT* model) {
29   std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
30   if (!zlib_compressor) {
31     TC3_LOG(ERROR) << "Cannot compress model.";
32     return false;
33   }
34 
35   // Compress regex rules.
36   if (model->rules != nullptr) {
37     for (int i = 0; i < model->rules->rule.size(); i++) {
38       RulesModel_::RuleT* rule = model->rules->rule[i].get();
39       rule->compressed_pattern.reset(new CompressedBufferT);
40       zlib_compressor->Compress(rule->pattern, rule->compressed_pattern.get());
41       rule->pattern.clear();
42     }
43   }
44 
45   if (model->low_confidence_rules != nullptr) {
46     for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
47       RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
48       if (!rule->pattern.empty()) {
49         rule->compressed_pattern.reset(new CompressedBufferT);
50         zlib_compressor->Compress(rule->pattern,
51                                   rule->compressed_pattern.get());
52         rule->pattern.clear();
53       }
54       if (!rule->output_pattern.empty()) {
55         rule->compressed_output_pattern.reset(new CompressedBufferT);
56         zlib_compressor->Compress(rule->pattern,
57                                   rule->compressed_output_pattern.get());
58         rule->output_pattern.clear();
59       }
60     }
61   }
62 
63   if (!model->lua_actions_script.empty()) {
64     model->compressed_lua_actions_script.reset(new CompressedBufferT);
65     zlib_compressor->Compress(model->lua_actions_script,
66                               model->compressed_lua_actions_script.get());
67   }
68 
69   if (model->ranking_options != nullptr &&
70       !model->ranking_options->lua_ranking_script.empty()) {
71     model->ranking_options->compressed_lua_ranking_script.reset(
72         new CompressedBufferT);
73     zlib_compressor->Compress(
74         model->ranking_options->lua_ranking_script,
75         model->ranking_options->compressed_lua_ranking_script.get());
76   }
77 
78   // Compress resources.
79   if (model->resources != nullptr) {
80     CompressResources(model->resources.get());
81   }
82 
83   // Compress intent generator.
84   if (model->android_intent_options != nullptr) {
85     CompressIntentModel(model->android_intent_options.get());
86   }
87 
88   return true;
89 }
90 
DecompressActionsModel(ActionsModelT * model)91 bool DecompressActionsModel(ActionsModelT* model) {
92   std::unique_ptr<ZlibDecompressor> zlib_decompressor =
93       ZlibDecompressor::Instance();
94   if (!zlib_decompressor) {
95     TC3_LOG(ERROR) << "Cannot initialize decompressor.";
96     return false;
97   }
98 
99   // Decompress regex rules.
100   if (model->rules != nullptr) {
101     for (int i = 0; i < model->rules->rule.size(); i++) {
102       RulesModel_::RuleT* rule = model->rules->rule[i].get();
103       if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
104                                               &rule->pattern)) {
105         TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
106         return false;
107       }
108       rule->compressed_pattern.reset(nullptr);
109     }
110   }
111 
112   // Decompress low confidence rules.
113   if (model->low_confidence_rules != nullptr) {
114     for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
115       RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
116       if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
117                                               &rule->pattern)) {
118         TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
119         return false;
120       }
121       if (!zlib_decompressor->MaybeDecompress(
122               rule->compressed_output_pattern.get(), &rule->output_pattern)) {
123         TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
124         return false;
125       }
126       rule->compressed_pattern.reset(nullptr);
127       rule->compressed_output_pattern.reset(nullptr);
128     }
129   }
130 
131   if (!zlib_decompressor->MaybeDecompress(
132           model->compressed_lua_actions_script.get(),
133           &model->lua_actions_script)) {
134     TC3_LOG(ERROR) << "Cannot decompress actions script.";
135     return false;
136   }
137 
138   if (model->ranking_options != nullptr &&
139       !zlib_decompressor->MaybeDecompress(
140           model->ranking_options->compressed_lua_ranking_script.get(),
141           &model->ranking_options->lua_ranking_script)) {
142     TC3_LOG(ERROR) << "Cannot decompress actions script.";
143     return false;
144   }
145 
146   return true;
147 }
148 
CompressSerializedActionsModel(const std::string & model)149 std::string CompressSerializedActionsModel(const std::string& model) {
150   std::unique_ptr<ActionsModelT> unpacked_model =
151       UnPackActionsModel(model.c_str());
152   TC3_CHECK(unpacked_model != nullptr);
153   TC3_CHECK(CompressActionsModel(unpacked_model.get()));
154   flatbuffers::FlatBufferBuilder builder;
155   FinishActionsModelBuffer(builder,
156                            ActionsModel::Pack(builder, unpacked_model.get()));
157   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
158                      builder.GetSize());
159 }
160 
GetUncompressedString(const flatbuffers::String * uncompressed_buffer,const CompressedBuffer * compressed_buffer,ZlibDecompressor * decompressor,std::string * out)161 bool GetUncompressedString(const flatbuffers::String* uncompressed_buffer,
162                            const CompressedBuffer* compressed_buffer,
163                            ZlibDecompressor* decompressor, std::string* out) {
164   if (uncompressed_buffer == nullptr && compressed_buffer == nullptr) {
165     out->clear();
166     return true;
167   }
168 
169   return decompressor->MaybeDecompressOptionallyCompressedBuffer(
170       uncompressed_buffer, compressed_buffer, out);
171 }
172 
173 }  // namespace libtextclassifier3
174