1// Copyright 2014 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// DTLS implementation. 6// 7// NOTE: This is a not even a remotely production-quality DTLS 8// implementation. It is the bare minimum necessary to be able to 9// achieve coverage on BoringSSL's implementation. Of note is that 10// this implementation assumes the underlying net.PacketConn is not 11// only reliable but also ordered. BoringSSL will be expected to deal 12// with simulated loss, but there is no point in forcing the test 13// driver to. 14 15package main 16 17import ( 18 "bytes" 19 "errors" 20 "fmt" 21 "io" 22 "math/rand" 23 "net" 24) 25 26func versionToWire(vers uint16, isDTLS bool) uint16 { 27 if isDTLS { 28 return ^(vers - 0x0201) 29 } 30 return vers 31} 32 33func wireToVersion(vers uint16, isDTLS bool) uint16 { 34 if isDTLS { 35 return ^vers + 0x0201 36 } 37 return vers 38} 39 40func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) { 41 recordHeaderLen := dtlsRecordHeaderLen 42 43 if c.rawInput == nil { 44 c.rawInput = c.in.newBlock() 45 } 46 b := c.rawInput 47 48 // Read a new packet only if the current one is empty. 49 if len(b.data) == 0 { 50 // Pick some absurdly large buffer size. 51 b.resize(maxCiphertext + recordHeaderLen) 52 n, err := c.conn.Read(c.rawInput.data) 53 if err != nil { 54 return 0, nil, err 55 } 56 if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength { 57 return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length") 58 } 59 c.rawInput.resize(n) 60 } 61 62 // Read out one record. 63 // 64 // A real DTLS implementation should be tolerant of errors, 65 // but this is test code. We should not be tolerant of our 66 // peer sending garbage. 67 if len(b.data) < recordHeaderLen { 68 return 0, nil, errors.New("dtls: failed to read record header") 69 } 70 typ := recordType(b.data[0]) 71 vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS) 72 if c.haveVers { 73 if vers != c.vers { 74 c.sendAlert(alertProtocolVersion) 75 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers)) 76 } 77 } else { 78 if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect { 79 c.sendAlert(alertProtocolVersion) 80 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect)) 81 } 82 } 83 seq := b.data[3:11] 84 // For test purposes, we assume a reliable channel. Require 85 // that the explicit sequence number matches the incrementing 86 // one we maintain. A real implementation would maintain a 87 // replay window and such. 88 if !bytes.Equal(seq, c.in.seq[:]) { 89 c.sendAlert(alertIllegalParameter) 90 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number")) 91 } 92 n := int(b.data[11])<<8 | int(b.data[12]) 93 if n > maxCiphertext || len(b.data) < recordHeaderLen+n { 94 c.sendAlert(alertRecordOverflow) 95 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n)) 96 } 97 98 // Process message. 99 b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) 100 ok, off, err := c.in.decrypt(b) 101 if !ok { 102 c.in.setErrorLocked(c.sendAlert(err)) 103 } 104 b.off = off 105 return typ, b, nil 106} 107 108func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte { 109 fragment := make([]byte, 0, 12+fragLen) 110 fragment = append(fragment, header...) 111 fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq)) 112 fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset)) 113 fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen)) 114 fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...) 115 return fragment 116} 117 118func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) { 119 if typ != recordTypeHandshake { 120 // Only handshake messages are fragmented. 121 return c.dtlsWriteRawRecord(typ, data) 122 } 123 124 maxLen := c.config.Bugs.MaxHandshakeRecordLength 125 if maxLen <= 0 { 126 maxLen = 1024 127 } 128 129 // Handshake messages have to be modified to include fragment 130 // offset and length and with the header replicated. Save the 131 // TLS header here. 132 // 133 // TODO(davidben): This assumes that data contains exactly one 134 // handshake message. This is incompatible with 135 // FragmentAcrossChangeCipherSpec. (Which is unfortunate 136 // because OpenSSL's DTLS implementation will probably accept 137 // such fragmentation and could do with a fix + tests.) 138 header := data[:4] 139 data = data[4:] 140 141 isFinished := header[0] == typeFinished 142 143 if c.config.Bugs.SendEmptyFragments { 144 fragment := c.makeFragment(header, data, 0, 0) 145 c.pendingFragments = append(c.pendingFragments, fragment) 146 } 147 148 firstRun := true 149 fragOffset := 0 150 for firstRun || fragOffset < len(data) { 151 firstRun = false 152 fragLen := len(data) - fragOffset 153 if fragLen > maxLen { 154 fragLen = maxLen 155 } 156 157 fragment := c.makeFragment(header, data, fragOffset, fragLen) 158 if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 { 159 fragment[0]++ 160 } 161 if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 { 162 fragment[3]++ 163 } 164 165 // Buffer the fragment for later. They will be sent (and 166 // reordered) on flush. 167 c.pendingFragments = append(c.pendingFragments, fragment) 168 if c.config.Bugs.ReorderHandshakeFragments { 169 // Don't duplicate Finished to avoid the peer 170 // interpreting it as a retransmit request. 171 if !isFinished { 172 c.pendingFragments = append(c.pendingFragments, fragment) 173 } 174 175 if fragLen > (maxLen+1)/2 { 176 // Overlap each fragment by half. 177 fragLen = (maxLen + 1) / 2 178 } 179 } 180 fragOffset += fragLen 181 n += fragLen 182 } 183 if !isFinished && c.config.Bugs.MixCompleteMessageWithFragments { 184 fragment := c.makeFragment(header, data, 0, len(data)) 185 c.pendingFragments = append(c.pendingFragments, fragment) 186 } 187 188 // Increment the handshake sequence number for the next 189 // handshake message. 190 c.sendHandshakeSeq++ 191 return 192} 193 194func (c *Conn) dtlsFlushHandshake() error { 195 if !c.isDTLS { 196 return nil 197 } 198 199 // This is a test-only DTLS implementation, so there is no need to 200 // retain |c.pendingFragments| for a future retransmit. 201 var fragments [][]byte 202 fragments, c.pendingFragments = c.pendingFragments, fragments 203 204 if c.config.Bugs.ReorderHandshakeFragments { 205 perm := rand.New(rand.NewSource(0)).Perm(len(fragments)) 206 tmp := make([][]byte, len(fragments)) 207 for i := range tmp { 208 tmp[i] = fragments[perm[i]] 209 } 210 fragments = tmp 211 } 212 213 maxRecordLen := c.config.Bugs.PackHandshakeFragments 214 maxPacketLen := c.config.Bugs.PackHandshakeRecords 215 216 // Pack handshake fragments into records. 217 var records [][]byte 218 for _, fragment := range fragments { 219 if c.config.Bugs.SplitFragmentHeader { 220 records = append(records, fragment[:2]) 221 records = append(records, fragment[2:]) 222 } else if c.config.Bugs.SplitFragmentBody { 223 if len(fragment) > 12 { 224 records = append(records, fragment[:13]) 225 records = append(records, fragment[13:]) 226 } else { 227 records = append(records, fragment) 228 } 229 } else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen { 230 records[i] = append(records[i], fragment...) 231 } else { 232 // The fragment will be appended to, so copy it. 233 records = append(records, append([]byte{}, fragment...)) 234 } 235 } 236 237 // Format them into packets. 238 var packets [][]byte 239 for _, record := range records { 240 b, err := c.dtlsSealRecord(recordTypeHandshake, record) 241 if err != nil { 242 return err 243 } 244 245 if i := len(packets) - 1; len(packets) > 0 && len(packets[i])+len(b.data) <= maxPacketLen { 246 packets[i] = append(packets[i], b.data...) 247 } else { 248 // The sealed record will be appended to and reused by 249 // |c.out|, so copy it. 250 packets = append(packets, append([]byte{}, b.data...)) 251 } 252 c.out.freeBlock(b) 253 } 254 255 // Send all the packets. 256 for _, packet := range packets { 257 if _, err := c.conn.Write(packet); err != nil { 258 return err 259 } 260 } 261 return nil 262} 263 264// dtlsSealRecord seals a record into a block from |c.out|'s pool. 265func (c *Conn) dtlsSealRecord(typ recordType, data []byte) (b *block, err error) { 266 recordHeaderLen := dtlsRecordHeaderLen 267 maxLen := c.config.Bugs.MaxHandshakeRecordLength 268 if maxLen <= 0 { 269 maxLen = 1024 270 } 271 272 b = c.out.newBlock() 273 274 explicitIVLen := 0 275 explicitIVIsSeq := false 276 277 if cbc, ok := c.out.cipher.(cbcMode); ok { 278 // Block cipher modes have an explicit IV. 279 explicitIVLen = cbc.BlockSize() 280 } else if aead, ok := c.out.cipher.(*tlsAead); ok { 281 if aead.explicitNonce { 282 explicitIVLen = 8 283 // The AES-GCM construction in TLS has an explicit nonce so that 284 // the nonce can be random. However, the nonce is only 8 bytes 285 // which is too small for a secure, random nonce. Therefore we 286 // use the sequence number as the nonce. 287 explicitIVIsSeq = true 288 } 289 } else if c.out.cipher != nil { 290 panic("Unknown cipher") 291 } 292 b.resize(recordHeaderLen + explicitIVLen + len(data)) 293 b.data[0] = byte(typ) 294 vers := c.vers 295 if vers == 0 { 296 // Some TLS servers fail if the record version is greater than 297 // TLS 1.0 for the initial ClientHello. 298 vers = VersionTLS10 299 } 300 vers = versionToWire(vers, c.isDTLS) 301 b.data[1] = byte(vers >> 8) 302 b.data[2] = byte(vers) 303 // DTLS records include an explicit sequence number. 304 copy(b.data[3:11], c.out.seq[0:]) 305 b.data[11] = byte(len(data) >> 8) 306 b.data[12] = byte(len(data)) 307 if explicitIVLen > 0 { 308 explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] 309 if explicitIVIsSeq { 310 copy(explicitIV, c.out.seq[:]) 311 } else { 312 if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { 313 return 314 } 315 } 316 } 317 copy(b.data[recordHeaderLen+explicitIVLen:], data) 318 c.out.encrypt(b, explicitIVLen) 319 return 320} 321 322func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) { 323 b, err := c.dtlsSealRecord(typ, data) 324 if err != nil { 325 return 326 } 327 328 _, err = c.conn.Write(b.data) 329 if err != nil { 330 return 331 } 332 n = len(data) 333 334 c.out.freeBlock(b) 335 336 if typ == recordTypeChangeCipherSpec { 337 err = c.out.changeCipherSpec(c.config) 338 if err != nil { 339 // Cannot call sendAlert directly, 340 // because we already hold c.out.Mutex. 341 c.tmp[0] = alertLevelError 342 c.tmp[1] = byte(err.(alert)) 343 c.writeRecord(recordTypeAlert, c.tmp[0:2]) 344 return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) 345 } 346 } 347 return 348} 349 350func (c *Conn) dtlsDoReadHandshake() ([]byte, error) { 351 // Assemble a full handshake message. For test purposes, this 352 // implementation assumes fragments arrive in order. It may 353 // need to be cleverer if we ever test BoringSSL's retransmit 354 // behavior. 355 for len(c.handMsg) < 4+c.handMsgLen { 356 // Get a new handshake record if the previous has been 357 // exhausted. 358 if c.hand.Len() == 0 { 359 if err := c.in.err; err != nil { 360 return nil, err 361 } 362 if err := c.readRecord(recordTypeHandshake); err != nil { 363 return nil, err 364 } 365 } 366 367 // Read the next fragment. It must fit entirely within 368 // the record. 369 if c.hand.Len() < 12 { 370 return nil, errors.New("dtls: bad handshake record") 371 } 372 header := c.hand.Next(12) 373 fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3]) 374 fragSeq := uint16(header[4])<<8 | uint16(header[5]) 375 fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8]) 376 fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11]) 377 378 if c.hand.Len() < fragLen { 379 return nil, errors.New("dtls: fragment length too long") 380 } 381 fragment := c.hand.Next(fragLen) 382 383 // Check it's a fragment for the right message. 384 if fragSeq != c.recvHandshakeSeq { 385 return nil, errors.New("dtls: bad handshake sequence number") 386 } 387 388 // Check that the length is consistent. 389 if c.handMsg == nil { 390 c.handMsgLen = fragN 391 if c.handMsgLen > maxHandshake { 392 return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError)) 393 } 394 // Start with the TLS handshake header, 395 // without the DTLS bits. 396 c.handMsg = append([]byte{}, header[:4]...) 397 } else if fragN != c.handMsgLen { 398 return nil, errors.New("dtls: bad handshake length") 399 } 400 401 // Add the fragment to the pending message. 402 if 4+fragOff != len(c.handMsg) { 403 return nil, errors.New("dtls: bad fragment offset") 404 } 405 if fragOff+fragLen > c.handMsgLen { 406 return nil, errors.New("dtls: bad fragment length") 407 } 408 c.handMsg = append(c.handMsg, fragment...) 409 } 410 c.recvHandshakeSeq++ 411 ret := c.handMsg 412 c.handMsg, c.handMsgLen = nil, 0 413 return ret, nil 414} 415 416// DTLSServer returns a new DTLS server side connection 417// using conn as the underlying transport. 418// The configuration config must be non-nil and must have 419// at least one certificate. 420func DTLSServer(conn net.Conn, config *Config) *Conn { 421 c := &Conn{config: config, isDTLS: true, conn: conn} 422 c.init() 423 return c 424} 425 426// DTLSClient returns a new DTLS client side connection 427// using conn as the underlying transport. 428// The config cannot be nil: users must set either ServerHostname or 429// InsecureSkipVerify in the config. 430func DTLSClient(conn net.Conn, config *Config) *Conn { 431 c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn} 432 c.init() 433 return c 434} 435