From 420f852f8608cc0e94f1cd90162eb51108d8ca9d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 8 Jan 2025 21:59:05 +0800 Subject: [PATCH] drain server's accept queue before returning ErrClosed from Accept (#4846) --- integrationtests/self/close_test.go | 34 ++++++++++++++ server.go | 69 ++++++++++++++++++----------- 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/integrationtests/self/close_test.go b/integrationtests/self/close_test.go index b03b72ea..583a804d 100644 --- a/integrationtests/self/close_test.go +++ b/integrationtests/self/close_test.go @@ -10,6 +10,7 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" + "github.com/quic-go/quic-go/internal/protocol" "github.com/stretchr/testify/require" ) @@ -79,3 +80,36 @@ func TestConnectionCloseRetransmission(t *testing.T) { require.Equal(t, packets[0], packets[i]) } } + +func TestDrainServerAcceptQueue(t *testing.T) { + server, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) + require.NoError(t, err) + defer server.Close() + + dialer := &quic.Transport{Conn: newUPDConnLocalhost(t)} + defer dialer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // fill up the accept queue + conns := make([]quic.Connection, 0, protocol.MaxAcceptQueueSize) + for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil)) + require.NoError(t, err) + conns = append(conns, conn) + } + time.Sleep(scaleDuration(25 * time.Millisecond)) // wait for connections to be queued + + server.Close() + for i := range protocol.MaxAcceptQueueSize { + c, err := server.Accept(ctx) + require.NoError(t, err) + // make sure the connection is not closed + require.NoError(t, conns[i].Context().Err(), "client connection closed") + require.NoError(t, c.Context().Err(), "server connection closed") + conns[i].CloseWithError(0, "") + c.CloseWithError(0, "") + } + _, err = server.Accept(ctx) + require.ErrorIs(t, err, quic.ErrServerClosed) +} diff --git a/server.go b/server.go index c9395edf..cece797b 100644 --- a/server.go +++ b/server.go @@ -105,10 +105,18 @@ type baseServer struct { protocol.Version, ) quicConn - closeMx sync.Mutex - errorChan chan struct{} // is closed when the server is closed - closeErr error - running chan struct{} // closed as soon as run() returns + closeMx sync.Mutex + // errorChan is closed when Close is called. This has two effects: + // 1. it cancels handshakes that are still in flight (using CONNECTION_REFUSED) errors + // 2. it stops handling of packets passed to this server + errorChan chan struct{} + // acceptChan is closed when Close returns. + // This only happens once all handshake in flight have either completed and canceled. + // Calls to Accept will first drain the queue of connections that have completed the handshake, + // and then return ErrServerClosed. + stopAccepting chan struct{} + closeErr error + running chan struct{} // closed as soon as run() returns versionNegotiationQueue chan receivedPacket invalidTokenQueue chan rejectedPacket @@ -263,6 +271,7 @@ func newServer( connHandler: connHandler, connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize), errorChan: make(chan struct{}), + stopAccepting: make(chan struct{}), running: make(chan struct{}), receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), versionNegotiationQueue: make(chan receivedPacket, 4), @@ -333,15 +342,19 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) { return nil, ctx.Err() case conn := <-s.connQueue: return conn, nil - case <-s.errorChan: + case <-s.stopAccepting: + // first drain the queue + select { + case conn := <-s.connQueue: + return conn, nil + default: + } return nil, s.closeErr } } func (s *baseServer) Close() error { s.close(ErrServerClosed, true) - // wait until all handshakes in flight have terminated - s.handshakingCount.Wait() return nil } @@ -359,6 +372,9 @@ func (s *baseServer) close(e error, notifyOnClose bool) { if notifyOnClose { s.onClose() } + // wait until all handshakes in flight have terminated + s.handshakingCount.Wait() + close(s.stopAccepting) } // Addr returns the server's network address @@ -369,6 +385,8 @@ func (s *baseServer) Addr() net.Addr { func (s *baseServer) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: + case <-s.errorChan: + return default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) if s.tracer != nil && s.tracer.DroppedPacket != nil { @@ -719,42 +737,39 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error s.handshakingCount.Add(1) go func() { defer s.handshakingCount.Done() - if completed := s.handleNewConn(conn); !completed { - return - } - - select { - case s.connQueue <- conn: - default: - conn.closeWithTransportError(ConnectionRefused) - } + s.handleNewConn(conn) }() go conn.run() return nil } -func (s *baseServer) handleNewConn(conn quicConn) bool { +func (s *baseServer) handleNewConn(conn quicConn) { if s.acceptEarlyConns { // wait until the early connection is ready, the handshake fails, or the server is closed select { case <-s.errorChan: conn.closeWithTransportError(ConnectionRefused) - return false + return case <-conn.Context().Done(): - return false + return case <-conn.earlyConnReady(): - return true + } + } else { + // wait until the handshake completes, fails, or the server is closed + select { + case <-s.errorChan: + conn.closeWithTransportError(ConnectionRefused) + return + case <-conn.Context().Done(): + return + case <-conn.HandshakeComplete(): } } - // wait until the handshake completes, fails, or the server is closed + select { - case <-s.errorChan: + case s.connQueue <- conn: + default: conn.closeWithTransportError(ConnectionRefused) - return false - case <-conn.Context().Done(): - return false - case <-conn.HandshakeComplete(): - return true } }