1 /*
2  * Copyright (c) 2016-2020, Yann Collet, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10 
11 #include "zstd_ldm.h"
12 
13 #include "../common/debug.h"
14 #include "zstd_fast.h"          /* ZSTD_fillHashTable() */
15 #include "zstd_double_fast.h"   /* ZSTD_fillDoubleHashTable() */
16 
17 #define LDM_BUCKET_SIZE_LOG 3
18 #define LDM_MIN_MATCH_LENGTH 64
19 #define LDM_HASH_RLOG 7
20 #define LDM_HASH_CHAR_OFFSET 10
21 
ZSTD_ldm_adjustParameters(ldmParams_t * params,ZSTD_compressionParameters const * cParams)22 void ZSTD_ldm_adjustParameters(ldmParams_t* params,
23                                ZSTD_compressionParameters const* cParams)
24 {
25     params->windowLog = cParams->windowLog;
26     ZSTD_STATIC_ASSERT(LDM_BUCKET_SIZE_LOG <= ZSTD_LDM_BUCKETSIZELOG_MAX);
27     DEBUGLOG(4, "ZSTD_ldm_adjustParameters");
28     if (!params->bucketSizeLog) params->bucketSizeLog = LDM_BUCKET_SIZE_LOG;
29     if (!params->minMatchLength) params->minMatchLength = LDM_MIN_MATCH_LENGTH;
30     if (params->hashLog == 0) {
31         params->hashLog = MAX(ZSTD_HASHLOG_MIN, params->windowLog - LDM_HASH_RLOG);
32         assert(params->hashLog <= ZSTD_HASHLOG_MAX);
33     }
34     if (params->hashRateLog == 0) {
35         params->hashRateLog = params->windowLog < params->hashLog
36                                    ? 0
37                                    : params->windowLog - params->hashLog;
38     }
39     params->bucketSizeLog = MIN(params->bucketSizeLog, params->hashLog);
40 }
41 
ZSTD_ldm_getTableSize(ldmParams_t params)42 size_t ZSTD_ldm_getTableSize(ldmParams_t params)
43 {
44     size_t const ldmHSize = ((size_t)1) << params.hashLog;
45     size_t const ldmBucketSizeLog = MIN(params.bucketSizeLog, params.hashLog);
46     size_t const ldmBucketSize = ((size_t)1) << (params.hashLog - ldmBucketSizeLog);
47     size_t const totalSize = ZSTD_cwksp_alloc_size(ldmBucketSize)
48                            + ZSTD_cwksp_alloc_size(ldmHSize * sizeof(ldmEntry_t));
49     return params.enableLdm ? totalSize : 0;
50 }
51 
ZSTD_ldm_getMaxNbSeq(ldmParams_t params,size_t maxChunkSize)52 size_t ZSTD_ldm_getMaxNbSeq(ldmParams_t params, size_t maxChunkSize)
53 {
54     return params.enableLdm ? (maxChunkSize / params.minMatchLength) : 0;
55 }
56 
57 /** ZSTD_ldm_getSmallHash() :
58  *  numBits should be <= 32
59  *  If numBits==0, returns 0.
60  *  @return : the most significant numBits of value. */
ZSTD_ldm_getSmallHash(U64 value,U32 numBits)61 static U32 ZSTD_ldm_getSmallHash(U64 value, U32 numBits)
62 {
63     assert(numBits <= 32);
64     return numBits == 0 ? 0 : (U32)(value >> (64 - numBits));
65 }
66 
67 /** ZSTD_ldm_getChecksum() :
68  *  numBitsToDiscard should be <= 32
69  *  @return : the next most significant 32 bits after numBitsToDiscard */
ZSTD_ldm_getChecksum(U64 hash,U32 numBitsToDiscard)70 static U32 ZSTD_ldm_getChecksum(U64 hash, U32 numBitsToDiscard)
71 {
72     assert(numBitsToDiscard <= 32);
73     return (hash >> (64 - 32 - numBitsToDiscard)) & 0xFFFFFFFF;
74 }
75 
76 /** ZSTD_ldm_getTag() ;
77  *  Given the hash, returns the most significant numTagBits bits
78  *  after (32 + hbits) bits.
79  *
80  *  If there are not enough bits remaining, return the last
81  *  numTagBits bits. */
ZSTD_ldm_getTag(U64 hash,U32 hbits,U32 numTagBits)82 static U32 ZSTD_ldm_getTag(U64 hash, U32 hbits, U32 numTagBits)
83 {
84     assert(numTagBits < 32 && hbits <= 32);
85     if (32 - hbits < numTagBits) {
86         return hash & (((U32)1 << numTagBits) - 1);
87     } else {
88         return (hash >> (32 - hbits - numTagBits)) & (((U32)1 << numTagBits) - 1);
89     }
90 }
91 
92 /** ZSTD_ldm_getBucket() :
93  *  Returns a pointer to the start of the bucket associated with hash. */
ZSTD_ldm_getBucket(ldmState_t * ldmState,size_t hash,ldmParams_t const ldmParams)94 static ldmEntry_t* ZSTD_ldm_getBucket(
95         ldmState_t* ldmState, size_t hash, ldmParams_t const ldmParams)
96 {
97     return ldmState->hashTable + (hash << ldmParams.bucketSizeLog);
98 }
99 
100 /** ZSTD_ldm_insertEntry() :
101  *  Insert the entry with corresponding hash into the hash table */
ZSTD_ldm_insertEntry(ldmState_t * ldmState,size_t const hash,const ldmEntry_t entry,ldmParams_t const ldmParams)102 static void ZSTD_ldm_insertEntry(ldmState_t* ldmState,
103                                  size_t const hash, const ldmEntry_t entry,
104                                  ldmParams_t const ldmParams)
105 {
106     BYTE* const bucketOffsets = ldmState->bucketOffsets;
107     *(ZSTD_ldm_getBucket(ldmState, hash, ldmParams) + bucketOffsets[hash]) = entry;
108     bucketOffsets[hash]++;
109     bucketOffsets[hash] &= ((U32)1 << ldmParams.bucketSizeLog) - 1;
110 }
111 
112 /** ZSTD_ldm_makeEntryAndInsertByTag() :
113  *
114  *  Gets the small hash, checksum, and tag from the rollingHash.
115  *
116  *  If the tag matches (1 << ldmParams.hashRateLog)-1, then
117  *  creates an ldmEntry from the offset, and inserts it into the hash table.
118  *
119  *  hBits is the length of the small hash, which is the most significant hBits
120  *  of rollingHash. The checksum is the next 32 most significant bits, followed
121  *  by ldmParams.hashRateLog bits that make up the tag. */
ZSTD_ldm_makeEntryAndInsertByTag(ldmState_t * ldmState,U64 const rollingHash,U32 const hBits,U32 const offset,ldmParams_t const ldmParams)122 static void ZSTD_ldm_makeEntryAndInsertByTag(ldmState_t* ldmState,
123                                              U64 const rollingHash,
124                                              U32 const hBits,
125                                              U32 const offset,
126                                              ldmParams_t const ldmParams)
127 {
128     U32 const tag = ZSTD_ldm_getTag(rollingHash, hBits, ldmParams.hashRateLog);
129     U32 const tagMask = ((U32)1 << ldmParams.hashRateLog) - 1;
130     if (tag == tagMask) {
131         U32 const hash = ZSTD_ldm_getSmallHash(rollingHash, hBits);
132         U32 const checksum = ZSTD_ldm_getChecksum(rollingHash, hBits);
133         ldmEntry_t entry;
134         entry.offset = offset;
135         entry.checksum = checksum;
136         ZSTD_ldm_insertEntry(ldmState, hash, entry, ldmParams);
137     }
138 }
139 
140 /** ZSTD_ldm_countBackwardsMatch() :
141  *  Returns the number of bytes that match backwards before pIn and pMatch.
142  *
143  *  We count only bytes where pMatch >= pBase and pIn >= pAnchor. */
ZSTD_ldm_countBackwardsMatch(const BYTE * pIn,const BYTE * pAnchor,const BYTE * pMatch,const BYTE * pMatchBase)144 static size_t ZSTD_ldm_countBackwardsMatch(
145             const BYTE* pIn, const BYTE* pAnchor,
146             const BYTE* pMatch, const BYTE* pMatchBase)
147 {
148     size_t matchLength = 0;
149     while (pIn > pAnchor && pMatch > pMatchBase && pIn[-1] == pMatch[-1]) {
150         pIn--;
151         pMatch--;
152         matchLength++;
153     }
154     return matchLength;
155 }
156 
157 /** ZSTD_ldm_countBackwardsMatch_2segments() :
158  *  Returns the number of bytes that match backwards from pMatch,
159  *  even with the backwards match spanning 2 different segments.
160  *
161  *  On reaching `pMatchBase`, start counting from mEnd */
ZSTD_ldm_countBackwardsMatch_2segments(const BYTE * pIn,const BYTE * pAnchor,const BYTE * pMatch,const BYTE * pMatchBase,const BYTE * pExtDictStart,const BYTE * pExtDictEnd)162 static size_t ZSTD_ldm_countBackwardsMatch_2segments(
163                     const BYTE* pIn, const BYTE* pAnchor,
164                     const BYTE* pMatch, const BYTE* pMatchBase,
165                     const BYTE* pExtDictStart, const BYTE* pExtDictEnd)
166 {
167     size_t matchLength = ZSTD_ldm_countBackwardsMatch(pIn, pAnchor, pMatch, pMatchBase);
168     if (pMatch - matchLength != pMatchBase || pMatchBase == pExtDictStart) {
169         /* If backwards match is entirely in the extDict or prefix, immediately return */
170         return matchLength;
171     }
172     DEBUGLOG(7, "ZSTD_ldm_countBackwardsMatch_2segments: found 2-parts backwards match (length in prefix==%zu)", matchLength);
173     matchLength += ZSTD_ldm_countBackwardsMatch(pIn - matchLength, pAnchor, pExtDictEnd, pExtDictStart);
174     DEBUGLOG(7, "final backwards match length = %zu", matchLength);
175     return matchLength;
176 }
177 
178 /** ZSTD_ldm_fillFastTables() :
179  *
180  *  Fills the relevant tables for the ZSTD_fast and ZSTD_dfast strategies.
181  *  This is similar to ZSTD_loadDictionaryContent.
182  *
183  *  The tables for the other strategies are filled within their
184  *  block compressors. */
ZSTD_ldm_fillFastTables(ZSTD_matchState_t * ms,void const * end)185 static size_t ZSTD_ldm_fillFastTables(ZSTD_matchState_t* ms,
186                                       void const* end)
187 {
188     const BYTE* const iend = (const BYTE*)end;
189 
190     switch(ms->cParams.strategy)
191     {
192     case ZSTD_fast:
193         ZSTD_fillHashTable(ms, iend, ZSTD_dtlm_fast);
194         break;
195 
196     case ZSTD_dfast:
197         ZSTD_fillDoubleHashTable(ms, iend, ZSTD_dtlm_fast);
198         break;
199 
200     case ZSTD_greedy:
201     case ZSTD_lazy:
202     case ZSTD_lazy2:
203     case ZSTD_btlazy2:
204     case ZSTD_btopt:
205     case ZSTD_btultra:
206     case ZSTD_btultra2:
207         break;
208     default:
209         assert(0);  /* not possible : not a valid strategy id */
210     }
211 
212     return 0;
213 }
214 
215 /** ZSTD_ldm_fillLdmHashTable() :
216  *
217  *  Fills hashTable from (lastHashed + 1) to iend (non-inclusive).
218  *  lastHash is the rolling hash that corresponds to lastHashed.
219  *
220  *  Returns the rolling hash corresponding to position iend-1. */
ZSTD_ldm_fillLdmHashTable(ldmState_t * state,U64 lastHash,const BYTE * lastHashed,const BYTE * iend,const BYTE * base,U32 hBits,ldmParams_t const ldmParams)221 static U64 ZSTD_ldm_fillLdmHashTable(ldmState_t* state,
222                                      U64 lastHash, const BYTE* lastHashed,
223                                      const BYTE* iend, const BYTE* base,
224                                      U32 hBits, ldmParams_t const ldmParams)
225 {
226     U64 rollingHash = lastHash;
227     const BYTE* cur = lastHashed + 1;
228 
229     while (cur < iend) {
230         rollingHash = ZSTD_rollingHash_rotate(rollingHash, cur[-1],
231                                               cur[ldmParams.minMatchLength-1],
232                                               state->hashPower);
233         ZSTD_ldm_makeEntryAndInsertByTag(state,
234                                          rollingHash, hBits,
235                                          (U32)(cur - base), ldmParams);
236         ++cur;
237     }
238     return rollingHash;
239 }
240 
ZSTD_ldm_fillHashTable(ldmState_t * state,const BYTE * ip,const BYTE * iend,ldmParams_t const * params)241 void ZSTD_ldm_fillHashTable(
242             ldmState_t* state, const BYTE* ip,
243             const BYTE* iend, ldmParams_t const* params)
244 {
245     DEBUGLOG(5, "ZSTD_ldm_fillHashTable");
246     if ((size_t)(iend - ip) >= params->minMatchLength) {
247         U64 startingHash = ZSTD_rollingHash_compute(ip, params->minMatchLength);
248         ZSTD_ldm_fillLdmHashTable(
249             state, startingHash, ip, iend - params->minMatchLength, state->window.base,
250             params->hashLog - params->bucketSizeLog,
251             *params);
252     }
253 }
254 
255 
256 /** ZSTD_ldm_limitTableUpdate() :
257  *
258  *  Sets cctx->nextToUpdate to a position corresponding closer to anchor
259  *  if it is far way
260  *  (after a long match, only update tables a limited amount). */
ZSTD_ldm_limitTableUpdate(ZSTD_matchState_t * ms,const BYTE * anchor)261 static void ZSTD_ldm_limitTableUpdate(ZSTD_matchState_t* ms, const BYTE* anchor)
262 {
263     U32 const curr = (U32)(anchor - ms->window.base);
264     if (curr > ms->nextToUpdate + 1024) {
265         ms->nextToUpdate =
266             curr - MIN(512, curr - ms->nextToUpdate - 1024);
267     }
268 }
269 
ZSTD_ldm_generateSequences_internal(ldmState_t * ldmState,rawSeqStore_t * rawSeqStore,ldmParams_t const * params,void const * src,size_t srcSize)270 static size_t ZSTD_ldm_generateSequences_internal(
271         ldmState_t* ldmState, rawSeqStore_t* rawSeqStore,
272         ldmParams_t const* params, void const* src, size_t srcSize)
273 {
274     /* LDM parameters */
275     int const extDict = ZSTD_window_hasExtDict(ldmState->window);
276     U32 const minMatchLength = params->minMatchLength;
277     U64 const hashPower = ldmState->hashPower;
278     U32 const hBits = params->hashLog - params->bucketSizeLog;
279     U32 const ldmBucketSize = 1U << params->bucketSizeLog;
280     U32 const hashRateLog = params->hashRateLog;
281     U32 const ldmTagMask = (1U << params->hashRateLog) - 1;
282     /* Prefix and extDict parameters */
283     U32 const dictLimit = ldmState->window.dictLimit;
284     U32 const lowestIndex = extDict ? ldmState->window.lowLimit : dictLimit;
285     BYTE const* const base = ldmState->window.base;
286     BYTE const* const dictBase = extDict ? ldmState->window.dictBase : NULL;
287     BYTE const* const dictStart = extDict ? dictBase + lowestIndex : NULL;
288     BYTE const* const dictEnd = extDict ? dictBase + dictLimit : NULL;
289     BYTE const* const lowPrefixPtr = base + dictLimit;
290     /* Input bounds */
291     BYTE const* const istart = (BYTE const*)src;
292     BYTE const* const iend = istart + srcSize;
293     BYTE const* const ilimit = iend - MAX(minMatchLength, HASH_READ_SIZE);
294     /* Input positions */
295     BYTE const* anchor = istart;
296     BYTE const* ip = istart;
297     /* Rolling hash */
298     BYTE const* lastHashed = NULL;
299     U64 rollingHash = 0;
300 
301     while (ip <= ilimit) {
302         size_t mLength;
303         U32 const curr = (U32)(ip - base);
304         size_t forwardMatchLength = 0, backwardMatchLength = 0;
305         ldmEntry_t* bestEntry = NULL;
306         if (ip != istart) {
307             rollingHash = ZSTD_rollingHash_rotate(rollingHash, lastHashed[0],
308                                                   lastHashed[minMatchLength],
309                                                   hashPower);
310         } else {
311             rollingHash = ZSTD_rollingHash_compute(ip, minMatchLength);
312         }
313         lastHashed = ip;
314 
315         /* Do not insert and do not look for a match */
316         if (ZSTD_ldm_getTag(rollingHash, hBits, hashRateLog) != ldmTagMask) {
317            ip++;
318            continue;
319         }
320 
321         /* Get the best entry and compute the match lengths */
322         {
323             ldmEntry_t* const bucket =
324                 ZSTD_ldm_getBucket(ldmState,
325                                    ZSTD_ldm_getSmallHash(rollingHash, hBits),
326                                    *params);
327             ldmEntry_t* cur;
328             size_t bestMatchLength = 0;
329             U32 const checksum = ZSTD_ldm_getChecksum(rollingHash, hBits);
330 
331             for (cur = bucket; cur < bucket + ldmBucketSize; ++cur) {
332                 size_t curForwardMatchLength, curBackwardMatchLength,
333                        curTotalMatchLength;
334                 if (cur->checksum != checksum || cur->offset <= lowestIndex) {
335                     continue;
336                 }
337                 if (extDict) {
338                     BYTE const* const curMatchBase =
339                         cur->offset < dictLimit ? dictBase : base;
340                     BYTE const* const pMatch = curMatchBase + cur->offset;
341                     BYTE const* const matchEnd =
342                         cur->offset < dictLimit ? dictEnd : iend;
343                     BYTE const* const lowMatchPtr =
344                         cur->offset < dictLimit ? dictStart : lowPrefixPtr;
345 
346                     curForwardMatchLength = ZSTD_count_2segments(
347                                                 ip, pMatch, iend,
348                                                 matchEnd, lowPrefixPtr);
349                     if (curForwardMatchLength < minMatchLength) {
350                         continue;
351                     }
352                     curBackwardMatchLength =
353                         ZSTD_ldm_countBackwardsMatch_2segments(ip, anchor,
354                                                                pMatch, lowMatchPtr,
355                                                                dictStart, dictEnd);
356                     curTotalMatchLength = curForwardMatchLength +
357                                           curBackwardMatchLength;
358                 } else { /* !extDict */
359                     BYTE const* const pMatch = base + cur->offset;
360                     curForwardMatchLength = ZSTD_count(ip, pMatch, iend);
361                     if (curForwardMatchLength < minMatchLength) {
362                         continue;
363                     }
364                     curBackwardMatchLength =
365                         ZSTD_ldm_countBackwardsMatch(ip, anchor, pMatch,
366                                                      lowPrefixPtr);
367                     curTotalMatchLength = curForwardMatchLength +
368                                           curBackwardMatchLength;
369                 }
370 
371                 if (curTotalMatchLength > bestMatchLength) {
372                     bestMatchLength = curTotalMatchLength;
373                     forwardMatchLength = curForwardMatchLength;
374                     backwardMatchLength = curBackwardMatchLength;
375                     bestEntry = cur;
376                 }
377             }
378         }
379 
380         /* No match found -- continue searching */
381         if (bestEntry == NULL) {
382             ZSTD_ldm_makeEntryAndInsertByTag(ldmState, rollingHash,
383                                              hBits, curr,
384                                              *params);
385             ip++;
386             continue;
387         }
388 
389         /* Match found */
390         mLength = forwardMatchLength + backwardMatchLength;
391         ip -= backwardMatchLength;
392 
393         {
394             /* Store the sequence:
395              * ip = curr - backwardMatchLength
396              * The match is at (bestEntry->offset - backwardMatchLength)
397              */
398             U32 const matchIndex = bestEntry->offset;
399             U32 const offset = curr - matchIndex;
400             rawSeq* const seq = rawSeqStore->seq + rawSeqStore->size;
401 
402             /* Out of sequence storage */
403             if (rawSeqStore->size == rawSeqStore->capacity)
404                 return ERROR(dstSize_tooSmall);
405             seq->litLength = (U32)(ip - anchor);
406             seq->matchLength = (U32)mLength;
407             seq->offset = offset;
408             rawSeqStore->size++;
409         }
410 
411         /* Insert the current entry into the hash table */
412         ZSTD_ldm_makeEntryAndInsertByTag(ldmState, rollingHash, hBits,
413                                          (U32)(lastHashed - base),
414                                          *params);
415 
416         assert(ip + backwardMatchLength == lastHashed);
417 
418         /* Fill the hash table from lastHashed+1 to ip+mLength*/
419         /* Heuristic: don't need to fill the entire table at end of block */
420         if (ip + mLength <= ilimit) {
421             rollingHash = ZSTD_ldm_fillLdmHashTable(
422                               ldmState, rollingHash, lastHashed,
423                               ip + mLength, base, hBits, *params);
424             lastHashed = ip + mLength - 1;
425         }
426         ip += mLength;
427         anchor = ip;
428     }
429     return iend - anchor;
430 }
431 
432 /*! ZSTD_ldm_reduceTable() :
433  *  reduce table indexes by `reducerValue` */
ZSTD_ldm_reduceTable(ldmEntry_t * const table,U32 const size,U32 const reducerValue)434 static void ZSTD_ldm_reduceTable(ldmEntry_t* const table, U32 const size,
435                                  U32 const reducerValue)
436 {
437     U32 u;
438     for (u = 0; u < size; u++) {
439         if (table[u].offset < reducerValue) table[u].offset = 0;
440         else table[u].offset -= reducerValue;
441     }
442 }
443 
ZSTD_ldm_generateSequences(ldmState_t * ldmState,rawSeqStore_t * sequences,ldmParams_t const * params,void const * src,size_t srcSize)444 size_t ZSTD_ldm_generateSequences(
445         ldmState_t* ldmState, rawSeqStore_t* sequences,
446         ldmParams_t const* params, void const* src, size_t srcSize)
447 {
448     U32 const maxDist = 1U << params->windowLog;
449     BYTE const* const istart = (BYTE const*)src;
450     BYTE const* const iend = istart + srcSize;
451     size_t const kMaxChunkSize = 1 << 20;
452     size_t const nbChunks = (srcSize / kMaxChunkSize) + ((srcSize % kMaxChunkSize) != 0);
453     size_t chunk;
454     size_t leftoverSize = 0;
455 
456     assert(ZSTD_CHUNKSIZE_MAX >= kMaxChunkSize);
457     /* Check that ZSTD_window_update() has been called for this chunk prior
458      * to passing it to this function.
459      */
460     assert(ldmState->window.nextSrc >= (BYTE const*)src + srcSize);
461     /* The input could be very large (in zstdmt), so it must be broken up into
462      * chunks to enforce the maximum distance and handle overflow correction.
463      */
464     assert(sequences->pos <= sequences->size);
465     assert(sequences->size <= sequences->capacity);
466     for (chunk = 0; chunk < nbChunks && sequences->size < sequences->capacity; ++chunk) {
467         BYTE const* const chunkStart = istart + chunk * kMaxChunkSize;
468         size_t const remaining = (size_t)(iend - chunkStart);
469         BYTE const *const chunkEnd =
470             (remaining < kMaxChunkSize) ? iend : chunkStart + kMaxChunkSize;
471         size_t const chunkSize = chunkEnd - chunkStart;
472         size_t newLeftoverSize;
473         size_t const prevSize = sequences->size;
474 
475         assert(chunkStart < iend);
476         /* 1. Perform overflow correction if necessary. */
477         if (ZSTD_window_needOverflowCorrection(ldmState->window, chunkEnd)) {
478             U32 const ldmHSize = 1U << params->hashLog;
479             U32 const correction = ZSTD_window_correctOverflow(
480                 &ldmState->window, /* cycleLog */ 0, maxDist, chunkStart);
481             ZSTD_ldm_reduceTable(ldmState->hashTable, ldmHSize, correction);
482             /* invalidate dictionaries on overflow correction */
483             ldmState->loadedDictEnd = 0;
484         }
485         /* 2. We enforce the maximum offset allowed.
486          *
487          * kMaxChunkSize should be small enough that we don't lose too much of
488          * the window through early invalidation.
489          * TODO: * Test the chunk size.
490          *       * Try invalidation after the sequence generation and test the
491          *         the offset against maxDist directly.
492          *
493          * NOTE: Because of dictionaries + sequence splitting we MUST make sure
494          * that any offset used is valid at the END of the sequence, since it may
495          * be split into two sequences. This condition holds when using
496          * ZSTD_window_enforceMaxDist(), but if we move to checking offsets
497          * against maxDist directly, we'll have to carefully handle that case.
498          */
499         ZSTD_window_enforceMaxDist(&ldmState->window, chunkEnd, maxDist, &ldmState->loadedDictEnd, NULL);
500         /* 3. Generate the sequences for the chunk, and get newLeftoverSize. */
501         newLeftoverSize = ZSTD_ldm_generateSequences_internal(
502             ldmState, sequences, params, chunkStart, chunkSize);
503         if (ZSTD_isError(newLeftoverSize))
504             return newLeftoverSize;
505         /* 4. We add the leftover literals from previous iterations to the first
506          *    newly generated sequence, or add the `newLeftoverSize` if none are
507          *    generated.
508          */
509         /* Prepend the leftover literals from the last call */
510         if (prevSize < sequences->size) {
511             sequences->seq[prevSize].litLength += (U32)leftoverSize;
512             leftoverSize = newLeftoverSize;
513         } else {
514             assert(newLeftoverSize == chunkSize);
515             leftoverSize += chunkSize;
516         }
517     }
518     return 0;
519 }
520 
ZSTD_ldm_skipSequences(rawSeqStore_t * rawSeqStore,size_t srcSize,U32 const minMatch)521 void ZSTD_ldm_skipSequences(rawSeqStore_t* rawSeqStore, size_t srcSize, U32 const minMatch) {
522     while (srcSize > 0 && rawSeqStore->pos < rawSeqStore->size) {
523         rawSeq* seq = rawSeqStore->seq + rawSeqStore->pos;
524         if (srcSize <= seq->litLength) {
525             /* Skip past srcSize literals */
526             seq->litLength -= (U32)srcSize;
527             return;
528         }
529         srcSize -= seq->litLength;
530         seq->litLength = 0;
531         if (srcSize < seq->matchLength) {
532             /* Skip past the first srcSize of the match */
533             seq->matchLength -= (U32)srcSize;
534             if (seq->matchLength < minMatch) {
535                 /* The match is too short, omit it */
536                 if (rawSeqStore->pos + 1 < rawSeqStore->size) {
537                     seq[1].litLength += seq[0].matchLength;
538                 }
539                 rawSeqStore->pos++;
540             }
541             return;
542         }
543         srcSize -= seq->matchLength;
544         seq->matchLength = 0;
545         rawSeqStore->pos++;
546     }
547 }
548 
549 /**
550  * If the sequence length is longer than remaining then the sequence is split
551  * between this block and the next.
552  *
553  * Returns the current sequence to handle, or if the rest of the block should
554  * be literals, it returns a sequence with offset == 0.
555  */
maybeSplitSequence(rawSeqStore_t * rawSeqStore,U32 const remaining,U32 const minMatch)556 static rawSeq maybeSplitSequence(rawSeqStore_t* rawSeqStore,
557                                  U32 const remaining, U32 const minMatch)
558 {
559     rawSeq sequence = rawSeqStore->seq[rawSeqStore->pos];
560     assert(sequence.offset > 0);
561     /* Likely: No partial sequence */
562     if (remaining >= sequence.litLength + sequence.matchLength) {
563         rawSeqStore->pos++;
564         return sequence;
565     }
566     /* Cut the sequence short (offset == 0 ==> rest is literals). */
567     if (remaining <= sequence.litLength) {
568         sequence.offset = 0;
569     } else if (remaining < sequence.litLength + sequence.matchLength) {
570         sequence.matchLength = remaining - sequence.litLength;
571         if (sequence.matchLength < minMatch) {
572             sequence.offset = 0;
573         }
574     }
575     /* Skip past `remaining` bytes for the future sequences. */
576     ZSTD_ldm_skipSequences(rawSeqStore, remaining, minMatch);
577     return sequence;
578 }
579 
ZSTD_ldm_skipRawSeqStoreBytes(rawSeqStore_t * rawSeqStore,size_t nbBytes)580 void ZSTD_ldm_skipRawSeqStoreBytes(rawSeqStore_t* rawSeqStore, size_t nbBytes) {
581     U32 currPos = (U32)(rawSeqStore->posInSequence + nbBytes);
582     while (currPos && rawSeqStore->pos < rawSeqStore->size) {
583         rawSeq currSeq = rawSeqStore->seq[rawSeqStore->pos];
584         if (currPos >= currSeq.litLength + currSeq.matchLength) {
585             currPos -= currSeq.litLength + currSeq.matchLength;
586             rawSeqStore->pos++;
587         } else {
588             rawSeqStore->posInSequence = currPos;
589             break;
590         }
591     }
592     if (currPos == 0 || rawSeqStore->pos == rawSeqStore->size) {
593         rawSeqStore->posInSequence = 0;
594     }
595 }
596 
ZSTD_ldm_blockCompress(rawSeqStore_t * rawSeqStore,ZSTD_matchState_t * ms,seqStore_t * seqStore,U32 rep[ZSTD_REP_NUM],void const * src,size_t srcSize)597 size_t ZSTD_ldm_blockCompress(rawSeqStore_t* rawSeqStore,
598     ZSTD_matchState_t* ms, seqStore_t* seqStore, U32 rep[ZSTD_REP_NUM],
599     void const* src, size_t srcSize)
600 {
601     const ZSTD_compressionParameters* const cParams = &ms->cParams;
602     unsigned const minMatch = cParams->minMatch;
603     ZSTD_blockCompressor const blockCompressor =
604         ZSTD_selectBlockCompressor(cParams->strategy, ZSTD_matchState_dictMode(ms));
605     /* Input bounds */
606     BYTE const* const istart = (BYTE const*)src;
607     BYTE const* const iend = istart + srcSize;
608     /* Input positions */
609     BYTE const* ip = istart;
610 
611     DEBUGLOG(5, "ZSTD_ldm_blockCompress: srcSize=%zu", srcSize);
612     /* If using opt parser, use LDMs only as candidates rather than always accepting them */
613     if (cParams->strategy >= ZSTD_btopt) {
614         size_t lastLLSize;
615         ms->ldmSeqStore = rawSeqStore;
616         lastLLSize = blockCompressor(ms, seqStore, rep, src, srcSize);
617         ZSTD_ldm_skipRawSeqStoreBytes(rawSeqStore, srcSize);
618         return lastLLSize;
619     }
620 
621     assert(rawSeqStore->pos <= rawSeqStore->size);
622     assert(rawSeqStore->size <= rawSeqStore->capacity);
623     /* Loop through each sequence and apply the block compressor to the lits */
624     while (rawSeqStore->pos < rawSeqStore->size && ip < iend) {
625         /* maybeSplitSequence updates rawSeqStore->pos */
626         rawSeq const sequence = maybeSplitSequence(rawSeqStore,
627                                                    (U32)(iend - ip), minMatch);
628         int i;
629         /* End signal */
630         if (sequence.offset == 0)
631             break;
632 
633         assert(ip + sequence.litLength + sequence.matchLength <= iend);
634 
635         /* Fill tables for block compressor */
636         ZSTD_ldm_limitTableUpdate(ms, ip);
637         ZSTD_ldm_fillFastTables(ms, ip);
638         /* Run the block compressor */
639         DEBUGLOG(5, "pos %u : calling block compressor on segment of size %u", (unsigned)(ip-istart), sequence.litLength);
640         {
641             size_t const newLitLength =
642                 blockCompressor(ms, seqStore, rep, ip, sequence.litLength);
643             ip += sequence.litLength;
644             /* Update the repcodes */
645             for (i = ZSTD_REP_NUM - 1; i > 0; i--)
646                 rep[i] = rep[i-1];
647             rep[0] = sequence.offset;
648             /* Store the sequence */
649             ZSTD_storeSeq(seqStore, newLitLength, ip - newLitLength, iend,
650                           sequence.offset + ZSTD_REP_MOVE,
651                           sequence.matchLength - MINMATCH);
652             ip += sequence.matchLength;
653         }
654     }
655     /* Fill the tables for the block compressor */
656     ZSTD_ldm_limitTableUpdate(ms, ip);
657     ZSTD_ldm_fillFastTables(ms, ip);
658     /* Compress the last literals */
659     return blockCompressor(ms, seqStore, rep, ip, iend - ip);
660 }
661