diff --git a/conn_id_manager.go b/conn_id_manager.go index 0fa862c3..4ca436a6 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -93,6 +93,7 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: entry.SequenceNumber, }) + h.removeStatelessResetToken(entry.StatelessResetToken) delete(h.pathProbing, id) } } @@ -189,6 +190,11 @@ func (h *connIDManager) Close() { if h.activeStatelessResetToken != nil { h.removeStatelessResetToken(*h.activeStatelessResetToken) } + if h.pathProbing != nil { + for _, entry := range h.pathProbing { + h.removeStatelessResetToken(entry.StatelessResetToken) + } + } } // is called when the server performs a Retry @@ -266,6 +272,7 @@ func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool h.queue = h.queue[1:] h.pathProbing[id] = front h.highestProbingID = front.SequenceNumber + h.addStatelessResetToken(front.StatelessResetToken) return front.ConnectionID, true } @@ -283,6 +290,7 @@ func (h *connIDManager) RetireConnIDForPath(pathID pathID) { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: entry.SequenceNumber, }) + h.removeStatelessResetToken(entry.StatelessResetToken) delete(h.pathProbing, pathID) } diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 94748568..4e7b1ca4 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -211,10 +211,11 @@ func TestConnIDManagerConnIDRotation(t *testing.T) { func TestConnIDManagerPathMigration(t *testing.T) { var frameQueue []wire.Frame + var addedTokens, removedTokens []protocol.StatelessResetToken m := newConnIDManager( protocol.ParseConnectionID([]byte{1, 2, 3, 4}), - func(protocol.StatelessResetToken) {}, - func(protocol.StatelessResetToken) {}, + func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) }, + func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) }, func(f wire.Frame) { frameQueue = append(frameQueue, f) }, ) @@ -226,35 +227,50 @@ func TestConnIDManagerPathMigration(t *testing.T) { require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + StatelessResetToken: protocol.StatelessResetToken{4, 3, 2, 1, 4, 3, 2, 1}, })) require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 2, ConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2}), - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + StatelessResetToken: protocol.StatelessResetToken{5, 4, 3, 2, 5, 4, 3, 2}, })) connID, ok := m.GetConnIDForPath(1) require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), connID) + require.Equal(t, []protocol.StatelessResetToken{{4, 3, 2, 1, 4, 3, 2, 1}}, addedTokens) + require.Empty(t, removedTokens) + + addedTokens = addedTokens[:0] connID, ok = m.GetConnIDForPath(2) require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2}), connID) + require.Equal(t, []protocol.StatelessResetToken{{5, 4, 3, 2, 5, 4, 3, 2}}, addedTokens) + require.Empty(t, removedTokens) + + addedTokens = addedTokens[:0] // asking for the connection for path 1 again returns the same connection ID connID, ok = m.GetConnIDForPath(1) require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), connID) + require.Empty(t, addedTokens) // if the connection ID is retired, the path will use another connection ID require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 3, RetirePriorTo: 2, ConnectionID: protocol.ParseConnectionID([]byte{6, 5, 4, 3}), - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + StatelessResetToken: protocol.StatelessResetToken{6, 5, 4, 3, 6, 5, 4, 3}, })) require.Len(t, frameQueue, 2) + require.Equal(t, []protocol.StatelessResetToken{{4, 3, 2, 1, 4, 3, 2, 1}}, removedTokens) frameQueue = nil + removedTokens = removedTokens[:0] require.Equal(t, protocol.ParseConnectionID([]byte{6, 5, 4, 3}), m.Get()) + require.Equal(t, []protocol.StatelessResetToken{{6, 5, 4, 3, 6, 5, 4, 3}}, addedTokens) + require.Empty(t, removedTokens) + addedTokens = addedTokens[:0] + // the connection ID is not used for new paths _, ok = m.GetConnIDForPath(3) require.False(t, ok) @@ -265,20 +281,31 @@ func TestConnIDManagerPathMigration(t *testing.T) { require.Empty(t, frameQueue) _, ok = m.GetConnIDForPath(1) require.False(t, ok) + require.Empty(t, removedTokens) // only after a new connection ID is added, it will be used for path 1 require.NoError(t, m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 4, ConnectionID: protocol.ParseConnectionID([]byte{7, 6, 5, 4}), - StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13}, })) connID, ok = m.GetConnIDForPath(1) require.True(t, ok) require.Equal(t, protocol.ParseConnectionID([]byte{7, 6, 5, 4}), connID) + require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13}}, addedTokens) + require.Empty(t, removedTokens) // a RETIRE_CONNECTION_ID frame for path 1 is queued when retiring the connection ID m.RetireConnIDForPath(1) require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 4}}, frameQueue) + require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13}}, removedTokens) + removedTokens = removedTokens[:0] + + m.Close() + require.Equal(t, []protocol.StatelessResetToken{ + {6, 5, 4, 3, 6, 5, 4, 3}, // currently active connection ID + {5, 4, 3, 2, 5, 4, 3, 2}, // path 2 + }, removedTokens) } func TestConnIDManagerZeroLengthConnectionID(t *testing.T) { @@ -289,7 +316,7 @@ func TestConnIDManagerZeroLengthConnectionID(t *testing.T) { func(f wire.Frame) {}, ) require.Equal(t, protocol.ConnectionID{}, m.Get()) - for i := 0; i < 5*protocol.PacketsPerConnectionID; i++ { + for range 5 * protocol.PacketsPerConnectionID { m.SentPacket() require.Equal(t, protocol.ConnectionID{}, m.Get()) } diff --git a/connection_test.go b/connection_test.go index ae83f488..81008d75 100644 --- a/connection_test.go +++ b/connection_test.go @@ -3063,6 +3063,7 @@ func testConnectionMigration(t *testing.T, enabled bool) { return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil }, ).AnyTimes() + tc.connRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) // add a new connection ID, so the path can be probed require.NoError(t, tc.conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 1, @@ -3089,6 +3090,7 @@ func testConnectionMigration(t *testing.T, enabled bool) { // teardown tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + tc.connRunner.EXPECT().RemoveResetToken(gomock.Any()).MaxTimes(1) tc.conn.destroy(nil) select { case <-errChan: