1 /*
2  * Copyright (c) 2016-2020, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10 
11 #include "method.h"
12 
13 #include <stdio.h>
14 #include <stdlib.h>
15 
16 #define ZSTD_STATIC_LINKING_ONLY
17 #include <zstd.h>
18 
19 #define MIN(x, y) ((x) < (y) ? (x) : (y))
20 
21 static char const* g_zstdcli = NULL;
22 
method_set_zstdcli(char const * zstdcli)23 void method_set_zstdcli(char const* zstdcli) {
24     g_zstdcli = zstdcli;
25 }
26 
27 /**
28  * Macro to get a pointer of type, given ptr, which is a member variable with
29  * the given name, member.
30  *
31  *     method_state_t* base = ...;
32  *     buffer_state_t* state = container_of(base, buffer_state_t, base);
33  */
34 #define container_of(ptr, type, member) \
35     ((type*)(ptr == NULL ? NULL : (char*)(ptr)-offsetof(type, member)))
36 
37 /** State to reuse the same buffers between compression calls. */
38 typedef struct {
39     method_state_t base;
40     data_buffers_t inputs; /**< The input buffer for each file. */
41     data_buffer_t dictionary; /**< The dictionary. */
42     data_buffer_t compressed; /**< The compressed data buffer. */
43     data_buffer_t decompressed; /**< The decompressed data buffer. */
44 } buffer_state_t;
45 
buffers_max_size(data_buffers_t buffers)46 static size_t buffers_max_size(data_buffers_t buffers) {
47     size_t max = 0;
48     for (size_t i = 0; i < buffers.size; ++i) {
49         if (buffers.buffers[i].size > max)
50             max = buffers.buffers[i].size;
51     }
52     return max;
53 }
54 
buffer_state_create(data_t const * data)55 static method_state_t* buffer_state_create(data_t const* data) {
56     buffer_state_t* state = (buffer_state_t*)calloc(1, sizeof(buffer_state_t));
57     if (state == NULL)
58         return NULL;
59     state->base.data = data;
60     state->inputs = data_buffers_get(data);
61     state->dictionary = data_buffer_get_dict(data);
62     size_t const max_size = buffers_max_size(state->inputs);
63     state->compressed = data_buffer_create(ZSTD_compressBound(max_size));
64     state->decompressed = data_buffer_create(max_size);
65     return &state->base;
66 }
67 
buffer_state_destroy(method_state_t * base)68 static void buffer_state_destroy(method_state_t* base) {
69     if (base == NULL)
70         return;
71     buffer_state_t* state = container_of(base, buffer_state_t, base);
72     free(state);
73 }
74 
buffer_state_bad(buffer_state_t const * state,config_t const * config)75 static int buffer_state_bad(
76     buffer_state_t const* state,
77     config_t const* config) {
78     if (state == NULL) {
79         fprintf(stderr, "buffer_state_t is NULL\n");
80         return 1;
81     }
82     if (state->inputs.size == 0 || state->compressed.data == NULL ||
83         state->decompressed.data == NULL) {
84         fprintf(stderr, "buffer state allocation failure\n");
85         return 1;
86     }
87     if (config->use_dictionary && state->dictionary.data == NULL) {
88         fprintf(stderr, "dictionary loading failed\n");
89         return 1;
90     }
91     return 0;
92 }
93 
simple_compress(method_state_t * base,config_t const * config)94 static result_t simple_compress(method_state_t* base, config_t const* config) {
95     buffer_state_t* state = container_of(base, buffer_state_t, base);
96 
97     if (buffer_state_bad(state, config))
98         return result_error(result_error_system_error);
99 
100     /* Keep the tests short by skipping directories, since behavior shouldn't
101      * change.
102      */
103     if (base->data->type != data_type_file)
104         return result_error(result_error_skip);
105 
106     if (config->use_dictionary || config->no_pledged_src_size)
107         return result_error(result_error_skip);
108 
109     /* If the config doesn't specify a level, skip. */
110     int const level = config_get_level(config);
111     if (level == CONFIG_NO_LEVEL)
112         return result_error(result_error_skip);
113 
114     data_buffer_t const input = state->inputs.buffers[0];
115 
116     /* Compress, decompress, and check the result. */
117     state->compressed.size = ZSTD_compress(
118         state->compressed.data,
119         state->compressed.capacity,
120         input.data,
121         input.size,
122         level);
123     if (ZSTD_isError(state->compressed.size))
124         return result_error(result_error_compression_error);
125 
126     state->decompressed.size = ZSTD_decompress(
127         state->decompressed.data,
128         state->decompressed.capacity,
129         state->compressed.data,
130         state->compressed.size);
131     if (ZSTD_isError(state->decompressed.size))
132         return result_error(result_error_decompression_error);
133     if (data_buffer_compare(input, state->decompressed))
134         return result_error(result_error_round_trip_error);
135 
136     result_data_t data;
137     data.total_size = state->compressed.size;
138     return result_data(data);
139 }
140 
compress_cctx_compress(method_state_t * base,config_t const * config)141 static result_t compress_cctx_compress(
142     method_state_t* base,
143     config_t const* config) {
144     buffer_state_t* state = container_of(base, buffer_state_t, base);
145 
146     if (buffer_state_bad(state, config))
147         return result_error(result_error_system_error);
148 
149     if (config->no_pledged_src_size)
150         return result_error(result_error_skip);
151 
152     if (base->data->type != data_type_dir)
153         return result_error(result_error_skip);
154 
155     int const level = config_get_level(config);
156 
157     ZSTD_CCtx* cctx = ZSTD_createCCtx();
158     ZSTD_DCtx* dctx = ZSTD_createDCtx();
159     if (cctx == NULL || dctx == NULL) {
160         fprintf(stderr, "context creation failed\n");
161         return result_error(result_error_system_error);
162     }
163 
164     result_t result;
165     result_data_t data = {.total_size = 0};
166     for (size_t i = 0; i < state->inputs.size; ++i) {
167         data_buffer_t const input = state->inputs.buffers[i];
168         ZSTD_parameters const params =
169             config_get_zstd_params(config, input.size, state->dictionary.size);
170 
171         if (level == CONFIG_NO_LEVEL)
172             state->compressed.size = ZSTD_compress_advanced(
173                 cctx,
174                 state->compressed.data,
175                 state->compressed.capacity,
176                 input.data,
177                 input.size,
178                 config->use_dictionary ? state->dictionary.data : NULL,
179                 config->use_dictionary ? state->dictionary.size : 0,
180                 params);
181         else if (config->use_dictionary)
182             state->compressed.size = ZSTD_compress_usingDict(
183                 cctx,
184                 state->compressed.data,
185                 state->compressed.capacity,
186                 input.data,
187                 input.size,
188                 state->dictionary.data,
189                 state->dictionary.size,
190                 level);
191         else
192             state->compressed.size = ZSTD_compressCCtx(
193                 cctx,
194                 state->compressed.data,
195                 state->compressed.capacity,
196                 input.data,
197                 input.size,
198                 level);
199 
200         if (ZSTD_isError(state->compressed.size)) {
201             result = result_error(result_error_compression_error);
202             goto out;
203         }
204 
205         if (config->use_dictionary)
206             state->decompressed.size = ZSTD_decompress_usingDict(
207                 dctx,
208                 state->decompressed.data,
209                 state->decompressed.capacity,
210                 state->compressed.data,
211                 state->compressed.size,
212                 state->dictionary.data,
213                 state->dictionary.size);
214         else
215             state->decompressed.size = ZSTD_decompressDCtx(
216                 dctx,
217                 state->decompressed.data,
218                 state->decompressed.capacity,
219                 state->compressed.data,
220                 state->compressed.size);
221         if (ZSTD_isError(state->decompressed.size)) {
222             result = result_error(result_error_decompression_error);
223             goto out;
224         }
225         if (data_buffer_compare(input, state->decompressed)) {
226             result = result_error(result_error_round_trip_error);
227             goto out;
228         }
229 
230         data.total_size += state->compressed.size;
231     }
232 
233     result = result_data(data);
234 out:
235     ZSTD_freeCCtx(cctx);
236     ZSTD_freeDCtx(dctx);
237     return result;
238 }
239 
240 /** Generic state creation function. */
method_state_create(data_t const * data)241 static method_state_t* method_state_create(data_t const* data) {
242     method_state_t* state = (method_state_t*)malloc(sizeof(method_state_t));
243     if (state == NULL)
244         return NULL;
245     state->data = data;
246     return state;
247 }
248 
method_state_destroy(method_state_t * state)249 static void method_state_destroy(method_state_t* state) {
250     free(state);
251 }
252 
cli_compress(method_state_t * state,config_t const * config)253 static result_t cli_compress(method_state_t* state, config_t const* config) {
254     if (config->cli_args == NULL)
255         return result_error(result_error_skip);
256 
257     /* We don't support no pledged source size with directories. Too slow. */
258     if (state->data->type == data_type_dir && config->no_pledged_src_size)
259         return result_error(result_error_skip);
260 
261     if (g_zstdcli == NULL)
262         return result_error(result_error_system_error);
263 
264     /* '<zstd>' -cqr <args> [-D '<dict>'] '<file/dir>' */
265     char cmd[1024];
266     size_t const cmd_size = snprintf(
267         cmd,
268         sizeof(cmd),
269         "'%s' -cqr %s %s%s%s %s '%s'",
270         g_zstdcli,
271         config->cli_args,
272         config->use_dictionary ? "-D '" : "",
273         config->use_dictionary ? state->data->dict.path : "",
274         config->use_dictionary ? "'" : "",
275         config->no_pledged_src_size ? "<" : "",
276         state->data->data.path);
277     if (cmd_size >= sizeof(cmd)) {
278         fprintf(stderr, "command too large: %s\n", cmd);
279         return result_error(result_error_system_error);
280     }
281     FILE* zstd = popen(cmd, "r");
282     if (zstd == NULL) {
283         fprintf(stderr, "failed to popen command: %s\n", cmd);
284         return result_error(result_error_system_error);
285     }
286 
287     char out[4096];
288     size_t total_size = 0;
289     while (1) {
290         size_t const size = fread(out, 1, sizeof(out), zstd);
291         total_size += size;
292         if (size != sizeof(out))
293             break;
294     }
295     if (ferror(zstd) || pclose(zstd) != 0) {
296         fprintf(stderr, "zstd failed with command: %s\n", cmd);
297         return result_error(result_error_compression_error);
298     }
299 
300     result_data_t const data = {.total_size = total_size};
301     return result_data(data);
302 }
303 
advanced_config(ZSTD_CCtx * cctx,buffer_state_t * state,config_t const * config)304 static int advanced_config(
305     ZSTD_CCtx* cctx,
306     buffer_state_t* state,
307     config_t const* config) {
308     ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
309     for (size_t p = 0; p < config->param_values.size; ++p) {
310         param_value_t const pv = config->param_values.data[p];
311         if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, pv.param, pv.value))) {
312             return 1;
313         }
314     }
315     if (config->use_dictionary) {
316         if (ZSTD_isError(ZSTD_CCtx_loadDictionary(
317                 cctx, state->dictionary.data, state->dictionary.size))) {
318             return 1;
319         }
320     }
321     return 0;
322 }
323 
advanced_one_pass_compress_output_adjustment(method_state_t * base,config_t const * config,size_t const subtract)324 static result_t advanced_one_pass_compress_output_adjustment(
325     method_state_t* base,
326     config_t const* config,
327     size_t const subtract) {
328     buffer_state_t* state = container_of(base, buffer_state_t, base);
329 
330     if (buffer_state_bad(state, config))
331         return result_error(result_error_system_error);
332 
333     ZSTD_CCtx* cctx = ZSTD_createCCtx();
334     result_t result;
335 
336     if (!cctx || advanced_config(cctx, state, config)) {
337         result = result_error(result_error_compression_error);
338         goto out;
339     }
340 
341     result_data_t data = {.total_size = 0};
342     for (size_t i = 0; i < state->inputs.size; ++i) {
343         data_buffer_t const input = state->inputs.buffers[i];
344 
345         if (!config->no_pledged_src_size) {
346             if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) {
347                 result = result_error(result_error_compression_error);
348                 goto out;
349             }
350         }
351         size_t const size = ZSTD_compress2(
352             cctx,
353             state->compressed.data,
354             ZSTD_compressBound(input.size) - subtract,
355             input.data,
356             input.size);
357         if (ZSTD_isError(size)) {
358             result = result_error(result_error_compression_error);
359             goto out;
360         }
361         data.total_size += size;
362     }
363 
364     result = result_data(data);
365 out:
366     ZSTD_freeCCtx(cctx);
367     return result;
368 }
369 
advanced_one_pass_compress(method_state_t * base,config_t const * config)370 static result_t advanced_one_pass_compress(
371     method_state_t* base,
372     config_t const* config) {
373   return advanced_one_pass_compress_output_adjustment(base, config, 0);
374 }
375 
advanced_one_pass_compress_small_output(method_state_t * base,config_t const * config)376 static result_t advanced_one_pass_compress_small_output(
377     method_state_t* base,
378     config_t const* config) {
379   return advanced_one_pass_compress_output_adjustment(base, config, 1);
380 }
381 
advanced_streaming_compress(method_state_t * base,config_t const * config)382 static result_t advanced_streaming_compress(
383     method_state_t* base,
384     config_t const* config) {
385     buffer_state_t* state = container_of(base, buffer_state_t, base);
386 
387     if (buffer_state_bad(state, config))
388         return result_error(result_error_system_error);
389 
390     ZSTD_CCtx* cctx = ZSTD_createCCtx();
391     result_t result;
392 
393     if (!cctx || advanced_config(cctx, state, config)) {
394         result = result_error(result_error_compression_error);
395         goto out;
396     }
397 
398     result_data_t data = {.total_size = 0};
399     for (size_t i = 0; i < state->inputs.size; ++i) {
400         data_buffer_t input = state->inputs.buffers[i];
401 
402         if (!config->no_pledged_src_size) {
403             if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) {
404                 result = result_error(result_error_compression_error);
405                 goto out;
406             }
407         }
408 
409         while (input.size > 0) {
410             ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)};
411             input.data += in.size;
412             input.size -= in.size;
413             ZSTD_EndDirective const op =
414                 input.size > 0 ? ZSTD_e_continue : ZSTD_e_end;
415             size_t ret = 0;
416             while (in.pos < in.size || (op == ZSTD_e_end && ret != 0)) {
417                 ZSTD_outBuffer out = {state->compressed.data,
418                                       MIN(state->compressed.capacity, 1024)};
419                 ret = ZSTD_compressStream2(cctx, &out, &in, op);
420                 if (ZSTD_isError(ret)) {
421                     result = result_error(result_error_compression_error);
422                     goto out;
423                 }
424                 data.total_size += out.pos;
425             }
426         }
427     }
428 
429     result = result_data(data);
430 out:
431     ZSTD_freeCCtx(cctx);
432     return result;
433 }
434 
init_cstream(buffer_state_t * state,ZSTD_CStream * zcs,config_t const * config,int const advanced,ZSTD_CDict ** cdict)435 static int init_cstream(
436     buffer_state_t* state,
437     ZSTD_CStream* zcs,
438     config_t const* config,
439     int const advanced,
440     ZSTD_CDict** cdict)
441 {
442     size_t zret;
443     if (advanced) {
444         ZSTD_parameters const params = config_get_zstd_params(config, 0, 0);
445         ZSTD_CDict* dict = NULL;
446         if (cdict) {
447             if (!config->use_dictionary)
448               return 1;
449             *cdict = ZSTD_createCDict_advanced(
450                 state->dictionary.data,
451                 state->dictionary.size,
452                 ZSTD_dlm_byRef,
453                 ZSTD_dct_auto,
454                 params.cParams,
455                 ZSTD_defaultCMem);
456             if (!*cdict) {
457                 return 1;
458             }
459             zret = ZSTD_initCStream_usingCDict_advanced(
460                 zcs, *cdict, params.fParams, ZSTD_CONTENTSIZE_UNKNOWN);
461         } else {
462             zret = ZSTD_initCStream_advanced(
463                 zcs,
464                 config->use_dictionary ? state->dictionary.data : NULL,
465                 config->use_dictionary ? state->dictionary.size : 0,
466                 params,
467                 ZSTD_CONTENTSIZE_UNKNOWN);
468         }
469     } else {
470         int const level = config_get_level(config);
471         if (level == CONFIG_NO_LEVEL)
472             return 1;
473         if (cdict) {
474             if (!config->use_dictionary)
475               return 1;
476             *cdict = ZSTD_createCDict(
477                 state->dictionary.data,
478                 state->dictionary.size,
479                 level);
480             if (!*cdict) {
481                 return 1;
482             }
483             zret = ZSTD_initCStream_usingCDict(zcs, *cdict);
484         } else if (config->use_dictionary) {
485             zret = ZSTD_initCStream_usingDict(
486                 zcs,
487                 state->dictionary.data,
488                 state->dictionary.size,
489                 level);
490         } else {
491             zret = ZSTD_initCStream(zcs, level);
492         }
493     }
494     if (ZSTD_isError(zret)) {
495         return 1;
496     }
497     return 0;
498 }
499 
old_streaming_compress_internal(method_state_t * base,config_t const * config,int const advanced,int const cdict)500 static result_t old_streaming_compress_internal(
501     method_state_t* base,
502     config_t const* config,
503     int const advanced,
504     int const cdict) {
505   buffer_state_t* state = container_of(base, buffer_state_t, base);
506 
507   if (buffer_state_bad(state, config))
508     return result_error(result_error_system_error);
509 
510 
511   ZSTD_CStream* zcs = ZSTD_createCStream();
512   ZSTD_CDict* cd = NULL;
513   result_t result;
514   if (zcs == NULL) {
515     result = result_error(result_error_compression_error);
516     goto out;
517   }
518   if (!advanced && config_get_level(config) == CONFIG_NO_LEVEL) {
519     result = result_error(result_error_skip);
520     goto out;
521   }
522   if (cdict && !config->use_dictionary) {
523     result = result_error(result_error_skip);
524     goto out;
525   }
526   if (init_cstream(state, zcs, config, advanced, cdict ? &cd : NULL)) {
527     result = result_error(result_error_compression_error);
528     goto out;
529   }
530 
531   result_data_t data = {.total_size = 0};
532   for (size_t i = 0; i < state->inputs.size; ++i) {
533     data_buffer_t input = state->inputs.buffers[i];
534     size_t zret = ZSTD_resetCStream(
535         zcs,
536         config->no_pledged_src_size ? ZSTD_CONTENTSIZE_UNKNOWN : input.size);
537     if (ZSTD_isError(zret)) {
538       result = result_error(result_error_compression_error);
539       goto out;
540     }
541 
542     while (input.size > 0) {
543       ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)};
544       input.data += in.size;
545       input.size -= in.size;
546       ZSTD_EndDirective const op =
547           input.size > 0 ? ZSTD_e_continue : ZSTD_e_end;
548       zret = 0;
549       while (in.pos < in.size || (op == ZSTD_e_end && zret != 0)) {
550         ZSTD_outBuffer out = {state->compressed.data,
551                               MIN(state->compressed.capacity, 1024)};
552         if (op == ZSTD_e_continue || in.pos < in.size)
553           zret = ZSTD_compressStream(zcs, &out, &in);
554         else
555           zret = ZSTD_endStream(zcs, &out);
556         if (ZSTD_isError(zret)) {
557           result = result_error(result_error_compression_error);
558           goto out;
559         }
560         data.total_size += out.pos;
561       }
562     }
563   }
564 
565   result = result_data(data);
566 out:
567     ZSTD_freeCStream(zcs);
568     ZSTD_freeCDict(cd);
569     return result;
570 }
571 
old_streaming_compress(method_state_t * base,config_t const * config)572 static result_t old_streaming_compress(
573     method_state_t* base,
574     config_t const* config)
575 {
576     return old_streaming_compress_internal(
577         base, config, /* advanced */ 0, /* cdict */ 0);
578 }
579 
old_streaming_compress_advanced(method_state_t * base,config_t const * config)580 static result_t old_streaming_compress_advanced(
581     method_state_t* base,
582     config_t const* config)
583 {
584     return old_streaming_compress_internal(
585         base, config, /* advanced */ 1, /* cdict */ 0);
586 }
587 
old_streaming_compress_cdict(method_state_t * base,config_t const * config)588 static result_t old_streaming_compress_cdict(
589     method_state_t* base,
590     config_t const* config)
591 {
592     return old_streaming_compress_internal(
593         base, config, /* advanced */ 0, /* cdict */ 1);
594 }
595 
old_streaming_compress_cdict_advanced(method_state_t * base,config_t const * config)596 static result_t old_streaming_compress_cdict_advanced(
597     method_state_t* base,
598     config_t const* config)
599 {
600     return old_streaming_compress_internal(
601         base, config, /* advanced */ 1, /* cdict */ 1);
602 }
603 
604 method_t const simple = {
605     .name = "compress simple",
606     .create = buffer_state_create,
607     .compress = simple_compress,
608     .destroy = buffer_state_destroy,
609 };
610 
611 method_t const compress_cctx = {
612     .name = "compress cctx",
613     .create = buffer_state_create,
614     .compress = compress_cctx_compress,
615     .destroy = buffer_state_destroy,
616 };
617 
618 method_t const advanced_one_pass = {
619     .name = "advanced one pass",
620     .create = buffer_state_create,
621     .compress = advanced_one_pass_compress,
622     .destroy = buffer_state_destroy,
623 };
624 
625 method_t const advanced_one_pass_small_out = {
626     .name = "advanced one pass small out",
627     .create = buffer_state_create,
628     .compress = advanced_one_pass_compress,
629     .destroy = buffer_state_destroy,
630 };
631 
632 method_t const advanced_streaming = {
633     .name = "advanced streaming",
634     .create = buffer_state_create,
635     .compress = advanced_streaming_compress,
636     .destroy = buffer_state_destroy,
637 };
638 
639 method_t const old_streaming = {
640     .name = "old streaming",
641     .create = buffer_state_create,
642     .compress = old_streaming_compress,
643     .destroy = buffer_state_destroy,
644 };
645 
646 method_t const old_streaming_advanced = {
647     .name = "old streaming advanced",
648     .create = buffer_state_create,
649     .compress = old_streaming_compress_advanced,
650     .destroy = buffer_state_destroy,
651 };
652 
653 method_t const old_streaming_cdict = {
654     .name = "old streaming cdcit",
655     .create = buffer_state_create,
656     .compress = old_streaming_compress_cdict,
657     .destroy = buffer_state_destroy,
658 };
659 
660 method_t const old_streaming_advanced_cdict = {
661     .name = "old streaming advanced cdict",
662     .create = buffer_state_create,
663     .compress = old_streaming_compress_cdict_advanced,
664     .destroy = buffer_state_destroy,
665 };
666 
667 method_t const cli = {
668     .name = "zstdcli",
669     .create = method_state_create,
670     .compress = cli_compress,
671     .destroy = method_state_destroy,
672 };
673 
674 static method_t const* g_methods[] = {
675     &simple,
676     &compress_cctx,
677     &cli,
678     &advanced_one_pass,
679     &advanced_one_pass_small_out,
680     &advanced_streaming,
681     &old_streaming,
682     &old_streaming_advanced,
683     &old_streaming_cdict,
684     &old_streaming_advanced_cdict,
685     NULL,
686 };
687 
688 method_t const* const* methods = g_methods;
689