diff --git a/client.go b/client.go index 297b95f9..787fa89c 100644 --- a/client.go +++ b/client.go @@ -37,6 +37,8 @@ type client struct { destConnID protocol.ConnectionID origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry) + initialPacketNumber protocol.PacketNumber + initialVersion protocol.VersionNumber version protocol.VersionNumber @@ -340,7 +342,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) { c.version = newVersion c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) - c.session.closeForRecreating() + c.initialPacketNumber = c.session.closeForRecreating() } func (c *client) handleRetryPacket(hdr *wire.Header) { @@ -366,7 +368,7 @@ func (c *client) handleRetryPacket(hdr *wire.Header) { c.origDestConnID = c.destConnID c.destConnID = hdr.SrcConnectionID c.token = hdr.Token - c.session.closeForRecreating() + c.initialPacketNumber = c.session.closeForRecreating() } func (c *client) createNewTLSSession(version protocol.VersionNumber) error { @@ -397,6 +399,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error { c.srcConnID, c.config, c.tlsConf, + c.initialPacketNumber, params, c.initialVersion, c.logger, diff --git a/client_test.go b/client_test.go index 847e1b99..875abcd6 100644 --- a/client_test.go +++ b/client_test.go @@ -38,6 +38,7 @@ var _ = Describe("Client", func() { srcConnID protocol.ConnectionID, conf *Config, tlsConf *tls.Config, + initialPacketNumber protocol.PacketNumber, params *handshake.TransportParameters, initialVersion protocol.VersionNumber, logger utils.Logger, @@ -142,6 +143,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -172,6 +174,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, tlsConf *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -202,6 +205,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -239,6 +243,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -279,6 +284,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -324,6 +330,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -369,6 +376,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -484,6 +492,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, configP *Config, _ *tls.Config, + _ protocol.PacketNumber, params *handshake.TransportParameters, _ protocol.VersionNumber, /* initial version */ _ utils.Logger, @@ -530,8 +539,9 @@ var _ = Describe("Client", func() { sess1.EXPECT().run().DoAndReturn(func() error { return <-run1 }) - sess1.EXPECT().closeForRecreating().Do(func() { + sess1.EXPECT().closeForRecreating().DoAndReturn(func() protocol.PacketNumber { run1 <- errCloseForRecreating + return 42 }) sess2 := NewMockQuicSession(mockCtrl) sess2.EXPECT().run() @@ -547,6 +557,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + initialPacketNumber protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -554,9 +565,11 @@ var _ = Describe("Client", func() { ) (quicSession, error) { switch len(sessions) { case 2: // for the first session + Expect(initialPacketNumber).To(BeZero()) Expect(origDestConnID).To(BeNil()) Expect(destConnID).ToNot(BeNil()) case 1: // for the second session + Expect(initialPacketNumber).To(Equal(protocol.PacketNumber(42))) Expect(origDestConnID).To(Equal(connID)) Expect(destConnID).ToNot(Equal(connID)) } @@ -615,6 +628,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, @@ -654,6 +668,7 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ *Config, _ *tls.Config, + _ protocol.PacketNumber, _ *handshake.TransportParameters, _ protocol.VersionNumber, _ utils.Logger, diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index c5516ece..223daea4 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -78,7 +78,11 @@ type sentPacketHandler struct { } // NewSentPacketHandler creates a new sentPacketHandler -func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler { +func NewSentPacketHandler( + initialPacketNumber protocol.PacketNumber, + rttStats *congestion.RTTStats, + logger utils.Logger, +) SentPacketHandler { congestion := congestion.NewCubicSender( congestion.DefaultClock{}, rttStats, @@ -88,7 +92,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se ) return &sentPacketHandler{ - packetNumberGenerator: newPacketNumberGenerator(0, protocol.SkipPacketAveragePeriodLength), + packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), packetHistory: newSentPacketHistory(), rttStats: rttStats, congestion: congestion, @@ -144,8 +148,10 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra } func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ { - for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { - h.logger.Debugf("Skipping packet number %#x", p) + if h.logger.Debug() && h.lastSentPacketNumber != 0 { + for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { + h.logger.Debugf("Skipping packet number %#x", p) + } } h.lastSentPacketNumber = packet.PacketNumber diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index fabdd133..d8e43f73 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -49,7 +49,7 @@ var _ = Describe("SentPacketHandler", func() { BeforeEach(func() { rttStats := &congestion.RTTStats{} - handler = NewSentPacketHandler(rttStats, utils.DefaultLogger).(*sentPacketHandler) + handler = NewSentPacketHandler(42, rttStats, utils.DefaultLogger).(*sentPacketHandler) handler.SetHandshakeComplete() streamFrame = wire.StreamFrame{ StreamID: 5, @@ -962,4 +962,17 @@ var _ = Describe("SentPacketHandler", func() { Expect(packet).To(BeNil()) }) }) + + Context("peeking and popping packet number", func() { + It("peeks and pops the initial packet number", func() { + pn, _ := handler.PeekPacketNumber() + Expect(pn).To(Equal(protocol.PacketNumber(42))) + Expect(handler.PopPacketNumber()).To(Equal(protocol.PacketNumber(42))) + }) + + It("peeks and pops beyond the initial packet number", func() { + Expect(handler.PopPacketNumber()).To(Equal(protocol.PacketNumber(42))) + Expect(handler.PopPacketNumber()).To(BeNumerically(">", 42)) + }) + }) }) diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index d3b034c9..226d3beb 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -200,8 +200,10 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call { } // closeForRecreating mocks base method -func (m *MockQuicSession) closeForRecreating() { - m.ctrl.Call(m, "closeForRecreating") +func (m *MockQuicSession) closeForRecreating() protocol.PacketNumber { + ret := m.ctrl.Call(m, "closeForRecreating") + ret0, _ := ret[0].(protocol.PacketNumber) + return ret0 } // closeForRecreating indicates an expected call of closeForRecreating diff --git a/server.go b/server.go index 6bfc51a3..14b43cff 100644 --- a/server.go +++ b/server.go @@ -43,7 +43,7 @@ type quicSession interface { GetVersion() protocol.VersionNumber run() error destroy(error) - closeForRecreating() + closeForRecreating() protocol.PacketNumber closeRemote(error) } diff --git a/session.go b/session.go index 0552d9bb..2cdd6d8c 100644 --- a/session.go +++ b/session.go @@ -158,6 +158,7 @@ var newSession = func( version: v, } s.preSetup() + s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.logger) initialStream := newCryptoStream() handshakeStream := newCryptoStream() s.streamsMap = newStreamsMap( @@ -218,6 +219,7 @@ var newClientSession = func( srcConnID protocol.ConnectionID, conf *Config, tlsConf *tls.Config, + initialPacketNumber protocol.PacketNumber, params *handshake.TransportParameters, initialVersion protocol.VersionNumber, logger utils.Logger, @@ -235,6 +237,7 @@ var newClientSession = func( version: v, } s.preSetup() + s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.logger) initialStream := newCryptoStream() handshakeStream := newCryptoStream() cs, clientHelloWritten, err := handshake.NewCryptoSetupClient( @@ -286,7 +289,6 @@ var newClientSession = func( func (s *session) preSetup() { s.rttStats = &congestion.RTTStats{} - s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.InitialMaxData, @@ -720,8 +722,12 @@ func (s *session) destroy(e error) { }) } -func (s *session) closeForRecreating() { +// closeForRecreating closes the session in order to recreate it immediately afterwards +// It returns the first packet number that should be used in the new session. +func (s *session) closeForRecreating() protocol.PacketNumber { s.destroy(errCloseForRecreating) + nextPN, _ := s.sentPacketHandler.PeekPacketNumber() + return nextPN } func (s *session) closeRemote(e error) { diff --git a/session_test.go b/session_test.go index 3d6cc069..02fb3579 100644 --- a/session_test.go +++ b/session_test.go @@ -1325,6 +1325,7 @@ var _ = Describe("Client Session", func() { protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, populateClientConfig(&Config{}, true), nil, // tls.Config + 42, // initial packet number nil, // transport parameters protocol.VersionWhatever, utils.DefaultLogger,