use a separate function to close the session after VN and retry

This commit is contained in:
Marten Seemann
2018-12-21 23:47:10 +06:30
parent b166757fd9
commit f9218444a9
6 changed files with 54 additions and 17 deletions

View File

@@ -3,7 +3,6 @@ package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
@@ -54,8 +53,6 @@ var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
)
// DialAddr establishes a new QUIC connection to a server.
@@ -255,7 +252,7 @@ func (c *client) dial(ctx context.Context) error {
return err
}
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
if err == errCloseForRecreating {
return c.dial(ctx)
}
return err
@@ -263,8 +260,7 @@ func (c *client) dial(ctx context.Context) error {
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
// - errCloseSessionRecreating when the server sends a version negotiation packet, or a stateless retry is performed
// - any other error that might occur
// - when the connection is forward-secure
func (c *client) establishSecureConnection(ctx context.Context) error {
@@ -272,7 +268,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
if err != errCloseForRecreating && c.createdPacketConn {
c.conn.Close()
}
errorChan <- err
@@ -344,7 +340,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
c.version = newVersion
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
c.session.closeForRecreating()
}
func (c *client) handleRetryPacket(hdr *wire.Header) {
@@ -370,7 +366,7 @@ func (c *client) handleRetryPacket(hdr *wire.Header) {
c.origDestConnID = c.destConnID
c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token
c.session.destroy(errCloseSessionForRetry)
c.session.closeForRecreating()
}
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {

View File

@@ -530,8 +530,8 @@ var _ = Describe("Client", func() {
sess1.EXPECT().run().DoAndReturn(func() error {
return <-run1
})
sess1.EXPECT().destroy(errCloseSessionForRetry).Do(func(e error) {
run1 <- e
sess1.EXPECT().closeForRecreating().Do(func() {
run1 <- errCloseForRecreating
})
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
@@ -597,8 +597,8 @@ var _ = Describe("Client", func() {
Eventually(run).Should(Receive(&err))
return err
})
sess.EXPECT().destroy(gomock.Any()).Do(func(e error) {
run <- e
sess.EXPECT().closeForRecreating().Do(func() {
run <- errCloseForRecreating
})
sessions <- sess
doneErr := errors.New("nothing to do")
@@ -717,7 +717,7 @@ var _ = Describe("Client", func() {
sess := NewMockQuicSession(mockCtrl)
destroyed := make(chan struct{})
sess.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) {
sess.EXPECT().closeForRecreating().Do(func() {
close(destroyed)
})
cl.session = sess

View File

@@ -199,6 +199,16 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr))
}
// closeForRecreating mocks base method
func (m *MockQuicSession) closeForRecreating() {
m.ctrl.Call(m, "closeForRecreating")
}
// closeForRecreating indicates an expected call of closeForRecreating
func (mr *MockQuicSessionMockRecorder) closeForRecreating() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForRecreating", reflect.TypeOf((*MockQuicSession)(nil).closeForRecreating))
}
// closeRemote mocks base method
func (m *MockQuicSession) closeRemote(arg0 error) {
m.ctrl.Call(m, "closeRemote", arg0)

View File

@@ -43,6 +43,7 @@ type quicSession interface {
GetVersion() protocol.VersionNumber
run() error
destroy(error)
closeForRecreating()
closeRemote(error)
}

View File

@@ -63,6 +63,8 @@ type closeError struct {
sendClose bool
}
var errCloseForRecreating = errors.New("closing session in order to recreate it")
// A Session is a QUIC session
type session struct {
sessionRunner sessionRunner
@@ -718,6 +720,10 @@ func (s *session) destroy(e error) {
})
}
func (s *session) closeForRecreating() {
s.destroy(errCloseForRecreating)
}
func (s *session) closeRemote(e error) {
s.closeOnce.Do(func() {
s.sessionRunner.removeConnectionID(s.srcConnID)

View File

@@ -352,14 +352,26 @@ var _ = Describe("Session", func() {
})
Context("closing", func() {
var (
runErr error
expectedRunErr error
)
BeforeEach(func() {
Eventually(areSessionsRunning).Should(BeFalse())
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() })
sess.run()
runErr = sess.run()
}()
Eventually(areSessionsRunning).Should(BeTrue())
expectedRunErr = nil
})
AfterEach(func() {
if expectedRunErr != nil {
Expect(runErr).To(MatchError(expectedRunErr))
}
})
It("shuts down without error", func() {
@@ -397,13 +409,25 @@ var _ = Describe("Session", func() {
Expect(sess.Context().Done()).To(BeClosed())
})
It("closes the session in order to replace it with another QUIC version", func() {
It("closes the session in order to recreate it", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
cryptoSetup.EXPECT().Close()
sess.destroy(errCloseSessionForNewVersion)
sess.closeForRecreating()
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
Eventually(areSessionsRunning).Should(BeFalse())
expectedRunErr = errCloseForRecreating
})
It("destroys the session", func() {
testErr := errors.New("close")
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().removeConnectionID(gomock.Any())
cryptoSetup.EXPECT().Close()
sess.destroy(testErr)
Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent
expectedRunErr = testErr
})
It("cancels the context when the run loop exists", func() {