1// Copyright (c) 2020, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15// Package hpke implements Hybrid Public Key Encryption (HPKE).
16//
17// See https://tools.ietf.org/html/draft-irtf-cfrg-hpke-07.
18package hpke
19
20import (
21	"crypto/aes"
22	"crypto/cipher"
23	"encoding/binary"
24	"errors"
25
26	"golang.org/x/crypto/chacha20poly1305"
27)
28
29// KEM scheme IDs.
30const (
31	X25519WithHKDFSHA256 uint16 = 0x0020
32)
33
34// HPKE AEAD IDs.
35const (
36	AES128GCM        uint16 = 0x0001
37	AES256GCM        uint16 = 0x0002
38	ChaCha20Poly1305 uint16 = 0x0003
39)
40
41// HPKE KDF IDs.
42const (
43	HKDFSHA256 uint16 = 0x0001
44	HKDFSHA384 uint16 = 0x0002
45	HKDFSHA512 uint16 = 0x0003
46)
47
48// Internal constants.
49const (
50	hpkeModeBase uint8 = 0
51	hpkeModePSK  uint8 = 1
52)
53
54type GenerateKeyPairFunc func() (public []byte, secret []byte, e error)
55
56// Context holds the HPKE state for a sender or a receiver.
57type Context struct {
58	kemID  uint16
59	kdfID  uint16
60	aeadID uint16
61
62	aead cipher.AEAD
63
64	key            []byte
65	baseNonce      []byte
66	seq            uint64
67	exporterSecret []byte
68}
69
70// SetupBaseSenderX25519 corresponds to the spec's SetupBaseS(), but only
71// supports X25519.
72func SetupBaseSenderX25519(kdfID, aeadID uint16, publicKeyR, info []byte, ephemKeygen GenerateKeyPairFunc) (context *Context, enc []byte, err error) {
73	sharedSecret, enc, err := x25519Encap(publicKeyR, ephemKeygen)
74	if err != nil {
75		return nil, nil, err
76	}
77	context, err = keySchedule(hpkeModeBase, X25519WithHKDFSHA256, kdfID, aeadID, sharedSecret, info, nil, nil)
78	return
79}
80
81// SetupBaseReceiverX25519 corresponds to the spec's SetupBaseR(), but only
82// supports X25519.
83func SetupBaseReceiverX25519(kdfID, aeadID uint16, enc, secretKeyR, info []byte) (context *Context, err error) {
84	sharedSecret, err := x25519Decap(enc, secretKeyR)
85	if err != nil {
86		return nil, err
87	}
88	return keySchedule(hpkeModeBase, X25519WithHKDFSHA256, kdfID, aeadID, sharedSecret, info, nil, nil)
89}
90
91// SetupPSKSenderX25519 corresponds to the spec's SetupPSKS(), but only supports
92// X25519.
93func SetupPSKSenderX25519(kdfID, aeadID uint16, publicKeyR, info, psk, pskID []byte, ephemKeygen GenerateKeyPairFunc) (context *Context, enc []byte, err error) {
94	sharedSecret, enc, err := x25519Encap(publicKeyR, ephemKeygen)
95	if err != nil {
96		return nil, nil, err
97	}
98	context, err = keySchedule(hpkeModePSK, X25519WithHKDFSHA256, kdfID, aeadID, sharedSecret, info, psk, pskID)
99	return
100}
101
102// SetupPSKReceiverX25519 corresponds to the spec's SetupPSKR(), but only
103// supports X25519.
104func SetupPSKReceiverX25519(kdfID, aeadID uint16, enc, secretKeyR, info, psk, pskID []byte) (context *Context, err error) {
105	sharedSecret, err := x25519Decap(enc, secretKeyR)
106	if err != nil {
107		return nil, err
108	}
109	context, err = keySchedule(hpkeModePSK, X25519WithHKDFSHA256, kdfID, aeadID, sharedSecret, info, psk, pskID)
110	if err != nil {
111		return nil, err
112	}
113	return context, nil
114}
115
116func (c *Context) Seal(additionalData, plaintext []byte) []byte {
117	ciphertext := c.aead.Seal(nil, c.computeNonce(), plaintext, additionalData)
118	c.incrementSeq()
119	return ciphertext
120}
121
122func (c *Context) Open(additionalData, ciphertext []byte) ([]byte, error) {
123	plaintext, err := c.aead.Open(nil, c.computeNonce(), ciphertext, additionalData)
124	if err != nil {
125		return nil, err
126	}
127	c.incrementSeq()
128	return plaintext, nil
129}
130
131func (c *Context) Export(exporterContext []byte, length int) []byte {
132	suiteID := buildSuiteID(c.kemID, c.kdfID, c.aeadID)
133	kdfHash := getKDFHash(c.kdfID)
134	return labeledExpand(kdfHash, c.exporterSecret, suiteID, []byte("sec"), exporterContext, length)
135}
136
137func buildSuiteID(kemID, kdfID, aeadID uint16) []byte {
138	ret := make([]byte, 0, 10)
139	ret = append(ret, "HPKE"...)
140	ret = appendBigEndianUint16(ret, kemID)
141	ret = appendBigEndianUint16(ret, kdfID)
142	ret = appendBigEndianUint16(ret, aeadID)
143	return ret
144}
145
146func newAEAD(aeadID uint16, key []byte) (cipher.AEAD, error) {
147	if len(key) != expectedKeyLength(aeadID) {
148		return nil, errors.New("wrong key length for specified AEAD")
149	}
150	switch aeadID {
151	case AES128GCM, AES256GCM:
152		block, err := aes.NewCipher(key)
153		if err != nil {
154			return nil, err
155		}
156		aead, err := cipher.NewGCM(block)
157		if err != nil {
158			return nil, err
159		}
160		return aead, nil
161	case ChaCha20Poly1305:
162		aead, err := chacha20poly1305.New(key)
163		if err != nil {
164			return nil, err
165		}
166		return aead, nil
167	}
168	return nil, errors.New("unsupported AEAD")
169}
170
171func keySchedule(mode uint8, kemID, kdfID, aeadID uint16, sharedSecret, info, psk, pskID []byte) (*Context, error) {
172	// Verify the PSK inputs.
173	switch mode {
174	case hpkeModeBase:
175		if len(psk) > 0 || len(pskID) > 0 {
176			panic("unnecessary psk inputs were provided")
177		}
178	case hpkeModePSK:
179		if len(psk) == 0 || len(pskID) == 0 {
180			panic("missing psk inputs")
181		}
182	default:
183		panic("unknown mode")
184	}
185
186	kdfHash := getKDFHash(kdfID)
187	suiteID := buildSuiteID(kemID, kdfID, aeadID)
188	pskIDHash := labeledExtract(kdfHash, nil, suiteID, []byte("psk_id_hash"), pskID)
189	infoHash := labeledExtract(kdfHash, nil, suiteID, []byte("info_hash"), info)
190
191	keyScheduleContext := make([]byte, 0)
192	keyScheduleContext = append(keyScheduleContext, mode)
193	keyScheduleContext = append(keyScheduleContext, pskIDHash...)
194	keyScheduleContext = append(keyScheduleContext, infoHash...)
195
196	secret := labeledExtract(kdfHash, sharedSecret, suiteID, []byte("secret"), psk)
197	key := labeledExpand(kdfHash, secret, suiteID, []byte("key"), keyScheduleContext, expectedKeyLength(aeadID))
198
199	aead, err := newAEAD(aeadID, key)
200	if err != nil {
201		return nil, err
202	}
203
204	baseNonce := labeledExpand(kdfHash, secret, suiteID, []byte("base_nonce"), keyScheduleContext, aead.NonceSize())
205	exporterSecret := labeledExpand(kdfHash, secret, suiteID, []byte("exp"), keyScheduleContext, kdfHash.Size())
206
207	return &Context{
208		kemID:          kemID,
209		kdfID:          kdfID,
210		aeadID:         aeadID,
211		aead:           aead,
212		key:            key,
213		baseNonce:      baseNonce,
214		seq:            0,
215		exporterSecret: exporterSecret,
216	}, nil
217}
218
219func (c Context) computeNonce() []byte {
220	nonce := make([]byte, len(c.baseNonce))
221	// Write the big-endian |c.seq| value at the *end* of |baseNonce|.
222	binary.BigEndian.PutUint64(nonce[len(nonce)-8:], c.seq)
223	// XOR the big-endian |seq| with |c.baseNonce|.
224	for i, b := range c.baseNonce {
225		nonce[i] ^= b
226	}
227	return nonce
228}
229
230func (c *Context) incrementSeq() {
231	c.seq++
232	if c.seq == 0 {
233		panic("sequence overflow")
234	}
235}
236
237func expectedKeyLength(aeadID uint16) int {
238	switch aeadID {
239	case AES128GCM:
240		return 128 / 8
241	case AES256GCM:
242		return 256 / 8
243	case ChaCha20Poly1305:
244		return chacha20poly1305.KeySize
245	}
246	panic("unsupported AEAD")
247}
248
249func appendBigEndianUint16(b []byte, v uint16) []byte {
250	return append(b, byte(v>>8), byte(v))
251}
252