diff --git a/conn_id_manager.go b/conn_id_manager.go index 4aa3f749..0e65231c 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -66,6 +66,12 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { } func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { + if h.activeConnectionID.Len() == 0 { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received NEW_CONNECTION_ID frame but zero-length connection IDs are in use", + } + } // If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active // connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately. if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired { diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index ffe435de..193d74c2 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -209,6 +209,26 @@ func TestConnIDManagerConnIDRotation(t *testing.T) { require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 2}}, frameQueue) } +func TestConnIDManagerZeroLengthConnectionID(t *testing.T) { + m := newConnIDManager( + protocol.ConnectionID{}, + func(protocol.StatelessResetToken) {}, + func(protocol.StatelessResetToken) {}, + func(f wire.Frame) {}, + ) + require.Equal(t, protocol.ConnectionID{}, m.Get()) + for i := 0; i < 5*protocol.PacketsPerConnectionID; i++ { + m.SentPacket() + require.Equal(t, protocol.ConnectionID{}, m.Get()) + } + + require.ErrorIs(t, m.Add(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{}, + StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, + }), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}) +} + func TestConnIDManagerClose(t *testing.T) { var addedTokens, removedTokens []protocol.StatelessResetToken m := newConnIDManager(