diff --git a/retransmission_queue.go b/retransmission_queue.go index 7b6cb5330..323af5b0c 100644 --- a/retransmission_queue.go +++ b/retransmission_queue.go @@ -9,124 +9,106 @@ import ( "github.com/quic-go/quic-go/internal/wire" ) +type framesToRetransmit struct { + crypto []*wire.CryptoFrame + other []wire.Frame +} + type retransmissionQueue struct { - initial []wire.Frame - initialCryptoData []*wire.CryptoFrame - - handshake []wire.Frame - handshakeCryptoData []*wire.CryptoFrame - - appData []wire.Frame + initial *framesToRetransmit + handshake *framesToRetransmit + appData framesToRetransmit } func newRetransmissionQueue() *retransmissionQueue { - return &retransmissionQueue{} + return &retransmissionQueue{ + initial: &framesToRetransmit{}, + handshake: &framesToRetransmit{}, + } } func (q *retransmissionQueue) addInitial(f wire.Frame) { - if cf, ok := f.(*wire.CryptoFrame); ok { - q.initialCryptoData = append(q.initialCryptoData, cf) + if q.initial == nil { return } - q.initial = append(q.initial, f) + if cf, ok := f.(*wire.CryptoFrame); ok { + q.initial.crypto = append(q.initial.crypto, cf) + return + } + q.initial.other = append(q.initial.other, f) } func (q *retransmissionQueue) addHandshake(f wire.Frame) { - if cf, ok := f.(*wire.CryptoFrame); ok { - q.handshakeCryptoData = append(q.handshakeCryptoData, cf) + if q.handshake == nil { return } - q.handshake = append(q.handshake, f) + if cf, ok := f.(*wire.CryptoFrame); ok { + q.handshake.crypto = append(q.handshake.crypto, cf) + return + } + q.handshake.other = append(q.handshake.other, f) +} + +func (q *retransmissionQueue) addAppData(f wire.Frame) { + switch f := f.(type) { + case *wire.StreamFrame: + panic("STREAM frames are handled with their respective streams.") + case *wire.CryptoFrame: + q.appData.crypto = append(q.appData.crypto, f) + default: + q.appData.other = append(q.appData.other, f) + } } func (q *retransmissionQueue) HasData(encLevel protocol.EncryptionLevel) bool { //nolint:exhaustive // 0-RTT data is retransmitted in 1-RTT packets. switch encLevel { case protocol.EncryptionInitial: - return len(q.initialCryptoData) > 0 || len(q.initial) > 0 + return q.initial != nil && + (len(q.initial.crypto) > 0 || len(q.initial.other) > 0) case protocol.EncryptionHandshake: - return len(q.handshakeCryptoData) > 0 || len(q.handshake) > 0 + return q.handshake != nil && + (len(q.handshake.crypto) > 0 || len(q.handshake.other) > 0) case protocol.Encryption1RTT: - return len(q.appData) > 0 + return len(q.appData.crypto) > 0 || len(q.appData.other) > 0 } return false } -func (q *retransmissionQueue) addAppData(f wire.Frame) { - if _, ok := f.(*wire.StreamFrame); ok { - panic("STREAM frames are handled with their respective streams.") - } - q.appData = append(q.appData, f) -} - func (q *retransmissionQueue) GetFrame(encLevel protocol.EncryptionLevel, maxLen protocol.ByteCount, v protocol.Version) wire.Frame { - //nolint:exhaustive // 0-RTT packets can't contain retransmissions + var r *framesToRetransmit + //nolint:exhaustive // 0-RTT data is retransmitted in 1-RTT packets. switch encLevel { case protocol.EncryptionInitial: - return q.getInitialFrame(maxLen, v) + r = q.initial case protocol.EncryptionHandshake: - return q.getHandshakeFrame(maxLen, v) + r = q.handshake case protocol.Encryption1RTT: - return q.getAppDataFrame(maxLen, v) + r = &q.appData + } + if r == nil { + return nil } - return nil -} -func (q *retransmissionQueue) getInitialFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame { - if len(q.initialCryptoData) > 0 { - f := q.initialCryptoData[0] + if len(r.crypto) > 0 { + f := r.crypto[0] newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) if newFrame == nil && !needsSplit { // the whole frame fits - q.initialCryptoData = q.initialCryptoData[1:] + r.crypto = r.crypto[1:] return f } if newFrame != nil { // frame was split. Leave the original frame in the queue. return newFrame } } - if len(q.initial) == 0 { + if len(r.other) == 0 { return nil } - f := q.initial[0] + f := r.other[0] if f.Length(v) > maxLen { return nil } - q.initial = q.initial[1:] - return f -} - -func (q *retransmissionQueue) getHandshakeFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame { - if len(q.handshakeCryptoData) > 0 { - f := q.handshakeCryptoData[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) - if newFrame == nil && !needsSplit { // the whole frame fits - q.handshakeCryptoData = q.handshakeCryptoData[1:] - return f - } - if newFrame != nil { // frame was split. Leave the original frame in the queue. - return newFrame - } - } - if len(q.handshake) == 0 { - return nil - } - f := q.handshake[0] - if f.Length(v) > maxLen { - return nil - } - q.handshake = q.handshake[1:] - return f -} - -func (q *retransmissionQueue) getAppDataFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame { - if len(q.appData) == 0 { - return nil - } - f := q.appData[0] - if f.Length(v) > maxLen { - return nil - } - q.appData = q.appData[1:] + r.other = r.other[1:] return f } @@ -135,10 +117,8 @@ func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) { switch encLevel { case protocol.EncryptionInitial: q.initial = nil - q.initialCryptoData = nil case protocol.EncryptionHandshake: q.handshake = nil - q.handshakeCryptoData = nil default: panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) }