From 57c6f3ceb56720b5eae494dc10598c5bf8ed3d40 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 11 Nov 2017 10:12:11 +0800 Subject: [PATCH] don't return the handshakeChan from the session constructor --- client.go | 6 ++---- client_test.go | 37 +++++++++++++++++++------------------ server.go | 8 ++++---- server_test.go | 21 +++++++++++---------- session.go | 24 +++++++++++++----------- session_test.go | 25 ++++++++++++------------- 6 files changed, 61 insertions(+), 60 deletions(-) diff --git a/client.go b/client.go index 21d19fed..dc962c83 100644 --- a/client.go +++ b/client.go @@ -22,8 +22,6 @@ type client struct { conn connection hostname string - handshakeChan <-chan handshakeEvent - versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version versionNegotiated bool // has version negotiation completed yet receivedVersionNegotiationPacket bool @@ -210,7 +208,7 @@ func (c *client) establishSecureConnection() error { select { case <-errorChan: return runErr - case ev := <-c.handshakeChan: + case ev := <-c.session.handshakeStatus(): if ev.err != nil { return ev.err } @@ -352,7 +350,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error { var err error utils.Debugf("createNewSession with initial version %s", initialVersion) - c.session, c.handshakeChan, err = newClientSession( + c.session, err = newClientSession( c.conn, c.hostname, c.version, diff --git a/client_test.go b/client_test.go index b1286923..0b2438fa 100644 --- a/client_test.go +++ b/client_test.go @@ -24,7 +24,7 @@ var _ = Describe("Client", func() { packetConn *mockPacketConn addr net.Addr - originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error) + originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) (packetHandler, error) ) // generate a packet sent by the server that accepts the QUIC version suggested by the client @@ -42,7 +42,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) - msess, _, _ := newMockSession(nil, 0, 0, nil, nil, nil) + msess, _ := newMockSession(nil, 0, 0, nil, nil, nil) sess = msess.(*mockSession) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = &mockPacketConn{ @@ -86,9 +86,9 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, - ) (packetHandler, <-chan handshakeEvent, error) { + ) (packetHandler, error) { Expect(conn.Write([]byte("fake CHLO"))).To(Succeed()) - return sess, sess.handshakeChan, nil + return sess, nil } origGenerateConnectionID = generateConnectionID generateConnectionID = func() (protocol.ConnectionID, error) { @@ -178,9 +178,9 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, - ) (packetHandler, <-chan handshakeEvent, error) { + ) (packetHandler, error) { remoteAddrChan <- conn.RemoteAddr().String() - return sess, nil, nil + return sess, nil } dialed := make(chan struct{}) go func() { @@ -206,9 +206,9 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, - ) (packetHandler, <-chan handshakeEvent, error) { + ) (packetHandler, error) { hostnameChan <- h - return sess, nil, nil + return sess, nil } dialed := make(chan struct{}) go func() { @@ -302,8 +302,8 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, - ) (packetHandler, <-chan handshakeEvent, error) { - return nil, nil, testErr + ) (packetHandler, error) { + return nil, testErr } _, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).To(MatchError(testErr)) @@ -345,7 +345,7 @@ var _ = Describe("Client", func() { _ *Config, initialVersionP protocol.VersionNumber, negotiatedVersionsP []protocol.VersionNumber, - ) (packetHandler, <-chan handshakeEvent, error) { + ) (packetHandler, error) { initialVersion = initialVersionP negotiatedVersions = negotiatedVersionsP // make the server accept the new version @@ -353,11 +353,12 @@ var _ = Describe("Client", func() { packetConn.dataToRead = acceptClientVersionPacket(connectionID) } sess := &mockSession{ - connectionID: connectionID, - stopRunLoop: make(chan struct{}), + connectionID: connectionID, + stopRunLoop: make(chan struct{}), + handshakeChan: handshakeChan, } sessionChan <- sess - return sess, handshakeChan, nil + return sess, nil } established := make(chan struct{}) @@ -395,9 +396,9 @@ var _ = Describe("Client", func() { _ *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, - ) (packetHandler, <-chan handshakeEvent, error) { + ) (packetHandler, error) { atomic.AddUint32(&sessionCounter, 1) - return sess, nil, nil + return sess, nil } go cl.establishSecureConnection() Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) @@ -497,13 +498,13 @@ var _ = Describe("Client", func() { configP *Config, _ protocol.VersionNumber, _ []protocol.VersionNumber, - ) (packetHandler, <-chan handshakeEvent, error) { + ) (packetHandler, error) { cconn = connP hostname = hostnameP version = versionP conf = configP close(c) - return sess, nil, nil + return sess, nil } dialed := make(chan struct{}) go func() { diff --git a/server.go b/server.go index 10c55805..85303f3c 100644 --- a/server.go +++ b/server.go @@ -19,6 +19,7 @@ import ( // packetHandler handles packets type packetHandler interface { Session + handshakeStatus() <-chan handshakeEvent handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber run() error @@ -43,7 +44,7 @@ type server struct { sessionQueue chan Session errorChan chan struct{} - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, <-chan handshakeEvent, error) + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) } var _ Listener = &server{} @@ -291,8 +292,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet } utils.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) - var handshakeChan <-chan handshakeEvent - session, handshakeChan, err = s.newSession( + session, err = s.newSession( &conn{pconn: pconn, currentAddr: remoteAddr}, version, hdr.ConnectionID, @@ -315,7 +315,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet go func() { for { - ev := <-handshakeChan + ev := <-session.handshakeStatus() if ev.err != nil { return } diff --git a/server_test.go b/server_test.go index 252507f5..189d3653 100644 --- a/server_test.go +++ b/server_test.go @@ -60,12 +60,13 @@ func (s *mockSession) closeRemote(e error) { func (s *mockSession) OpenStream() (Stream, error) { return &stream{streamID: 1337}, nil } -func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } -func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } -func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } -func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } -func (*mockSession) Context() context.Context { panic("not implemented") } -func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } +func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } +func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } +func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } +func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } +func (*mockSession) Context() context.Context { panic("not implemented") } +func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } +func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } var _ Session = &mockSession{} var _ NonFWSession = &mockSession{} @@ -77,14 +78,14 @@ func newMockSession( _ *handshake.ServerConfig, _ *tls.Config, _ *Config, -) (packetHandler, <-chan handshakeEvent, error) { +) (packetHandler, error) { s := mockSession{ connectionID: connectionID, handshakeChan: make(chan handshakeEvent), handshakeComplete: make(chan error), stopRunLoop: make(chan struct{}), } - return &s, s.handshakeChan, nil + return &s, nil } var _ = Describe("Server", func() { @@ -217,7 +218,7 @@ var _ = Describe("Server", func() { }) It("closes sessions and the connection when Close is called", func() { - session, _, _ := newMockSession(nil, 0, 0, nil, nil, nil) + session, _ := newMockSession(nil, 0, 0, nil, nil, nil) serv.sessions[1] = session err := serv.Close() Expect(err).NotTo(HaveOccurred()) @@ -267,7 +268,7 @@ var _ = Describe("Server", func() { }, 0.5) It("closes all sessions when encountering a connection error", func() { - session, _, _ := newMockSession(nil, 0, 0, nil, nil, nil) + session, _ := newMockSession(nil, 0, 0, nil, nil, nil) serv.sessions[0x12345] = session Expect(serv.sessions[0x12345].(*mockSession).closed).To(BeFalse()) testErr := errors.New("connection error") diff --git a/session.go b/session.go index 2fb39016..418d38ad 100644 --- a/session.go +++ b/session.go @@ -101,7 +101,7 @@ type session struct { // handshakeChan receives handshake events and is closed as soon the handshake completes // the receiving end of this channel is passed to the creator of the session // it receives at most 3 handshake events: 2 when the encryption level changes, and one error - handshakeChan chan<- handshakeEvent + handshakeChan chan handshakeEvent lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire @@ -129,7 +129,7 @@ func newSession( sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, -) (packetHandler, <-chan handshakeEvent, error) { +) (packetHandler, error) { s := &session{ conn: conn, connectionID: connectionID, @@ -137,7 +137,7 @@ func newSession( version: v, config: config, } - return s.setup(sCfg, "", tlsConf, v, nil) + return s, s.setup(sCfg, "", tlsConf, v, nil) } // declare this as a variable, such that we can it mock it in the tests @@ -150,7 +150,7 @@ var newClientSession = func( config *Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton -) (packetHandler, <-chan handshakeEvent, error) { +) (packetHandler, error) { s := &session{ conn: conn, connectionID: connectionID, @@ -158,7 +158,7 @@ var newClientSession = func( version: v, config: config, } - return s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions) + return s, s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions) } func (s *session) setup( @@ -167,13 +167,12 @@ func (s *session) setup( tlsConf *tls.Config, initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber, -) (packetHandler, <-chan handshakeEvent, error) { +) error { aeadChanged := make(chan protocol.EncryptionLevel, 2) paramsChan := make(chan handshake.TransportParameters) s.aeadChanged = aeadChanged s.paramsChan = paramsChan - handshakeChan := make(chan handshakeEvent, 3) - s.handshakeChan = handshakeChan + s.handshakeChan = make(chan handshakeEvent, 3) s.handshakeCompleteChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) @@ -267,7 +266,7 @@ func (s *session) setup( } } if err != nil { - return nil, nil, err + return err } s.packer = newPacketPacker(s.connectionID, @@ -277,8 +276,7 @@ func (s *session) setup( s.version, ) s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - - return s, handshakeChan, nil + return nil } // run the session main loop @@ -891,6 +889,10 @@ func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } +func (s *session) handshakeStatus() <-chan handshakeEvent { + return s.handshakeChan +} + func (s *session) GetVersion() protocol.VersionNumber { return s.version } diff --git a/session_test.go b/session_test.go index 5b83dd37..3fff01fe 100644 --- a/session_test.go +++ b/session_test.go @@ -145,12 +145,11 @@ func areSessionsRunning() bool { var _ = Describe("Session", func() { var ( - sess *session - scfg *handshake.ServerConfig - mconn *mockConnection - cryptoSetup *mockCryptoSetup - handshakeChan <-chan handshakeEvent - aeadChanged chan<- protocol.EncryptionLevel + sess *session + scfg *handshake.ServerConfig + mconn *mockConnection + cryptoSetup *mockCryptoSetup + aeadChanged chan<- protocol.EncryptionLevel ) BeforeEach(func() { @@ -180,7 +179,7 @@ var _ = Describe("Session", func() { scfg, err = handshake.NewServerConfig(kex, certChain) Expect(err).NotTo(HaveOccurred()) var pSess Session - pSess, handshakeChan, err = newSession( + pSess, err = newSession( mconn, protocol.Version39, 0, @@ -229,7 +228,7 @@ var _ = Describe("Session", func() { paramCookie = cookie return false } - pSess, _, err := newSession( + pSess, err := newSession( mconn, protocol.Version39, 0, @@ -1219,7 +1218,7 @@ var _ = Describe("Session", func() { It("send a handshake event on the handshakeChan when the AEAD changes to secure", func(done Done) { go sess.run() aeadChanged <- protocol.EncryptionSecure - Eventually(handshakeChan).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionSecure})) + Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionSecure})) Expect(sess.Close(nil)).To(Succeed()) close(done) }) @@ -1227,7 +1226,7 @@ var _ = Describe("Session", func() { It("send a handshake event on the handshakeChan when the AEAD changes to forward-secure", func(done Done) { go sess.run() aeadChanged <- protocol.EncryptionForwardSecure - Eventually(handshakeChan).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionForwardSecure})) + Eventually(sess.handshakeStatus()).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionForwardSecure})) Expect(sess.Close(nil)).To(Succeed()) close(done) }) @@ -1235,7 +1234,7 @@ var _ = Describe("Session", func() { It("closes the handshakeChan when the handshake completes", func(done Done) { go sess.run() close(aeadChanged) - Eventually(handshakeChan).Should(BeClosed()) + Eventually(sess.handshakeStatus()).Should(BeClosed()) Expect(sess.Close(nil)).To(Succeed()) close(done) }) @@ -1244,7 +1243,7 @@ var _ = Describe("Session", func() { testErr := errors.New("handshake error") go sess.run() Expect(sess.Close(nil)).To(Succeed()) - Expect(handshakeChan).To(Receive(&handshakeEvent{err: testErr})) + Expect(sess.handshakeStatus()).To(Receive(&handshakeEvent{err: testErr})) close(done) }) @@ -1524,7 +1523,7 @@ var _ = Describe("Client Session", func() { } mconn = newMockConnection() - sessP, _, err := newClientSession( + sessP, err := newClientSession( mconn, "hostname", protocol.Version39,