1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <algorithm>
20 #include <memory>
21 #include <span>
22 
23 #include <android-base/logging.h>
24 
25 #include <brotli/decode.h>
26 #include <brotli/encode.h>
27 #include <lz4frame.h>
28 #include <zstd.h>
29 
30 #include "types.h"
31 
32 enum class DecodeResult {
33     Error,
34     Done,
35     NeedInput,
36     MoreOutput,
37 };
38 
39 enum class EncodeResult {
40     Error,
41     Done,
42     NeedInput,
43     MoreOutput,
44 };
45 
46 struct Decoder {
47     void Append(Block&& block) { input_buffer_.append(std::move(block)); }
48     bool Finish() {
49         bool old = std::exchange(finished_, true);
50         if (old) {
51             LOG(FATAL) << "Decoder::Finish called while already finished?";
52             return false;
53         }
54         return true;
55     }
56 
57     virtual DecodeResult Decode(std::span<char>* output) = 0;
58 
59   protected:
60     Decoder(std::span<char> output_buffer) : output_buffer_(output_buffer) {}
61     ~Decoder() = default;
62 
63     bool finished_ = false;
64     IOVector input_buffer_;
65     std::span<char> output_buffer_;
66 };
67 
68 struct Encoder {
69     void Append(Block input) { input_buffer_.append(std::move(input)); }
70     bool Finish() {
71         bool old = std::exchange(finished_, true);
72         if (old) {
73             LOG(FATAL) << "Decoder::Finish called while already finished?";
74             return false;
75         }
76         return true;
77     }
78 
79     virtual EncodeResult Encode(Block* output) = 0;
80 
81   protected:
82     explicit Encoder(size_t output_block_size) : output_block_size_(output_block_size) {}
83     ~Encoder() = default;
84 
85     const size_t output_block_size_;
86     bool finished_ = false;
87     IOVector input_buffer_;
88 };
89 
90 struct NullDecoder final : public Decoder {
91     explicit NullDecoder(std::span<char> output_buffer) : Decoder(output_buffer) {}
92 
93     DecodeResult Decode(std::span<char>* output) final {
94         size_t available_out = output_buffer_.size();
95         void* p = output_buffer_.data();
96         while (available_out > 0 && !input_buffer_.empty()) {
97             size_t len = std::min(available_out, input_buffer_.front_size());
98             p = mempcpy(p, input_buffer_.front_data(), len);
99             available_out -= len;
100             input_buffer_.drop_front(len);
101         }
102         *output = std::span(output_buffer_.data(), static_cast<char*>(p));
103         if (input_buffer_.empty()) {
104             return finished_ ? DecodeResult::Done : DecodeResult::NeedInput;
105         }
106         return DecodeResult::MoreOutput;
107     }
108 };
109 
110 struct NullEncoder final : public Encoder {
111     explicit NullEncoder(size_t output_block_size) : Encoder(output_block_size) {}
112 
113     EncodeResult Encode(Block* output) final {
114         output->clear();
115         output->resize(output_block_size_);
116 
117         size_t available_out = output->size();
118         void* p = output->data();
119 
120         while (available_out > 0 && !input_buffer_.empty()) {
121             size_t len = std::min(available_out, input_buffer_.front_size());
122             p = mempcpy(p, input_buffer_.front_data(), len);
123             available_out -= len;
124             input_buffer_.drop_front(len);
125         }
126 
127         output->resize(output->size() - available_out);
128 
129         if (input_buffer_.empty()) {
130             return finished_ ? EncodeResult::Done : EncodeResult::NeedInput;
131         }
132         return EncodeResult::MoreOutput;
133     }
134 };
135 
136 struct BrotliDecoder final : public Decoder {
137     explicit BrotliDecoder(std::span<char> output_buffer)
138         : Decoder(output_buffer),
139           decoder_(BrotliDecoderCreateInstance(nullptr, nullptr, nullptr),
140                    BrotliDecoderDestroyInstance) {}
141 
142     DecodeResult Decode(std::span<char>* output) final {
143         size_t available_in = input_buffer_.front_size();
144         const uint8_t* next_in = reinterpret_cast<const uint8_t*>(input_buffer_.front_data());
145 
146         size_t available_out = output_buffer_.size();
147         uint8_t* next_out = reinterpret_cast<uint8_t*>(output_buffer_.data());
148 
149         BrotliDecoderResult r = BrotliDecoderDecompressStream(
150                 decoder_.get(), &available_in, &next_in, &available_out, &next_out, nullptr);
151 
152         size_t bytes_consumed = input_buffer_.front_size() - available_in;
153         input_buffer_.drop_front(bytes_consumed);
154 
155         size_t bytes_emitted = output_buffer_.size() - available_out;
156         *output = std::span<char>(output_buffer_.data(), bytes_emitted);
157 
158         switch (r) {
159             case BROTLI_DECODER_RESULT_SUCCESS:
160                 // We need to wait for ID_DONE from the other end.
161                 return finished_ ? DecodeResult::Done : DecodeResult::NeedInput;
162             case BROTLI_DECODER_RESULT_ERROR:
163                 return DecodeResult::Error;
164             case BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT:
165                 // Brotli guarantees as one of its invariants that if it returns NEEDS_MORE_INPUT,
166                 // it will consume the entire input buffer passed in, so we don't have to worry
167                 // about bytes left over in the front block with more input remaining.
168                 return DecodeResult::NeedInput;
169             case BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT:
170                 return DecodeResult::MoreOutput;
171         }
172     }
173 
174   private:
175     std::unique_ptr<BrotliDecoderState, void (*)(BrotliDecoderState*)> decoder_;
176 };
177 
178 struct BrotliEncoder final : public Encoder {
179     explicit BrotliEncoder(size_t output_block_size)
180         : Encoder(output_block_size),
181           output_block_(output_block_size_),
182           output_bytes_left_(output_block_size_),
183           encoder_(BrotliEncoderCreateInstance(nullptr, nullptr, nullptr),
184                    BrotliEncoderDestroyInstance) {
185         BrotliEncoderSetParameter(encoder_.get(), BROTLI_PARAM_QUALITY, 1);
186     }
187 
188     EncodeResult Encode(Block* output) final {
189         output->clear();
190 
191         while (true) {
192             size_t available_in = input_buffer_.front_size();
193             const uint8_t* next_in = reinterpret_cast<const uint8_t*>(input_buffer_.front_data());
194 
195             size_t available_out = output_bytes_left_;
196             uint8_t* next_out = reinterpret_cast<uint8_t*>(
197                     output_block_.data() + (output_block_size_ - output_bytes_left_));
198 
199             BrotliEncoderOperation op = BROTLI_OPERATION_PROCESS;
200             if (finished_) {
201                 op = BROTLI_OPERATION_FINISH;
202             }
203 
204             if (!BrotliEncoderCompressStream(encoder_.get(), op, &available_in, &next_in,
205                                              &available_out, &next_out, nullptr)) {
206                 return EncodeResult::Error;
207             }
208 
209             size_t bytes_consumed = input_buffer_.front_size() - available_in;
210             input_buffer_.drop_front(bytes_consumed);
211 
212             output_bytes_left_ = available_out;
213 
214             if (BrotliEncoderIsFinished(encoder_.get())) {
215                 output_block_.resize(output_block_size_ - output_bytes_left_);
216                 *output = std::move(output_block_);
217                 return EncodeResult::Done;
218             } else if (output_bytes_left_ == 0) {
219                 *output = std::move(output_block_);
220                 output_block_.resize(output_block_size_);
221                 output_bytes_left_ = output_block_size_;
222                 return EncodeResult::MoreOutput;
223             } else if (input_buffer_.empty()) {
224                 return EncodeResult::NeedInput;
225             }
226         }
227     }
228 
229   private:
230     Block output_block_;
231     size_t output_bytes_left_;
232     std::unique_ptr<BrotliEncoderState, void (*)(BrotliEncoderState*)> encoder_;
233 };
234 
235 struct LZ4Decoder final : public Decoder {
236     explicit LZ4Decoder(std::span<char> output_buffer)
237         : Decoder(output_buffer), decoder_(nullptr, nullptr) {
238         LZ4F_dctx* dctx;
239         if (LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION) != 0) {
240             LOG(FATAL) << "failed to initialize LZ4 decompression context";
241         }
242         decoder_ = std::unique_ptr<LZ4F_dctx, decltype(&LZ4F_freeDecompressionContext)>(
243                 dctx, LZ4F_freeDecompressionContext);
244     }
245 
246     DecodeResult Decode(std::span<char>* output) final {
247         size_t available_in = input_buffer_.front_size();
248         const char* next_in = input_buffer_.front_data();
249 
250         size_t available_out = output_buffer_.size();
251         char* next_out = output_buffer_.data();
252 
253         size_t rc = LZ4F_decompress(decoder_.get(), next_out, &available_out, next_in,
254                                     &available_in, nullptr);
255         if (LZ4F_isError(rc)) {
256             LOG(ERROR) << "LZ4F_decompress failed: " << LZ4F_getErrorName(rc);
257             return DecodeResult::Error;
258         }
259 
260         input_buffer_.drop_front(available_in);
261 
262         if (rc == 0) {
263             if (!input_buffer_.empty()) {
264                 LOG(ERROR) << "LZ4 stream hit end before reading all data";
265                 return DecodeResult::Error;
266             }
267             lz4_done_ = true;
268         }
269 
270         *output = std::span<char>(output_buffer_.data(), available_out);
271 
272         if (finished_) {
273             return input_buffer_.empty() && lz4_done_ ? DecodeResult::Done
274                                                       : DecodeResult::MoreOutput;
275         }
276 
277         return DecodeResult::NeedInput;
278     }
279 
280   private:
281     bool lz4_done_ = false;
282     std::unique_ptr<LZ4F_dctx, LZ4F_errorCode_t (*)(LZ4F_dctx*)> decoder_;
283 };
284 
285 struct LZ4Encoder final : public Encoder {
286     explicit LZ4Encoder(size_t output_block_size)
287         : Encoder(output_block_size), encoder_(nullptr, nullptr) {
288         LZ4F_cctx* cctx;
289         if (LZ4F_createCompressionContext(&cctx, LZ4F_VERSION) != 0) {
290             LOG(FATAL) << "failed to initialize LZ4 compression context";
291         }
292         encoder_ = std::unique_ptr<LZ4F_cctx, decltype(&LZ4F_freeCompressionContext)>(
293                 cctx, LZ4F_freeCompressionContext);
294         Block header(LZ4F_HEADER_SIZE_MAX);
295         size_t rc = LZ4F_compressBegin(encoder_.get(), header.data(), header.size(), nullptr);
296         if (LZ4F_isError(rc)) {
297             LOG(FATAL) << "LZ4F_compressBegin failed: %s", LZ4F_getErrorName(rc);
298         }
299         header.resize(rc);
300         output_buffer_.append(std::move(header));
301     }
302 
303     // As an optimization, only emit a block if we have an entire output block ready, or we're done.
304     bool OutputReady() const {
305         return output_buffer_.size() >= output_block_size_ || lz4_finalized_;
306     }
307 
308     // TODO: Switch the output type to IOVector to remove a copy?
309     EncodeResult Encode(Block* output) final {
310         size_t available_in = input_buffer_.front_size();
311         const char* next_in = input_buffer_.front_data();
312 
313         // LZ4 makes no guarantees about being able to recover from trying to compress with an
314         // insufficiently large output buffer. LZ4F_compressBound tells us how much buffer we
315         // need to compress a given number of bytes, but the smallest value seems to be bigger
316         // than SYNC_DATA_MAX, so we need to buffer ourselves.
317 
318         // Input size chosen to be a local maximum for LZ4F_compressBound (i.e. the block size).
319         constexpr size_t max_input_size = 65536;
320         const size_t encode_block_size = LZ4F_compressBound(max_input_size, nullptr);
321 
322         if (available_in != 0) {
323             if (lz4_finalized_) {
324                 LOG(ERROR) << "LZ4Encoder received data after Finish?";
325                 return EncodeResult::Error;
326             }
327 
328             available_in = std::min(available_in, max_input_size);
329 
330             Block encode_block(encode_block_size);
331             size_t available_out = encode_block.capacity();
332             char* next_out = encode_block.data();
333 
334             size_t rc = LZ4F_compressUpdate(encoder_.get(), next_out, available_out, next_in,
335                                             available_in, nullptr);
336             if (LZ4F_isError(rc)) {
337                 LOG(ERROR) << "LZ4F_compressUpdate failed: " << LZ4F_getErrorName(rc);
338                 return EncodeResult::Error;
339             }
340 
341             input_buffer_.drop_front(available_in);
342 
343             available_out -= rc;
344             next_out += rc;
345 
346             encode_block.resize(encode_block_size - available_out);
347             output_buffer_.append(std::move(encode_block));
348         }
349 
350         if (finished_ && !lz4_finalized_) {
351             lz4_finalized_ = true;
352 
353             Block final_block(encode_block_size + 4);
354             size_t rc = LZ4F_compressEnd(encoder_.get(), final_block.data(), final_block.size(),
355                                          nullptr);
356             if (LZ4F_isError(rc)) {
357                 LOG(ERROR) << "LZ4F_compressEnd failed: " << LZ4F_getErrorName(rc);
358                 return EncodeResult::Error;
359             }
360 
361             final_block.resize(rc);
362             output_buffer_.append(std::move(final_block));
363         }
364 
365         if (OutputReady()) {
366             size_t len = std::min(output_block_size_, output_buffer_.size());
367             *output = output_buffer_.take_front(len).coalesce();
368         } else {
369             output->clear();
370         }
371 
372         if (lz4_finalized_ && output_buffer_.empty()) {
373             return EncodeResult::Done;
374         } else if (OutputReady()) {
375             return EncodeResult::MoreOutput;
376         }
377         return EncodeResult::NeedInput;
378     }
379 
380   private:
381     bool lz4_finalized_ = false;
382     std::unique_ptr<LZ4F_cctx, LZ4F_errorCode_t (*)(LZ4F_cctx*)> encoder_;
383     IOVector output_buffer_;
384 };
385 
386 struct ZstdDecoder final : public Decoder {
387     explicit ZstdDecoder(std::span<char> output_buffer)
388         : Decoder(output_buffer), decoder_(ZSTD_createDStream(), ZSTD_freeDStream) {
389         if (!decoder_) {
390             LOG(FATAL) << "failed to initialize Zstd decompression context";
391         }
392     }
393 
394     DecodeResult Decode(std::span<char>* output) final {
395         ZSTD_inBuffer in;
396         in.src = input_buffer_.front_data();
397         in.size = input_buffer_.front_size();
398         in.pos = 0;
399 
400         ZSTD_outBuffer out;
401         out.dst = output_buffer_.data();
402         // The standard specifies size() as returning size_t, but our current version of
403         // libc++ returns a signed value instead.
404         out.size = static_cast<size_t>(output_buffer_.size());
405         out.pos = 0;
406 
407         size_t rc = ZSTD_decompressStream(decoder_.get(), &out, &in);
408         if (ZSTD_isError(rc)) {
409             LOG(ERROR) << "ZSTD_decompressStream failed: " << ZSTD_getErrorName(rc);
410             return DecodeResult::Error;
411         }
412 
413         input_buffer_.drop_front(in.pos);
414         if (rc == 0) {
415             if (!input_buffer_.empty()) {
416                 LOG(ERROR) << "Zstd stream hit end before reading all data";
417                 return DecodeResult::Error;
418             }
419             zstd_done_ = true;
420         }
421 
422         *output = std::span<char>(output_buffer_.data(), out.pos);
423 
424         if (finished_) {
425             return input_buffer_.empty() && zstd_done_ ? DecodeResult::Done
426                                                        : DecodeResult::MoreOutput;
427         }
428         return DecodeResult::NeedInput;
429     }
430 
431   private:
432     bool zstd_done_ = false;
433     std::unique_ptr<ZSTD_DStream, size_t (*)(ZSTD_DStream*)> decoder_;
434 };
435 
436 struct ZstdEncoder final : public Encoder {
437     explicit ZstdEncoder(size_t output_block_size)
438         : Encoder(output_block_size), encoder_(ZSTD_createCStream(), ZSTD_freeCStream) {
439         if (!encoder_) {
440             LOG(FATAL) << "failed to initialize Zstd compression context";
441         }
442         ZSTD_CCtx_setParameter(encoder_.get(), ZSTD_c_compressionLevel, 1);
443     }
444 
445     EncodeResult Encode(Block* output) final {
446         ZSTD_inBuffer in;
447         in.src = input_buffer_.front_data();
448         in.size = input_buffer_.front_size();
449         in.pos = 0;
450 
451         output->resize(output_block_size_);
452 
453         ZSTD_outBuffer out;
454         out.dst = output->data();
455         out.size = static_cast<size_t>(output->size());
456         out.pos = 0;
457 
458         ZSTD_EndDirective end_directive = finished_ ? ZSTD_e_end : ZSTD_e_continue;
459         size_t rc = ZSTD_compressStream2(encoder_.get(), &out, &in, end_directive);
460         if (ZSTD_isError(rc)) {
461             LOG(ERROR) << "ZSTD_compressStream2 failed: " << ZSTD_getErrorName(rc);
462             return EncodeResult::Error;
463         }
464 
465         input_buffer_.drop_front(in.pos);
466         output->resize(out.pos);
467 
468         if (rc == 0) {
469             // Zstd finished flushing its data.
470             if (finished_) {
471                 if (!input_buffer_.empty()) {
472                     LOG(ERROR) << "ZSTD_compressStream2 finished early";
473                     return EncodeResult::Error;
474                 }
475                 return EncodeResult::Done;
476             } else {
477                 return input_buffer_.empty() ? EncodeResult::NeedInput : EncodeResult::MoreOutput;
478             }
479         } else {
480             return EncodeResult::MoreOutput;
481         }
482     }
483 
484   private:
485     std::unique_ptr<ZSTD_CStream, size_t (*)(ZSTD_CStream*)> encoder_;
486 };
487