forked from quic-go/quic-go
fix race condition when accessing the encryption level in crypto setup
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
@@ -63,9 +64,6 @@ type cryptoSetup struct {
|
|||||||
|
|
||||||
messageChan chan []byte
|
messageChan chan []byte
|
||||||
|
|
||||||
readEncLevel protocol.EncryptionLevel
|
|
||||||
writeEncLevel protocol.EncryptionLevel
|
|
||||||
|
|
||||||
paramsChan <-chan []byte
|
paramsChan <-chan []byte
|
||||||
handleParamsCallback func([]byte)
|
handleParamsCallback func([]byte)
|
||||||
|
|
||||||
@@ -80,18 +78,6 @@ type cryptoSetup struct {
|
|||||||
clientHelloWritten bool
|
clientHelloWritten bool
|
||||||
clientHelloWrittenChan chan struct{}
|
clientHelloWrittenChan chan struct{}
|
||||||
|
|
||||||
initialStream io.Writer
|
|
||||||
initialOpener Opener
|
|
||||||
initialSealer Sealer
|
|
||||||
|
|
||||||
handshakeStream io.Writer
|
|
||||||
handshakeOpener Opener
|
|
||||||
handshakeSealer Sealer
|
|
||||||
|
|
||||||
oneRTTStream io.Writer
|
|
||||||
opener Opener
|
|
||||||
sealer Sealer
|
|
||||||
|
|
||||||
receivedWriteKey chan struct{}
|
receivedWriteKey chan struct{}
|
||||||
receivedReadKey chan struct{}
|
receivedReadKey chan struct{}
|
||||||
// WriteRecord does a non-blocking send on this channel.
|
// WriteRecord does a non-blocking send on this channel.
|
||||||
@@ -104,6 +90,23 @@ type cryptoSetup struct {
|
|||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
|
|
||||||
perspective protocol.Perspective
|
perspective protocol.Perspective
|
||||||
|
|
||||||
|
mutex sync.Mutex // protects all members below
|
||||||
|
|
||||||
|
readEncLevel protocol.EncryptionLevel
|
||||||
|
writeEncLevel protocol.EncryptionLevel
|
||||||
|
|
||||||
|
initialStream io.Writer
|
||||||
|
initialOpener Opener
|
||||||
|
initialSealer Sealer
|
||||||
|
|
||||||
|
handshakeStream io.Writer
|
||||||
|
handshakeOpener Opener
|
||||||
|
handshakeSealer Sealer
|
||||||
|
|
||||||
|
oneRTTStream io.Writer
|
||||||
|
opener Opener
|
||||||
|
sealer Sealer
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ qtls.RecordLayer = &cryptoSetup{}
|
var _ qtls.RecordLayer = &cryptoSetup{}
|
||||||
@@ -431,6 +434,7 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
|
|||||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.mutex.Lock()
|
||||||
switch h.readEncLevel {
|
switch h.readEncLevel {
|
||||||
case protocol.EncryptionInitial:
|
case protocol.EncryptionInitial:
|
||||||
h.readEncLevel = protocol.EncryptionHandshake
|
h.readEncLevel = protocol.EncryptionHandshake
|
||||||
@@ -443,6 +447,7 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
|
|||||||
default:
|
default:
|
||||||
panic("unexpected read encryption level")
|
panic("unexpected read encryption level")
|
||||||
}
|
}
|
||||||
|
h.mutex.Unlock()
|
||||||
h.receivedReadKey <- struct{}{}
|
h.receivedReadKey <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,6 +460,7 @@ func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte)
|
|||||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.mutex.Lock()
|
||||||
switch h.writeEncLevel {
|
switch h.writeEncLevel {
|
||||||
case protocol.EncryptionInitial:
|
case protocol.EncryptionInitial:
|
||||||
h.writeEncLevel = protocol.EncryptionHandshake
|
h.writeEncLevel = protocol.EncryptionHandshake
|
||||||
@@ -467,6 +473,7 @@ func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte)
|
|||||||
default:
|
default:
|
||||||
panic("unexpected write encryption level")
|
panic("unexpected write encryption level")
|
||||||
}
|
}
|
||||||
|
h.mutex.Unlock()
|
||||||
h.receivedWriteKey <- struct{}{}
|
h.receivedWriteKey <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,6 +486,9 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
h.mutex.Lock()
|
||||||
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
switch h.writeEncLevel {
|
switch h.writeEncLevel {
|
||||||
case protocol.EncryptionInitial:
|
case protocol.EncryptionInitial:
|
||||||
// assume that the first WriteRecord call contains the ClientHello
|
// assume that the first WriteRecord call contains the ClientHello
|
||||||
@@ -502,6 +512,9 @@ func (h *cryptoSetup) SendAlert(alert uint8) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||||
|
h.mutex.Lock()
|
||||||
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
if h.sealer != nil {
|
if h.sealer != nil {
|
||||||
return protocol.Encryption1RTT, h.sealer
|
return protocol.Encryption1RTT, h.sealer
|
||||||
}
|
}
|
||||||
@@ -514,6 +527,9 @@ func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
|||||||
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
|
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
|
||||||
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
|
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
|
||||||
|
|
||||||
|
h.mutex.Lock()
|
||||||
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
switch level {
|
switch level {
|
||||||
case protocol.EncryptionInitial:
|
case protocol.EncryptionInitial:
|
||||||
return h.initialSealer, nil
|
return h.initialSealer, nil
|
||||||
@@ -533,6 +549,9 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) {
|
func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) {
|
||||||
|
h.mutex.Lock()
|
||||||
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
switch level {
|
switch level {
|
||||||
case protocol.EncryptionInitial:
|
case protocol.EncryptionInitial:
|
||||||
return h.initialOpener, nil
|
return h.initialOpener, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user