diff --git a/connection.go b/connection.go index 33ab14b8..13f905c5 100644 --- a/connection.go +++ b/connection.go @@ -2424,10 +2424,17 @@ func (s *connection) GetVersion() protocol.Version { return s.version } -func (s *connection) NextConnection() Connection { - <-s.HandshakeComplete() - s.streamsMap.UseResetMaps() - return s +func (s *connection) NextConnection(ctx context.Context) (Connection, error) { + // The handshake might fail after the server rejected 0-RTT. + // This could happen if the Finished message is malformed or never received. + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case <-s.Context().Done(): + case <-s.HandshakeComplete(): + s.streamsMap.UseResetMaps() + } + return s, nil } // estimateMaxPayloadSize estimates the maximum payload size for short header packets. diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index bf966b9c..ccdeab6d 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -871,7 +871,8 @@ var _ = Describe("0-RTT", func() { _, err = conn.AcceptStream(ctx) Expect(err).To(Equal(quic.Err0RTTRejected)) - newConn := conn.NextConnection() + newConn, err := conn.NextConnection(context.Background()) + Expect(err).ToNot(HaveOccurred()) str, err := newConn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = newConn.OpenUniStream() diff --git a/interface.go b/interface.go index cdefee44..56155f71 100644 --- a/interface.go +++ b/interface.go @@ -224,7 +224,7 @@ type EarlyConnection interface { // however the client's identity is only verified once the handshake completes. HandshakeComplete() <-chan struct{} - NextConnection() Connection + NextConnection(context.Context) (Connection, error) } // StatelessResetKey is a key used to derive stateless reset tokens. diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index f08f0481..a6b17341 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -311,17 +311,18 @@ func (c *MockEarlyConnectionLocalAddrCall) DoAndReturn(f func() net.Addr) *MockE } // NextConnection mocks base method. -func (m *MockEarlyConnection) NextConnection() quic.Connection { +func (m *MockEarlyConnection) NextConnection(arg0 context.Context) (quic.Connection, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextConnection") + ret := m.ctrl.Call(m, "NextConnection", arg0) ret0, _ := ret[0].(quic.Connection) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // NextConnection indicates an expected call of NextConnection. -func (mr *MockEarlyConnectionMockRecorder) NextConnection() *MockEarlyConnectionNextConnectionCall { +func (mr *MockEarlyConnectionMockRecorder) NextConnection(arg0 any) *MockEarlyConnectionNextConnectionCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection), arg0) return &MockEarlyConnectionNextConnectionCall{Call: call} } @@ -331,19 +332,19 @@ type MockEarlyConnectionNextConnectionCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection) *MockEarlyConnectionNextConnectionCall { - c.Call = c.Call.Return(arg0) +func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection, arg1 error) *MockEarlyConnectionNextConnectionCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockEarlyConnectionNextConnectionCall) Do(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall { +func (c *MockEarlyConnectionNextConnectionCall) Do(f func(context.Context) (quic.Connection, error)) *MockEarlyConnectionNextConnectionCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall { +func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func(context.Context) (quic.Connection, error)) *MockEarlyConnectionNextConnectionCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index 179f26ee..785ef5d3 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -310,17 +310,18 @@ func (c *MockQUICConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockQUICConn } // NextConnection mocks base method. -func (m *MockQUICConn) NextConnection() Connection { +func (m *MockQUICConn) NextConnection(arg0 context.Context) (Connection, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextConnection") + ret := m.ctrl.Call(m, "NextConnection", arg0) ret0, _ := ret[0].(Connection) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // NextConnection indicates an expected call of NextConnection. -func (mr *MockQUICConnMockRecorder) NextConnection() *MockQUICConnNextConnectionCall { +func (mr *MockQUICConnMockRecorder) NextConnection(arg0 any) *MockQUICConnNextConnectionCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection), arg0) return &MockQUICConnNextConnectionCall{Call: call} } @@ -330,19 +331,19 @@ type MockQUICConnNextConnectionCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection) *MockQUICConnNextConnectionCall { - c.Call = c.Call.Return(arg0) +func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection, arg1 error) *MockQUICConnNextConnectionCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockQUICConnNextConnectionCall) Do(f func() Connection) *MockQUICConnNextConnectionCall { +func (c *MockQUICConnNextConnectionCall) Do(f func(context.Context) (Connection, error)) *MockQUICConnNextConnectionCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func() Connection) *MockQUICConnNextConnectionCall { +func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func(context.Context) (Connection, error)) *MockQUICConnNextConnectionCall { c.Call = c.Call.DoAndReturn(f) return c }