forked from quic-go/quic-go
make sure not to return closed session from Listener.Accept()
This commit is contained in:
@@ -147,23 +147,29 @@ var _ = Describe("Handshake tests", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
Context("rate limiting", func() {
|
Context("rate limiting", func() {
|
||||||
It("rejects new connection attempts if connections don't get accepted", func() {
|
var server quic.Listener
|
||||||
|
|
||||||
|
dial := func() (quic.Session, error) {
|
||||||
|
return quic.DialAddr(
|
||||||
|
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||||
|
&tls.Config{RootCAs: testdata.GetRootCA()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
serverConfig.AcceptCookie = func(net.Addr, *quic.Cookie) bool { return true }
|
||||||
|
var err error
|
||||||
// start the server, but don't call Accept
|
// start the server, but don't call Accept
|
||||||
serverConfig.AcceptCookie = func(net.Addr, *quic.Cookie) bool {
|
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
|
||||||
return true
|
|
||||||
}
|
|
||||||
server, err := quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer server.Close()
|
})
|
||||||
|
|
||||||
dial := func() (quic.Session, error) {
|
AfterEach(func() {
|
||||||
return quic.DialAddr(
|
Expect(server.Close()).To(Succeed())
|
||||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
})
|
||||||
&tls.Config{RootCAs: testdata.GetRootCA()},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
It("rejects new connection attempts if connections don't get accepted", func() {
|
||||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||||
sess, err := dial()
|
sess, err := dial()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -171,7 +177,7 @@ var _ = Describe("Handshake tests", func() {
|
|||||||
}
|
}
|
||||||
time.Sleep(25 * time.Millisecond) // wait a bit for the sessions to be queued
|
time.Sleep(25 * time.Millisecond) // wait a bit for the sessions to be queued
|
||||||
|
|
||||||
_, err = dial()
|
_, err := dial()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
// TODO(#1567): use the SERVER_BUSY error code
|
// TODO(#1567): use the SERVER_BUSY error code
|
||||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
|
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
|
||||||
@@ -190,5 +196,37 @@ var _ = Describe("Handshake tests", func() {
|
|||||||
// TODO(#1567): use the SERVER_BUSY error code
|
// TODO(#1567): use the SERVER_BUSY error code
|
||||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
|
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("rejects new connection attempts if connections don't get accepted", func() {
|
||||||
|
firstSess, err := dial()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
for i := 1; i < protocol.MaxAcceptQueueSize; i++ {
|
||||||
|
sess, err := dial()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer sess.Close()
|
||||||
|
}
|
||||||
|
time.Sleep(25 * time.Millisecond) // wait a bit for the sessions to be queued
|
||||||
|
|
||||||
|
_, err = dial()
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
// TODO(#1567): use the SERVER_BUSY error code
|
||||||
|
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
|
||||||
|
|
||||||
|
// Now close the one of the session that are waiting to be accepted.
|
||||||
|
// This should free one spot in the queue.
|
||||||
|
Expect(firstSess.Close())
|
||||||
|
time.Sleep(25 * time.Millisecond)
|
||||||
|
|
||||||
|
// dial again, and expect that this dial succeeds
|
||||||
|
_, err = dial()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
time.Sleep(25 * time.Millisecond) // wait a bit for the session to be queued
|
||||||
|
|
||||||
|
_, err = dial()
|
||||||
|
// TODO(#1567): use the SERVER_BUSY error code
|
||||||
|
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PeerGoingAway))
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -170,8 +170,13 @@ func (s *server) setup() error {
|
|||||||
onHandshakeCompleteImpl: func(sess Session) {
|
onHandshakeCompleteImpl: func(sess Session) {
|
||||||
go func() {
|
go func() {
|
||||||
atomic.AddInt32(&s.sessionQueueLen, 1)
|
atomic.AddInt32(&s.sessionQueueLen, 1)
|
||||||
s.sessionQueue <- sess // blocks until the session is accepted
|
defer atomic.AddInt32(&s.sessionQueueLen, -1)
|
||||||
atomic.AddInt32(&s.sessionQueueLen, -1)
|
select {
|
||||||
|
case s.sessionQueue <- sess:
|
||||||
|
// blocks until the session is accepted
|
||||||
|
case <-sess.Context().Done():
|
||||||
|
// don't pass sessions that were already closed to Accept()
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
},
|
},
|
||||||
retireConnectionIDImpl: s.sessionHandler.Retire,
|
retireConnectionIDImpl: s.sessionHandler.Retire,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package quic
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
@@ -349,6 +350,7 @@ var _ = Describe("Server", func() {
|
|||||||
sess := NewMockQuicSession(mockCtrl)
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
sess.EXPECT().handlePacket(p)
|
sess.EXPECT().handlePacket(p)
|
||||||
sess.EXPECT().run()
|
sess.EXPECT().run()
|
||||||
|
sess.EXPECT().Context().Return(context.Background())
|
||||||
runner.onHandshakeComplete(sess)
|
runner.onHandshakeComplete(sess)
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
@@ -375,6 +377,64 @@ var _ = Describe("Server", func() {
|
|||||||
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||||
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("doesn't accept new sessions if they were closed in the mean time", func() {
|
||||||
|
serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true }
|
||||||
|
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
|
||||||
|
|
||||||
|
hdr := &wire.Header{
|
||||||
|
Type: protocol.PacketTypeInitial,
|
||||||
|
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
|
||||||
|
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||||
|
Version: protocol.VersionTLS,
|
||||||
|
}
|
||||||
|
p := &receivedPacket{
|
||||||
|
remoteAddr: senderAddr,
|
||||||
|
hdr: hdr,
|
||||||
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
sessionCreated := make(chan struct{})
|
||||||
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
|
serv.newSession = func(
|
||||||
|
_ connection,
|
||||||
|
runner sessionRunner,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ *Config,
|
||||||
|
_ *tls.Config,
|
||||||
|
_ *handshake.TransportParameters,
|
||||||
|
_ utils.Logger,
|
||||||
|
_ protocol.VersionNumber,
|
||||||
|
) (quicSession, error) {
|
||||||
|
sess.EXPECT().handlePacket(p)
|
||||||
|
sess.EXPECT().run()
|
||||||
|
sess.EXPECT().Context().Return(ctx)
|
||||||
|
runner.onHandshakeComplete(sess)
|
||||||
|
close(sessionCreated)
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serv.handlePacket(insertPacketBuffer(p))
|
||||||
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
|
Eventually(sessionCreated).Should(BeClosed())
|
||||||
|
cancel()
|
||||||
|
time.Sleep(scaleDuration(200 * time.Millisecond))
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
serv.Accept()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
|
|
||||||
|
// make the go routine return
|
||||||
|
sess.EXPECT().Close()
|
||||||
|
Expect(serv.Close()).To(Succeed())
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("accepting sessions", func() {
|
Context("accepting sessions", func() {
|
||||||
@@ -440,6 +500,7 @@ var _ = Describe("Server", func() {
|
|||||||
runner.onHandshakeComplete(sess)
|
runner.onHandshakeComplete(sess)
|
||||||
}()
|
}()
|
||||||
sess.EXPECT().run().Do(func() {})
|
sess.EXPECT().run().Do(func() {})
|
||||||
|
sess.EXPECT().Context().Return(context.Background())
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
||||||
@@ -466,8 +527,9 @@ var _ = Describe("Server", func() {
|
|||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
) (quicSession, error) {
|
) (quicSession, error) {
|
||||||
sess := NewMockQuicSession(mockCtrl)
|
sess := NewMockQuicSession(mockCtrl)
|
||||||
runner.onHandshakeComplete(sess)
|
|
||||||
sess.EXPECT().run().Do(func() {})
|
sess.EXPECT().run().Do(func() {})
|
||||||
|
sess.EXPECT().Context().Return(context.Background())
|
||||||
|
runner.onHandshakeComplete(sess)
|
||||||
done <- struct{}{}
|
done <- struct{}{}
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user