forked from quic-go/quic-go
simplify and reorganize session tests
This commit is contained in:
20
session.go
20
session.go
@@ -57,6 +57,8 @@ type Session struct {
|
||||
unpacker *packetUnpacker
|
||||
packer *packetPacker
|
||||
|
||||
cryptoSetup *handshake.CryptoSetup
|
||||
|
||||
receivedPackets chan receivedPacket
|
||||
sendingScheduled chan struct{}
|
||||
closeChan chan struct{}
|
||||
@@ -100,22 +102,16 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
}
|
||||
|
||||
cryptoStream, _ := session.OpenStream(1)
|
||||
cryptoSetup := handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
|
||||
|
||||
go func() {
|
||||
if err := cryptoSetup.HandleCryptoStream(); err != nil {
|
||||
session.Close(err)
|
||||
}
|
||||
}()
|
||||
session.cryptoSetup = handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
|
||||
|
||||
session.packer = &packetPacker{
|
||||
aead: cryptoSetup,
|
||||
aead: session.cryptoSetup,
|
||||
connectionParametersManager: session.connectionParametersManager,
|
||||
sentPacketHandler: session.sentPacketHandler,
|
||||
connectionID: connectionID,
|
||||
version: v,
|
||||
}
|
||||
session.unpacker = &packetUnpacker{aead: cryptoSetup, version: v}
|
||||
session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v}
|
||||
|
||||
session.congestion = congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
@@ -130,6 +126,12 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
|
||||
// run the session main loop
|
||||
func (s *Session) run() {
|
||||
go func() {
|
||||
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
|
||||
s.Close(err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
// Close immediately if requested
|
||||
select {
|
||||
|
||||
113
session_test.go
113
session_test.go
@@ -79,26 +79,31 @@ func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// TODO: Reorganize
|
||||
var _ = Describe("Session", func() {
|
||||
var (
|
||||
session *Session
|
||||
callbackCalled bool
|
||||
conn *mockConnection
|
||||
session *Session
|
||||
streamCallbackCalled bool
|
||||
closeCallbackCalled bool
|
||||
conn *mockConnection
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
conn = &mockConnection{}
|
||||
callbackCalled = false
|
||||
session = &Session{
|
||||
conn: conn,
|
||||
streams: make(map[protocol.StreamID]*stream),
|
||||
streamCallback: func(*Session, utils.Stream) { callbackCalled = true },
|
||||
connectionParametersManager: handshake.NewConnectionParamatersManager(),
|
||||
closeChan: make(chan struct{}, 1),
|
||||
closeCallback: func(protocol.ConnectionID) {},
|
||||
packer: &packetPacker{aead: &crypto.NullAEAD{}},
|
||||
}
|
||||
streamCallbackCalled = false
|
||||
closeCallbackCalled = false
|
||||
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = newSession(
|
||||
conn,
|
||||
0,
|
||||
0,
|
||||
scfg,
|
||||
func(*Session, utils.Stream) { streamCallbackCalled = true },
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
).(*Session)
|
||||
Expect(session.streams).To(HaveLen(1)) // Crypto stream
|
||||
})
|
||||
|
||||
Context("when handling stream frames", func() {
|
||||
@@ -107,8 +112,8 @@ var _ = Describe("Session", func() {
|
||||
StreamID: 5,
|
||||
Data: []byte{0xde, 0xca, 0xfb, 0xad},
|
||||
})
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(callbackCalled).To(BeTrue())
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
p := make([]byte, 4)
|
||||
_, err := session.streams[5].Read(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -138,14 +143,14 @@ var _ = Describe("Session", func() {
|
||||
StreamID: 5,
|
||||
Data: []byte{0xde, 0xca},
|
||||
})
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(callbackCalled).To(BeTrue())
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
session.handleStreamFrame(&frames.StreamFrame{
|
||||
StreamID: 5,
|
||||
Offset: 2,
|
||||
Data: []byte{0xfb, 0xad},
|
||||
})
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
p := make([]byte, 4)
|
||||
_, err := session.streams[5].Read(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -157,7 +162,7 @@ var _ = Describe("Session", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.Close()
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).ToNot(BeNil())
|
||||
})
|
||||
|
||||
@@ -167,15 +172,15 @@ var _ = Describe("Session", func() {
|
||||
Data: []byte{0xde, 0xca, 0xfb, 0xad},
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).ToNot(BeNil())
|
||||
Expect(callbackCalled).To(BeTrue())
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
p := make([]byte, 4)
|
||||
_, err := session.streams[5].Read(p)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).ToNot(BeNil())
|
||||
})
|
||||
|
||||
@@ -185,20 +190,20 @@ var _ = Describe("Session", func() {
|
||||
Data: []byte{0xde, 0xca, 0xfb, 0xad},
|
||||
FinBit: true,
|
||||
})
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).ToNot(BeNil())
|
||||
Expect(callbackCalled).To(BeTrue())
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
p := make([]byte, 4)
|
||||
_, err := session.streams[5].Read(p)
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).ToNot(BeNil())
|
||||
// We still need to close the stream locally
|
||||
session.streams[5].Close()
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).To(BeNil())
|
||||
})
|
||||
|
||||
@@ -208,9 +213,9 @@ var _ = Describe("Session", func() {
|
||||
StreamID: 5,
|
||||
Data: []byte{0xde, 0xca, 0xfb, 0xad},
|
||||
})
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).ToNot(BeNil())
|
||||
Expect(callbackCalled).To(BeTrue())
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
p := make([]byte, 4)
|
||||
_, err := session.streams[5].Read(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -218,7 +223,7 @@ var _ = Describe("Session", func() {
|
||||
_, err = session.streams[5].Read(p)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).To(BeNil())
|
||||
})
|
||||
|
||||
@@ -227,14 +232,14 @@ var _ = Describe("Session", func() {
|
||||
session.handleStreamFrame(&frames.StreamFrame{
|
||||
StreamID: 5,
|
||||
})
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).ToNot(BeNil())
|
||||
Expect(callbackCalled).To(BeTrue())
|
||||
Expect(streamCallbackCalled).To(BeTrue())
|
||||
session.closeStreamsWithError(testErr)
|
||||
_, err := session.streams[5].Read([]byte{0})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
session.garbageCollectStreams()
|
||||
Expect(session.streams).To(HaveLen(1))
|
||||
Expect(session.streams).To(HaveLen(2))
|
||||
Expect(session.streams[5]).To(BeNil())
|
||||
})
|
||||
|
||||
@@ -314,23 +319,18 @@ var _ = Describe("Session", func() {
|
||||
Context("closing", func() {
|
||||
var (
|
||||
nGoRoutinesBefore int
|
||||
closed bool
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
time.Sleep(10 * time.Millisecond) // Wait for old goroutines to finish
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
nGoRoutinesBefore = runtime.NumGoroutine()
|
||||
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) { closed = true }).(*Session)
|
||||
go session.run()
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore + 2))
|
||||
})
|
||||
|
||||
It("shuts down without error", func() {
|
||||
session.Close(nil)
|
||||
Expect(closed).To(BeTrue())
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
|
||||
Expect(conn.written).To(HaveLen(1))
|
||||
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
|
||||
@@ -348,7 +348,7 @@ var _ = Describe("Session", func() {
|
||||
s, err := session.OpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
session.Close(testErr)
|
||||
Expect(closed).To(BeTrue())
|
||||
Expect(closeCallbackCalled).To(BeTrue())
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
|
||||
n, err := s.Read([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
@@ -360,13 +360,6 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
Context("sending packets", func() {
|
||||
BeforeEach(func() {
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = newSession(conn, 0, 0, scfg, nil, nil).(*Session)
|
||||
})
|
||||
|
||||
It("sends ack frames", func() {
|
||||
packetNumber := protocol.PacketNumber(0x0135)
|
||||
var entropy ackhandler.EntropyAccumulator
|
||||
@@ -453,13 +446,6 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
Context("scheduling sending", func() {
|
||||
BeforeEach(func() {
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
|
||||
})
|
||||
|
||||
It("sends after queuing a stream frame", func() {
|
||||
Expect(session.sendingScheduled).NotTo(Receive())
|
||||
err := session.queueStreamFrame(&frames.StreamFrame{StreamID: 1})
|
||||
@@ -553,10 +539,7 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
It("closes when crypto stream errors", func() {
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
|
||||
go session.run()
|
||||
s, err := session.OpenStream(3)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = session.handleStreamFrame(&frames.StreamFrame{
|
||||
@@ -570,11 +553,6 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
It("sends public reset after too many undecryptable packets", func() {
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
|
||||
|
||||
// Write protocol.MaxUndecryptablePackets and expect a public reset to happen
|
||||
for i := 0; i < protocol.MaxUndecryptablePackets; i++ {
|
||||
hdr := &publicHeader{
|
||||
@@ -589,10 +567,6 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
It("unqueues undecryptable packets for later decryption", func() {
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
|
||||
session.undecryptablePackets = []receivedPacket{{
|
||||
nil,
|
||||
&publicHeader{PacketNumber: protocol.PacketNumber(42)},
|
||||
@@ -621,11 +595,6 @@ var _ = Describe("Session", func() {
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
|
||||
|
||||
cong = &mockCongestion{}
|
||||
session.congestion = cong
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user