1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/utils/entropy_decoder.h"
16 
17 #include <cassert>
18 #include <cstring>
19 
20 #include "src/utils/common.h"
21 #include "src/utils/compiler_attributes.h"
22 #include "src/utils/constants.h"
23 #include "src/utils/cpu.h"
24 
25 #if defined(__ARM_NEON__) || defined(__aarch64__) || \
26     (defined(_MSC_VER) && defined(_M_ARM))
27 #define LIBGAV1_ENTROPY_DECODER_ENABLE_NEON 1
28 #else
29 #define LIBGAV1_ENTROPY_DECODER_ENABLE_NEON 0
30 #endif
31 
32 #if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
33 #include <arm_neon.h>
34 #endif
35 
36 #if defined(__SSE2__) || defined(LIBGAV1_X86_MSVC)
37 #define LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2 1
38 #else
39 #define LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2 0
40 #endif
41 
42 #if LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
43 #include <emmintrin.h>
44 #endif
45 
46 namespace libgav1 {
47 namespace {
48 
49 constexpr uint32_t kReadBitMask = ~255;
50 constexpr int kCdfPrecision = 6;
51 constexpr int kMinimumProbabilityPerSymbol = 4;
52 
53 // This function computes the "cur" variable as specified inside the do-while
54 // loop in Section 8.2.6 of the spec. This function is monotonically
55 // decreasing as the values of index increases (note that the |cdf| array is
56 // sorted in decreasing order).
ScaleCdf(uint32_t values_in_range_shifted,const uint16_t * const cdf,int index,int symbol_count)57 uint32_t ScaleCdf(uint32_t values_in_range_shifted, const uint16_t* const cdf,
58                   int index, int symbol_count) {
59   return ((values_in_range_shifted * (cdf[index] >> kCdfPrecision)) >> 1) +
60          (kMinimumProbabilityPerSymbol * (symbol_count - index));
61 }
62 
UpdateCdf(uint16_t * const cdf,const int symbol_count,const int symbol)63 void UpdateCdf(uint16_t* const cdf, const int symbol_count, const int symbol) {
64   const uint16_t count = cdf[symbol_count];
65   // rate is computed in the spec as:
66   //  3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
67   // In this case cdf[N] is |count|.
68   // Min(FloorLog2(N), 2) is 1 for symbol_count == {2, 3} and 2 for all
69   // symbol_count > 3. So the equation becomes:
70   //  4 + (count > 15) + (count > 31) + (symbol_count > 3).
71   // Note that the largest value for count is 32 (it is not incremented beyond
72   // 32). So using that information:
73   //  count >> 4 is 0 for count from 0 to 15.
74   //  count >> 4 is 1 for count from 16 to 31.
75   //  count >> 4 is 2 for count == 31.
76   // Now, the equation becomes:
77   //  4 + (count >> 4) + (symbol_count > 3).
78   // Since (count >> 4) can only be 0 or 1 or 2, the addition could be replaced
79   // with bitwise or:
80   //  (4 | (count >> 4)) + (symbol_count > 3).
81   // but using addition will allow the compiler to eliminate an operation when
82   // symbol_count is known and this function is inlined.
83   const int rate = (count >> 4) + 4 + static_cast<int>(symbol_count > 3);
84   // Hints for further optimizations:
85   //
86   // 1. clang can vectorize this for loop with width 4, even though the loop
87   // contains an if-else statement. Therefore, it may be advantageous to use
88   // "i < symbol_count" as the loop condition when symbol_count is 8, 12, or 16
89   // (a multiple of 4 that's not too small).
90   //
91   // 2. The for loop can be rewritten in the following form, which would enable
92   // clang to vectorize the loop with width 8:
93   //
94   //   const int rounding = (1 << rate) - 1;
95   //   for (int i = 0; i < symbol_count - 1; ++i) {
96   //     const uint16_t a = (i < symbol) ? kCdfMaxProbability : rounding;
97   //     cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
98   //   }
99   //
100   // The subtraction (a - cdf[i]) relies on the overflow semantics of unsigned
101   // integer arithmetic. The result of the unsigned subtraction is cast to a
102   // signed integer and right-shifted. This requires the right shift of a
103   // signed integer be an arithmetic shift, which is true for clang, gcc, and
104   // Visual C++.
105   assert(symbol_count - 1 > 0);
106   int i = 0;
107   do {
108     if (i < symbol) {
109       cdf[i] += (kCdfMaxProbability - cdf[i]) >> rate;
110     } else {
111       cdf[i] -= cdf[i] >> rate;
112     }
113   } while (++i < symbol_count - 1);
114   cdf[symbol_count] += static_cast<uint16_t>(count < 32);
115 }
116 
117 // Define the UpdateCdfN functions. UpdateCdfN is a specialized implementation
118 // of UpdateCdf based on the fact that symbol_count == N. UpdateCdfN uses the
119 // SIMD instruction sets if available.
120 
121 #if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
122 
123 // The UpdateCdf() method contains the following for loop:
124 //
125 //   for (int i = 0; i < symbol_count - 1; ++i) {
126 //     if (i < symbol) {
127 //       cdf[i] += (kCdfMaxProbability - cdf[i]) >> rate;
128 //     } else {
129 //       cdf[i] -= cdf[i] >> rate;
130 //     }
131 //   }
132 //
133 // It can be rewritten in the following two forms, which are amenable to SIMD
134 // implementations:
135 //
136 //   const int rounding = (1 << rate) - 1;
137 //   for (int i = 0; i < symbol_count - 1; ++i) {
138 //     const uint16_t a = (i < symbol) ? kCdfMaxProbability : rounding;
139 //     cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
140 //   }
141 //
142 // or:
143 //
144 //   const int rounding = (1 << rate) - 1;
145 //   for (int i = 0; i < symbol_count - 1; ++i) {
146 //     const uint16_t a = (i < symbol) ? (kCdfMaxProbability - rounding) : 0;
147 //     cdf[i] -= static_cast<int16_t>(cdf[i] - a) >> rate;
148 //   }
149 //
150 // The following ARM NEON implementations use a modified version of the first
151 // form, using the comparison mask and unsigned rollover to avoid the need to
152 // calculate rounding.
153 //
154 // The cdf array has symbol_count + 1 elements. The first symbol_count elements
155 // are the CDF. The last element is a count that is initialized to 0 and may
156 // grow up to 32. The for loop in UpdateCdf updates the CDF in the array. Since
157 // cdf[symbol_count - 1] is always 0, the for loop does not update
158 // cdf[symbol_count - 1]. However, it would be correct to have the for loop
159 // update cdf[symbol_count - 1] anyway: since symbol_count - 1 >= symbol, the
160 // for loop would take the else branch when i is symbol_count - 1:
161 //      cdf[i] -= cdf[i] >> rate;
162 // Since cdf[symbol_count - 1] is 0, cdf[symbol_count - 1] would still be 0
163 // after the update. The ARM NEON implementations take advantage of this in the
164 // following two cases:
165 // 1. When symbol_count is 8 or 16, the vectorized code updates the first
166 //    symbol_count elements in the array.
167 // 2. When symbol_count is 7, the vectorized code updates all the 8 elements in
168 //    the cdf array. Since an invalid CDF value is written into cdf[7], the
169 //    count in cdf[7] needs to be fixed up after the vectorized code.
170 
UpdateCdf5(uint16_t * const cdf,const int symbol)171 void UpdateCdf5(uint16_t* const cdf, const int symbol) {
172   uint16x4_t cdf_vec = vld1_u16(cdf);
173   const uint16_t count = cdf[5];
174   const int rate = (count >> 4) + 5;
175   const uint16x4_t cdf_max_probability = vdup_n_u16(kCdfMaxProbability);
176   const uint16x4_t index = vcreate_u16(0x0003000200010000);
177   const uint16x4_t symbol_vec = vdup_n_u16(symbol);
178   const uint16x4_t mask = vcge_u16(index, symbol_vec);
179   // i < symbol: 32768, i >= symbol: 65535.
180   const uint16x4_t a = vorr_u16(mask, cdf_max_probability);
181   // i < symbol: 32768 - cdf, i >= symbol: 65535 - cdf.
182   const int16x4_t diff = vreinterpret_s16_u16(vsub_u16(a, cdf_vec));
183   // i < symbol: cdf - 0, i >= symbol: cdf - 65535.
184   const uint16x4_t cdf_offset = vsub_u16(cdf_vec, mask);
185   const int16x4_t negative_rate = vdup_n_s16(-rate);
186   // i < symbol: (32768 - cdf) >> rate, i >= symbol: (65535 (-1) - cdf) >> rate.
187   const uint16x4_t delta = vreinterpret_u16_s16(vshl_s16(diff, negative_rate));
188   // i < symbol: (cdf - 0) + ((32768 - cdf) >> rate).
189   // i >= symbol: (cdf - 65535) + ((65535 - cdf) >> rate).
190   cdf_vec = vadd_u16(cdf_offset, delta);
191   vst1_u16(cdf, cdf_vec);
192   cdf[5] = count + static_cast<uint16_t>(count < 32);
193 }
194 
195 // This version works for |symbol_count| = 7, 8, or 9.
196 // See UpdateCdf5 for implementation details.
197 template <int symbol_count>
UpdateCdf7To9(uint16_t * const cdf,const int symbol)198 void UpdateCdf7To9(uint16_t* const cdf, const int symbol) {
199   static_assert(symbol_count >= 7 && symbol_count <= 9, "");
200   uint16x8_t cdf_vec = vld1q_u16(cdf);
201   const uint16_t count = cdf[symbol_count];
202   const int rate = (count >> 4) + 5;
203   const uint16x8_t cdf_max_probability = vdupq_n_u16(kCdfMaxProbability);
204   const uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
205                                         vcreate_u16(0x0007000600050004));
206   const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
207   const uint16x8_t mask = vcgeq_u16(index, symbol_vec);
208   const uint16x8_t a = vorrq_u16(mask, cdf_max_probability);
209   const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec));
210   const uint16x8_t cdf_offset = vsubq_u16(cdf_vec, mask);
211   const int16x8_t negative_rate = vdupq_n_s16(-rate);
212   const uint16x8_t delta =
213       vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
214   cdf_vec = vaddq_u16(cdf_offset, delta);
215   vst1q_u16(cdf, cdf_vec);
216   cdf[symbol_count] = count + static_cast<uint16_t>(count < 32);
217 }
218 
UpdateCdf7(uint16_t * const cdf,const int symbol)219 void UpdateCdf7(uint16_t* const cdf, const int symbol) {
220   UpdateCdf7To9<7>(cdf, symbol);
221 }
222 
UpdateCdf8(uint16_t * const cdf,const int symbol)223 void UpdateCdf8(uint16_t* const cdf, const int symbol) {
224   UpdateCdf7To9<8>(cdf, symbol);
225 }
226 
UpdateCdf9(uint16_t * const cdf,const int symbol)227 void UpdateCdf9(uint16_t* const cdf, const int symbol) {
228   UpdateCdf7To9<9>(cdf, symbol);
229 }
230 
231 // See UpdateCdf5 for implementation details.
UpdateCdf11(uint16_t * const cdf,const int symbol)232 void UpdateCdf11(uint16_t* const cdf, const int symbol) {
233   uint16x8_t cdf_vec = vld1q_u16(cdf + 2);
234   const uint16_t count = cdf[11];
235   cdf[11] = count + static_cast<uint16_t>(count < 32);
236   const int rate = (count >> 4) + 5;
237   if (symbol > 1) {
238     cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
239     cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
240     const uint16x8_t cdf_max_probability = vdupq_n_u16(kCdfMaxProbability);
241     const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
242     const int16x8_t negative_rate = vdupq_n_s16(-rate);
243     const uint16x8_t index = vcombine_u16(vcreate_u16(0x0005000400030002),
244                                           vcreate_u16(0x0009000800070006));
245     const uint16x8_t mask = vcgeq_u16(index, symbol_vec);
246     const uint16x8_t a = vorrq_u16(mask, cdf_max_probability);
247     const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec));
248     const uint16x8_t cdf_offset = vsubq_u16(cdf_vec, mask);
249     const uint16x8_t delta =
250         vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
251     cdf_vec = vaddq_u16(cdf_offset, delta);
252     vst1q_u16(cdf + 2, cdf_vec);
253   } else {
254     if (symbol != 0) {
255       cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
256       cdf[1] -= cdf[1] >> rate;
257     } else {
258       cdf[0] -= cdf[0] >> rate;
259       cdf[1] -= cdf[1] >> rate;
260     }
261     const int16x8_t negative_rate = vdupq_n_s16(-rate);
262     const uint16x8_t delta = vshlq_u16(cdf_vec, negative_rate);
263     cdf_vec = vsubq_u16(cdf_vec, delta);
264     vst1q_u16(cdf + 2, cdf_vec);
265   }
266 }
267 
268 // See UpdateCdf5 for implementation details.
UpdateCdf13(uint16_t * const cdf,const int symbol)269 void UpdateCdf13(uint16_t* const cdf, const int symbol) {
270   uint16x8_t cdf_vec0 = vld1q_u16(cdf);
271   uint16x8_t cdf_vec1 = vld1q_u16(cdf + 4);
272   const uint16_t count = cdf[13];
273   const int rate = (count >> 4) + 5;
274   const uint16x8_t cdf_max_probability = vdupq_n_u16(kCdfMaxProbability);
275   const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
276   const int16x8_t negative_rate = vdupq_n_s16(-rate);
277 
278   uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
279                                   vcreate_u16(0x0007000600050004));
280   uint16x8_t mask = vcgeq_u16(index, symbol_vec);
281   uint16x8_t a = vorrq_u16(mask, cdf_max_probability);
282   int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec0));
283   uint16x8_t cdf_offset = vsubq_u16(cdf_vec0, mask);
284   uint16x8_t delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
285   cdf_vec0 = vaddq_u16(cdf_offset, delta);
286   vst1q_u16(cdf, cdf_vec0);
287 
288   index = vcombine_u16(vcreate_u16(0x0007000600050004),
289                        vcreate_u16(0x000b000a00090008));
290   mask = vcgeq_u16(index, symbol_vec);
291   a = vorrq_u16(mask, cdf_max_probability);
292   diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec1));
293   cdf_offset = vsubq_u16(cdf_vec1, mask);
294   delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
295   cdf_vec1 = vaddq_u16(cdf_offset, delta);
296   vst1q_u16(cdf + 4, cdf_vec1);
297 
298   cdf[13] = count + static_cast<uint16_t>(count < 32);
299 }
300 
301 // See UpdateCdf5 for implementation details.
UpdateCdf16(uint16_t * const cdf,const int symbol)302 void UpdateCdf16(uint16_t* const cdf, const int symbol) {
303   uint16x8_t cdf_vec = vld1q_u16(cdf);
304   const uint16_t count = cdf[16];
305   const int rate = (count >> 4) + 5;
306   const uint16x8_t cdf_max_probability = vdupq_n_u16(kCdfMaxProbability);
307   const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
308   const int16x8_t negative_rate = vdupq_n_s16(-rate);
309 
310   uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
311                                   vcreate_u16(0x0007000600050004));
312   uint16x8_t mask = vcgeq_u16(index, symbol_vec);
313   uint16x8_t a = vorrq_u16(mask, cdf_max_probability);
314   int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec));
315   uint16x8_t cdf_offset = vsubq_u16(cdf_vec, mask);
316   uint16x8_t delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
317   cdf_vec = vaddq_u16(cdf_offset, delta);
318   vst1q_u16(cdf, cdf_vec);
319 
320   cdf_vec = vld1q_u16(cdf + 8);
321   index = vcombine_u16(vcreate_u16(0x000b000a00090008),
322                        vcreate_u16(0x000f000e000d000c));
323   mask = vcgeq_u16(index, symbol_vec);
324   a = vorrq_u16(mask, cdf_max_probability);
325   diff = vreinterpretq_s16_u16(vsubq_u16(a, cdf_vec));
326   cdf_offset = vsubq_u16(cdf_vec, mask);
327   delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
328   cdf_vec = vaddq_u16(cdf_offset, delta);
329   vst1q_u16(cdf + 8, cdf_vec);
330 
331   cdf[16] = count + static_cast<uint16_t>(count < 32);
332 }
333 
334 #else  // !LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
335 
336 #if LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
337 
LoadLo8(const void * a)338 inline __m128i LoadLo8(const void* a) {
339   return _mm_loadl_epi64(static_cast<const __m128i*>(a));
340 }
341 
LoadUnaligned16(const void * a)342 inline __m128i LoadUnaligned16(const void* a) {
343   return _mm_loadu_si128(static_cast<const __m128i*>(a));
344 }
345 
StoreLo8(void * a,const __m128i v)346 inline void StoreLo8(void* a, const __m128i v) {
347   _mm_storel_epi64(static_cast<__m128i*>(a), v);
348 }
349 
StoreUnaligned16(void * a,const __m128i v)350 inline void StoreUnaligned16(void* a, const __m128i v) {
351   _mm_storeu_si128(static_cast<__m128i*>(a), v);
352 }
353 
UpdateCdf5(uint16_t * const cdf,const int symbol)354 void UpdateCdf5(uint16_t* const cdf, const int symbol) {
355   __m128i cdf_vec = LoadLo8(cdf);
356   const uint16_t count = cdf[5];
357   const int rate = (count >> 4) + 5;
358   const __m128i cdf_max_probability =
359       _mm_shufflelo_epi16(_mm_cvtsi32_si128(kCdfMaxProbability), 0);
360   const __m128i index = _mm_set_epi32(0x0, 0x0, 0x00040003, 0x00020001);
361   const __m128i symbol_vec = _mm_shufflelo_epi16(_mm_cvtsi32_si128(symbol), 0);
362   // i >= symbol.
363   const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
364   // i < symbol: 32768, i >= symbol: 65535.
365   const __m128i a = _mm_or_si128(mask, cdf_max_probability);
366   // i < symbol: 32768 - cdf, i >= symbol: 65535 - cdf.
367   const __m128i diff = _mm_sub_epi16(a, cdf_vec);
368   // i < symbol: cdf - 0, i >= symbol: cdf - 65535.
369   const __m128i cdf_offset = _mm_sub_epi16(cdf_vec, mask);
370   // i < symbol: (32768 - cdf) >> rate, i >= symbol: (65535 (-1) - cdf) >> rate.
371   const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
372   // i < symbol: (cdf - 0) + ((32768 - cdf) >> rate).
373   // i >= symbol: (cdf - 65535) + ((65535 - cdf) >> rate).
374   cdf_vec = _mm_add_epi16(cdf_offset, delta);
375   StoreLo8(cdf, cdf_vec);
376   cdf[5] = count + static_cast<uint16_t>(count < 32);
377 }
378 
379 // This version works for |symbol_count| = 7, 8, or 9.
380 // See UpdateCdf5 for implementation details.
381 template <int symbol_count>
UpdateCdf7To9(uint16_t * const cdf,const int symbol)382 void UpdateCdf7To9(uint16_t* const cdf, const int symbol) {
383   static_assert(symbol_count >= 7 && symbol_count <= 9, "");
384   __m128i cdf_vec = LoadUnaligned16(cdf);
385   const uint16_t count = cdf[symbol_count];
386   const int rate = (count >> 4) + 5;
387   const __m128i cdf_max_probability =
388       _mm_set1_epi16(static_cast<int16_t>(kCdfMaxProbability));
389   const __m128i index =
390       _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001);
391   const __m128i symbol_vec = _mm_set1_epi16(static_cast<int16_t>(symbol));
392   const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
393   const __m128i a = _mm_or_si128(mask, cdf_max_probability);
394   const __m128i diff = _mm_sub_epi16(a, cdf_vec);
395   const __m128i cdf_offset = _mm_sub_epi16(cdf_vec, mask);
396   const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
397   cdf_vec = _mm_add_epi16(cdf_offset, delta);
398   StoreUnaligned16(cdf, cdf_vec);
399   cdf[symbol_count] = count + static_cast<uint16_t>(count < 32);
400 }
401 
UpdateCdf7(uint16_t * const cdf,const int symbol)402 void UpdateCdf7(uint16_t* const cdf, const int symbol) {
403   UpdateCdf7To9<7>(cdf, symbol);
404 }
405 
UpdateCdf8(uint16_t * const cdf,const int symbol)406 void UpdateCdf8(uint16_t* const cdf, const int symbol) {
407   UpdateCdf7To9<8>(cdf, symbol);
408 }
409 
UpdateCdf9(uint16_t * const cdf,const int symbol)410 void UpdateCdf9(uint16_t* const cdf, const int symbol) {
411   UpdateCdf7To9<9>(cdf, symbol);
412 }
413 
414 // See UpdateCdf5 for implementation details.
UpdateCdf11(uint16_t * const cdf,const int symbol)415 void UpdateCdf11(uint16_t* const cdf, const int symbol) {
416   __m128i cdf_vec = LoadUnaligned16(cdf + 2);
417   const uint16_t count = cdf[11];
418   cdf[11] = count + static_cast<uint16_t>(count < 32);
419   const int rate = (count >> 4) + 5;
420   if (symbol > 1) {
421     cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
422     cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
423     const __m128i cdf_max_probability =
424         _mm_set1_epi16(static_cast<int16_t>(kCdfMaxProbability));
425     const __m128i index =
426         _mm_set_epi32(0x000a0009, 0x00080007, 0x00060005, 0x00040003);
427     const __m128i symbol_vec = _mm_set1_epi16(static_cast<int16_t>(symbol));
428     const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
429     const __m128i a = _mm_or_si128(mask, cdf_max_probability);
430     const __m128i diff = _mm_sub_epi16(a, cdf_vec);
431     const __m128i cdf_offset = _mm_sub_epi16(cdf_vec, mask);
432     const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
433     cdf_vec = _mm_add_epi16(cdf_offset, delta);
434     StoreUnaligned16(cdf + 2, cdf_vec);
435   } else {
436     if (symbol != 0) {
437       cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
438       cdf[1] -= cdf[1] >> rate;
439     } else {
440       cdf[0] -= cdf[0] >> rate;
441       cdf[1] -= cdf[1] >> rate;
442     }
443     const __m128i delta = _mm_sra_epi16(cdf_vec, _mm_cvtsi32_si128(rate));
444     cdf_vec = _mm_sub_epi16(cdf_vec, delta);
445     StoreUnaligned16(cdf + 2, cdf_vec);
446   }
447 }
448 
449 // See UpdateCdf5 for implementation details.
UpdateCdf13(uint16_t * const cdf,const int symbol)450 void UpdateCdf13(uint16_t* const cdf, const int symbol) {
451   __m128i cdf_vec0 = LoadLo8(cdf);
452   __m128i cdf_vec1 = LoadUnaligned16(cdf + 4);
453   const uint16_t count = cdf[13];
454   const int rate = (count >> 4) + 5;
455   const __m128i cdf_max_probability =
456       _mm_set1_epi16(static_cast<int16_t>(kCdfMaxProbability));
457   const __m128i symbol_vec = _mm_set1_epi16(static_cast<int16_t>(symbol));
458 
459   const __m128i index = _mm_set_epi32(0x0, 0x0, 0x00040003, 0x00020001);
460   const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
461   const __m128i a = _mm_or_si128(mask, cdf_max_probability);
462   const __m128i diff = _mm_sub_epi16(a, cdf_vec0);
463   const __m128i cdf_offset = _mm_sub_epi16(cdf_vec0, mask);
464   const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
465   cdf_vec0 = _mm_add_epi16(cdf_offset, delta);
466   StoreLo8(cdf, cdf_vec0);
467 
468   const __m128i index1 =
469       _mm_set_epi32(0x000c000b, 0x000a0009, 0x00080007, 0x00060005);
470   const __m128i mask1 = _mm_cmpgt_epi16(index1, symbol_vec);
471   const __m128i a1 = _mm_or_si128(mask1, cdf_max_probability);
472   const __m128i diff1 = _mm_sub_epi16(a1, cdf_vec1);
473   const __m128i cdf_offset1 = _mm_sub_epi16(cdf_vec1, mask1);
474   const __m128i delta1 = _mm_sra_epi16(diff1, _mm_cvtsi32_si128(rate));
475   cdf_vec1 = _mm_add_epi16(cdf_offset1, delta1);
476   StoreUnaligned16(cdf + 4, cdf_vec1);
477 
478   cdf[13] = count + static_cast<uint16_t>(count < 32);
479 }
480 
UpdateCdf16(uint16_t * const cdf,const int symbol)481 void UpdateCdf16(uint16_t* const cdf, const int symbol) {
482   __m128i cdf_vec0 = LoadUnaligned16(cdf);
483   const uint16_t count = cdf[16];
484   const int rate = (count >> 4) + 5;
485   const __m128i cdf_max_probability =
486       _mm_set1_epi16(static_cast<int16_t>(kCdfMaxProbability));
487   const __m128i symbol_vec = _mm_set1_epi16(static_cast<int16_t>(symbol));
488 
489   const __m128i index =
490       _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001);
491   const __m128i mask = _mm_cmpgt_epi16(index, symbol_vec);
492   const __m128i a = _mm_or_si128(mask, cdf_max_probability);
493   const __m128i diff = _mm_sub_epi16(a, cdf_vec0);
494   const __m128i cdf_offset = _mm_sub_epi16(cdf_vec0, mask);
495   const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
496   cdf_vec0 = _mm_add_epi16(cdf_offset, delta);
497   StoreUnaligned16(cdf, cdf_vec0);
498 
499   __m128i cdf_vec1 = LoadUnaligned16(cdf + 8);
500   const __m128i index1 =
501       _mm_set_epi32(0x0010000f, 0x000e000d, 0x000c000b, 0x000a0009);
502   const __m128i mask1 = _mm_cmpgt_epi16(index1, symbol_vec);
503   const __m128i a1 = _mm_or_si128(mask1, cdf_max_probability);
504   const __m128i diff1 = _mm_sub_epi16(a1, cdf_vec1);
505   const __m128i cdf_offset1 = _mm_sub_epi16(cdf_vec1, mask1);
506   const __m128i delta1 = _mm_sra_epi16(diff1, _mm_cvtsi32_si128(rate));
507   cdf_vec1 = _mm_add_epi16(cdf_offset1, delta1);
508   StoreUnaligned16(cdf + 8, cdf_vec1);
509 
510   cdf[16] = count + static_cast<uint16_t>(count < 32);
511 }
512 
513 #else  // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
514 
UpdateCdf5(uint16_t * const cdf,const int symbol)515 void UpdateCdf5(uint16_t* const cdf, const int symbol) {
516   UpdateCdf(cdf, 5, symbol);
517 }
518 
UpdateCdf7(uint16_t * const cdf,const int symbol)519 void UpdateCdf7(uint16_t* const cdf, const int symbol) {
520   UpdateCdf(cdf, 7, symbol);
521 }
522 
UpdateCdf8(uint16_t * const cdf,const int symbol)523 void UpdateCdf8(uint16_t* const cdf, const int symbol) {
524   UpdateCdf(cdf, 8, symbol);
525 }
526 
UpdateCdf9(uint16_t * const cdf,const int symbol)527 void UpdateCdf9(uint16_t* const cdf, const int symbol) {
528   UpdateCdf(cdf, 9, symbol);
529 }
530 
UpdateCdf11(uint16_t * const cdf,const int symbol)531 void UpdateCdf11(uint16_t* const cdf, const int symbol) {
532   UpdateCdf(cdf, 11, symbol);
533 }
534 
UpdateCdf13(uint16_t * const cdf,const int symbol)535 void UpdateCdf13(uint16_t* const cdf, const int symbol) {
536   UpdateCdf(cdf, 13, symbol);
537 }
538 
UpdateCdf16(uint16_t * const cdf,const int symbol)539 void UpdateCdf16(uint16_t* const cdf, const int symbol) {
540   UpdateCdf(cdf, 16, symbol);
541 }
542 
543 #endif  // LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
544 #endif  // LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
545 
HostToBigEndian(const DaalaBitReader::WindowSize x)546 inline DaalaBitReader::WindowSize HostToBigEndian(
547     const DaalaBitReader::WindowSize x) {
548   static_assert(sizeof(x) == 4 || sizeof(x) == 8, "");
549 #if defined(__GNUC__)
550 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
551   return (sizeof(x) == 8) ? __builtin_bswap64(x) : __builtin_bswap32(x);
552 #else
553   return x;
554 #endif
555 #elif defined(_WIN32)
556   // Note Windows targets are assumed to be little endian.
557   return static_cast<DaalaBitReader::WindowSize>(
558       (sizeof(x) == 8) ? _byteswap_uint64(static_cast<unsigned __int64>(x))
559                        : _byteswap_ulong(static_cast<unsigned long>(x)));
560 #else
561 #error Unknown compiler!
562 #endif  // defined(__GNUC__)
563 }
564 
565 }  // namespace
566 
567 #if !LIBGAV1_CXX17
568 constexpr int DaalaBitReader::kWindowSize;  // static.
569 #endif
570 
DaalaBitReader(const uint8_t * data,size_t size,bool allow_update_cdf)571 DaalaBitReader::DaalaBitReader(const uint8_t* data, size_t size,
572                                bool allow_update_cdf)
573     : data_(data),
574       data_end_(data + size),
575       data_memcpy_end_((size >= sizeof(WindowSize))
576                            ? data + size - sizeof(WindowSize) + 1
577                            : data),
578       allow_update_cdf_(allow_update_cdf),
579       values_in_range_(kCdfMaxProbability) {
580   if (data_ < data_memcpy_end_) {
581     // This is a simplified version of PopulateBits() which loads 8 extra bits
582     // and skips the unnecessary shifts of value and window_diff_.
583     WindowSize value;
584     memcpy(&value, data_, sizeof(value));
585     data_ += sizeof(value);
586     window_diff_ = HostToBigEndian(value) ^ -1;
587     // Note the initial value of bits_ is larger than kMaxCachedBits as it's
588     // used to restore the most significant 0 bit that would be present after
589     // PopulateBits() when we extract the first symbol value.
590     // As shown in Section 8.2.2 Initialization process for symbol decoder,
591     // which uses a fixed offset to read the symbol values, the most
592     // significant bit is always 0:
593     //   The variable numBits is set equal to Min( sz * 8, 15).
594     //   The variable buf is read using the f(numBits) parsing process.
595     //   The variable paddedBuf is set equal to ( buf << (15 - numBits) ).
596     //   The variable SymbolValue is set to ((1 << 15) - 1) ^ paddedBuf.
597     bits_ = kWindowSize - 15;
598     return;
599   }
600   window_diff_ = 0;
601   bits_ = -15;
602   PopulateBits();
603 }
604 
605 // This is similar to the ReadSymbol() implementation but it is optimized based
606 // on the following facts:
607 //   * The probability is fixed at half. So some multiplications can be replaced
608 //     with bit operations.
609 //   * Symbol count is fixed at 2.
ReadBit()610 int DaalaBitReader::ReadBit() {
611   const uint32_t curr =
612       ((values_in_range_ & kReadBitMask) >> 1) + kMinimumProbabilityPerSymbol;
613   const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
614   int bit = 1;
615   if (symbol_value >= curr) {
616     values_in_range_ -= curr;
617     window_diff_ -= static_cast<WindowSize>(curr) << bits_;
618     bit = 0;
619   } else {
620     values_in_range_ = curr;
621   }
622   NormalizeRange();
623   return bit;
624 }
625 
ReadLiteral(int num_bits)626 int64_t DaalaBitReader::ReadLiteral(int num_bits) {
627   assert(num_bits <= 32);
628   assert(num_bits > 0);
629   uint32_t literal = 0;
630   int bit = num_bits - 1;
631   do {
632     // ARM can combine a shift operation with a constant number of bits with
633     // some other operations, such as the OR operation.
634     // Here is an ARM disassembly example:
635     // orr w1, w0, w1, lsl #1
636     // which left shifts register w1 by 1 bit and OR the shift result with
637     // register w0.
638     // The next 2 lines are equivalent to:
639     // literal |= static_cast<uint32_t>(ReadBit()) << bit;
640     literal <<= 1;
641     literal |= static_cast<uint32_t>(ReadBit());
642   } while (--bit >= 0);
643   return literal;
644 }
645 
ReadSymbol(uint16_t * const cdf,int symbol_count)646 int DaalaBitReader::ReadSymbol(uint16_t* const cdf, int symbol_count) {
647   const int symbol = ReadSymbolImpl(cdf, symbol_count);
648   if (allow_update_cdf_) {
649     UpdateCdf(cdf, symbol_count, symbol);
650   }
651   return symbol;
652 }
653 
ReadSymbol(uint16_t * cdf)654 bool DaalaBitReader::ReadSymbol(uint16_t* cdf) {
655   assert(cdf[1] == 0);
656   const bool symbol = ReadSymbolImpl(cdf[0]) != 0;
657   if (allow_update_cdf_) {
658     const uint16_t count = cdf[2];
659     // rate is computed in the spec as:
660     //  3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
661     // In this case N is 2 and cdf[N] is |count|. So the equation becomes:
662     //  4 + (count > 15) + (count > 31)
663     // Note that the largest value for count is 32 (it is not incremented beyond
664     // 32). So using that information:
665     //  count >> 4 is 0 for count from 0 to 15.
666     //  count >> 4 is 1 for count from 16 to 31.
667     //  count >> 4 is 2 for count == 32.
668     // Now, the equation becomes:
669     //  4 + (count >> 4).
670     // Since (count >> 4) can only be 0 or 1 or 2, the addition can be replaced
671     // with bitwise or. So the final equation is:
672     //  4 | (count >> 4).
673     const int rate = 4 | (count >> 4);
674     if (symbol) {
675       cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
676     } else {
677       cdf[0] -= cdf[0] >> rate;
678     }
679     cdf[2] += static_cast<uint16_t>(count < 32);
680   }
681   return symbol;
682 }
683 
ReadSymbolWithoutCdfUpdate(uint16_t cdf)684 bool DaalaBitReader::ReadSymbolWithoutCdfUpdate(uint16_t cdf) {
685   return ReadSymbolImpl(cdf) != 0;
686 }
687 
688 template <int symbol_count>
ReadSymbol(uint16_t * const cdf)689 int DaalaBitReader::ReadSymbol(uint16_t* const cdf) {
690   static_assert(symbol_count >= 3 && symbol_count <= 16, "");
691   if (symbol_count == 3 || symbol_count == 4) {
692     return ReadSymbol3Or4(cdf, symbol_count);
693   }
694   int symbol;
695   if (symbol_count == 8) {
696     symbol = ReadSymbolImpl8(cdf);
697   } else if (symbol_count <= 13) {
698     symbol = ReadSymbolImpl(cdf, symbol_count);
699   } else {
700     symbol = ReadSymbolImplBinarySearch(cdf, symbol_count);
701   }
702   if (allow_update_cdf_) {
703     if (symbol_count == 5) {
704       UpdateCdf5(cdf, symbol);
705     } else if (symbol_count == 7) {
706       UpdateCdf7(cdf, symbol);
707     } else if (symbol_count == 8) {
708       UpdateCdf8(cdf, symbol);
709     } else if (symbol_count == 9) {
710       UpdateCdf9(cdf, symbol);
711     } else if (symbol_count == 11) {
712       UpdateCdf11(cdf, symbol);
713     } else if (symbol_count == 13) {
714       UpdateCdf13(cdf, symbol);
715     } else if (symbol_count == 16) {
716       UpdateCdf16(cdf, symbol);
717     } else {
718       UpdateCdf(cdf, symbol_count, symbol);
719     }
720   }
721   return symbol;
722 }
723 
ReadSymbolImpl(const uint16_t * const cdf,int symbol_count)724 int DaalaBitReader::ReadSymbolImpl(const uint16_t* const cdf,
725                                    int symbol_count) {
726   assert(cdf[symbol_count - 1] == 0);
727   --symbol_count;
728   uint32_t curr = values_in_range_;
729   int symbol = -1;
730   uint32_t prev;
731   const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
732   uint32_t delta = kMinimumProbabilityPerSymbol * symbol_count;
733   // Search through the |cdf| array to determine where the scaled cdf value and
734   // |symbol_value| cross over.
735   do {
736     prev = curr;
737     curr = (((values_in_range_ >> 8) * (cdf[++symbol] >> kCdfPrecision)) >> 1) +
738            delta;
739     delta -= kMinimumProbabilityPerSymbol;
740   } while (symbol_value < curr);
741   values_in_range_ = prev - curr;
742   window_diff_ -= static_cast<WindowSize>(curr) << bits_;
743   NormalizeRange();
744   return symbol;
745 }
746 
ReadSymbolImplBinarySearch(const uint16_t * const cdf,int symbol_count)747 int DaalaBitReader::ReadSymbolImplBinarySearch(const uint16_t* const cdf,
748                                                int symbol_count) {
749   assert(cdf[symbol_count - 1] == 0);
750   assert(symbol_count > 1 && symbol_count <= 16);
751   --symbol_count;
752   const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
753   // Search through the |cdf| array to determine where the scaled cdf value and
754   // |symbol_value| cross over. Since the CDFs are sorted, we can use binary
755   // search to do this. Let |symbol| be the index of the first |cdf| array
756   // entry whose scaled cdf value is less than or equal to |symbol_value|. The
757   // binary search maintains the invariant:
758   //   low <= symbol <= high + 1
759   // and terminates when low == high + 1.
760   int low = 0;
761   int high = symbol_count - 1;
762   // The binary search maintains the invariants that |prev| is the scaled cdf
763   // value for low - 1 and |curr| is the scaled cdf value for high + 1. (By
764   // convention, the scaled cdf value for -1 is values_in_range_.) When the
765   // binary search terminates, |prev| is the scaled cdf value for symbol - 1
766   // and |curr| is the scaled cdf value for |symbol|.
767   uint32_t prev = values_in_range_;
768   uint32_t curr = 0;
769   const uint32_t values_in_range_shifted = values_in_range_ >> 8;
770   do {
771     const int mid = DivideBy2(low + high);
772     const uint32_t scaled_cdf =
773         ScaleCdf(values_in_range_shifted, cdf, mid, symbol_count);
774     if (symbol_value < scaled_cdf) {
775       low = mid + 1;
776       prev = scaled_cdf;
777     } else {
778       high = mid - 1;
779       curr = scaled_cdf;
780     }
781   } while (low <= high);
782   assert(low == high + 1);
783   // At this point, |low| is the symbol that has been decoded.
784   values_in_range_ = prev - curr;
785   window_diff_ -= static_cast<WindowSize>(curr) << bits_;
786   NormalizeRange();
787   return low;
788 }
789 
ReadSymbolImpl(uint16_t cdf)790 int DaalaBitReader::ReadSymbolImpl(uint16_t cdf) {
791   const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
792   const uint32_t curr =
793       (((values_in_range_ >> 8) * (cdf >> kCdfPrecision)) >> 1) +
794       kMinimumProbabilityPerSymbol;
795   const int symbol = static_cast<int>(symbol_value < curr);
796   if (symbol == 1) {
797     values_in_range_ = curr;
798   } else {
799     values_in_range_ -= curr;
800     window_diff_ -= static_cast<WindowSize>(curr) << bits_;
801   }
802   NormalizeRange();
803   return symbol;
804 }
805 
806 // Equivalent to ReadSymbol(cdf, [3,4]), with the ReadSymbolImpl and UpdateCdf
807 // calls inlined.
ReadSymbol3Or4(uint16_t * const cdf,const int symbol_count)808 int DaalaBitReader::ReadSymbol3Or4(uint16_t* const cdf,
809                                    const int symbol_count) {
810   assert(cdf[symbol_count - 1] == 0);
811   uint32_t curr = values_in_range_;
812   uint32_t prev;
813   const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
814   uint32_t delta = kMinimumProbabilityPerSymbol * (symbol_count - 1);
815   const uint32_t values_in_range_shifted = values_in_range_ >> 8;
816 
817   // Search through the |cdf| array to determine where the scaled cdf value and
818   // |symbol_value| cross over. If allow_update_cdf_ is true, update the |cdf|
819   // array.
820   //
821   // The original code is:
822   //
823   //  int symbol = -1;
824   //  do {
825   //    prev = curr;
826   //    curr =
827   //        ((values_in_range_shifted * (cdf[++symbol] >> kCdfPrecision)) >> 1)
828   //        + delta;
829   //    delta -= kMinimumProbabilityPerSymbol;
830   //  } while (symbol_value < curr);
831   //  if (allow_update_cdf_) {
832   //    UpdateCdf(cdf, [3,4], symbol);
833   //  }
834   //
835   // The do-while loop is unrolled with three or four iterations, and the
836   // UpdateCdf call is inlined and merged into the iterations.
837   int symbol = 0;
838   // Iteration 0.
839   prev = curr;
840   curr =
841       ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
842   if (symbol_value >= curr) {
843     // symbol == 0.
844     if (allow_update_cdf_) {
845       // Inlined version of UpdateCdf(cdf, [3,4], /*symbol=*/0).
846       const uint16_t count = cdf[symbol_count];
847       cdf[symbol_count] += static_cast<uint16_t>(count < 32);
848       const int rate = (count >> 4) + 4 + static_cast<int>(symbol_count == 4);
849       if (symbol_count == 4) {
850 #if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
851         // 1. On Motorola Moto G5 Plus (running 32-bit Android 8.1.0), the ARM
852         // NEON code is slower. Consider using the C version if __arm__ is
853         // defined.
854         // 2. The ARM NEON code (compiled for arm64) is slightly slower on
855         // Samsung Galaxy S8+ (SM-G955FD).
856         uint16x4_t cdf_vec = vld1_u16(cdf);
857         const int16x4_t negative_rate = vdup_n_s16(-rate);
858         const uint16x4_t delta = vshl_u16(cdf_vec, negative_rate);
859         cdf_vec = vsub_u16(cdf_vec, delta);
860         vst1_u16(cdf, cdf_vec);
861 #elif LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
862         __m128i cdf_vec = LoadLo8(cdf);
863         const __m128i delta = _mm_sra_epi16(cdf_vec, _mm_cvtsi32_si128(rate));
864         cdf_vec = _mm_sub_epi16(cdf_vec, delta);
865         StoreLo8(cdf, cdf_vec);
866 #else  // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
867         cdf[0] -= cdf[0] >> rate;
868         cdf[1] -= cdf[1] >> rate;
869         cdf[2] -= cdf[2] >> rate;
870 #endif
871       } else {  // symbol_count == 3.
872         cdf[0] -= cdf[0] >> rate;
873         cdf[1] -= cdf[1] >> rate;
874       }
875     }
876     goto found;
877   }
878   ++symbol;
879   delta -= kMinimumProbabilityPerSymbol;
880   // Iteration 1.
881   prev = curr;
882   curr =
883       ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
884   if (symbol_value >= curr) {
885     // symbol == 1.
886     if (allow_update_cdf_) {
887       // Inlined version of UpdateCdf(cdf, [3,4], /*symbol=*/1).
888       const uint16_t count = cdf[symbol_count];
889       cdf[symbol_count] += static_cast<uint16_t>(count < 32);
890       const int rate = (count >> 4) + 4 + static_cast<int>(symbol_count == 4);
891       cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
892       cdf[1] -= cdf[1] >> rate;
893       if (symbol_count == 4) cdf[2] -= cdf[2] >> rate;
894     }
895     goto found;
896   }
897   ++symbol;
898   if (symbol_count == 4) {
899     delta -= kMinimumProbabilityPerSymbol;
900     // Iteration 2.
901     prev = curr;
902     curr = ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) +
903            delta;
904     if (symbol_value >= curr) {
905       // symbol == 2.
906       if (allow_update_cdf_) {
907         // Inlined version of UpdateCdf(cdf, 4, /*symbol=*/2).
908         const uint16_t count = cdf[4];
909         cdf[4] += static_cast<uint16_t>(count < 32);
910         const int rate = (count >> 4) + 5;
911         cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
912         cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
913         cdf[2] -= cdf[2] >> rate;
914       }
915       goto found;
916     }
917     ++symbol;
918   }
919   // |delta| is 0 for the last iteration.
920   // Iteration 2 (symbol_count == 3) or 3 (symbol_count == 4).
921   prev = curr;
922   // Since cdf[symbol_count - 1] is 0 and |delta| is 0, |curr| is also 0.
923   curr = 0;
924   // symbol == [2,3].
925   if (allow_update_cdf_) {
926     // Inlined version of UpdateCdf(cdf, [3,4], /*symbol=*/[2,3]).
927     const uint16_t count = cdf[symbol_count];
928     cdf[symbol_count] += static_cast<uint16_t>(count < 32);
929     const int rate = (4 | (count >> 4)) + static_cast<int>(symbol_count == 4);
930     if (symbol_count == 4) {
931 #if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
932       // On Motorola Moto G5 Plus (running 32-bit Android 8.1.0), the ARM NEON
933       // code is a tiny bit slower. Consider using the C version if __arm__ is
934       // defined.
935       uint16x4_t cdf_vec = vld1_u16(cdf);
936       const uint16x4_t cdf_max_probability = vdup_n_u16(kCdfMaxProbability);
937       const int16x4_t diff =
938           vreinterpret_s16_u16(vsub_u16(cdf_max_probability, cdf_vec));
939       const int16x4_t negative_rate = vdup_n_s16(-rate);
940       const uint16x4_t delta =
941           vreinterpret_u16_s16(vshl_s16(diff, negative_rate));
942       cdf_vec = vadd_u16(cdf_vec, delta);
943       vst1_u16(cdf, cdf_vec);
944       cdf[3] = 0;
945 #elif LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
946       __m128i cdf_vec = LoadLo8(cdf);
947       const __m128i cdf_max_probability =
948           _mm_shufflelo_epi16(_mm_cvtsi32_si128(kCdfMaxProbability), 0);
949       const __m128i diff = _mm_sub_epi16(cdf_max_probability, cdf_vec);
950       const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
951       cdf_vec = _mm_add_epi16(cdf_vec, delta);
952       StoreLo8(cdf, cdf_vec);
953       cdf[3] = 0;
954 #else  // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
955       cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
956       cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
957       cdf[2] += (kCdfMaxProbability - cdf[2]) >> rate;
958 #endif
959     } else {  // symbol_count == 3.
960       cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
961       cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
962     }
963   }
964 found:
965   // End of unrolled do-while loop.
966 
967   values_in_range_ = prev - curr;
968   window_diff_ -= static_cast<WindowSize>(curr) << bits_;
969   NormalizeRange();
970   return symbol;
971 }
972 
ReadSymbolImpl8(const uint16_t * const cdf)973 int DaalaBitReader::ReadSymbolImpl8(const uint16_t* const cdf) {
974   assert(cdf[7] == 0);
975   uint32_t curr = values_in_range_;
976   uint32_t prev;
977   const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
978   uint32_t delta = kMinimumProbabilityPerSymbol * 7;
979   // Search through the |cdf| array to determine where the scaled cdf value and
980   // |symbol_value| cross over.
981   //
982   // The original code is:
983   //
984   // int symbol = -1;
985   // do {
986   //   prev = curr;
987   //   curr =
988   //       (((values_in_range_ >> 8) * (cdf[++symbol] >> kCdfPrecision)) >> 1)
989   //       + delta;
990   //   delta -= kMinimumProbabilityPerSymbol;
991   // } while (symbol_value < curr);
992   //
993   // The do-while loop is unrolled with eight iterations.
994   int symbol = 0;
995 
996 #define READ_SYMBOL_ITERATION                                                \
997   prev = curr;                                                               \
998   curr = (((values_in_range_ >> 8) * (cdf[symbol] >> kCdfPrecision)) >> 1) + \
999          delta;                                                              \
1000   if (symbol_value >= curr) goto found;                                      \
1001   ++symbol;                                                                  \
1002   delta -= kMinimumProbabilityPerSymbol
1003 
1004   READ_SYMBOL_ITERATION;  // Iteration 0.
1005   READ_SYMBOL_ITERATION;  // Iteration 1.
1006   READ_SYMBOL_ITERATION;  // Iteration 2.
1007   READ_SYMBOL_ITERATION;  // Iteration 3.
1008   READ_SYMBOL_ITERATION;  // Iteration 4.
1009   READ_SYMBOL_ITERATION;  // Iteration 5.
1010 
1011   // The last two iterations can be simplified, so they don't use the
1012   // READ_SYMBOL_ITERATION macro.
1013 #undef READ_SYMBOL_ITERATION
1014 
1015   // Iteration 6.
1016   prev = curr;
1017   curr =
1018       (((values_in_range_ >> 8) * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
1019   if (symbol_value >= curr) goto found;  // symbol == 6.
1020   ++symbol;
1021   // |delta| is 0 for the last iteration.
1022   // Iteration 7.
1023   prev = curr;
1024   // Since cdf[7] is 0 and |delta| is 0, |curr| is also 0.
1025   curr = 0;
1026   // symbol == 7.
1027 found:
1028   // End of unrolled do-while loop.
1029 
1030   values_in_range_ = prev - curr;
1031   window_diff_ -= static_cast<WindowSize>(curr) << bits_;
1032   NormalizeRange();
1033   return symbol;
1034 }
1035 
PopulateBits()1036 void DaalaBitReader::PopulateBits() {
1037   constexpr int kMaxCachedBits = kWindowSize - 16;
1038 #if defined(__aarch64__)
1039   // Fast path: read eight bytes and add the first six bytes to window_diff_.
1040   // This fast path makes the following assumptions.
1041   // 1. We assume that unaligned load of uint64_t is fast.
1042   // 2. When there are enough bytes in data_, the for loop below reads 6 or 7
1043   //    bytes depending on the value of bits_. This fast path always reads 6
1044   //    bytes, which results in more calls to PopulateBits(). We assume that
1045   //    making more calls to a faster PopulateBits() is overall a win.
1046   // NOTE: Although this fast path could also be used on x86_64, it hurts
1047   // performance (measured on Lenovo ThinkStation P920 running Linux). (The
1048   // reason is still unknown.) Therefore this fast path is only used on arm64.
1049   static_assert(kWindowSize == 64, "");
1050   if (data_ < data_memcpy_end_) {
1051     uint64_t value;
1052     // arm64 supports unaligned loads, so this memcpy call is compiled to a
1053     // single ldr instruction.
1054     memcpy(&value, data_, sizeof(value));
1055     data_ += kMaxCachedBits >> 3;
1056     value = HostToBigEndian(value) ^ -1;
1057     value >>= kWindowSize - kMaxCachedBits;
1058     window_diff_ = value | (window_diff_ << kMaxCachedBits);
1059     bits_ += kMaxCachedBits;
1060     return;
1061   }
1062 #endif
1063 
1064   const uint8_t* data = data_;
1065   int bits = bits_;
1066   WindowSize window_diff = window_diff_;
1067 
1068   int count = kWindowSize - 9 - (bits + 15);
1069   // The fast path above, if compiled, would cause clang 8.0.7 to vectorize
1070   // this loop. Since -15 <= bits_ <= -1, this loop has at most 6 or 7
1071   // iterations when WindowSize is 64 bits. So it is not profitable to
1072   // vectorize this loop. Note that clang 8.0.7 does not vectorize this loop if
1073   // the fast path above is not compiled.
1074 
1075 #ifdef __clang__
1076 #pragma clang loop vectorize(disable) interleave(disable)
1077 #endif
1078   for (; count >= 0 && data < data_end_; count -= 8) {
1079     const uint8_t value = *data++ ^ -1;
1080     window_diff = static_cast<WindowSize>(value) | (window_diff << 8);
1081     bits += 8;
1082   }
1083   assert(bits <= kMaxCachedBits);
1084   if (data == data_end_) {
1085     // Shift in some 1s. This is equivalent to providing fake 0 data bits.
1086     window_diff = ((window_diff + 1) << (kMaxCachedBits - bits)) - 1;
1087     bits = kMaxCachedBits;
1088   }
1089 
1090   data_ = data;
1091   bits_ = bits;
1092   window_diff_ = window_diff;
1093 }
1094 
NormalizeRange()1095 void DaalaBitReader::NormalizeRange() {
1096   const int bits_used = 15 ^ FloorLog2(values_in_range_);
1097   bits_ -= bits_used;
1098   values_in_range_ <<= bits_used;
1099   if (bits_ < 0) PopulateBits();
1100 }
1101 
1102 // Explicit instantiations.
1103 template int DaalaBitReader::ReadSymbol<3>(uint16_t* cdf);
1104 template int DaalaBitReader::ReadSymbol<4>(uint16_t* cdf);
1105 template int DaalaBitReader::ReadSymbol<5>(uint16_t* cdf);
1106 template int DaalaBitReader::ReadSymbol<6>(uint16_t* cdf);
1107 template int DaalaBitReader::ReadSymbol<7>(uint16_t* cdf);
1108 template int DaalaBitReader::ReadSymbol<8>(uint16_t* cdf);
1109 template int DaalaBitReader::ReadSymbol<9>(uint16_t* cdf);
1110 template int DaalaBitReader::ReadSymbol<10>(uint16_t* cdf);
1111 template int DaalaBitReader::ReadSymbol<11>(uint16_t* cdf);
1112 template int DaalaBitReader::ReadSymbol<12>(uint16_t* cdf);
1113 template int DaalaBitReader::ReadSymbol<13>(uint16_t* cdf);
1114 template int DaalaBitReader::ReadSymbol<14>(uint16_t* cdf);
1115 template int DaalaBitReader::ReadSymbol<16>(uint16_t* cdf);
1116 
1117 }  // namespace libgav1
1118