diff --git a/conn_id_manager.go b/conn_id_manager.go index a4fbd93c..0fa862c3 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -2,11 +2,11 @@ package quic import ( "fmt" + "slices" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" - list "github.com/quic-go/quic-go/internal/utils/linkedlist" "github.com/quic-go/quic-go/internal/wire" ) @@ -17,7 +17,7 @@ type newConnID struct { } type connIDManager struct { - queue list.List[newConnID] + queue []newConnID highestProbingID uint64 pathProbing map[pathID]newConnID // initialized lazily @@ -53,6 +53,7 @@ func newConnIDManager( addStatelessResetToken: addStatelessResetToken, removeStatelessResetToken: removeStatelessResetToken, queueControlFrame: queueControlFrame, + queue: make([]newConnID, 0, protocol.MaxActiveConnectionIDs), } } @@ -64,7 +65,7 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { if err := h.add(f); err != nil { return err } - if h.queue.Len() >= protocol.MaxActiveConnectionIDs { + if len(h.queue) >= protocol.MaxActiveConnectionIDs { return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} } return nil @@ -99,17 +100,15 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { // Retire elements in the queue. // Doesn't retire the active connection ID. if f.RetirePriorTo > h.highestRetired { - var next *list.Element[newConnID] - for el := h.queue.Front(); el != nil; el = next { - if el.Value.SequenceNumber >= f.RetirePriorTo { - break + var newQueue []newConnID + for _, entry := range h.queue { + if entry.SequenceNumber >= f.RetirePriorTo { + newQueue = append(newQueue, entry) + } else { + h.queueControlFrame(&wire.RetireConnectionIDFrame{SequenceNumber: entry.SequenceNumber}) } - next = el.Next() - h.queueControlFrame(&wire.RetireConnectionIDFrame{ - SequenceNumber: el.Value.SequenceNumber, - }) - h.queue.Remove(el) } + h.queue = newQueue h.highestRetired = f.RetirePriorTo } @@ -130,36 +129,39 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { } func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { - // insert a new element at the end - if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq { - h.queue.PushBack(newConnID{ + // fast path: add to the end of the queue + if len(h.queue) == 0 || h.queue[len(h.queue)-1].SequenceNumber < seq { + h.queue = append(h.queue, newConnID{ SequenceNumber: seq, ConnectionID: connID, StatelessResetToken: resetToken, }) return nil } - // insert a new element somewhere in the middle - for el := h.queue.Front(); el != nil; el = el.Next() { - if el.Value.SequenceNumber == seq { - if el.Value.ConnectionID != connID { + + // slow path: insert in the middle + for i, entry := range h.queue { + if entry.SequenceNumber == seq { + if entry.ConnectionID != connID { return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) } - if el.Value.StatelessResetToken != resetToken { + if entry.StatelessResetToken != resetToken { return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq) } - break + return nil } - if el.Value.SequenceNumber > seq { - h.queue.InsertBefore(newConnID{ + + // insert at the correct position to maintain sorted order + if entry.SequenceNumber > seq { + h.queue = slices.Insert(h.queue, i, newConnID{ SequenceNumber: seq, ConnectionID: connID, StatelessResetToken: resetToken, - }, el) - break + }) + return nil } } - return nil + return nil // unreachable } func (h *connIDManager) updateConnectionID() { @@ -172,7 +174,8 @@ func (h *connIDManager) updateConnectionID() { h.removeStatelessResetToken(*h.activeStatelessResetToken) } - front := h.queue.Remove(h.queue.Front()) + front := h.queue[0] + h.queue = h.queue[1:] h.activeSequenceNumber = front.SequenceNumber h.activeConnectionID = front.ConnectionID h.activeStatelessResetToken = &front.StatelessResetToken @@ -216,13 +219,13 @@ func (h *connIDManager) shouldUpdateConnID() bool { return false } // initiate the first change as early as possible (after handshake completion) - if h.queue.Len() > 0 && h.activeSequenceNumber == 0 { + if len(h.queue) > 0 && h.activeSequenceNumber == 0 { return true } // For later changes, only change if // 1. The queue of connection IDs is filled more than 50%. // 2. We sent at least PacketsPerConnectionID packets - return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs && + return 2*len(h.queue) >= protocol.MaxActiveConnectionIDs && h.packetsSinceLastChange >= h.packetsPerConnectionID } @@ -256,10 +259,11 @@ func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool if ok { return entry.ConnectionID, true } - if h.queue.Len() == 0 { + if len(h.queue) == 0 { return protocol.ConnectionID{}, false } - front := h.queue.Remove(h.queue.Front()) + front := h.queue[0] + h.queue = h.queue[1:] h.pathProbing[id] = front h.highestProbingID = front.SequenceNumber return front.ConnectionID, true diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 405abc98..94748568 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -329,3 +329,52 @@ func TestConnIDManagerClose(t *testing.T) { require.Panics(t, func() { m.Get() }) require.Panics(t, func() { m.SetStatelessResetToken(protocol.StatelessResetToken{}) }) } + +func BenchmarkConnIDManagerReordered(b *testing.B) { + benchmarkConnIDManager(b, true) +} + +func BenchmarkConnIDManagerInOrder(b *testing.B) { + benchmarkConnIDManager(b, false) +} + +func benchmarkConnIDManager(b *testing.B, reordered bool) { + m := newConnIDManager( + protocol.ParseConnectionID([]byte{1, 2, 3, 4}), + func(protocol.StatelessResetToken) {}, + func(protocol.StatelessResetToken) {}, + func(f wire.Frame) {}, + ) + connIDs := make([]protocol.ConnectionID, 0, protocol.MaxActiveConnectionIDs) + statelessResetTokens := make([]protocol.StatelessResetToken, 0, protocol.MaxActiveConnectionIDs) + for range protocol.MaxActiveConnectionIDs { + b := make([]byte, 8) + rand.Read(b) + connIDs = append(connIDs, protocol.ParseConnectionID(b)) + var statelessResetToken protocol.StatelessResetToken + rand.Read(statelessResetToken[:]) + statelessResetTokens = append(statelessResetTokens, statelessResetToken) + } + + // 1 -> 3 + // 2 -> 1 + // 3 -> 2 + // 4 -> 4 + offsets := []int{2, -1, -1, 0} + + b.ResetTimer() + for i := range b.N { + seq := i + if reordered { + seq += offsets[i%len(offsets)] + } + m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: uint64(seq), + ConnectionID: connIDs[i%len(connIDs)], + StatelessResetToken: statelessResetTokens[i%len(statelessResetTokens)], + }) + if i > protocol.MaxActiveConnectionIDs-2 { + m.updateConnectionID() + } + } +}