diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index b17104c9..44b508ad 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -20,13 +20,16 @@ var _ = Describe("Handshake tests", func() { ) BeforeEach(func() { + server = nil acceptStopped = make(chan struct{}) serverConfig = &quic.Config{} }) AfterEach(func() { - Expect(server.Close()).To(Succeed()) - <-acceptStopped + if server != nil { + server.Close() + <-acceptStopped + } }) runServer := func() { @@ -60,7 +63,7 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) }) - It("when the client supports more versions than the supports", func() { + It("when the client supports more versions than the server supports", func() { if len(protocol.SupportedVersions) == 1 { Skip("Test requires at least 2 supported versions.") } diff --git a/server.go b/server.go index 03766f3e..e7ddae8f 100644 --- a/server.go +++ b/server.go @@ -40,15 +40,17 @@ type server struct { certChain crypto.CertChain scfg *handshake.ServerConfig - sessions map[protocol.ConnectionID]packetHandler - sessionsMutex sync.RWMutex - deleteClosedSessionsAfter time.Duration + sessionsMutex sync.RWMutex + sessions map[protocol.ConnectionID]packetHandler + closed bool serverError error sessionQueue chan Session errorChan chan struct{} - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) + // set as members, so they can be set in the tests + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) + deleteClosedSessionsAfter time.Duration } var _ Listener = &server{} @@ -240,6 +242,12 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { s.sessionsMutex.Lock() + if s.closed { + s.sessionsMutex.Unlock() + return nil + } + s.closed = true + var wg sync.WaitGroup for _, session := range s.sessions { if session != nil { @@ -254,10 +262,9 @@ func (s *server) Close() error { s.sessionsMutex.Unlock() wg.Wait() - if s.conn == nil { - return nil - } - return s.conn.Close() + err := s.conn.Close() + <-s.errorChan // wait for serve() to return + return err } // Addr returns the server's network address diff --git a/server_test.go b/server_test.go index b82a14e4..1ca7e652 100644 --- a/server_test.go +++ b/server_test.go @@ -222,6 +222,7 @@ var _ = Describe("Server", func() { }) It("closes sessions and the connection when Close is called", func() { + go serv.serve() session, _ := newMockSession(nil, 0, 0, nil, nil, nil) serv.sessions[1] = session err := serv.Close()