From 71a27d40c96a4d7a97fb9854e0f090c9b6f4f5e6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 3 Dec 2024 11:10:13 +0800 Subject: [PATCH] wait until handshaking connections have terminated when closing server (#4743) --- integrationtests/self/handshake_test.go | 29 ++++++++++++++++++------- server.go | 7 +++++- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index e4a53991..5e038ab3 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -640,27 +640,40 @@ var _ = Describe("Handshake tests", func() { tr := &quic.Transport{Conn: udpConn} addTracer(tr) defer tr.Close() - tlsConf := &tls.Config{} - done := make(chan struct{}) - tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - <-done - return nil, errors.New("closed") + + rtt := scaleDuration(40 * time.Millisecond) + connQueued := make(chan struct{}) + tlsConf := &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + close(connQueued) + // Sleep for a bit. + // This allows the server to close the connection before the handshake completes. + time.Sleep(rtt / 2) + return getTLSConfig(), nil + }, } ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() errChan := make(chan error, 1) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() go func() { defer GinkgoRecover() _, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil)) errChan <- err }() - time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued + Eventually(connQueued, 5*rtt).Should(BeClosed()) Expect(ln.Close()).To(Succeed()) - close(done) err = <-errChan + Expect(err).To(HaveOccurred()) var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) diff --git a/server.go b/server.go index 0cf45aca..c9395edf 100644 --- a/server.go +++ b/server.go @@ -114,6 +114,7 @@ type baseServer struct { invalidTokenQueue chan rejectedPacket connectionRefusedQueue chan rejectedPacket retryQueue chan rejectedPacket + handshakingCount sync.WaitGroup verifySourceAddress func(net.Addr) bool @@ -339,6 +340,8 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) { func (s *baseServer) Close() error { s.close(ErrServerClosed, true) + // wait until all handshakes in flight have terminated + s.handshakingCount.Wait() return nil } @@ -713,8 +716,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error delete(s.zeroRTTQueues, hdr.DestConnectionID) } - go conn.run() + s.handshakingCount.Add(1) go func() { + defer s.handshakingCount.Done() if completed := s.handleNewConn(conn); !completed { return } @@ -725,6 +729,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error conn.closeWithTransportError(ConnectionRefused) } }() + go conn.run() return nil }