diff --git a/conn_id_manager.go b/conn_id_manager.go index 8681b54a0..c10ee2cb6 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -15,6 +15,7 @@ import ( type connIDManager struct { queue utils.NewConnectionIDList + handshakeComplete bool activeSequenceNumber uint64 highestRetired uint64 activeConnectionID protocol.ConnectionID @@ -190,7 +191,10 @@ func (h *connIDManager) SentPacket() { } func (h *connIDManager) shouldUpdateConnID() bool { - // initiate the first change as early as possible + if !h.handshakeComplete { + return false + } + // initiate the first change as early as possible (after handshake completion) if h.queue.Len() > 0 && h.activeSequenceNumber == 0 { return true } @@ -207,3 +211,7 @@ func (h *connIDManager) Get() protocol.ConnectionID { } return h.activeConnectionID } + +func (h *connIDManager) SetHandshakeComplete() { + h.handshakeComplete = true +} diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 76e2b4ae5..3cd406a82 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -104,6 +104,7 @@ var _ = Describe("Connection ID Manager", func() { ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } + m.SetHandshakeComplete() Expect(m.Add(f)).To(Succeed()) Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) c, _ := get() @@ -206,6 +207,7 @@ var _ = Describe("Connection ID Manager", func() { SequenceNumber: 1, ConnectionID: connID, })).To(Succeed()) + m.SetHandshakeComplete() Expect(frameQueue).To(BeEmpty()) Expect(m.Get()).To(Equal(connID)) Expect(frameQueue).To(HaveLen(1)) @@ -237,6 +239,7 @@ var _ = Describe("Connection ID Manager", func() { It("initiates the first connection ID update as soon as possible", func() { Expect(m.Get()).To(Equal(initialConnID)) + m.SetHandshakeComplete() Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, @@ -245,6 +248,18 @@ var _ = Describe("Connection ID Manager", func() { Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) }) + It("waits until handshake completion before initiating a connection ID update", func() { + Expect(m.Get()).To(Equal(initialConnID)) + Expect(m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + })).To(Succeed()) + Expect(m.Get()).To(Equal(initialConnID)) + m.SetHandshakeComplete() + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + }) + It("initiates subsequent updates when enough packets are sent", func() { var s uint8 for s = uint8(1); s < protocol.MaxActiveConnectionIDs; s++ { @@ -255,6 +270,7 @@ var _ = Describe("Connection ID Manager", func() { })).To(Succeed()) } + m.SetHandshakeComplete() lastConnID := m.Get() Expect(lastConnID).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) @@ -287,6 +303,7 @@ var _ = Describe("Connection ID Manager", func() { StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, })).To(Succeed()) } + m.SetHandshakeComplete() Expect(m.Get()).To(Equal(protocol.ConnectionID{10, 10, 10, 10})) for { m.SentPacket() @@ -316,6 +333,7 @@ var _ = Describe("Connection ID Manager", func() { StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, })).To(Succeed()) } + m.SetHandshakeComplete() Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) for i := 0; i < 2*protocol.PacketsPerConnectionID; i++ { m.SentPacket() @@ -339,6 +357,7 @@ var _ = Describe("Connection ID Manager", func() { ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })).To(Succeed()) + m.SetHandshakeComplete() Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) m.Close() Expect(retiredTokens).To(BeEmpty()) diff --git a/session.go b/session.go index 13a99fbc0..c3b0abe94 100644 --- a/session.go +++ b/session.go @@ -691,6 +691,7 @@ func (s *session) handleHandshakeComplete() { s.handshakeCompleteChan = nil // prevent this case from ever being selected again s.handshakeCtxCancel() + s.connIDManager.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete() if s.perspective == protocol.PerspectiveServer { diff --git a/session_test.go b/session_test.go index a94bee623..2d42847d2 100644 --- a/session_test.go +++ b/session_test.go @@ -2248,6 +2248,7 @@ var _ = Describe("Client Session", func() { unpacker := NewMockUnpacker(mockCtrl) sess.unpacker = unpacker sessionRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) + sess.connIDManager.SetHandshakeComplete() sess.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, @@ -2491,6 +2492,7 @@ var _ = Describe("Client Session", func() { packer.EXPECT().PackCoalescedPacket().MaxTimes(1) tracer.EXPECT().ReceivedTransportParameters(params) sess.processTransportParameters(params) + sess.connIDManager.SetHandshakeComplete() // make sure the connection ID is not retired cf, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(cf).To(BeEmpty())