1package hrss
2
3import (
4	"crypto/hmac"
5	"crypto/sha256"
6	"crypto/subtle"
7	"encoding/binary"
8	"io"
9	"math/bits"
10)
11
12const (
13	PublicKeySize  = modQBytes
14	CiphertextSize = modQBytes
15)
16
17const (
18	N         = 701
19	Q         = 8192
20	mod3Bytes = 140
21	modQBytes = 1138
22)
23
24const (
25	bitsPerWord      = bits.UintSize
26	wordsPerPoly     = (N + bitsPerWord - 1) / bitsPerWord
27	fullWordsPerPoly = N / bitsPerWord
28	bitsInLastWord   = N % bitsPerWord
29)
30
31// poly3 represents a degree-N polynomial over GF(3). Each coefficient is
32// bitsliced across the |s| and |a| arrays, like this:
33//
34//   s  |  a  | value
35//  -----------------
36//   0  |  0  | 0
37//   0  |  1  | 1
38//   1  |  0  | 2 (aka -1)
39//   1  |  1  | <invalid>
40//
41// ('s' is for sign, and 'a' is just a letter.)
42//
43// Once bitsliced as such, the following circuits can be used to implement
44// addition and multiplication mod 3:
45//
46//   (s3, a3) = (s1, a1) × (s2, a2)
47//   s3 = (s2 ∧ a1) ⊕ (s1 ∧ a2)
48//   a3 = (s1 ∧ s2) ⊕ (a1 ∧ a2)
49//
50//   (s3, a3) = (s1, a1) + (s2, a2)
51//   t1 = ~(s1 ∨ a1)
52//   t2 = ~(s2 ∨ a2)
53//   s3 = (a1 ∧ a2) ⊕ (t1 ∧ s2) ⊕ (t2 ∧ s1)
54//   a3 = (s1 ∧ s2) ⊕ (t1 ∧ a2) ⊕ (t2 ∧ a1)
55//
56// Negating a value just involves swapping s and a.
57type poly3 struct {
58	s [wordsPerPoly]uint
59	a [wordsPerPoly]uint
60}
61
62func (p *poly3) trim() {
63	p.s[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1
64	p.a[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1
65}
66
67func (p *poly3) zero() {
68	for i := range p.a {
69		p.s[i] = 0
70		p.a[i] = 0
71	}
72}
73
74func (p *poly3) fromDiscrete(in *poly) {
75	var shift uint
76	s := p.s[:]
77	a := p.a[:]
78	s[0] = 0
79	a[0] = 0
80
81	for _, v := range in {
82		s[0] >>= 1
83		s[0] |= uint((v>>1)&1) << (bitsPerWord - 1)
84		a[0] >>= 1
85		a[0] |= uint(v&1) << (bitsPerWord - 1)
86		shift++
87		if shift == bitsPerWord {
88			s = s[1:]
89			a = a[1:]
90			s[0] = 0
91			a[0] = 0
92			shift = 0
93		}
94	}
95
96	a[0] >>= bitsPerWord - shift
97	s[0] >>= bitsPerWord - shift
98}
99
100func (p *poly3) fromModQ(in *poly) int {
101	var shift uint
102	s := p.s[:]
103	a := p.a[:]
104	s[0] = 0
105	a[0] = 0
106	ok := 1
107
108	for _, v := range in {
109		vMod3, vOk := modQToMod3(v)
110		ok &= vOk
111
112		s[0] >>= 1
113		s[0] |= uint((vMod3>>1)&1) << (bitsPerWord - 1)
114		a[0] >>= 1
115		a[0] |= uint(vMod3&1) << (bitsPerWord - 1)
116		shift++
117		if shift == bitsPerWord {
118			s = s[1:]
119			a = a[1:]
120			s[0] = 0
121			a[0] = 0
122			shift = 0
123		}
124	}
125
126	a[0] >>= bitsPerWord - shift
127	s[0] >>= bitsPerWord - shift
128
129	return ok
130}
131
132func (p *poly3) fromDiscreteMod3(in *poly) {
133	var shift uint
134	s := p.s[:]
135	a := p.a[:]
136	s[0] = 0
137	a[0] = 0
138
139	for _, v := range in {
140		// This duplicates the 13th bit upwards to the top of the
141		// uint16, essentially treating it as a sign bit and converting
142		// into a signed int16. The signed value is reduced mod 3,
143		// yeilding {-2, -1, 0, 1, 2}.
144		v = uint16((int16(v<<3)>>3)%3) & 7
145
146		// We want to map v thus:
147		// {-2, -1, 0, 1, 2} -> {1, 2, 0, 1, 2}. We take the bottom
148		// three bits and then the constants below, when shifted by
149		// those three bits, perform the required mapping.
150		s[0] >>= 1
151		s[0] |= (0xbc >> v) << (bitsPerWord - 1)
152		a[0] >>= 1
153		a[0] |= (0x7a >> v) << (bitsPerWord - 1)
154		shift++
155		if shift == bitsPerWord {
156			s = s[1:]
157			a = a[1:]
158			s[0] = 0
159			a[0] = 0
160			shift = 0
161		}
162	}
163
164	a[0] >>= bitsPerWord - shift
165	s[0] >>= bitsPerWord - shift
166}
167
168func (p *poly3) marshal(out []byte) {
169	s := p.s[:]
170	a := p.a[:]
171	sw := s[0]
172	aw := a[0]
173	var shift int
174
175	for i := 0; i < 700; i += 5 {
176		acc, scale := 0, 1
177		for j := 0; j < 5; j++ {
178			v := int(aw&1) | int(sw&1)<<1
179			acc += scale * v
180			scale *= 3
181
182			shift++
183			if shift == bitsPerWord {
184				s = s[1:]
185				a = a[1:]
186				sw = s[0]
187				aw = a[0]
188				shift = 0
189			} else {
190				sw >>= 1
191				aw >>= 1
192			}
193		}
194
195		out[0] = byte(acc)
196		out = out[1:]
197	}
198}
199
200func (p *poly) fromMod2(in *poly2) {
201	var shift uint
202	words := in[:]
203	word := words[0]
204
205	for i := range p {
206		p[i] = uint16(word & 1)
207		word >>= 1
208		shift++
209		if shift == bitsPerWord {
210			words = words[1:]
211			word = words[0]
212			shift = 0
213		}
214	}
215}
216
217func (p *poly) fromMod3(in *poly3) {
218	var shift uint
219	s := in.s[:]
220	a := in.a[:]
221	sw := s[0]
222	aw := a[0]
223
224	for i := range p {
225		p[i] = uint16(aw&1 | (sw&1)<<1)
226		aw >>= 1
227		sw >>= 1
228		shift++
229		if shift == bitsPerWord {
230			a = a[1:]
231			s = s[1:]
232			aw = a[0]
233			sw = s[0]
234			shift = 0
235		}
236	}
237}
238
239func (p *poly) fromMod3ToModQ(in *poly3) {
240	var shift uint
241	s := in.s[:]
242	a := in.a[:]
243	sw := s[0]
244	aw := a[0]
245
246	for i := range p {
247		p[i] = mod3ToModQ(uint16(aw&1 | (sw&1)<<1))
248		aw >>= 1
249		sw >>= 1
250		shift++
251		if shift == bitsPerWord {
252			a = a[1:]
253			s = s[1:]
254			aw = a[0]
255			sw = s[0]
256			shift = 0
257		}
258	}
259}
260
261func lsbToAll(v uint) uint {
262	return uint(int(v<<(bitsPerWord-1)) >> (bitsPerWord - 1))
263}
264
265func (p *poly3) mulConst(ms, ma uint) {
266	ms = lsbToAll(ms)
267	ma = lsbToAll(ma)
268
269	for i := range p.a {
270		p.s[i], p.a[i] = (ma&p.s[i])^(ms&p.a[i]), (ma&p.a[i])^(ms&p.s[i])
271	}
272}
273
274func cmovWords(out, in *[wordsPerPoly]uint, mov uint) {
275	for i := range out {
276		out[i] = (out[i] & ^mov) | (in[i] & mov)
277	}
278}
279
280func rotWords(out, in *[wordsPerPoly]uint, bits uint) {
281	start := bits / bitsPerWord
282	n := (N - bits) / bitsPerWord
283
284	for i := uint(0); i < n; i++ {
285		out[i] = in[start+i]
286	}
287
288	carry := in[wordsPerPoly-1]
289
290	for i := uint(0); i < start; i++ {
291		out[n+i] = carry | in[i]<<bitsInLastWord
292		carry = in[i] >> (bitsPerWord - bitsInLastWord)
293	}
294
295	out[wordsPerPoly-1] = carry
296}
297
298// rotBits right-rotates the bits in |in|. bits must be a non-zero power of two
299// and less than bitsPerWord.
300func rotBits(out, in *[wordsPerPoly]uint, bits uint) {
301	if (bits == 0 || (bits & (bits - 1)) != 0 || bits > bitsPerWord/2 || bitsInLastWord < bitsPerWord/2) {
302		panic("internal error");
303	}
304
305	carry := in[wordsPerPoly-1] << (bitsPerWord - bits)
306
307	for i := wordsPerPoly - 2; i >= 0; i-- {
308		out[i] = carry | in[i]>>bits
309		carry = in[i] << (bitsPerWord - bits)
310	}
311
312	out[wordsPerPoly-1] = carry>>(bitsPerWord-bitsInLastWord) | in[wordsPerPoly-1]>>bits
313}
314
315func (p *poly3) rotWords(bits uint, in *poly3) {
316	rotWords(&p.s, &in.s, bits)
317	rotWords(&p.a, &in.a, bits)
318}
319
320func (p *poly3) rotBits(bits uint, in *poly3) {
321	rotBits(&p.s, &in.s, bits)
322	rotBits(&p.a, &in.a, bits)
323}
324
325func (p *poly3) cmov(in *poly3, mov uint) {
326	cmovWords(&p.s, &in.s, mov)
327	cmovWords(&p.a, &in.a, mov)
328}
329
330func (p *poly3) rot(bits uint) {
331	if bits > N {
332		panic("invalid")
333	}
334	var shifted poly3
335
336	shift := uint(9)
337	for ; (1 << shift) >= bitsPerWord; shift-- {
338		shifted.rotWords(1<<shift, p)
339		p.cmov(&shifted, lsbToAll(bits>>shift))
340	}
341	for ; shift < 9; shift-- {
342		shifted.rotBits(1<<shift, p)
343		p.cmov(&shifted, lsbToAll(bits>>shift))
344	}
345}
346
347func (p *poly3) fmadd(ms, ma uint, in *poly3) {
348	ms = lsbToAll(ms)
349	ma = lsbToAll(ma)
350
351	for i := range p.a {
352		products := (ma & in.s[i]) ^ (ms & in.a[i])
353		producta := (ma & in.a[i]) ^ (ms & in.s[i])
354
355		ns1Ana1 := ^p.s[i] & ^p.a[i]
356		ns2Ana2 := ^products & ^producta
357
358		p.s[i], p.a[i] = (p.a[i]&producta)^(ns1Ana1&products)^(p.s[i]&ns2Ana2), (p.s[i]&products)^(ns1Ana1&producta)^(p.a[i]&ns2Ana2)
359	}
360}
361
362func (p *poly3) modPhiN() {
363	factora := uint(int(p.s[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1))
364	factors := uint(int(p.a[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1))
365	ns2Ana2 := ^factors & ^factora
366
367	for i := range p.s {
368		ns1Ana1 := ^p.s[i] & ^p.a[i]
369		p.s[i], p.a[i] = (p.a[i]&factora)^(ns1Ana1&factors)^(p.s[i]&ns2Ana2), (p.s[i]&factors)^(ns1Ana1&factora)^(p.a[i]&ns2Ana2)
370	}
371}
372
373func (p *poly3) cswap(other *poly3, swap uint) {
374	for i := range p.s {
375		sums := swap & (p.s[i] ^ other.s[i])
376		p.s[i] ^= sums
377		other.s[i] ^= sums
378
379		suma := swap & (p.a[i] ^ other.a[i])
380		p.a[i] ^= suma
381		other.a[i] ^= suma
382	}
383}
384
385func (p *poly3) mulx() {
386	carrys := (p.s[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1
387	carrya := (p.a[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1
388
389	for i := range p.s {
390		outCarrys := p.s[i] >> (bitsPerWord - 1)
391		outCarrya := p.a[i] >> (bitsPerWord - 1)
392		p.s[i] <<= 1
393		p.a[i] <<= 1
394		p.s[i] |= carrys
395		p.a[i] |= carrya
396		carrys = outCarrys
397		carrya = outCarrya
398	}
399}
400
401func (p *poly3) divx() {
402	var carrys, carrya uint
403
404	for i := len(p.s) - 1; i >= 0; i-- {
405		outCarrys := p.s[i] & 1
406		outCarrya := p.a[i] & 1
407		p.s[i] >>= 1
408		p.a[i] >>= 1
409		p.s[i] |= carrys << (bitsPerWord - 1)
410		p.a[i] |= carrya << (bitsPerWord - 1)
411		carrys = outCarrys
412		carrya = outCarrya
413	}
414}
415
416type poly2 [wordsPerPoly]uint
417
418func (p *poly2) fromDiscrete(in *poly) {
419	var shift uint
420	words := p[:]
421	words[0] = 0
422
423	for _, v := range in {
424		words[0] >>= 1
425		words[0] |= uint(v&1) << (bitsPerWord - 1)
426		shift++
427		if shift == bitsPerWord {
428			words = words[1:]
429			words[0] = 0
430			shift = 0
431		}
432	}
433
434	words[0] >>= bitsPerWord - shift
435}
436
437func (p *poly2) setPhiN() {
438	for i := range p {
439		p[i] = ^uint(0)
440	}
441	p[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1
442}
443
444func (p *poly2) cswap(other *poly2, swap uint) {
445	for i := range p {
446		sum := swap & (p[i] ^ other[i])
447		p[i] ^= sum
448		other[i] ^= sum
449	}
450}
451
452func (p *poly2) fmadd(m uint, in *poly2) {
453	m = ^(m - 1)
454
455	for i := range p {
456		p[i] ^= in[i] & m
457	}
458}
459
460func (p *poly2) lshift1() {
461	var carry uint
462	for i := range p {
463		nextCarry := p[i] >> (bitsPerWord - 1)
464		p[i] <<= 1
465		p[i] |= carry
466		carry = nextCarry
467	}
468}
469
470func (p *poly2) rshift1() {
471	var carry uint
472	for i := len(p) - 1; i >= 0; i-- {
473		nextCarry := p[i] & 1
474		p[i] >>= 1
475		p[i] |= carry << (bitsPerWord - 1)
476		carry = nextCarry
477	}
478}
479
480func (p *poly2) rot(bits uint) {
481	if bits > N {
482		panic("invalid")
483	}
484	var shifted [wordsPerPoly]uint
485	out := (*[wordsPerPoly]uint)(p)
486
487	shift := uint(9)
488	for ; (1 << shift) >= bitsPerWord; shift-- {
489		rotWords(&shifted, out, 1<<shift)
490		cmovWords(out, &shifted, lsbToAll(bits>>shift))
491	}
492	for ; shift < 9; shift-- {
493		rotBits(&shifted, out, 1<<shift)
494		cmovWords(out, &shifted, lsbToAll(bits>>shift))
495	}
496}
497
498type poly [N]uint16
499
500func (in *poly) marshal(out []byte) {
501	p := in[:]
502
503	for len(p) >= 8 {
504		out[0] = byte(p[0])
505		out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5)
506		out[2] = byte(p[1] >> 3)
507		out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2)
508		out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7)
509		out[5] = byte(p[3] >> 1)
510		out[6] = byte(p[3]>>9) | byte((p[4]&0x0f)<<4)
511		out[7] = byte(p[4] >> 4)
512		out[8] = byte(p[4]>>12) | byte((p[5]&0x7f)<<1)
513		out[9] = byte(p[5]>>7) | byte((p[6]&0x03)<<6)
514		out[10] = byte(p[6] >> 2)
515		out[11] = byte(p[6]>>10) | byte((p[7]&0x1f)<<3)
516		out[12] = byte(p[7] >> 5)
517
518		p = p[8:]
519		out = out[13:]
520	}
521
522	// There are four remaining values.
523	out[0] = byte(p[0])
524	out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5)
525	out[2] = byte(p[1] >> 3)
526	out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2)
527	out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7)
528	out[5] = byte(p[3] >> 1)
529	out[6] = byte(p[3] >> 9)
530}
531
532func (out *poly) unmarshal(in []byte) bool {
533	p := out[:]
534	for i := 0; i < 87; i++ {
535		p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8
536		p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11
537		p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6
538		p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9
539		p[4] = uint16(in[6]>>4) | uint16(in[7])<<4 | uint16(in[8]&1)<<12
540		p[5] = uint16(in[8]>>1) | uint16(in[9]&0x3f)<<7
541		p[6] = uint16(in[9]>>6) | uint16(in[10])<<2 | uint16(in[11]&7)<<10
542		p[7] = uint16(in[11]>>3) | uint16(in[12])<<5
543
544		p = p[8:]
545		in = in[13:]
546	}
547
548	// There are four coefficients left over
549	p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8
550	p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11
551	p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6
552	p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9
553
554	if in[6]&0xf0 != 0 {
555		return false
556	}
557
558	out[N-1] = 0
559	var top int
560	for _, v := range out {
561		top += int(v)
562	}
563
564	out[N-1] = uint16(-top) % Q
565	return true
566}
567
568func (in *poly) marshalS3(out []byte) {
569	p := in[:]
570	for len(p) >= 5 {
571		out[0] = byte(p[0] + p[1]*3 + p[2]*9 + p[3]*27 + p[4]*81)
572		out = out[1:]
573		p = p[5:]
574	}
575}
576
577func (out *poly) unmarshalS3(in []byte) bool {
578	p := out[:]
579	for i := 0; i < 140; i++ {
580		c := in[0]
581		if c >= 243 {
582			return false
583		}
584		p[0] = uint16(c % 3)
585		p[1] = uint16((c / 3) % 3)
586		p[2] = uint16((c / 9) % 3)
587		p[3] = uint16((c / 27) % 3)
588		p[4] = uint16((c / 81) % 3)
589
590		p = p[5:]
591		in = in[1:]
592	}
593
594	out[N-1] = 0
595	return true
596}
597
598func (p *poly) modPhiN() {
599	for i := range p {
600		p[i] = (p[i] + Q - p[N-1]) % Q
601	}
602}
603
604func (out *poly) shortSample(in []byte) {
605	//  b  a  result
606	// 00 00 00
607	// 00 01 01
608	// 00 10 10
609	// 00 11 11
610	// 01 00 10
611	// 01 01 00
612	// 01 10 01
613	// 01 11 11
614	// 10 00 01
615	// 10 01 10
616	// 10 10 00
617	// 10 11 11
618	// 11 00 11
619	// 11 01 11
620	// 11 10 11
621	// 11 11 11
622
623	// 1111 1111 1100 1001 1101 0010 1110 0100
624	//   f    f    c    9    d    2    e    4
625	const lookup = uint32(0xffc9d2e4)
626
627	p := out[:]
628	for i := 0; i < 87; i++ {
629		v := binary.LittleEndian.Uint32(in)
630		v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555)
631		for j := 0; j < 8; j++ {
632			p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3)
633			v2 >>= 4
634		}
635		p = p[8:]
636		in = in[4:]
637	}
638
639	// There are four values remaining.
640	v := binary.LittleEndian.Uint32(in)
641	v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555)
642	for j := 0; j < 4; j++ {
643		p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3)
644		v2 >>= 4
645	}
646
647	out[N-1] = 0
648}
649
650func (out *poly) shortSamplePlus(in []byte) {
651	out.shortSample(in)
652
653	var sum uint16
654	for i := 0; i < N-1; i++ {
655		sum += mod3ResultToModQ(out[i] * out[i+1])
656	}
657
658	scale := 1 + (1 & (sum >> 12))
659	for i := 0; i < len(out); i += 2 {
660		out[i] = (out[i] * scale) % 3
661	}
662}
663
664func mul(out, scratch, a, b []uint16) {
665	const schoolbookLimit = 32
666	if len(a) < schoolbookLimit {
667		for i := 0; i < len(a)*2; i++ {
668			out[i] = 0
669		}
670		for i := range a {
671			for j := range b {
672				out[i+j] += a[i] * b[j]
673			}
674		}
675		return
676	}
677
678	lowLen := len(a) / 2
679	highLen := len(a) - lowLen
680	aLow, aHigh := a[:lowLen], a[lowLen:]
681	bLow, bHigh := b[:lowLen], b[lowLen:]
682
683	for i := 0; i < lowLen; i++ {
684		out[i] = aHigh[i] + aLow[i]
685	}
686	if highLen != lowLen {
687		out[lowLen] = aHigh[lowLen]
688	}
689
690	for i := 0; i < lowLen; i++ {
691		out[highLen+i] = bHigh[i] + bLow[i]
692	}
693	if highLen != lowLen {
694		out[highLen+lowLen] = bHigh[lowLen]
695	}
696
697	mul(scratch, scratch[2*highLen:], out[:highLen], out[highLen:highLen*2])
698	mul(out[lowLen*2:], scratch[2*highLen:], aHigh, bHigh)
699	mul(out, scratch[2*highLen:], aLow, bLow)
700
701	for i := 0; i < lowLen*2; i++ {
702		scratch[i] -= out[i] + out[lowLen*2+i]
703	}
704	if lowLen != highLen {
705		scratch[lowLen*2] -= out[lowLen*4]
706	}
707
708	for i := 0; i < 2*highLen; i++ {
709		out[lowLen+i] += scratch[i]
710	}
711}
712
713func (out *poly) mul(a, b *poly) {
714	var prod, scratch [2 * N]uint16
715	mul(prod[:], scratch[:], a[:], b[:])
716	for i := range out {
717		out[i] = (prod[i] + prod[i+N]) % Q
718	}
719}
720
721func (p3 *poly3) mulMod3(x, y *poly3) {
722	// (��^n - 1) is a multiple of Φ(N) so we can work mod (��^n - 1) here and
723	// (reduce mod Φ(N) afterwards.
724	x3 := *x
725	y3 := *y
726	s := x3.s[:]
727	a := x3.a[:]
728	sw := s[0]
729	aw := a[0]
730	p3.zero()
731	var shift uint
732	for i := 0; i < N; i++ {
733		p3.fmadd(sw, aw, &y3)
734		sw >>= 1
735		aw >>= 1
736		shift++
737		if shift == bitsPerWord {
738			s = s[1:]
739			a = a[1:]
740			sw = s[0]
741			aw = a[0]
742			shift = 0
743		}
744		y3.mulx()
745	}
746	p3.modPhiN()
747}
748
749// mod3ToModQ maps {0, 1, 2, 3} to {0, 1, Q-1, 0xffff}
750// The case of n == 3 should never happen but is included so that modQToMod3
751// can easily catch invalid inputs.
752func mod3ToModQ(n uint16) uint16 {
753	return uint16(uint64(0xffff1fff00010000) >> (16 * n))
754}
755
756// modQToMod3 maps {0, 1, Q-1} to {(0, 0), (0, 1), (1, 0)} and also returns an int
757// which is one if the input is in range and zero otherwise.
758func modQToMod3(n uint16) (uint16, int) {
759	result := (n&3 - (n>>1)&1)
760	return result, subtle.ConstantTimeEq(int32(mod3ToModQ(result)), int32(n))
761}
762
763// mod3ResultToModQ maps {0, 1, 2, 4} to {0, 1, Q-1, 1}
764func mod3ResultToModQ(n uint16) uint16 {
765	return ((((uint16(0x13) >> n) & 1) - 1) & 0x1fff) | ((uint16(0x12) >> n) & 1)
766	//shift := (uint(0x324) >> (2 * n)) & 3
767	//return uint16(uint64(0x00011fff00010000) >> (16 * shift))
768}
769
770// mulXMinus1 sets out to a×(�� - 1) mod (��^n - 1)
771func (out *poly) mulXMinus1() {
772	// Multiplying by (�� - 1) means negating each coefficient and adding in
773	// the value of the previous one.
774	origOut700 := out[700]
775
776	for i := N - 1; i > 0; i-- {
777		out[i] = (Q - out[i] + out[i-1]) % Q
778	}
779	out[0] = (Q - out[0] + origOut700) % Q
780}
781
782func (out *poly) lift(a *poly) {
783	// We wish to calculate a/(��-1) mod Φ(N) over GF(3), where Φ(N) is the
784	// Nth cyclotomic polynomial, i.e. 1 + �� + … + ��^700 (since N is prime).
785
786	// 1/(��-1) has a fairly basic structure that we can exploit to speed this up:
787	//
788	// R.<x> = PolynomialRing(GF(3)…)
789	// inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n))
790	// list(inv)[:15]
791	//   [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]
792	//
793	// This three-element pattern of coefficients repeats for the whole
794	// polynomial.
795	//
796	// Next define the overbar operator such that z̅ = z[0] +
797	// reverse(z[1:]). (Index zero of a polynomial here is the coefficient
798	// of the constant term. So index one is the coefficient of �� and so
799	// on.)
800	//
801	// A less odd way to define this is to see that z̅ negates the indexes,
802	// so z̅[0] = z[-0], z̅[1] = z[-1] and so on.
803	//
804	// The use of z̅  is that, when working mod (��^701 - 1), vz[0] = <v,
805	// z̅>, vz[1] = <v, ��z̅>, …. (Where <a, b> is the inner product: the sum
806	// of the point-wise products.) Although we calculated the inverse mod
807	// Φ(N), we can work mod (��^N - 1) and reduce mod Φ(N) at the end.
808	// (That's because (��^N - 1) is a multiple of Φ(N).)
809	//
810	// When working mod (��^N - 1), multiplication by �� is a right-rotation
811	// of the list of coefficients.
812	//
813	// Thus we can consider what the pattern of z̅, ��z̅, ��^2z̅, … looks like:
814	//
815	// def reverse(xs):
816	//   suffix = list(xs[1:])
817	//   suffix.reverse()
818	//   return [xs[0]] + suffix
819	//
820	// def rotate(xs):
821	//   return [xs[-1]] + xs[:-1]
822	//
823	// zoverbar = reverse(list(inv) + [0])
824	// xzoverbar = rotate(reverse(list(inv) + [0]))
825	// x2zoverbar = rotate(rotate(reverse(list(inv) + [0])))
826	//
827	// zoverbar[:15]
828	//   [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]
829	// xzoverbar[:15]
830	//   [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
831	// x2zoverbar[:15]
832	//   [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
833	//
834	// (For a formula for z̅, see lemma two of appendix B.)
835	//
836	// After the first three elements have been taken care of, all then have
837	// a repeating three-element cycle. The next value (��^3z̅) involves
838	// three rotations of the first pattern, thus the three-element cycle
839	// lines up. However, the discontinuity in the first three elements
840	// obviously moves to a different position. Consider the difference
841	// between ��^3z̅ and z̅:
842	//
843	// [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15]
844	//    [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
845	//
846	// This pattern of differences is the same for all elements, although it
847	// obviously moves right with the rotations.
848	//
849	// From this, we reach algorithm eight of appendix B.
850
851	// Handle the first three elements of the inner products.
852	out[0] = a[0] + a[2]
853	out[1] = a[1]
854	out[2] = 2*a[0] + a[2]
855
856	// Use the repeating pattern to complete the first three inner products.
857	for i := 3; i < 699; i += 3 {
858		out[0] += 2*a[i] + a[i+2]
859		out[1] += a[i] + 2*a[i+1]
860		out[2] += a[i+1] + 2*a[i+2]
861	}
862
863	// Handle the fact that the three-element pattern doesn't fill the
864	// polynomial exactly (since 701 isn't a multiple of three).
865	out[2] += a[700]
866	out[0] += 2 * a[699]
867	out[1] += a[699] + 2*a[700]
868
869	out[0] = out[0] % 3
870	out[1] = out[1] % 3
871	out[2] = out[2] % 3
872
873	// Calculate the remaining inner products by taking advantage of the
874	// fact that the pattern repeats every three cycles and the pattern of
875	// differences is moves with the rotation.
876	for i := 3; i < N; i++ {
877		// Add twice something is the same as subtracting when working
878		// mod 3. Doing it this way avoids underflow. Underflow is bad
879		// because "% 3" doesn't work correctly for negative numbers
880		// here since underflow will wrap to 2^16-1 and 2^16 isn't a
881		// multiple of three.
882		out[i] = (out[i-3] + 2*(a[i-2]+a[i-1]+a[i])) % 3
883	}
884
885	// Reduce mod Φ(N) by subtracting a multiple of out[700] from every
886	// element and convert to mod Q. (See above about adding twice as
887	// subtraction.)
888	v := out[700] * 2
889	for i := range out {
890		out[i] = mod3ToModQ((out[i] + v) % 3)
891	}
892
893	out.mulXMinus1()
894}
895
896func (a *poly) cswap(b *poly, swap uint16) {
897	for i := range a {
898		sum := swap & (a[i] ^ b[i])
899		a[i] ^= sum
900		b[i] ^= sum
901	}
902}
903
904func lt(a, b uint) uint {
905	if a < b {
906		return ^uint(0)
907	}
908	return 0
909}
910
911func bsMul(s1, a1, s2, a2 uint) (s3, a3 uint) {
912	s3 = (a1 & s2) ^ (s1 & a2)
913	a3 = (a1 & a2) ^ (s1 & s2)
914	return
915}
916
917func (out *poly3) invertMod3(in *poly3) {
918	// This algorithm follows algorithm 10 in the paper. (Although note that
919	// the paper appears to have a bug: k should start at zero, not one.)
920	// The best explanation for why it works is in the "Why it works"
921	// section of
922	// https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf.
923	var k uint
924	degF, degG := uint(N-1), uint(N-1)
925
926	var b, c, g poly3
927	f := *in
928
929	for i := range g.a {
930		g.a[i] = ^uint(0)
931	}
932
933	b.a[0] = 1
934
935	var f0s, f0a uint
936	stillGoing := ^uint(0)
937	for i := 0; i < 2*(N-1)-1; i++ {
938		ss, sa := bsMul(f.s[0], f.a[0], g.s[0], g.a[0])
939		ss, sa = sa&stillGoing&1, ss&stillGoing&1
940		shouldSwap := ^uint(int((ss|sa)-1)>>(bitsPerWord-1)) & lt(degF, degG)
941		f.cswap(&g, shouldSwap)
942		b.cswap(&c, shouldSwap)
943		degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap)
944		f.fmadd(ss, sa, &g)
945		b.fmadd(ss, sa, &c)
946
947		f.divx()
948		f.s[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1
949		f.a[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1
950		c.mulx()
951		c.s[0] &= ^uint(1)
952		c.a[0] &= ^uint(1)
953
954		degF--
955		k += 1 & stillGoing
956		f0s = (stillGoing & f.s[0]) | (^stillGoing & f0s)
957		f0a = (stillGoing & f.a[0]) | (^stillGoing & f0a)
958		stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1))
959	}
960
961	k -= N & lt(N, k)
962	*out = b
963	out.rot(k)
964	out.mulConst(f0s, f0a)
965	out.modPhiN()
966}
967
968func (out *poly) invertMod2(a *poly) {
969	// This algorithm follows mix of algorithm 10 in the paper and the first
970	// page of the PDF linked below. (Although note that the paper appears
971	// to have a bug: k should start at zero, not one.) The best explanation
972	// for why it works is in the "Why it works" section of
973	// https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf.
974	var k uint
975	degF, degG := uint(N-1), uint(N-1)
976
977	var f poly2
978	f.fromDiscrete(a)
979	var b, c, g poly2
980	g.setPhiN()
981	b[0] = 1
982
983	stillGoing := ^uint(0)
984	for i := 0; i < 2*(N-1)-1; i++ {
985		s := uint(f[0]&1) & stillGoing
986		shouldSwap := ^(s - 1) & lt(degF, degG)
987		f.cswap(&g, shouldSwap)
988		b.cswap(&c, shouldSwap)
989		degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap)
990		f.fmadd(s, &g)
991		b.fmadd(s, &c)
992
993		f.rshift1()
994		c.lshift1()
995
996		degF--
997		k += 1 & stillGoing
998		stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1))
999	}
1000
1001	k -= N & lt(N, k)
1002	b.rot(k)
1003	out.fromMod2(&b)
1004}
1005
1006func (out *poly) invert(origA *poly) {
1007	// Inversion mod Q, which is done based on the result of inverting mod
1008	// 2. See the NTRU paper, page three.
1009	var a, tmp, tmp2, b poly
1010	b.invertMod2(origA)
1011
1012	// Negate a.
1013	for i := range a {
1014		a[i] = Q - origA[i]
1015	}
1016
1017	// We are working mod Q=2**13 and we need to iterate ceil(log_2(13))
1018	// times, which is four.
1019	for i := 0; i < 4; i++ {
1020		tmp.mul(&a, &b)
1021		tmp[0] += 2
1022		tmp2.mul(&b, &tmp)
1023		b = tmp2
1024	}
1025
1026	*out = b
1027}
1028
1029type PublicKey struct {
1030	h poly
1031}
1032
1033func ParsePublicKey(in []byte) (*PublicKey, bool) {
1034	ret := new(PublicKey)
1035	if !ret.h.unmarshal(in) {
1036		return nil, false
1037	}
1038	return ret, true
1039}
1040
1041func (pub *PublicKey) Marshal() []byte {
1042	ret := make([]byte, modQBytes)
1043	pub.h.marshal(ret)
1044	return ret
1045}
1046
1047func (pub *PublicKey) Encap(rand io.Reader) (ciphertext []byte, sharedKey []byte) {
1048	var randBytes [352 + 352]byte
1049	if _, err := io.ReadFull(rand, randBytes[:]); err != nil {
1050		panic("rand failed")
1051	}
1052
1053	var m, r poly
1054	m.shortSample(randBytes[:352])
1055	r.shortSample(randBytes[352:])
1056
1057	var mBytes, rBytes [mod3Bytes]byte
1058	m.marshalS3(mBytes[:])
1059	r.marshalS3(rBytes[:])
1060
1061	ciphertext = pub.owf(&m, &r)
1062
1063	h := sha256.New()
1064	h.Write([]byte("shared key\x00"))
1065	h.Write(mBytes[:])
1066	h.Write(rBytes[:])
1067	h.Write(ciphertext)
1068	sharedKey = h.Sum(nil)
1069
1070	return ciphertext, sharedKey
1071}
1072
1073func (pub *PublicKey) owf(m, r *poly) []byte {
1074	for i := range r {
1075		r[i] = mod3ToModQ(r[i])
1076	}
1077
1078	var mq poly
1079	mq.lift(m)
1080
1081	var e poly
1082	e.mul(r, &pub.h)
1083	for i := range e {
1084		e[i] = (e[i] + mq[i]) % Q
1085	}
1086
1087	ret := make([]byte, modQBytes)
1088	e.marshal(ret[:])
1089	return ret
1090}
1091
1092type PrivateKey struct {
1093	PublicKey
1094	f, fp   poly3
1095	hInv    poly
1096	hmacKey [32]byte
1097}
1098
1099func (priv *PrivateKey) Marshal() []byte {
1100	var ret [2*mod3Bytes + modQBytes]byte
1101	priv.f.marshal(ret[:])
1102	priv.fp.marshal(ret[mod3Bytes:])
1103	priv.h.marshal(ret[2*mod3Bytes:])
1104	return ret[:]
1105}
1106
1107func (priv *PrivateKey) Decap(ciphertext []byte) (sharedKey []byte, ok bool) {
1108	if len(ciphertext) != modQBytes {
1109		return nil, false
1110	}
1111
1112	var e poly
1113	if !e.unmarshal(ciphertext) {
1114		return nil, false
1115	}
1116
1117	var f poly
1118	f.fromMod3ToModQ(&priv.f)
1119
1120	var v1, m poly
1121	v1.mul(&e, &f)
1122
1123	var v13 poly3
1124	v13.fromDiscreteMod3(&v1)
1125	// Note: v13 is not reduced mod phi(n).
1126
1127	var m3 poly3
1128	m3.mulMod3(&v13, &priv.fp)
1129	m3.modPhiN()
1130	m.fromMod3(&m3)
1131
1132	var mLift, delta poly
1133	mLift.lift(&m)
1134	for i := range delta {
1135		delta[i] = (e[i] - mLift[i] + Q) % Q
1136	}
1137	delta.mul(&delta, &priv.hInv)
1138	delta.modPhiN()
1139
1140	var r poly3
1141	allOk := r.fromModQ(&delta)
1142
1143	var mBytes, rBytes [mod3Bytes]byte
1144	m.marshalS3(mBytes[:])
1145	r.marshal(rBytes[:])
1146
1147	var rPoly poly
1148	rPoly.fromMod3(&r)
1149	expectedCiphertext := priv.PublicKey.owf(&m, &rPoly)
1150
1151	allOk &= subtle.ConstantTimeCompare(ciphertext, expectedCiphertext)
1152
1153	hmacHash := hmac.New(sha256.New, priv.hmacKey[:])
1154	hmacHash.Write(ciphertext)
1155	hmacDigest := hmacHash.Sum(nil)
1156
1157	h := sha256.New()
1158	h.Write([]byte("shared key\x00"))
1159	h.Write(mBytes[:])
1160	h.Write(rBytes[:])
1161	h.Write(ciphertext)
1162	sharedKey = h.Sum(nil)
1163
1164	mask := uint8(allOk - 1)
1165	for i := range sharedKey {
1166		sharedKey[i] = (sharedKey[i] & ^mask) | (hmacDigest[i] & mask)
1167	}
1168
1169	return sharedKey, true
1170}
1171
1172func GenerateKey(rand io.Reader) PrivateKey {
1173	var randBytes [352 + 352]byte
1174	if _, err := io.ReadFull(rand, randBytes[:]); err != nil {
1175		panic("rand failed")
1176	}
1177
1178	var f poly
1179	f.shortSamplePlus(randBytes[:352])
1180	var priv PrivateKey
1181	priv.f.fromDiscrete(&f)
1182	priv.fp.invertMod3(&priv.f)
1183
1184	var g poly
1185	g.shortSamplePlus(randBytes[352:])
1186
1187	var pgPhi1 poly
1188	for i := range g {
1189		pgPhi1[i] = mod3ToModQ(g[i])
1190	}
1191	for i := range pgPhi1 {
1192		pgPhi1[i] = (pgPhi1[i] * 3) % Q
1193	}
1194	pgPhi1.mulXMinus1()
1195
1196	var fModQ poly
1197	fModQ.fromMod3ToModQ(&priv.f)
1198
1199	var pfgPhi1 poly
1200	pfgPhi1.mul(&fModQ, &pgPhi1)
1201
1202	var i poly
1203	i.invert(&pfgPhi1)
1204
1205	priv.h.mul(&i, &pgPhi1)
1206	priv.h.mul(&priv.h, &pgPhi1)
1207
1208	priv.hInv.mul(&i, &fModQ)
1209	priv.hInv.mul(&priv.hInv, &fModQ)
1210
1211	return priv
1212}
1213