simplify and reorganize session tests

This commit is contained in:
Lucas Clemente
2016-05-17 19:37:04 +02:00
parent 2864c97a70
commit d6ef71a54c
2 changed files with 52 additions and 81 deletions

View File

@@ -57,6 +57,8 @@ type Session struct {
unpacker *packetUnpacker
packer *packetPacker
cryptoSetup *handshake.CryptoSetup
receivedPackets chan receivedPacket
sendingScheduled chan struct{}
closeChan chan struct{}
@@ -100,22 +102,16 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
}
cryptoStream, _ := session.OpenStream(1)
cryptoSetup := handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
go func() {
if err := cryptoSetup.HandleCryptoStream(); err != nil {
session.Close(err)
}
}()
session.cryptoSetup = handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged)
session.packer = &packetPacker{
aead: cryptoSetup,
aead: session.cryptoSetup,
connectionParametersManager: session.connectionParametersManager,
sentPacketHandler: session.sentPacketHandler,
connectionID: connectionID,
version: v,
}
session.unpacker = &packetUnpacker{aead: cryptoSetup, version: v}
session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v}
session.congestion = congestion.NewCubicSender(
congestion.DefaultClock{},
@@ -130,6 +126,12 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
// run the session main loop
func (s *Session) run() {
go func() {
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
s.Close(err)
}
}()
for {
// Close immediately if requested
select {

View File

@@ -79,26 +79,31 @@ func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) {
panic("not implemented")
}
// TODO: Reorganize
var _ = Describe("Session", func() {
var (
session *Session
callbackCalled bool
conn *mockConnection
session *Session
streamCallbackCalled bool
closeCallbackCalled bool
conn *mockConnection
)
BeforeEach(func() {
conn = &mockConnection{}
callbackCalled = false
session = &Session{
conn: conn,
streams: make(map[protocol.StreamID]*stream),
streamCallback: func(*Session, utils.Stream) { callbackCalled = true },
connectionParametersManager: handshake.NewConnectionParamatersManager(),
closeChan: make(chan struct{}, 1),
closeCallback: func(protocol.ConnectionID) {},
packer: &packetPacker{aead: &crypto.NullAEAD{}},
}
streamCallbackCalled = false
closeCallbackCalled = false
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = newSession(
conn,
0,
0,
scfg,
func(*Session, utils.Stream) { streamCallbackCalled = true },
func(protocol.ConnectionID) { closeCallbackCalled = true },
).(*Session)
Expect(session.streams).To(HaveLen(1)) // Crypto stream
})
Context("when handling stream frames", func() {
@@ -107,8 +112,8 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(1))
Expect(callbackCalled).To(BeTrue())
Expect(session.streams).To(HaveLen(2))
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
Expect(err).ToNot(HaveOccurred())
@@ -138,14 +143,14 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca},
})
Expect(session.streams).To(HaveLen(1))
Expect(callbackCalled).To(BeTrue())
Expect(session.streams).To(HaveLen(2))
Expect(streamCallbackCalled).To(BeTrue())
session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
Offset: 2,
Data: []byte{0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
Expect(err).ToNot(HaveOccurred())
@@ -157,7 +162,7 @@ var _ = Describe("Session", func() {
Expect(err).ToNot(HaveOccurred())
str.Close()
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
})
@@ -167,15 +172,15 @@ var _ = Describe("Session", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(callbackCalled).To(BeTrue())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
})
@@ -185,20 +190,20 @@ var _ = Describe("Session", func() {
Data: []byte{0xde, 0xca, 0xfb, 0xad},
FinBit: true,
})
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(callbackCalled).To(BeTrue())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
Expect(err).To(MatchError(io.EOF))
Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad}))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
// We still need to close the stream locally
session.streams[5].Close()
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
})
@@ -208,9 +213,9 @@ var _ = Describe("Session", func() {
StreamID: 5,
Data: []byte{0xde, 0xca, 0xfb, 0xad},
})
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(callbackCalled).To(BeTrue())
Expect(streamCallbackCalled).To(BeTrue())
p := make([]byte, 4)
_, err := session.streams[5].Read(p)
Expect(err).ToNot(HaveOccurred())
@@ -218,7 +223,7 @@ var _ = Describe("Session", func() {
_, err = session.streams[5].Read(p)
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
})
@@ -227,14 +232,14 @@ var _ = Describe("Session", func() {
session.handleStreamFrame(&frames.StreamFrame{
StreamID: 5,
})
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).ToNot(BeNil())
Expect(callbackCalled).To(BeTrue())
Expect(streamCallbackCalled).To(BeTrue())
session.closeStreamsWithError(testErr)
_, err := session.streams[5].Read([]byte{0})
Expect(err).To(MatchError(testErr))
session.garbageCollectStreams()
Expect(session.streams).To(HaveLen(1))
Expect(session.streams).To(HaveLen(2))
Expect(session.streams[5]).To(BeNil())
})
@@ -314,23 +319,18 @@ var _ = Describe("Session", func() {
Context("closing", func() {
var (
nGoRoutinesBefore int
closed bool
)
BeforeEach(func() {
time.Sleep(10 * time.Millisecond) // Wait for old goroutines to finish
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
nGoRoutinesBefore = runtime.NumGoroutine()
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) { closed = true }).(*Session)
go session.run()
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore + 2))
})
It("shuts down without error", func() {
session.Close(nil)
Expect(closed).To(BeTrue())
Expect(closeCallbackCalled).To(BeTrue())
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
@@ -348,7 +348,7 @@ var _ = Describe("Session", func() {
s, err := session.OpenStream(5)
Expect(err).NotTo(HaveOccurred())
session.Close(testErr)
Expect(closed).To(BeTrue())
Expect(closeCallbackCalled).To(BeTrue())
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
n, err := s.Read([]byte{0})
Expect(n).To(BeZero())
@@ -360,13 +360,6 @@ var _ = Describe("Session", func() {
})
Context("sending packets", func() {
BeforeEach(func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = newSession(conn, 0, 0, scfg, nil, nil).(*Session)
})
It("sends ack frames", func() {
packetNumber := protocol.PacketNumber(0x0135)
var entropy ackhandler.EntropyAccumulator
@@ -453,13 +446,6 @@ var _ = Describe("Session", func() {
})
Context("scheduling sending", func() {
BeforeEach(func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
})
It("sends after queuing a stream frame", func() {
Expect(session.sendingScheduled).NotTo(Receive())
err := session.queueStreamFrame(&frames.StreamFrame{StreamID: 1})
@@ -553,10 +539,7 @@ var _ = Describe("Session", func() {
})
It("closes when crypto stream errors", func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
go session.run()
s, err := session.OpenStream(3)
Expect(err).NotTo(HaveOccurred())
err = session.handleStreamFrame(&frames.StreamFrame{
@@ -570,11 +553,6 @@ var _ = Describe("Session", func() {
})
It("sends public reset after too many undecryptable packets", func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
// Write protocol.MaxUndecryptablePackets and expect a public reset to happen
for i := 0; i < protocol.MaxUndecryptablePackets; i++ {
hdr := &publicHeader{
@@ -589,10 +567,6 @@ var _ = Describe("Session", func() {
})
It("unqueues undecryptable packets for later decryption", func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
session.undecryptablePackets = []receivedPacket{{
nil,
&publicHeader{PacketNumber: protocol.PacketNumber(42)},
@@ -621,11 +595,6 @@ var _ = Describe("Session", func() {
)
BeforeEach(func() {
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session)
cong = &mockCongestion{}
session.congestion = cong
})