1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/zlib-utils.h"
18
19 #include <memory>
20
21 #include "utils/base/logging.h"
22 #include "utils/intents/zlib-utils.h"
23 #include "utils/resources.h"
24
25 namespace libtextclassifier3 {
26
27 // Compress rule fields in the model.
CompressActionsModel(ActionsModelT * model)28 bool CompressActionsModel(ActionsModelT* model) {
29 std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
30 if (!zlib_compressor) {
31 TC3_LOG(ERROR) << "Cannot compress model.";
32 return false;
33 }
34
35 // Compress regex rules.
36 if (model->rules != nullptr) {
37 for (int i = 0; i < model->rules->rule.size(); i++) {
38 RulesModel_::RuleT* rule = model->rules->rule[i].get();
39 rule->compressed_pattern.reset(new CompressedBufferT);
40 zlib_compressor->Compress(rule->pattern, rule->compressed_pattern.get());
41 rule->pattern.clear();
42 }
43 }
44
45 if (model->low_confidence_rules != nullptr) {
46 for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
47 RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
48 if (!rule->pattern.empty()) {
49 rule->compressed_pattern.reset(new CompressedBufferT);
50 zlib_compressor->Compress(rule->pattern,
51 rule->compressed_pattern.get());
52 rule->pattern.clear();
53 }
54 if (!rule->output_pattern.empty()) {
55 rule->compressed_output_pattern.reset(new CompressedBufferT);
56 zlib_compressor->Compress(rule->pattern,
57 rule->compressed_output_pattern.get());
58 rule->output_pattern.clear();
59 }
60 }
61 }
62
63 if (!model->lua_actions_script.empty()) {
64 model->compressed_lua_actions_script.reset(new CompressedBufferT);
65 zlib_compressor->Compress(model->lua_actions_script,
66 model->compressed_lua_actions_script.get());
67 }
68
69 if (model->ranking_options != nullptr &&
70 !model->ranking_options->lua_ranking_script.empty()) {
71 model->ranking_options->compressed_lua_ranking_script.reset(
72 new CompressedBufferT);
73 zlib_compressor->Compress(
74 model->ranking_options->lua_ranking_script,
75 model->ranking_options->compressed_lua_ranking_script.get());
76 }
77
78 // Compress resources.
79 if (model->resources != nullptr) {
80 CompressResources(model->resources.get());
81 }
82
83 // Compress intent generator.
84 if (model->android_intent_options != nullptr) {
85 CompressIntentModel(model->android_intent_options.get());
86 }
87
88 return true;
89 }
90
DecompressActionsModel(ActionsModelT * model)91 bool DecompressActionsModel(ActionsModelT* model) {
92 std::unique_ptr<ZlibDecompressor> zlib_decompressor =
93 ZlibDecompressor::Instance();
94 if (!zlib_decompressor) {
95 TC3_LOG(ERROR) << "Cannot initialize decompressor.";
96 return false;
97 }
98
99 // Decompress regex rules.
100 if (model->rules != nullptr) {
101 for (int i = 0; i < model->rules->rule.size(); i++) {
102 RulesModel_::RuleT* rule = model->rules->rule[i].get();
103 if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
104 &rule->pattern)) {
105 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
106 return false;
107 }
108 rule->compressed_pattern.reset(nullptr);
109 }
110 }
111
112 // Decompress low confidence rules.
113 if (model->low_confidence_rules != nullptr) {
114 for (int i = 0; i < model->low_confidence_rules->rule.size(); i++) {
115 RulesModel_::RuleT* rule = model->low_confidence_rules->rule[i].get();
116 if (!zlib_decompressor->MaybeDecompress(rule->compressed_pattern.get(),
117 &rule->pattern)) {
118 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
119 return false;
120 }
121 if (!zlib_decompressor->MaybeDecompress(
122 rule->compressed_output_pattern.get(), &rule->output_pattern)) {
123 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
124 return false;
125 }
126 rule->compressed_pattern.reset(nullptr);
127 rule->compressed_output_pattern.reset(nullptr);
128 }
129 }
130
131 if (!zlib_decompressor->MaybeDecompress(
132 model->compressed_lua_actions_script.get(),
133 &model->lua_actions_script)) {
134 TC3_LOG(ERROR) << "Cannot decompress actions script.";
135 return false;
136 }
137
138 if (model->ranking_options != nullptr &&
139 !zlib_decompressor->MaybeDecompress(
140 model->ranking_options->compressed_lua_ranking_script.get(),
141 &model->ranking_options->lua_ranking_script)) {
142 TC3_LOG(ERROR) << "Cannot decompress actions script.";
143 return false;
144 }
145
146 return true;
147 }
148
CompressSerializedActionsModel(const std::string & model)149 std::string CompressSerializedActionsModel(const std::string& model) {
150 std::unique_ptr<ActionsModelT> unpacked_model =
151 UnPackActionsModel(model.c_str());
152 TC3_CHECK(unpacked_model != nullptr);
153 TC3_CHECK(CompressActionsModel(unpacked_model.get()));
154 flatbuffers::FlatBufferBuilder builder;
155 FinishActionsModelBuffer(builder,
156 ActionsModel::Pack(builder, unpacked_model.get()));
157 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
158 builder.GetSize());
159 }
160
GetUncompressedString(const flatbuffers::String * uncompressed_buffer,const CompressedBuffer * compressed_buffer,ZlibDecompressor * decompressor,std::string * out)161 bool GetUncompressedString(const flatbuffers::String* uncompressed_buffer,
162 const CompressedBuffer* compressed_buffer,
163 ZlibDecompressor* decompressor, std::string* out) {
164 if (uncompressed_buffer == nullptr && compressed_buffer == nullptr) {
165 out->clear();
166 return true;
167 }
168
169 return decompressor->MaybeDecompressOptionallyCompressedBuffer(
170 uncompressed_buffer, compressed_buffer, out);
171 }
172
173 } // namespace libtextclassifier3
174