1 /*
2  * Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10 
11 /*
12  * This program takes a file in input,
13  * performs an LZ4 round-trip test (compress + decompress)
14  * compares the result with original
15  * and generates an abort() on corruption detection,
16  * in order for afl to register the event as a crash.
17 */
18 
19 
20 /*===========================================
21 *   Tuning Constant
22 *==========================================*/
23 #ifndef MIN_CLEVEL
24 #  define MIN_CLEVEL (int)(-5)
25 #endif
26 
27 
28 
29 /*===========================================
30 *   Dependencies
31 *==========================================*/
32 #include <stddef.h>     /* size_t */
33 #include <stdlib.h>     /* malloc, free, exit */
34 #include <stdio.h>      /* fprintf */
35 #include <string.h>     /* strcmp */
36 #include <assert.h>
37 #include <sys/types.h>  /* stat */
38 #include <sys/stat.h>   /* stat */
39 #include "xxhash.h"
40 
41 #include "lz4.h"
42 #include "lz4hc.h"
43 
44 
45 /*===========================================
46 *   Macros
47 *==========================================*/
48 #define MIN(a,b)  ( (a) < (b) ? (a) : (b) )
49 
50 #define MSG(...)    fprintf(stderr, __VA_ARGS__)
51 
52 #define CONTROL_MSG(c, ...) {   \
53     if ((c)) {                  \
54         MSG(__VA_ARGS__);       \
55         MSG(" \n");             \
56         abort();                \
57     }                           \
58 }
59 
60 
checkBuffers(const void * buff1,const void * buff2,size_t buffSize)61 static size_t checkBuffers(const void* buff1, const void* buff2, size_t buffSize)
62 {
63     const char* const ip1 = (const char*)buff1;
64     const char* const ip2 = (const char*)buff2;
65     size_t pos;
66 
67     for (pos=0; pos<buffSize; pos++)
68         if (ip1[pos]!=ip2[pos])
69             break;
70 
71     return pos;
72 }
73 
74 
75 /* select a compression level
76  * based on first bytes present in a reference buffer */
select_clevel(const void * refBuff,size_t refBuffSize)77 static int select_clevel(const void* refBuff, size_t refBuffSize)
78 {
79     const int minCLevel = MIN_CLEVEL;
80     const int maxClevel = LZ4HC_CLEVEL_MAX;
81     const int cLevelSpan = maxClevel - minCLevel;
82     size_t const hashLength = MIN(16, refBuffSize);
83     unsigned const h32 = XXH32(refBuff, hashLength, 0);
84     int const randL = h32 % (cLevelSpan+1);
85 
86     return minCLevel + randL;
87 }
88 
89 
90 typedef int (*compressFn)(const char* src, char* dst, int srcSize, int dstSize, int cLevel);
91 
92 
93 /** roundTripTest() :
94  *  Compresses `srcBuff` into `compressedBuff`,
95  *  then decompresses `compressedBuff` into `resultBuff`.
96  *  If clevel==0, compression level is derived from srcBuff's content head bytes.
97  *  This function abort() if it detects any round-trip error.
98  *  Therefore, if it returns, round trip is considered successfully validated.
99  *  Note : `compressedBuffCapacity` should be `>= LZ4_compressBound(srcSize)`
100  *         for compression to be guaranteed to work */
roundTripTest(void * resultBuff,size_t resultBuffCapacity,void * compressedBuff,size_t compressedBuffCapacity,const void * srcBuff,size_t srcSize,int clevel)101 static void roundTripTest(void* resultBuff, size_t resultBuffCapacity,
102                           void* compressedBuff, size_t compressedBuffCapacity,
103                     const void* srcBuff, size_t srcSize,
104                           int clevel)
105 {
106     int const proposed_clevel = clevel ? clevel : select_clevel(srcBuff, srcSize);
107     int const selected_clevel = proposed_clevel < 0 ? -proposed_clevel : proposed_clevel;   /* if level < 0, it becomes an accelearion value */
108     compressFn compress = selected_clevel >= LZ4HC_CLEVEL_MIN ? LZ4_compress_HC : LZ4_compress_fast;
109     int const cSize = compress((const char*)srcBuff, (char*)compressedBuff, (int)srcSize, (int)compressedBuffCapacity, selected_clevel);
110     CONTROL_MSG(cSize == 0, "Compression error !");
111 
112     {   int const dSize = LZ4_decompress_safe((const char*)compressedBuff, (char*)resultBuff, cSize, (int)resultBuffCapacity);
113         CONTROL_MSG(dSize < 0, "Decompression detected an error !");
114         CONTROL_MSG(dSize != (int)srcSize, "Decompression corruption error : wrong decompressed size !");
115     }
116 
117     /* check potential content corruption error */
118     assert(resultBuffCapacity >= srcSize);
119     {   size_t const errorPos = checkBuffers(srcBuff, resultBuff, srcSize);
120         CONTROL_MSG(errorPos != srcSize,
121                     "Silent decoding corruption, at pos %u !!!",
122                     (unsigned)errorPos);
123     }
124 
125 }
126 
roundTripCheck(const void * srcBuff,size_t srcSize,int clevel)127 static void roundTripCheck(const void* srcBuff, size_t srcSize, int clevel)
128 {
129     size_t const cBuffSize = LZ4_compressBound((int)srcSize);
130     void* const cBuff = malloc(cBuffSize);
131     void* const rBuff = malloc(cBuffSize);
132 
133     if (!cBuff || !rBuff) {
134         fprintf(stderr, "not enough memory ! \n");
135         exit(1);
136     }
137 
138     roundTripTest(rBuff, cBuffSize,
139                   cBuff, cBuffSize,
140                   srcBuff, srcSize,
141                   clevel);
142 
143     free(rBuff);
144     free(cBuff);
145 }
146 
147 
getFileSize(const char * infilename)148 static size_t getFileSize(const char* infilename)
149 {
150     int r;
151 #if defined(_MSC_VER)
152     struct _stat64 statbuf;
153     r = _stat64(infilename, &statbuf);
154     if (r || !(statbuf.st_mode & S_IFREG)) return 0;   /* No good... */
155 #else
156     struct stat statbuf;
157     r = stat(infilename, &statbuf);
158     if (r || !S_ISREG(statbuf.st_mode)) return 0;   /* No good... */
159 #endif
160     return (size_t)statbuf.st_size;
161 }
162 
163 
isDirectory(const char * infilename)164 static int isDirectory(const char* infilename)
165 {
166     int r;
167 #if defined(_MSC_VER)
168     struct _stat64 statbuf;
169     r = _stat64(infilename, &statbuf);
170     if (!r && (statbuf.st_mode & _S_IFDIR)) return 1;
171 #else
172     struct stat statbuf;
173     r = stat(infilename, &statbuf);
174     if (!r && S_ISDIR(statbuf.st_mode)) return 1;
175 #endif
176     return 0;
177 }
178 
179 
180 /** loadFile() :
181  *  requirement : `buffer` size >= `fileSize` */
loadFile(void * buffer,const char * fileName,size_t fileSize)182 static void loadFile(void* buffer, const char* fileName, size_t fileSize)
183 {
184     FILE* const f = fopen(fileName, "rb");
185     if (isDirectory(fileName)) {
186         MSG("Ignoring %s directory \n", fileName);
187         exit(2);
188     }
189     if (f==NULL) {
190         MSG("Impossible to open %s \n", fileName);
191         exit(3);
192     }
193     {   size_t const readSize = fread(buffer, 1, fileSize, f);
194         if (readSize != fileSize) {
195             MSG("Error reading %s \n", fileName);
196             exit(5);
197     }   }
198     fclose(f);
199 }
200 
201 
fileCheck(const char * fileName,int clevel)202 static void fileCheck(const char* fileName, int clevel)
203 {
204     size_t const fileSize = getFileSize(fileName);
205     void* const buffer = malloc(fileSize + !fileSize /* avoid 0 */);
206     if (!buffer) {
207         MSG("not enough memory \n");
208         exit(4);
209     }
210     loadFile(buffer, fileName, fileSize);
211     roundTripCheck(buffer, fileSize, clevel);
212     free (buffer);
213 }
214 
215 
bad_usage(const char * exeName)216 int bad_usage(const char* exeName)
217 {
218     MSG(" \n");
219     MSG("bad usage: \n");
220     MSG(" \n");
221     MSG("%s [Options] fileName \n", exeName);
222     MSG(" \n");
223     MSG("Options: \n");
224     MSG("-#     : use #=[0-9] compression level (default:0 == random) \n");
225     return 1;
226 }
227 
228 
main(int argCount,const char ** argv)229 int main(int argCount, const char** argv)
230 {
231     const char* const exeName = argv[0];
232     int argNb = 1;
233     int clevel = 0;
234 
235     assert(argCount >= 1);
236     if (argCount < 2) return bad_usage(exeName);
237 
238     if (argv[1][0] == '-') {
239         clevel = argv[1][1] - '0';
240         argNb = 2;
241     }
242 
243     if (argNb >= argCount) return bad_usage(exeName);
244 
245     fileCheck(argv[argNb], clevel);
246     MSG("no pb detected \n");
247     return 0;
248 }
249