forked from quic-go/quic-go
add context to EarlyConnection.NextConnection, handle handshake failures (#4551)
This commit is contained in:
@@ -2424,10 +2424,17 @@ func (s *connection) GetVersion() protocol.Version {
|
|||||||
return s.version
|
return s.version
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *connection) NextConnection() Connection {
|
func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
|
||||||
<-s.HandshakeComplete()
|
// The handshake might fail after the server rejected 0-RTT.
|
||||||
|
// This could happen if the Finished message is malformed or never received.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, context.Cause(ctx)
|
||||||
|
case <-s.Context().Done():
|
||||||
|
case <-s.HandshakeComplete():
|
||||||
s.streamsMap.UseResetMaps()
|
s.streamsMap.UseResetMaps()
|
||||||
return s
|
}
|
||||||
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
|
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
|
||||||
|
|||||||
@@ -871,7 +871,8 @@ var _ = Describe("0-RTT", func() {
|
|||||||
_, err = conn.AcceptStream(ctx)
|
_, err = conn.AcceptStream(ctx)
|
||||||
Expect(err).To(Equal(quic.Err0RTTRejected))
|
Expect(err).To(Equal(quic.Err0RTTRejected))
|
||||||
|
|
||||||
newConn := conn.NextConnection()
|
newConn, err := conn.NextConnection(context.Background())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
str, err := newConn.OpenUniStream()
|
str, err := newConn.OpenUniStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = newConn.OpenUniStream()
|
_, err = newConn.OpenUniStream()
|
||||||
|
|||||||
@@ -224,7 +224,7 @@ type EarlyConnection interface {
|
|||||||
// however the client's identity is only verified once the handshake completes.
|
// however the client's identity is only verified once the handshake completes.
|
||||||
HandshakeComplete() <-chan struct{}
|
HandshakeComplete() <-chan struct{}
|
||||||
|
|
||||||
NextConnection() Connection
|
NextConnection(context.Context) (Connection, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StatelessResetKey is a key used to derive stateless reset tokens.
|
// StatelessResetKey is a key used to derive stateless reset tokens.
|
||||||
|
|||||||
@@ -311,17 +311,18 @@ func (c *MockEarlyConnectionLocalAddrCall) DoAndReturn(f func() net.Addr) *MockE
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NextConnection mocks base method.
|
// NextConnection mocks base method.
|
||||||
func (m *MockEarlyConnection) NextConnection() quic.Connection {
|
func (m *MockEarlyConnection) NextConnection(arg0 context.Context) (quic.Connection, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "NextConnection")
|
ret := m.ctrl.Call(m, "NextConnection", arg0)
|
||||||
ret0, _ := ret[0].(quic.Connection)
|
ret0, _ := ret[0].(quic.Connection)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// NextConnection indicates an expected call of NextConnection.
|
// NextConnection indicates an expected call of NextConnection.
|
||||||
func (mr *MockEarlyConnectionMockRecorder) NextConnection() *MockEarlyConnectionNextConnectionCall {
|
func (mr *MockEarlyConnectionMockRecorder) NextConnection(arg0 any) *MockEarlyConnectionNextConnectionCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection))
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection), arg0)
|
||||||
return &MockEarlyConnectionNextConnectionCall{Call: call}
|
return &MockEarlyConnectionNextConnectionCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,19 +332,19 @@ type MockEarlyConnectionNextConnectionCall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection) *MockEarlyConnectionNextConnectionCall {
|
func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection, arg1 error) *MockEarlyConnectionNextConnectionCall {
|
||||||
c.Call = c.Call.Return(arg0)
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockEarlyConnectionNextConnectionCall) Do(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall {
|
func (c *MockEarlyConnectionNextConnectionCall) Do(f func(context.Context) (quic.Connection, error)) *MockEarlyConnectionNextConnectionCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall {
|
func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func(context.Context) (quic.Connection, error)) *MockEarlyConnectionNextConnectionCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -310,17 +310,18 @@ func (c *MockQUICConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockQUICConn
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NextConnection mocks base method.
|
// NextConnection mocks base method.
|
||||||
func (m *MockQUICConn) NextConnection() Connection {
|
func (m *MockQUICConn) NextConnection(arg0 context.Context) (Connection, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "NextConnection")
|
ret := m.ctrl.Call(m, "NextConnection", arg0)
|
||||||
ret0, _ := ret[0].(Connection)
|
ret0, _ := ret[0].(Connection)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// NextConnection indicates an expected call of NextConnection.
|
// NextConnection indicates an expected call of NextConnection.
|
||||||
func (mr *MockQUICConnMockRecorder) NextConnection() *MockQUICConnNextConnectionCall {
|
func (mr *MockQUICConnMockRecorder) NextConnection(arg0 any) *MockQUICConnNextConnectionCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection))
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection), arg0)
|
||||||
return &MockQUICConnNextConnectionCall{Call: call}
|
return &MockQUICConnNextConnectionCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,19 +331,19 @@ type MockQUICConnNextConnectionCall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection) *MockQUICConnNextConnectionCall {
|
func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection, arg1 error) *MockQUICConnNextConnectionCall {
|
||||||
c.Call = c.Call.Return(arg0)
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockQUICConnNextConnectionCall) Do(f func() Connection) *MockQUICConnNextConnectionCall {
|
func (c *MockQUICConnNextConnectionCall) Do(f func(context.Context) (Connection, error)) *MockQUICConnNextConnectionCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func() Connection) *MockQUICConnNextConnectionCall {
|
func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func(context.Context) (Connection, error)) *MockQUICConnNextConnectionCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user