1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runner
6
7import (
8	"bytes"
9	"crypto/aes"
10	"crypto/cipher"
11	"crypto/hmac"
12	"crypto/sha256"
13	"crypto/subtle"
14	"errors"
15	"io"
16)
17
18// sessionState contains the information that is serialized into a session
19// ticket in order to later resume a connection.
20type sessionState struct {
21	vers                 uint16
22	cipherSuite          uint16
23	masterSecret         []byte
24	handshakeHash        []byte
25	certificates         [][]byte
26	extendedMasterSecret bool
27}
28
29func (s *sessionState) equal(i interface{}) bool {
30	s1, ok := i.(*sessionState)
31	if !ok {
32		return false
33	}
34
35	if s.vers != s1.vers ||
36		s.cipherSuite != s1.cipherSuite ||
37		!bytes.Equal(s.masterSecret, s1.masterSecret) ||
38		!bytes.Equal(s.handshakeHash, s1.handshakeHash) ||
39		s.extendedMasterSecret != s1.extendedMasterSecret {
40		return false
41	}
42
43	if len(s.certificates) != len(s1.certificates) {
44		return false
45	}
46
47	for i := range s.certificates {
48		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
49			return false
50		}
51	}
52
53	return true
54}
55
56func (s *sessionState) marshal() []byte {
57	length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2
58	for _, cert := range s.certificates {
59		length += 4 + len(cert)
60	}
61	length++
62
63	ret := make([]byte, length)
64	x := ret
65	x[0] = byte(s.vers >> 8)
66	x[1] = byte(s.vers)
67	x[2] = byte(s.cipherSuite >> 8)
68	x[3] = byte(s.cipherSuite)
69	x[4] = byte(len(s.masterSecret) >> 8)
70	x[5] = byte(len(s.masterSecret))
71	x = x[6:]
72	copy(x, s.masterSecret)
73	x = x[len(s.masterSecret):]
74
75	x[0] = byte(len(s.handshakeHash) >> 8)
76	x[1] = byte(len(s.handshakeHash))
77	x = x[2:]
78	copy(x, s.handshakeHash)
79	x = x[len(s.handshakeHash):]
80
81	x[0] = byte(len(s.certificates) >> 8)
82	x[1] = byte(len(s.certificates))
83	x = x[2:]
84
85	for _, cert := range s.certificates {
86		x[0] = byte(len(cert) >> 24)
87		x[1] = byte(len(cert) >> 16)
88		x[2] = byte(len(cert) >> 8)
89		x[3] = byte(len(cert))
90		copy(x[4:], cert)
91		x = x[4+len(cert):]
92	}
93
94	if s.extendedMasterSecret {
95		x[0] = 1
96	}
97	x = x[1:]
98
99	return ret
100}
101
102func (s *sessionState) unmarshal(data []byte) bool {
103	if len(data) < 8 {
104		return false
105	}
106
107	s.vers = uint16(data[0])<<8 | uint16(data[1])
108	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
109	masterSecretLen := int(data[4])<<8 | int(data[5])
110	data = data[6:]
111	if len(data) < masterSecretLen {
112		return false
113	}
114
115	s.masterSecret = data[:masterSecretLen]
116	data = data[masterSecretLen:]
117
118	if len(data) < 2 {
119		return false
120	}
121
122	handshakeHashLen := int(data[0])<<8 | int(data[1])
123	data = data[2:]
124	if len(data) < handshakeHashLen {
125		return false
126	}
127
128	s.handshakeHash = data[:handshakeHashLen]
129	data = data[handshakeHashLen:]
130
131	if len(data) < 2 {
132		return false
133	}
134
135	numCerts := int(data[0])<<8 | int(data[1])
136	data = data[2:]
137
138	s.certificates = make([][]byte, numCerts)
139	for i := range s.certificates {
140		if len(data) < 4 {
141			return false
142		}
143		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
144		data = data[4:]
145		if certLen < 0 {
146			return false
147		}
148		if len(data) < certLen {
149			return false
150		}
151		s.certificates[i] = data[:certLen]
152		data = data[certLen:]
153	}
154
155	if len(data) < 1 {
156		return false
157	}
158
159	s.extendedMasterSecret = false
160	if data[0] == 1 {
161		s.extendedMasterSecret = true
162	}
163	data = data[1:]
164
165	if len(data) > 0 {
166		return false
167	}
168
169	return true
170}
171
172func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
173	serialized := state.marshal()
174	encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size)
175	iv := encrypted[:aes.BlockSize]
176	macBytes := encrypted[len(encrypted)-sha256.Size:]
177
178	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
179		return nil, err
180	}
181	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
182	if err != nil {
183		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
184	}
185	cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized)
186
187	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
188	mac.Write(encrypted[:len(encrypted)-sha256.Size])
189	mac.Sum(macBytes[:0])
190
191	return encrypted, nil
192}
193
194func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
195	if len(encrypted) < aes.BlockSize+sha256.Size {
196		return nil, false
197	}
198
199	iv := encrypted[:aes.BlockSize]
200	macBytes := encrypted[len(encrypted)-sha256.Size:]
201
202	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
203	mac.Write(encrypted[:len(encrypted)-sha256.Size])
204	expected := mac.Sum(nil)
205
206	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
207		return nil, false
208	}
209
210	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
211	if err != nil {
212		return nil, false
213	}
214	ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
215	plaintext := make([]byte, len(ciphertext))
216	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
217
218	state := new(sessionState)
219	ok := state.unmarshal(plaintext)
220	return state, ok
221}
222