make sure not to return closed session from Listener.Accept()

This commit is contained in:
Marten Seemann
2019-01-06 14:13:13 +07:00
parent 90514d53d1
commit 181aa493e0
3 changed files with 122 additions and 17 deletions

View File

@@ -2,6 +2,7 @@ package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"net"
@@ -349,6 +350,7 @@ var _ = Describe("Server", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().Context().Return(context.Background())
runner.onHandshakeComplete(sess)
return sess, nil
}
@@ -375,6 +377,64 @@ var _ = Describe("Server", func() {
Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
})
It("doesn't accept new sessions if they were closed in the mean time", func() {
serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
p := &receivedPacket{
remoteAddr: senderAddr,
hdr: hdr,
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
}
ctx, cancel := context.WithCancel(context.Background())
sessionCreated := make(chan struct{})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ connection,
runner sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
sess.EXPECT().handlePacket(p)
sess.EXPECT().run()
sess.EXPECT().Context().Return(ctx)
runner.onHandshakeComplete(sess)
close(sessionCreated)
return sess, nil
}
serv.handlePacket(insertPacketBuffer(p))
Consistently(conn.dataWritten).ShouldNot(Receive())
Eventually(sessionCreated).Should(BeClosed())
cancel()
time.Sleep(scaleDuration(200 * time.Millisecond))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.Accept()
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
sess.EXPECT().Close()
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
})
Context("accepting sessions", func() {
@@ -440,6 +500,7 @@ var _ = Describe("Server", func() {
runner.onHandshakeComplete(sess)
}()
sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background())
return sess, nil
}
_, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
@@ -466,8 +527,9 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber,
) (quicSession, error) {
sess := NewMockQuicSession(mockCtrl)
runner.onHandshakeComplete(sess)
sess.EXPECT().run().Do(func() {})
sess.EXPECT().Context().Return(context.Background())
runner.onHandshakeComplete(sess)
done <- struct{}{}
return sess, nil
}