diff --git a/connection.go b/connection.go index e603ce38..872ba79a 100644 --- a/connection.go +++ b/connection.go @@ -58,6 +58,7 @@ type cryptoStreamHandler interface { GetSessionTicket() ([]byte, error) NextEvent() handshake.Event DiscardInitialKeys() + HandleMessage([]byte, protocol.EncryptionLevel) error io.Closer ConnectionState() handshake.ConnectionState } @@ -333,7 +334,7 @@ var newConnection = func( s.cryptoStreamHandler = cs s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) - s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, s.oneRTTStream) + s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, s.oneRTTStream) return s } @@ -437,7 +438,7 @@ var newClientConnection = func( s.version, ) s.cryptoStreamHandler = cs - s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream) + s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, oneRTTStream) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) if len(tlsConf.ServerName) > 0 { @@ -1377,6 +1378,15 @@ func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protoco if err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil { return err } + for { + data := s.cryptoStreamManager.GetCryptoData(encLevel) + if data == nil { + break + } + if err := s.cryptoStreamHandler.HandleMessage(data, encLevel); err != nil { + return err + } + } return s.handleHandshakeEvents() } diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index c48e238a..b29eb408 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -3,32 +3,22 @@ package quic import ( "fmt" - "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) -type cryptoDataHandler interface { - HandleMessage([]byte, protocol.EncryptionLevel) error - NextEvent() handshake.Event -} - type cryptoStreamManager struct { - cryptoHandler cryptoDataHandler - initialStream cryptoStream handshakeStream cryptoStream oneRTTStream cryptoStream } func newCryptoStreamManager( - cryptoHandler cryptoDataHandler, initialStream cryptoStream, handshakeStream cryptoStream, oneRTTStream cryptoStream, ) *cryptoStreamManager { return &cryptoStreamManager{ - cryptoHandler: cryptoHandler, initialStream: initialStream, handshakeStream: handshakeStream, oneRTTStream: oneRTTStream, @@ -48,18 +38,23 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve default: return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) } - if err := str.HandleCryptoFrame(frame); err != nil { - return err - } - for { - data := str.GetCryptoData() - if data == nil { - return nil - } - if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil { - return err - } + return str.HandleCryptoFrame(frame) +} + +func (m *cryptoStreamManager) GetCryptoData(encLevel protocol.EncryptionLevel) []byte { + var str cryptoStream + //nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets. + switch encLevel { + case protocol.EncryptionInitial: + str = m.initialStream + case protocol.EncryptionHandshake: + str = m.handshakeStream + case protocol.Encryption1RTT: + str = m.oneRTTStream + default: + panic(fmt.Sprintf("received CRYPTO frame with unexpected encryption level: %s", encLevel)) } + return str.GetCryptoData() } func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame { diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index daffffe6..5d32d016 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -11,7 +11,6 @@ import ( var _ = Describe("Crypto Stream Manager", func() { var ( csm *cryptoStreamManager - cs *MockCryptoDataHandler initialStream *MockCryptoStream handshakeStream *MockCryptoStream @@ -22,43 +21,31 @@ var _ = Describe("Crypto Stream Manager", func() { initialStream = NewMockCryptoStream(mockCtrl) handshakeStream = NewMockCryptoStream(mockCtrl) oneRTTStream = NewMockCryptoStream(mockCtrl) - cs = NewMockCryptoDataHandler(mockCtrl) - csm = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) + csm = newCryptoStreamManager(initialStream, handshakeStream, oneRTTStream) }) It("passes messages to the initial stream", func() { cf := &wire.CryptoFrame{Data: []byte("foobar")} initialStream.EXPECT().HandleCryptoFrame(cf) initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) - initialStream.EXPECT().GetCryptoData() - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed()) + Expect(csm.GetCryptoData(protocol.EncryptionInitial)).To(Equal([]byte("foobar"))) }) It("passes messages to the handshake stream", func() { cf := &wire.CryptoFrame{Data: []byte("foobar")} handshakeStream.EXPECT().HandleCryptoFrame(cf) handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) - handshakeStream.EXPECT().GetCryptoData() - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + Expect(csm.GetCryptoData(protocol.EncryptionHandshake)).To(Equal([]byte("foobar"))) }) It("passes messages to the 1-RTT stream", func() { cf := &wire.CryptoFrame{Data: []byte("foobar")} oneRTTStream.EXPECT().HandleCryptoFrame(cf) oneRTTStream.EXPECT().GetCryptoData().Return([]byte("foobar")) - oneRTTStream.EXPECT().GetCryptoData() - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.Encryption1RTT) Expect(csm.HandleCryptoFrame(cf, protocol.Encryption1RTT)).To(Succeed()) - }) - - It("doesn't call the message handler, if there's no message", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - handshakeStream.EXPECT().HandleCryptoFrame(cf) - handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle - // don't EXPECT any calls to HandleMessage() - Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + Expect(csm.GetCryptoData(protocol.Encryption1RTT)).To(Equal([]byte("foobar"))) }) It("processes all messages", func() { @@ -67,9 +54,16 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData().Return([]byte("foo")) handshakeStream.EXPECT().GetCryptoData().Return([]byte("bar")) handshakeStream.EXPECT().GetCryptoData() - cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) - cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) + var data []byte + for { + b := csm.GetCryptoData(protocol.EncryptionHandshake) + if len(b) == 0 { + break + } + data = append(data, b...) + } + Expect(data).To(Equal([]byte("foobar"))) }) It("errors for unknown encryption levels", func() { diff --git a/mock_crypto_data_handler_test.go b/mock_crypto_data_handler_test.go deleted file mode 100644 index 96fb3d8e..00000000 --- a/mock_crypto_data_handler_test.go +++ /dev/null @@ -1,117 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: CryptoDataHandler) -// -// Generated by this command: -// -// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler -// - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - handshake "github.com/quic-go/quic-go/internal/handshake" - protocol "github.com/quic-go/quic-go/internal/protocol" - gomock "go.uber.org/mock/gomock" -) - -// MockCryptoDataHandler is a mock of CryptoDataHandler interface. -type MockCryptoDataHandler struct { - ctrl *gomock.Controller - recorder *MockCryptoDataHandlerMockRecorder -} - -// MockCryptoDataHandlerMockRecorder is the mock recorder for MockCryptoDataHandler. -type MockCryptoDataHandlerMockRecorder struct { - mock *MockCryptoDataHandler -} - -// NewMockCryptoDataHandler creates a new mock instance. -func NewMockCryptoDataHandler(ctrl *gomock.Controller) *MockCryptoDataHandler { - mock := &MockCryptoDataHandler{ctrl: ctrl} - mock.recorder = &MockCryptoDataHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { - return m.recorder -} - -// HandleMessage mocks base method. -func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 any) *MockCryptoDataHandlerHandleMessageCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) - return &MockCryptoDataHandlerHandleMessageCall{Call: call} -} - -// MockCryptoDataHandlerHandleMessageCall wrap *gomock.Call -type MockCryptoDataHandlerHandleMessageCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockCryptoDataHandlerHandleMessageCall) Return(arg0 error) *MockCryptoDataHandlerHandleMessageCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockCryptoDataHandlerHandleMessageCall) Do(f func([]byte, protocol.EncryptionLevel) error) *MockCryptoDataHandlerHandleMessageCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockCryptoDataHandlerHandleMessageCall) DoAndReturn(f func([]byte, protocol.EncryptionLevel) error) *MockCryptoDataHandlerHandleMessageCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// NextEvent mocks base method. -func (m *MockCryptoDataHandler) NextEvent() handshake.Event { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextEvent") - ret0, _ := ret[0].(handshake.Event) - return ret0 -} - -// NextEvent indicates an expected call of NextEvent. -func (mr *MockCryptoDataHandlerMockRecorder) NextEvent() *MockCryptoDataHandlerNextEventCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoDataHandler)(nil).NextEvent)) - return &MockCryptoDataHandlerNextEventCall{Call: call} -} - -// MockCryptoDataHandlerNextEventCall wrap *gomock.Call -type MockCryptoDataHandlerNextEventCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockCryptoDataHandlerNextEventCall) Return(arg0 handshake.Event) *MockCryptoDataHandlerNextEventCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockCryptoDataHandlerNextEventCall) Do(f func() handshake.Event) *MockCryptoDataHandlerNextEventCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockCryptoDataHandlerNextEventCall) DoAndReturn(f func() handshake.Event) *MockCryptoDataHandlerNextEventCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/mockgen.go b/mockgen.go index 81cc4a5e..22355d88 100644 --- a/mockgen.go +++ b/mockgen.go @@ -29,9 +29,6 @@ type StreamGetter = streamGetter //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender" type StreamSender = streamSender -//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler" -type CryptoDataHandler = cryptoDataHandler - //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource" type FrameSource = frameSource