wait for serve() to return before returning from server.Close()

This commit is contained in:
Marten Seemann
2017-12-20 14:20:46 +07:00
parent 1692958f10
commit 96571b56e5
3 changed files with 22 additions and 11 deletions

View File

@@ -20,13 +20,16 @@ var _ = Describe("Handshake tests", func() {
)
BeforeEach(func() {
server = nil
acceptStopped = make(chan struct{})
serverConfig = &quic.Config{}
})
AfterEach(func() {
Expect(server.Close()).To(Succeed())
<-acceptStopped
if server != nil {
server.Close()
<-acceptStopped
}
})
runServer := func() {
@@ -60,7 +63,7 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
})
It("when the client supports more versions than the supports", func() {
It("when the client supports more versions than the server supports", func() {
if len(protocol.SupportedVersions) == 1 {
Skip("Test requires at least 2 supported versions.")
}

View File

@@ -40,15 +40,17 @@ type server struct {
certChain crypto.CertChain
scfg *handshake.ServerConfig
sessions map[protocol.ConnectionID]packetHandler
sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration
sessionsMutex sync.RWMutex
sessions map[protocol.ConnectionID]packetHandler
closed bool
serverError error
sessionQueue chan Session
errorChan chan struct{}
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error)
// set as members, so they can be set in the tests
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error)
deleteClosedSessionsAfter time.Duration
}
var _ Listener = &server{}
@@ -240,6 +242,12 @@ func (s *server) Accept() (Session, error) {
// Close the server
func (s *server) Close() error {
s.sessionsMutex.Lock()
if s.closed {
s.sessionsMutex.Unlock()
return nil
}
s.closed = true
var wg sync.WaitGroup
for _, session := range s.sessions {
if session != nil {
@@ -254,10 +262,9 @@ func (s *server) Close() error {
s.sessionsMutex.Unlock()
wg.Wait()
if s.conn == nil {
return nil
}
return s.conn.Close()
err := s.conn.Close()
<-s.errorChan // wait for serve() to return
return err
}
// Addr returns the server's network address

View File

@@ -222,6 +222,7 @@ var _ = Describe("Server", func() {
})
It("closes sessions and the connection when Close is called", func() {
go serv.serve()
session, _ := newMockSession(nil, 0, 0, nil, nil, nil)
serv.sessions[1] = session
err := serv.Close()