Merge pull request #2777 from lucas-clemente/fix-tls-handshake-message-handling

fix handling of multiple handshake messages in the case of errors
This commit is contained in:
Marten Seemann
2020-09-14 13:59:22 +07:00
committed by GitHub
2 changed files with 75 additions and 156 deletions

View File

@@ -87,7 +87,9 @@ type cryptoSetup struct {
extraConf *qtls.ExtraConfig
conn *qtls.Conn
messageChan chan []byte
messageChan chan []byte
isReadingHandshakeMessage chan struct{}
readFirstHandshakeMessage bool
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
@@ -105,15 +107,6 @@ type cryptoSetup struct {
clientHelloWritten bool
clientHelloWrittenChan chan *wire.TransportParameters
receivedWriteKey chan struct{}
receivedReadKey chan struct{}
// WriteRecord does a non-blocking send on this channel.
// This way, handleMessage can see if qtls tries to write a message.
// This is necessary:
// for servers: to see if a HelloRetryRequest should be sent in response to a ClientHello
// for clients: to see if a ServerHello is a HelloRetryRequest
writeRecord chan struct{}
rttStats *utils.RTTStats
tracer logging.ConnectionTracer
@@ -231,29 +224,27 @@ func newCryptoSetup(
}
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
cs := &cryptoSetup{
tlsConf: tlsConf,
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
aead: newUpdatableAEAD(rttStats, tracer, logger),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
messageChan: make(chan []byte, 100),
receivedReadKey: make(chan struct{}),
receivedWriteKey: make(chan struct{}),
writeRecord: make(chan struct{}, 1),
closeChan: make(chan struct{}),
tlsConf: tlsConf,
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
aead: newUpdatableAEAD(rttStats, tracer, logger),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
messageChan: make(chan []byte, 100),
isReadingHandshakeMessage: make(chan struct{}),
closeChan: make(chan struct{}),
}
var maxEarlyData uint32
if enable0RTT {
@@ -344,20 +335,25 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev
h.messageChan <- data
if encLevel == protocol.Encryption1RTT {
h.handlePostHandshakeMessage()
return false
}
var strFinished bool
switch h.perspective {
case protocol.PerspectiveClient:
strFinished = h.handleMessageForClient(msgType)
case protocol.PerspectiveServer:
strFinished = h.handleMessageForServer(msgType)
default:
panic("")
readLoop:
for {
select {
case data := <-h.paramsChan:
h.handleTransportParameters(data)
case <-h.isReadingHandshakeMessage:
break readLoop
case <-h.handshakeDone:
break readLoop
}
}
if strFinished {
h.logger.Debugf("Done with encryption level %s.", encLevel)
}
return strFinished
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
// but only if a handshake opener and sealer was created.
// Otherwise, a HelloRetryRequest was performed.
// We're done with the Handshake encryption level after processing the Finished message.
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
msgType == typeFinished
}
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
@@ -383,108 +379,6 @@ func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protoco
return nil
}
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
switch msgType {
case typeClientHello:
select {
case <-h.writeRecord:
// If qtls sends a HelloRetryRequest, it will only write the record.
// If it accepts the ClientHello, it will first read the transport parameters.
h.logger.Debugf("Sending HelloRetryRequest")
return false
case data := <-h.paramsChan:
h.handleTransportParameters(data)
case <-h.handshakeDone:
return false
}
// get the handshake read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
// get the 1-RTT write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
return true
case typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
return true
default:
// unexpected message
return false
}
}
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
switch msgType {
case typeServerHello:
// get the handshake write key
select {
case <-h.writeRecord:
// If qtls writes in response to a ServerHello, this means that this ServerHello
// is a HelloRetryRequest.
// Otherwise, we'd just wait for the Certificate message.
h.logger.Debugf("ServerHello is a HelloRetryRequest")
return false
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
// get the handshake read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
return true
case typeEncryptedExtensions:
select {
case data := <-h.paramsChan:
h.handleTransportParameters(data)
case <-h.handshakeDone:
return false
}
return false
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
return true
default:
return false
}
}
func (h *cryptoSetup) handleTransportParameters(data []byte) {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
@@ -591,6 +485,7 @@ func (h *cryptoSetup) handlePostHandshakeMessage() {
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
alertChan := make(chan uint8, 1)
go func() {
<-h.isReadingHandshakeMessage
select {
case alert := <-h.alertChan:
alertChan <- alert
@@ -606,6 +501,11 @@ func (h *cryptoSetup) handlePostHandshakeMessage() {
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
if !h.readFirstHandshakeMessage {
h.readFirstHandshakeMessage = true
} else {
h.isReadingHandshakeMessage <- struct{}{}
}
msg, ok := <-h.messageChan
if !ok {
return nil, errors.New("error while handling the handshake message")
@@ -651,7 +551,6 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
}
h.receivedReadKey <- struct{}{}
}
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
@@ -696,7 +595,6 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
}
h.receivedWriteKey <- struct{}{}
}
// WriteRecord is called when TLS writes data
@@ -717,11 +615,6 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
h.logger.Debugf("Not doing 0-RTT.")
h.clientHelloWrittenChan <- nil
}
} else {
// We need additional signaling to properly detect HelloRetryRequests.
// For servers: when the ServerHello is written.
// For clients: when a reply is sent in response to a ServerHello.
h.writeRecord <- struct{}{}
}
return n, err
case protocol.EncryptionHandshake:

View File

@@ -1,6 +1,7 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@@ -22,6 +23,13 @@ import (
. "github.com/onsi/gomega"
)
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
}
type chunk struct {
data []byte
encLevel protocol.EncryptionLevel
@@ -257,9 +265,27 @@ var _ = Describe("Crypto Setup TLS", func() {
for {
select {
case c := <-cChunkChan:
server.HandleMessage(c.data, c.encLevel)
msgType := messageType(c.data[0])
finished := server.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeClientHello {
// If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys.
_, err := server.GetHandshakeOpener()
Expect(finished).To(Equal(err == nil))
} else {
Expect(finished).To(BeFalse())
}
case c := <-sChunkChan:
client.HandleMessage(c.data, c.encLevel)
msgType := messageType(c.data[0])
finished := client.HandleMessage(c.data, c.encLevel)
if msgType == typeFinished {
Expect(finished).To(BeTrue())
} else if msgType == typeServerHello {
Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom)))
} else {
Expect(finished).To(BeFalse())
}
case <-done: // handshake complete
return
}