From 996fad14f814f1c227dbdd95db4ae72787d7d73b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 14 May 2017 13:43:45 +0800 Subject: [PATCH] remove unsafe from the session tests for the crypto setup --- packet_packer_test.go | 6 ++-- session.go | 18 ++++++++-- session_test.go | 76 ++++++++++++++++++++++++++++++------------- 3 files changed, 71 insertions(+), 29 deletions(-) diff --git a/packet_packer_test.go b/packet_packer_test.go index 7b347aab..1c7f07c0 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -34,10 +34,8 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) return append(src, bytes.Repeat([]byte{0}, 12)...) }, nil } -func (m *mockCryptoSetup) DiversificationNonce() []byte { - return m.divNonce -} -func (m *mockCryptoSetup) SetDiversificationNonce([]byte) { panic("not implemented") } +func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } +func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce } var _ handshake.CryptoSetup = &mockCryptoSetup{} diff --git a/session.go b/session.go index 70622620..9eea8e26 100644 --- a/session.go +++ b/session.go @@ -34,6 +34,11 @@ var ( errSessionAlreadyClosed = errors.New("cannot close session; it was already closed before") ) +var ( + newCryptoSetup = handshake.NewCryptoSetup + newCryptoSetupClient = handshake.NewCryptoSetupClient +) + type handshakeEvent struct { encLevel protocol.EncryptionLevel err error @@ -143,7 +148,16 @@ func newSession( handshakeChan := make(chan handshakeEvent, 3) s.handshakeChan = handshakeChan var err error - s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, config.Versions, aeadChanged) + s.cryptoSetup, err = newCryptoSetup( + connectionID, + sourceAddr, + v, + sCfg, + cryptoStream, + s.connectionParameters, + config.Versions, + aeadChanged, + ) if err != nil { return nil, nil, err } @@ -182,7 +196,7 @@ var newClientSession = func( s.handshakeChan = handshakeChan cryptoStream, _ := s.OpenStream() var err error - s.cryptoSetup, err = handshake.NewCryptoSetupClient( + s.cryptoSetup, err = newCryptoSetupClient( hostname, connectionID, v, diff --git a/session_test.go b/session_test.go index 70c330aa..ce43c1f4 100644 --- a/session_test.go +++ b/session_test.go @@ -2,15 +2,14 @@ package quic import ( "bytes" + "crypto/tls" "errors" "io" "net" - "reflect" "runtime/pprof" "strings" "sync/atomic" "time" - "unsafe" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -129,13 +128,32 @@ var _ = Describe("Session", func() { scfg *handshake.ServerConfig mconn *mockConnection cpm *mockConnectionParametersManager - aeadChanged chan<- protocol.EncryptionLevel + cryptoSetup *mockCryptoSetup handshakeChan <-chan handshakeEvent + aeadChanged chan<- protocol.EncryptionLevel + + cryptoSetupSourceAddr []byte ) BeforeEach(func() { Eventually(areSessionsRunning).Should(BeFalse()) + cryptoSetup = &mockCryptoSetup{} + newCryptoSetup = func( + _ protocol.ConnectionID, + sourceAddr []byte, + _ protocol.VersionNumber, + _ *handshake.ServerConfig, + _ io.ReadWriter, + _ handshake.ConnectionParametersManager, + _ []protocol.VersionNumber, + aeadChangedP chan<- protocol.EncryptionLevel, + ) (handshake.CryptoSetup, error) { + cryptoSetupSourceAddr = sourceAddr + aeadChanged = aeadChangedP + return cryptoSetup, nil + } + mconn = &mockConnection{ remoteAddr: &net.UDPAddr{}, } @@ -155,22 +173,20 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream - // we need an aeadChanged chan that we can write to - // since type assertions on chans are not possible, we have to extract it from the CryptoSetup - aeadChanged = *(*chan<- protocol.EncryptionLevel)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("aeadChanged").UnsafeAddr())) cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second} sess.connectionParameters = cpm }) AfterEach(func() { + newCryptoSetup = handshake.NewCryptoSetup Eventually(areSessionsRunning).Should(BeFalse()) }) Context("source address", func() { It("uses the IP address if given an UDP connection", func() { conn := &conn{currentAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}} - sess, _, err := newSession( + _, _, err := newSession( conn, protocol.VersionWhatever, 0, @@ -178,14 +194,14 @@ var _ = Describe("Session", func() { populateServerConfig(&Config{}), ) Expect(err).ToNot(HaveOccurred()) - Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200})) + Expect(cryptoSetupSourceAddr).To(Equal([]byte{192, 168, 100, 200})) }) It("uses the string representation of the remote addresses if not given a UDP connection", func() { conn := &conn{ currentAddr: &net.TCPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}, } - sess, _, err := newSession( + _, _, err := newSession( conn, protocol.VersionWhatever, 0, @@ -193,7 +209,7 @@ var _ = Describe("Session", func() { populateServerConfig(&Config{}), ) Expect(err).ToNot(HaveOccurred()) - Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337"))) + Expect(cryptoSetupSourceAddr).To(Equal([]byte("192.168.100.200:1337"))) }) }) @@ -688,7 +704,7 @@ var _ = Describe("Session", func() { sess.Close(nil) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) - Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) + Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))) Expect(sess.runClosed).To(BeClosed()) }) @@ -1216,17 +1232,14 @@ var _ = Describe("Session", func() { }) It("closes when crypto stream errors", func() { + testErr := errors.New("crypto setup error") + cryptoSetup.handleErr = testErr var runErr error go func() { runErr = sess.run() }() - err := sess.handleStreamFrame(&frames.StreamFrame{ - StreamID: 1, - Data: []byte("4242\x00\x00\x00\x00"), - }) - Expect(err).NotTo(HaveOccurred()) Eventually(func() error { return runErr }).Should(HaveOccurred()) - Expect(runErr.(qerr.ErrorCode)).To(Equal(qerr.InvalidCryptoMessageType)) + Expect(runErr).To(MatchError(testErr)) }) Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() { @@ -1512,11 +1525,29 @@ var _ = Describe("Client Session", func() { sess *session mconn *mockConnection aeadChanged chan<- protocol.EncryptionLevel + + cryptoSetup *mockCryptoSetup ) BeforeEach(func() { Eventually(areSessionsRunning).Should(BeFalse()) + cryptoSetup = &mockCryptoSetup{} + newCryptoSetupClient = func( + _ string, + _ protocol.ConnectionID, + _ protocol.VersionNumber, + _ io.ReadWriter, + _ *tls.Config, + _ handshake.ConnectionParametersManager, + aeadChangedP chan<- protocol.EncryptionLevel, + _ *handshake.TransportParameters, + _ []protocol.VersionNumber, + ) (handshake.CryptoSetup, error) { + aeadChanged = aeadChangedP + return cryptoSetup, nil + } + mconn = &mockConnection{ remoteAddr: &net.UDPAddr{}, } @@ -1531,9 +1562,10 @@ var _ = Describe("Client Session", func() { sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream - // we need an aeadChanged chan that we can write to - // since type assertions on chans are not possible, we have to extract it from the CryptoSetup - aeadChanged = *(*chan<- protocol.EncryptionLevel)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("aeadChanged").UnsafeAddr())) + }) + + AfterEach(func() { + newCryptoSetupClient = handshake.NewCryptoSetupClient }) Context("receiving packets", func() { @@ -1550,9 +1582,7 @@ var _ = Describe("Client Session", func() { hdr.DiversificationNonce = []byte("foobar") err := sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) - Eventually(func() []byte { - return *(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("diversificationNonce").UnsafeAddr())) - }).Should(Equal(hdr.DiversificationNonce)) + Eventually(func() []byte { return cryptoSetup.divNonce }).Should(Equal(hdr.DiversificationNonce)) Expect(sess.Close(nil)).To(Succeed()) }) })