From 9fad63ff5058f872e00db5852c08490ed2a1f7bf Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 12 May 2017 18:40:53 +0800 Subject: [PATCH] improve client tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use a mock newClientSession. That way, it’s a lot easier to test dialing new connections. --- client_test.go | 356 +++++++++++++++++++++++++++++++----------------- server_test.go | 42 +++--- session.go | 5 +- session_test.go | 4 +- 4 files changed, 263 insertions(+), 144 deletions(-) diff --git a/client_test.go b/client_test.go index de20e6eb..bf40d044 100644 --- a/client_test.go +++ b/client_test.go @@ -4,8 +4,6 @@ import ( "bytes" "errors" "net" - "reflect" - "unsafe" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -21,27 +19,34 @@ var _ = Describe("Client", func() { sess *mockSession packetConn *mockPacketConn addr net.Addr + + originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error) ) BeforeEach(func() { + originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) + msess, _, _ := newMockSession(nil, 0, 0, nil, nil) + sess = msess.(*mockSession) packetConn = &mockPacketConn{} config = &Config{ Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78}, } addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - sess = &mockSession{connectionID: 0x1337} cl = &client{ - config: config, - connectionID: 0x1337, - session: sess, - version: protocol.SupportedVersions[0], - conn: &conn{pconn: packetConn, currentAddr: addr}, - errorChan: make(chan struct{}), - handshakeChan: make(chan handshakeEvent), + config: config, + connectionID: 0x1337, + session: sess, + version: protocol.SupportedVersions[0], + conn: &conn{pconn: packetConn, currentAddr: addr}, + errorChan: make(chan struct{}), } }) + AfterEach(func() { + newClientSession = originalClientSessConstructor + }) + AfterEach(func() { if s, ok := cl.session.(*session); ok { s.Close(nil) @@ -50,13 +55,88 @@ var _ = Describe("Client", func() { }) Context("Dialing", func() { - PIt("creates a new client", func() { - packetConn.dataToRead = []byte{0x0, 0x1, 0x0} - sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) - Expect(err).ToNot(HaveOccurred()) - Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) - Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io")) - sess.Close(nil) + BeforeEach(func() { + newClientSession = func( + _ connection, + _ string, + _ protocol.VersionNumber, + _ protocol.ConnectionID, + _ *Config, + _ []protocol.VersionNumber, + ) (packetHandler, <-chan handshakeEvent, error) { + return sess, sess.handshakeChan, nil + } + }) + + It("dials non-forward-secure", func(done Done) { + var dialedSess Session + go func() { + defer GinkgoRecover() + var err error + dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config) + Expect(err).ToNot(HaveOccurred()) + }() + Consistently(func() Session { return dialedSess }).Should(BeNil()) + sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} + Eventually(func() Session { return dialedSess }).ShouldNot(BeNil()) + close(done) + }) + + It("Dial only returns after the handshake is complete", func(done Done) { + var dialedSess Session + go func() { + defer GinkgoRecover() + var err error + dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", config) + Expect(err).ToNot(HaveOccurred()) + }() + sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} + Consistently(func() Session { return dialedSess }).Should(BeNil()) + close(sess.handshakeComplete) + Eventually(func() Session { return dialedSess }).ShouldNot(BeNil()) + close(done) + }) + + It("resolves the address", func(done Done) { + var cconn connection + newClientSession = func( + conn connection, + _ string, + _ protocol.VersionNumber, + _ protocol.ConnectionID, + _ *Config, + _ []protocol.VersionNumber, + ) (packetHandler, <-chan handshakeEvent, error) { + cconn = conn + return sess, nil, nil + } + go DialAddr("localhost:17890", &Config{}) + Eventually(func() connection { return cconn }).ShouldNot(BeNil()) + Expect(cconn.RemoteAddr().String()).To(Equal("127.0.0.1:17890")) + close(done) + }) + + It("returns an error that occurs while waiting for the connection to become secure", func(done Done) { + testErr := errors.New("early handshake error") + var dialErr error + go func() { + _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config) + }() + sess.handshakeChan <- handshakeEvent{err: testErr} + Eventually(func() error { return dialErr }).Should(MatchError(testErr)) + close(done) + }) + + It("returns an error that occurs while waiting for the handshake to complete", func(done Done) { + testErr := errors.New("late handshake error") + var dialErr error + go func() { + _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config) + }() + sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} + sess.handshakeComplete <- testErr + Eventually(func() error { return dialErr }).Should(MatchError(testErr)) + close(done) }) It("uses all supported versions, if none are specified in the quic.Config", func() { @@ -64,18 +144,121 @@ var _ = Describe("Client", func() { Expect(c.Versions).To(Equal(protocol.SupportedVersions)) }) - It("errors when receiving an invalid first packet from the server", func() { + It("errors when receiving an invalid first packet from the server", func(done Done) { packetConn.dataToRead = []byte{0xff} - sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) Expect(err).To(HaveOccurred()) - Expect(sess).To(BeNil()) + close(done) }) - It("errors when receiving an error from the connection", func() { + It("errors when receiving an error from the connection", func(done Done) { testErr := errors.New("connection error") packetConn.readErr = testErr _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) Expect(err).To(MatchError(testErr)) + close(done) + }) + + It("errors if it can't create a session", func() { + testErr := errors.New("error creating session") + newClientSession = func( + _ connection, + _ string, + _ protocol.VersionNumber, + _ protocol.ConnectionID, + _ *Config, + _ []protocol.VersionNumber, + ) (packetHandler, <-chan handshakeEvent, error) { + return nil, nil, testErr + } + _, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config) + Expect(err).To(MatchError(testErr)) + }) + + Context("version negotiation", func() { + It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() { + ph := PublicHeader{ + PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen2, + ConnectionID: 0x1337, + } + b := &bytes.Buffer{} + err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + err = cl.handlePacket(nil, b.Bytes()) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.versionNegotiated).To(BeTrue()) + }) + + It("changes the version after receiving a version negotiation packet", func() { + var negotiatedVersions []protocol.VersionNumber + newClientSession = func( + _ connection, + _ string, + _ protocol.VersionNumber, + connectionID protocol.ConnectionID, + _ *Config, + negotiatedVersionsP []protocol.VersionNumber, + ) (packetHandler, <-chan handshakeEvent, error) { + negotiatedVersions = negotiatedVersionsP + return &mockSession{ + connectionID: connectionID, + }, nil, nil + } + + newVersion := protocol.VersionNumber(77) + Expect(config.Versions).To(ContainElement(newVersion)) + Expect(newVersion).ToNot(Equal(cl.version)) + Expect(sess.packetCount).To(BeZero()) + cl.connectionID = 0x1337 + err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.version).To(Equal(newVersion)) + Expect(cl.versionNegotiated).To(BeTrue()) + // it swapped the sessions + // Expect(cl.session).ToNot(Equal(sess)) + Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID + Expect(err).ToNot(HaveOccurred()) + // it didn't pass the version negoation packet to the old session (since it has no payload) + Expect(sess.packetCount).To(BeZero()) + Expect(negotiatedVersions).To(Equal([]protocol.VersionNumber{newVersion})) + }) + + It("errors if no matching version is found", func() { + err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) + Expect(err).To(MatchError(qerr.InvalidVersion)) + }) + + It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { + v := protocol.SupportedVersions[1] + Expect(v).ToNot(Equal(cl.version)) + Expect(config.Versions).ToNot(ContainElement(v)) + err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v})) + Expect(err).To(MatchError(qerr.InvalidVersion)) + }) + + It("changes to the version preferred by the quic.Config", func() { + err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.version).To(Equal(config.Versions[1])) + }) + + It("ignores delayed version negotiation packets", func() { + // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test + cl.versionNegotiated = true + Expect(sess.packetCount).To(BeZero()) + err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.versionNegotiated).To(BeTrue()) + Expect(sess.packetCount).To(BeZero()) + }) + + It("drops version negotiation packets that contain the offered version", func() { + ver := cl.version + err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{ver})) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.version).To(Equal(ver)) + }) }) }) @@ -84,39 +267,37 @@ var _ = Describe("Client", func() { Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader)) }) - // this test requires a real session - // and a real UDP conn (because it unblocks and errors when it is closed) - PIt("properly closes", func(done Done) { - Eventually(areSessionsRunning).Should(BeFalse()) - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) - Expect(err).ToNot(HaveOccurred()) - cl.conn = &conn{pconn: udpConn, currentAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}} - err = cl.createNewSession(nil) - Expect(err).ToNot(HaveOccurred()) - Eventually(areSessionsRunning).Should(BeTrue()) - - var stoppedListening bool + It("creates new sessions with the right parameters", func(done Done) { + c := make(chan struct{}) + var cconn connection + var hostname string + var version protocol.VersionNumber + var conf *Config + newClientSession = func( + connP connection, + hostnameP string, + versionP protocol.VersionNumber, + _ protocol.ConnectionID, + configP *Config, + _ []protocol.VersionNumber, + ) (packetHandler, <-chan handshakeEvent, error) { + cconn = connP + hostname = hostnameP + version = versionP + conf = configP + close(c) + return sess, nil, nil + } go func() { - cl.listen() - stoppedListening = true + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + Expect(err).ToNot(HaveOccurred()) }() - - testErr := errors.New("test error") - err = cl.session.Close(testErr) - Expect(err).ToNot(HaveOccurred()) - Eventually(func() bool { return stoppedListening }).Should(BeTrue()) - Eventually(areSessionsRunning).Should(BeFalse()) + <-c + Expect(cconn.(*conn).pconn).To(Equal(packetConn)) + Expect(hostname).To(Equal("quic.clemente.io")) + Expect(version).To(Equal(cl.version)) + Expect(conf).To(Equal(config)) close(done) - }, 10) - - It("creates new sessions with the right parameters", func() { - cl.session = nil - cl.hostname = "hostname" - err := cl.createNewSession(nil) - Expect(err).ToNot(HaveOccurred()) - Expect(cl.session).ToNot(BeNil()) - Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID)) - Expect(cl.session.(*session).version).To(Equal(cl.version)) }) Context("handling packets", func() { @@ -160,77 +341,4 @@ var _ = Describe("Client", func() { Expect(sess.closeReason).To(MatchError(testErr)) }) }) - - Context("version negotiation", func() { - It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() { - ph := PublicHeader{ - PacketNumber: 1, - PacketNumberLen: protocol.PacketNumberLen2, - ConnectionID: 0x1337, - } - b := &bytes.Buffer{} - err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - err = cl.handlePacket(nil, b.Bytes()) - Expect(err).ToNot(HaveOccurred()) - Expect(cl.versionNegotiated).To(BeTrue()) - }) - - It("changes the version after receiving a version negotiation packet", func() { - newVersion := protocol.VersionNumber(77) - Expect(config.Versions).To(ContainElement(newVersion)) - Expect(newVersion).ToNot(Equal(cl.version)) - Expect(sess.packetCount).To(BeZero()) - cl.connectionID = 0x1337 - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) - Expect(err).ToNot(HaveOccurred()) - Expect(cl.version).To(Equal(newVersion)) - Expect(cl.versionNegotiated).To(BeTrue()) - // it swapped the sessions - Expect(cl.session).ToNot(Equal(sess)) - Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID - Expect(err).ToNot(HaveOccurred()) - // it didn't pass the version negoation packet to the old session (since it has no payload) - Expect(sess.packetCount).To(BeZero()) - // if the version negotiation packet was passed to the new session, it would end up as an undecryptable packet there - Expect(cl.session.(*session).undecryptablePackets).To(BeEmpty()) - Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{newVersion})) - }) - - It("errors if no matching version is found", func() { - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) - Expect(err).To(MatchError(qerr.InvalidVersion)) - }) - - It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() { - v := protocol.SupportedVersions[1] - Expect(v).ToNot(Equal(cl.version)) - Expect(config.Versions).ToNot(ContainElement(v)) - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v})) - Expect(err).To(MatchError(qerr.InvalidVersion)) - }) - - It("changes to the version preferred by the quic.Config", func() { - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) - Expect(err).ToNot(HaveOccurred()) - Expect(cl.version).To(Equal(config.Versions[1])) - }) - - It("ignores delayed version negotiation packets", func() { - // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test - cl.versionNegotiated = true - Expect(sess.packetCount).To(BeZero()) - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) - Expect(err).ToNot(HaveOccurred()) - Expect(cl.versionNegotiated).To(BeTrue()) - Expect(sess.packetCount).To(BeZero()) - }) - - It("drops version negotiation packets that contain the offered version", func() { - ver := cl.version - err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{ver})) - Expect(err).ToNot(HaveOccurred()) - Expect(cl.version).To(Equal(ver)) - }) - }) }) diff --git a/server_test.go b/server_test.go index 248385ae..75addde6 100644 --- a/server_test.go +++ b/server_test.go @@ -18,12 +18,13 @@ import ( ) type mockSession struct { - connectionID protocol.ConnectionID - packetCount int - closed bool - closeReason error - stopRunLoop chan struct{} // run returns as soon as this channel receives a value - handshakeChan chan handshakeEvent + connectionID protocol.ConnectionID + packetCount int + closed bool + closeReason error + stopRunLoop chan struct{} // run returns as soon as this channel receives a value + handshakeChan chan handshakeEvent + handshakeComplete chan error // for WaitUntilHandshakeComplete } func (s *mockSession) handlePacket(*receivedPacket) { @@ -34,9 +35,16 @@ func (s *mockSession) run() error { <-s.stopRunLoop return s.closeReason } +func (s *mockSession) WaitUntilHandshakeComplete() error { + return <-s.handshakeComplete +} func (s *mockSession) Close(e error) error { + if s.closed { + return nil + } s.closeReason = e s.closed = true + close(s.stopRunLoop) return nil } func (s *mockSession) AcceptStream() (Stream, error) { @@ -56,6 +64,7 @@ func (s *mockSession) RemoteAddr() net.Addr { } var _ Session = &mockSession{} +var _ NonFWSession = &mockSession{} func newMockSession( _ connection, @@ -65,9 +74,10 @@ func newMockSession( _ *Config, ) (packetHandler, <-chan handshakeEvent, error) { s := mockSession{ - connectionID: connectionID, - handshakeChan: make(chan handshakeEvent), - stopRunLoop: make(chan struct{}), + connectionID: connectionID, + handshakeChan: make(chan handshakeEvent), + handshakeComplete: make(chan error), + stopRunLoop: make(chan struct{}), } return &s, s.handshakeChan, nil } @@ -211,11 +221,11 @@ var _ = Describe("Server", func() { }) It("closes sessions and the connection when Close is called", func() { - session := &mockSession{} + session, _, _ := newMockSession(nil, 0, 0, nil, nil) serv.sessions[1] = session err := serv.Close() Expect(err).NotTo(HaveOccurred()) - Expect(session.closed).To(BeTrue()) + Expect(session.(*mockSession).closed).To(BeTrue()) Expect(conn.closed).To(BeTrue()) }) @@ -254,14 +264,14 @@ var _ = Describe("Server", func() { }, 0.5) It("closes all sessions when encountering a connection error", func() { - err := serv.handlePacket(nil, nil, firstPacket) - Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveKey(connID)) - Expect(serv.sessions[connID].(*mockSession).closed).To(BeFalse()) + session, _, _ := newMockSession(nil, 0, 0, nil, nil) + serv.sessions[0x12345] = session + Expect(serv.sessions[0x12345].(*mockSession).closed).To(BeFalse()) testErr := errors.New("connection error") conn.readErr = testErr go serv.serve() - Eventually(func() bool { return serv.sessions[connID].(*mockSession).closed }).Should(BeTrue()) + Eventually(func() Session { return serv.sessions[connID] }).Should(BeNil()) + Eventually(func() bool { return session.(*mockSession).closed }).Should(BeTrue()) Expect(serv.Close()).To(Succeed()) }) diff --git a/session.go b/session.go index 27ec2602..70622620 100644 --- a/session.go +++ b/session.go @@ -154,14 +154,15 @@ func newSession( return s, handshakeChan, err } -func newClientSession( +// declare this as a variable, such that we can it mock it in the tests +var newClientSession = func( conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, config *Config, negotiatedVersions []protocol.VersionNumber, -) (*session, <-chan handshakeEvent, error) { +) (packetHandler, <-chan handshakeEvent, error) { s := &session{ conn: conn, connectionID: connectionID, diff --git a/session_test.go b/session_test.go index 9469d5fa..70c330aa 100644 --- a/session_test.go +++ b/session_test.go @@ -1520,8 +1520,7 @@ var _ = Describe("Client Session", func() { mconn = &mockConnection{ remoteAddr: &net.UDPAddr{}, } - var err error - sess, _, err = newClientSession( + sessP, _, err := newClientSession( mconn, "hostname", protocol.Version35, @@ -1529,6 +1528,7 @@ var _ = Describe("Client Session", func() { populateClientConfig(&Config{}), nil, ) + sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream // we need an aeadChanged chan that we can write to