fix race condition when accessing the encryption level in crypto setup

This commit is contained in:
Marten Seemann
2019-04-01 12:16:31 +09:00
parent e9f7f87063
commit 9ffbd662c1

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net"
"sync"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -63,9 +64,6 @@ type cryptoSetup struct {
messageChan chan []byte
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
paramsChan <-chan []byte
handleParamsCallback func([]byte)
@@ -80,18 +78,6 @@ type cryptoSetup struct {
clientHelloWritten bool
clientHelloWrittenChan chan struct{}
initialStream io.Writer
initialOpener Opener
initialSealer Sealer
handshakeStream io.Writer
handshakeOpener Opener
handshakeSealer Sealer
oneRTTStream io.Writer
opener Opener
sealer Sealer
receivedWriteKey chan struct{}
receivedReadKey chan struct{}
// WriteRecord does a non-blocking send on this channel.
@@ -104,6 +90,23 @@ type cryptoSetup struct {
logger utils.Logger
perspective protocol.Perspective
mutex sync.Mutex // protects all members below
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
initialStream io.Writer
initialOpener Opener
initialSealer Sealer
handshakeStream io.Writer
handshakeOpener Opener
handshakeSealer Sealer
oneRTTStream io.Writer
opener Opener
sealer Sealer
}
var _ qtls.RecordLayer = &cryptoSetup{}
@@ -431,6 +434,7 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
h.mutex.Lock()
switch h.readEncLevel {
case protocol.EncryptionInitial:
h.readEncLevel = protocol.EncryptionHandshake
@@ -443,6 +447,7 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
default:
panic("unexpected read encryption level")
}
h.mutex.Unlock()
h.receivedReadKey <- struct{}{}
}
@@ -455,6 +460,7 @@ func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte)
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
h.mutex.Lock()
switch h.writeEncLevel {
case protocol.EncryptionInitial:
h.writeEncLevel = protocol.EncryptionHandshake
@@ -467,6 +473,7 @@ func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte)
default:
panic("unexpected write encryption level")
}
h.mutex.Unlock()
h.receivedWriteKey <- struct{}{}
}
@@ -479,6 +486,9 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
}
}()
h.mutex.Lock()
defer h.mutex.Unlock()
switch h.writeEncLevel {
case protocol.EncryptionInitial:
// assume that the first WriteRecord call contains the ClientHello
@@ -502,6 +512,9 @@ func (h *cryptoSetup) SendAlert(alert uint8) {
}
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.sealer != nil {
return protocol.Encryption1RTT, h.sealer
}
@@ -514,6 +527,9 @@ func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
h.mutex.Lock()
defer h.mutex.Unlock()
switch level {
case protocol.EncryptionInitial:
return h.initialSealer, nil
@@ -533,6 +549,9 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
}
func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
switch level {
case protocol.EncryptionInitial:
return h.initialOpener, nil