1// Copyright (c) 2016, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15package runner
16
17import (
18	"bytes"
19	"crypto/aes"
20	"crypto/cipher"
21	"crypto/hmac"
22	"crypto/sha256"
23	"encoding/asn1"
24	"errors"
25)
26
27// TestShimTicketKey is the testing key assumed for the shim.
28var TestShimTicketKey = make([]byte, 48)
29
30func DecryptShimTicket(in []byte) ([]byte, error) {
31	name := TestShimTicketKey[:16]
32	macKey := TestShimTicketKey[16:32]
33	encKey := TestShimTicketKey[32:48]
34
35	h := hmac.New(sha256.New, macKey)
36
37	block, err := aes.NewCipher(encKey)
38	if err != nil {
39		panic(err)
40	}
41
42	if len(in) < len(name)+block.BlockSize()+1+h.Size() {
43		return nil, errors.New("tls: shim ticket too short")
44	}
45
46	// Check the key name.
47	if !bytes.Equal(name, in[:len(name)]) {
48		return nil, errors.New("tls: shim ticket name mismatch")
49	}
50
51	// Check the MAC at the end of the ticket.
52	mac := in[len(in)-h.Size():]
53	in = in[:len(in)-h.Size()]
54	h.Write(in)
55	if !hmac.Equal(mac, h.Sum(nil)) {
56		return nil, errors.New("tls: shim ticket MAC mismatch")
57	}
58
59	// The MAC covers the key name, but the encryption does not.
60	in = in[len(name):]
61
62	// Decrypt in-place.
63	iv := in[:block.BlockSize()]
64	in = in[block.BlockSize():]
65	if l := len(in); l == 0 || l%block.BlockSize() != 0 {
66		return nil, errors.New("tls: ticket ciphertext not a multiple of the block size")
67	}
68	out := make([]byte, len(in))
69	cbc := cipher.NewCBCDecrypter(block, iv)
70	cbc.CryptBlocks(out, in)
71
72	// Remove the padding.
73	pad := int(out[len(out)-1])
74	if pad == 0 || pad > block.BlockSize() || pad > len(in) {
75		return nil, errors.New("tls: bad shim ticket CBC pad")
76	}
77
78	for i := 0; i < pad; i++ {
79		if out[len(out)-1-i] != byte(pad) {
80			return nil, errors.New("tls: bad shim ticket CBC pad")
81		}
82	}
83
84	return out[:len(out)-pad], nil
85}
86
87func EncryptShimTicket(in []byte) []byte {
88	name := TestShimTicketKey[:16]
89	macKey := TestShimTicketKey[16:32]
90	encKey := TestShimTicketKey[32:48]
91
92	h := hmac.New(sha256.New, macKey)
93
94	block, err := aes.NewCipher(encKey)
95	if err != nil {
96		panic(err)
97	}
98
99	// Use the zero IV for rewritten tickets.
100	iv := make([]byte, block.BlockSize())
101	cbc := cipher.NewCBCEncrypter(block, iv)
102	pad := block.BlockSize() - (len(in) % block.BlockSize())
103
104	out := make([]byte, 0, len(name)+len(iv)+len(in)+pad+h.Size())
105	out = append(out, name...)
106	out = append(out, iv...)
107	out = append(out, in...)
108	for i := 0; i < pad; i++ {
109		out = append(out, byte(pad))
110	}
111
112	ciphertext := out[len(name)+len(iv):]
113	cbc.CryptBlocks(ciphertext, ciphertext)
114
115	h.Write(out)
116	return h.Sum(out)
117}
118
119const asn1Constructed = 0x20
120
121func parseDERElement(in []byte) (tag byte, body, rest []byte, ok bool) {
122	rest = in
123	if len(rest) < 1 {
124		return
125	}
126
127	tag = rest[0]
128	rest = rest[1:]
129
130	if tag&0x1f == 0x1f {
131		// Long-form tags not supported.
132		return
133	}
134
135	if len(rest) < 1 {
136		return
137	}
138
139	length := int(rest[0])
140	rest = rest[1:]
141	if length > 0x7f {
142		lengthLength := length & 0x7f
143		length = 0
144		if lengthLength == 0 {
145			// No indefinite-length encoding.
146			return
147		}
148
149		// Decode long-form lengths.
150		for lengthLength > 0 {
151			if len(rest) < 1 || (length<<8)>>8 != length {
152				return
153			}
154			if length == 0 && rest[0] == 0 {
155				// Length not minimally-encoded.
156				return
157			}
158			length <<= 8
159			length |= int(rest[0])
160			rest = rest[1:]
161			lengthLength--
162		}
163
164		if length < 0x80 {
165			// Length not minimally-encoded.
166			return
167		}
168	}
169
170	if len(rest) < length {
171		return
172	}
173
174	body = rest[:length]
175	rest = rest[length:]
176	ok = true
177	return
178}
179
180func SetShimTicketVersion(in []byte, vers uint16) ([]byte, error) {
181	plaintext, err := DecryptShimTicket(in)
182	if err != nil {
183		return nil, err
184	}
185
186	tag, session, _, ok := parseDERElement(plaintext)
187	if !ok || tag != asn1.TagSequence|asn1Constructed {
188		return nil, errors.New("tls: could not decode shim session")
189	}
190
191	// Skip the session version.
192	tag, _, session, ok = parseDERElement(session)
193	if !ok || tag != asn1.TagInteger {
194		return nil, errors.New("tls: could not decode shim session")
195	}
196
197	// Next field is the protocol version.
198	tag, version, _, ok := parseDERElement(session)
199	if !ok || tag != asn1.TagInteger {
200		return nil, errors.New("tls: could not decode shim session")
201	}
202
203	// This code assumes both old and new versions are encoded in two
204	// bytes. This isn't quite right as INTEGERs are minimally-encoded, but
205	// we do not need to support other caess for now.
206	if len(version) != 2 || vers < 0x80 || vers >= 0x8000 {
207		return nil, errors.New("tls: unsupported version in shim session")
208	}
209
210	version[0] = byte(vers >> 8)
211	version[1] = byte(vers)
212
213	return EncryptShimTicket(plaintext), nil
214}
215
216func SetShimTicketCipherSuite(in []byte, id uint16) ([]byte, error) {
217	plaintext, err := DecryptShimTicket(in)
218	if err != nil {
219		return nil, err
220	}
221
222	tag, session, _, ok := parseDERElement(plaintext)
223	if !ok || tag != asn1.TagSequence|asn1Constructed {
224		return nil, errors.New("tls: could not decode shim session")
225	}
226
227	// Skip the session version.
228	tag, _, session, ok = parseDERElement(session)
229	if !ok || tag != asn1.TagInteger {
230		return nil, errors.New("tls: could not decode shim session")
231	}
232
233	// Skip the protocol version.
234	tag, _, session, ok = parseDERElement(session)
235	if !ok || tag != asn1.TagInteger {
236		return nil, errors.New("tls: could not decode shim session")
237	}
238
239	// Next field is the cipher suite.
240	tag, cipherSuite, _, ok := parseDERElement(session)
241	if !ok || tag != asn1.TagOctetString || len(cipherSuite) != 2 {
242		return nil, errors.New("tls: could not decode shim session")
243	}
244
245	cipherSuite[0] = byte(id >> 8)
246	cipherSuite[1] = byte(id)
247
248	return EncryptShimTicket(plaintext), nil
249}
250