From d6ef71a54c793bd604f63de3c210bace6b565c79 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Tue, 17 May 2016 19:37:04 +0200 Subject: [PATCH] simplify and reorganize session tests --- session.go | 20 +++++---- session_test.go | 113 ++++++++++++++++++------------------------------ 2 files changed, 52 insertions(+), 81 deletions(-) diff --git a/session.go b/session.go index a3cf1e829..902d6acbd 100644 --- a/session.go +++ b/session.go @@ -57,6 +57,8 @@ type Session struct { unpacker *packetUnpacker packer *packetPacker + cryptoSetup *handshake.CryptoSetup + receivedPackets chan receivedPacket sendingScheduled chan struct{} closeChan chan struct{} @@ -100,22 +102,16 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol } cryptoStream, _ := session.OpenStream(1) - cryptoSetup := handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) - - go func() { - if err := cryptoSetup.HandleCryptoStream(); err != nil { - session.Close(err) - } - }() + session.cryptoSetup = handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) session.packer = &packetPacker{ - aead: cryptoSetup, + aead: session.cryptoSetup, connectionParametersManager: session.connectionParametersManager, sentPacketHandler: session.sentPacketHandler, connectionID: connectionID, version: v, } - session.unpacker = &packetUnpacker{aead: cryptoSetup, version: v} + session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: v} session.congestion = congestion.NewCubicSender( congestion.DefaultClock{}, @@ -130,6 +126,12 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol // run the session main loop func (s *Session) run() { + go func() { + if err := s.cryptoSetup.HandleCryptoStream(); err != nil { + s.Close(err) + } + }() + for { // Close immediately if requested select { diff --git a/session_test.go b/session_test.go index b7e4252fe..c48af4353 100644 --- a/session_test.go +++ b/session_test.go @@ -79,26 +79,31 @@ func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) { panic("not implemented") } -// TODO: Reorganize var _ = Describe("Session", func() { var ( - session *Session - callbackCalled bool - conn *mockConnection + session *Session + streamCallbackCalled bool + closeCallbackCalled bool + conn *mockConnection ) BeforeEach(func() { conn = &mockConnection{} - callbackCalled = false - session = &Session{ - conn: conn, - streams: make(map[protocol.StreamID]*stream), - streamCallback: func(*Session, utils.Stream) { callbackCalled = true }, - connectionParametersManager: handshake.NewConnectionParamatersManager(), - closeChan: make(chan struct{}, 1), - closeCallback: func(protocol.ConnectionID) {}, - packer: &packetPacker{aead: &crypto.NullAEAD{}}, - } + streamCallbackCalled = false + closeCallbackCalled = false + + signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) + Expect(err).ToNot(HaveOccurred()) + scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) + session = newSession( + conn, + 0, + 0, + scfg, + func(*Session, utils.Stream) { streamCallbackCalled = true }, + func(protocol.ConnectionID) { closeCallbackCalled = true }, + ).(*Session) + Expect(session.streams).To(HaveLen(1)) // Crypto stream }) Context("when handling stream frames", func() { @@ -107,8 +112,8 @@ var _ = Describe("Session", func() { StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) - Expect(session.streams).To(HaveLen(1)) - Expect(callbackCalled).To(BeTrue()) + Expect(session.streams).To(HaveLen(2)) + Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) _, err := session.streams[5].Read(p) Expect(err).ToNot(HaveOccurred()) @@ -138,14 +143,14 @@ var _ = Describe("Session", func() { StreamID: 5, Data: []byte{0xde, 0xca}, }) - Expect(session.streams).To(HaveLen(1)) - Expect(callbackCalled).To(BeTrue()) + Expect(session.streams).To(HaveLen(2)) + Expect(streamCallbackCalled).To(BeTrue()) session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Offset: 2, Data: []byte{0xfb, 0xad}, }) - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) p := make([]byte, 4) _, err := session.streams[5].Read(p) Expect(err).ToNot(HaveOccurred()) @@ -157,7 +162,7 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) str.Close() session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).ToNot(BeNil()) }) @@ -167,15 +172,15 @@ var _ = Describe("Session", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, FinBit: true, }) - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).ToNot(BeNil()) - Expect(callbackCalled).To(BeTrue()) + Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) _, err := session.streams[5].Read(p) Expect(err).To(MatchError(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).ToNot(BeNil()) }) @@ -185,20 +190,20 @@ var _ = Describe("Session", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, FinBit: true, }) - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).ToNot(BeNil()) - Expect(callbackCalled).To(BeTrue()) + Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) _, err := session.streams[5].Read(p) Expect(err).To(MatchError(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).ToNot(BeNil()) // We still need to close the stream locally session.streams[5].Close() session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).To(BeNil()) }) @@ -208,9 +213,9 @@ var _ = Describe("Session", func() { StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).ToNot(BeNil()) - Expect(callbackCalled).To(BeTrue()) + Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) _, err := session.streams[5].Read(p) Expect(err).ToNot(HaveOccurred()) @@ -218,7 +223,7 @@ var _ = Describe("Session", func() { _, err = session.streams[5].Read(p) Expect(err).To(MatchError(testErr)) session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).To(BeNil()) }) @@ -227,14 +232,14 @@ var _ = Describe("Session", func() { session.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, }) - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).ToNot(BeNil()) - Expect(callbackCalled).To(BeTrue()) + Expect(streamCallbackCalled).To(BeTrue()) session.closeStreamsWithError(testErr) _, err := session.streams[5].Read([]byte{0}) Expect(err).To(MatchError(testErr)) session.garbageCollectStreams() - Expect(session.streams).To(HaveLen(1)) + Expect(session.streams).To(HaveLen(2)) Expect(session.streams[5]).To(BeNil()) }) @@ -314,23 +319,18 @@ var _ = Describe("Session", func() { Context("closing", func() { var ( nGoRoutinesBefore int - closed bool ) BeforeEach(func() { time.Sleep(10 * time.Millisecond) // Wait for old goroutines to finish - signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) nGoRoutinesBefore = runtime.NumGoroutine() - session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) { closed = true }).(*Session) go session.run() Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore + 2)) }) It("shuts down without error", func() { session.Close(nil) - Expect(closed).To(BeTrue()) + Expect(closeCallbackCalled).To(BeTrue()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Expect(conn.written).To(HaveLen(1)) Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) @@ -348,7 +348,7 @@ var _ = Describe("Session", func() { s, err := session.OpenStream(5) Expect(err).NotTo(HaveOccurred()) session.Close(testErr) - Expect(closed).To(BeTrue()) + Expect(closeCallbackCalled).To(BeTrue()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) n, err := s.Read([]byte{0}) Expect(n).To(BeZero()) @@ -360,13 +360,6 @@ var _ = Describe("Session", func() { }) Context("sending packets", func() { - BeforeEach(func() { - signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = newSession(conn, 0, 0, scfg, nil, nil).(*Session) - }) - It("sends ack frames", func() { packetNumber := protocol.PacketNumber(0x0135) var entropy ackhandler.EntropyAccumulator @@ -453,13 +446,6 @@ var _ = Describe("Session", func() { }) Context("scheduling sending", func() { - BeforeEach(func() { - signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session) - }) - It("sends after queuing a stream frame", func() { Expect(session.sendingScheduled).NotTo(Receive()) err := session.queueStreamFrame(&frames.StreamFrame{StreamID: 1}) @@ -553,10 +539,7 @@ var _ = Describe("Session", func() { }) It("closes when crypto stream errors", func() { - signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session) + go session.run() s, err := session.OpenStream(3) Expect(err).NotTo(HaveOccurred()) err = session.handleStreamFrame(&frames.StreamFrame{ @@ -570,11 +553,6 @@ var _ = Describe("Session", func() { }) It("sends public reset after too many undecryptable packets", func() { - signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session) - // Write protocol.MaxUndecryptablePackets and expect a public reset to happen for i := 0; i < protocol.MaxUndecryptablePackets; i++ { hdr := &publicHeader{ @@ -589,10 +567,6 @@ var _ = Describe("Session", func() { }) It("unqueues undecryptable packets for later decryption", func() { - signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session) session.undecryptablePackets = []receivedPacket{{ nil, &publicHeader{PacketNumber: protocol.PacketNumber(42)}, @@ -621,11 +595,6 @@ var _ = Describe("Session", func() { ) BeforeEach(func() { - signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) - Expect(err).ToNot(HaveOccurred()) - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = newSession(conn, 0, 0, scfg, nil, func(protocol.ConnectionID) {}).(*Session) - cong = &mockCongestion{} session.congestion = cong })