refactor session to use a much smaller crypto setup

This commit is contained in:
Marten Seemann
2018-03-23 11:00:21 +00:00
parent c6526ad927
commit 69f6427d3c
2 changed files with 69 additions and 33 deletions

View File

@@ -44,6 +44,11 @@ type streamManager interface {
CloseWithError(error)
}
type cryptoStreamHandler interface {
HandleCryptoStream() error
ConnectionState() handshake.ConnectionState
}
type receivedPacket struct {
remoteAddr net.Addr
header *wire.Header
@@ -84,8 +89,8 @@ type session struct {
unpacker unpacker
packer *packetPacker
cryptoSetup handshake.CryptoSetup
divNonceChan chan<- []byte // only set for the client
cryptoStreamHandler cryptoStreamHandler
divNonceChan chan<- []byte // only set for the client
receivedPackets chan *receivedPacket
sendingScheduled chan struct{}
@@ -177,8 +182,20 @@ func newSession(
if err != nil {
return nil, err
}
s.cryptoSetup = cs
return s, s.postSetup(1)
s.cryptoStreamHandler = cs
s.unpacker = &packetUnpacker{aead: cs, version: s.version}
s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective)
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
s.packer = newPacketPacker(s.connectionID,
1,
s.sentPacketHandler.GetPacketNumberLen,
s.RemoteAddr(),
cs,
s.streamFramer,
s.perspective,
s.version,
)
return s, s.postSetup()
}
// declare this as a variable, so that we can it mock it in the tests
@@ -226,9 +243,21 @@ var newClientSession = func(
if err != nil {
return nil, err
}
s.cryptoSetup = cs
s.cryptoStreamHandler = cs
s.divNonceChan = divNonceChan
return s, s.postSetup(1)
s.unpacker = &packetUnpacker{aead: cs, version: s.version}
s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective)
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
s.packer = newPacketPacker(s.connectionID,
1,
s.sentPacketHandler.GetPacketNumberLen,
s.RemoteAddr(),
cs,
s.streamFramer,
s.perspective,
s.version,
)
return s, s.postSetup()
}
func newTLSServerSession(
@@ -252,19 +281,31 @@ func newTLSServerSession(
handshakeEvent: handshakeEvent,
}
s.preSetup()
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
cs := handshake.NewCryptoSetupTLSServer(
tls,
cryptoStreamConn,
nullAEAD,
handshakeEvent,
v,
)
if err := s.postSetup(initialPacketNumber); err != nil {
s.cryptoStreamHandler = cs
s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
s.packer = newPacketPacker(s.connectionID,
initialPacketNumber,
s.sentPacketHandler.GetPacketNumberLen,
s.RemoteAddr(),
cs,
s.streamFramer,
s.perspective,
s.version,
)
if err := s.postSetup(); err != nil {
return nil, err
}
s.peerParams = peerParams
s.processTransportParameters(peerParams)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
s.unpacker = &packetUnpacker{aead: cs, version: s.version}
return s, nil
}
@@ -302,12 +343,25 @@ var newTLSClientSession = func(
if err != nil {
return nil, err
}
s.cryptoSetup = cs
return s, s.postSetup(initialPacketNumber)
s.cryptoStreamHandler = cs
s.unpacker = &packetUnpacker{aead: cs, version: s.version}
s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
s.packer = newPacketPacker(s.connectionID,
initialPacketNumber,
s.sentPacketHandler.GetPacketNumberLen,
s.RemoteAddr(),
cs,
s.streamFramer,
s.perspective,
s.version,
)
return s, s.postSetup()
}
func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
@@ -316,7 +370,7 @@ func (s *session) preSetup() {
s.cryptoStream = s.newCryptoStream()
}
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
func (s *session) postSetup() error {
s.handshakeChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
@@ -329,26 +383,8 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.lastNetworkActivityTime = now
s.sessionCreationTime = now
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.version)
if s.version.UsesTLS() {
s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
} else {
s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective)
}
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
s.packer = newPacketPacker(s.connectionID,
initialPacketNumber,
s.sentPacketHandler.GetPacketNumberLen,
s.RemoteAddr(),
s.cryptoSetup,
s.streamFramer,
s.perspective,
s.version,
)
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return nil
}
@@ -357,7 +393,7 @@ func (s *session) run() error {
defer s.ctxCancel()
go func() {
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
if err := s.cryptoStreamHandler.HandleCryptoStream(); err != nil {
s.Close(err)
}
}()
@@ -472,7 +508,7 @@ func (s *session) Context() context.Context {
}
func (s *session) ConnectionState() ConnectionState {
return s.cryptoSetup.ConnectionState()
return s.cryptoStreamHandler.ConnectionState()
}
func (s *session) maybeResetTimer() {

View File

@@ -1274,7 +1274,7 @@ var _ = Describe("Session", func() {
BeforeEach(func() {
sess.unpacker = &mockUnpacker{unpackErr: qerr.Error(qerr.DecryptionFailure, "")}
sess.cryptoSetup = &mockCryptoSetup{}
sess.cryptoStreamHandler = &mockCryptoSetup{}
streamManager.EXPECT().CloseWithError(gomock.Any()).MaxTimes(1)
})