1 /*
2  * Copyright (c) 7-2020, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10 #include <stddef.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 
15 #include <linux/zstd.h>
16 
17 #define CONTROL(x)                                                             \
18   do {                                                                         \
19     if (!(x)) {                                                                \
20       fprintf(stderr, "%s:%u: %s failed!\n", __FUNCTION__, __LINE__, #x);      \
21       abort();                                                                 \
22     }                                                                          \
23   } while (0)
24 
25 typedef struct {
26   char *data;
27   char *data2;
28   size_t dataSize;
29   char *comp;
30   size_t compSize;
31 } test_data_t;
32 
create_test_data(void)33 test_data_t create_test_data(void) {
34   test_data_t data;
35   data.dataSize = 128 * 1024;
36   data.data = malloc(data.dataSize);
37   CONTROL(data.data != NULL);
38   data.data2 = malloc(data.dataSize);
39   CONTROL(data.data2 != NULL);
40   data.compSize = zstd_compress_bound(data.dataSize);
41   data.comp = malloc(data.compSize);
42   CONTROL(data.comp != NULL);
43   memset(data.data, 0, data.dataSize);
44   return data;
45 }
46 
free_test_data(test_data_t const * data)47 static void free_test_data(test_data_t const *data) {
48   free(data->data);
49   free(data->data2);
50   free(data->comp);
51 }
52 
53 #define MIN(a, b) ((a) < (b) ? (a) : (b))
54 #define MAX(a, b) ((a) > (b) ? (a) : (b))
55 
test_btrfs(test_data_t const * data)56 static void test_btrfs(test_data_t const *data) {
57   fprintf(stderr, "testing btrfs use cases... ");
58   size_t const size = MIN(data->dataSize, 128 * 1024);
59   for (int level = -1; level < 16; ++level) {
60     struct zstd_parameters params = zstd_get_params(level, size);
61     CONTROL(params.cparams.window_log <= 17);
62     size_t const workspaceSize =
63         MAX(zstd_cstream_workspace_bound(&params.cparams),
64             zstd_dstream_workspace_bound(size));
65     void *workspace = malloc(workspaceSize);
66     CONTROL(workspace != NULL);
67 
68     char const *ip = data->data;
69     char const *iend = ip + size;
70     char *op = data->comp;
71     char *oend = op + data->compSize;
72     {
73       zstd_cstream *cctx = zstd_init_cstream(&params, size, workspace, workspaceSize);
74       CONTROL(cctx != NULL);
75       struct zstd_out_buffer out = {NULL, 0, 0};
76       struct zstd_in_buffer in = {NULL, 0, 0};
77       for (;;) {
78         if (in.pos == in.size) {
79           in.src = ip;
80           in.size = MIN(4096, iend - ip);
81           in.pos = 0;
82           ip += in.size;
83         }
84 
85         if (out.pos == out.size) {
86           out.dst = op;
87           out.size = MIN(4096, oend - op);
88           out.pos = 0;
89           op += out.size;
90         }
91 
92         if (ip != iend || in.pos < in.size) {
93           CONTROL(!zstd_is_error(zstd_compress_stream(cctx, &out, &in)));
94         } else {
95           size_t const ret = zstd_end_stream(cctx, &out);
96           CONTROL(!zstd_is_error(ret));
97           if (ret == 0) {
98             break;
99           }
100         }
101       }
102       op += out.pos;
103     }
104 
105     ip = data->comp;
106     iend = op;
107     op = data->data2;
108     oend = op + size;
109     {
110       zstd_dstream *dctx = zstd_init_dstream(1ULL << params.cparams.window_log, workspace, workspaceSize);
111       CONTROL(dctx != NULL);
112       struct zstd_out_buffer out = {NULL, 0, 0};
113       struct zstd_in_buffer in = {NULL, 0, 0};
114       for (;;) {
115         if (in.pos == in.size) {
116           in.src = ip;
117           in.size = MIN(4096, iend - ip);
118           in.pos = 0;
119           ip += in.size;
120         }
121 
122         if (out.pos == out.size) {
123           out.dst = op;
124           out.size = MIN(4096, oend - op);
125           out.pos = 0;
126           op += out.size;
127         }
128 
129         size_t const ret = zstd_decompress_stream(dctx, &out, &in);
130         CONTROL(!zstd_is_error(ret));
131         if (ret == 0) {
132           break;
133         }
134       }
135     }
136     CONTROL(op - data->data2 == data->dataSize);
137     CONTROL(!memcmp(data->data, data->data2, data->dataSize));
138     free(workspace);
139   }
140   fprintf(stderr, "Ok\n");
141 }
142 
test_decompress_unzstd(test_data_t const * data)143 static void test_decompress_unzstd(test_data_t const *data) {
144     fprintf(stderr, "Testing decompress unzstd... ");
145     size_t cSize;
146     {
147         struct zstd_parameters params = zstd_get_params(19, 0);
148         size_t const wkspSize = zstd_cctx_workspace_bound(&params.cparams);
149         void* wksp = malloc(wkspSize);
150         CONTROL(wksp != NULL);
151         zstd_cctx* cctx = zstd_init_cctx(wksp, wkspSize);
152         CONTROL(cctx != NULL);
153         cSize = zstd_compress_cctx(cctx, data->comp, data->compSize, data->data, data->dataSize, &params);
154         CONTROL(!zstd_is_error(cSize));
155         free(wksp);
156     }
157     {
158         size_t const wkspSize = zstd_dctx_workspace_bound();
159         void* wksp = malloc(wkspSize);
160         CONTROL(wksp != NULL);
161         zstd_dctx* dctx = zstd_init_dctx(wksp, wkspSize);
162         CONTROL(dctx != NULL);
163         size_t const dSize = zstd_decompress_dctx(dctx, data->data2, data->dataSize, data->comp, cSize);
164         CONTROL(!zstd_is_error(dSize));
165         CONTROL(dSize == data->dataSize);
166         CONTROL(!memcmp(data->data, data->data2, data->dataSize));
167         free(wksp);
168     }
169     fprintf(stderr, "Ok\n");
170 }
171 
172 static char *g_stack = NULL;
173 
use(void * x)174 static void __attribute__((noinline)) use(void *x) {
175   asm volatile("" : "+r"(x));
176 }
177 
set_stack()178 static void __attribute__((noinline)) set_stack() {
179 
180   char stack[8192];
181   g_stack = stack;
182   memset(g_stack, 0x33, 8192);
183   use(g_stack);
184 }
185 
check_stack()186 static void __attribute__((noinline)) check_stack() {
187   size_t cleanStack = 0;
188   while (cleanStack < 8192 && g_stack[cleanStack] == 0x33) {
189     ++cleanStack;
190   }
191   size_t const stackSize = 8192 - cleanStack;
192   fprintf(stderr, "Maximum stack size: %zu\n", stackSize);
193   CONTROL(stackSize <= 2048 + 512);
194 }
195 
test_stack_usage(test_data_t const * data)196 static void test_stack_usage(test_data_t const *data) {
197   set_stack();
198   test_btrfs(data);
199   test_decompress_unzstd(data);
200   check_stack();
201 }
202 
main(void)203 int main(void) {
204   test_data_t data = create_test_data();
205   test_btrfs(&data);
206   test_decompress_unzstd(&data);
207   test_stack_usage(&data);
208   free_test_data(&data);
209   return 0;
210 }
211