1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "cow_decompress.h"
18 
19 #include <utility>
20 
21 #include <android-base/logging.h>
22 #include <brotli/decode.h>
23 #include <zlib.h>
24 
25 namespace android {
26 namespace snapshot {
27 
28 class NoDecompressor final : public IDecompressor {
29   public:
30     bool Decompress(size_t) override;
31 };
32 
Decompress(size_t)33 bool NoDecompressor::Decompress(size_t) {
34     size_t stream_remaining = stream_->Size();
35     while (stream_remaining) {
36         size_t buffer_size = stream_remaining;
37         uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size));
38         if (!buffer) {
39             LOG(ERROR) << "Could not acquire buffer from sink";
40             return false;
41         }
42 
43         // Read until we can fill the buffer.
44         uint8_t* buffer_pos = buffer;
45         size_t bytes_to_read = std::min(buffer_size, stream_remaining);
46         while (bytes_to_read) {
47             size_t read;
48             if (!stream_->Read(buffer_pos, bytes_to_read, &read)) {
49                 return false;
50             }
51             if (!read) {
52                 LOG(ERROR) << "Stream ended prematurely";
53                 return false;
54             }
55             if (!sink_->ReturnData(buffer_pos, read)) {
56                 LOG(ERROR) << "Could not return buffer to sink";
57                 return false;
58             }
59             buffer_pos += read;
60             bytes_to_read -= read;
61             stream_remaining -= read;
62         }
63     }
64     return true;
65 }
66 
Uncompressed()67 std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() {
68     return std::unique_ptr<IDecompressor>(new NoDecompressor());
69 }
70 
71 // Read chunks of the COW and incrementally stream them to the decoder.
72 class StreamDecompressor : public IDecompressor {
73   public:
74     bool Decompress(size_t output_bytes) override;
75 
76     virtual bool Init() = 0;
77     virtual bool DecompressInput(const uint8_t* data, size_t length) = 0;
78     virtual bool Done() = 0;
79 
80   protected:
81     bool GetFreshBuffer();
82 
83     size_t output_bytes_;
84     size_t stream_remaining_;
85     uint8_t* output_buffer_ = nullptr;
86     size_t output_buffer_remaining_ = 0;
87 };
88 
89 static constexpr size_t kChunkSize = 4096;
90 
Decompress(size_t output_bytes)91 bool StreamDecompressor::Decompress(size_t output_bytes) {
92     if (!Init()) {
93         return false;
94     }
95 
96     stream_remaining_ = stream_->Size();
97     output_bytes_ = output_bytes;
98 
99     uint8_t chunk[kChunkSize];
100     while (stream_remaining_) {
101         size_t read = std::min(stream_remaining_, sizeof(chunk));
102         if (!stream_->Read(chunk, read, &read)) {
103             return false;
104         }
105         if (!read) {
106             LOG(ERROR) << "Stream ended prematurely";
107             return false;
108         }
109         if (!DecompressInput(chunk, read)) {
110             return false;
111         }
112 
113         stream_remaining_ -= read;
114 
115         if (stream_remaining_ && Done()) {
116             LOG(ERROR) << "Decompressor terminated early";
117             return false;
118         }
119     }
120     if (!Done()) {
121         LOG(ERROR) << "Decompressor expected more bytes";
122         return false;
123     }
124     return true;
125 }
126 
GetFreshBuffer()127 bool StreamDecompressor::GetFreshBuffer() {
128     size_t request_size = std::min(output_bytes_, kChunkSize);
129     output_buffer_ =
130             reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_));
131     if (!output_buffer_) {
132         LOG(ERROR) << "Could not acquire buffer from sink";
133         return false;
134     }
135     return true;
136 }
137 
138 class GzDecompressor final : public StreamDecompressor {
139   public:
140     ~GzDecompressor();
141 
142     bool Init() override;
143     bool DecompressInput(const uint8_t* data, size_t length) override;
Done()144     bool Done() override { return ended_; }
145 
146   private:
147     z_stream z_ = {};
148     bool ended_ = false;
149 };
150 
Init()151 bool GzDecompressor::Init() {
152     if (int rv = inflateInit(&z_); rv != Z_OK) {
153         LOG(ERROR) << "inflateInit returned error code " << rv;
154         return false;
155     }
156     return true;
157 }
158 
~GzDecompressor()159 GzDecompressor::~GzDecompressor() {
160     inflateEnd(&z_);
161 }
162 
DecompressInput(const uint8_t * data,size_t length)163 bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) {
164     z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
165     z_.avail_in = length;
166 
167     while (z_.avail_in) {
168         // If no more output buffer, grab a new buffer.
169         if (z_.avail_out == 0) {
170             if (!GetFreshBuffer()) {
171                 return false;
172             }
173             z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
174             z_.avail_out = output_buffer_remaining_;
175         }
176 
177         // Remember the position of the output buffer so we can call ReturnData.
178         auto avail_out = z_.avail_out;
179 
180         // Decompress.
181         int rv = inflate(&z_, Z_NO_FLUSH);
182         if (rv != Z_OK && rv != Z_STREAM_END) {
183             LOG(ERROR) << "inflate returned error code " << rv;
184             return false;
185         }
186 
187         size_t returned = avail_out - z_.avail_out;
188         if (!sink_->ReturnData(output_buffer_, returned)) {
189             LOG(ERROR) << "Could not return buffer to sink";
190             return false;
191         }
192         output_buffer_ += returned;
193         output_buffer_remaining_ -= returned;
194 
195         if (rv == Z_STREAM_END) {
196             if (z_.avail_in) {
197                 LOG(ERROR) << "Gz stream ended prematurely";
198                 return false;
199             }
200             ended_ = true;
201             return true;
202         }
203     }
204     return true;
205 }
206 
Gz()207 std::unique_ptr<IDecompressor> IDecompressor::Gz() {
208     return std::unique_ptr<IDecompressor>(new GzDecompressor());
209 }
210 
211 class BrotliDecompressor final : public StreamDecompressor {
212   public:
213     ~BrotliDecompressor();
214 
215     bool Init() override;
216     bool DecompressInput(const uint8_t* data, size_t length) override;
Done()217     bool Done() override { return BrotliDecoderIsFinished(decoder_); }
218 
219   private:
220     BrotliDecoderState* decoder_ = nullptr;
221 };
222 
Init()223 bool BrotliDecompressor::Init() {
224     decoder_ = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr);
225     return true;
226 }
227 
~BrotliDecompressor()228 BrotliDecompressor::~BrotliDecompressor() {
229     if (decoder_) {
230         BrotliDecoderDestroyInstance(decoder_);
231     }
232 }
233 
DecompressInput(const uint8_t * data,size_t length)234 bool BrotliDecompressor::DecompressInput(const uint8_t* data, size_t length) {
235     size_t available_in = length;
236     const uint8_t* next_in = data;
237 
238     bool needs_more_output = false;
239     while (available_in || needs_more_output) {
240         if (!output_buffer_remaining_ && !GetFreshBuffer()) {
241             return false;
242         }
243 
244         auto output_buffer = output_buffer_;
245         auto r = BrotliDecoderDecompressStream(decoder_, &available_in, &next_in,
246                                                &output_buffer_remaining_, &output_buffer_, nullptr);
247         if (r == BROTLI_DECODER_RESULT_ERROR) {
248             LOG(ERROR) << "brotli decode failed";
249             return false;
250         }
251         if (!sink_->ReturnData(output_buffer, output_buffer_ - output_buffer)) {
252             return false;
253         }
254         needs_more_output = (r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT);
255     }
256     return true;
257 }
258 
Brotli()259 std::unique_ptr<IDecompressor> IDecompressor::Brotli() {
260     return std::unique_ptr<IDecompressor>(new BrotliDecompressor());
261 }
262 
263 }  // namespace snapshot
264 }  // namespace android
265