make sure not to return closed session from Listener.Accept()

This commit is contained in:
Marten Seemann
2019-01-06 14:13:13 +07:00
parent 90514d53d1
commit 181aa493e0
3 changed files with 122 additions and 17 deletions

View File

@@ -147,23 +147,29 @@ var _ = Describe("Handshake tests", func() {
}) })
Context("rate limiting", func() { Context("rate limiting", func() {
It("rejects new connection attempts if connections don't get accepted", func() { var server quic.Listener
dial := func() (quic.Session, error) {
return quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
&tls.Config{RootCAs: testdata.GetRootCA()},
nil,
)
}
BeforeEach(func() {
serverConfig.AcceptCookie = func(net.Addr, *quic.Cookie) bool { return true }
var err error
// start the server, but don't call Accept // start the server, but don't call Accept
serverConfig.AcceptCookie = func(net.Addr, *quic.Cookie) bool { server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
return true
}
server, err := quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer server.Close() })
dial := func() (quic.Session, error) { AfterEach(func() {
return quic.DialAddr( Expect(server.Close()).To(Succeed())
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), })
&tls.Config{RootCAs: testdata.GetRootCA()},
nil,
)
}
It("rejects new connection attempts if connections don't get accepted", func() {
for i := 0; i < protocol.MaxAcceptQueueSize; i++ { for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
sess, err := dial() sess, err := dial()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@@ -171,7 +177,7 @@ var _ = Describe("Handshake tests", func() {
} }
time.Sleep(25 * time.Millisecond) // wait a bit for the sessions to be queued time.Sleep(25 * time.Millisecond) // wait a bit for the sessions to be queued
_, err = dial() _, err := dial()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
// TODO(#1567): use the SERVER_BUSY error code // TODO(#1567): use the SERVER_BUSY error code
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
@@ -190,5 +196,37 @@ var _ = Describe("Handshake tests", func() {
// TODO(#1567): use the SERVER_BUSY error code // TODO(#1567): use the SERVER_BUSY error code
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
}) })
It("rejects new connection attempts if connections don't get accepted", func() {
firstSess, err := dial()
Expect(err).ToNot(HaveOccurred())
for i := 1; i < protocol.MaxAcceptQueueSize; i++ {
sess, err := dial()
Expect(err).ToNot(HaveOccurred())
defer sess.Close()
}
time.Sleep(25 * time.Millisecond) // wait a bit for the sessions to be queued
_, err = dial()
Expect(err).To(HaveOccurred())
// TODO(#1567): use the SERVER_BUSY error code
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
// Now close the one of the session that are waiting to be accepted.
// This should free one spot in the queue.
Expect(firstSess.Close())
time.Sleep(25 * time.Millisecond)
// dial again, and expect that this dial succeeds
_, err = dial()
Expect(err).ToNot(HaveOccurred())
time.Sleep(25 * time.Millisecond) // wait a bit for the session to be queued
_, err = dial()
// TODO(#1567): use the SERVER_BUSY error code
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
})
}) })
}) })

View File

@@ -170,8 +170,13 @@ func (s *server) setup() error {
onHandshakeCompleteImpl: func(sess Session) { onHandshakeCompleteImpl: func(sess Session) {
go func() { go func() {
atomic.AddInt32(&s.sessionQueueLen, 1) atomic.AddInt32(&s.sessionQueueLen, 1)
s.sessionQueue <- sess // blocks until the session is accepted defer atomic.AddInt32(&s.sessionQueueLen, -1)
atomic.AddInt32(&s.sessionQueueLen, -1) select {
case s.sessionQueue <- sess:
// blocks until the session is accepted
case <-sess.Context().Done():
// don't pass sessions that were already closed to Accept()
}
}() }()
}, },
retireConnectionIDImpl: s.sessionHandler.Retire, retireConnectionIDImpl: s.sessionHandler.Retire,

View File

@@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
@@ -349,6 +350,7 @@ var _ = Describe("Server", func() {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p) sess.EXPECT().handlePacket(p)
sess.EXPECT().run() sess.EXPECT().run()
sess.EXPECT().Context().Return(context.Background())
runner.onHandshakeComplete(sess) runner.onHandshakeComplete(sess)
return sess, nil return sess, nil
} }
@@ -375,6 +377,64 @@ var _ = Describe("Server", func() {
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
}) })
It("doesn't accept new sessions if they were closed in the mean time", func() {
serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := &receivedPacket{
remoteAddr: senderAddr,
hdr: hdr,
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
}
ctx, cancel := context.WithCancel(context.Background())
sessionCreated := make(chan struct{})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().Context().Return(ctx)
runner.onHandshakeComplete(sess)
close(sessionCreated)
return sess, nil
}
serv.handlePacket(insertPacketBuffer(p))
Consistently(conn.dataWritten).ShouldNot(Receive())
Eventually(sessionCreated).Should(BeClosed())
cancel()
time.Sleep(scaleDuration(200 * time.Millisecond))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept()
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
sess.EXPECT().Close()
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
}) })
Context("accepting sessions", func() { Context("accepting sessions", func() {
@@ -440,6 +500,7 @@ var _ = Describe("Server", func() {
runner.onHandshakeComplete(sess) runner.onHandshakeComplete(sess)
}() }()
sess.EXPECT().run().Do(func() {}) sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background())
return sess, nil return sess, nil
} }
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
@@ -466,8 +527,9 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
) (quicSession, error) { ) (quicSession, error) {
sess := NewMockQuicSession(mockCtrl) sess := NewMockQuicSession(mockCtrl)
runner.onHandshakeComplete(sess)
sess.EXPECT().run().Do(func() {}) sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background())
runner.onHandshakeComplete(sess)
done <- struct{}{} done <- struct{}{}
return sess, nil return sess, nil
} }