forked from quic-go/quic-go
use a separate function to close the session after VN and retry
This commit is contained in:
14
client.go
14
client.go
@@ -3,7 +3,6 @@ package quic
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -54,8 +53,6 @@ var (
|
||||
// make it possible to mock connection ID generation in the tests
|
||||
generateConnectionID = protocol.GenerateConnectionID
|
||||
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
|
||||
)
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
@@ -255,7 +252,7 @@ func (c *client) dial(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
err := c.establishSecureConnection(ctx)
|
||||
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
|
||||
if err == errCloseForRecreating {
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
@@ -263,8 +260,7 @@ func (c *client) dial(ctx context.Context) error {
|
||||
|
||||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
// It returns:
|
||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
|
||||
// - errCloseSessionRecreating when the server sends a version negotiation packet, or a stateless retry is performed
|
||||
// - any other error that might occur
|
||||
// - when the connection is forward-secure
|
||||
func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||
@@ -272,7 +268,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||
|
||||
go func() {
|
||||
err := c.session.run() // returns as soon as the session is closed
|
||||
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
|
||||
if err != errCloseForRecreating && c.createdPacketConn {
|
||||
c.conn.Close()
|
||||
}
|
||||
errorChan <- err
|
||||
@@ -344,7 +340,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
|
||||
c.version = newVersion
|
||||
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||
c.session.destroy(errCloseSessionForNewVersion)
|
||||
c.session.closeForRecreating()
|
||||
}
|
||||
|
||||
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
||||
@@ -370,7 +366,7 @@ func (c *client) handleRetryPacket(hdr *wire.Header) {
|
||||
c.origDestConnID = c.destConnID
|
||||
c.destConnID = hdr.SrcConnectionID
|
||||
c.token = hdr.Token
|
||||
c.session.destroy(errCloseSessionForRetry)
|
||||
c.session.closeForRecreating()
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
|
||||
|
||||
@@ -530,8 +530,8 @@ var _ = Describe("Client", func() {
|
||||
sess1.EXPECT().run().DoAndReturn(func() error {
|
||||
return <-run1
|
||||
})
|
||||
sess1.EXPECT().destroy(errCloseSessionForRetry).Do(func(e error) {
|
||||
run1 <- e
|
||||
sess1.EXPECT().closeForRecreating().Do(func() {
|
||||
run1 <- errCloseForRecreating
|
||||
})
|
||||
sess2 := NewMockQuicSession(mockCtrl)
|
||||
sess2.EXPECT().run()
|
||||
@@ -597,8 +597,8 @@ var _ = Describe("Client", func() {
|
||||
Eventually(run).Should(Receive(&err))
|
||||
return err
|
||||
})
|
||||
sess.EXPECT().destroy(gomock.Any()).Do(func(e error) {
|
||||
run <- e
|
||||
sess.EXPECT().closeForRecreating().Do(func() {
|
||||
run <- errCloseForRecreating
|
||||
})
|
||||
sessions <- sess
|
||||
doneErr := errors.New("nothing to do")
|
||||
@@ -717,7 +717,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
destroyed := make(chan struct{})
|
||||
sess.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) {
|
||||
sess.EXPECT().closeForRecreating().Do(func() {
|
||||
close(destroyed)
|
||||
})
|
||||
cl.session = sess
|
||||
|
||||
@@ -199,6 +199,16 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr))
|
||||
}
|
||||
|
||||
// closeForRecreating mocks base method
|
||||
func (m *MockQuicSession) closeForRecreating() {
|
||||
m.ctrl.Call(m, "closeForRecreating")
|
||||
}
|
||||
|
||||
// closeForRecreating indicates an expected call of closeForRecreating
|
||||
func (mr *MockQuicSessionMockRecorder) closeForRecreating() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForRecreating", reflect.TypeOf((*MockQuicSession)(nil).closeForRecreating))
|
||||
}
|
||||
|
||||
// closeRemote mocks base method
|
||||
func (m *MockQuicSession) closeRemote(arg0 error) {
|
||||
m.ctrl.Call(m, "closeRemote", arg0)
|
||||
|
||||
@@ -43,6 +43,7 @@ type quicSession interface {
|
||||
GetVersion() protocol.VersionNumber
|
||||
run() error
|
||||
destroy(error)
|
||||
closeForRecreating()
|
||||
closeRemote(error)
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,8 @@ type closeError struct {
|
||||
sendClose bool
|
||||
}
|
||||
|
||||
var errCloseForRecreating = errors.New("closing session in order to recreate it")
|
||||
|
||||
// A Session is a QUIC session
|
||||
type session struct {
|
||||
sessionRunner sessionRunner
|
||||
@@ -718,6 +720,10 @@ func (s *session) destroy(e error) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *session) closeForRecreating() {
|
||||
s.destroy(errCloseForRecreating)
|
||||
}
|
||||
|
||||
func (s *session) closeRemote(e error) {
|
||||
s.closeOnce.Do(func() {
|
||||
s.sessionRunner.removeConnectionID(s.srcConnID)
|
||||
|
||||
@@ -352,14 +352,26 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
var (
|
||||
runErr error
|
||||
expectedRunErr error
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
|
||||
sess.run()
|
||||
runErr = sess.run()
|
||||
}()
|
||||
Eventually(areSessionsRunning).Should(BeTrue())
|
||||
expectedRunErr = nil
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if expectedRunErr != nil {
|
||||
Expect(runErr).To(MatchError(expectedRunErr))
|
||||
}
|
||||
})
|
||||
|
||||
It("shuts down without error", func() {
|
||||
@@ -397,13 +409,25 @@ var _ = Describe("Session", func() {
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the session in order to replace it with another QUIC version", func() {
|
||||
It("closes the session in order to recreate it", func() {
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
cryptoSetup.EXPECT().Close()
|
||||
sess.destroy(errCloseSessionForNewVersion)
|
||||
sess.closeForRecreating()
|
||||
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
expectedRunErr = errCloseForRecreating
|
||||
})
|
||||
|
||||
It("destroys the session", func() {
|
||||
testErr := errors.New("close")
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
|
||||
cryptoSetup.EXPECT().Close()
|
||||
sess.destroy(testErr)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
|
||||
expectedRunErr = testErr
|
||||
})
|
||||
|
||||
It("cancels the context when the run loop exists", func() {
|
||||
|
||||
Reference in New Issue
Block a user