add context to EarlyConnection.NextConnection, handle handshake failures (#4551)

This commit is contained in:
Marten Seemann
2024-06-05 11:51:54 +08:00
committed by GitHub
parent 0db354456a
commit 07acaad2f7
5 changed files with 34 additions and 24 deletions

View File

@@ -2424,10 +2424,17 @@ func (s *connection) GetVersion() protocol.Version {
return s.version return s.version
} }
func (s *connection) NextConnection() Connection { func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
<-s.HandshakeComplete() // The handshake might fail after the server rejected 0-RTT.
// This could happen if the Finished message is malformed or never received.
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-s.Context().Done():
case <-s.HandshakeComplete():
s.streamsMap.UseResetMaps() s.streamsMap.UseResetMaps()
return s }
return s, nil
} }
// estimateMaxPayloadSize estimates the maximum payload size for short header packets. // estimateMaxPayloadSize estimates the maximum payload size for short header packets.

View File

@@ -871,7 +871,8 @@ var _ = Describe("0-RTT", func() {
_, err = conn.AcceptStream(ctx) _, err = conn.AcceptStream(ctx)
Expect(err).To(Equal(quic.Err0RTTRejected)) Expect(err).To(Equal(quic.Err0RTTRejected))
newConn := conn.NextConnection() newConn, err := conn.NextConnection(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := newConn.OpenUniStream() str, err := newConn.OpenUniStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = newConn.OpenUniStream() _, err = newConn.OpenUniStream()

View File

@@ -224,7 +224,7 @@ type EarlyConnection interface {
// however the client's identity is only verified once the handshake completes. // however the client's identity is only verified once the handshake completes.
HandshakeComplete() <-chan struct{} HandshakeComplete() <-chan struct{}
NextConnection() Connection NextConnection(context.Context) (Connection, error)
} }
// StatelessResetKey is a key used to derive stateless reset tokens. // StatelessResetKey is a key used to derive stateless reset tokens.

View File

@@ -311,17 +311,18 @@ func (c *MockEarlyConnectionLocalAddrCall) DoAndReturn(f func() net.Addr) *MockE
} }
// NextConnection mocks base method. // NextConnection mocks base method.
func (m *MockEarlyConnection) NextConnection() quic.Connection { func (m *MockEarlyConnection) NextConnection(arg0 context.Context) (quic.Connection, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextConnection") ret := m.ctrl.Call(m, "NextConnection", arg0)
ret0, _ := ret[0].(quic.Connection) ret0, _ := ret[0].(quic.Connection)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// NextConnection indicates an expected call of NextConnection. // NextConnection indicates an expected call of NextConnection.
func (mr *MockEarlyConnectionMockRecorder) NextConnection() *MockEarlyConnectionNextConnectionCall { func (mr *MockEarlyConnectionMockRecorder) NextConnection(arg0 any) *MockEarlyConnectionNextConnectionCall {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection)) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection), arg0)
return &MockEarlyConnectionNextConnectionCall{Call: call} return &MockEarlyConnectionNextConnectionCall{Call: call}
} }
@@ -331,19 +332,19 @@ type MockEarlyConnectionNextConnectionCall struct {
} }
// Return rewrite *gomock.Call.Return // Return rewrite *gomock.Call.Return
func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection) *MockEarlyConnectionNextConnectionCall { func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection, arg1 error) *MockEarlyConnectionNextConnectionCall {
c.Call = c.Call.Return(arg0) c.Call = c.Call.Return(arg0, arg1)
return c return c
} }
// Do rewrite *gomock.Call.Do // Do rewrite *gomock.Call.Do
func (c *MockEarlyConnectionNextConnectionCall) Do(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall { func (c *MockEarlyConnectionNextConnectionCall) Do(f func(context.Context) (quic.Connection, error)) *MockEarlyConnectionNextConnectionCall {
c.Call = c.Call.Do(f) c.Call = c.Call.Do(f)
return c return c
} }
// DoAndReturn rewrite *gomock.Call.DoAndReturn // DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall { func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func(context.Context) (quic.Connection, error)) *MockEarlyConnectionNextConnectionCall {
c.Call = c.Call.DoAndReturn(f) c.Call = c.Call.DoAndReturn(f)
return c return c
} }

View File

@@ -310,17 +310,18 @@ func (c *MockQUICConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockQUICConn
} }
// NextConnection mocks base method. // NextConnection mocks base method.
func (m *MockQUICConn) NextConnection() Connection { func (m *MockQUICConn) NextConnection(arg0 context.Context) (Connection, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextConnection") ret := m.ctrl.Call(m, "NextConnection", arg0)
ret0, _ := ret[0].(Connection) ret0, _ := ret[0].(Connection)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// NextConnection indicates an expected call of NextConnection. // NextConnection indicates an expected call of NextConnection.
func (mr *MockQUICConnMockRecorder) NextConnection() *MockQUICConnNextConnectionCall { func (mr *MockQUICConnMockRecorder) NextConnection(arg0 any) *MockQUICConnNextConnectionCall {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection)) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection), arg0)
return &MockQUICConnNextConnectionCall{Call: call} return &MockQUICConnNextConnectionCall{Call: call}
} }
@@ -330,19 +331,19 @@ type MockQUICConnNextConnectionCall struct {
} }
// Return rewrite *gomock.Call.Return // Return rewrite *gomock.Call.Return
func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection) *MockQUICConnNextConnectionCall { func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection, arg1 error) *MockQUICConnNextConnectionCall {
c.Call = c.Call.Return(arg0) c.Call = c.Call.Return(arg0, arg1)
return c return c
} }
// Do rewrite *gomock.Call.Do // Do rewrite *gomock.Call.Do
func (c *MockQUICConnNextConnectionCall) Do(f func() Connection) *MockQUICConnNextConnectionCall { func (c *MockQUICConnNextConnectionCall) Do(f func(context.Context) (Connection, error)) *MockQUICConnNextConnectionCall {
c.Call = c.Call.Do(f) c.Call = c.Call.Do(f)
return c return c
} }
// DoAndReturn rewrite *gomock.Call.DoAndReturn // DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func() Connection) *MockQUICConnNextConnectionCall { func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func(context.Context) (Connection, error)) *MockQUICConnNextConnectionCall {
c.Call = c.Call.DoAndReturn(f) c.Call = c.Call.DoAndReturn(f)
return c return c
} }