Merge pull request #1552 from lucas-clemente/crypto-stream-empty-check

check that the crypto stream is empty when switching encryption levels
This commit is contained in:
Marten Seemann
2018-10-26 17:59:36 +07:00
committed by GitHub
12 changed files with 305 additions and 142 deletions

View File

@@ -1,6 +1,7 @@
package quic
import (
"errors"
"fmt"
"io"
@@ -13,6 +14,7 @@ type cryptoStream interface {
// for receiving data
HandleCryptoFrame(*wire.CryptoFrame) error
GetCryptoData() []byte
Finish() error
// for sending data
io.Writer
HasData() bool
@@ -20,7 +22,11 @@ type cryptoStream interface {
}
type cryptoStreamImpl struct {
queue *frameSorter
queue *frameSorter
msgBuf []byte
highestOffset protocol.ByteCount
finished bool
writeOffset protocol.ByteCount
writeBuf []byte
@@ -33,16 +39,53 @@ func newCryptoStream() cryptoStream {
}
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset {
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)
}
return s.queue.Push(f.Data, f.Offset, false)
if s.finished {
if highestOffset > s.highestOffset {
// reject crypto data received after this stream was already finished
return errors.New("received crypto data after change of encryption level")
}
// ignore data with a smaller offset than the highest received
// could e.g. be a retransmission
return nil
}
s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
if err := s.queue.Push(f.Data, f.Offset, false); err != nil {
return err
}
for {
data, _ := s.queue.Pop()
if data == nil {
return nil
}
s.msgBuf = append(s.msgBuf, data...)
}
}
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
data, _ := s.queue.Pop()
return data
if len(s.msgBuf) < 4 {
return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
}
func (s *cryptoStreamImpl) Finish() error {
if s.queue.HasMoreData() {
return errors.New("encryption level changed, but crypto stream has more data to read")
}
s.finished = true
return nil
}
// Writes writes data that should be sent out in CRYPTO frames

View File

@@ -8,7 +8,7 @@ import (
)
type cryptoDataHandler interface {
HandleData([]byte, protocol.EncryptionLevel)
HandleMessage([]byte, protocol.EncryptionLevel) bool
}
type cryptoStreamManager struct {
@@ -30,7 +30,7 @@ func newCryptoStreamManager(
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
var str cryptoStream
switch encLevel {
case protocol.EncryptionInitial:
@@ -38,16 +38,18 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
case protocol.EncryptionHandshake:
str = m.handshakeStream
default:
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return err
return false, err
}
for {
data := str.GetCryptoData()
if data == nil {
return nil
return false, nil
}
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
return true, str.Finish()
}
m.cryptoHandler.HandleData(data, encLevel)
}
}

View File

@@ -1,6 +1,8 @@
package quic
import (
"errors"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
@@ -13,43 +15,91 @@ var _ = Describe("Crypto Stream Manager", func() {
var (
csm *cryptoStreamManager
cs *MockCryptoDataHandler
initialStream *MockCryptoStream
handshakeStream *MockCryptoStream
)
BeforeEach(func() {
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
initialStream = NewMockCryptoStream(mockCtrl)
handshakeStream = NewMockCryptoStream(mockCtrl)
cs = NewMockCryptoDataHandler(mockCtrl)
csm = newCryptoStreamManager(cs, initialStream, handshakeStream)
})
It("handles in in-order crypto frame", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")}
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionInitial)
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionInitial)).To(Succeed())
It("passes messages to the initial stream", func() {
cf := &wire.CryptoFrame{Data: []byte("foobar")}
initialStream.EXPECT().HandleCryptoFrame(cf)
initialStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
initialStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("passes messages to the handshake stream", func() {
cf := &wire.CryptoFrame{Data: []byte("foobar")}
handshakeStream.EXPECT().HandleCryptoFrame(cf)
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar"))
handshakeStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("doesn't call the message handler, if there's no message", func() {
cf := &wire.CryptoFrame{Data: []byte("foobar")}
handshakeStream.EXPECT().HandleCryptoFrame(cf)
handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle
// don't EXPECT any calls to HandleMessage()
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("processes all messages", func() {
cf := &wire.CryptoFrame{Data: []byte("foobar")}
handshakeStream.EXPECT().HandleCryptoFrame(cf)
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foo"))
handshakeStream.EXPECT().GetCryptoData().Return([]byte("bar"))
handshakeStream.EXPECT().GetCryptoData()
cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake)
cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeFalse())
})
It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() {
cf := &wire.CryptoFrame{Data: []byte("foobar")}
gomock.InOrder(
handshakeStream.EXPECT().HandleCryptoFrame(cf),
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
handshakeStream.EXPECT().Finish(),
)
encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).ToNot(HaveOccurred())
Expect(encLevelChanged).To(BeTrue())
})
It("returns errors that occur when finishing a stream", func() {
testErr := errors.New("test error")
cf := &wire.CryptoFrame{Data: []byte("foobar")}
gomock.InOrder(
handshakeStream.EXPECT().HandleCryptoFrame(cf),
handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true),
handshakeStream.EXPECT().Finish().Return(testErr),
)
_, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)
Expect(err).To(MatchError(err))
})
It("errors for unknown encryption levels", func() {
err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
_, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT)
Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT"))
})
It("handles out-of-order crypto frames", func() {
f1 := &wire.CryptoFrame{Data: []byte("foo")}
f2 := &wire.CryptoFrame{
Offset: 3,
Data: []byte("bar"),
}
gomock.InOrder(
cs.EXPECT().HandleData([]byte("foo"), protocol.EncryptionInitial),
cs.EXPECT().HandleData([]byte("bar"), protocol.EncryptionInitial),
)
Expect(csm.HandleCryptoFrame(f1, protocol.EncryptionInitial)).To(Succeed())
Expect(csm.HandleCryptoFrame(f2, protocol.EncryptionInitial)).To(Succeed())
})
It("handles handshake data", func() {
f := &wire.CryptoFrame{Data: []byte("foobar")}
cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake)
Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed())
})
})

View File

@@ -1,6 +1,7 @@
package quic
import (
"crypto/rand"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -10,6 +11,16 @@ import (
. "github.com/onsi/gomega"
)
func createHandshakeMessage(len int) []byte {
msg := make([]byte, 4+len)
rand.Read(msg[:1]) // random message type
msg[1] = uint8(len >> 16)
msg[2] = uint8(len >> 8)
msg[3] = uint8(len)
rand.Read(msg[4:])
return msg
}
var _ = Describe("Crypto Stream", func() {
var (
str cryptoStream
@@ -21,11 +32,21 @@ var _ = Describe("Crypto Stream", func() {
Context("handling incoming data", func() {
It("handles in-order CRYPTO frames", func() {
err := str.HandleCryptoFrame(&wire.CryptoFrame{
Data: []byte("foobar"),
})
msg := createHandshakeMessage(6)
err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal([]byte("foobar")))
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(BeNil())
})
It("handles multiple messages in one CRYPTO frame", func() {
msg1 := createHandshakeMessage(6)
msg2 := createHandshakeMessage(10)
msg := append(append([]byte{}, msg1...), msg2...)
err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal(msg1))
Expect(str.GetCryptoData()).To(Equal(msg2))
Expect(str.GetCryptoData()).To(BeNil())
})
@@ -37,21 +58,83 @@ var _ = Describe("Crypto Stream", func() {
Expect(err).To(MatchError(fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset)))
})
It("handles out-of-order CRYPTO frames", func() {
It("handles messages split over multiple CRYPTO frames", func() {
msg := createHandshakeMessage(6)
err := str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: 3,
Data: []byte("bar"),
Data: msg[:4],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(BeNil())
err = str.HandleCryptoFrame(&wire.CryptoFrame{
Data: []byte("foo"),
Offset: 4,
Data: msg[4:],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal([]byte("foo")))
Expect(str.GetCryptoData()).To(Equal([]byte("bar")))
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(BeNil())
})
It("handles out-of-order CRYPTO frames", func() {
msg := createHandshakeMessage(6)
err := str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: 4,
Data: msg[4:],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(BeNil())
err = str.HandleCryptoFrame(&wire.CryptoFrame{
Data: msg[:4],
})
Expect(err).ToNot(HaveOccurred())
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.GetCryptoData()).To(BeNil())
})
Context("finishing", func() {
It("errors if there's still data to read after finishing", func() {
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Data: createHandshakeMessage(5),
Offset: 10,
})).To(Succeed())
err := str.Finish()
Expect(err).To(MatchError("encryption level changed, but crypto stream has more data to read"))
})
It("works with reordered data", func() {
f1 := &wire.CryptoFrame{
Data: []byte("foo"),
}
f2 := &wire.CryptoFrame{
Offset: 3,
Data: []byte("bar"),
}
Expect(str.HandleCryptoFrame(f2)).To(Succeed())
Expect(str.HandleCryptoFrame(f1)).To(Succeed())
Expect(str.Finish()).To(Succeed())
Expect(str.HandleCryptoFrame(f2)).To(Succeed())
})
It("rejects new crypto data after finishing", func() {
Expect(str.Finish()).To(Succeed())
err := str.HandleCryptoFrame(&wire.CryptoFrame{
Data: createHandshakeMessage(5),
})
Expect(err).To(MatchError("received crypto data after change of encryption level"))
})
It("ignores crypto data below the maximum offset received before finishing", func() {
msg := createHandshakeMessage(15)
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Data: msg,
})).To(Succeed())
Expect(str.GetCryptoData()).To(Equal(msg))
Expect(str.Finish()).To(Succeed())
Expect(str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: protocol.ByteCount(len(msg) - 6),
Data: []byte("foobar"),
})).To(Succeed())
})
})
})
Context("writing data", func() {

View File

@@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
s.readPos += protocol.ByteCount(len(data))
return data, s.readPos >= s.finalOffset
}
// HasMoreData says if there is any more data queued at *any* offset.
func (s *frameSorter) HasMoreData() bool {
return len(s.queue) > 0
}

View File

@@ -55,6 +55,15 @@ var _ = Describe("STREAM frame sorter", func() {
Expect(s.Pop()).To(BeNil())
})
It("says if has more data", func() {
Expect(s.HasMoreData()).To(BeFalse())
Expect(s.Push([]byte("foo"), 0, false)).To(Succeed())
Expect(s.HasMoreData()).To(BeTrue())
data, _ := s.Pop()
Expect(data).To(Equal([]byte("foo")))
Expect(s.HasMoreData()).To(BeFalse())
})
Context("FIN handling", func() {
It("saves a FIN at offset 0", func() {
Expect(s.Push(nil, 0, true)).To(Succeed())

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
@@ -64,8 +63,6 @@ type cryptoSetupTLS struct {
handshakeErrChan chan struct{}
// HandleData() sends errors on the messageErrChan
messageErrChan chan error
// handshakeEvent signals a change of encryption level to the session
handshakeEvent chan<- struct{}
// handshakeComplete is closed when the handshake completes
handshakeComplete chan<- struct{}
// transport parameters are sent on the receivedTransportParams, as soon as they are received
@@ -74,14 +71,12 @@ type cryptoSetupTLS struct {
clientHelloWritten bool
clientHelloWrittenChan chan struct{}
initialReadBuf bytes.Buffer
initialStream io.Writer
initialAEAD crypto.AEAD
initialStream io.Writer
initialAEAD crypto.AEAD
handshakeReadBuf bytes.Buffer
handshakeStream io.Writer
handshakeOpener Opener
handshakeSealer Sealer
handshakeStream io.Writer
handshakeOpener Opener
handshakeSealer Sealer
opener Opener
sealer Sealer
@@ -111,7 +106,6 @@ func NewCryptoSetupTLSClient(
connID protocol.ConnectionID,
params *TransportParameters,
handleParams func(*TransportParameters),
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
tlsConf *tls.Config,
initialVersion protocol.VersionNumber,
@@ -126,7 +120,6 @@ func NewCryptoSetupTLSClient(
connID,
params,
handleParams,
handshakeEvent,
handshakeComplete,
tlsConf,
versionInfo{
@@ -146,7 +139,6 @@ func NewCryptoSetupTLSServer(
connID protocol.ConnectionID,
params *TransportParameters,
handleParams func(*TransportParameters),
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
tlsConf *tls.Config,
supportedVersions []protocol.VersionNumber,
@@ -160,7 +152,6 @@ func NewCryptoSetupTLSServer(
connID,
params,
handleParams,
handshakeEvent,
handshakeComplete,
tlsConf,
versionInfo{
@@ -179,7 +170,6 @@ func newCryptoSetupTLS(
connID protocol.ConnectionID,
params *TransportParameters,
handleParams func(*TransportParameters),
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
tlsConf *tls.Config,
versionInfo versionInfo,
@@ -197,7 +187,6 @@ func newCryptoSetupTLS(
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams,
handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
logger: logger,
perspective: perspective,
@@ -272,51 +261,25 @@ func (h *cryptoSetupTLS) RunHandshake() error {
}
}
func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) {
var buf *bytes.Buffer
switch encLevel {
case protocol.EncryptionInitial:
buf = &h.initialReadBuf
case protocol.EncryptionHandshake:
buf = &h.handshakeReadBuf
default:
h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel)
return
}
buf.Write(data)
for buf.Len() >= 4 {
b := buf.Bytes()
// read the TLS message length
length := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if buf.Len() < 4+length { // message not yet complete
return
}
msg := make([]byte, length+4)
buf.Read(msg)
if err := h.handleMessage(msg, encLevel); err != nil {
h.messageErrChan <- err
}
}
}
// handleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
// It returns if it is done with messages on the same encryption level.
func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
return err
h.messageErrChan <- err
return false
}
h.messageChan <- data
switch h.perspective {
case protocol.PerspectiveClient:
h.handleMessageForClient(msgType)
return h.handleMessageForClient(msgType)
case protocol.PerspectiveServer:
h.handleMessageForServer(msgType)
return h.handleMessageForServer(msgType)
default:
panic("")
}
return nil
}
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
@@ -340,78 +303,78 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
return nil
}
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) {
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
switch msgType {
case typeClientHello:
select {
case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params)
case <-h.handshakeErrChan:
return
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
return false
}
// get the 1-RTT write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
return false
}
// get the handshake read key
// TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
return false
}
h.handshakeEvent <- struct{}{}
return true
case typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the 1-RTT read key
// TODO: check that the handshake stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
return false
}
h.handshakeEvent <- struct{}{}
return true
default:
panic("unexpected handshake message")
}
}
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
switch msgType {
case typeServerHello:
// get the handshake read key
// TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
return false
}
h.handshakeEvent <- struct{}{}
return true
case typeEncryptedExtensions:
select {
case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params)
case <-h.handshakeErrChan:
return
return false
}
return false
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the handshake write key
// TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
return false
}
// While the order of these two is not defined by the TLS spec,
// we have to do it on the same order as our TLS library does it.
@@ -419,16 +382,15 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) {
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return
return false
}
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return
return false
}
// TODO: check that the handshake stream doesn't have any more data
h.handshakeEvent <- struct{}{}
return true
default:
panic("unexpected handshake message: ")
}

View File

@@ -63,7 +63,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},
@@ -83,7 +82,7 @@ var _ = Describe("Crypto Setup TLS", func() {
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleData(fakeCH, protocol.EncryptionInitial)
server.HandleMessage(fakeCH, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed())
})
@@ -95,7 +94,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},
@@ -114,7 +112,7 @@ var _ = Describe("Crypto Setup TLS", func() {
}()
fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...)
server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level
Eventually(done).Should(BeClosed())
})
@@ -150,9 +148,9 @@ var _ = Describe("Crypto Setup TLS", func() {
for {
select {
case c := <-cChunkChan:
server.HandleData(c.data, c.encLevel)
server.HandleMessage(c.data, c.encLevel)
case c := <-sChunkChan:
client.HandleData(c.data, c.encLevel)
client.HandleMessage(c.data, c.encLevel)
case <-done: // handshake complete
}
}
@@ -178,7 +176,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
clientConf,
protocol.VersionTLS,
@@ -196,7 +193,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
serverConf,
[]protocol.VersionNumber{protocol.VersionTLS},
@@ -237,7 +233,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
&TransportParameters{},
func(p *TransportParameters) {},
make(chan struct{}, 100),
make(chan struct{}),
&tls.Config{InsecureSkipVerify: true},
protocol.VersionTLS,
@@ -264,7 +259,7 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(len(ch.data) - 4).To(Equal(length))
// make the go routine return
client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial)
Eventually(done).Should(BeClosed())
})
@@ -278,7 +273,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
cTransportParameters,
func(p *TransportParameters) { sTransportParametersRcvd = p },
make(chan struct{}, 100),
make(chan struct{}),
&tls.Config{ServerName: "quic.clemente.io"},
protocol.VersionTLS,
@@ -300,7 +294,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.ConnectionID{},
sTransportParameters,
func(p *TransportParameters) { cTransportParametersRcvd = p },
make(chan struct{}, 100),
make(chan struct{}),
testdata.GetTLSConfig(),
[]protocol.VersionNumber{protocol.VersionTLS},

View File

@@ -44,7 +44,7 @@ type CryptoSetup interface {
type CryptoSetupTLS interface {
baseCryptoSetup
HandleData([]byte, protocol.EncryptionLevel)
HandleMessage([]byte, protocol.EncryptionLevel) bool
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)

View File

@@ -34,12 +34,14 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder {
return m.recorder
}
// HandleData mocks base method
func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) {
m.ctrl.Call(m, "HandleData", arg0, arg1)
// HandleMessage mocks base method
func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// HandleData indicates an expected call of HandleData
func (mr *MockCryptoDataHandlerMockRecorder) HandleData(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleData", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleData), arg0, arg1)
// HandleMessage indicates an expected call of HandleMessage
func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1)
}

View File

@@ -35,6 +35,18 @@ func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder {
return m.recorder
}
// Finish mocks base method
func (m *MockCryptoStream) Finish() error {
ret := m.ctrl.Call(m, "Finish")
ret0, _ := ret[0].(error)
return ret0
}
// Finish indicates an expected call of Finish
func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish))
}
// GetCryptoData mocks base method
func (m *MockCryptoStream) GetCryptoData() []byte {
ret := m.ctrl.Call(m, "GetCryptoData")

View File

@@ -120,6 +120,7 @@ type session struct {
paramsChan <-chan handshake.TransportParameters
// the handshakeEvent channel is passed to the CryptoSetup.
// It receives when it makes sense to try decrypting undecryptable packets.
// Only used for gQUIC.
handshakeEvent <-chan struct{}
handshakeCompleteChan <-chan struct{} // is closed when the handshake completes
handshakeComplete bool
@@ -325,7 +326,6 @@ func newTLSServerSession(
logger utils.Logger,
v protocol.VersionNumber,
) (quicSession, error) {
handshakeEvent := make(chan struct{}, 2) // TODO: explain cap
handshakeCompleteChan := make(chan struct{})
s := &session{
conn: conn,
@@ -334,7 +334,6 @@ func newTLSServerSession(
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveServer,
handshakeEvent: handshakeEvent,
handshakeCompleteChan: handshakeCompleteChan,
logger: logger,
version: v,
@@ -350,7 +349,6 @@ func newTLSServerSession(
origConnID,
params,
s.processTransportParameters,
handshakeEvent,
handshakeCompleteChan,
tlsConf,
conf.Versions,
@@ -403,7 +401,6 @@ var newTLSClientSession = func(
logger utils.Logger,
v protocol.VersionNumber,
) (quicSession, error) {
handshakeEvent := make(chan struct{}, 2) // TODO: explain cap
handshakeCompleteChan := make(chan struct{})
s := &session{
conn: conn,
@@ -412,7 +409,6 @@ var newTLSClientSession = func(
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveClient,
handshakeEvent: handshakeEvent,
handshakeCompleteChan: handshakeCompleteChan,
logger: logger,
version: v,
@@ -426,7 +422,6 @@ var newTLSClientSession = func(
s.destConnID,
params,
s.processTransportParameters,
handshakeEvent,
handshakeCompleteChan,
tlsConf,
initialVersion,
@@ -804,7 +799,14 @@ func (s *session) handlePacket(p *receivedPacket) {
}
func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel)
if err != nil {
return err
}
if encLevelChanged {
s.tryDecryptingQueuedPackets()
}
return nil
}
func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error {