diff --git a/server.go b/server.go index 93cea4d87..dc23b8687 100644 --- a/server.go +++ b/server.go @@ -133,9 +133,12 @@ func (s *server) setupTLS() error { select { case <-s.errorChan: return - case sess := <-sessionChan: - // TODO: think about what to do with connection ID collisions - connID := sess.(*session).connectionID + case tlsSession := <-sessionChan: + connID := tlsSession.connID + sess := tlsSession.sess + if _, ok := s.sessions[connID]; ok { // drop this session if it already exists + return + } s.sessionsMutex.Lock() s.sessions[connID] = sess s.sessionsMutex.Unlock() diff --git a/server_test.go b/server_test.go index 467a95664..f48c11cfb 100644 --- a/server_test.go +++ b/server_test.go @@ -138,6 +138,51 @@ var _ = Describe("Server", func() { Expect(sess.packetCount).To(Equal(1)) }) + It("accepts new TLS sessions", func() { + connID := protocol.ConnectionID(0x12345) + sess, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil) + Expect(err).ToNot(HaveOccurred()) + err = serv.setupTLS() + Expect(err).ToNot(HaveOccurred()) + serv.serverTLS.sessionChan <- tlsSession{ + connID: connID, + sess: sess, + } + Eventually(func() packetHandler { + serv.sessionsMutex.Lock() + defer serv.sessionsMutex.Unlock() + return serv.sessions[connID] + }).Should(Equal(sess)) + }) + + It("only accepts one new TLS sessions for one connection ID", func() { + connID := protocol.ConnectionID(0x12345) + sess1, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil) + Expect(err).ToNot(HaveOccurred()) + sess2, err := newMockSession(nil, protocol.VersionTLS, connID, nil, nil, nil) + Expect(err).ToNot(HaveOccurred()) + err = serv.setupTLS() + Expect(err).ToNot(HaveOccurred()) + serv.serverTLS.sessionChan <- tlsSession{ + connID: connID, + sess: sess1, + } + Eventually(func() packetHandler { + serv.sessionsMutex.Lock() + defer serv.sessionsMutex.Unlock() + return serv.sessions[connID] + }).Should(Equal(sess1)) + serv.serverTLS.sessionChan <- tlsSession{ + connID: connID, + sess: sess2, + } + Eventually(func() packetHandler { + serv.sessionsMutex.Lock() + defer serv.sessionsMutex.Unlock() + return serv.sessions[connID] + }).Should(Equal(sess1)) + }) + It("accepts a session once the connection it is forward secure", func(done Done) { var acceptedSess Session go func() { diff --git a/server_tls.go b/server_tls.go index c45de35b7..5f270e349 100644 --- a/server_tls.go +++ b/server_tls.go @@ -26,6 +26,11 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass return data, protocol.EncryptionUnencrypted, err } +type tlsSession struct { + connID protocol.ConnectionID + sess packetHandler +} + type serverTLS struct { conn net.PacketConn config *Config @@ -35,7 +40,7 @@ type serverTLS struct { params *handshake.TransportParameters newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) - sessionChan chan<- packetHandler + sessionChan chan<- tlsSession } func newServerTLS( @@ -43,7 +48,7 @@ func newServerTLS( config *Config, cookieHandler *handshake.CookieHandler, tlsConf *tls.Config, -) (*serverTLS, <-chan packetHandler, error) { +) (*serverTLS, <-chan tlsSession, error) { mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer) if err != nil { return nil, nil, err @@ -56,7 +61,7 @@ func newServerTLS( mconf.CookieProtector = cs mconf.CookieHandler = cookieHandler - sessionChan := make(chan packetHandler) + sessionChan := make(chan tlsSession) s := &serverTLS{ conn: conn, config: config, @@ -86,7 +91,10 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data [] if sess == nil { // a stateless reset was done return } - s.sessionChan <- sess + s.sessionChan <- tlsSession{ + connID: hdr.ConnectionID, + sess: sess, + } } // will be set to s.newMintConn by the constructor diff --git a/server_tls_test.go b/server_tls_test.go index 7cead5cd8..54f94372a 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -22,7 +22,7 @@ var _ = Describe("Stateless TLS handling", func() { var ( conn *mockPacketConn server *serverTLS - sessionChan <-chan packetHandler + sessionChan <-chan tlsSession mintTLS *mockhandshake.MockMintTLS extHandler *mocks.MockTLSExtensionHandler mintReply io.Writer