diff --git a/interface.go b/interface.go index 6b5a0176..84ae906c 100644 --- a/interface.go +++ b/interface.go @@ -2,6 +2,7 @@ package quic import ( "context" + "errors" "io" "net" "time" @@ -65,6 +66,13 @@ type TokenStore interface { // Valid values range between 0 and MAX_UINT62. type ErrorCode = protocol.ApplicationErrorCode +// Err0RTTRejected is the returned from: +// * Open{Uni}Stream{Sync} +// * Accept{Uni}Stream +// * Stream.Read and Stream.Write +// when the server rejects a 0-RTT connection attempt. +var Err0RTTRejected = errors.New("0-RTT rejected") + // Stream is the interface implemented by QUIC streams type Stream interface { ReceiveStream diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 44226e02..92c31da9 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -194,6 +194,18 @@ func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) } +// ResetFor0RTT mocks base method. +func (m *MockStreamManager) ResetFor0RTT() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ResetFor0RTT") +} + +// ResetFor0RTT indicates an expected call of ResetFor0RTT. +func (mr *MockStreamManagerMockRecorder) ResetFor0RTT() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetFor0RTT", reflect.TypeOf((*MockStreamManager)(nil).ResetFor0RTT)) +} + // UpdateLimits mocks base method. func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) { m.ctrl.T.Helper() @@ -205,3 +217,15 @@ func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 interface{}) *gomock. mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) } + +// UseResetMaps mocks base method. +func (m *MockStreamManager) UseResetMaps() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UseResetMaps") +} + +// UseResetMaps indicates an expected call of UseResetMaps. +func (mr *MockStreamManagerMockRecorder) UseResetMaps() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseResetMaps", reflect.TypeOf((*MockStreamManager)(nil).UseResetMaps)) +} diff --git a/session.go b/session.go index 7146aa2a..395eadfb 100644 --- a/session.go +++ b/session.go @@ -45,6 +45,8 @@ type streamManager interface { UpdateLimits(*wire.TransportParameters) HandleMaxStreamsFrame(*wire.MaxStreamsFrame) CloseWithError(error) + ResetFor0RTT() + UseResetMaps() } type cryptoStreamHandler interface { diff --git a/streams_map.go b/streams_map.go index de4d850b..61a5398e 100644 --- a/streams_map.go +++ b/streams_map.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "sync" "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -45,14 +46,20 @@ var errTooManyOpenStreams = errors.New("too many open streams") type streamsMap struct { perspective protocol.Perspective + version protocol.VersionNumber + + maxIncomingBidiStreams uint64 + maxIncomingUniStreams uint64 sender streamSender newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController + mutex sync.Mutex outgoingBidiStreams *outgoingBidiStreamsMap outgoingUniStreams *outgoingUniStreamsMap incomingBidiStreams *incomingBidiStreamsMap incomingUniStreams *incomingUniStreamsMap + reset bool } var _ streamManager = &streamsMap{} @@ -66,70 +73,119 @@ func newStreamsMap( version protocol.VersionNumber, ) streamManager { m := &streamsMap{ - perspective: perspective, - newFlowController: newFlowController, - sender: sender, + perspective: perspective, + newFlowController: newFlowController, + maxIncomingBidiStreams: maxIncomingBidiStreams, + maxIncomingUniStreams: maxIncomingUniStreams, + sender: sender, + version: version, } - m.outgoingBidiStreams = newOutgoingBidiStreamsMap( - func(num protocol.StreamNum) streamI { - id := num.StreamID(protocol.StreamTypeBidi, perspective) - return newStream(id, m.sender, m.newFlowController(id), version) - }, - sender.queueControlFrame, - ) - m.incomingBidiStreams = newIncomingBidiStreamsMap( - func(num protocol.StreamNum) streamI { - id := num.StreamID(protocol.StreamTypeBidi, perspective.Opposite()) - return newStream(id, m.sender, m.newFlowController(id), version) - }, - maxIncomingBidiStreams, - sender.queueControlFrame, - ) - m.outgoingUniStreams = newOutgoingUniStreamsMap( - func(num protocol.StreamNum) sendStreamI { - id := num.StreamID(protocol.StreamTypeUni, perspective) - return newSendStream(id, m.sender, m.newFlowController(id), version) - }, - sender.queueControlFrame, - ) - m.incomingUniStreams = newIncomingUniStreamsMap( - func(num protocol.StreamNum) receiveStreamI { - id := num.StreamID(protocol.StreamTypeUni, perspective.Opposite()) - return newReceiveStream(id, m.sender, m.newFlowController(id), version) - }, - maxIncomingUniStreams, - sender.queueControlFrame, - ) + m.initMaps() return m } +func (m *streamsMap) initMaps() { + m.outgoingBidiStreams = newOutgoingBidiStreamsMap( + func(num protocol.StreamNum) streamI { + id := num.StreamID(protocol.StreamTypeBidi, m.perspective) + return newStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.sender.queueControlFrame, + ) + m.incomingBidiStreams = newIncomingBidiStreamsMap( + func(num protocol.StreamNum) streamI { + id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) + return newStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.maxIncomingBidiStreams, + m.sender.queueControlFrame, + ) + m.outgoingUniStreams = newOutgoingUniStreamsMap( + func(num protocol.StreamNum) sendStreamI { + id := num.StreamID(protocol.StreamTypeUni, m.perspective) + return newSendStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.sender.queueControlFrame, + ) + m.incomingUniStreams = newIncomingUniStreamsMap( + func(num protocol.StreamNum) receiveStreamI { + id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite()) + return newReceiveStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.maxIncomingUniStreams, + m.sender.queueControlFrame, + ) +} + func (m *streamsMap) OpenStream() (Stream, error) { - str, err := m.outgoingBidiStreams.OpenStream() + m.mutex.Lock() + reset := m.reset + mm := m.outgoingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStream() return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) { - str, err := m.outgoingBidiStreams.OpenStreamSync(ctx) + m.mutex.Lock() + reset := m.reset + mm := m.outgoingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStreamSync(ctx) return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } func (m *streamsMap) OpenUniStream() (SendStream, error) { - str, err := m.outgoingUniStreams.OpenStream() + m.mutex.Lock() + reset := m.reset + mm := m.outgoingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStream() return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) } func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) { - str, err := m.outgoingUniStreams.OpenStreamSync(ctx) + m.mutex.Lock() + reset := m.reset + mm := m.outgoingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStreamSync(ctx) return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) } func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { - str, err := m.incomingBidiStreams.AcceptStream(ctx) + m.mutex.Lock() + reset := m.reset + mm := m.incomingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.AcceptStream(ctx) return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) } func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { - str, err := m.incomingUniStreams.AcceptStream(ctx) + m.mutex.Lock() + reset := m.reset + mm := m.incomingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.AcceptStream(ctx) return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) } @@ -232,3 +288,22 @@ func (m *streamsMap) CloseWithError(err error) { m.incomingBidiStreams.CloseWithError(err) m.incomingUniStreams.CloseWithError(err) } + +// ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are +// 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error. +// 2. reset to their initial state, such that we can immediately process new incoming stream data. +// Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error, +// until UseResetMaps() has been called. +func (m *streamsMap) ResetFor0RTT() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.reset = true + m.CloseWithError(Err0RTTRejected) + m.initMaps() +} + +func (m *streamsMap) UseResetMaps() { + m.mutex.Lock() + m.reset = false + m.mutex.Unlock() +} diff --git a/streams_map_test.go b/streams_map_test.go index 0ba630b5..81e4b128 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -320,39 +320,32 @@ var _ = Describe("Streams Map", func() { }) }) - Context("updating stream ID limits", func() { - for _, p := range []protocol.Perspective{protocol.PerspectiveClient, protocol.PerspectiveServer} { - pers := p + It("processes the parameter for outgoing streams", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()) + _, err := m.OpenStream() + expectTooManyStreamsError(err) + m.UpdateLimits(&wire.TransportParameters{ + MaxBidiStreamNum: 5, + MaxUniStreamNum: 8, + }) - It(fmt.Sprintf("processes the parameter for outgoing streams, as a %s", pers), func() { - mockSender.EXPECT().queueControlFrame(gomock.Any()) - m.perspective = pers - _, err := m.OpenStream() - expectTooManyStreamsError(err) - m.UpdateLimits(&wire.TransportParameters{ - MaxBidiStreamNum: 5, - MaxUniStreamNum: 8, - }) - - mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2) - // test we can only 5 bidirectional streams - for i := 0; i < 5; i++ { - str, err := m.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i))) - } - _, err = m.OpenStream() - expectTooManyStreamsError(err) - // test we can only 8 unidirectional streams - for i := 0; i < 8; i++ { - str, err := m.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i))) - } - _, err = m.OpenUniStream() - expectTooManyStreamsError(err) - }) + mockSender.EXPECT().queueControlFrame(gomock.Any()).Times(2) + // test we can only 5 bidirectional streams + for i := 0; i < 5; i++ { + str, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + protocol.StreamID(4*i))) } + _, err = m.OpenStream() + expectTooManyStreamsError(err) + // test we can only 8 unidirectional streams + for i := 0; i < 8; i++ { + str, err := m.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + protocol.StreamID(4*i))) + } + _, err = m.OpenUniStream() + expectTooManyStreamsError(err) }) Context("handling MAX_STREAMS frames", func() { @@ -431,6 +424,28 @@ var _ = Describe("Streams Map", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(Equal(testErr.Error())) }) + + if perspective == protocol.PerspectiveClient { + It("resets for 0-RTT", func() { + mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes() + m.ResetFor0RTT() + // make sure that calls to open / accept streams fail + _, err := m.OpenStream() + Expect(err).To(MatchError(Err0RTTRejected)) + _, err = m.AcceptStream(context.Background()) + Expect(err).To(MatchError(Err0RTTRejected)) + // make sure that we can still get new streams, as the server might be sending us data + str, err := m.GetOrOpenReceiveStream(3) + Expect(err).ToNot(HaveOccurred()) + Expect(str).ToNot(BeNil()) + + // now switch to using the new streams map + m.UseResetMaps() + _, err = m.OpenStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + }) + } }) } })