From 19e5feef57b1c6b386228dc5d70d76f99f24933e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 20 Oct 2018 10:11:25 +0900 Subject: [PATCH 1/4] move TLS message header parsing logic to the crypto stream --- crypto_stream.go | 27 ++++++++-- crypto_stream_manager.go | 4 +- crypto_stream_manager_test.go | 50 +++++++++--------- crypto_stream_test.go | 57 +++++++++++++++++---- internal/handshake/crypto_setup_tls.go | 46 +++-------------- internal/handshake/crypto_setup_tls_test.go | 10 ++-- internal/handshake/interface.go | 2 +- mock_crypto_data_handler.go | 12 ++--- 8 files changed, 117 insertions(+), 91 deletions(-) diff --git a/crypto_stream.go b/crypto_stream.go index 879333578..fbd41d7e2 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -20,7 +20,8 @@ type cryptoStream interface { } type cryptoStreamImpl struct { - queue *frameSorter + queue *frameSorter + msgBuf []byte writeOffset protocol.ByteCount writeBuf []byte @@ -36,13 +37,31 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset { return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset) } - return s.queue.Push(f.Data, f.Offset, false) + if err := s.queue.Push(f.Data, f.Offset, false); err != nil { + return err + } + for { + data, _ := s.queue.Pop() + if data == nil { + return nil + } + s.msgBuf = append(s.msgBuf, data...) + } } // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { - data, _ := s.queue.Pop() - return data + if len(s.msgBuf) < 4 { + return nil + } + msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) + if len(s.msgBuf) < msgLen { + return nil + } + msg := make([]byte, msgLen) + copy(msg, s.msgBuf[:msgLen]) + s.msgBuf = s.msgBuf[msgLen:] + return msg } // Writes writes data that should be sent out in CRYPTO frames diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 0c1d46942..764bc2f23 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -8,7 +8,7 @@ import ( ) type cryptoDataHandler interface { - HandleData([]byte, protocol.EncryptionLevel) + HandleMessage([]byte, protocol.EncryptionLevel) } type cryptoStreamManager struct { @@ -48,6 +48,6 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve if data == nil { return nil } - m.cryptoHandler.HandleData(data, encLevel) + m.cryptoHandler.HandleMessage(data, encLevel) } } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index a7b777c5b..733d741cf 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -1,7 +1,6 @@ package quic import ( - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -22,34 +21,35 @@ var _ = Describe("Crypto Stream Manager", func() { csm = newCryptoStreamManager(cs, initialStream, handshakeStream) }) - It("handles in in-order crypto frame", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionInitial) - Expect(csm.HandleCryptoFrame(f, protocol.EncryptionInitial)).To(Succeed()) + It("passes messages to the right stream", func() { + initialMsg := createHandshakeMessage(10) + handshakeMsg := createHandshakeMessage(20) + + // only pass in a part of the message, to make sure they get assembled in the right crypto stream + Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{ + Data: initialMsg[:5], + }, protocol.EncryptionInitial)).To(Succeed()) + Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{ + Data: handshakeMsg[:5], + }, protocol.EncryptionHandshake)).To(Succeed()) + + // now pass in the rest of the initial message + cs.EXPECT().HandleMessage(initialMsg, protocol.EncryptionInitial) + Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{ + Data: initialMsg[5:], + Offset: 5, + }, protocol.EncryptionInitial)).To(Succeed()) + + // now pass in the rest of the handshake message + cs.EXPECT().HandleMessage(handshakeMsg, protocol.EncryptionHandshake) + Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{ + Data: handshakeMsg[5:], + Offset: 5, + }, protocol.EncryptionHandshake)).To(Succeed()) }) It("errors for unknown encryption levels", func() { err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) }) - - It("handles out-of-order crypto frames", func() { - f1 := &wire.CryptoFrame{Data: []byte("foo")} - f2 := &wire.CryptoFrame{ - Offset: 3, - Data: []byte("bar"), - } - gomock.InOrder( - cs.EXPECT().HandleData([]byte("foo"), protocol.EncryptionInitial), - cs.EXPECT().HandleData([]byte("bar"), protocol.EncryptionInitial), - ) - Expect(csm.HandleCryptoFrame(f1, protocol.EncryptionInitial)).To(Succeed()) - Expect(csm.HandleCryptoFrame(f2, protocol.EncryptionInitial)).To(Succeed()) - }) - - It("handles handshake data", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake) - Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed()) - }) }) diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 514930dbc..98f35619f 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,6 +1,7 @@ package quic import ( + "crypto/rand" "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -10,6 +11,16 @@ import ( . "github.com/onsi/gomega" ) +func createHandshakeMessage(len int) []byte { + msg := make([]byte, 4+len) + rand.Read(msg[:1]) // random message type + msg[1] = uint8(len >> 16) + msg[2] = uint8(len >> 8) + msg[3] = uint8(len) + rand.Read(msg[4:]) + return msg +} + var _ = Describe("Crypto Stream", func() { var ( str cryptoStream @@ -21,11 +32,21 @@ var _ = Describe("Crypto Stream", func() { Context("handling incoming data", func() { It("handles in-order CRYPTO frames", func() { - err := str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: []byte("foobar"), - }) + msg := createHandshakeMessage(6) + err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal([]byte("foobar"))) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + It("handles multiple messages in one CRYPTO frame", func() { + msg1 := createHandshakeMessage(6) + msg2 := createHandshakeMessage(10) + msg := append(append([]byte{}, msg1...), msg2...) + err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal(msg1)) + Expect(str.GetCryptoData()).To(Equal(msg2)) Expect(str.GetCryptoData()).To(BeNil()) }) @@ -37,19 +58,35 @@ var _ = Describe("Crypto Stream", func() { Expect(err).To(MatchError(fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset))) }) - It("handles out-of-order CRYPTO frames", func() { + It("handles messages split over multiple CRYPTO frames", func() { + msg := createHandshakeMessage(6) err := str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: 3, - Data: []byte("bar"), + Data: msg[:4], }) Expect(err).ToNot(HaveOccurred()) Expect(str.GetCryptoData()).To(BeNil()) err = str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: []byte("foo"), + Offset: 4, + Data: msg[4:], }) Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal([]byte("foo"))) - Expect(str.GetCryptoData()).To(Equal([]byte("bar"))) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + It("handles out-of-order CRYPTO frames", func() { + msg := createHandshakeMessage(6) + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: 4, + Data: msg[4:], + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(BeNil()) + err = str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: msg[:4], + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal(msg)) Expect(str.GetCryptoData()).To(BeNil()) }) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 3b0970d57..efe0ba3f1 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -1,7 +1,6 @@ package handshake import ( - "bytes" "crypto/tls" "errors" "fmt" @@ -74,14 +73,12 @@ type cryptoSetupTLS struct { clientHelloWritten bool clientHelloWrittenChan chan struct{} - initialReadBuf bytes.Buffer - initialStream io.Writer - initialAEAD crypto.AEAD + initialStream io.Writer + initialAEAD crypto.AEAD - handshakeReadBuf bytes.Buffer - handshakeStream io.Writer - handshakeOpener Opener - handshakeSealer Sealer + handshakeStream io.Writer + handshakeOpener Opener + handshakeSealer Sealer opener Opener sealer Sealer @@ -272,40 +269,14 @@ func (h *cryptoSetupTLS) RunHandshake() error { } } -func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) { - var buf *bytes.Buffer - switch encLevel { - case protocol.EncryptionInitial: - buf = &h.initialReadBuf - case protocol.EncryptionHandshake: - buf = &h.handshakeReadBuf - default: - h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel) - return - } - buf.Write(data) - for buf.Len() >= 4 { - b := buf.Bytes() - // read the TLS message length - length := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) - if buf.Len() < 4+length { // message not yet complete - return - } - msg := make([]byte, length+4) - buf.Read(msg) - if err := h.handleMessage(msg, encLevel); err != nil { - h.messageErrChan <- err - } - } -} - // handleMessage handles a TLS handshake message. // It is called by the crypto streams when a new message is available. -func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error { +func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) { msgType := messageType(data[0]) h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { - return err + h.messageErrChan <- err + return } h.messageChan <- data switch h.perspective { @@ -316,7 +287,6 @@ func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.Encryption default: panic("") } - return nil } func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 4ff37d040..0f5206776 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -83,7 +83,7 @@ var _ = Describe("Crypto Setup TLS", func() { }() fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - server.HandleData(fakeCH, protocol.EncryptionInitial) + server.HandleMessage(fakeCH, protocol.EncryptionInitial) Eventually(done).Should(BeClosed()) }) @@ -114,7 +114,7 @@ var _ = Describe("Crypto Setup TLS", func() { }() fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level + server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level Eventually(done).Should(BeClosed()) }) @@ -150,9 +150,9 @@ var _ = Describe("Crypto Setup TLS", func() { for { select { case c := <-cChunkChan: - server.HandleData(c.data, c.encLevel) + server.HandleMessage(c.data, c.encLevel) case c := <-sChunkChan: - client.HandleData(c.data, c.encLevel) + client.HandleMessage(c.data, c.encLevel) case <-done: // handshake complete } } @@ -264,7 +264,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(len(ch.data) - 4).To(Equal(length)) // make the go routine return - client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) + client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) Eventually(done).Should(BeClosed()) }) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 7acb6a13c..264e9a658 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -44,7 +44,7 @@ type CryptoSetup interface { type CryptoSetupTLS interface { baseCryptoSetup - HandleData([]byte, protocol.EncryptionLevel) + HandleMessage([]byte, protocol.EncryptionLevel) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) diff --git a/mock_crypto_data_handler.go b/mock_crypto_data_handler.go index b789ca8b5..8c7400664 100644 --- a/mock_crypto_data_handler.go +++ b/mock_crypto_data_handler.go @@ -34,12 +34,12 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { return m.recorder } -// HandleData mocks base method -func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) { - m.ctrl.Call(m, "HandleData", arg0, arg1) +// HandleMessage mocks base method +func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) { + m.ctrl.Call(m, "HandleMessage", arg0, arg1) } -// HandleData indicates an expected call of HandleData -func (mr *MockCryptoDataHandlerMockRecorder) HandleData(arg0, arg1 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleData", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleData), arg0, arg1) +// HandleMessage indicates an expected call of HandleMessage +func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) } From fe442e4d195aade2864a88c147f275c804ac7b03 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 20 Oct 2018 11:13:17 +0900 Subject: [PATCH 2/4] use a mock crypto stream in the crypto stream manager tests --- crypto_stream_manager_test.go | 61 +++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index 733d741cf..21785f5e3 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -12,40 +12,53 @@ var _ = Describe("Crypto Stream Manager", func() { var ( csm *cryptoStreamManager cs *MockCryptoDataHandler + + initialStream *MockCryptoStream + handshakeStream *MockCryptoStream ) BeforeEach(func() { - initialStream := newCryptoStream() - handshakeStream := newCryptoStream() + initialStream = NewMockCryptoStream(mockCtrl) + handshakeStream = NewMockCryptoStream(mockCtrl) cs = NewMockCryptoDataHandler(mockCtrl) csm = newCryptoStreamManager(cs, initialStream, handshakeStream) }) - It("passes messages to the right stream", func() { - initialMsg := createHandshakeMessage(10) - handshakeMsg := createHandshakeMessage(20) + It("passes messages to the initial stream", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + initialStream.EXPECT().HandleCryptoFrame(cf) + initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) + initialStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed()) + }) - // only pass in a part of the message, to make sure they get assembled in the right crypto stream - Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{ - Data: initialMsg[:5], - }, protocol.EncryptionInitial)).To(Succeed()) - Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{ - Data: handshakeMsg[:5], - }, protocol.EncryptionHandshake)).To(Succeed()) + It("passes messages to the handshake stream", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) + handshakeStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + }) - // now pass in the rest of the initial message - cs.EXPECT().HandleMessage(initialMsg, protocol.EncryptionInitial) - Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{ - Data: initialMsg[5:], - Offset: 5, - }, protocol.EncryptionInitial)).To(Succeed()) + It("doesn't call the message handler, if there's no message", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle + // don't EXPECT any calls to HandleMessage() + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + }) - // now pass in the rest of the handshake message - cs.EXPECT().HandleMessage(handshakeMsg, protocol.EncryptionHandshake) - Expect(csm.HandleCryptoFrame(&wire.CryptoFrame{ - Data: handshakeMsg[5:], - Offset: 5, - }, protocol.EncryptionHandshake)).To(Succeed()) + It("processes all messages", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foo")) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("bar")) + handshakeStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) + cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) It("errors for unknown encryption levels", func() { From 387c28d707749e6316d7ee34079e5ac7b922adce Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 20 Oct 2018 11:22:05 +0900 Subject: [PATCH 3/4] when the encryption level changes, reject data on that crypto stream There are two checks that need to be performed: 1. the crypto stream must not have any more data queued for reading 2. when receiving CRYPTO frames for that crypto stream afterwards, they must not exceed the highest offset received on that stream --- crypto_stream.go | 26 ++++++++++++++- crypto_stream_manager.go | 6 ++-- crypto_stream_manager_test.go | 26 +++++++++++++++ crypto_stream_test.go | 46 ++++++++++++++++++++++++++ frame_sorter.go | 5 +++ frame_sorter_test.go | 9 +++++ internal/handshake/crypto_setup_tls.go | 44 +++++++++++++----------- internal/handshake/interface.go | 2 +- mock_crypto_data_handler.go | 6 ++-- mock_crypto_stream_test.go | 12 +++++++ 10 files changed, 156 insertions(+), 26 deletions(-) diff --git a/crypto_stream.go b/crypto_stream.go index fbd41d7e2..9007a2b03 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "fmt" "io" @@ -13,6 +14,7 @@ type cryptoStream interface { // for receiving data HandleCryptoFrame(*wire.CryptoFrame) error GetCryptoData() []byte + Finish() error // for sending data io.Writer HasData() bool @@ -23,6 +25,9 @@ type cryptoStreamImpl struct { queue *frameSorter msgBuf []byte + highestOffset protocol.ByteCount + finished bool + writeOffset protocol.ByteCount writeBuf []byte } @@ -34,9 +39,20 @@ func newCryptoStream() cryptoStream { } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { - if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset { + highestOffset := f.Offset + protocol.ByteCount(len(f.Data)) + if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset { return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset) } + if s.finished { + if highestOffset > s.highestOffset { + // reject crypto data received after this stream was already finished + return errors.New("received crypto data after change of encryption level") + } + // ignore data with a smaller offset than the highest received + // could e.g. be a retransmission + return nil + } + s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset) if err := s.queue.Push(f.Data, f.Offset, false); err != nil { return err } @@ -64,6 +80,14 @@ func (s *cryptoStreamImpl) GetCryptoData() []byte { return msg } +func (s *cryptoStreamImpl) Finish() error { + if s.queue.HasMoreData() { + return errors.New("encryption level changed, but crypto stream has more data to read") + } + s.finished = true + return nil +} + // Writes writes data that should be sent out in CRYPTO frames func (s *cryptoStreamImpl) Write(p []byte) (int, error) { s.writeBuf = append(s.writeBuf, p...) diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 764bc2f23..0498b5162 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -8,7 +8,7 @@ import ( ) type cryptoDataHandler interface { - HandleMessage([]byte, protocol.EncryptionLevel) + HandleMessage([]byte, protocol.EncryptionLevel) bool } type cryptoStreamManager struct { @@ -48,6 +48,8 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve if data == nil { return nil } - m.cryptoHandler.HandleMessage(data, encLevel) + if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { + return str.Finish() + } } } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index 21785f5e3..aada3197c 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -1,6 +1,9 @@ package quic import ( + "errors" + + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -61,6 +64,29 @@ var _ = Describe("Crypto Stream Manager", func() { Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) + It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + gomock.InOrder( + handshakeStream.EXPECT().HandleCryptoFrame(cf), + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), + handshakeStream.EXPECT().Finish(), + ) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + }) + + It("returns errors that occur when finishing a stream", func() { + testErr := errors.New("test error") + cf := &wire.CryptoFrame{Data: []byte("foobar")} + gomock.InOrder( + handshakeStream.EXPECT().HandleCryptoFrame(cf), + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), + handshakeStream.EXPECT().Finish().Return(testErr), + ) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(MatchError(testErr)) + }) + It("errors for unknown encryption levels", func() { err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 98f35619f..ed8e769e0 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -89,6 +89,52 @@ var _ = Describe("Crypto Stream", func() { Expect(str.GetCryptoData()).To(Equal(msg)) Expect(str.GetCryptoData()).To(BeNil()) }) + + Context("finishing", func() { + It("errors if there's still data to read after finishing", func() { + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: createHandshakeMessage(5), + Offset: 10, + })).To(Succeed()) + err := str.Finish() + Expect(err).To(MatchError("encryption level changed, but crypto stream has more data to read")) + }) + + It("works with reordered data", func() { + f1 := &wire.CryptoFrame{ + Data: []byte("foo"), + } + f2 := &wire.CryptoFrame{ + Offset: 3, + Data: []byte("bar"), + } + Expect(str.HandleCryptoFrame(f2)).To(Succeed()) + Expect(str.HandleCryptoFrame(f1)).To(Succeed()) + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(f2)).To(Succeed()) + }) + + It("rejects new crypto data after finishing", func() { + Expect(str.Finish()).To(Succeed()) + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: createHandshakeMessage(5), + }) + Expect(err).To(MatchError("received crypto data after change of encryption level")) + }) + + It("ignores crypto data below the maximum offset received before finishing", func() { + msg := createHandshakeMessage(15) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: msg, + })).To(Succeed()) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.ByteCount(len(msg) - 6), + Data: []byte("foobar"), + })).To(Succeed()) + }) + }) }) Context("writing data", func() { diff --git a/frame_sorter.go b/frame_sorter.go index 47062c069..e07dad47f 100644 --- a/frame_sorter.go +++ b/frame_sorter.go @@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) { s.readPos += protocol.ByteCount(len(data)) return data, s.readPos >= s.finalOffset } + +// HasMoreData says if there is any more data queued at *any* offset. +func (s *frameSorter) HasMoreData() bool { + return len(s.queue) > 0 +} diff --git a/frame_sorter_test.go b/frame_sorter_test.go index 9def2100a..433b43a2a 100644 --- a/frame_sorter_test.go +++ b/frame_sorter_test.go @@ -55,6 +55,15 @@ var _ = Describe("STREAM frame sorter", func() { Expect(s.Pop()).To(BeNil()) }) + It("says if has more data", func() { + Expect(s.HasMoreData()).To(BeFalse()) + Expect(s.Push([]byte("foo"), 0, false)).To(Succeed()) + Expect(s.HasMoreData()).To(BeTrue()) + data, _ := s.Pop() + Expect(data).To(Equal([]byte("foo"))) + Expect(s.HasMoreData()).To(BeFalse()) + }) + Context("FIN handling", func() { It("saves a FIN at offset 0", func() { Expect(s.Push(nil, 0, true)).To(Succeed()) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index efe0ba3f1..e8a06ea83 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -271,19 +271,20 @@ func (h *cryptoSetupTLS) RunHandshake() error { // handleMessage handles a TLS handshake message. // It is called by the crypto streams when a new message is available. -func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) { +// It returns if it is done with messages on the same encryption level. +func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { msgType := messageType(data[0]) h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { h.messageErrChan <- err - return + return false } h.messageChan <- data switch h.perspective { case protocol.PerspectiveClient: - h.handleMessageForClient(msgType) + return h.handleMessageForClient(msgType) case protocol.PerspectiveServer: - h.handleMessageForServer(msgType) + return h.handleMessageForServer(msgType) default: panic("") } @@ -310,78 +311,81 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot return nil } -func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) { +func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool { switch msgType { case typeClientHello: select { case params := <-h.receivedTransportParams: h.handleParamsCallback(¶ms) case <-h.handshakeErrChan: - return + return false } // get the handshake write key select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the 1-RTT write key select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the handshake read key // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } h.handshakeEvent <- struct{}{} + return true case typeCertificate, typeCertificateVerify: // nothing to do + return false case typeFinished: // get the 1-RTT read key - // TODO: check that the handshake stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } h.handshakeEvent <- struct{}{} + return true default: panic("unexpected handshake message") } } -func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { +func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool { switch msgType { case typeServerHello: // get the handshake read key - // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } h.handshakeEvent <- struct{}{} + return true case typeEncryptedExtensions: select { case params := <-h.receivedTransportParams: h.handleParamsCallback(¶ms) case <-h.handshakeErrChan: - return + return false } + return false case typeCertificateRequest, typeCertificate, typeCertificateVerify: // nothing to do + return false case typeFinished: // get the handshake write key - // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // While the order of these two is not defined by the TLS spec, // we have to do it on the same order as our TLS library does it. @@ -389,16 +393,16 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the 1-RTT read key select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } - // TODO: check that the handshake stream doesn't have any more data h.handshakeEvent <- struct{}{} + return true default: panic("unexpected handshake message: ") } diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 264e9a658..88122dc1b 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -44,7 +44,7 @@ type CryptoSetup interface { type CryptoSetupTLS interface { baseCryptoSetup - HandleMessage([]byte, protocol.EncryptionLevel) + HandleMessage([]byte, protocol.EncryptionLevel) bool OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) diff --git a/mock_crypto_data_handler.go b/mock_crypto_data_handler.go index 8c7400664..37a408008 100644 --- a/mock_crypto_data_handler.go +++ b/mock_crypto_data_handler.go @@ -35,8 +35,10 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { } // HandleMessage mocks base method -func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) { - m.ctrl.Call(m, "HandleMessage", arg0, arg1) +func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { + ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 } // HandleMessage indicates an expected call of HandleMessage diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 66de8d0f9..0465e7e25 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -35,6 +35,18 @@ func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder { return m.recorder } +// Finish mocks base method +func (m *MockCryptoStream) Finish() error { + ret := m.ctrl.Call(m, "Finish") + ret0, _ := ret[0].(error) + return ret0 +} + +// Finish indicates an expected call of Finish +func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish)) +} + // GetCryptoData mocks base method func (m *MockCryptoStream) GetCryptoData() []byte { ret := m.ctrl.Call(m, "GetCryptoData") From b63c81f0bf5485a84714e38fa81a00573199c8db Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 20 Oct 2018 11:40:33 +0900 Subject: [PATCH 4/4] try decrypting undecryptable packets when the encryption level changes There's no need to do this asynchronously any more when using TLS. --- crypto_stream_manager.go | 10 ++++----- crypto_stream_manager_test.go | 25 +++++++++++++++------ internal/handshake/crypto_setup_tls.go | 12 ---------- internal/handshake/crypto_setup_tls_test.go | 7 ------ session.go | 16 +++++++------ 5 files changed, 32 insertions(+), 38 deletions(-) diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 0498b5162..330b26daf 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -30,7 +30,7 @@ func newCryptoStreamManager( } } -func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { +func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) { var str cryptoStream switch encLevel { case protocol.EncryptionInitial: @@ -38,18 +38,18 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve case protocol.EncryptionHandshake: str = m.handshakeStream default: - return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) + return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) } if err := str.HandleCryptoFrame(frame); err != nil { - return err + return false, err } for { data := str.GetCryptoData() if data == nil { - return nil + return false, nil } if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { - return str.Finish() + return true, str.Finish() } } } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index aada3197c..b57a02993 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -33,7 +33,9 @@ var _ = Describe("Crypto Stream Manager", func() { initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) initialStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) }) It("passes messages to the handshake stream", func() { @@ -42,7 +44,9 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) handshakeStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) }) It("doesn't call the message handler, if there's no message", func() { @@ -50,7 +54,9 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().HandleCryptoFrame(cf) handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle // don't EXPECT any calls to HandleMessage() - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) }) It("processes all messages", func() { @@ -61,7 +67,9 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) }) It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { @@ -72,7 +80,9 @@ var _ = Describe("Crypto Stream Manager", func() { cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), handshakeStream.EXPECT().Finish(), ) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeTrue()) }) It("returns errors that occur when finishing a stream", func() { @@ -84,11 +94,12 @@ var _ = Describe("Crypto Stream Manager", func() { cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), handshakeStream.EXPECT().Finish().Return(testErr), ) - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(MatchError(testErr)) + _, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).To(MatchError(err)) }) It("errors for unknown encryption levels", func() { - err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) + _, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) }) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index e8a06ea83..1ced3b47e 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -63,8 +63,6 @@ type cryptoSetupTLS struct { handshakeErrChan chan struct{} // HandleData() sends errors on the messageErrChan messageErrChan chan error - // handshakeEvent signals a change of encryption level to the session - handshakeEvent chan<- struct{} // handshakeComplete is closed when the handshake completes handshakeComplete chan<- struct{} // transport parameters are sent on the receivedTransportParams, as soon as they are received @@ -108,7 +106,6 @@ func NewCryptoSetupTLSClient( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, initialVersion protocol.VersionNumber, @@ -123,7 +120,6 @@ func NewCryptoSetupTLSClient( connID, params, handleParams, - handshakeEvent, handshakeComplete, tlsConf, versionInfo{ @@ -143,7 +139,6 @@ func NewCryptoSetupTLSServer( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, supportedVersions []protocol.VersionNumber, @@ -157,7 +152,6 @@ func NewCryptoSetupTLSServer( connID, params, handleParams, - handshakeEvent, handshakeComplete, tlsConf, versionInfo{ @@ -176,7 +170,6 @@ func newCryptoSetupTLS( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, versionInfo versionInfo, @@ -194,7 +187,6 @@ func newCryptoSetupTLS( readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, handleParamsCallback: handleParams, - handshakeEvent: handshakeEvent, handshakeComplete: handshakeComplete, logger: logger, perspective: perspective, @@ -339,7 +331,6 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool { case <-h.handshakeErrChan: return false } - h.handshakeEvent <- struct{}{} return true case typeCertificate, typeCertificateVerify: // nothing to do @@ -351,7 +342,6 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool { case <-h.handshakeErrChan: return false } - h.handshakeEvent <- struct{}{} return true default: panic("unexpected handshake message") @@ -367,7 +357,6 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool { case <-h.handshakeErrChan: return false } - h.handshakeEvent <- struct{}{} return true case typeEncryptedExtensions: select { @@ -401,7 +390,6 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool { case <-h.handshakeErrChan: return false } - h.handshakeEvent <- struct{}{} return true default: panic("unexpected handshake message: ") diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 0f5206776..c0e2cb16a 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -63,7 +63,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, @@ -95,7 +94,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, @@ -178,7 +176,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), clientConf, protocol.VersionTLS, @@ -196,7 +193,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), serverConf, []protocol.VersionNumber{protocol.VersionTLS}, @@ -237,7 +233,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), &tls.Config{InsecureSkipVerify: true}, protocol.VersionTLS, @@ -278,7 +273,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, cTransportParameters, func(p *TransportParameters) { sTransportParametersRcvd = p }, - make(chan struct{}, 100), make(chan struct{}), &tls.Config{ServerName: "quic.clemente.io"}, protocol.VersionTLS, @@ -300,7 +294,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, sTransportParameters, func(p *TransportParameters) { cTransportParametersRcvd = p }, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, diff --git a/session.go b/session.go index c8ca516ab..ab9ce5844 100644 --- a/session.go +++ b/session.go @@ -120,6 +120,7 @@ type session struct { paramsChan <-chan handshake.TransportParameters // the handshakeEvent channel is passed to the CryptoSetup. // It receives when it makes sense to try decrypting undecryptable packets. + // Only used for gQUIC. handshakeEvent <-chan struct{} handshakeCompleteChan <-chan struct{} // is closed when the handshake completes handshakeComplete bool @@ -325,7 +326,6 @@ func newTLSServerSession( logger utils.Logger, v protocol.VersionNumber, ) (quicSession, error) { - handshakeEvent := make(chan struct{}, 2) // TODO: explain cap handshakeCompleteChan := make(chan struct{}) s := &session{ conn: conn, @@ -334,7 +334,6 @@ func newTLSServerSession( srcConnID: srcConnID, destConnID: destConnID, perspective: protocol.PerspectiveServer, - handshakeEvent: handshakeEvent, handshakeCompleteChan: handshakeCompleteChan, logger: logger, version: v, @@ -350,7 +349,6 @@ func newTLSServerSession( origConnID, params, s.processTransportParameters, - handshakeEvent, handshakeCompleteChan, tlsConf, conf.Versions, @@ -403,7 +401,6 @@ var newTLSClientSession = func( logger utils.Logger, v protocol.VersionNumber, ) (quicSession, error) { - handshakeEvent := make(chan struct{}, 2) // TODO: explain cap handshakeCompleteChan := make(chan struct{}) s := &session{ conn: conn, @@ -412,7 +409,6 @@ var newTLSClientSession = func( srcConnID: srcConnID, destConnID: destConnID, perspective: protocol.PerspectiveClient, - handshakeEvent: handshakeEvent, handshakeCompleteChan: handshakeCompleteChan, logger: logger, version: v, @@ -426,7 +422,6 @@ var newTLSClientSession = func( s.destConnID, params, s.processTransportParameters, - handshakeEvent, handshakeCompleteChan, tlsConf, initialVersion, @@ -804,7 +799,14 @@ func (s *session) handlePacket(p *receivedPacket) { } func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { - return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + if err != nil { + return err + } + if encLevelChanged { + s.tryDecryptingQueuedPackets() + } + return nil } func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error {