/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/zlib-utils.h" #include #include "utils/base/logging.h" #include "utils/intents/zlib-utils.h" #include "utils/zlib/zlib.h" namespace libtextclassifier3 { // Compress rule fields in the model. bool CompressModel(ModelT* model) { std::unique_ptr zlib_compressor = ZlibCompressor::Instance(); if (!zlib_compressor) { TC3_LOG(ERROR) << "Cannot compress model."; return false; } // Compress regex rules. if (model->regex_model != nullptr) { for (int i = 0; i < model->regex_model->patterns.size(); i++) { RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); pattern->compressed_pattern.reset(new CompressedBufferT); zlib_compressor->Compress(pattern->pattern, pattern->compressed_pattern.get()); pattern->pattern.clear(); } } // Compress date-time rules. if (model->datetime_model != nullptr) { for (int i = 0; i < model->datetime_model->patterns.size(); i++) { DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); for (int j = 0; j < pattern->regexes.size(); j++) { DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); regex->compressed_pattern.reset(new CompressedBufferT); zlib_compressor->Compress(regex->pattern, regex->compressed_pattern.get()); regex->pattern.clear(); } } for (int i = 0; i < model->datetime_model->extractors.size(); i++) { DatetimeModelExtractorT* extractor = model->datetime_model->extractors[i].get(); extractor->compressed_pattern.reset(new CompressedBufferT); zlib_compressor->Compress(extractor->pattern, extractor->compressed_pattern.get()); extractor->pattern.clear(); } } // Compress intent generator. if (model->intent_options != nullptr) { CompressIntentModel(model->intent_options.get()); } return true; } bool DecompressModel(ModelT* model) { std::unique_ptr zlib_decompressor = ZlibDecompressor::Instance(); if (!zlib_decompressor) { TC3_LOG(ERROR) << "Cannot initialize decompressor."; return false; } // Decompress regex rules. if (model->regex_model != nullptr) { for (int i = 0; i < model->regex_model->patterns.size(); i++) { RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get(); if (!zlib_decompressor->MaybeDecompress(pattern->compressed_pattern.get(), &pattern->pattern)) { TC3_LOG(ERROR) << "Cannot decompress pattern: " << i; return false; } pattern->compressed_pattern.reset(nullptr); } } // Decompress date-time rules. if (model->datetime_model != nullptr) { for (int i = 0; i < model->datetime_model->patterns.size(); i++) { DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get(); for (int j = 0; j < pattern->regexes.size(); j++) { DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get(); if (!zlib_decompressor->MaybeDecompress(regex->compressed_pattern.get(), ®ex->pattern)) { TC3_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j; return false; } regex->compressed_pattern.reset(nullptr); } } for (int i = 0; i < model->datetime_model->extractors.size(); i++) { DatetimeModelExtractorT* extractor = model->datetime_model->extractors[i].get(); if (!zlib_decompressor->MaybeDecompress( extractor->compressed_pattern.get(), &extractor->pattern)) { TC3_LOG(ERROR) << "Cannot decompress pattern: " << i; return false; } extractor->compressed_pattern.reset(nullptr); } } if (model->intent_options != nullptr) { DecompressIntentModel(model->intent_options.get()); } return true; } std::string CompressSerializedModel(const std::string& model) { std::unique_ptr unpacked_model = UnPackModel(model.c_str()); TC3_CHECK(unpacked_model != nullptr); TC3_CHECK(CompressModel(unpacked_model.get())); flatbuffers::FlatBufferBuilder builder; FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get())); return std::string(reinterpret_cast(builder.GetBufferPointer()), builder.GetSize()); } } // namespace libtextclassifier3