From 80534c094419555f50d5e6882e2f8bc559c952d7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 2 Nov 2020 16:20:26 +0700 Subject: [PATCH] wait until the handshake is complete before updating the connection ID --- conn_id_manager.go | 10 +++++++++- conn_id_manager_test.go | 19 +++++++++++++++++++ session.go | 1 + session_test.go | 2 ++ 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/conn_id_manager.go b/conn_id_manager.go index 8681b54a0..c10ee2cb6 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -15,6 +15,7 @@ import ( type connIDManager struct { queue utils.NewConnectionIDList + handshakeComplete bool activeSequenceNumber uint64 highestRetired uint64 activeConnectionID protocol.ConnectionID @@ -190,7 +191,10 @@ func (h *connIDManager) SentPacket() { } func (h *connIDManager) shouldUpdateConnID() bool { - // initiate the first change as early as possible + if !h.handshakeComplete { + return false + } + // initiate the first change as early as possible (after handshake completion) if h.queue.Len() > 0 && h.activeSequenceNumber == 0 { return true } @@ -207,3 +211,7 @@ func (h *connIDManager) Get() protocol.ConnectionID { } return h.activeConnectionID } + +func (h *connIDManager) SetHandshakeComplete() { + h.handshakeComplete = true +} diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 76e2b4ae5..3cd406a82 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -104,6 +104,7 @@ var _ = Describe("Connection ID Manager", func() { ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } + m.SetHandshakeComplete() Expect(m.Add(f)).To(Succeed()) Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) c, _ := get() @@ -206,6 +207,7 @@ var _ = Describe("Connection ID Manager", func() { SequenceNumber: 1, ConnectionID: connID, })).To(Succeed()) + m.SetHandshakeComplete() Expect(frameQueue).To(BeEmpty()) Expect(m.Get()).To(Equal(connID)) Expect(frameQueue).To(HaveLen(1)) @@ -237,6 +239,7 @@ var _ = Describe("Connection ID Manager", func() { It("initiates the first connection ID update as soon as possible", func() { Expect(m.Get()).To(Equal(initialConnID)) + m.SetHandshakeComplete() Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, @@ -245,6 +248,18 @@ var _ = Describe("Connection ID Manager", func() { Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) }) + It("waits until handshake completion before initiating a connection ID update", func() { + Expect(m.Get()).To(Equal(initialConnID)) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + })).To(Succeed()) + Expect(m.Get()).To(Equal(initialConnID)) + m.SetHandshakeComplete() + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + }) + It("initiates subsequent updates when enough packets are sent", func() { var s uint8 for s = uint8(1); s < protocol.MaxActiveConnectionIDs; s++ { @@ -255,6 +270,7 @@ var _ = Describe("Connection ID Manager", func() { })).To(Succeed()) } + m.SetHandshakeComplete() lastConnID := m.Get() Expect(lastConnID).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) @@ -287,6 +303,7 @@ var _ = Describe("Connection ID Manager", func() { StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, })).To(Succeed()) } + m.SetHandshakeComplete() Expect(m.Get()).To(Equal(protocol.ConnectionID{10, 10, 10, 10})) for { m.SentPacket() @@ -316,6 +333,7 @@ var _ = Describe("Connection ID Manager", func() { StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, })).To(Succeed()) } + m.SetHandshakeComplete() Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) for i := 0; i < 2*protocol.PacketsPerConnectionID; i++ { m.SentPacket() @@ -339,6 +357,7 @@ var _ = Describe("Connection ID Manager", func() { ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })).To(Succeed()) + m.SetHandshakeComplete() Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) m.Close() Expect(retiredTokens).To(BeEmpty()) diff --git a/session.go b/session.go index 2275c834f..1a750ada1 100644 --- a/session.go +++ b/session.go @@ -691,6 +691,7 @@ func (s *session) handleHandshakeComplete() { s.handshakeCompleteChan = nil // prevent this case from ever being selected again s.handshakeCtxCancel() + s.connIDManager.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete() if s.perspective == protocol.PerspectiveServer { diff --git a/session_test.go b/session_test.go index 81fdfee0e..e50d2422f 100644 --- a/session_test.go +++ b/session_test.go @@ -2251,6 +2251,7 @@ var _ = Describe("Client Session", func() { unpacker := NewMockUnpacker(mockCtrl) sess.unpacker = unpacker sessionRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) + sess.connIDManager.SetHandshakeComplete() sess.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, @@ -2494,6 +2495,7 @@ var _ = Describe("Client Session", func() { packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).MaxTimes(1) tracer.EXPECT().ReceivedTransportParameters(params) sess.processTransportParameters(params) + sess.connIDManager.SetHandshakeComplete() // make sure the connection ID is not retired cf, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(cf).To(BeEmpty())