diff --git a/server.go b/server.go index f25796a7..e6317556 100644 --- a/server.go +++ b/server.go @@ -67,6 +67,9 @@ type server struct { config *Config conn net.PacketConn + // If the server is started with ListenAddr, we create a packet conn. + // If it is started with Listen, we take a packet conn as a parameter. + createdPacketConn bool supportsTLS bool serverTLS *serverTLS @@ -102,12 +105,21 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err if err != nil { return nil, err } - return Listen(conn, tlsConf, config) + serv, err := listen(conn, tlsConf, config) + if err != nil { + return nil, err + } + serv.createdPacketConn = true + return serv, nil } // Listen listens for QUIC connections on a given net.PacketConn. // The tls.Config must not be nil, the quic.Config may be nil. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { + return listen(conn, tlsConf, config) +} + +func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) { certChain := crypto.NewCertChain(tlsConf) kex, err := crypto.NewCurve25519KEX() if err != nil { @@ -287,12 +299,17 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { s.sessionHandler.CloseServer() - // TODO: close the conn if this server was started with ListenAddr() (but not with Listen(net.PacketConn)) if s.serverError == nil { s.serverError = errors.New("server closed") } + var err error + // If the server was started with ListenAddr, we created the packet conn. + // We need to close it in order to make the go routine reading from that conn return. + if s.createdPacketConn { + err = s.conn.Close() + } close(s.errorChan) - return nil + return err } // Addr returns the server's network address diff --git a/server_test.go b/server_test.go index ac55f598..10f58ced 100644 --- a/server_test.go +++ b/server_test.go @@ -254,6 +254,21 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) + It("closes the connection when it was created with ListenAddr", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:12345") + Expect(err).ToNot(HaveOccurred()) + + serv, err := ListenAddr("localhost:0", nil, nil) + Expect(err).ToNot(HaveOccurred()) + // test that we can write on the packet conn + _, err = serv.(*server).conn.WriteTo([]byte("foobar"), addr) + Expect(err).ToNot(HaveOccurred()) + Expect(serv.Close()).To(Succeed()) + // test that we can't write any more on the packet conn + _, err = serv.(*server).conn.WriteTo([]byte("foobar"), addr) + Expect(err.Error()).To(ContainSubstring("use of closed network connection")) + }) + It("returns Accept when it is closed", func() { done := make(chan struct{}) go func() {