1package main
2
3import (
4	"crypto"
5	"crypto/aes"
6	"crypto/cipher"
7	"crypto/des"
8	"crypto/hmac"
9	_ "crypto/md5"
10	"crypto/rc4"
11	_ "crypto/sha1"
12	_ "crypto/sha256"
13	_ "crypto/sha512"
14	"encoding/hex"
15	"flag"
16	"fmt"
17	"os"
18)
19
20var bulkCipher *string = flag.String("cipher", "", "The bulk cipher to use")
21var mac *string = flag.String("mac", "", "The hash function to use in the MAC")
22var implicitIV *bool = flag.Bool("implicit-iv", false, "If true, generate tests for a cipher using a pre-TLS-1.0 implicit IV")
23var ssl3 *bool = flag.Bool("ssl3", false, "If true, use the SSLv3 MAC and padding rather than TLS")
24
25// rc4Stream produces a deterministic stream of pseudorandom bytes. This is to
26// make this script idempotent.
27type rc4Stream struct {
28	cipher *rc4.Cipher
29}
30
31func newRc4Stream(seed string) (*rc4Stream, error) {
32	cipher, err := rc4.NewCipher([]byte(seed))
33	if err != nil {
34		return nil, err
35	}
36	return &rc4Stream{cipher}, nil
37}
38
39func (rs *rc4Stream) fillBytes(p []byte) {
40	for i := range p {
41		p[i] = 0
42	}
43	rs.cipher.XORKeyStream(p, p)
44}
45
46func getHash(name string) (crypto.Hash, bool) {
47	switch name {
48	case "md5":
49		return crypto.MD5, true
50	case "sha1":
51		return crypto.SHA1, true
52	case "sha256":
53		return crypto.SHA256, true
54	case "sha384":
55		return crypto.SHA384, true
56	default:
57		return 0, false
58	}
59}
60
61func getKeySize(name string) int {
62	switch name {
63	case "rc4":
64		return 16
65	case "aes128":
66		return 16
67	case "aes256":
68		return 32
69	case "3des":
70		return 24
71	default:
72		return 0
73	}
74}
75
76func newBlockCipher(name string, key []byte) (cipher.Block, error) {
77	switch name {
78	case "aes128":
79		return aes.NewCipher(key)
80	case "aes256":
81		return aes.NewCipher(key)
82	case "3des":
83		return des.NewTripleDESCipher(key)
84	default:
85		return nil, fmt.Errorf("unknown cipher '%s'", name)
86	}
87}
88
89var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36}
90
91var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c}
92
93func ssl30MAC(hash crypto.Hash, key, input, ad []byte) []byte {
94	padLength := 48
95	if hash.Size() == 20 {
96		padLength = 40
97	}
98
99	h := hash.New()
100	h.Write(key)
101	h.Write(ssl30Pad1[:padLength])
102	h.Write(ad)
103	h.Write(input)
104	digestBuf := h.Sum(nil)
105
106	h.Reset()
107	h.Write(key)
108	h.Write(ssl30Pad2[:padLength])
109	h.Write(digestBuf)
110	return h.Sum(digestBuf[:0])
111}
112
113type testCase struct {
114	digest     []byte
115	key        []byte
116	nonce      []byte
117	input      []byte
118	ad         []byte
119	ciphertext []byte
120	tag        []byte
121	noSeal     bool
122	fails      bool
123}
124
125// options adds additional options for a test.
126type options struct {
127	// extraPadding causes an extra block of padding to be added.
128	extraPadding bool
129	// wrongPadding causes one of the padding bytes to be wrong.
130	wrongPadding bool
131	// noPadding causes padding is to be omitted. The plaintext + MAC must
132	// be a multiple of the block size.
133	noPadding bool
134}
135
136func makeTestCase(length int, options options) (*testCase, error) {
137	rand, err := newRc4Stream("input stream")
138	if err != nil {
139		return nil, err
140	}
141
142	input := make([]byte, length)
143	rand.fillBytes(input)
144
145	var adFull []byte
146	if *ssl3 {
147		adFull = make([]byte, 11)
148	} else {
149		adFull = make([]byte, 13)
150	}
151	ad := adFull[:len(adFull)-2]
152	rand.fillBytes(ad)
153	adFull[len(adFull)-2] = uint8(length >> 8)
154	adFull[len(adFull)-1] = uint8(length & 0xff)
155
156	hash, ok := getHash(*mac)
157	if !ok {
158		return nil, fmt.Errorf("unknown hash function '%s'", *mac)
159	}
160
161	macKey := make([]byte, hash.Size())
162	rand.fillBytes(macKey)
163
164	var digest []byte
165	if *ssl3 {
166		if hash != crypto.SHA1 && hash != crypto.MD5 {
167			return nil, fmt.Errorf("invalid hash for SSLv3: '%s'", *mac)
168		}
169		digest = ssl30MAC(hash, macKey, input, adFull)
170	} else {
171		h := hmac.New(hash.New, macKey)
172		h.Write(adFull)
173		h.Write(input)
174		digest = h.Sum(nil)
175	}
176
177	size := getKeySize(*bulkCipher)
178	if size == 0 {
179		return nil, fmt.Errorf("unknown cipher '%s'", *bulkCipher)
180	}
181	encKey := make([]byte, size)
182	rand.fillBytes(encKey)
183
184	var fixedIV []byte
185	var nonce []byte
186	var sealed []byte
187	var noSeal, fails bool
188	if *bulkCipher == "rc4" {
189		if *implicitIV {
190			return nil, fmt.Errorf("implicit IV enabled on a stream cipher")
191		}
192
193		stream, err := rc4.NewCipher(encKey)
194		if err != nil {
195			return nil, err
196		}
197
198		sealed = make([]byte, 0, len(input)+len(digest))
199		sealed = append(sealed, input...)
200		sealed = append(sealed, digest...)
201		stream.XORKeyStream(sealed, sealed)
202	} else {
203		block, err := newBlockCipher(*bulkCipher, encKey)
204		if err != nil {
205			return nil, err
206		}
207
208		iv := make([]byte, block.BlockSize())
209		rand.fillBytes(iv)
210		if *implicitIV || *ssl3 {
211			fixedIV = iv
212		} else {
213			nonce = iv
214		}
215
216		cbc := cipher.NewCBCEncrypter(block, iv)
217
218		sealed = make([]byte, 0, len(input)+len(digest)+cbc.BlockSize())
219		sealed = append(sealed, input...)
220		sealed = append(sealed, digest...)
221		paddingLen := cbc.BlockSize() - (len(sealed) % cbc.BlockSize())
222		if options.noPadding {
223			if paddingLen != cbc.BlockSize() {
224				return nil, fmt.Errorf("invalid length for noPadding")
225			}
226			noSeal = true
227			fails = true
228		} else {
229			if options.extraPadding {
230				paddingLen += cbc.BlockSize()
231				noSeal = true
232				if *ssl3 {
233					// SSLv3 padding must be minimal.
234					fails = true
235				}
236			}
237			if *ssl3 {
238				sealed = append(sealed, make([]byte, paddingLen-1)...)
239				sealed = append(sealed, byte(paddingLen-1))
240			} else {
241				pad := make([]byte, paddingLen)
242				for i := range pad {
243					pad[i] = byte(paddingLen - 1)
244				}
245				sealed = append(sealed, pad...)
246			}
247			if options.wrongPadding && paddingLen > 1 {
248				sealed[len(sealed)-2]++
249				noSeal = true
250				if !*ssl3 {
251					// TLS specifies the all the padding bytes.
252					fails = true
253				}
254			}
255		}
256		cbc.CryptBlocks(sealed, sealed)
257	}
258
259	key := make([]byte, 0, len(macKey)+len(encKey)+len(fixedIV))
260	key = append(key, macKey...)
261	key = append(key, encKey...)
262	key = append(key, fixedIV...)
263	t := &testCase{
264		digest:     digest,
265		key:        key,
266		nonce:      nonce,
267		input:      input,
268		ad:         ad,
269		ciphertext: sealed[:len(sealed)-hash.Size()],
270		tag:        sealed[len(sealed)-hash.Size():],
271		noSeal:     noSeal,
272		fails:      fails,
273	}
274	return t, nil
275}
276
277func printTestCase(t *testCase) {
278	fmt.Printf("# DIGEST: %s\n", hex.EncodeToString(t.digest))
279	fmt.Printf("KEY: %s\n", hex.EncodeToString(t.key))
280	fmt.Printf("NONCE: %s\n", hex.EncodeToString(t.nonce))
281	fmt.Printf("IN: %s\n", hex.EncodeToString(t.input))
282	fmt.Printf("AD: %s\n", hex.EncodeToString(t.ad))
283	fmt.Printf("CT: %s\n", hex.EncodeToString(t.ciphertext))
284	fmt.Printf("TAG: %s\n", hex.EncodeToString(t.tag))
285	if t.noSeal {
286		fmt.Printf("NO_SEAL: 01\n")
287	}
288	if t.fails {
289		fmt.Printf("FAILS: 01\n")
290	}
291}
292
293func main() {
294	flag.Parse()
295
296	commandLine := fmt.Sprintf("go run make_legacy_aead_tests.go -cipher %s -mac %s", *bulkCipher, *mac)
297	if *implicitIV {
298		commandLine += " -implicit-iv"
299	}
300	if *ssl3 {
301		commandLine += " -ssl3"
302	}
303	fmt.Printf("# Generated by\n")
304	fmt.Printf("#   %s\n", commandLine)
305	fmt.Printf("#\n")
306	fmt.Printf("# Note: aead_test's input format splits the ciphertext and tag positions of the sealed\n")
307	fmt.Printf("# input. But these legacy AEADs are MAC-then-encrypt and may include padding, so this\n")
308	fmt.Printf("# split isn't meaningful. The unencrypted MAC is included in the 'DIGEST' tag above\n")
309	fmt.Printf("# each test case.\n")
310	fmt.Printf("\n")
311
312	// For CBC-mode ciphers, emit tests for padding flexibility.
313	if *bulkCipher != "rc4" {
314		fmt.Printf("# Test with non-minimal padding.\n")
315		t, err := makeTestCase(5, options{extraPadding: true})
316		if err != nil {
317			fmt.Fprintf(os.Stderr, "%s\n", err)
318			os.Exit(1)
319		}
320		printTestCase(t)
321		fmt.Printf("\n")
322
323		fmt.Printf("# Test with bad padding values.\n")
324		t, err = makeTestCase(5, options{wrongPadding: true})
325		if err != nil {
326			fmt.Fprintf(os.Stderr, "%s\n", err)
327			os.Exit(1)
328		}
329		printTestCase(t)
330		fmt.Printf("\n")
331
332		fmt.Printf("# Test with no padding.\n")
333		hash, ok := getHash(*mac)
334		if !ok {
335			panic("unknown hash")
336		}
337		t, err = makeTestCase(64-hash.Size(), options{noPadding: true})
338		if err != nil {
339			fmt.Fprintf(os.Stderr, "%s\n", err)
340			os.Exit(1)
341		}
342		printTestCase(t)
343		fmt.Printf("\n")
344	}
345
346	// Generate long enough of input to cover a non-zero num_starting_blocks
347	// value in the constant-time CBC logic.
348	for l := 0; l < 500; l += 5 {
349		t, err := makeTestCase(l, options{})
350		if err != nil {
351			fmt.Fprintf(os.Stderr, "%s\n", err)
352			os.Exit(1)
353		}
354		printTestCase(t)
355		fmt.Printf("\n")
356	}
357}
358