1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runner
6
7import (
8	"crypto/aes"
9	"crypto/cipher"
10	"crypto/hmac"
11	"crypto/sha256"
12	"crypto/subtle"
13	"encoding/binary"
14	"errors"
15	"io"
16	"time"
17)
18
19// sessionState contains the information that is serialized into a session
20// ticket in order to later resume a connection.
21type sessionState struct {
22	vers                 uint16
23	cipherSuite          uint16
24	masterSecret         []byte
25	handshakeHash        []byte
26	certificates         [][]byte
27	extendedMasterSecret bool
28	earlyALPN            []byte
29	ticketCreationTime   time.Time
30	ticketExpiration     time.Time
31	ticketFlags          uint32
32	ticketAgeAdd         uint32
33}
34
35func (s *sessionState) marshal() []byte {
36	msg := newByteBuilder()
37	msg.addU16(s.vers)
38	msg.addU16(s.cipherSuite)
39	masterSecret := msg.addU16LengthPrefixed()
40	masterSecret.addBytes(s.masterSecret)
41	handshakeHash := msg.addU16LengthPrefixed()
42	handshakeHash.addBytes(s.handshakeHash)
43	msg.addU16(uint16(len(s.certificates)))
44	for _, cert := range s.certificates {
45		certMsg := msg.addU32LengthPrefixed()
46		certMsg.addBytes(cert)
47	}
48
49	if s.extendedMasterSecret {
50		msg.addU8(1)
51	} else {
52		msg.addU8(0)
53	}
54
55	if s.vers >= VersionTLS13 {
56		msg.addU64(uint64(s.ticketCreationTime.UnixNano()))
57		msg.addU64(uint64(s.ticketExpiration.UnixNano()))
58		msg.addU32(s.ticketFlags)
59		msg.addU32(s.ticketAgeAdd)
60	}
61
62	earlyALPN := msg.addU16LengthPrefixed()
63	earlyALPN.addBytes(s.earlyALPN)
64
65	return msg.finish()
66}
67
68func (s *sessionState) unmarshal(data []byte) bool {
69	if len(data) < 8 {
70		return false
71	}
72
73	s.vers = uint16(data[0])<<8 | uint16(data[1])
74	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
75	masterSecretLen := int(data[4])<<8 | int(data[5])
76	data = data[6:]
77	if len(data) < masterSecretLen {
78		return false
79	}
80
81	s.masterSecret = data[:masterSecretLen]
82	data = data[masterSecretLen:]
83
84	if len(data) < 2 {
85		return false
86	}
87
88	handshakeHashLen := int(data[0])<<8 | int(data[1])
89	data = data[2:]
90	if len(data) < handshakeHashLen {
91		return false
92	}
93
94	s.handshakeHash = data[:handshakeHashLen]
95	data = data[handshakeHashLen:]
96
97	if len(data) < 2 {
98		return false
99	}
100
101	numCerts := int(data[0])<<8 | int(data[1])
102	data = data[2:]
103
104	s.certificates = make([][]byte, numCerts)
105	for i := range s.certificates {
106		if len(data) < 4 {
107			return false
108		}
109		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
110		data = data[4:]
111		if certLen < 0 {
112			return false
113		}
114		if len(data) < certLen {
115			return false
116		}
117		s.certificates[i] = data[:certLen]
118		data = data[certLen:]
119	}
120
121	if len(data) < 1 {
122		return false
123	}
124
125	s.extendedMasterSecret = false
126	if data[0] == 1 {
127		s.extendedMasterSecret = true
128	}
129	data = data[1:]
130
131	if s.vers >= VersionTLS13 {
132		if len(data) < 24 {
133			return false
134		}
135		s.ticketCreationTime = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
136		data = data[8:]
137		s.ticketExpiration = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
138		data = data[8:]
139		s.ticketFlags = binary.BigEndian.Uint32(data)
140		data = data[4:]
141		s.ticketAgeAdd = binary.BigEndian.Uint32(data)
142		data = data[4:]
143	}
144
145	earlyALPNLen := int(data[0])<<8 | int(data[1])
146	data = data[2:]
147	if len(data) < earlyALPNLen {
148		return false
149	}
150	s.earlyALPN = data[:earlyALPNLen]
151	data = data[earlyALPNLen:]
152
153	if len(data) > 0 {
154		return false
155	}
156
157	return true
158}
159
160func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
161	serialized := state.marshal()
162	encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size)
163	iv := encrypted[:aes.BlockSize]
164	macBytes := encrypted[len(encrypted)-sha256.Size:]
165
166	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
167		return nil, err
168	}
169	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
170	if err != nil {
171		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
172	}
173	cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized)
174
175	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
176	mac.Write(encrypted[:len(encrypted)-sha256.Size])
177	mac.Sum(macBytes[:0])
178
179	return encrypted, nil
180}
181
182func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
183	if len(encrypted) < aes.BlockSize+sha256.Size {
184		return nil, false
185	}
186
187	iv := encrypted[:aes.BlockSize]
188	macBytes := encrypted[len(encrypted)-sha256.Size:]
189
190	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
191	mac.Write(encrypted[:len(encrypted)-sha256.Size])
192	expected := mac.Sum(nil)
193
194	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
195		return nil, false
196	}
197
198	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
199	if err != nil {
200		return nil, false
201	}
202	ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
203	plaintext := make([]byte, len(ciphertext))
204	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
205
206	state := new(sessionState)
207	ok := state.unmarshal(plaintext)
208	return state, ok
209}
210