add context to EarlyConnection.NextConnection, handle handshake failures (#4551)

This commit is contained in:
Marten Seemann
2024-06-05 11:51:54 +08:00
committed by GitHub
parent 0db354456a
commit 07acaad2f7
5 changed files with 34 additions and 24 deletions

View File

@@ -2424,10 +2424,17 @@ func (s *connection) GetVersion() protocol.Version {
return s.version
}
func (s *connection) NextConnection() Connection {
<-s.HandshakeComplete()
func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
// The handshake might fail after the server rejected 0-RTT.
// This could happen if the Finished message is malformed or never received.
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-s.Context().Done():
case <-s.HandshakeComplete():
s.streamsMap.UseResetMaps()
return s
}
return s, nil
}
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.

View File

@@ -871,7 +871,8 @@ var _ = Describe("0-RTT", func() {
_, err = conn.AcceptStream(ctx)
Expect(err).To(Equal(quic.Err0RTTRejected))
newConn := conn.NextConnection()
newConn, err := conn.NextConnection(context.Background())
Expect(err).ToNot(HaveOccurred())
str, err := newConn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = newConn.OpenUniStream()

View File

@@ -224,7 +224,7 @@ type EarlyConnection interface {
// however the client's identity is only verified once the handshake completes.
HandshakeComplete() <-chan struct{}
NextConnection() Connection
NextConnection(context.Context) (Connection, error)
}
// StatelessResetKey is a key used to derive stateless reset tokens.

View File

@@ -311,17 +311,18 @@ func (c *MockEarlyConnectionLocalAddrCall) DoAndReturn(f func() net.Addr) *MockE
}
// NextConnection mocks base method.
func (m *MockEarlyConnection) NextConnection() quic.Connection {
func (m *MockEarlyConnection) NextConnection(arg0 context.Context) (quic.Connection, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextConnection")
ret := m.ctrl.Call(m, "NextConnection", arg0)
ret0, _ := ret[0].(quic.Connection)
return ret0
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NextConnection indicates an expected call of NextConnection.
func (mr *MockEarlyConnectionMockRecorder) NextConnection() *MockEarlyConnectionNextConnectionCall {
func (mr *MockEarlyConnectionMockRecorder) NextConnection(arg0 any) *MockEarlyConnectionNextConnectionCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection))
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection), arg0)
return &MockEarlyConnectionNextConnectionCall{Call: call}
}
@@ -331,19 +332,19 @@ type MockEarlyConnectionNextConnectionCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection) *MockEarlyConnectionNextConnectionCall {
c.Call = c.Call.Return(arg0)
func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection, arg1 error) *MockEarlyConnectionNextConnectionCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockEarlyConnectionNextConnectionCall) Do(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall {
func (c *MockEarlyConnectionNextConnectionCall) Do(f func(context.Context) (quic.Connection, error)) *MockEarlyConnectionNextConnectionCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall {
func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func(context.Context) (quic.Connection, error)) *MockEarlyConnectionNextConnectionCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -310,17 +310,18 @@ func (c *MockQUICConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockQUICConn
}
// NextConnection mocks base method.
func (m *MockQUICConn) NextConnection() Connection {
func (m *MockQUICConn) NextConnection(arg0 context.Context) (Connection, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextConnection")
ret := m.ctrl.Call(m, "NextConnection", arg0)
ret0, _ := ret[0].(Connection)
return ret0
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NextConnection indicates an expected call of NextConnection.
func (mr *MockQUICConnMockRecorder) NextConnection() *MockQUICConnNextConnectionCall {
func (mr *MockQUICConnMockRecorder) NextConnection(arg0 any) *MockQUICConnNextConnectionCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection))
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection), arg0)
return &MockQUICConnNextConnectionCall{Call: call}
}
@@ -330,19 +331,19 @@ type MockQUICConnNextConnectionCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection) *MockQUICConnNextConnectionCall {
c.Call = c.Call.Return(arg0)
func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection, arg1 error) *MockQUICConnNextConnectionCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockQUICConnNextConnectionCall) Do(f func() Connection) *MockQUICConnNextConnectionCall {
func (c *MockQUICConnNextConnectionCall) Do(f func(context.Context) (Connection, error)) *MockQUICConnNextConnectionCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func() Connection) *MockQUICConnNextConnectionCall {
func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func(context.Context) (Connection, error)) *MockQUICConnNextConnectionCall {
c.Call = c.Call.DoAndReturn(f)
return c
}