forked from quic-go/quic-go
wait until handshaking connections have terminated when closing server (#4743)
This commit is contained in:
@@ -640,27 +640,40 @@ var _ = Describe("Handshake tests", func() {
|
||||
tr := &quic.Transport{Conn: udpConn}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
tlsConf := &tls.Config{}
|
||||
done := make(chan struct{})
|
||||
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
<-done
|
||||
return nil, errors.New("closed")
|
||||
|
||||
rtt := scaleDuration(40 * time.Millisecond)
|
||||
connQueued := make(chan struct{})
|
||||
tlsConf := &tls.Config{
|
||||
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
close(connQueued)
|
||||
// Sleep for a bit.
|
||||
// This allows the server to close the connection before the handshake completes.
|
||||
time.Sleep(rtt / 2)
|
||||
return getTLSConfig(), nil
|
||||
},
|
||||
}
|
||||
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
|
||||
errChan <- err
|
||||
}()
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
|
||||
Eventually(connQueued, 5*rtt).Should(BeClosed())
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
close(done)
|
||||
err = <-errChan
|
||||
Expect(err).To(HaveOccurred())
|
||||
var transportErr *quic.TransportError
|
||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
|
||||
|
||||
@@ -114,6 +114,7 @@ type baseServer struct {
|
||||
invalidTokenQueue chan rejectedPacket
|
||||
connectionRefusedQueue chan rejectedPacket
|
||||
retryQueue chan rejectedPacket
|
||||
handshakingCount sync.WaitGroup
|
||||
|
||||
verifySourceAddress func(net.Addr) bool
|
||||
|
||||
@@ -339,6 +340,8 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
|
||||
|
||||
func (s *baseServer) Close() error {
|
||||
s.close(ErrServerClosed, true)
|
||||
// wait until all handshakes in flight have terminated
|
||||
s.handshakingCount.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -713,8 +716,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
}
|
||||
|
||||
go conn.run()
|
||||
s.handshakingCount.Add(1)
|
||||
go func() {
|
||||
defer s.handshakingCount.Done()
|
||||
if completed := s.handleNewConn(conn); !completed {
|
||||
return
|
||||
}
|
||||
@@ -725,6 +729,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
|
||||
conn.closeWithTransportError(ConnectionRefused)
|
||||
}
|
||||
}()
|
||||
go conn.run()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user