diff --git a/server.go b/server.go index ea0ab833f..77df23ce4 100644 --- a/server.go +++ b/server.go @@ -165,11 +165,9 @@ func (s *server) setupTLS() error { case <-s.errorChan: return case tlsSession := <-sessionChan: - sess := tlsSession.sess // The connection ID is a randomly chosen 8 byte value. // It is safe to assume that it doesn't collide with other randomly chosen values. - s.sessionHandler.Add(tlsSession.connID, sess) - go sess.run() + s.sessionHandler.Add(tlsSession.connID, tlsSession.sess) } } }() diff --git a/server_test.go b/server_test.go index 1871dc775..7449f708f 100644 --- a/server_test.go +++ b/server_test.go @@ -153,17 +153,18 @@ var _ = Describe("Server", func() { It("accepts new TLS sessions", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - run := make(chan struct{}) sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().run().Do(func() { close(run) }) err := serv.setupTLS() Expect(err).ToNot(HaveOccurred()) - sessionHandler.EXPECT().Add(connID, sess) + added := make(chan struct{}) + sessionHandler.EXPECT().Add(connID, sess).Do(func(protocol.ConnectionID, packetHandler) { + close(added) + }) serv.serverTLS.sessionChan <- tlsSession{ connID: connID, sess: sess, } - Eventually(run).Should(BeClosed()) + Eventually(added).Should(BeClosed()) }) It("accepts a session once the connection it is forward secure", func() { diff --git a/server_tls.go b/server_tls.go index f240b870b..4a60f4ba5 100644 --- a/server_tls.go +++ b/server_tls.go @@ -42,6 +42,8 @@ type serverTLS struct { params *handshake.TransportParameters newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) + newSession func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, handshake.MintTLS, *handshake.CryptoStreamConn, crypto.AEAD, *handshake.TransportParameters, protocol.VersionNumber, utils.Logger) (packetHandler, error) + sessionRunner sessionRunner sessionChan chan<- tlsSession @@ -83,7 +85,8 @@ func newServerTLS( MaxBidiStreams: uint16(config.MaxIncomingStreams), MaxUniStreams: uint16(config.MaxIncomingUniStreams), }, - logger: logger, + newSession: newTLSServerSession, + logger: logger, } s.newMintConn = s.newMintConnImpl return s, sessionChan, nil @@ -225,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, return nil, nil, err } s.logger.Debugf("Changing source connection ID to %s.", connID) - sess, err := newTLSServerSession( + sess, err := s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, s.sessionRunner, hdr.SrcConnectionID, @@ -245,5 +248,6 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, cs := sess.getCryptoStream() cs.setReadOffset(frame.DataLen()) bc.SetStream(cs) + go sess.run() return sess, connID, nil } diff --git a/server_tls_test.go b/server_tls_test.go index 37dedb65d..214fe5bec 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -5,6 +5,7 @@ import ( "io" "github.com/bifurcation/mint" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" @@ -143,6 +144,15 @@ var _ = Describe("Stateless TLS handling", func() { mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert).Do(func() { mintReply.Write([]byte("Server Hello")) }) + run := make(chan struct{}) + server.newSession = func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, handshake.MintTLS, *handshake.CryptoStreamConn, crypto.AEAD, *handshake.TransportParameters, protocol.VersionNumber, utils.Logger) (packetHandler, error) { + sess := NewMockPacketHandler(mockCtrl) + cryptoStream := NewMockCryptoStream(mockCtrl) + cryptoStream.EXPECT().setReadOffset(gomock.Any()) + sess.EXPECT().getCryptoStream().Return(cryptoStream) + sess.EXPECT().run().Do(func() { close(run) }) + return sess, nil + } mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert) mintTLS.EXPECT().State().Return(mint.StateServerNegotiated) mintTLS.EXPECT().State().Return(mint.StateServerWaitFlight2) @@ -163,6 +173,7 @@ var _ = Describe("Stateless TLS handling", func() { // make sure we're using a server-generated connection ID Expect(tlsSess.connID).ToNot(Equal(hdr.SrcConnectionID)) Expect(tlsSess.connID).ToNot(Equal(hdr.DestConnectionID)) + Eventually(run).Should(BeClosed()) Eventually(done).Should(BeClosed()) })