diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 4feaaa978..de6ddc684 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "sync" "unsafe" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -63,9 +64,6 @@ type cryptoSetup struct { messageChan chan []byte - readEncLevel protocol.EncryptionLevel - writeEncLevel protocol.EncryptionLevel - paramsChan <-chan []byte handleParamsCallback func([]byte) @@ -80,18 +78,6 @@ type cryptoSetup struct { clientHelloWritten bool clientHelloWrittenChan chan struct{} - initialStream io.Writer - initialOpener Opener - initialSealer Sealer - - handshakeStream io.Writer - handshakeOpener Opener - handshakeSealer Sealer - - oneRTTStream io.Writer - opener Opener - sealer Sealer - receivedWriteKey chan struct{} receivedReadKey chan struct{} // WriteRecord does a non-blocking send on this channel. @@ -104,6 +90,23 @@ type cryptoSetup struct { logger utils.Logger perspective protocol.Perspective + + mutex sync.Mutex // protects all members below + + readEncLevel protocol.EncryptionLevel + writeEncLevel protocol.EncryptionLevel + + initialStream io.Writer + initialOpener Opener + initialSealer Sealer + + handshakeStream io.Writer + handshakeOpener Opener + handshakeSealer Sealer + + oneRTTStream io.Writer + opener Opener + sealer Sealer } var _ qtls.RecordLayer = &cryptoSetup{} @@ -431,6 +434,7 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) panic(fmt.Sprintf("error creating new AES cipher: %s", err)) } + h.mutex.Lock() switch h.readEncLevel { case protocol.EncryptionInitial: h.readEncLevel = protocol.EncryptionHandshake @@ -443,6 +447,7 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) default: panic("unexpected read encryption level") } + h.mutex.Unlock() h.receivedReadKey <- struct{}{} } @@ -455,6 +460,7 @@ func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) panic(fmt.Sprintf("error creating new AES cipher: %s", err)) } + h.mutex.Lock() switch h.writeEncLevel { case protocol.EncryptionInitial: h.writeEncLevel = protocol.EncryptionHandshake @@ -467,6 +473,7 @@ func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) default: panic("unexpected write encryption level") } + h.mutex.Unlock() h.receivedWriteKey <- struct{}{} } @@ -479,6 +486,9 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { } }() + h.mutex.Lock() + defer h.mutex.Unlock() + switch h.writeEncLevel { case protocol.EncryptionInitial: // assume that the first WriteRecord call contains the ClientHello @@ -502,6 +512,9 @@ func (h *cryptoSetup) SendAlert(alert uint8) { } func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) { + h.mutex.Lock() + defer h.mutex.Unlock() + if h.sealer != nil { return protocol.Encryption1RTT, h.sealer } @@ -514,6 +527,9 @@ func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) { errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String()) + h.mutex.Lock() + defer h.mutex.Unlock() + switch level { case protocol.EncryptionInitial: return h.initialSealer, nil @@ -533,6 +549,9 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve } func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + switch level { case protocol.EncryptionInitial: return h.initialOpener, nil