drain server's accept queue before returning ErrClosed from Accept (#4846)

This commit is contained in:
Marten Seemann
2025-01-08 21:59:05 +08:00
committed by GitHub
parent 793389b322
commit 420f852f86
2 changed files with 76 additions and 27 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/stretchr/testify/require"
)
@@ -79,3 +80,36 @@ func TestConnectionCloseRetransmission(t *testing.T) {
require.Equal(t, packets[0], packets[i])
}
}
func TestDrainServerAcceptQueue(t *testing.T) {
server, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer server.Close()
dialer := &quic.Transport{Conn: newUPDConnLocalhost(t)}
defer dialer.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// fill up the accept queue
conns := make([]quic.Connection, 0, protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
require.NoError(t, err)
conns = append(conns, conn)
}
time.Sleep(scaleDuration(25 * time.Millisecond)) // wait for connections to be queued
server.Close()
for i := range protocol.MaxAcceptQueueSize {
c, err := server.Accept(ctx)
require.NoError(t, err)
// make sure the connection is not closed
require.NoError(t, conns[i].Context().Err(), "client connection closed")
require.NoError(t, c.Context().Err(), "server connection closed")
conns[i].CloseWithError(0, "")
c.CloseWithError(0, "")
}
_, err = server.Accept(ctx)
require.ErrorIs(t, err, quic.ErrServerClosed)
}

View File

@@ -105,10 +105,18 @@ type baseServer struct {
protocol.Version,
) quicConn
closeMx sync.Mutex
errorChan chan struct{} // is closed when the server is closed
closeErr error
running chan struct{} // closed as soon as run() returns
closeMx sync.Mutex
// errorChan is closed when Close is called. This has two effects:
// 1. it cancels handshakes that are still in flight (using CONNECTION_REFUSED) errors
// 2. it stops handling of packets passed to this server
errorChan chan struct{}
// acceptChan is closed when Close returns.
// This only happens once all handshake in flight have either completed and canceled.
// Calls to Accept will first drain the queue of connections that have completed the handshake,
// and then return ErrServerClosed.
stopAccepting chan struct{}
closeErr error
running chan struct{} // closed as soon as run() returns
versionNegotiationQueue chan receivedPacket
invalidTokenQueue chan rejectedPacket
@@ -263,6 +271,7 @@ func newServer(
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
stopAccepting: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
@@ -333,15 +342,19 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
return nil, ctx.Err()
case conn := <-s.connQueue:
return conn, nil
case <-s.errorChan:
case <-s.stopAccepting:
// first drain the queue
select {
case conn := <-s.connQueue:
return conn, nil
default:
}
return nil, s.closeErr
}
}
func (s *baseServer) Close() error {
s.close(ErrServerClosed, true)
// wait until all handshakes in flight have terminated
s.handshakingCount.Wait()
return nil
}
@@ -359,6 +372,9 @@ func (s *baseServer) close(e error, notifyOnClose bool) {
if notifyOnClose {
s.onClose()
}
// wait until all handshakes in flight have terminated
s.handshakingCount.Wait()
close(s.stopAccepting)
}
// Addr returns the server's network address
@@ -369,6 +385,8 @@ func (s *baseServer) Addr() net.Addr {
func (s *baseServer) handlePacket(p receivedPacket) {
select {
case s.receivedPackets <- p:
case <-s.errorChan:
return
default:
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
if s.tracer != nil && s.tracer.DroppedPacket != nil {
@@ -719,42 +737,39 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.handshakingCount.Add(1)
go func() {
defer s.handshakingCount.Done()
if completed := s.handleNewConn(conn); !completed {
return
}
select {
case s.connQueue <- conn:
default:
conn.closeWithTransportError(ConnectionRefused)
}
s.handleNewConn(conn)
}()
go conn.run()
return nil
}
func (s *baseServer) handleNewConn(conn quicConn) bool {
func (s *baseServer) handleNewConn(conn quicConn) {
if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return false
return
case <-conn.Context().Done():
return false
return
case <-conn.earlyConnReady():
return true
}
} else {
// wait until the handshake completes, fails, or the server is closed
select {
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.Context().Done():
return
case <-conn.HandshakeComplete():
}
}
// wait until the handshake completes, fails, or the server is closed
select {
case <-s.errorChan:
case s.connQueue <- conn:
default:
conn.closeWithTransportError(ConnectionRefused)
return false
case <-conn.Context().Done():
return false
case <-conn.HandshakeComplete():
return true
}
}