diff --git a/session.go b/session.go index a0178180..9996e3d5 100644 --- a/session.go +++ b/session.go @@ -44,6 +44,11 @@ type streamManager interface { CloseWithError(error) } +type cryptoStreamHandler interface { + HandleCryptoStream() error + ConnectionState() handshake.ConnectionState +} + type receivedPacket struct { remoteAddr net.Addr header *wire.Header @@ -84,8 +89,8 @@ type session struct { unpacker unpacker packer *packetPacker - cryptoSetup handshake.CryptoSetup - divNonceChan chan<- []byte // only set for the client + cryptoStreamHandler cryptoStreamHandler + divNonceChan chan<- []byte // only set for the client receivedPackets chan *receivedPacket sendingScheduled chan struct{} @@ -177,8 +182,20 @@ func newSession( if err != nil { return nil, err } - s.cryptoSetup = cs - return s, s.postSetup(1) + s.cryptoStreamHandler = cs + s.unpacker = &packetUnpacker{aead: cs, version: s.version} + s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.packer = newPacketPacker(s.connectionID, + 1, + s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), + cs, + s.streamFramer, + s.perspective, + s.version, + ) + return s, s.postSetup() } // declare this as a variable, so that we can it mock it in the tests @@ -226,9 +243,21 @@ var newClientSession = func( if err != nil { return nil, err } - s.cryptoSetup = cs + s.cryptoStreamHandler = cs s.divNonceChan = divNonceChan - return s, s.postSetup(1) + s.unpacker = &packetUnpacker{aead: cs, version: s.version} + s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.packer = newPacketPacker(s.connectionID, + 1, + s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), + cs, + s.streamFramer, + s.perspective, + s.version, + ) + return s, s.postSetup() } func newTLSServerSession( @@ -252,19 +281,31 @@ func newTLSServerSession( handshakeEvent: handshakeEvent, } s.preSetup() - s.cryptoSetup = handshake.NewCryptoSetupTLSServer( + cs := handshake.NewCryptoSetupTLSServer( tls, cryptoStreamConn, nullAEAD, handshakeEvent, v, ) - if err := s.postSetup(initialPacketNumber); err != nil { + s.cryptoStreamHandler = cs + s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.packer = newPacketPacker(s.connectionID, + initialPacketNumber, + s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), + cs, + s.streamFramer, + s.perspective, + s.version, + ) + if err := s.postSetup(); err != nil { return nil, err } s.peerParams = peerParams s.processTransportParameters(peerParams) - s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} + s.unpacker = &packetUnpacker{aead: cs, version: s.version} return s, nil } @@ -302,12 +343,25 @@ var newTLSClientSession = func( if err != nil { return nil, err } - s.cryptoSetup = cs - return s, s.postSetup(initialPacketNumber) + s.cryptoStreamHandler = cs + s.unpacker = &packetUnpacker{aead: cs, version: s.version} + s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) + s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) + s.packer = newPacketPacker(s.connectionID, + initialPacketNumber, + s.sentPacketHandler.GetPacketNumberLen, + s.RemoteAddr(), + cs, + s.streamFramer, + s.perspective, + s.version, + ) + return s, s.postSetup() } func (s *session) preSetup() { s.rttStats = &congestion.RTTStats{} + s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ReceiveConnectionFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), @@ -316,7 +370,7 @@ func (s *session) preSetup() { s.cryptoStream = s.newCryptoStream() } -func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { +func (s *session) postSetup() error { s.handshakeChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) @@ -329,26 +383,8 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.version) - - if s.version.UsesTLS() { - s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) - } else { - s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) - } - s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) - s.packer = newPacketPacker(s.connectionID, - initialPacketNumber, - s.sentPacketHandler.GetPacketNumberLen, - s.RemoteAddr(), - s.cryptoSetup, - s.streamFramer, - s.perspective, - s.version, - ) s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame) - s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} return nil } @@ -357,7 +393,7 @@ func (s *session) run() error { defer s.ctxCancel() go func() { - if err := s.cryptoSetup.HandleCryptoStream(); err != nil { + if err := s.cryptoStreamHandler.HandleCryptoStream(); err != nil { s.Close(err) } }() @@ -472,7 +508,7 @@ func (s *session) Context() context.Context { } func (s *session) ConnectionState() ConnectionState { - return s.cryptoSetup.ConnectionState() + return s.cryptoStreamHandler.ConnectionState() } func (s *session) maybeResetTimer() { diff --git a/session_test.go b/session_test.go index 07db95c9..4a21f589 100644 --- a/session_test.go +++ b/session_test.go @@ -1274,7 +1274,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { sess.unpacker = &mockUnpacker{unpackErr: qerr.Error(qerr.DecryptionFailure, "")} - sess.cryptoSetup = &mockCryptoSetup{} + sess.cryptoStreamHandler = &mockCryptoSetup{} streamManager.EXPECT().CloseWithError(gomock.Any()).MaxTimes(1) })