start running TLS server session directly when creating them

This commit is contained in:
Marten Seemann
2018-05-20 13:56:26 +08:00
parent e98228df9d
commit ef34d9e85f
4 changed files with 23 additions and 9 deletions

View File

@@ -165,11 +165,9 @@ func (s *server) setupTLS() error {
case <-s.errorChan:
return
case tlsSession := <-sessionChan:
sess := tlsSession.sess
// The connection ID is a randomly chosen 8 byte value.
// It is safe to assume that it doesn't collide with other randomly chosen values.
s.sessionHandler.Add(tlsSession.connID, sess)
go sess.run()
s.sessionHandler.Add(tlsSession.connID, tlsSession.sess)
}
}
}()

View File

@@ -153,17 +153,18 @@ var _ = Describe("Server", func() {
It("accepts new TLS sessions", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
run := make(chan struct{})
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().run().Do(func() { close(run) })
err := serv.setupTLS()
Expect(err).ToNot(HaveOccurred())
sessionHandler.EXPECT().Add(connID, sess)
added := make(chan struct{})
sessionHandler.EXPECT().Add(connID, sess).Do(func(protocol.ConnectionID, packetHandler) {
close(added)
})
serv.serverTLS.sessionChan <- tlsSession{
connID: connID,
sess: sess,
}
Eventually(run).Should(BeClosed())
Eventually(added).Should(BeClosed())
})
It("accepts a session once the connection it is forward secure", func() {

View File

@@ -42,6 +42,8 @@ type serverTLS struct {
params *handshake.TransportParameters
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
newSession func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, handshake.MintTLS, *handshake.CryptoStreamConn, crypto.AEAD, *handshake.TransportParameters, protocol.VersionNumber, utils.Logger) (packetHandler, error)
sessionRunner sessionRunner
sessionChan chan<- tlsSession
@@ -83,7 +85,8 @@ func newServerTLS(
MaxBidiStreams: uint16(config.MaxIncomingStreams),
MaxUniStreams: uint16(config.MaxIncomingUniStreams),
},
logger: logger,
newSession: newTLSServerSession,
logger: logger,
}
s.newMintConn = s.newMintConnImpl
return s, sessionChan, nil
@@ -225,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
return nil, nil, err
}
s.logger.Debugf("Changing source connection ID to %s.", connID)
sess, err := newTLSServerSession(
sess, err := s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionRunner,
hdr.SrcConnectionID,
@@ -245,5 +248,6 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
cs := sess.getCryptoStream()
cs.setReadOffset(frame.DataLen())
bc.SetStream(cs)
go sess.run()
return sess, connID, nil
}

View File

@@ -5,6 +5,7 @@ import (
"io"
"github.com/bifurcation/mint"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
@@ -143,6 +144,15 @@ var _ = Describe("Stateless TLS handling", func() {
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert).Do(func() {
mintReply.Write([]byte("Server Hello"))
})
run := make(chan struct{})
server.newSession = func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, handshake.MintTLS, *handshake.CryptoStreamConn, crypto.AEAD, *handshake.TransportParameters, protocol.VersionNumber, utils.Logger) (packetHandler, error) {
sess := NewMockPacketHandler(mockCtrl)
cryptoStream := NewMockCryptoStream(mockCtrl)
cryptoStream.EXPECT().setReadOffset(gomock.Any())
sess.EXPECT().getCryptoStream().Return(cryptoStream)
sess.EXPECT().run().Do(func() { close(run) })
return sess, nil
}
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert)
mintTLS.EXPECT().State().Return(mint.StateServerNegotiated)
mintTLS.EXPECT().State().Return(mint.StateServerWaitFlight2)
@@ -163,6 +173,7 @@ var _ = Describe("Stateless TLS handling", func() {
// make sure we're using a server-generated connection ID
Expect(tlsSess.connID).ToNot(Equal(hdr.SrcConnectionID))
Expect(tlsSess.connID).ToNot(Equal(hdr.DestConnectionID))
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
})