diff --git a/server.go b/server.go index 0c2410df..06c258c3 100644 --- a/server.go +++ b/server.go @@ -38,7 +38,7 @@ type server struct { sessionQueue chan Session errorChan chan struct{} - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, config *Config) (packetHandler, error) + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, config *Config) (packetHandler, <-chan handshakeEvent, error) } var _ Listener = &server{} @@ -211,12 +211,12 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet } utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr) - session, err = s.newSession( + var handshakeChan <-chan handshakeEvent + session, handshakeChan, err = s.newSession( &conn{pconn: pconn, currentAddr: remoteAddr}, version, hdr.ConnectionID, s.scfg, - s.cryptoChangeCallback, s.config, ) if err != nil { @@ -229,9 +229,21 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet go func() { // session.run() returns as soon as the session is closed _ = session.run() - s.removeConnection(hdr.ConnectionID) }() + + go func() { + for { + ev := <-handshakeChan + if ev.err != nil { + return + } + if ev.encLevel == protocol.EncryptionForwardSecure { + break + } + } + s.sessionQueue <- session + }() } if session == nil { // Late packet for closed session @@ -246,12 +258,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return nil } -func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) { - if isForwardSecure { - s.sessionQueue <- session - } -} - func (s *server) removeConnection(id protocol.ConnectionID) { s.sessionsMutex.Lock() s.sessions[id] = nil diff --git a/server_test.go b/server_test.go index 4afec7c4..248385ae 100644 --- a/server_test.go +++ b/server_test.go @@ -18,11 +18,12 @@ import ( ) type mockSession struct { - connectionID protocol.ConnectionID - packetCount int - closed bool - closeReason error - stopRunLoop chan struct{} // run returns as soon as this channel receives a value + connectionID protocol.ConnectionID + packetCount int + closed bool + closeReason error + stopRunLoop chan struct{} // run returns as soon as this channel receives a value + handshakeChan chan handshakeEvent } func (s *mockSession) handlePacket(*receivedPacket) { @@ -56,11 +57,19 @@ func (s *mockSession) RemoteAddr() net.Addr { var _ Session = &mockSession{} -func newMockSession(_ connection, _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, _ cryptoChangeCallback, _ *Config) (packetHandler, error) { - return &mockSession{ - connectionID: connectionID, - stopRunLoop: make(chan struct{}), - }, nil +func newMockSession( + _ connection, + _ protocol.VersionNumber, + connectionID protocol.ConnectionID, + _ *handshake.ServerConfig, + _ *Config, +) (packetHandler, <-chan handshakeEvent, error) { + s := mockSession{ + connectionID: connectionID, + handshakeChan: make(chan handshakeEvent), + stopRunLoop: make(chan struct{}), + } + return &s, s.handshakeChan, nil } var _ = Describe("Server", func() { @@ -133,14 +142,33 @@ var _ = Describe("Server", func() { acceptedSess, err = serv.Accept() Expect(err).ToNot(HaveOccurred()) }() - sess := &mockSession{} - // serv.cryptoChangeCallback(sess, false) - // Consistently(func() Session { return acceptedSess }).Should(BeNil()) - serv.cryptoChangeCallback(sess, true) + err := serv.handlePacket(nil, nil, firstPacket) + Expect(err).ToNot(HaveOccurred()) + Expect(serv.sessions).To(HaveLen(1)) + sess := serv.sessions[connID].(*mockSession) + sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} + Consistently(func() Session { return acceptedSess }).Should(BeNil()) + sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionForwardSecure} Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) close(done) }, 0.5) + It("doesn't accept session that error during the handshake", func(done Done) { + var accepted bool + go func() { + defer GinkgoRecover() + serv.Accept() + accepted = true + }() + err := serv.handlePacket(nil, nil, firstPacket) + Expect(err).ToNot(HaveOccurred()) + Expect(serv.sessions).To(HaveLen(1)) + sess := serv.sessions[connID].(*mockSession) + sess.handshakeChan <- handshakeEvent{err: errors.New("handshake failed")} + Consistently(func() bool { return accepted }).Should(BeFalse()) + close(done) + }) + It("assigns packets to existing sessions", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) diff --git a/session.go b/session.go index b8f96d42..27ec2602 100644 --- a/session.go +++ b/session.go @@ -34,10 +34,6 @@ var ( errSessionAlreadyClosed = errors.New("cannot close session; it was already closed before") ) -// cryptoChangeCallback is called every time the encryption level changes -// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that -type cryptoChangeCallback func(session Session, isForwardSecure bool) - type handshakeEvent struct { encLevel protocol.EncryptionLevel err error @@ -55,8 +51,6 @@ type session struct { version protocol.VersionNumber config *Config - cryptoChangeCallback cryptoChangeCallback - conn connection streamsMap *streamsMap @@ -95,6 +89,7 @@ type session struct { handshakeCompleteChan chan error // handshakeChan receives handshake events and is closed as soon the handshake completes // the receiving end of this channel is passed to the creator of the session + // it receives at most 3 handshake events: 2 when the encryption level changes, and one error handshakeChan chan<- handshakeEvent nextAckScheduledTime time.Time @@ -122,9 +117,8 @@ func newSession( v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, - cryptoChangeCallback cryptoChangeCallback, config *Config, -) (packetHandler, error) { +) (packetHandler, <-chan handshakeEvent, error) { s := &session{ conn: conn, connectionID: connectionID, @@ -132,7 +126,6 @@ func newSession( version: v, config: config, - cryptoChangeCallback: cryptoChangeCallback, connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v), } @@ -147,18 +140,18 @@ func newSession( } aeadChanged := make(chan protocol.EncryptionLevel, 2) s.aeadChanged = aeadChanged - handshakeChan := make(chan handshakeEvent, 1) + handshakeChan := make(chan handshakeEvent, 3) s.handshakeChan = handshakeChan var err error s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, config.Versions, aeadChanged) if err != nil { - return nil, err + return nil, nil, err } s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version) s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - return s, err + return s, handshakeChan, err } func newClientSession( @@ -184,7 +177,7 @@ func newClientSession( aeadChanged := make(chan protocol.EncryptionLevel, 2) s.aeadChanged = aeadChanged - handshakeChan := make(chan handshakeEvent, 1) + handshakeChan := make(chan handshakeEvent, 3) s.handshakeChan = handshakeChan cryptoStream, _ := s.OpenStream() var err error @@ -295,12 +288,7 @@ runLoop: s.packer.SetForwardSecure() } s.tryDecryptingQueuedPackets() - // TODO: remove this, when removing the cryptoChangeCallback for the server - if s.perspective == protocol.PerspectiveServer { - s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure) - } else { - s.handshakeChan <- handshakeEvent{encLevel: l} - } + s.handshakeChan <- handshakeEvent{encLevel: l} } } diff --git a/session_test.go b/session_test.go index dbb52a74..e476d0af 100644 --- a/session_test.go +++ b/session_test.go @@ -125,12 +125,13 @@ func areSessionsRunning() bool { var _ = Describe("Session", func() { var ( - sess *session - clientSess *session - scfg *handshake.ServerConfig - mconn *mockConnection - cpm *mockConnectionParametersManager - aeadChanged chan<- protocol.EncryptionLevel + sess *session + clientSess *session + scfg *handshake.ServerConfig + mconn *mockConnection + cpm *mockConnectionParametersManager + aeadChanged chan<- protocol.EncryptionLevel + handshakeChan <-chan handshakeEvent ) BeforeEach(func() { @@ -144,12 +145,12 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) scfg, err = handshake.NewServerConfig(kex, certChain) Expect(err).NotTo(HaveOccurred()) - pSess, err := newSession( + var pSess Session + pSess, handshakeChan, err = newSession( mconn, protocol.Version35, 0, scfg, - func(Session, bool) {}, populateServerConfig(&Config{}), ) Expect(err).NotTo(HaveOccurred()) @@ -181,12 +182,11 @@ var _ = Describe("Session", func() { Context("source address", func() { It("uses the IP address if given an UDP connection", func() { conn := &conn{currentAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}} - sess, err := newSession( + sess, _, err := newSession( conn, protocol.VersionWhatever, 0, scfg, - func(Session, bool) {}, populateServerConfig(&Config{}), ) Expect(err).ToNot(HaveOccurred()) @@ -197,12 +197,11 @@ var _ = Describe("Session", func() { conn := &conn{ currentAddr: &net.TCPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}, } - sess, err := newSession( + sess, _, err := newSession( conn, protocol.VersionWhatever, 0, scfg, - func(Session, bool) {}, populateServerConfig(&Config{}), ) Expect(err).ToNot(HaveOccurred()) @@ -1244,13 +1243,14 @@ var _ = Describe("Session", func() { }) }) - It("tells the packetPacker when forward-secure encryption is used", func() { + It("tells the packetPacker when forward-secure encryption is used", func(done Done) { go sess.run() - defer sess.Close(nil) aeadChanged <- protocol.EncryptionSecure Consistently(func() bool { return sess.packer.isForwardSecure }).Should(BeFalse()) aeadChanged <- protocol.EncryptionForwardSecure Eventually(func() bool { return sess.packer.isForwardSecure }).Should(BeTrue()) + Expect(sess.Close(nil)).To(Succeed()) + close(done) }) It("closes when crypto stream errors", func() { @@ -1350,39 +1350,45 @@ var _ = Describe("Session", func() { }) }) - It("calls the cryptoChangeCallback when the AEAD changes", func(done Done) { - var callbackCalled bool - var callbackCalledWith bool - var callbackSession Session - cb := func(s Session, p bool) { - callbackCalled = true - callbackCalledWith = p - callbackSession = s - } - sess.cryptoChangeCallback = cb - aeadChanged <- protocol.EncryptionSecure + It("send a handshake event on the handshakeChan when the AEAD changes to secure", func(done Done) { go sess.run() - defer sess.Close(nil) - Eventually(func() bool { return callbackCalled }).Should(BeTrue()) - Expect(callbackCalledWith).To(BeFalse()) - Expect(callbackSession).To(Equal(sess)) + aeadChanged <- protocol.EncryptionSecure + Eventually(handshakeChan).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionSecure})) + Expect(sess.Close(nil)).To(Succeed()) close(done) }) - It("calls the cryptoChangeCallback when the AEAD changes to forward secure encryption", func(done Done) { - var callbackCalledWith bool - var callbackSession Session - cb := func(s Session, p bool) { - callbackSession = s - callbackCalledWith = p - } - sess.cryptoChangeCallback = cb - aeadChanged <- protocol.EncryptionForwardSecure - close(aeadChanged) + It("send a handshake event on the handshakeChan when the AEAD changes to forward-secure", func(done Done) { go sess.run() - defer sess.Close(nil) - Eventually(func() bool { return callbackCalledWith }).Should(BeTrue()) - Expect(callbackSession).To(Equal(sess)) + aeadChanged <- protocol.EncryptionForwardSecure + Eventually(handshakeChan).Should(Receive(&handshakeEvent{encLevel: protocol.EncryptionForwardSecure})) + Expect(sess.Close(nil)).To(Succeed()) + close(done) + }) + + It("closes the handshakeChan when the handshake completes", func(done Done) { + go sess.run() + close(aeadChanged) + Eventually(handshakeChan).Should(BeClosed()) + Expect(sess.Close(nil)).To(Succeed()) + close(done) + }) + + It("passes errors to the handshakeChan", func(done Done) { + testErr := errors.New("handshake error") + go sess.run() + Expect(sess.Close(nil)).To(Succeed()) + Expect(handshakeChan).To(Receive(&handshakeEvent{err: testErr})) + close(done) + }) + + It("does not block if an error occurs", func(done Done) { + // this test basically tests that the handshakeChan has a capacity of 3 + // The session needs to run (and close) properly, even if no one is receiving from the handshakeChan + go sess.run() + aeadChanged <- protocol.EncryptionSecure + aeadChanged <- protocol.EncryptionForwardSecure + Expect(sess.Close(nil)).To(Succeed()) close(done) })