forked from quic-go/quic-go
refactor session to use a much smaller crypto setup
This commit is contained in:
100
session.go
100
session.go
@@ -44,6 +44,11 @@ type streamManager interface {
|
||||
CloseWithError(error)
|
||||
}
|
||||
|
||||
type cryptoStreamHandler interface {
|
||||
HandleCryptoStream() error
|
||||
ConnectionState() handshake.ConnectionState
|
||||
}
|
||||
|
||||
type receivedPacket struct {
|
||||
remoteAddr net.Addr
|
||||
header *wire.Header
|
||||
@@ -84,8 +89,8 @@ type session struct {
|
||||
unpacker unpacker
|
||||
packer *packetPacker
|
||||
|
||||
cryptoSetup handshake.CryptoSetup
|
||||
divNonceChan chan<- []byte // only set for the client
|
||||
cryptoStreamHandler cryptoStreamHandler
|
||||
divNonceChan chan<- []byte // only set for the client
|
||||
|
||||
receivedPackets chan *receivedPacket
|
||||
sendingScheduled chan struct{}
|
||||
@@ -177,8 +182,20 @@ func newSession(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cryptoSetup = cs
|
||||
return s, s.postSetup(1)
|
||||
s.cryptoStreamHandler = cs
|
||||
s.unpacker = &packetUnpacker{aead: cs, version: s.version}
|
||||
s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective)
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
|
||||
s.packer = newPacketPacker(s.connectionID,
|
||||
1,
|
||||
s.sentPacketHandler.GetPacketNumberLen,
|
||||
s.RemoteAddr(),
|
||||
cs,
|
||||
s.streamFramer,
|
||||
s.perspective,
|
||||
s.version,
|
||||
)
|
||||
return s, s.postSetup()
|
||||
}
|
||||
|
||||
// declare this as a variable, so that we can it mock it in the tests
|
||||
@@ -226,9 +243,21 @@ var newClientSession = func(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cryptoSetup = cs
|
||||
s.cryptoStreamHandler = cs
|
||||
s.divNonceChan = divNonceChan
|
||||
return s, s.postSetup(1)
|
||||
s.unpacker = &packetUnpacker{aead: cs, version: s.version}
|
||||
s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective)
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
|
||||
s.packer = newPacketPacker(s.connectionID,
|
||||
1,
|
||||
s.sentPacketHandler.GetPacketNumberLen,
|
||||
s.RemoteAddr(),
|
||||
cs,
|
||||
s.streamFramer,
|
||||
s.perspective,
|
||||
s.version,
|
||||
)
|
||||
return s, s.postSetup()
|
||||
}
|
||||
|
||||
func newTLSServerSession(
|
||||
@@ -252,19 +281,31 @@ func newTLSServerSession(
|
||||
handshakeEvent: handshakeEvent,
|
||||
}
|
||||
s.preSetup()
|
||||
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
|
||||
cs := handshake.NewCryptoSetupTLSServer(
|
||||
tls,
|
||||
cryptoStreamConn,
|
||||
nullAEAD,
|
||||
handshakeEvent,
|
||||
v,
|
||||
)
|
||||
if err := s.postSetup(initialPacketNumber); err != nil {
|
||||
s.cryptoStreamHandler = cs
|
||||
s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
|
||||
s.packer = newPacketPacker(s.connectionID,
|
||||
initialPacketNumber,
|
||||
s.sentPacketHandler.GetPacketNumberLen,
|
||||
s.RemoteAddr(),
|
||||
cs,
|
||||
s.streamFramer,
|
||||
s.perspective,
|
||||
s.version,
|
||||
)
|
||||
if err := s.postSetup(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.peerParams = peerParams
|
||||
s.processTransportParameters(peerParams)
|
||||
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
|
||||
s.unpacker = &packetUnpacker{aead: cs, version: s.version}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -302,12 +343,25 @@ var newTLSClientSession = func(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cryptoSetup = cs
|
||||
return s, s.postSetup(initialPacketNumber)
|
||||
s.cryptoStreamHandler = cs
|
||||
s.unpacker = &packetUnpacker{aead: cs, version: s.version}
|
||||
s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
|
||||
s.packer = newPacketPacker(s.connectionID,
|
||||
initialPacketNumber,
|
||||
s.sentPacketHandler.GetPacketNumberLen,
|
||||
s.RemoteAddr(),
|
||||
cs,
|
||||
s.streamFramer,
|
||||
s.perspective,
|
||||
s.version,
|
||||
)
|
||||
return s, s.postSetup()
|
||||
}
|
||||
|
||||
func (s *session) preSetup() {
|
||||
s.rttStats = &congestion.RTTStats{}
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.ReceiveConnectionFlowControlWindow,
|
||||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||
@@ -316,7 +370,7 @@ func (s *session) preSetup() {
|
||||
s.cryptoStream = s.newCryptoStream()
|
||||
}
|
||||
|
||||
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
|
||||
func (s *session) postSetup() error {
|
||||
s.handshakeChan = make(chan error, 1)
|
||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
@@ -329,26 +383,8 @@ func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
|
||||
s.lastNetworkActivityTime = now
|
||||
s.sessionCreationTime = now
|
||||
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.version)
|
||||
|
||||
if s.version.UsesTLS() {
|
||||
s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
|
||||
} else {
|
||||
s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective)
|
||||
}
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
|
||||
s.packer = newPacketPacker(s.connectionID,
|
||||
initialPacketNumber,
|
||||
s.sentPacketHandler.GetPacketNumberLen,
|
||||
s.RemoteAddr(),
|
||||
s.cryptoSetup,
|
||||
s.streamFramer,
|
||||
s.perspective,
|
||||
s.version,
|
||||
)
|
||||
s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.cryptoStream, s.packer.QueueControlFrame)
|
||||
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -357,7 +393,7 @@ func (s *session) run() error {
|
||||
defer s.ctxCancel()
|
||||
|
||||
go func() {
|
||||
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
|
||||
if err := s.cryptoStreamHandler.HandleCryptoStream(); err != nil {
|
||||
s.Close(err)
|
||||
}
|
||||
}()
|
||||
@@ -472,7 +508,7 @@ func (s *session) Context() context.Context {
|
||||
}
|
||||
|
||||
func (s *session) ConnectionState() ConnectionState {
|
||||
return s.cryptoSetup.ConnectionState()
|
||||
return s.cryptoStreamHandler.ConnectionState()
|
||||
}
|
||||
|
||||
func (s *session) maybeResetTimer() {
|
||||
|
||||
@@ -1274,7 +1274,7 @@ var _ = Describe("Session", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
sess.unpacker = &mockUnpacker{unpackErr: qerr.Error(qerr.DecryptionFailure, "")}
|
||||
sess.cryptoSetup = &mockCryptoSetup{}
|
||||
sess.cryptoStreamHandler = &mockCryptoSetup{}
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any()).MaxTimes(1)
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user