wait until handshaking connections have terminated when closing server (#4743)

This commit is contained in:
Marten Seemann
2024-12-03 11:10:13 +08:00
committed by GitHub
parent 363533dc7a
commit 71a27d40c9
2 changed files with 27 additions and 9 deletions

View File

@@ -640,27 +640,40 @@ var _ = Describe("Handshake tests", func() {
tr := &quic.Transport{Conn: udpConn}
addTracer(tr)
defer tr.Close()
tlsConf := &tls.Config{}
done := make(chan struct{})
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
<-done
return nil, errors.New("closed")
rtt := scaleDuration(40 * time.Millisecond)
connQueued := make(chan struct{})
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
close(connQueued)
// Sleep for a bit.
// This allows the server to close the connection before the handshake completes.
time.Sleep(rtt / 2)
return getTLSConfig(), nil
},
}
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
errChan := make(chan error, 1)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go func() {
defer GinkgoRecover()
_, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
errChan <- err
}()
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
Eventually(connQueued, 5*rtt).Should(BeClosed())
Expect(ln.Close()).To(Succeed())
close(done)
err = <-errChan
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))

View File

@@ -114,6 +114,7 @@ type baseServer struct {
invalidTokenQueue chan rejectedPacket
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket
handshakingCount sync.WaitGroup
verifySourceAddress func(net.Addr) bool
@@ -339,6 +340,8 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
func (s *baseServer) Close() error {
s.close(ErrServerClosed, true)
// wait until all handshakes in flight have terminated
s.handshakingCount.Wait()
return nil
}
@@ -713,8 +716,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
go conn.run()
s.handshakingCount.Add(1)
go func() {
defer s.handshakingCount.Done()
if completed := s.handleNewConn(conn); !completed {
return
}
@@ -725,6 +729,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
conn.closeWithTransportError(ConnectionRefused)
}
}()
go conn.run()
return nil
}