1// Copyright (c) 2020, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15package hpke
16
17import (
18	"bytes"
19	_ "crypto/sha256"
20	_ "crypto/sha512"
21	"encoding/hex"
22	"encoding/json"
23	"errors"
24	"flag"
25	"fmt"
26	"io/ioutil"
27	"path/filepath"
28	"testing"
29)
30
31const (
32	exportOnlyAEAD uint16 = 0xffff
33)
34
35var (
36	testDataDir = flag.String("testdata", "testdata", "The path to the test vector JSON file.")
37)
38
39// Simple round-trip test for fixed inputs.
40func TestRoundTrip(t *testing.T) {
41	publicKeyR, secretKeyR, err := GenerateKeyPair()
42	if err != nil {
43		t.Errorf("failed to generate key pair: %s", err)
44		return
45	}
46
47	// Set up the sender and receiver contexts.
48	senderContext, enc, err := SetupBaseSenderX25519(HKDFSHA256, AES256GCM, publicKeyR, nil, nil)
49	if err != nil {
50		t.Errorf("failed to set up sender: %s", err)
51		return
52	}
53	receiverContext, err := SetupBaseReceiverX25519(HKDFSHA256, AES256GCM, enc, secretKeyR, nil)
54	if err != nil {
55		t.Errorf("failed to set up receiver: %s", err)
56		return
57	}
58
59	// Seal() our plaintext with the sender context, then Open() the
60	// ciphertext with the receiver context.
61	plaintext := []byte("foobar")
62	ciphertext := senderContext.Seal(nil, plaintext)
63	decrypted, err := receiverContext.Open(nil, ciphertext)
64	if err != nil {
65		t.Errorf("encryption round trip failed: %s", err)
66		return
67	}
68	checkBytesEqual(t, "decrypted", decrypted, plaintext)
69}
70
71// HpkeTestVector defines the subset of test-vectors.json that we read.
72type HpkeTestVector struct {
73	KEM         uint16                 `json:"kem_id"`
74	Mode        uint8                  `json:"mode"`
75	KDF         uint16                 `json:"kdf_id"`
76	AEAD        uint16                 `json:"aead_id"`
77	Info        HexString              `json:"info"`
78	PSK         HexString              `json:"psk"`
79	PSKID       HexString              `json:"psk_id"`
80	SecretKeyR  HexString              `json:"skRm"`
81	SecretKeyE  HexString              `json:"skEm"`
82	PublicKeyR  HexString              `json:"pkRm"`
83	PublicKeyE  HexString              `json:"pkEm"`
84	Enc         HexString              `json:"enc"`
85	Encryptions []EncryptionTestVector `json:"encryptions"`
86	Exports     []ExportTestVector     `json:"exports"`
87}
88type EncryptionTestVector struct {
89	Plaintext      HexString `json:"plaintext"`
90	AdditionalData HexString `json:"aad"`
91	Ciphertext     HexString `json:"ciphertext"`
92}
93type ExportTestVector struct {
94	ExportContext HexString `json:"exporter_context"`
95	ExportLength  int       `json:"L"`
96	ExportValue   HexString `json:"exported_value"`
97}
98
99// TestVectors checks all relevant test vectors in test-vectors.json.
100func TestVectors(t *testing.T) {
101	jsonStr, err := ioutil.ReadFile(filepath.Join(*testDataDir, "test-vectors.json"))
102	if err != nil {
103		t.Errorf("error reading test vectors: %s", err)
104		return
105	}
106
107	var testVectors []HpkeTestVector
108	err = json.Unmarshal(jsonStr, &testVectors)
109	if err != nil {
110		t.Errorf("error parsing test vectors: %s", err)
111		return
112	}
113
114	var numSkippedTests = 0
115
116	for testNum, testVec := range testVectors {
117		// Skip this vector if it specifies an unsupported parameter.
118		if testVec.KEM != X25519WithHKDFSHA256 ||
119			(testVec.Mode != hpkeModeBase && testVec.Mode != hpkeModePSK) ||
120			testVec.AEAD == exportOnlyAEAD {
121			numSkippedTests++
122			continue
123		}
124
125		testVec := testVec // capture the range variable
126		t.Run(fmt.Sprintf("test%d,Mode=%d,KDF=%d,AEAD=%d", testNum, testVec.Mode, testVec.KDF, testVec.AEAD), func(t *testing.T) {
127			var senderContext *Context
128			var receiverContext *Context
129			var enc []byte
130			var err error
131
132			switch testVec.Mode {
133			case hpkeModeBase:
134				senderContext, enc, err = SetupBaseSenderX25519(testVec.KDF, testVec.AEAD, testVec.PublicKeyR, testVec.Info,
135					func() ([]byte, []byte, error) {
136						return testVec.PublicKeyE, testVec.SecretKeyE, nil
137					})
138				if err != nil {
139					t.Errorf("failed to set up sender: %s", err)
140					return
141				}
142				checkBytesEqual(t, "sender enc", enc, testVec.Enc)
143
144				receiverContext, err = SetupBaseReceiverX25519(testVec.KDF, testVec.AEAD, enc, testVec.SecretKeyR, testVec.Info)
145				if err != nil {
146					t.Errorf("failed to set up receiver: %s", err)
147					return
148				}
149			case hpkeModePSK:
150				senderContext, enc, err = SetupPSKSenderX25519(testVec.KDF, testVec.AEAD, testVec.PublicKeyR, testVec.Info, testVec.PSK, testVec.PSKID,
151					func() ([]byte, []byte, error) {
152						return testVec.PublicKeyE, testVec.SecretKeyE, nil
153					})
154				if err != nil {
155					t.Errorf("failed to set up sender: %s", err)
156					return
157				}
158				checkBytesEqual(t, "sender enc", enc, testVec.Enc)
159
160				receiverContext, err = SetupPSKReceiverX25519(testVec.KDF, testVec.AEAD, enc, testVec.SecretKeyR, testVec.Info, testVec.PSK, testVec.PSKID)
161				if err != nil {
162					t.Errorf("failed to set up receiver: %s", err)
163					return
164				}
165			default:
166				panic("unsupported mode")
167			}
168
169			for encryptionNum, e := range testVec.Encryptions {
170				ciphertext := senderContext.Seal(e.AdditionalData, e.Plaintext)
171				checkBytesEqual(t, "ciphertext", ciphertext, e.Ciphertext)
172
173				decrypted, err := receiverContext.Open(e.AdditionalData, ciphertext)
174				if err != nil {
175					t.Errorf("decryption %d failed: %s", encryptionNum, err)
176					return
177				}
178				checkBytesEqual(t, "decrypted plaintext", decrypted, e.Plaintext)
179			}
180
181			for _, ex := range testVec.Exports {
182				exportValue := senderContext.Export(ex.ExportContext, ex.ExportLength)
183				checkBytesEqual(t, "exportValue", exportValue, ex.ExportValue)
184
185				exportValue = receiverContext.Export(ex.ExportContext, ex.ExportLength)
186				checkBytesEqual(t, "exportValue", exportValue, ex.ExportValue)
187			}
188		})
189	}
190
191	if numSkippedTests == len(testVectors) {
192		panic("no test vectors were used")
193	}
194}
195
196// HexString enables us to unmarshal JSON strings containing hex byte strings.
197type HexString []byte
198
199func (h *HexString) UnmarshalJSON(data []byte) error {
200	if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
201		return errors.New("missing double quotes")
202	}
203	var err error
204	*h, err = hex.DecodeString(string(data[1 : len(data)-1]))
205	return err
206}
207
208func checkBytesEqual(t *testing.T, name string, actual, expected []byte) {
209	if !bytes.Equal(actual, expected) {
210		t.Errorf("%s = %x; want %x", name, actual, expected)
211	}
212}
213