diff --git a/session.go b/session.go index 45e243451..7677e1d50 100644 --- a/session.go +++ b/session.go @@ -130,46 +130,8 @@ func newSession( perspective: protocol.PerspectiveServer, version: v, config: config, - - connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v), } - - s.setup() - cryptoStream, _ := s.GetOrOpenStream(1) - _, _ = s.AcceptStream() // don't expose the crypto stream - aeadChanged := make(chan protocol.EncryptionLevel, 2) - s.aeadChanged = aeadChanged - handshakeChan := make(chan handshakeEvent, 3) - s.handshakeChan = handshakeChan - verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool { - if hstk == nil { - return config.AcceptSTK(clientAddr, nil) - } - return config.AcceptSTK( - clientAddr, - &STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime}, - ) - } - var err error - s.cryptoSetup, err = newCryptoSetup( - connectionID, - conn.RemoteAddr(), - v, - sCfg, - cryptoStream, - s.connectionParameters, - config.Versions, - verifySourceAddr, - aeadChanged, - ) - if err != nil { - return nil, nil, err - } - - s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version) - s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - - return s, handshakeChan, err + return s.setup(sCfg, "", nil) } // declare this as a variable, such that we can it mock it in the tests @@ -187,67 +149,83 @@ var newClientSession = func( perspective: protocol.PerspectiveClient, version: v, config: config, - - connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), } + return s.setup(nil, hostname, negotiatedVersions) +} - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) - s.setup() - +func (s *session) setup( + scfg *handshake.ServerConfig, + hostname string, + negotiatedVersions []protocol.VersionNumber, +) (packetHandler, <-chan handshakeEvent, error) { aeadChanged := make(chan protocol.EncryptionLevel, 2) s.aeadChanged = aeadChanged handshakeChan := make(chan handshakeEvent, 3) s.handshakeChan = handshakeChan - cryptoStream, _ := s.OpenStream() - var err error - s.cryptoSetup, err = newCryptoSetupClient( - hostname, - connectionID, - v, - cryptoStream, - config.TLSConfig, - s.connectionParameters, - aeadChanged, - &handshake.TransportParameters{RequestConnectionIDTruncation: config.RequestConnectionIDTruncation}, - negotiatedVersions, - ) - if err != nil { - return nil, nil, err - } - - s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version) - s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - - return s, handshakeChan, err -} - -// setup is called from newSession and newClientSession and initializes values that are independent of the perspective -func (s *session) setup() { - s.rttStats = &congestion.RTTStats{} - flowControlManager := flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) - - sentPacketHandler := ackhandler.NewSentPacketHandler(s.rttStats) - - now := time.Now() - - s.sentPacketHandler = sentPacketHandler - s.flowControlManager = flowControlManager - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) - + s.runClosed = make(chan struct{}) + s.handshakeCompleteChan = make(chan error, 1) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) - s.aeadChanged = make(chan protocol.EncryptionLevel, 2) - s.runClosed = make(chan struct{}) - s.handshakeCompleteChan = make(chan error, 1) s.timer = time.NewTimer(0) + now := time.Now() s.lastNetworkActivityTime = now s.sessionCreationTime = now + s.rttStats = &congestion.RTTStats{} + s.connectionParameters = handshake.NewConnectionParamatersManager(s.perspective, s.version) + s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) + s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) + + var err error + if s.perspective == protocol.PerspectiveServer { + cryptoStream, _ := s.GetOrOpenStream(1) + _, _ = s.AcceptStream() // don't expose the crypto stream + verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool { + var stk *STK + if hstk != nil { + stk = &STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime} + } + return s.config.AcceptSTK(clientAddr, stk) + } + s.cryptoSetup, err = newCryptoSetup( + s.connectionID, + s.conn.RemoteAddr(), + s.version, + scfg, + cryptoStream, + s.connectionParameters, + s.config.Versions, + verifySourceAddr, + aeadChanged, + ) + } else { + cryptoStream, _ := s.OpenStream() + s.cryptoSetup, err = newCryptoSetupClient( + hostname, + s.connectionID, + s.version, + cryptoStream, + s.config.TLSConfig, + s.connectionParameters, + aeadChanged, + &handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation}, + negotiatedVersions, + ) + } + if err != nil { + return nil, nil, err + } + + s.packer = newPacketPacker(s.connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version) + s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} + + return s, handshakeChan, nil } // run the session main loop