diff --git a/crypto_stream.go b/crypto_stream.go index 879333578..9007a2b03 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "fmt" "io" @@ -13,6 +14,7 @@ type cryptoStream interface { // for receiving data HandleCryptoFrame(*wire.CryptoFrame) error GetCryptoData() []byte + Finish() error // for sending data io.Writer HasData() bool @@ -20,7 +22,11 @@ type cryptoStream interface { } type cryptoStreamImpl struct { - queue *frameSorter + queue *frameSorter + msgBuf []byte + + highestOffset protocol.ByteCount + finished bool writeOffset protocol.ByteCount writeBuf []byte @@ -33,16 +39,53 @@ func newCryptoStream() cryptoStream { } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { - if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset { + highestOffset := f.Offset + protocol.ByteCount(len(f.Data)) + if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset { return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset) } - return s.queue.Push(f.Data, f.Offset, false) + if s.finished { + if highestOffset > s.highestOffset { + // reject crypto data received after this stream was already finished + return errors.New("received crypto data after change of encryption level") + } + // ignore data with a smaller offset than the highest received + // could e.g. be a retransmission + return nil + } + s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset) + if err := s.queue.Push(f.Data, f.Offset, false); err != nil { + return err + } + for { + data, _ := s.queue.Pop() + if data == nil { + return nil + } + s.msgBuf = append(s.msgBuf, data...) + } } // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { - data, _ := s.queue.Pop() - return data + if len(s.msgBuf) < 4 { + return nil + } + msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) + if len(s.msgBuf) < msgLen { + return nil + } + msg := make([]byte, msgLen) + copy(msg, s.msgBuf[:msgLen]) + s.msgBuf = s.msgBuf[msgLen:] + return msg +} + +func (s *cryptoStreamImpl) Finish() error { + if s.queue.HasMoreData() { + return errors.New("encryption level changed, but crypto stream has more data to read") + } + s.finished = true + return nil } // Writes writes data that should be sent out in CRYPTO frames diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 0c1d46942..330b26daf 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -8,7 +8,7 @@ import ( ) type cryptoDataHandler interface { - HandleData([]byte, protocol.EncryptionLevel) + HandleMessage([]byte, protocol.EncryptionLevel) bool } type cryptoStreamManager struct { @@ -30,7 +30,7 @@ func newCryptoStreamManager( } } -func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { +func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) { var str cryptoStream switch encLevel { case protocol.EncryptionInitial: @@ -38,16 +38,18 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve case protocol.EncryptionHandshake: str = m.handshakeStream default: - return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) + return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) } if err := str.HandleCryptoFrame(frame); err != nil { - return err + return false, err } for { data := str.GetCryptoData() if data == nil { - return nil + return false, nil + } + if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { + return true, str.Finish() } - m.cryptoHandler.HandleData(data, encLevel) } } diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index a7b777c5b..b57a02993 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -1,6 +1,8 @@ package quic import ( + "errors" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -13,43 +15,91 @@ var _ = Describe("Crypto Stream Manager", func() { var ( csm *cryptoStreamManager cs *MockCryptoDataHandler + + initialStream *MockCryptoStream + handshakeStream *MockCryptoStream ) BeforeEach(func() { - initialStream := newCryptoStream() - handshakeStream := newCryptoStream() + initialStream = NewMockCryptoStream(mockCtrl) + handshakeStream = NewMockCryptoStream(mockCtrl) cs = NewMockCryptoDataHandler(mockCtrl) csm = newCryptoStreamManager(cs, initialStream, handshakeStream) }) - It("handles in in-order crypto frame", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionInitial) - Expect(csm.HandleCryptoFrame(f, protocol.EncryptionInitial)).To(Succeed()) + It("passes messages to the initial stream", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + initialStream.EXPECT().HandleCryptoFrame(cf) + initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) + initialStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("passes messages to the handshake stream", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) + handshakeStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("doesn't call the message handler, if there's no message", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle + // don't EXPECT any calls to HandleMessage() + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("processes all messages", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + handshakeStream.EXPECT().HandleCryptoFrame(cf) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foo")) + handshakeStream.EXPECT().GetCryptoData().Return([]byte("bar")) + handshakeStream.EXPECT().GetCryptoData() + cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) + cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeFalse()) + }) + + It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { + cf := &wire.CryptoFrame{Data: []byte("foobar")} + gomock.InOrder( + handshakeStream.EXPECT().HandleCryptoFrame(cf), + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), + handshakeStream.EXPECT().Finish(), + ) + encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).ToNot(HaveOccurred()) + Expect(encLevelChanged).To(BeTrue()) + }) + + It("returns errors that occur when finishing a stream", func() { + testErr := errors.New("test error") + cf := &wire.CryptoFrame{Data: []byte("foobar")} + gomock.InOrder( + handshakeStream.EXPECT().HandleCryptoFrame(cf), + handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), + cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), + handshakeStream.EXPECT().Finish().Return(testErr), + ) + _, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) + Expect(err).To(MatchError(err)) }) It("errors for unknown encryption levels", func() { - err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) + _, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) }) - - It("handles out-of-order crypto frames", func() { - f1 := &wire.CryptoFrame{Data: []byte("foo")} - f2 := &wire.CryptoFrame{ - Offset: 3, - Data: []byte("bar"), - } - gomock.InOrder( - cs.EXPECT().HandleData([]byte("foo"), protocol.EncryptionInitial), - cs.EXPECT().HandleData([]byte("bar"), protocol.EncryptionInitial), - ) - Expect(csm.HandleCryptoFrame(f1, protocol.EncryptionInitial)).To(Succeed()) - Expect(csm.HandleCryptoFrame(f2, protocol.EncryptionInitial)).To(Succeed()) - }) - - It("handles handshake data", func() { - f := &wire.CryptoFrame{Data: []byte("foobar")} - cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake) - Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed()) - }) }) diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 514930dbc..ed8e769e0 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,6 +1,7 @@ package quic import ( + "crypto/rand" "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -10,6 +11,16 @@ import ( . "github.com/onsi/gomega" ) +func createHandshakeMessage(len int) []byte { + msg := make([]byte, 4+len) + rand.Read(msg[:1]) // random message type + msg[1] = uint8(len >> 16) + msg[2] = uint8(len >> 8) + msg[3] = uint8(len) + rand.Read(msg[4:]) + return msg +} + var _ = Describe("Crypto Stream", func() { var ( str cryptoStream @@ -21,11 +32,21 @@ var _ = Describe("Crypto Stream", func() { Context("handling incoming data", func() { It("handles in-order CRYPTO frames", func() { - err := str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: []byte("foobar"), - }) + msg := createHandshakeMessage(6) + err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal([]byte("foobar"))) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + It("handles multiple messages in one CRYPTO frame", func() { + msg1 := createHandshakeMessage(6) + msg2 := createHandshakeMessage(10) + msg := append(append([]byte{}, msg1...), msg2...) + err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal(msg1)) + Expect(str.GetCryptoData()).To(Equal(msg2)) Expect(str.GetCryptoData()).To(BeNil()) }) @@ -37,21 +58,83 @@ var _ = Describe("Crypto Stream", func() { Expect(err).To(MatchError(fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset))) }) - It("handles out-of-order CRYPTO frames", func() { + It("handles messages split over multiple CRYPTO frames", func() { + msg := createHandshakeMessage(6) err := str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: 3, - Data: []byte("bar"), + Data: msg[:4], }) Expect(err).ToNot(HaveOccurred()) Expect(str.GetCryptoData()).To(BeNil()) err = str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: []byte("foo"), + Offset: 4, + Data: msg[4:], }) Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal([]byte("foo"))) - Expect(str.GetCryptoData()).To(Equal([]byte("bar"))) + Expect(str.GetCryptoData()).To(Equal(msg)) Expect(str.GetCryptoData()).To(BeNil()) }) + + It("handles out-of-order CRYPTO frames", func() { + msg := createHandshakeMessage(6) + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: 4, + Data: msg[4:], + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(BeNil()) + err = str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: msg[:4], + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + Context("finishing", func() { + It("errors if there's still data to read after finishing", func() { + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: createHandshakeMessage(5), + Offset: 10, + })).To(Succeed()) + err := str.Finish() + Expect(err).To(MatchError("encryption level changed, but crypto stream has more data to read")) + }) + + It("works with reordered data", func() { + f1 := &wire.CryptoFrame{ + Data: []byte("foo"), + } + f2 := &wire.CryptoFrame{ + Offset: 3, + Data: []byte("bar"), + } + Expect(str.HandleCryptoFrame(f2)).To(Succeed()) + Expect(str.HandleCryptoFrame(f1)).To(Succeed()) + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(f2)).To(Succeed()) + }) + + It("rejects new crypto data after finishing", func() { + Expect(str.Finish()).To(Succeed()) + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: createHandshakeMessage(5), + }) + Expect(err).To(MatchError("received crypto data after change of encryption level")) + }) + + It("ignores crypto data below the maximum offset received before finishing", func() { + msg := createHandshakeMessage(15) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: msg, + })).To(Succeed()) + Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.Finish()).To(Succeed()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.ByteCount(len(msg) - 6), + Data: []byte("foobar"), + })).To(Succeed()) + }) + }) }) Context("writing data", func() { diff --git a/frame_sorter.go b/frame_sorter.go index 47062c069..e07dad47f 100644 --- a/frame_sorter.go +++ b/frame_sorter.go @@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) { s.readPos += protocol.ByteCount(len(data)) return data, s.readPos >= s.finalOffset } + +// HasMoreData says if there is any more data queued at *any* offset. +func (s *frameSorter) HasMoreData() bool { + return len(s.queue) > 0 +} diff --git a/frame_sorter_test.go b/frame_sorter_test.go index 9def2100a..433b43a2a 100644 --- a/frame_sorter_test.go +++ b/frame_sorter_test.go @@ -55,6 +55,15 @@ var _ = Describe("STREAM frame sorter", func() { Expect(s.Pop()).To(BeNil()) }) + It("says if has more data", func() { + Expect(s.HasMoreData()).To(BeFalse()) + Expect(s.Push([]byte("foo"), 0, false)).To(Succeed()) + Expect(s.HasMoreData()).To(BeTrue()) + data, _ := s.Pop() + Expect(data).To(Equal([]byte("foo"))) + Expect(s.HasMoreData()).To(BeFalse()) + }) + Context("FIN handling", func() { It("saves a FIN at offset 0", func() { Expect(s.Push(nil, 0, true)).To(Succeed()) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 3b0970d57..1ced3b47e 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -1,7 +1,6 @@ package handshake import ( - "bytes" "crypto/tls" "errors" "fmt" @@ -64,8 +63,6 @@ type cryptoSetupTLS struct { handshakeErrChan chan struct{} // HandleData() sends errors on the messageErrChan messageErrChan chan error - // handshakeEvent signals a change of encryption level to the session - handshakeEvent chan<- struct{} // handshakeComplete is closed when the handshake completes handshakeComplete chan<- struct{} // transport parameters are sent on the receivedTransportParams, as soon as they are received @@ -74,14 +71,12 @@ type cryptoSetupTLS struct { clientHelloWritten bool clientHelloWrittenChan chan struct{} - initialReadBuf bytes.Buffer - initialStream io.Writer - initialAEAD crypto.AEAD + initialStream io.Writer + initialAEAD crypto.AEAD - handshakeReadBuf bytes.Buffer - handshakeStream io.Writer - handshakeOpener Opener - handshakeSealer Sealer + handshakeStream io.Writer + handshakeOpener Opener + handshakeSealer Sealer opener Opener sealer Sealer @@ -111,7 +106,6 @@ func NewCryptoSetupTLSClient( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, initialVersion protocol.VersionNumber, @@ -126,7 +120,6 @@ func NewCryptoSetupTLSClient( connID, params, handleParams, - handshakeEvent, handshakeComplete, tlsConf, versionInfo{ @@ -146,7 +139,6 @@ func NewCryptoSetupTLSServer( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, supportedVersions []protocol.VersionNumber, @@ -160,7 +152,6 @@ func NewCryptoSetupTLSServer( connID, params, handleParams, - handshakeEvent, handshakeComplete, tlsConf, versionInfo{ @@ -179,7 +170,6 @@ func newCryptoSetupTLS( connID protocol.ConnectionID, params *TransportParameters, handleParams func(*TransportParameters), - handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, tlsConf *tls.Config, versionInfo versionInfo, @@ -197,7 +187,6 @@ func newCryptoSetupTLS( readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, handleParamsCallback: handleParams, - handshakeEvent: handshakeEvent, handshakeComplete: handshakeComplete, logger: logger, perspective: perspective, @@ -272,51 +261,25 @@ func (h *cryptoSetupTLS) RunHandshake() error { } } -func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) { - var buf *bytes.Buffer - switch encLevel { - case protocol.EncryptionInitial: - buf = &h.initialReadBuf - case protocol.EncryptionHandshake: - buf = &h.handshakeReadBuf - default: - h.messageErrChan <- fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel) - return - } - buf.Write(data) - for buf.Len() >= 4 { - b := buf.Bytes() - // read the TLS message length - length := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) - if buf.Len() < 4+length { // message not yet complete - return - } - msg := make([]byte, length+4) - buf.Read(msg) - if err := h.handleMessage(msg, encLevel); err != nil { - h.messageErrChan <- err - } - } -} - // handleMessage handles a TLS handshake message. // It is called by the crypto streams when a new message is available. -func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error { +// It returns if it is done with messages on the same encryption level. +func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { msgType := messageType(data[0]) h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { - return err + h.messageErrChan <- err + return false } h.messageChan <- data switch h.perspective { case protocol.PerspectiveClient: - h.handleMessageForClient(msgType) + return h.handleMessageForClient(msgType) case protocol.PerspectiveServer: - h.handleMessageForServer(msgType) + return h.handleMessageForServer(msgType) default: panic("") } - return nil } func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { @@ -340,78 +303,78 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot return nil } -func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) { +func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool { switch msgType { case typeClientHello: select { case params := <-h.receivedTransportParams: h.handleParamsCallback(¶ms) case <-h.handshakeErrChan: - return + return false } // get the handshake write key select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the 1-RTT write key select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the handshake read key // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } - h.handshakeEvent <- struct{}{} + return true case typeCertificate, typeCertificateVerify: // nothing to do + return false case typeFinished: // get the 1-RTT read key - // TODO: check that the handshake stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } - h.handshakeEvent <- struct{}{} + return true default: panic("unexpected handshake message") } } -func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { +func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool { switch msgType { case typeServerHello: // get the handshake read key - // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } - h.handshakeEvent <- struct{}{} + return true case typeEncryptedExtensions: select { case params := <-h.receivedTransportParams: h.handleParamsCallback(¶ms) case <-h.handshakeErrChan: - return + return false } + return false case typeCertificateRequest, typeCertificate, typeCertificateVerify: // nothing to do + return false case typeFinished: // get the handshake write key - // TODO: check that the initial stream doesn't have any more data select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // While the order of these two is not defined by the TLS spec, // we have to do it on the same order as our TLS library does it. @@ -419,16 +382,15 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) { select { case <-h.receivedWriteKey: case <-h.handshakeErrChan: - return + return false } // get the 1-RTT read key select { case <-h.receivedReadKey: case <-h.handshakeErrChan: - return + return false } - // TODO: check that the handshake stream doesn't have any more data - h.handshakeEvent <- struct{}{} + return true default: panic("unexpected handshake message: ") } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 4ff37d040..c0e2cb16a 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -63,7 +63,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, @@ -83,7 +82,7 @@ var _ = Describe("Crypto Setup TLS", func() { }() fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - server.HandleData(fakeCH, protocol.EncryptionInitial) + server.HandleMessage(fakeCH, protocol.EncryptionInitial) Eventually(done).Should(BeClosed()) }) @@ -95,7 +94,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, @@ -114,7 +112,7 @@ var _ = Describe("Crypto Setup TLS", func() { }() fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - server.HandleData(fakeCH, protocol.EncryptionHandshake) // wrong encryption level + server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level Eventually(done).Should(BeClosed()) }) @@ -150,9 +148,9 @@ var _ = Describe("Crypto Setup TLS", func() { for { select { case c := <-cChunkChan: - server.HandleData(c.data, c.encLevel) + server.HandleMessage(c.data, c.encLevel) case c := <-sChunkChan: - client.HandleData(c.data, c.encLevel) + client.HandleMessage(c.data, c.encLevel) case <-done: // handshake complete } } @@ -178,7 +176,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), clientConf, protocol.VersionTLS, @@ -196,7 +193,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), serverConf, []protocol.VersionNumber{protocol.VersionTLS}, @@ -237,7 +233,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, &TransportParameters{}, func(p *TransportParameters) {}, - make(chan struct{}, 100), make(chan struct{}), &tls.Config{InsecureSkipVerify: true}, protocol.VersionTLS, @@ -264,7 +259,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(len(ch.data) - 4).To(Equal(length)) // make the go routine return - client.HandleData([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) + client.HandleMessage([]byte{42 /* unknown handshake message type */, 0, 0, 1, 0}, protocol.EncryptionInitial) Eventually(done).Should(BeClosed()) }) @@ -278,7 +273,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, cTransportParameters, func(p *TransportParameters) { sTransportParametersRcvd = p }, - make(chan struct{}, 100), make(chan struct{}), &tls.Config{ServerName: "quic.clemente.io"}, protocol.VersionTLS, @@ -300,7 +294,6 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.ConnectionID{}, sTransportParameters, func(p *TransportParameters) { cTransportParametersRcvd = p }, - make(chan struct{}, 100), make(chan struct{}), testdata.GetTLSConfig(), []protocol.VersionNumber{protocol.VersionTLS}, diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 7acb6a13c..88122dc1b 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -44,7 +44,7 @@ type CryptoSetup interface { type CryptoSetupTLS interface { baseCryptoSetup - HandleData([]byte, protocol.EncryptionLevel) + HandleMessage([]byte, protocol.EncryptionLevel) bool OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) diff --git a/mock_crypto_data_handler.go b/mock_crypto_data_handler.go index b789ca8b5..37a408008 100644 --- a/mock_crypto_data_handler.go +++ b/mock_crypto_data_handler.go @@ -34,12 +34,14 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { return m.recorder } -// HandleData mocks base method -func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) { - m.ctrl.Call(m, "HandleData", arg0, arg1) +// HandleMessage mocks base method +func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { + ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 } -// HandleData indicates an expected call of HandleData -func (mr *MockCryptoDataHandlerMockRecorder) HandleData(arg0, arg1 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleData", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleData), arg0, arg1) +// HandleMessage indicates an expected call of HandleMessage +func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) } diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 66de8d0f9..0465e7e25 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -35,6 +35,18 @@ func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder { return m.recorder } +// Finish mocks base method +func (m *MockCryptoStream) Finish() error { + ret := m.ctrl.Call(m, "Finish") + ret0, _ := ret[0].(error) + return ret0 +} + +// Finish indicates an expected call of Finish +func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish)) +} + // GetCryptoData mocks base method func (m *MockCryptoStream) GetCryptoData() []byte { ret := m.ctrl.Call(m, "GetCryptoData") diff --git a/session.go b/session.go index c8ca516ab..ab9ce5844 100644 --- a/session.go +++ b/session.go @@ -120,6 +120,7 @@ type session struct { paramsChan <-chan handshake.TransportParameters // the handshakeEvent channel is passed to the CryptoSetup. // It receives when it makes sense to try decrypting undecryptable packets. + // Only used for gQUIC. handshakeEvent <-chan struct{} handshakeCompleteChan <-chan struct{} // is closed when the handshake completes handshakeComplete bool @@ -325,7 +326,6 @@ func newTLSServerSession( logger utils.Logger, v protocol.VersionNumber, ) (quicSession, error) { - handshakeEvent := make(chan struct{}, 2) // TODO: explain cap handshakeCompleteChan := make(chan struct{}) s := &session{ conn: conn, @@ -334,7 +334,6 @@ func newTLSServerSession( srcConnID: srcConnID, destConnID: destConnID, perspective: protocol.PerspectiveServer, - handshakeEvent: handshakeEvent, handshakeCompleteChan: handshakeCompleteChan, logger: logger, version: v, @@ -350,7 +349,6 @@ func newTLSServerSession( origConnID, params, s.processTransportParameters, - handshakeEvent, handshakeCompleteChan, tlsConf, conf.Versions, @@ -403,7 +401,6 @@ var newTLSClientSession = func( logger utils.Logger, v protocol.VersionNumber, ) (quicSession, error) { - handshakeEvent := make(chan struct{}, 2) // TODO: explain cap handshakeCompleteChan := make(chan struct{}) s := &session{ conn: conn, @@ -412,7 +409,6 @@ var newTLSClientSession = func( srcConnID: srcConnID, destConnID: destConnID, perspective: protocol.PerspectiveClient, - handshakeEvent: handshakeEvent, handshakeCompleteChan: handshakeCompleteChan, logger: logger, version: v, @@ -426,7 +422,6 @@ var newTLSClientSession = func( s.destConnID, params, s.processTransportParameters, - handshakeEvent, handshakeCompleteChan, tlsConf, initialVersion, @@ -804,7 +799,14 @@ func (s *session) handlePacket(p *receivedPacket) { } func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { - return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + if err != nil { + return err + } + if encLevelChanged { + s.tryDecryptingQueuedPackets() + } + return nil } func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error {