1 /* Copyright (c) 2018, Google Inc.
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #include <openssl/hrss.h>
16 
17 #include <assert.h>
18 #include <stdio.h>
19 #include <stdlib.h>
20 
21 #include <openssl/bn.h>
22 #include <openssl/cpu.h>
23 #include <openssl/hmac.h>
24 #include <openssl/mem.h>
25 #include <openssl/sha.h>
26 
27 #if defined(OPENSSL_X86) || defined(OPENSSL_X86_64)
28 #include <emmintrin.h>
29 #endif
30 
31 #if (defined(OPENSSL_ARM) || defined(OPENSSL_AARCH64)) && \
32     (defined(__ARM_NEON__) || defined(__ARM_NEON))
33 #include <arm_neon.h>
34 #endif
35 
36 #if defined(_MSC_VER)
37 #define RESTRICT
38 #else
39 #define RESTRICT restrict
40 #endif
41 
42 #include "../internal.h"
43 #include "internal.h"
44 
45 // This is an implementation of [HRSS], but with a KEM transformation based on
46 // [SXY]. The primary references are:
47 
48 // HRSS: https://eprint.iacr.org/2017/667.pdf
49 // HRSSNIST:
50 // https://csrc.nist.gov/CSRC/media/Projects/Post-Quantum-Cryptography/documents/round-1/submissions/NTRU_HRSS_KEM.zip
51 // SXY: https://eprint.iacr.org/2017/1005.pdf
52 // NTRUTN14:
53 // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf
54 // NTRUCOMP:
55 // https://eprint.iacr.org/2018/1174
56 
57 
58 // Vector operations.
59 //
60 // A couple of functions in this file can use vector operations to meaningful
61 // effect. If we're building for a target that has a supported vector unit,
62 // |HRSS_HAVE_VECTOR_UNIT| will be defined and |vec_t| will be typedefed to a
63 // 128-bit vector. The following functions abstract over the differences between
64 // NEON and SSE2 for implementing some vector operations.
65 
66 // TODO: MSVC can likely also be made to work with vector operations.
67 #if ((defined(__SSE__) && defined(OPENSSL_X86)) || defined(OPENSSL_X86_64)) && \
68     (defined(__clang__) || !defined(_MSC_VER))
69 
70 #define HRSS_HAVE_VECTOR_UNIT
71 typedef __m128i vec_t;
72 
73 // vec_capable returns one iff the current platform supports SSE2.
vec_capable(void)74 static int vec_capable(void) {
75 #if defined(__SSE2__)
76   return 1;
77 #else
78   int has_sse2 = (OPENSSL_ia32cap_P[0] & (1 << 26)) != 0;
79   return has_sse2;
80 #endif
81 }
82 
83 // vec_add performs a pair-wise addition of four uint16s from |a| and |b|.
vec_add(vec_t a,vec_t b)84 static inline vec_t vec_add(vec_t a, vec_t b) { return _mm_add_epi16(a, b); }
85 
86 // vec_sub performs a pair-wise subtraction of four uint16s from |a| and |b|.
vec_sub(vec_t a,vec_t b)87 static inline vec_t vec_sub(vec_t a, vec_t b) { return _mm_sub_epi16(a, b); }
88 
89 // vec_mul multiplies each uint16_t in |a| by |b| and returns the resulting
90 // vector.
vec_mul(vec_t a,uint16_t b)91 static inline vec_t vec_mul(vec_t a, uint16_t b) {
92   return _mm_mullo_epi16(a, _mm_set1_epi16(b));
93 }
94 
95 // vec_fma multiplies each uint16_t in |b| by |c|, adds the result to |a|, and
96 // returns the resulting vector.
vec_fma(vec_t a,vec_t b,uint16_t c)97 static inline vec_t vec_fma(vec_t a, vec_t b, uint16_t c) {
98   return _mm_add_epi16(a, _mm_mullo_epi16(b, _mm_set1_epi16(c)));
99 }
100 
101 // vec3_rshift_word right-shifts the 24 uint16_t's in |v| by one uint16.
vec3_rshift_word(vec_t v[3])102 static inline void vec3_rshift_word(vec_t v[3]) {
103   // Intel's left and right shifting is backwards compared to the order in
104   // memory because they're based on little-endian order of words (and not just
105   // bytes). So the shifts in this function will be backwards from what one
106   // might expect.
107   const __m128i carry0 = _mm_srli_si128(v[0], 14);
108   v[0] = _mm_slli_si128(v[0], 2);
109 
110   const __m128i carry1 = _mm_srli_si128(v[1], 14);
111   v[1] = _mm_slli_si128(v[1], 2);
112   v[1] |= carry0;
113 
114   v[2] = _mm_slli_si128(v[2], 2);
115   v[2] |= carry1;
116 }
117 
118 // vec4_rshift_word right-shifts the 32 uint16_t's in |v| by one uint16.
vec4_rshift_word(vec_t v[4])119 static inline void vec4_rshift_word(vec_t v[4]) {
120   // Intel's left and right shifting is backwards compared to the order in
121   // memory because they're based on little-endian order of words (and not just
122   // bytes). So the shifts in this function will be backwards from what one
123   // might expect.
124   const __m128i carry0 = _mm_srli_si128(v[0], 14);
125   v[0] = _mm_slli_si128(v[0], 2);
126 
127   const __m128i carry1 = _mm_srli_si128(v[1], 14);
128   v[1] = _mm_slli_si128(v[1], 2);
129   v[1] |= carry0;
130 
131   const __m128i carry2 = _mm_srli_si128(v[2], 14);
132   v[2] = _mm_slli_si128(v[2], 2);
133   v[2] |= carry1;
134 
135   v[3] = _mm_slli_si128(v[3], 2);
136   v[3] |= carry2;
137 }
138 
139 // vec_merge_3_5 takes the final three uint16_t's from |left|, appends the first
140 // five from |right|, and returns the resulting vector.
vec_merge_3_5(vec_t left,vec_t right)141 static inline vec_t vec_merge_3_5(vec_t left, vec_t right) {
142   return _mm_srli_si128(left, 10) | _mm_slli_si128(right, 6);
143 }
144 
145 // poly3_vec_lshift1 left-shifts the 768 bits in |a_s|, and in |a_a|, by one
146 // bit.
poly3_vec_lshift1(vec_t a_s[6],vec_t a_a[6])147 static inline void poly3_vec_lshift1(vec_t a_s[6], vec_t a_a[6]) {
148   vec_t carry_s = {0};
149   vec_t carry_a = {0};
150 
151   for (int i = 0; i < 6; i++) {
152     vec_t next_carry_s = _mm_srli_epi64(a_s[i], 63);
153     a_s[i] = _mm_slli_epi64(a_s[i], 1);
154     a_s[i] |= _mm_slli_si128(next_carry_s, 8);
155     a_s[i] |= carry_s;
156     carry_s = _mm_srli_si128(next_carry_s, 8);
157 
158     vec_t next_carry_a = _mm_srli_epi64(a_a[i], 63);
159     a_a[i] = _mm_slli_epi64(a_a[i], 1);
160     a_a[i] |= _mm_slli_si128(next_carry_a, 8);
161     a_a[i] |= carry_a;
162     carry_a = _mm_srli_si128(next_carry_a, 8);
163   }
164 }
165 
166 // poly3_vec_rshift1 right-shifts the 768 bits in |a_s|, and in |a_a|, by one
167 // bit.
poly3_vec_rshift1(vec_t a_s[6],vec_t a_a[6])168 static inline void poly3_vec_rshift1(vec_t a_s[6], vec_t a_a[6]) {
169   vec_t carry_s = {0};
170   vec_t carry_a = {0};
171 
172   for (int i = 5; i >= 0; i--) {
173     const vec_t next_carry_s = _mm_slli_epi64(a_s[i], 63);
174     a_s[i] = _mm_srli_epi64(a_s[i], 1);
175     a_s[i] |= _mm_srli_si128(next_carry_s, 8);
176     a_s[i] |= carry_s;
177     carry_s = _mm_slli_si128(next_carry_s, 8);
178 
179     const vec_t next_carry_a = _mm_slli_epi64(a_a[i], 63);
180     a_a[i] = _mm_srli_epi64(a_a[i], 1);
181     a_a[i] |= _mm_srli_si128(next_carry_a, 8);
182     a_a[i] |= carry_a;
183     carry_a = _mm_slli_si128(next_carry_a, 8);
184   }
185 }
186 
187 // vec_broadcast_bit duplicates the least-significant bit in |a| to all bits in
188 // a vector and returns the result.
vec_broadcast_bit(vec_t a)189 static inline vec_t vec_broadcast_bit(vec_t a) {
190   return _mm_shuffle_epi32(_mm_srai_epi32(_mm_slli_epi64(a, 63), 31),
191                            0b01010101);
192 }
193 
194 // vec_broadcast_bit15 duplicates the most-significant bit of the first word in
195 // |a| to all bits in a vector and returns the result.
vec_broadcast_bit15(vec_t a)196 static inline vec_t vec_broadcast_bit15(vec_t a) {
197   return _mm_shuffle_epi32(_mm_srai_epi32(_mm_slli_epi64(a, 63 - 15), 31),
198                            0b01010101);
199 }
200 
201 // vec_get_word returns the |i|th uint16_t in |v|. (This is a macro because the
202 // compiler requires that |i| be a compile-time constant.)
203 #define vec_get_word(v, i) _mm_extract_epi16(v, i)
204 
205 #elif (defined(OPENSSL_ARM) || defined(OPENSSL_AARCH64)) && \
206     (defined(__ARM_NEON__) || defined(__ARM_NEON))
207 
208 #define HRSS_HAVE_VECTOR_UNIT
209 typedef uint16x8_t vec_t;
210 
211 // These functions perform the same actions as the SSE2 function of the same
212 // name, above.
213 
vec_capable(void)214 static int vec_capable(void) { return CRYPTO_is_NEON_capable(); }
215 
vec_add(vec_t a,vec_t b)216 static inline vec_t vec_add(vec_t a, vec_t b) { return a + b; }
217 
vec_sub(vec_t a,vec_t b)218 static inline vec_t vec_sub(vec_t a, vec_t b) { return a - b; }
219 
vec_mul(vec_t a,uint16_t b)220 static inline vec_t vec_mul(vec_t a, uint16_t b) { return vmulq_n_u16(a, b); }
221 
vec_fma(vec_t a,vec_t b,uint16_t c)222 static inline vec_t vec_fma(vec_t a, vec_t b, uint16_t c) {
223   return vmlaq_n_u16(a, b, c);
224 }
225 
vec3_rshift_word(vec_t v[3])226 static inline void vec3_rshift_word(vec_t v[3]) {
227   const uint16x8_t kZero = {0};
228   v[2] = vextq_u16(v[1], v[2], 7);
229   v[1] = vextq_u16(v[0], v[1], 7);
230   v[0] = vextq_u16(kZero, v[0], 7);
231 }
232 
vec4_rshift_word(vec_t v[4])233 static inline void vec4_rshift_word(vec_t v[4]) {
234   const uint16x8_t kZero = {0};
235   v[3] = vextq_u16(v[2], v[3], 7);
236   v[2] = vextq_u16(v[1], v[2], 7);
237   v[1] = vextq_u16(v[0], v[1], 7);
238   v[0] = vextq_u16(kZero, v[0], 7);
239 }
240 
vec_merge_3_5(vec_t left,vec_t right)241 static inline vec_t vec_merge_3_5(vec_t left, vec_t right) {
242   return vextq_u16(left, right, 5);
243 }
244 
vec_get_word(vec_t v,unsigned i)245 static inline uint16_t vec_get_word(vec_t v, unsigned i) {
246   return v[i];
247 }
248 
249 #if !defined(OPENSSL_AARCH64)
250 
vec_broadcast_bit(vec_t a)251 static inline vec_t vec_broadcast_bit(vec_t a) {
252   a = (vec_t)vshrq_n_s16(((int16x8_t)a) << 15, 15);
253   return vdupq_lane_u16(vget_low_u16(a), 0);
254 }
255 
vec_broadcast_bit15(vec_t a)256 static inline vec_t vec_broadcast_bit15(vec_t a) {
257   a = (vec_t)vshrq_n_s16((int16x8_t)a, 15);
258   return vdupq_lane_u16(vget_low_u16(a), 0);
259 }
260 
poly3_vec_lshift1(vec_t a_s[6],vec_t a_a[6])261 static inline void poly3_vec_lshift1(vec_t a_s[6], vec_t a_a[6]) {
262   vec_t carry_s = {0};
263   vec_t carry_a = {0};
264   const vec_t kZero = {0};
265 
266   for (int i = 0; i < 6; i++) {
267     vec_t next_carry_s = a_s[i] >> 15;
268     a_s[i] <<= 1;
269     a_s[i] |= vextq_u16(kZero, next_carry_s, 7);
270     a_s[i] |= carry_s;
271     carry_s = vextq_u16(next_carry_s, kZero, 7);
272 
273     vec_t next_carry_a = a_a[i] >> 15;
274     a_a[i] <<= 1;
275     a_a[i] |= vextq_u16(kZero, next_carry_a, 7);
276     a_a[i] |= carry_a;
277     carry_a = vextq_u16(next_carry_a, kZero, 7);
278   }
279 }
280 
poly3_vec_rshift1(vec_t a_s[6],vec_t a_a[6])281 static inline void poly3_vec_rshift1(vec_t a_s[6], vec_t a_a[6]) {
282   vec_t carry_s = {0};
283   vec_t carry_a = {0};
284   const vec_t kZero = {0};
285 
286   for (int i = 5; i >= 0; i--) {
287     vec_t next_carry_s = a_s[i] << 15;
288     a_s[i] >>= 1;
289     a_s[i] |= vextq_u16(next_carry_s, kZero, 1);
290     a_s[i] |= carry_s;
291     carry_s = vextq_u16(kZero, next_carry_s, 1);
292 
293     vec_t next_carry_a = a_a[i] << 15;
294     a_a[i] >>= 1;
295     a_a[i] |= vextq_u16(next_carry_a, kZero, 1);
296     a_a[i] |= carry_a;
297     carry_a = vextq_u16(kZero, next_carry_a, 1);
298   }
299 }
300 
301 #endif  // !OPENSSL_AARCH64
302 
303 #endif  // (ARM || AARCH64) && NEON
304 
305 // Polynomials in this scheme have N terms.
306 // #define N 701
307 
308 // Underlying data types and arithmetic operations.
309 // ------------------------------------------------
310 
311 // Binary polynomials.
312 
313 // poly2 represents a degree-N polynomial over GF(2). The words are in little-
314 // endian order, i.e. the coefficient of x^0 is the LSB of the first word. The
315 // final word is only partially used since N is not a multiple of the word size.
316 
317 // Defined in internal.h:
318 // struct poly2 {
319 //  crypto_word_t v[WORDS_PER_POLY];
320 // };
321 
hexdump(const void * void_in,size_t len)322 OPENSSL_UNUSED static void hexdump(const void *void_in, size_t len) {
323   const uint8_t *in = (const uint8_t *)void_in;
324   for (size_t i = 0; i < len; i++) {
325     printf("%02x", in[i]);
326   }
327   printf("\n");
328 }
329 
poly2_zero(struct poly2 * p)330 static void poly2_zero(struct poly2 *p) {
331   OPENSSL_memset(&p->v[0], 0, sizeof(crypto_word_t) * WORDS_PER_POLY);
332 }
333 
334 // poly2_cmov sets |out| to |in| iff |mov| is all ones.
poly2_cmov(struct poly2 * out,const struct poly2 * in,crypto_word_t mov)335 static void poly2_cmov(struct poly2 *out, const struct poly2 *in,
336                        crypto_word_t mov) {
337   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
338     out->v[i] = (out->v[i] & ~mov) | (in->v[i] & mov);
339   }
340 }
341 
342 // poly2_rotr_words performs a right-rotate on |in|, writing the result to
343 // |out|. The shift count, |bits|, must be a non-zero multiple of the word size.
poly2_rotr_words(struct poly2 * out,const struct poly2 * in,size_t bits)344 static void poly2_rotr_words(struct poly2 *out, const struct poly2 *in,
345                              size_t bits) {
346   assert(bits >= BITS_PER_WORD && bits % BITS_PER_WORD == 0);
347   assert(out != in);
348 
349   const size_t start = bits / BITS_PER_WORD;
350   const size_t n = (N - bits) / BITS_PER_WORD;
351 
352   // The rotate is by a whole number of words so the first few words are easy:
353   // just move them down.
354   for (size_t i = 0; i < n; i++) {
355     out->v[i] = in->v[start + i];
356   }
357 
358   // Since the last word is only partially filled, however, the remainder needs
359   // shifting and merging of words to take care of that.
360   crypto_word_t carry = in->v[WORDS_PER_POLY - 1];
361 
362   for (size_t i = 0; i < start; i++) {
363     out->v[n + i] = carry | in->v[i] << BITS_IN_LAST_WORD;
364     carry = in->v[i] >> (BITS_PER_WORD - BITS_IN_LAST_WORD);
365   }
366 
367   out->v[WORDS_PER_POLY - 1] = carry;
368 }
369 
370 // poly2_rotr_bits performs a right-rotate on |in|, writing the result to |out|.
371 // The shift count, |bits|, must be a power of two that is less than
372 // |BITS_PER_WORD|.
poly2_rotr_bits(struct poly2 * out,const struct poly2 * in,size_t bits)373 static void poly2_rotr_bits(struct poly2 *out, const struct poly2 *in,
374                             size_t bits) {
375   assert(bits <= BITS_PER_WORD / 2);
376   assert(bits != 0);
377   assert((bits & (bits - 1)) == 0);
378   assert(out != in);
379 
380   // BITS_PER_WORD/2 is the greatest legal value of |bits|. If
381   // |BITS_IN_LAST_WORD| is smaller than this then the code below doesn't work
382   // because more than the last word needs to carry down in the previous one and
383   // so on.
384   OPENSSL_STATIC_ASSERT(
385       BITS_IN_LAST_WORD >= BITS_PER_WORD / 2,
386       "there are more carry bits than fit in BITS_IN_LAST_WORD");
387 
388   crypto_word_t carry = in->v[WORDS_PER_POLY - 1] << (BITS_PER_WORD - bits);
389 
390   for (size_t i = WORDS_PER_POLY - 2; i < WORDS_PER_POLY; i--) {
391     out->v[i] = carry | in->v[i] >> bits;
392     carry = in->v[i] << (BITS_PER_WORD - bits);
393   }
394 
395   crypto_word_t last_word = carry >> (BITS_PER_WORD - BITS_IN_LAST_WORD) |
396                             in->v[WORDS_PER_POLY - 1] >> bits;
397   last_word &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
398   out->v[WORDS_PER_POLY - 1] = last_word;
399 }
400 
401 // HRSS_poly2_rotr_consttime right-rotates |p| by |bits| in constant-time.
HRSS_poly2_rotr_consttime(struct poly2 * p,size_t bits)402 void HRSS_poly2_rotr_consttime(struct poly2 *p, size_t bits) {
403   assert(bits <= N);
404   assert(p->v[WORDS_PER_POLY-1] >> BITS_IN_LAST_WORD == 0);
405 
406   // Constant-time rotation is implemented by calculating the rotations of
407   // powers-of-two bits and throwing away the unneeded values. 2^9 (i.e. 512) is
408   // the largest power-of-two shift that we need to consider because 2^10 > N.
409 #define HRSS_POLY2_MAX_SHIFT 9
410   size_t shift = HRSS_POLY2_MAX_SHIFT;
411   OPENSSL_STATIC_ASSERT((1 << (HRSS_POLY2_MAX_SHIFT + 1)) > N,
412                         "maximum shift is too small");
413   OPENSSL_STATIC_ASSERT((1 << HRSS_POLY2_MAX_SHIFT) <= N,
414                         "maximum shift is too large");
415   struct poly2 shifted;
416 
417   for (; (UINT64_C(1) << shift) >= BITS_PER_WORD; shift--) {
418     poly2_rotr_words(&shifted, p, UINT64_C(1) << shift);
419     poly2_cmov(p, &shifted, ~((1 & (bits >> shift)) - 1));
420   }
421 
422   for (; shift < HRSS_POLY2_MAX_SHIFT; shift--) {
423     poly2_rotr_bits(&shifted, p, UINT64_C(1) << shift);
424     poly2_cmov(p, &shifted, ~((1 & (bits >> shift)) - 1));
425   }
426 #undef HRSS_POLY2_MAX_SHIFT
427 }
428 
429 // poly2_cswap exchanges the values of |a| and |b| if |swap| is all ones.
poly2_cswap(struct poly2 * a,struct poly2 * b,crypto_word_t swap)430 static void poly2_cswap(struct poly2 *a, struct poly2 *b, crypto_word_t swap) {
431   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
432     const crypto_word_t sum = swap & (a->v[i] ^ b->v[i]);
433     a->v[i] ^= sum;
434     b->v[i] ^= sum;
435   }
436 }
437 
438 // poly2_fmadd sets |out| to |out| + |in| * m, where m is either
439 // |CONSTTIME_TRUE_W| or |CONSTTIME_FALSE_W|.
poly2_fmadd(struct poly2 * out,const struct poly2 * in,crypto_word_t m)440 static void poly2_fmadd(struct poly2 *out, const struct poly2 *in,
441                         crypto_word_t m) {
442   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
443     out->v[i] ^= in->v[i] & m;
444   }
445 }
446 
447 // poly2_lshift1 left-shifts |p| by one bit.
poly2_lshift1(struct poly2 * p)448 static void poly2_lshift1(struct poly2 *p) {
449   crypto_word_t carry = 0;
450   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
451     const crypto_word_t next_carry = p->v[i] >> (BITS_PER_WORD - 1);
452     p->v[i] <<= 1;
453     p->v[i] |= carry;
454     carry = next_carry;
455   }
456 }
457 
458 // poly2_rshift1 right-shifts |p| by one bit.
poly2_rshift1(struct poly2 * p)459 static void poly2_rshift1(struct poly2 *p) {
460   crypto_word_t carry = 0;
461   for (size_t i = WORDS_PER_POLY - 1; i < WORDS_PER_POLY; i--) {
462     const crypto_word_t next_carry = p->v[i] & 1;
463     p->v[i] >>= 1;
464     p->v[i] |= carry << (BITS_PER_WORD - 1);
465     carry = next_carry;
466   }
467 }
468 
469 // poly2_clear_top_bits clears the bits in the final word that are only for
470 // alignment.
poly2_clear_top_bits(struct poly2 * p)471 static void poly2_clear_top_bits(struct poly2 *p) {
472   p->v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
473 }
474 
475 // poly2_top_bits_are_clear returns one iff the extra bits in the final words of
476 // |p| are zero.
poly2_top_bits_are_clear(const struct poly2 * p)477 static int poly2_top_bits_are_clear(const struct poly2 *p) {
478   return (p->v[WORDS_PER_POLY - 1] &
479           ~((UINT64_C(1) << BITS_IN_LAST_WORD) - 1)) == 0;
480 }
481 
482 // Ternary polynomials.
483 
484 // poly3 represents a degree-N polynomial over GF(3). Each coefficient is
485 // bitsliced across the |s| and |a| arrays, like this:
486 //
487 //   s  |  a  | value
488 //  -----------------
489 //   0  |  0  | 0
490 //   0  |  1  | 1
491 //   1  |  1  | -1 (aka 2)
492 //   1  |  0  | <invalid>
493 //
494 // ('s' is for sign, and 'a' is the absolute value.)
495 //
496 // Once bitsliced as such, the following circuits can be used to implement
497 // addition and multiplication mod 3:
498 //
499 //   (s3, a3) = (s1, a1) × (s2, a2)
500 //   a3 = a1 ∧ a2
501 //   s3 = (s1 ⊕ s2) ∧ a3
502 //
503 //   (s3, a3) = (s1, a1) + (s2, a2)
504 //   t = s1 ⊕ a2
505 //   s3 = t ∧ (s2 ⊕ a1)
506 //   a3 = (a1 ⊕ a2) ∨ (t ⊕ s2)
507 //
508 //   (s3, a3) = (s1, a1) - (s2, a2)
509 //   t = a1 ⊕ a2
510 //   s3 = (s1 ⊕ a2) ∧ (t ⊕ s2)
511 //   a3 = t ∨ (s1 ⊕ s2)
512 //
513 // Negating a value just involves XORing s by a.
514 //
515 // struct poly3 {
516 //   struct poly2 s, a;
517 // };
518 
poly3_print(const struct poly3 * in)519 OPENSSL_UNUSED static void poly3_print(const struct poly3 *in) {
520   struct poly3 p;
521   OPENSSL_memcpy(&p, in, sizeof(p));
522   p.s.v[WORDS_PER_POLY - 1] &= ((crypto_word_t)1 << BITS_IN_LAST_WORD) - 1;
523   p.a.v[WORDS_PER_POLY - 1] &= ((crypto_word_t)1 << BITS_IN_LAST_WORD) - 1;
524 
525   printf("{[");
526   for (unsigned i = 0; i < WORDS_PER_POLY; i++) {
527     if (i) {
528       printf(" ");
529     }
530     printf(BN_HEX_FMT2, p.s.v[i]);
531   }
532   printf("] [");
533   for (unsigned i = 0; i < WORDS_PER_POLY; i++) {
534     if (i) {
535       printf(" ");
536     }
537     printf(BN_HEX_FMT2, p.a.v[i]);
538   }
539   printf("]}\n");
540 }
541 
poly3_zero(struct poly3 * p)542 static void poly3_zero(struct poly3 *p) {
543   poly2_zero(&p->s);
544   poly2_zero(&p->a);
545 }
546 
547 // poly3_word_mul sets (|out_s|, |out_a) to (|s1|, |a1|) × (|s2|, |a2|).
poly3_word_mul(crypto_word_t * out_s,crypto_word_t * out_a,const crypto_word_t s1,const crypto_word_t a1,const crypto_word_t s2,const crypto_word_t a2)548 static void poly3_word_mul(crypto_word_t *out_s, crypto_word_t *out_a,
549                            const crypto_word_t s1, const crypto_word_t a1,
550                            const crypto_word_t s2, const crypto_word_t a2) {
551   *out_a = a1 & a2;
552   *out_s = (s1 ^ s2) & *out_a;
553 }
554 
555 // poly3_word_add sets (|out_s|, |out_a|) to (|s1|, |a1|) + (|s2|, |a2|).
poly3_word_add(crypto_word_t * out_s,crypto_word_t * out_a,const crypto_word_t s1,const crypto_word_t a1,const crypto_word_t s2,const crypto_word_t a2)556 static void poly3_word_add(crypto_word_t *out_s, crypto_word_t *out_a,
557                            const crypto_word_t s1, const crypto_word_t a1,
558                            const crypto_word_t s2, const crypto_word_t a2) {
559   const crypto_word_t t = s1 ^ a2;
560   *out_s = t & (s2 ^ a1);
561   *out_a = (a1 ^ a2) | (t ^ s2);
562 }
563 
564 // poly3_word_sub sets (|out_s|, |out_a|) to (|s1|, |a1|) - (|s2|, |a2|).
poly3_word_sub(crypto_word_t * out_s,crypto_word_t * out_a,const crypto_word_t s1,const crypto_word_t a1,const crypto_word_t s2,const crypto_word_t a2)565 static void poly3_word_sub(crypto_word_t *out_s, crypto_word_t *out_a,
566                            const crypto_word_t s1, const crypto_word_t a1,
567                            const crypto_word_t s2, const crypto_word_t a2) {
568   const crypto_word_t t = a1 ^ a2;
569   *out_s = (s1 ^ a2) & (t ^ s2);
570   *out_a = t | (s1 ^ s2);
571 }
572 
573 // lsb_to_all replicates the least-significant bit of |v| to all bits of the
574 // word. This is used in bit-slicing operations to make a vector from a fixed
575 // value.
lsb_to_all(crypto_word_t v)576 static crypto_word_t lsb_to_all(crypto_word_t v) { return 0u - (v & 1); }
577 
578 // poly3_mul_const sets |p| to |p|×m, where m = (ms, ma).
poly3_mul_const(struct poly3 * p,crypto_word_t ms,crypto_word_t ma)579 static void poly3_mul_const(struct poly3 *p, crypto_word_t ms,
580                             crypto_word_t ma) {
581   ms = lsb_to_all(ms);
582   ma = lsb_to_all(ma);
583 
584   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
585     poly3_word_mul(&p->s.v[i], &p->a.v[i], p->s.v[i], p->a.v[i], ms, ma);
586   }
587 }
588 
589 // poly3_rotr_consttime right-rotates |p| by |bits| in constant-time.
poly3_rotr_consttime(struct poly3 * p,size_t bits)590 static void poly3_rotr_consttime(struct poly3 *p, size_t bits) {
591   assert(bits <= N);
592   HRSS_poly2_rotr_consttime(&p->s, bits);
593   HRSS_poly2_rotr_consttime(&p->a, bits);
594 }
595 
596 // poly3_fmadd sets |out| to |out| - |in|×m, where m is (ms, ma).
poly3_fmsub(struct poly3 * RESTRICT out,const struct poly3 * RESTRICT in,crypto_word_t ms,crypto_word_t ma)597 static void poly3_fmsub(struct poly3 *RESTRICT out,
598                         const struct poly3 *RESTRICT in, crypto_word_t ms,
599                         crypto_word_t ma) {
600   crypto_word_t product_s, product_a;
601   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
602     poly3_word_mul(&product_s, &product_a, in->s.v[i], in->a.v[i], ms, ma);
603     poly3_word_sub(&out->s.v[i], &out->a.v[i], out->s.v[i], out->a.v[i],
604                    product_s, product_a);
605   }
606 }
607 
608 // final_bit_to_all replicates the bit in the final position of the last word to
609 // all the bits in the word.
final_bit_to_all(crypto_word_t v)610 static crypto_word_t final_bit_to_all(crypto_word_t v) {
611   return lsb_to_all(v >> (BITS_IN_LAST_WORD - 1));
612 }
613 
614 // poly3_top_bits_are_clear returns one iff the extra bits in the final words of
615 // |p| are zero.
poly3_top_bits_are_clear(const struct poly3 * p)616 OPENSSL_UNUSED static int poly3_top_bits_are_clear(const struct poly3 *p) {
617   return poly2_top_bits_are_clear(&p->s) && poly2_top_bits_are_clear(&p->a);
618 }
619 
620 // poly3_mod_phiN reduces |p| by Φ(N).
poly3_mod_phiN(struct poly3 * p)621 static void poly3_mod_phiN(struct poly3 *p) {
622   // In order to reduce by Φ(N) we subtract by the value of the greatest
623   // coefficient.
624   const crypto_word_t factor_s = final_bit_to_all(p->s.v[WORDS_PER_POLY - 1]);
625   const crypto_word_t factor_a = final_bit_to_all(p->a.v[WORDS_PER_POLY - 1]);
626 
627   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
628     poly3_word_sub(&p->s.v[i], &p->a.v[i], p->s.v[i], p->a.v[i], factor_s,
629                    factor_a);
630   }
631 
632   poly2_clear_top_bits(&p->s);
633   poly2_clear_top_bits(&p->a);
634 }
635 
poly3_cswap(struct poly3 * a,struct poly3 * b,crypto_word_t swap)636 static void poly3_cswap(struct poly3 *a, struct poly3 *b, crypto_word_t swap) {
637   poly2_cswap(&a->s, &b->s, swap);
638   poly2_cswap(&a->a, &b->a, swap);
639 }
640 
poly3_lshift1(struct poly3 * p)641 static void poly3_lshift1(struct poly3 *p) {
642   poly2_lshift1(&p->s);
643   poly2_lshift1(&p->a);
644 }
645 
poly3_rshift1(struct poly3 * p)646 static void poly3_rshift1(struct poly3 *p) {
647   poly2_rshift1(&p->s);
648   poly2_rshift1(&p->a);
649 }
650 
651 // poly3_span represents a pointer into a poly3.
652 struct poly3_span {
653   crypto_word_t *s;
654   crypto_word_t *a;
655 };
656 
657 // poly3_span_add adds |n| words of values from |a| and |b| and writes the
658 // result to |out|.
poly3_span_add(const struct poly3_span * out,const struct poly3_span * a,const struct poly3_span * b,size_t n)659 static void poly3_span_add(const struct poly3_span *out,
660                            const struct poly3_span *a,
661                            const struct poly3_span *b, size_t n) {
662   for (size_t i = 0; i < n; i++) {
663     poly3_word_add(&out->s[i], &out->a[i], a->s[i], a->a[i], b->s[i], b->a[i]);
664   }
665 }
666 
667 // poly3_span_sub subtracts |n| words of |b| from |n| words of |a|.
poly3_span_sub(const struct poly3_span * a,const struct poly3_span * b,size_t n)668 static void poly3_span_sub(const struct poly3_span *a,
669                            const struct poly3_span *b, size_t n) {
670   for (size_t i = 0; i < n; i++) {
671     poly3_word_sub(&a->s[i], &a->a[i], a->s[i], a->a[i], b->s[i], b->a[i]);
672   }
673 }
674 
675 // poly3_mul_aux is a recursive function that multiplies |n| words from |a| and
676 // |b| and writes 2×|n| words to |out|. Each call uses 2*ceil(n/2) elements of
677 // |scratch| and the function recurses, except if |n| == 1, when |scratch| isn't
678 // used and the recursion stops. For |n| in {11, 22}, the transitive total
679 // amount of |scratch| needed happens to be 2n+2.
poly3_mul_aux(const struct poly3_span * out,const struct poly3_span * scratch,const struct poly3_span * a,const struct poly3_span * b,size_t n)680 static void poly3_mul_aux(const struct poly3_span *out,
681                           const struct poly3_span *scratch,
682                           const struct poly3_span *a,
683                           const struct poly3_span *b, size_t n) {
684   if (n == 1) {
685     crypto_word_t r_s_low = 0, r_s_high = 0, r_a_low = 0, r_a_high = 0;
686     crypto_word_t b_s = b->s[0], b_a = b->a[0];
687     const crypto_word_t a_s = a->s[0], a_a = a->a[0];
688 
689     for (size_t i = 0; i < BITS_PER_WORD; i++) {
690       // Multiply (s, a) by the next value from (b_s, b_a).
691       crypto_word_t m_s, m_a;
692       poly3_word_mul(&m_s, &m_a, a_s, a_a, lsb_to_all(b_s), lsb_to_all(b_a));
693       b_s >>= 1;
694       b_a >>= 1;
695 
696       if (i == 0) {
697         // Special case otherwise the code tries to shift by BITS_PER_WORD
698         // below, which is undefined.
699         r_s_low = m_s;
700         r_a_low = m_a;
701         continue;
702       }
703 
704       // Shift the multiplication result to the correct position.
705       const crypto_word_t m_s_low = m_s << i;
706       const crypto_word_t m_s_high = m_s >> (BITS_PER_WORD - i);
707       const crypto_word_t m_a_low = m_a << i;
708       const crypto_word_t m_a_high = m_a >> (BITS_PER_WORD - i);
709 
710       // Add into the result.
711       poly3_word_add(&r_s_low, &r_a_low, r_s_low, r_a_low, m_s_low, m_a_low);
712       poly3_word_add(&r_s_high, &r_a_high, r_s_high, r_a_high, m_s_high,
713                      m_a_high);
714     }
715 
716     out->s[0] = r_s_low;
717     out->s[1] = r_s_high;
718     out->a[0] = r_a_low;
719     out->a[1] = r_a_high;
720     return;
721   }
722 
723   // Karatsuba multiplication.
724   // https://en.wikipedia.org/wiki/Karatsuba_algorithm
725 
726   // When |n| is odd, the two "halves" will have different lengths. The first
727   // is always the smaller.
728   const size_t low_len = n / 2;
729   const size_t high_len = n - low_len;
730   const struct poly3_span a_high = {&a->s[low_len], &a->a[low_len]};
731   const struct poly3_span b_high = {&b->s[low_len], &b->a[low_len]};
732 
733   // Store a_1 + a_0 in the first half of |out| and b_1 + b_0 in the second
734   // half.
735   const struct poly3_span a_cross_sum = *out;
736   const struct poly3_span b_cross_sum = {&out->s[high_len], &out->a[high_len]};
737   poly3_span_add(&a_cross_sum, a, &a_high, low_len);
738   poly3_span_add(&b_cross_sum, b, &b_high, low_len);
739   if (high_len != low_len) {
740     a_cross_sum.s[low_len] = a_high.s[low_len];
741     a_cross_sum.a[low_len] = a_high.a[low_len];
742     b_cross_sum.s[low_len] = b_high.s[low_len];
743     b_cross_sum.a[low_len] = b_high.a[low_len];
744   }
745 
746   const struct poly3_span child_scratch = {&scratch->s[2 * high_len],
747                                            &scratch->a[2 * high_len]};
748   const struct poly3_span out_mid = {&out->s[low_len], &out->a[low_len]};
749   const struct poly3_span out_high = {&out->s[2 * low_len],
750                                       &out->a[2 * low_len]};
751 
752   // Calculate (a_1 + a_0) × (b_1 + b_0) and write to scratch buffer.
753   poly3_mul_aux(scratch, &child_scratch, &a_cross_sum, &b_cross_sum, high_len);
754   // Calculate a_1 × b_1.
755   poly3_mul_aux(&out_high, &child_scratch, &a_high, &b_high, high_len);
756   // Calculate a_0 × b_0.
757   poly3_mul_aux(out, &child_scratch, a, b, low_len);
758 
759   // Subtract those last two products from the first.
760   poly3_span_sub(scratch, out, low_len * 2);
761   poly3_span_sub(scratch, &out_high, high_len * 2);
762 
763   // Add the middle product into the output.
764   poly3_span_add(&out_mid, &out_mid, scratch, high_len * 2);
765 }
766 
767 // HRSS_poly3_mul sets |*out| to |x|×|y| mod Φ(N).
HRSS_poly3_mul(struct poly3 * out,const struct poly3 * x,const struct poly3 * y)768 void HRSS_poly3_mul(struct poly3 *out, const struct poly3 *x,
769                     const struct poly3 *y) {
770   crypto_word_t prod_s[WORDS_PER_POLY * 2];
771   crypto_word_t prod_a[WORDS_PER_POLY * 2];
772   crypto_word_t scratch_s[WORDS_PER_POLY * 2 + 2];
773   crypto_word_t scratch_a[WORDS_PER_POLY * 2 + 2];
774   const struct poly3_span prod_span = {prod_s, prod_a};
775   const struct poly3_span scratch_span = {scratch_s, scratch_a};
776   const struct poly3_span x_span = {(crypto_word_t *)x->s.v,
777                                     (crypto_word_t *)x->a.v};
778   const struct poly3_span y_span = {(crypto_word_t *)y->s.v,
779                                     (crypto_word_t *)y->a.v};
780 
781   poly3_mul_aux(&prod_span, &scratch_span, &x_span, &y_span, WORDS_PER_POLY);
782 
783   // |prod| needs to be reduced mod (��^n - 1), which just involves adding the
784   // upper-half to the lower-half. However, N is 701, which isn't a multiple of
785   // BITS_PER_WORD, so the upper-half vectors all have to be shifted before
786   // being added to the lower-half.
787   for (size_t i = 0; i < WORDS_PER_POLY; i++) {
788     crypto_word_t v_s = prod_s[WORDS_PER_POLY + i - 1] >> BITS_IN_LAST_WORD;
789     v_s |= prod_s[WORDS_PER_POLY + i] << (BITS_PER_WORD - BITS_IN_LAST_WORD);
790     crypto_word_t v_a = prod_a[WORDS_PER_POLY + i - 1] >> BITS_IN_LAST_WORD;
791     v_a |= prod_a[WORDS_PER_POLY + i] << (BITS_PER_WORD - BITS_IN_LAST_WORD);
792 
793     poly3_word_add(&out->s.v[i], &out->a.v[i], prod_s[i], prod_a[i], v_s, v_a);
794   }
795 
796   poly3_mod_phiN(out);
797 }
798 
799 #if defined(HRSS_HAVE_VECTOR_UNIT) && !defined(OPENSSL_AARCH64)
800 
801 // poly3_vec_cswap swaps (|a_s|, |a_a|) and (|b_s|, |b_a|) if |swap| is
802 // |0xff..ff|. Otherwise, |swap| must be zero.
poly3_vec_cswap(vec_t a_s[6],vec_t a_a[6],vec_t b_s[6],vec_t b_a[6],const vec_t swap)803 static inline void poly3_vec_cswap(vec_t a_s[6], vec_t a_a[6], vec_t b_s[6],
804                                    vec_t b_a[6], const vec_t swap) {
805   for (int i = 0; i < 6; i++) {
806     const vec_t sum_s = swap & (a_s[i] ^ b_s[i]);
807     a_s[i] ^= sum_s;
808     b_s[i] ^= sum_s;
809 
810     const vec_t sum_a = swap & (a_a[i] ^ b_a[i]);
811     a_a[i] ^= sum_a;
812     b_a[i] ^= sum_a;
813   }
814 }
815 
816 // poly3_vec_fmsub subtracts (|ms|, |ma|) × (|b_s|, |b_a|) from (|a_s|, |a_a|).
poly3_vec_fmsub(vec_t a_s[6],vec_t a_a[6],vec_t b_s[6],vec_t b_a[6],const vec_t ms,const vec_t ma)817 static inline void poly3_vec_fmsub(vec_t a_s[6], vec_t a_a[6], vec_t b_s[6],
818                                    vec_t b_a[6], const vec_t ms,
819                                    const vec_t ma) {
820   for (int i = 0; i < 6; i++) {
821     // See the bitslice formula, above.
822     const vec_t s = b_s[i];
823     const vec_t a = b_a[i];
824     const vec_t product_a = a & ma;
825     const vec_t product_s = (s ^ ms) & product_a;
826 
827     const vec_t out_s = a_s[i];
828     const vec_t out_a = a_a[i];
829     const vec_t t = out_a ^ product_a;
830     a_s[i] = (out_s ^ product_a) & (t ^ product_s);
831     a_a[i] = t | (out_s ^ product_s);
832   }
833 }
834 
835 // poly3_invert_vec sets |*out| to |in|^-1, i.e. such that |out|×|in| == 1 mod
836 // Φ(N).
poly3_invert_vec(struct poly3 * out,const struct poly3 * in)837 static void poly3_invert_vec(struct poly3 *out, const struct poly3 *in) {
838   // See the comment in |HRSS_poly3_invert| about this algorithm. In addition to
839   // the changes described there, this implementation attempts to use vector
840   // registers to speed up the computation. Even non-poly3 variables are held in
841   // vectors where possible to minimise the amount of data movement between
842   // the vector and general-purpose registers.
843 
844   vec_t b_s[6], b_a[6], c_s[6], c_a[6], f_s[6], f_a[6], g_s[6], g_a[6];
845   const vec_t kZero = {0};
846   const vec_t kOne = {1};
847   static const uint8_t kOneBytes[sizeof(vec_t)] = {1};
848   static const uint8_t kBottomSixtyOne[sizeof(vec_t)] = {
849       0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x1f};
850 
851   memset(b_s, 0, sizeof(b_s));
852   memcpy(b_a, kOneBytes, sizeof(kOneBytes));
853   memset(&b_a[1], 0, 5 * sizeof(vec_t));
854 
855   memset(c_s, 0, sizeof(c_s));
856   memset(c_a, 0, sizeof(c_a));
857 
858   f_s[5] = kZero;
859   memcpy(f_s, in->s.v, WORDS_PER_POLY * sizeof(crypto_word_t));
860   f_a[5] = kZero;
861   memcpy(f_a, in->a.v, WORDS_PER_POLY * sizeof(crypto_word_t));
862 
863   // Set g to all ones.
864   memset(g_s, 0, sizeof(g_s));
865   memset(g_a, 0xff, 5 * sizeof(vec_t));
866   memcpy(&g_a[5], kBottomSixtyOne, sizeof(kBottomSixtyOne));
867 
868   vec_t deg_f = {N - 1}, deg_g = {N - 1}, rotation = kZero;
869   vec_t k = kOne;
870   vec_t f0s = {0}, f0a = {0};
871   vec_t still_going;
872   memset(&still_going, 0xff, sizeof(still_going));
873 
874   for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
875     const vec_t s_a = vec_broadcast_bit(still_going & (f_a[0] & g_a[0]));
876     const vec_t s_s =
877         vec_broadcast_bit(still_going & ((f_s[0] ^ g_s[0]) & s_a));
878     const vec_t should_swap =
879         (s_s | s_a) & vec_broadcast_bit15(deg_f - deg_g);
880 
881     poly3_vec_cswap(f_s, f_a, g_s, g_a, should_swap);
882     poly3_vec_fmsub(f_s, f_a, g_s, g_a, s_s, s_a);
883     poly3_vec_rshift1(f_s, f_a);
884 
885     poly3_vec_cswap(b_s, b_a, c_s, c_a, should_swap);
886     poly3_vec_fmsub(b_s, b_a, c_s, c_a, s_s, s_a);
887     poly3_vec_lshift1(c_s, c_a);
888 
889     const vec_t deg_sum = should_swap & (deg_f ^ deg_g);
890     deg_f ^= deg_sum;
891     deg_g ^= deg_sum;
892 
893     deg_f -= kOne;
894     still_going &= ~vec_broadcast_bit15(deg_f - kOne);
895 
896     const vec_t f0_is_nonzero = vec_broadcast_bit(f_s[0] | f_a[0]);
897     // |f0_is_nonzero| implies |still_going|.
898     rotation ^= f0_is_nonzero & (k ^ rotation);
899     k += kOne;
900 
901     const vec_t f0s_sum = f0_is_nonzero & (f_s[0] ^ f0s);
902     f0s ^= f0s_sum;
903     const vec_t f0a_sum = f0_is_nonzero & (f_a[0] ^ f0a);
904     f0a ^= f0a_sum;
905   }
906 
907   crypto_word_t rotation_word = vec_get_word(rotation, 0);
908   rotation_word -= N & constant_time_lt_w(N, rotation_word);
909   memcpy(out->s.v, b_s, WORDS_PER_POLY * sizeof(crypto_word_t));
910   memcpy(out->a.v, b_a, WORDS_PER_POLY * sizeof(crypto_word_t));
911   assert(poly3_top_bits_are_clear(out));
912   poly3_rotr_consttime(out, rotation_word);
913   poly3_mul_const(out, vec_get_word(f0s, 0), vec_get_word(f0a, 0));
914   poly3_mod_phiN(out);
915 }
916 
917 #endif  // HRSS_HAVE_VECTOR_UNIT
918 
919 // HRSS_poly3_invert sets |*out| to |in|^-1, i.e. such that |out|×|in| == 1 mod
920 // Φ(N).
HRSS_poly3_invert(struct poly3 * out,const struct poly3 * in)921 void HRSS_poly3_invert(struct poly3 *out, const struct poly3 *in) {
922   // The vector version of this function seems slightly slower on AArch64, but
923   // is useful on ARMv7 and x86-64.
924 #if defined(HRSS_HAVE_VECTOR_UNIT) && !defined(OPENSSL_AARCH64)
925   if (vec_capable()) {
926     poly3_invert_vec(out, in);
927     return;
928   }
929 #endif
930 
931   // This algorithm mostly follows algorithm 10 in the paper. Some changes:
932   //   1) k should start at zero, not one. In the code below k is omitted and
933   //      the loop counter, |i|, is used instead.
934   //   2) The rotation count is conditionally updated to handle trailing zero
935   //      coefficients.
936   // The best explanation for why it works is in the "Why it works" section of
937   // [NTRUTN14].
938 
939   struct poly3 c, f, g;
940   OPENSSL_memcpy(&f, in, sizeof(f));
941 
942   // Set g to all ones.
943   OPENSSL_memset(&g.s, 0, sizeof(struct poly2));
944   OPENSSL_memset(&g.a, 0xff, sizeof(struct poly2));
945   g.a.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
946 
947   struct poly3 *b = out;
948   poly3_zero(b);
949   poly3_zero(&c);
950   // Set b to one.
951   b->a.v[0] = 1;
952 
953   crypto_word_t deg_f = N - 1, deg_g = N - 1, rotation = 0;
954   crypto_word_t f0s = 0, f0a = 0;
955   crypto_word_t still_going = CONSTTIME_TRUE_W;
956 
957   for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
958     const crypto_word_t s_a = lsb_to_all(
959         still_going & (f.a.v[0] & g.a.v[0]));
960     const crypto_word_t s_s = lsb_to_all(
961         still_going & ((f.s.v[0] ^ g.s.v[0]) & s_a));
962     const crypto_word_t should_swap =
963         (s_s | s_a) & constant_time_lt_w(deg_f, deg_g);
964 
965     poly3_cswap(&f, &g, should_swap);
966     poly3_cswap(b, &c, should_swap);
967 
968     const crypto_word_t deg_sum = should_swap & (deg_f ^ deg_g);
969     deg_f ^= deg_sum;
970     deg_g ^= deg_sum;
971     assert(deg_g >= 1);
972 
973     poly3_fmsub(&f, &g, s_s, s_a);
974     poly3_fmsub(b, &c, s_s, s_a);
975     poly3_rshift1(&f);
976     poly3_lshift1(&c);
977 
978     deg_f--;
979     const crypto_word_t f0_is_nonzero =
980         lsb_to_all(f.s.v[0]) | lsb_to_all(f.a.v[0]);
981     // |f0_is_nonzero| implies |still_going|.
982     assert(!(f0_is_nonzero && !still_going));
983     still_going &= ~constant_time_is_zero_w(deg_f);
984 
985     rotation = constant_time_select_w(f0_is_nonzero, i, rotation);
986     f0s = constant_time_select_w(f0_is_nonzero, f.s.v[0], f0s);
987     f0a = constant_time_select_w(f0_is_nonzero, f.a.v[0], f0a);
988   }
989 
990   rotation++;
991   rotation -= N & constant_time_lt_w(N, rotation);
992   assert(poly3_top_bits_are_clear(out));
993   poly3_rotr_consttime(out, rotation);
994   poly3_mul_const(out, f0s, f0a);
995   poly3_mod_phiN(out);
996 }
997 
998 // Polynomials in Q.
999 
1000 // Coefficients are reduced mod Q. (Q is clearly not prime, therefore the
1001 // coefficients do not form a field.)
1002 #define Q 8192
1003 
1004 // VECS_PER_POLY is the number of 128-bit vectors needed to represent a
1005 // polynomial.
1006 #define COEFFICIENTS_PER_VEC (sizeof(vec_t) / sizeof(uint16_t))
1007 #define VECS_PER_POLY ((N + COEFFICIENTS_PER_VEC - 1) / COEFFICIENTS_PER_VEC)
1008 
1009 // poly represents a polynomial with coefficients mod Q. Note that, while Q is a
1010 // power of two, this does not operate in GF(Q). That would be a binary field
1011 // but this is simply mod Q. Thus the coefficients are not a field.
1012 //
1013 // Coefficients are ordered little-endian, thus the coefficient of x^0 is the
1014 // first element of the array.
1015 struct poly {
1016 #if defined(HRSS_HAVE_VECTOR_UNIT)
1017   union {
1018     // N + 3 = 704, which is a multiple of 64 and thus aligns things, esp for
1019     // the vector code.
1020     uint16_t v[N + 3];
1021     vec_t vectors[VECS_PER_POLY];
1022   };
1023 #else
1024   // Even if !HRSS_HAVE_VECTOR_UNIT, external assembly may be called that
1025   // requires alignment.
1026   alignas(16) uint16_t v[N + 3];
1027 #endif
1028 };
1029 
poly_print(const struct poly * p)1030 OPENSSL_UNUSED static void poly_print(const struct poly *p) {
1031   printf("[");
1032   for (unsigned i = 0; i < N; i++) {
1033     if (i) {
1034       printf(" ");
1035     }
1036     printf("%d", p->v[i]);
1037   }
1038   printf("]\n");
1039 }
1040 
1041 #if defined(HRSS_HAVE_VECTOR_UNIT)
1042 
1043 // poly_mul_vec_aux is a recursive function that multiplies |n| words from |a|
1044 // and |b| and writes 2×|n| words to |out|. Each call uses 2*ceil(n/2) elements
1045 // of |scratch| and the function recurses, except if |n| < 3, when |scratch|
1046 // isn't used and the recursion stops. If |n| == |VECS_PER_POLY| then |scratch|
1047 // needs 172 elements.
poly_mul_vec_aux(vec_t * restrict out,vec_t * restrict scratch,const vec_t * restrict a,const vec_t * restrict b,const size_t n)1048 static void poly_mul_vec_aux(vec_t *restrict out, vec_t *restrict scratch,
1049                              const vec_t *restrict a, const vec_t *restrict b,
1050                              const size_t n) {
1051   // In [HRSS], the technique they used for polynomial multiplication is
1052   // described: they start with Toom-4 at the top level and then two layers of
1053   // Karatsuba. Karatsuba is a specific instance of the general Toom–Cook
1054   // decomposition, which splits an input n-ways and produces 2n-1
1055   // multiplications of those parts. So, starting with 704 coefficients (rounded
1056   // up from 701 to have more factors of two), Toom-4 gives seven
1057   // multiplications of degree-174 polynomials. Each round of Karatsuba (which
1058   // is Toom-2) increases the number of multiplications by a factor of three
1059   // while halving the size of the values being multiplied. So two rounds gives
1060   // 63 multiplications of degree-44 polynomials. Then they (I think) form
1061   // vectors by gathering all 63 coefficients of each power together, for each
1062   // input, and doing more rounds of Karatsuba on the vectors until they bottom-
1063   // out somewhere with schoolbook multiplication.
1064   //
1065   // I tried something like that for NEON. NEON vectors are 128 bits so hold
1066   // eight coefficients. I wrote a function that did Karatsuba on eight
1067   // multiplications at the same time, using such vectors, and a Go script that
1068   // decomposed from degree-704, with Karatsuba in non-transposed form, until it
1069   // reached multiplications of degree-44. It batched up those 81
1070   // multiplications into lots of eight with a single one left over (which was
1071   // handled directly).
1072   //
1073   // It worked, but it was significantly slower than the dumb algorithm used
1074   // below. Potentially that was because I misunderstood how [HRSS] did it, or
1075   // because Clang is bad at generating good code from NEON intrinsics on ARMv7.
1076   // (Which is true: the code generated by Clang for the below is pretty crap.)
1077   //
1078   // This algorithm is much simpler. It just does Karatsuba decomposition all
1079   // the way down and never transposes. When it gets down to degree-16 or
1080   // degree-24 values, they are multiplied using schoolbook multiplication and
1081   // vector intrinsics. The vector operations form each of the eight phase-
1082   // shifts of one of the inputs, point-wise multiply, and then add into the
1083   // result at the correct place. This means that 33% (degree-16) or 25%
1084   // (degree-24) of the multiplies and adds are wasted, but it does ok.
1085   if (n == 2) {
1086     vec_t result[4];
1087     vec_t vec_a[3];
1088     static const vec_t kZero = {0};
1089     vec_a[0] = a[0];
1090     vec_a[1] = a[1];
1091     vec_a[2] = kZero;
1092 
1093     result[0] = vec_mul(vec_a[0], vec_get_word(b[0], 0));
1094     result[1] = vec_mul(vec_a[1], vec_get_word(b[0], 0));
1095 
1096     result[1] = vec_fma(result[1], vec_a[0], vec_get_word(b[1], 0));
1097     result[2] = vec_mul(vec_a[1], vec_get_word(b[1], 0));
1098     result[3] = kZero;
1099 
1100     vec3_rshift_word(vec_a);
1101 
1102 #define BLOCK(x, y)                                                      \
1103   do {                                                                   \
1104     result[x + 0] =                                                      \
1105         vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
1106     result[x + 1] =                                                      \
1107         vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
1108     result[x + 2] =                                                      \
1109         vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
1110   } while (0)
1111 
1112     BLOCK(0, 1);
1113     BLOCK(1, 9);
1114 
1115     vec3_rshift_word(vec_a);
1116 
1117     BLOCK(0, 2);
1118     BLOCK(1, 10);
1119 
1120     vec3_rshift_word(vec_a);
1121 
1122     BLOCK(0, 3);
1123     BLOCK(1, 11);
1124 
1125     vec3_rshift_word(vec_a);
1126 
1127     BLOCK(0, 4);
1128     BLOCK(1, 12);
1129 
1130     vec3_rshift_word(vec_a);
1131 
1132     BLOCK(0, 5);
1133     BLOCK(1, 13);
1134 
1135     vec3_rshift_word(vec_a);
1136 
1137     BLOCK(0, 6);
1138     BLOCK(1, 14);
1139 
1140     vec3_rshift_word(vec_a);
1141 
1142     BLOCK(0, 7);
1143     BLOCK(1, 15);
1144 
1145 #undef BLOCK
1146 
1147     memcpy(out, result, sizeof(result));
1148     return;
1149   }
1150 
1151   if (n == 3) {
1152     vec_t result[6];
1153     vec_t vec_a[4];
1154     static const vec_t kZero = {0};
1155     vec_a[0] = a[0];
1156     vec_a[1] = a[1];
1157     vec_a[2] = a[2];
1158     vec_a[3] = kZero;
1159 
1160     result[0] = vec_mul(a[0], vec_get_word(b[0], 0));
1161     result[1] = vec_mul(a[1], vec_get_word(b[0], 0));
1162     result[2] = vec_mul(a[2], vec_get_word(b[0], 0));
1163 
1164 #define BLOCK_PRE(x, y)                                                  \
1165   do {                                                                   \
1166     result[x + 0] =                                                      \
1167         vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
1168     result[x + 1] =                                                      \
1169         vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
1170     result[x + 2] = vec_mul(vec_a[2], vec_get_word(b[y / 8], y % 8));    \
1171   } while (0)
1172 
1173     BLOCK_PRE(1, 8);
1174     BLOCK_PRE(2, 16);
1175 
1176     result[5] = kZero;
1177 
1178     vec4_rshift_word(vec_a);
1179 
1180 #define BLOCK(x, y)                                                      \
1181   do {                                                                   \
1182     result[x + 0] =                                                      \
1183         vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
1184     result[x + 1] =                                                      \
1185         vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
1186     result[x + 2] =                                                      \
1187         vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
1188     result[x + 3] =                                                      \
1189         vec_fma(result[x + 3], vec_a[3], vec_get_word(b[y / 8], y % 8)); \
1190   } while (0)
1191 
1192     BLOCK(0, 1);
1193     BLOCK(1, 9);
1194     BLOCK(2, 17);
1195 
1196     vec4_rshift_word(vec_a);
1197 
1198     BLOCK(0, 2);
1199     BLOCK(1, 10);
1200     BLOCK(2, 18);
1201 
1202     vec4_rshift_word(vec_a);
1203 
1204     BLOCK(0, 3);
1205     BLOCK(1, 11);
1206     BLOCK(2, 19);
1207 
1208     vec4_rshift_word(vec_a);
1209 
1210     BLOCK(0, 4);
1211     BLOCK(1, 12);
1212     BLOCK(2, 20);
1213 
1214     vec4_rshift_word(vec_a);
1215 
1216     BLOCK(0, 5);
1217     BLOCK(1, 13);
1218     BLOCK(2, 21);
1219 
1220     vec4_rshift_word(vec_a);
1221 
1222     BLOCK(0, 6);
1223     BLOCK(1, 14);
1224     BLOCK(2, 22);
1225 
1226     vec4_rshift_word(vec_a);
1227 
1228     BLOCK(0, 7);
1229     BLOCK(1, 15);
1230     BLOCK(2, 23);
1231 
1232 #undef BLOCK
1233 #undef BLOCK_PRE
1234 
1235     memcpy(out, result, sizeof(result));
1236 
1237     return;
1238   }
1239 
1240   // Karatsuba multiplication.
1241   // https://en.wikipedia.org/wiki/Karatsuba_algorithm
1242 
1243   // When |n| is odd, the two "halves" will have different lengths. The first is
1244   // always the smaller.
1245   const size_t low_len = n / 2;
1246   const size_t high_len = n - low_len;
1247   const vec_t *a_high = &a[low_len];
1248   const vec_t *b_high = &b[low_len];
1249 
1250   // Store a_1 + a_0 in the first half of |out| and b_1 + b_0 in the second
1251   // half.
1252   for (size_t i = 0; i < low_len; i++) {
1253     out[i] = vec_add(a_high[i], a[i]);
1254     out[high_len + i] = vec_add(b_high[i], b[i]);
1255   }
1256   if (high_len != low_len) {
1257     out[low_len] = a_high[low_len];
1258     out[high_len + low_len] = b_high[low_len];
1259   }
1260 
1261   vec_t *const child_scratch = &scratch[2 * high_len];
1262   // Calculate (a_1 + a_0) × (b_1 + b_0) and write to scratch buffer.
1263   poly_mul_vec_aux(scratch, child_scratch, out, &out[high_len], high_len);
1264   // Calculate a_1 × b_1.
1265   poly_mul_vec_aux(&out[low_len * 2], child_scratch, a_high, b_high, high_len);
1266   // Calculate a_0 × b_0.
1267   poly_mul_vec_aux(out, child_scratch, a, b, low_len);
1268 
1269   // Subtract those last two products from the first.
1270   for (size_t i = 0; i < low_len * 2; i++) {
1271     scratch[i] = vec_sub(scratch[i], vec_add(out[i], out[low_len * 2 + i]));
1272   }
1273   if (low_len != high_len) {
1274     scratch[low_len * 2] = vec_sub(scratch[low_len * 2], out[low_len * 4]);
1275     scratch[low_len * 2 + 1] =
1276         vec_sub(scratch[low_len * 2 + 1], out[low_len * 4 + 1]);
1277   }
1278 
1279   // Add the middle product into the output.
1280   for (size_t i = 0; i < high_len * 2; i++) {
1281     out[low_len + i] = vec_add(out[low_len + i], scratch[i]);
1282   }
1283 }
1284 
1285 // poly_mul_vec sets |*out| to |x|×|y| mod (��^n - 1).
poly_mul_vec(struct poly * out,const struct poly * x,const struct poly * y)1286 static void poly_mul_vec(struct poly *out, const struct poly *x,
1287                          const struct poly *y) {
1288   OPENSSL_memset((uint16_t *)&x->v[N], 0, 3 * sizeof(uint16_t));
1289   OPENSSL_memset((uint16_t *)&y->v[N], 0, 3 * sizeof(uint16_t));
1290 
1291   OPENSSL_STATIC_ASSERT(sizeof(out->v) == sizeof(vec_t) * VECS_PER_POLY,
1292                         "struct poly is the wrong size");
1293   OPENSSL_STATIC_ASSERT(alignof(struct poly) == alignof(vec_t),
1294                         "struct poly has incorrect alignment");
1295 
1296   vec_t prod[VECS_PER_POLY * 2];
1297   vec_t scratch[172];
1298   poly_mul_vec_aux(prod, scratch, x->vectors, y->vectors, VECS_PER_POLY);
1299 
1300   // |prod| needs to be reduced mod (��^n - 1), which just involves adding the
1301   // upper-half to the lower-half. However, N is 701, which isn't a multiple of
1302   // the vector size, so the upper-half vectors all have to be shifted before
1303   // being added to the lower-half.
1304   vec_t *out_vecs = (vec_t *)out->v;
1305 
1306   for (size_t i = 0; i < VECS_PER_POLY; i++) {
1307     const vec_t prev = prod[VECS_PER_POLY - 1 + i];
1308     const vec_t this = prod[VECS_PER_POLY + i];
1309     out_vecs[i] = vec_add(prod[i], vec_merge_3_5(prev, this));
1310   }
1311 
1312   OPENSSL_memset(&out->v[N], 0, 3 * sizeof(uint16_t));
1313 }
1314 
1315 #endif  // HRSS_HAVE_VECTOR_UNIT
1316 
1317 // poly_mul_novec_aux writes the product of |a| and |b| to |out|, using
1318 // |scratch| as scratch space. It'll use Karatsuba if the inputs are large
1319 // enough to warrant it. Each call uses 2*ceil(n/2) elements of |scratch| and
1320 // the function recurses, except if |n| < 64, when |scratch| isn't used and the
1321 // recursion stops. If |n| == |N| then |scratch| needs 1318 elements.
poly_mul_novec_aux(uint16_t * out,uint16_t * scratch,const uint16_t * a,const uint16_t * b,size_t n)1322 static void poly_mul_novec_aux(uint16_t *out, uint16_t *scratch,
1323                                const uint16_t *a, const uint16_t *b, size_t n) {
1324   static const size_t kSchoolbookLimit = 64;
1325   if (n < kSchoolbookLimit) {
1326     OPENSSL_memset(out, 0, sizeof(uint16_t) * n * 2);
1327     for (size_t i = 0; i < n; i++) {
1328       for (size_t j = 0; j < n; j++) {
1329         out[i + j] += (unsigned) a[i] * b[j];
1330       }
1331     }
1332 
1333     return;
1334   }
1335 
1336   // Karatsuba multiplication.
1337   // https://en.wikipedia.org/wiki/Karatsuba_algorithm
1338 
1339   // When |n| is odd, the two "halves" will have different lengths. The
1340   // first is always the smaller.
1341   const size_t low_len = n / 2;
1342   const size_t high_len = n - low_len;
1343   const uint16_t *const a_high = &a[low_len];
1344   const uint16_t *const b_high = &b[low_len];
1345 
1346   for (size_t i = 0; i < low_len; i++) {
1347     out[i] = a_high[i] + a[i];
1348     out[high_len + i] = b_high[i] + b[i];
1349   }
1350   if (high_len != low_len) {
1351     out[low_len] = a_high[low_len];
1352     out[high_len + low_len] = b_high[low_len];
1353   }
1354 
1355   uint16_t *const child_scratch = &scratch[2 * high_len];
1356   poly_mul_novec_aux(scratch, child_scratch, out, &out[high_len], high_len);
1357   poly_mul_novec_aux(&out[low_len * 2], child_scratch, a_high, b_high,
1358                      high_len);
1359   poly_mul_novec_aux(out, child_scratch, a, b, low_len);
1360 
1361   for (size_t i = 0; i < low_len * 2; i++) {
1362     scratch[i] -= out[i] + out[low_len * 2 + i];
1363   }
1364   if (low_len != high_len) {
1365     scratch[low_len * 2] -= out[low_len * 4];
1366     assert(out[low_len * 4 + 1] == 0);
1367   }
1368 
1369   for (size_t i = 0; i < high_len * 2; i++) {
1370     out[low_len + i] += scratch[i];
1371   }
1372 }
1373 
1374 // poly_mul_novec sets |*out| to |x|×|y| mod (��^n - 1).
poly_mul_novec(struct poly * out,const struct poly * x,const struct poly * y)1375 static void poly_mul_novec(struct poly *out, const struct poly *x,
1376                            const struct poly *y) {
1377   uint16_t prod[2 * N];
1378   uint16_t scratch[1318];
1379   poly_mul_novec_aux(prod, scratch, x->v, y->v, N);
1380 
1381   for (size_t i = 0; i < N; i++) {
1382     out->v[i] = prod[i] + prod[i + N];
1383   }
1384   OPENSSL_memset(&out->v[N], 0, 3 * sizeof(uint16_t));
1385 }
1386 
poly_mul(struct poly * r,const struct poly * a,const struct poly * b)1387 static void poly_mul(struct poly *r, const struct poly *a,
1388                      const struct poly *b) {
1389 #if defined(POLY_RQ_MUL_ASM)
1390   const int has_avx2 = (OPENSSL_ia32cap_P[2] & (1 << 5)) != 0;
1391   if (has_avx2) {
1392     poly_Rq_mul(r->v, a->v, b->v);
1393     return;
1394   }
1395 #endif
1396 
1397 #if defined(HRSS_HAVE_VECTOR_UNIT)
1398   if (vec_capable()) {
1399     poly_mul_vec(r, a, b);
1400     return;
1401   }
1402 #endif
1403 
1404   // Fallback, non-vector case.
1405   poly_mul_novec(r, a, b);
1406 }
1407 
1408 // poly_mul_x_minus_1 sets |p| to |p|×(�� - 1) mod (��^n - 1).
poly_mul_x_minus_1(struct poly * p)1409 static void poly_mul_x_minus_1(struct poly *p) {
1410   // Multiplying by (�� - 1) means negating each coefficient and adding in
1411   // the value of the previous one.
1412   const uint16_t orig_final_coefficient = p->v[N - 1];
1413 
1414   for (size_t i = N - 1; i > 0; i--) {
1415     p->v[i] = p->v[i - 1] - p->v[i];
1416   }
1417   p->v[0] = orig_final_coefficient - p->v[0];
1418 }
1419 
1420 // poly_mod_phiN sets |p| to |p| mod Φ(N).
poly_mod_phiN(struct poly * p)1421 static void poly_mod_phiN(struct poly *p) {
1422   const uint16_t coeff700 = p->v[N - 1];
1423 
1424   for (unsigned i = 0; i < N; i++) {
1425     p->v[i] -= coeff700;
1426   }
1427 }
1428 
1429 // poly_clamp reduces each coefficient mod Q.
poly_clamp(struct poly * p)1430 static void poly_clamp(struct poly *p) {
1431   for (unsigned i = 0; i < N; i++) {
1432     p->v[i] &= Q - 1;
1433   }
1434 }
1435 
1436 
1437 // Conversion functions
1438 // --------------------
1439 
1440 // poly2_from_poly sets |*out| to |in| mod 2.
poly2_from_poly(struct poly2 * out,const struct poly * in)1441 static void poly2_from_poly(struct poly2 *out, const struct poly *in) {
1442   crypto_word_t *words = out->v;
1443   unsigned shift = 0;
1444   crypto_word_t word = 0;
1445 
1446   for (unsigned i = 0; i < N; i++) {
1447     word >>= 1;
1448     word |= (crypto_word_t)(in->v[i] & 1) << (BITS_PER_WORD - 1);
1449     shift++;
1450 
1451     if (shift == BITS_PER_WORD) {
1452       *words = word;
1453       words++;
1454       word = 0;
1455       shift = 0;
1456     }
1457   }
1458 
1459   word >>= BITS_PER_WORD - shift;
1460   *words = word;
1461 }
1462 
1463 // mod3 treats |a| as a signed number and returns |a| mod 3.
mod3(int16_t a)1464 static uint16_t mod3(int16_t a) {
1465   const int16_t q = ((int32_t)a * 21845) >> 16;
1466   int16_t ret = a - 3 * q;
1467   // At this point, |ret| is in {0, 1, 2, 3} and that needs to be mapped to {0,
1468   // 1, 2, 0}.
1469   return ret & ((ret & (ret >> 1)) - 1);
1470 }
1471 
1472 // poly3_from_poly sets |*out| to |in|.
poly3_from_poly(struct poly3 * out,const struct poly * in)1473 static void poly3_from_poly(struct poly3 *out, const struct poly *in) {
1474   crypto_word_t *words_s = out->s.v;
1475   crypto_word_t *words_a = out->a.v;
1476   crypto_word_t s = 0;
1477   crypto_word_t a = 0;
1478   unsigned shift = 0;
1479 
1480   for (unsigned i = 0; i < N; i++) {
1481     // This duplicates the 13th bit upwards to the top of the uint16,
1482     // essentially treating it as a sign bit and converting into a signed int16.
1483     // The signed value is reduced mod 3, yielding {0, 1, 2}.
1484     const uint16_t v = mod3((int16_t)(in->v[i] << 3) >> 3);
1485     s >>= 1;
1486     const crypto_word_t s_bit = (crypto_word_t)(v & 2) << (BITS_PER_WORD - 2);
1487     s |= s_bit;
1488     a >>= 1;
1489     a |= s_bit | (crypto_word_t)(v & 1) << (BITS_PER_WORD - 1);
1490     shift++;
1491 
1492     if (shift == BITS_PER_WORD) {
1493       *words_s = s;
1494       words_s++;
1495       *words_a = a;
1496       words_a++;
1497       s = a = 0;
1498       shift = 0;
1499     }
1500   }
1501 
1502   s >>= BITS_PER_WORD - shift;
1503   a >>= BITS_PER_WORD - shift;
1504   *words_s = s;
1505   *words_a = a;
1506 }
1507 
1508 // poly3_from_poly_checked sets |*out| to |in|, which has coefficients in {0, 1,
1509 // Q-1}. It returns a mask indicating whether all coefficients were found to be
1510 // in that set.
poly3_from_poly_checked(struct poly3 * out,const struct poly * in)1511 static crypto_word_t poly3_from_poly_checked(struct poly3 *out,
1512                                              const struct poly *in) {
1513   crypto_word_t *words_s = out->s.v;
1514   crypto_word_t *words_a = out->a.v;
1515   crypto_word_t s = 0;
1516   crypto_word_t a = 0;
1517   unsigned shift = 0;
1518   crypto_word_t ok = CONSTTIME_TRUE_W;
1519 
1520   for (unsigned i = 0; i < N; i++) {
1521     const uint16_t v = in->v[i];
1522     // Maps {0, 1, Q-1} to {0, 1, 2}.
1523     uint16_t mod3 = v & 3;
1524     mod3 ^= mod3 >> 1;
1525     const uint16_t expected = (uint16_t)((~((mod3 >> 1) - 1)) | mod3) % Q;
1526     ok &= constant_time_eq_w(v, expected);
1527 
1528     s >>= 1;
1529     const crypto_word_t s_bit = (crypto_word_t)(mod3 & 2)
1530                                 << (BITS_PER_WORD - 2);
1531     s |= s_bit;
1532     a >>= 1;
1533     a |= s_bit | (crypto_word_t)(mod3 & 1) << (BITS_PER_WORD - 1);
1534     shift++;
1535 
1536     if (shift == BITS_PER_WORD) {
1537       *words_s = s;
1538       words_s++;
1539       *words_a = a;
1540       words_a++;
1541       s = a = 0;
1542       shift = 0;
1543     }
1544   }
1545 
1546   s >>= BITS_PER_WORD - shift;
1547   a >>= BITS_PER_WORD - shift;
1548   *words_s = s;
1549   *words_a = a;
1550 
1551   return ok;
1552 }
1553 
poly_from_poly2(struct poly * out,const struct poly2 * in)1554 static void poly_from_poly2(struct poly *out, const struct poly2 *in) {
1555   const crypto_word_t *words = in->v;
1556   unsigned shift = 0;
1557   crypto_word_t word = *words;
1558 
1559   for (unsigned i = 0; i < N; i++) {
1560     out->v[i] = word & 1;
1561     word >>= 1;
1562     shift++;
1563 
1564     if (shift == BITS_PER_WORD) {
1565       words++;
1566       word = *words;
1567       shift = 0;
1568     }
1569   }
1570 }
1571 
poly_from_poly3(struct poly * out,const struct poly3 * in)1572 static void poly_from_poly3(struct poly *out, const struct poly3 *in) {
1573   const crypto_word_t *words_s = in->s.v;
1574   const crypto_word_t *words_a = in->a.v;
1575   crypto_word_t word_s = ~(*words_s);
1576   crypto_word_t word_a = *words_a;
1577   unsigned shift = 0;
1578 
1579   for (unsigned i = 0; i < N; i++) {
1580     out->v[i] = (uint16_t)(word_s & 1) - 1;
1581     out->v[i] |= word_a & 1;
1582     word_s >>= 1;
1583     word_a >>= 1;
1584     shift++;
1585 
1586     if (shift == BITS_PER_WORD) {
1587       words_s++;
1588       words_a++;
1589       word_s = ~(*words_s);
1590       word_a = *words_a;
1591       shift = 0;
1592     }
1593   }
1594 }
1595 
1596 // Polynomial inversion
1597 // --------------------
1598 
1599 // poly_invert_mod2 sets |*out| to |in^-1| (i.e. such that |*out|×|in| = 1 mod
1600 // Φ(N)), all mod 2. This isn't useful in itself, but is part of doing inversion
1601 // mod Q.
poly_invert_mod2(struct poly * out,const struct poly * in)1602 static void poly_invert_mod2(struct poly *out, const struct poly *in) {
1603   // This algorithm follows algorithm 10 in the paper. (Although, in contrast to
1604   // the paper, k should start at zero, not one, and the rotation count is needs
1605   // to handle trailing zero coefficients.) The best explanation for why it
1606   // works is in the "Why it works" section of [NTRUTN14].
1607 
1608   struct poly2 b, c, f, g;
1609   poly2_from_poly(&f, in);
1610   OPENSSL_memset(&b, 0, sizeof(b));
1611   b.v[0] = 1;
1612   OPENSSL_memset(&c, 0, sizeof(c));
1613 
1614   // Set g to all ones.
1615   OPENSSL_memset(&g, 0xff, sizeof(struct poly2));
1616   g.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
1617 
1618   crypto_word_t deg_f = N - 1, deg_g = N - 1, rotation = 0;
1619   crypto_word_t still_going = CONSTTIME_TRUE_W;
1620 
1621   for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
1622     const crypto_word_t s = still_going & lsb_to_all(f.v[0]);
1623     const crypto_word_t should_swap = s & constant_time_lt_w(deg_f, deg_g);
1624     poly2_cswap(&f, &g, should_swap);
1625     poly2_cswap(&b, &c, should_swap);
1626     const crypto_word_t deg_sum = should_swap & (deg_f ^ deg_g);
1627     deg_f ^= deg_sum;
1628     deg_g ^= deg_sum;
1629     assert(deg_g >= 1);
1630     poly2_fmadd(&f, &g, s);
1631     poly2_fmadd(&b, &c, s);
1632 
1633     poly2_rshift1(&f);
1634     poly2_lshift1(&c);
1635 
1636     deg_f--;
1637     const crypto_word_t f0_is_nonzero = lsb_to_all(f.v[0]);
1638     // |f0_is_nonzero| implies |still_going|.
1639     assert(!(f0_is_nonzero && !still_going));
1640     rotation = constant_time_select_w(f0_is_nonzero, i, rotation);
1641     still_going &= ~constant_time_is_zero_w(deg_f);
1642   }
1643 
1644   rotation++;
1645   rotation -= N & constant_time_lt_w(N, rotation);
1646   assert(poly2_top_bits_are_clear(&b));
1647   HRSS_poly2_rotr_consttime(&b, rotation);
1648   poly_from_poly2(out, &b);
1649 }
1650 
1651 // poly_invert sets |*out| to |in^-1| (i.e. such that |*out|×|in| = 1 mod Φ(N)).
poly_invert(struct poly * out,const struct poly * in)1652 static void poly_invert(struct poly *out, const struct poly *in) {
1653   // Inversion mod Q, which is done based on the result of inverting mod
1654   // 2. See [NTRUTN14] paper, bottom of page two.
1655   struct poly a, *b, tmp;
1656 
1657   // a = -in.
1658   for (unsigned i = 0; i < N; i++) {
1659     a.v[i] = -in->v[i];
1660   }
1661 
1662   // b = in^-1 mod 2.
1663   b = out;
1664   poly_invert_mod2(b, in);
1665 
1666   // We are working mod Q=2**13 and we need to iterate ceil(log_2(13))
1667   // times, which is four.
1668   for (unsigned i = 0; i < 4; i++) {
1669     poly_mul(&tmp, &a, b);
1670     tmp.v[0] += 2;
1671     poly_mul(b, b, &tmp);
1672   }
1673 }
1674 
1675 // Marshal and unmarshal functions for various basic types.
1676 // --------------------------------------------------------
1677 
1678 #define POLY_BYTES 1138
1679 
1680 // poly_marshal serialises all but the final coefficient of |in| to |out|.
poly_marshal(uint8_t out[POLY_BYTES],const struct poly * in)1681 static void poly_marshal(uint8_t out[POLY_BYTES], const struct poly *in) {
1682   const uint16_t *p = in->v;
1683 
1684   for (size_t i = 0; i < N / 8; i++) {
1685     out[0] = p[0];
1686     out[1] = (0x1f & (p[0] >> 8)) | ((p[1] & 0x07) << 5);
1687     out[2] = p[1] >> 3;
1688     out[3] = (3 & (p[1] >> 11)) | ((p[2] & 0x3f) << 2);
1689     out[4] = (0x7f & (p[2] >> 6)) | ((p[3] & 0x01) << 7);
1690     out[5] = p[3] >> 1;
1691     out[6] = (0xf & (p[3] >> 9)) | ((p[4] & 0x0f) << 4);
1692     out[7] = p[4] >> 4;
1693     out[8] = (1 & (p[4] >> 12)) | ((p[5] & 0x7f) << 1);
1694     out[9] = (0x3f & (p[5] >> 7)) | ((p[6] & 0x03) << 6);
1695     out[10] = p[6] >> 2;
1696     out[11] = (7 & (p[6] >> 10)) | ((p[7] & 0x1f) << 3);
1697     out[12] = p[7] >> 5;
1698 
1699     p += 8;
1700     out += 13;
1701   }
1702 
1703   // There are four remaining values.
1704   out[0] = p[0];
1705   out[1] = (0x1f & (p[0] >> 8)) | ((p[1] & 0x07) << 5);
1706   out[2] = p[1] >> 3;
1707   out[3] = (3 & (p[1] >> 11)) | ((p[2] & 0x3f) << 2);
1708   out[4] = (0x7f & (p[2] >> 6)) | ((p[3] & 0x01) << 7);
1709   out[5] = p[3] >> 1;
1710   out[6] = 0xf & (p[3] >> 9);
1711 }
1712 
1713 // poly_unmarshal parses the output of |poly_marshal| and sets |out| such that
1714 // all but the final coefficients match, and the final coefficient is calculated
1715 // such that evaluating |out| at one results in zero. It returns one on success
1716 // or zero if |in| is an invalid encoding.
poly_unmarshal(struct poly * out,const uint8_t in[POLY_BYTES])1717 static int poly_unmarshal(struct poly *out, const uint8_t in[POLY_BYTES]) {
1718   uint16_t *p = out->v;
1719 
1720   for (size_t i = 0; i < N / 8; i++) {
1721     p[0] = (uint16_t)(in[0]) | (uint16_t)(in[1] & 0x1f) << 8;
1722     p[1] = (uint16_t)(in[1] >> 5) | (uint16_t)(in[2]) << 3 |
1723            (uint16_t)(in[3] & 3) << 11;
1724     p[2] = (uint16_t)(in[3] >> 2) | (uint16_t)(in[4] & 0x7f) << 6;
1725     p[3] = (uint16_t)(in[4] >> 7) | (uint16_t)(in[5]) << 1 |
1726            (uint16_t)(in[6] & 0xf) << 9;
1727     p[4] = (uint16_t)(in[6] >> 4) | (uint16_t)(in[7]) << 4 |
1728            (uint16_t)(in[8] & 1) << 12;
1729     p[5] = (uint16_t)(in[8] >> 1) | (uint16_t)(in[9] & 0x3f) << 7;
1730     p[6] = (uint16_t)(in[9] >> 6) | (uint16_t)(in[10]) << 2 |
1731            (uint16_t)(in[11] & 7) << 10;
1732     p[7] = (uint16_t)(in[11] >> 3) | (uint16_t)(in[12]) << 5;
1733 
1734     p += 8;
1735     in += 13;
1736   }
1737 
1738   // There are four coefficients remaining.
1739   p[0] = (uint16_t)(in[0]) | (uint16_t)(in[1] & 0x1f) << 8;
1740   p[1] = (uint16_t)(in[1] >> 5) | (uint16_t)(in[2]) << 3 |
1741          (uint16_t)(in[3] & 3) << 11;
1742   p[2] = (uint16_t)(in[3] >> 2) | (uint16_t)(in[4] & 0x7f) << 6;
1743   p[3] = (uint16_t)(in[4] >> 7) | (uint16_t)(in[5]) << 1 |
1744          (uint16_t)(in[6] & 0xf) << 9;
1745 
1746   for (unsigned i = 0; i < N - 1; i++) {
1747     out->v[i] = (int16_t)(out->v[i] << 3) >> 3;
1748   }
1749 
1750   // There are four unused bits in the last byte. We require them to be zero.
1751   if ((in[6] & 0xf0) != 0) {
1752     return 0;
1753   }
1754 
1755   // Set the final coefficient as specifed in [HRSSNIST] 1.9.2 step 6.
1756   uint32_t sum = 0;
1757   for (size_t i = 0; i < N - 1; i++) {
1758     sum += out->v[i];
1759   }
1760 
1761   out->v[N - 1] = (uint16_t)(0u - sum);
1762 
1763   return 1;
1764 }
1765 
1766 // mod3_from_modQ maps {0, 1, Q-1, 65535} -> {0, 1, 2, 2}. Note that |v| may
1767 // have an invalid value when processing attacker-controlled inputs.
mod3_from_modQ(uint16_t v)1768 static uint16_t mod3_from_modQ(uint16_t v) {
1769   v &= 3;
1770   return v ^ (v >> 1);
1771 }
1772 
1773 // poly_marshal_mod3 marshals |in| to |out| where the coefficients of |in| are
1774 // all in {0, 1, Q-1, 65535} and |in| is mod Φ(N). (Note that coefficients may
1775 // have invalid values when processing attacker-controlled inputs.)
poly_marshal_mod3(uint8_t out[HRSS_POLY3_BYTES],const struct poly * in)1776 static void poly_marshal_mod3(uint8_t out[HRSS_POLY3_BYTES],
1777                               const struct poly *in) {
1778   const uint16_t *coeffs = in->v;
1779 
1780   // Only 700 coefficients are marshaled because in[700] must be zero.
1781   assert(coeffs[N-1] == 0);
1782 
1783   for (size_t i = 0; i < HRSS_POLY3_BYTES; i++) {
1784     const uint16_t coeffs0 = mod3_from_modQ(coeffs[0]);
1785     const uint16_t coeffs1 = mod3_from_modQ(coeffs[1]);
1786     const uint16_t coeffs2 = mod3_from_modQ(coeffs[2]);
1787     const uint16_t coeffs3 = mod3_from_modQ(coeffs[3]);
1788     const uint16_t coeffs4 = mod3_from_modQ(coeffs[4]);
1789     out[i] = coeffs0 + coeffs1 * 3 + coeffs2 * 9 + coeffs3 * 27 + coeffs4 * 81;
1790     coeffs += 5;
1791   }
1792 }
1793 
1794 // HRSS-specific functions
1795 // -----------------------
1796 
1797 // poly_short_sample samples a vector of values in {0xffff (i.e. -1), 0, 1}.
1798 // This is the same action as the algorithm in [HRSSNIST] section 1.8.1, but
1799 // with HRSS-SXY the sampling algorithm is now a private detail of the
1800 // implementation (previously it had to match between two parties). This
1801 // function uses that freedom to implement a flatter distribution of values.
poly_short_sample(struct poly * out,const uint8_t in[HRSS_SAMPLE_BYTES])1802 static void poly_short_sample(struct poly *out,
1803                               const uint8_t in[HRSS_SAMPLE_BYTES]) {
1804   OPENSSL_STATIC_ASSERT(HRSS_SAMPLE_BYTES == N - 1,
1805                         "HRSS_SAMPLE_BYTES incorrect");
1806   for (size_t i = 0; i < N - 1; i++) {
1807     uint16_t v = mod3(in[i]);
1808     // Map {0, 1, 2} -> {0, 1, 0xffff}
1809     v |= ((v >> 1) ^ 1) - 1;
1810     out->v[i] = v;
1811   }
1812   out->v[N - 1] = 0;
1813 }
1814 
1815 // poly_short_sample_plus performs the T+ sample as defined in [HRSSNIST],
1816 // section 1.8.2.
poly_short_sample_plus(struct poly * out,const uint8_t in[HRSS_SAMPLE_BYTES])1817 static void poly_short_sample_plus(struct poly *out,
1818                                    const uint8_t in[HRSS_SAMPLE_BYTES]) {
1819   poly_short_sample(out, in);
1820 
1821   // sum (and the product in the for loop) will overflow. But that's fine
1822   // because |sum| is bound by +/- (N-2), and N < 2^15 so it works out.
1823   uint16_t sum = 0;
1824   for (unsigned i = 0; i < N - 2; i++) {
1825     sum += (unsigned) out->v[i] * out->v[i + 1];
1826   }
1827 
1828   // If the sum is negative, flip the sign of even-positioned coefficients. (See
1829   // page 8 of [HRSS].)
1830   sum = ((int16_t) sum) >> 15;
1831   const uint16_t scale = sum | (~sum & 1);
1832   for (unsigned i = 0; i < N; i += 2) {
1833     out->v[i] = (unsigned) out->v[i] * scale;
1834   }
1835 }
1836 
1837 // poly_lift computes the function discussed in [HRSS], appendix B.
poly_lift(struct poly * out,const struct poly * a)1838 static void poly_lift(struct poly *out, const struct poly *a) {
1839   // We wish to calculate a/(��-1) mod Φ(N) over GF(3), where Φ(N) is the
1840   // Nth cyclotomic polynomial, i.e. 1 + �� + … + ��^700 (since N is prime).
1841 
1842   // 1/(��-1) has a fairly basic structure that we can exploit to speed this up:
1843   //
1844   // R.<x> = PolynomialRing(GF(3)…)
1845   // inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n))
1846   // list(inv)[:15]
1847   //   [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]
1848   //
1849   // This three-element pattern of coefficients repeats for the whole
1850   // polynomial.
1851   //
1852   // Next define the overbar operator such that z̅ = z[0] +
1853   // reverse(z[1:]). (Index zero of a polynomial here is the coefficient
1854   // of the constant term. So index one is the coefficient of �� and so
1855   // on.)
1856   //
1857   // A less odd way to define this is to see that z̅ negates the indexes,
1858   // so z̅[0] = z[-0], z̅[1] = z[-1] and so on.
1859   //
1860   // The use of z̅ is that, when working mod (��^701 - 1), vz[0] = <v,
1861   // z̅>, vz[1] = <v, ��z̅>, …. (Where <a, b> is the inner product: the sum
1862   // of the point-wise products.) Although we calculated the inverse mod
1863   // Φ(N), we can work mod (��^N - 1) and reduce mod Φ(N) at the end.
1864   // (That's because (��^N - 1) is a multiple of Φ(N).)
1865   //
1866   // When working mod (��^N - 1), multiplication by �� is a right-rotation
1867   // of the list of coefficients.
1868   //
1869   // Thus we can consider what the pattern of z̅, ��z̅, ��^2z̅, … looks like:
1870   //
1871   // def reverse(xs):
1872   //   suffix = list(xs[1:])
1873   //   suffix.reverse()
1874   //   return [xs[0]] + suffix
1875   //
1876   // def rotate(xs):
1877   //   return [xs[-1]] + xs[:-1]
1878   //
1879   // zoverbar = reverse(list(inv) + [0])
1880   // xzoverbar = rotate(reverse(list(inv) + [0]))
1881   // x2zoverbar = rotate(rotate(reverse(list(inv) + [0])))
1882   //
1883   // zoverbar[:15]
1884   //   [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]
1885   // xzoverbar[:15]
1886   //   [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
1887   // x2zoverbar[:15]
1888   //   [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
1889   //
1890   // (For a formula for z̅, see lemma two of appendix B.)
1891   //
1892   // After the first three elements have been taken care of, all then have
1893   // a repeating three-element cycle. The next value (��^3z̅) involves
1894   // three rotations of the first pattern, thus the three-element cycle
1895   // lines up. However, the discontinuity in the first three elements
1896   // obviously moves to a different position. Consider the difference
1897   // between ��^3z̅ and z̅:
1898   //
1899   // [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15]
1900   //    [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1901   //
1902   // This pattern of differences is the same for all elements, although it
1903   // obviously moves right with the rotations.
1904   //
1905   // From this, we reach algorithm eight of appendix B.
1906 
1907   // Handle the first three elements of the inner products.
1908   out->v[0] = a->v[0] + a->v[2];
1909   out->v[1] = a->v[1];
1910   out->v[2] = -a->v[0] + a->v[2];
1911 
1912   // s0, s1, s2 are added into out->v[0], out->v[1], and out->v[2],
1913   // respectively. We do not compute s1 because it's just -(s0 + s1).
1914   uint16_t s0 = 0, s2 = 0;
1915   for (size_t i = 3; i < 699; i += 3) {
1916     s0 += -a->v[i] + a->v[i + 2];
1917     // s1 += a->v[i] - a->v[i + 1];
1918     s2 += a->v[i + 1] - a->v[i + 2];
1919   }
1920 
1921   // Handle the fact that the three-element pattern doesn't fill the
1922   // polynomial exactly (since 701 isn't a multiple of three).
1923   s0 -= a->v[699];
1924   // s1 += a->v[699] - a->v[700];
1925   s2 += a->v[700];
1926 
1927   // Note that s0 + s1 + s2 = 0.
1928   out->v[0] += s0;
1929   out->v[1] -= (s0 + s2); // = s1
1930   out->v[2] += s2;
1931 
1932   // Calculate the remaining inner products by taking advantage of the
1933   // fact that the pattern repeats every three cycles and the pattern of
1934   // differences moves with the rotation.
1935   for (size_t i = 3; i < N; i++) {
1936     out->v[i] = (out->v[i - 3] - (a->v[i - 2] + a->v[i - 1] + a->v[i]));
1937   }
1938 
1939   // Reduce mod Φ(N) by subtracting a multiple of out[700] from every
1940   // element and convert to mod Q. (See above about adding twice as
1941   // subtraction.)
1942   const crypto_word_t v = out->v[700];
1943   for (unsigned i = 0; i < N; i++) {
1944     const uint16_t vi_mod3 = mod3(out->v[i] - v);
1945     // Map {0, 1, 2} to {0, 1, 0xffff}.
1946     out->v[i] = (~((vi_mod3 >> 1) - 1)) | vi_mod3;
1947   }
1948 
1949   poly_mul_x_minus_1(out);
1950 }
1951 
1952 struct public_key {
1953   struct poly ph;
1954 };
1955 
1956 struct private_key {
1957   struct poly3 f, f_inverse;
1958   struct poly ph_inverse;
1959   uint8_t hmac_key[32];
1960 };
1961 
1962 // public_key_from_external converts an external public key pointer into an
1963 // internal one. Externally the alignment is only specified to be eight bytes
1964 // but we need 16-byte alignment. We could annotate the external struct with
1965 // that alignment but we can only assume that malloced pointers are 8-byte
1966 // aligned in any case. (Even if the underlying malloc returns values with
1967 // 16-byte alignment, |OPENSSL_malloc| will store an 8-byte size prefix and mess
1968 // that up.)
public_key_from_external(struct HRSS_public_key * ext)1969 static struct public_key *public_key_from_external(
1970     struct HRSS_public_key *ext) {
1971   OPENSSL_STATIC_ASSERT(
1972       sizeof(struct HRSS_public_key) >= sizeof(struct public_key) + 15,
1973       "HRSS public key too small");
1974 
1975   uintptr_t p = (uintptr_t)ext;
1976   p = (p + 15) & ~15;
1977   return (struct public_key *)p;
1978 }
1979 
1980 // private_key_from_external does the same thing as |public_key_from_external|,
1981 // but for private keys. See the comment on that function about alignment
1982 // issues.
private_key_from_external(struct HRSS_private_key * ext)1983 static struct private_key *private_key_from_external(
1984     struct HRSS_private_key *ext) {
1985   OPENSSL_STATIC_ASSERT(
1986       sizeof(struct HRSS_private_key) >= sizeof(struct private_key) + 15,
1987       "HRSS private key too small");
1988 
1989   uintptr_t p = (uintptr_t)ext;
1990   p = (p + 15) & ~15;
1991   return (struct private_key *)p;
1992 }
1993 
HRSS_generate_key(struct HRSS_public_key * out_pub,struct HRSS_private_key * out_priv,const uint8_t in[HRSS_SAMPLE_BYTES+HRSS_SAMPLE_BYTES+32])1994 void HRSS_generate_key(
1995     struct HRSS_public_key *out_pub, struct HRSS_private_key *out_priv,
1996     const uint8_t in[HRSS_SAMPLE_BYTES + HRSS_SAMPLE_BYTES + 32]) {
1997   struct public_key *pub = public_key_from_external(out_pub);
1998   struct private_key *priv = private_key_from_external(out_priv);
1999 
2000   OPENSSL_memcpy(priv->hmac_key, in + 2 * HRSS_SAMPLE_BYTES,
2001                  sizeof(priv->hmac_key));
2002 
2003   struct poly f;
2004   poly_short_sample_plus(&f, in);
2005   poly3_from_poly(&priv->f, &f);
2006   HRSS_poly3_invert(&priv->f_inverse, &priv->f);
2007 
2008   // pg_phi1 is p (i.e. 3) × g × Φ(1) (i.e. ��-1).
2009   struct poly pg_phi1;
2010   poly_short_sample_plus(&pg_phi1, in + HRSS_SAMPLE_BYTES);
2011   for (unsigned i = 0; i < N; i++) {
2012     pg_phi1.v[i] *= 3;
2013   }
2014   poly_mul_x_minus_1(&pg_phi1);
2015 
2016   struct poly pfg_phi1;
2017   poly_mul(&pfg_phi1, &f, &pg_phi1);
2018 
2019   struct poly pfg_phi1_inverse;
2020   poly_invert(&pfg_phi1_inverse, &pfg_phi1);
2021 
2022   poly_mul(&pub->ph, &pfg_phi1_inverse, &pg_phi1);
2023   poly_mul(&pub->ph, &pub->ph, &pg_phi1);
2024   poly_clamp(&pub->ph);
2025 
2026   poly_mul(&priv->ph_inverse, &pfg_phi1_inverse, &f);
2027   poly_mul(&priv->ph_inverse, &priv->ph_inverse, &f);
2028   poly_clamp(&priv->ph_inverse);
2029 }
2030 
2031 static const char kSharedKey[] = "shared key";
2032 
HRSS_encap(uint8_t out_ciphertext[POLY_BYTES],uint8_t out_shared_key[32],const struct HRSS_public_key * in_pub,const uint8_t in[HRSS_SAMPLE_BYTES+HRSS_SAMPLE_BYTES])2033 void HRSS_encap(uint8_t out_ciphertext[POLY_BYTES],
2034                 uint8_t out_shared_key[32],
2035                 const struct HRSS_public_key *in_pub,
2036                 const uint8_t in[HRSS_SAMPLE_BYTES + HRSS_SAMPLE_BYTES]) {
2037   const struct public_key *pub =
2038       public_key_from_external((struct HRSS_public_key *)in_pub);
2039   struct poly m, r, m_lifted;
2040   poly_short_sample(&m, in);
2041   poly_short_sample(&r, in + HRSS_SAMPLE_BYTES);
2042   poly_lift(&m_lifted, &m);
2043 
2044   struct poly prh_plus_m;
2045   poly_mul(&prh_plus_m, &r, &pub->ph);
2046   for (unsigned i = 0; i < N; i++) {
2047     prh_plus_m.v[i] += m_lifted.v[i];
2048   }
2049 
2050   poly_marshal(out_ciphertext, &prh_plus_m);
2051 
2052   uint8_t m_bytes[HRSS_POLY3_BYTES], r_bytes[HRSS_POLY3_BYTES];
2053   poly_marshal_mod3(m_bytes, &m);
2054   poly_marshal_mod3(r_bytes, &r);
2055 
2056   SHA256_CTX hash_ctx;
2057   SHA256_Init(&hash_ctx);
2058   SHA256_Update(&hash_ctx, kSharedKey, sizeof(kSharedKey));
2059   SHA256_Update(&hash_ctx, m_bytes, sizeof(m_bytes));
2060   SHA256_Update(&hash_ctx, r_bytes, sizeof(r_bytes));
2061   SHA256_Update(&hash_ctx, out_ciphertext, POLY_BYTES);
2062   SHA256_Final(out_shared_key, &hash_ctx);
2063 }
2064 
HRSS_decap(uint8_t out_shared_key[HRSS_KEY_BYTES],const struct HRSS_private_key * in_priv,const uint8_t * ciphertext,size_t ciphertext_len)2065 void HRSS_decap(uint8_t out_shared_key[HRSS_KEY_BYTES],
2066                 const struct HRSS_private_key *in_priv,
2067                 const uint8_t *ciphertext, size_t ciphertext_len) {
2068   const struct private_key *priv =
2069       private_key_from_external((struct HRSS_private_key *)in_priv);
2070 
2071   // This is HMAC, expanded inline rather than using the |HMAC| function so that
2072   // we can avoid dealing with possible allocation failures and so keep this
2073   // function infallible.
2074   uint8_t masked_key[SHA256_CBLOCK];
2075   OPENSSL_STATIC_ASSERT(sizeof(priv->hmac_key) <= sizeof(masked_key),
2076                         "HRSS HMAC key larger than SHA-256 block size");
2077   for (size_t i = 0; i < sizeof(priv->hmac_key); i++) {
2078     masked_key[i] = priv->hmac_key[i] ^ 0x36;
2079   }
2080   OPENSSL_memset(masked_key + sizeof(priv->hmac_key), 0x36,
2081                  sizeof(masked_key) - sizeof(priv->hmac_key));
2082 
2083   SHA256_CTX hash_ctx;
2084   SHA256_Init(&hash_ctx);
2085   SHA256_Update(&hash_ctx, masked_key, sizeof(masked_key));
2086   SHA256_Update(&hash_ctx, ciphertext, ciphertext_len);
2087   uint8_t inner_digest[SHA256_DIGEST_LENGTH];
2088   SHA256_Final(inner_digest, &hash_ctx);
2089 
2090   for (size_t i = 0; i < sizeof(priv->hmac_key); i++) {
2091     masked_key[i] ^= (0x5c ^ 0x36);
2092   }
2093   OPENSSL_memset(masked_key + sizeof(priv->hmac_key), 0x5c,
2094                  sizeof(masked_key) - sizeof(priv->hmac_key));
2095 
2096   SHA256_Init(&hash_ctx);
2097   SHA256_Update(&hash_ctx, masked_key, sizeof(masked_key));
2098   SHA256_Update(&hash_ctx, inner_digest, sizeof(inner_digest));
2099   OPENSSL_STATIC_ASSERT(HRSS_KEY_BYTES == SHA256_DIGEST_LENGTH,
2100                         "HRSS shared key length incorrect");
2101   SHA256_Final(out_shared_key, &hash_ctx);
2102 
2103   struct poly c;
2104   // If the ciphertext is publicly invalid then a random shared key is still
2105   // returned to simply the logic of the caller, but this path is not constant
2106   // time.
2107   if (ciphertext_len != HRSS_CIPHERTEXT_BYTES ||
2108       !poly_unmarshal(&c, ciphertext)) {
2109     return;
2110   }
2111 
2112   struct poly f, cf;
2113   struct poly3 cf3, m3;
2114   poly_from_poly3(&f, &priv->f);
2115   poly_mul(&cf, &c, &f);
2116   poly3_from_poly(&cf3, &cf);
2117   // Note that cf3 is not reduced mod Φ(N). That reduction is deferred.
2118   HRSS_poly3_mul(&m3, &cf3, &priv->f_inverse);
2119 
2120   struct poly m, m_lifted;
2121   poly_from_poly3(&m, &m3);
2122   poly_lift(&m_lifted, &m);
2123 
2124   struct poly r;
2125   for (unsigned i = 0; i < N; i++) {
2126     r.v[i] = c.v[i] - m_lifted.v[i];
2127   }
2128   poly_mul(&r, &r, &priv->ph_inverse);
2129   poly_mod_phiN(&r);
2130   poly_clamp(&r);
2131 
2132   struct poly3 r3;
2133   crypto_word_t ok = poly3_from_poly_checked(&r3, &r);
2134 
2135   // [NTRUCOMP] section 5.1 includes ReEnc2 and a proof that it's valid. Rather
2136   // than do an expensive |poly_mul|, it rebuilds |c'| from |c - lift(m)|
2137   // (called |b|) with:
2138   //   t = (−b(1)/N) mod Q
2139   //   c' = b + tΦ(N) + lift(m) mod Q
2140   //
2141   // When polynomials are transmitted, the final coefficient is omitted and
2142   // |poly_unmarshal| sets it such that f(1) == 0. Thus c(1) == 0. Also,
2143   // |poly_lift| multiplies the result by (x-1) and therefore evaluating a
2144   // lifted polynomial at 1 is also zero. Thus lift(m)(1) == 0 and so
2145   // (c - lift(m))(1) == 0.
2146   //
2147   // Although we defer the reduction above, |b| is conceptually reduced mod
2148   // Φ(N). In order to do that reduction one subtracts |c[N-1]| from every
2149   // coefficient. Therefore b(1) = -c[N-1]×N. The value of |t|, above, then is
2150   // just recovering |c[N-1]|, and adding tΦ(N) is simply undoing the reduction.
2151   // Therefore b + tΦ(N) + lift(m) = c by construction and we don't need to
2152   // recover |c| at all so long as we do the checks in
2153   // |poly3_from_poly_checked|.
2154   //
2155   // The |poly_marshal| here then is just confirming that |poly_unmarshal| is
2156   // strict and could be omitted.
2157 
2158   uint8_t expected_ciphertext[HRSS_CIPHERTEXT_BYTES];
2159   OPENSSL_STATIC_ASSERT(HRSS_CIPHERTEXT_BYTES == POLY_BYTES,
2160                         "ciphertext is the wrong size");
2161   assert(ciphertext_len == sizeof(expected_ciphertext));
2162   poly_marshal(expected_ciphertext, &c);
2163 
2164   uint8_t m_bytes[HRSS_POLY3_BYTES];
2165   uint8_t r_bytes[HRSS_POLY3_BYTES];
2166   poly_marshal_mod3(m_bytes, &m);
2167   poly_marshal_mod3(r_bytes, &r);
2168 
2169   ok &= constant_time_is_zero_w(CRYPTO_memcmp(ciphertext, expected_ciphertext,
2170                                               sizeof(expected_ciphertext)));
2171 
2172   uint8_t shared_key[32];
2173   SHA256_Init(&hash_ctx);
2174   SHA256_Update(&hash_ctx, kSharedKey, sizeof(kSharedKey));
2175   SHA256_Update(&hash_ctx, m_bytes, sizeof(m_bytes));
2176   SHA256_Update(&hash_ctx, r_bytes, sizeof(r_bytes));
2177   SHA256_Update(&hash_ctx, expected_ciphertext, sizeof(expected_ciphertext));
2178   SHA256_Final(shared_key, &hash_ctx);
2179 
2180   for (unsigned i = 0; i < sizeof(shared_key); i++) {
2181     out_shared_key[i] =
2182         constant_time_select_8(ok, shared_key[i], out_shared_key[i]);
2183   }
2184 }
2185 
HRSS_marshal_public_key(uint8_t out[HRSS_PUBLIC_KEY_BYTES],const struct HRSS_public_key * in_pub)2186 void HRSS_marshal_public_key(uint8_t out[HRSS_PUBLIC_KEY_BYTES],
2187                              const struct HRSS_public_key *in_pub) {
2188   const struct public_key *pub =
2189       public_key_from_external((struct HRSS_public_key *)in_pub);
2190   poly_marshal(out, &pub->ph);
2191 }
2192 
HRSS_parse_public_key(struct HRSS_public_key * out,const uint8_t in[HRSS_PUBLIC_KEY_BYTES])2193 int HRSS_parse_public_key(struct HRSS_public_key *out,
2194                           const uint8_t in[HRSS_PUBLIC_KEY_BYTES]) {
2195   struct public_key *pub = public_key_from_external(out);
2196   if (!poly_unmarshal(&pub->ph, in)) {
2197     return 0;
2198   }
2199   OPENSSL_memset(&pub->ph.v[N], 0, 3 * sizeof(uint16_t));
2200   return 1;
2201 }
2202