forked from quic-go/quic-go
drain server's accept queue before returning ErrClosed from Accept (#4846)
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -79,3 +80,36 @@ func TestConnectionCloseRetransmission(t *testing.T) {
|
||||
require.Equal(t, packets[0], packets[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrainServerAcceptQueue(t *testing.T) {
|
||||
server, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
|
||||
dialer := &quic.Transport{Conn: newUPDConnLocalhost(t)}
|
||||
defer dialer.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
// fill up the accept queue
|
||||
conns := make([]quic.Connection, 0, protocol.MaxAcceptQueueSize)
|
||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||
conn, err := dialer.Dial(ctx, server.Addr(), getTLSClientConfig(), getQuicConfig(nil))
|
||||
require.NoError(t, err)
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
time.Sleep(scaleDuration(25 * time.Millisecond)) // wait for connections to be queued
|
||||
|
||||
server.Close()
|
||||
for i := range protocol.MaxAcceptQueueSize {
|
||||
c, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
// make sure the connection is not closed
|
||||
require.NoError(t, conns[i].Context().Err(), "client connection closed")
|
||||
require.NoError(t, c.Context().Err(), "server connection closed")
|
||||
conns[i].CloseWithError(0, "")
|
||||
c.CloseWithError(0, "")
|
||||
}
|
||||
_, err = server.Accept(ctx)
|
||||
require.ErrorIs(t, err, quic.ErrServerClosed)
|
||||
}
|
||||
|
||||
69
server.go
69
server.go
@@ -105,10 +105,18 @@ type baseServer struct {
|
||||
protocol.Version,
|
||||
) quicConn
|
||||
|
||||
closeMx sync.Mutex
|
||||
errorChan chan struct{} // is closed when the server is closed
|
||||
closeErr error
|
||||
running chan struct{} // closed as soon as run() returns
|
||||
closeMx sync.Mutex
|
||||
// errorChan is closed when Close is called. This has two effects:
|
||||
// 1. it cancels handshakes that are still in flight (using CONNECTION_REFUSED) errors
|
||||
// 2. it stops handling of packets passed to this server
|
||||
errorChan chan struct{}
|
||||
// acceptChan is closed when Close returns.
|
||||
// This only happens once all handshake in flight have either completed and canceled.
|
||||
// Calls to Accept will first drain the queue of connections that have completed the handshake,
|
||||
// and then return ErrServerClosed.
|
||||
stopAccepting chan struct{}
|
||||
closeErr error
|
||||
running chan struct{} // closed as soon as run() returns
|
||||
|
||||
versionNegotiationQueue chan receivedPacket
|
||||
invalidTokenQueue chan rejectedPacket
|
||||
@@ -263,6 +271,7 @@ func newServer(
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
|
||||
errorChan: make(chan struct{}),
|
||||
stopAccepting: make(chan struct{}),
|
||||
running: make(chan struct{}),
|
||||
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||
versionNegotiationQueue: make(chan receivedPacket, 4),
|
||||
@@ -333,15 +342,19 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
|
||||
return nil, ctx.Err()
|
||||
case conn := <-s.connQueue:
|
||||
return conn, nil
|
||||
case <-s.errorChan:
|
||||
case <-s.stopAccepting:
|
||||
// first drain the queue
|
||||
select {
|
||||
case conn := <-s.connQueue:
|
||||
return conn, nil
|
||||
default:
|
||||
}
|
||||
return nil, s.closeErr
|
||||
}
|
||||
}
|
||||
|
||||
func (s *baseServer) Close() error {
|
||||
s.close(ErrServerClosed, true)
|
||||
// wait until all handshakes in flight have terminated
|
||||
s.handshakingCount.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -359,6 +372,9 @@ func (s *baseServer) close(e error, notifyOnClose bool) {
|
||||
if notifyOnClose {
|
||||
s.onClose()
|
||||
}
|
||||
// wait until all handshakes in flight have terminated
|
||||
s.handshakingCount.Wait()
|
||||
close(s.stopAccepting)
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
@@ -369,6 +385,8 @@ func (s *baseServer) Addr() net.Addr {
|
||||
func (s *baseServer) handlePacket(p receivedPacket) {
|
||||
select {
|
||||
case s.receivedPackets <- p:
|
||||
case <-s.errorChan:
|
||||
return
|
||||
default:
|
||||
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
|
||||
if s.tracer != nil && s.tracer.DroppedPacket != nil {
|
||||
@@ -719,42 +737,39 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
s.handshakingCount.Add(1)
|
||||
go func() {
|
||||
defer s.handshakingCount.Done()
|
||||
if completed := s.handleNewConn(conn); !completed {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case s.connQueue <- conn:
|
||||
default:
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
}
|
||||
s.handleNewConn(conn)
|
||||
}()
|
||||
go conn.run()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *baseServer) handleNewConn(conn quicConn) bool {
|
||||
func (s *baseServer) handleNewConn(conn quicConn) {
|
||||
if s.acceptEarlyConns {
|
||||
// wait until the early connection is ready, the handshake fails, or the server is closed
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
return false
|
||||
return
|
||||
case <-conn.Context().Done():
|
||||
return false
|
||||
return
|
||||
case <-conn.earlyConnReady():
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// wait until the handshake completes, fails, or the server is closed
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
return
|
||||
case <-conn.Context().Done():
|
||||
return
|
||||
case <-conn.HandshakeComplete():
|
||||
}
|
||||
}
|
||||
// wait until the handshake completes, fails, or the server is closed
|
||||
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
case s.connQueue <- conn:
|
||||
default:
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
return false
|
||||
case <-conn.Context().Done():
|
||||
return false
|
||||
case <-conn.HandshakeComplete():
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user