forked from quic-go/quic-go
Merge pull request #1570 from lucas-clemente/move-pngen
move the packet number generator to the ackhandler package
This commit is contained in:
@@ -423,7 +423,6 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
|
||||
c.tlsConf,
|
||||
params,
|
||||
c.initialVersion,
|
||||
1,
|
||||
c.logger,
|
||||
c.version,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,6 @@ var _ = Describe("Client", func() {
|
||||
tlsConf *tls.Config,
|
||||
params *handshake.TransportParameters,
|
||||
initialVersion protocol.VersionNumber,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
) (quicSession, error)
|
||||
@@ -143,7 +142,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -173,7 +171,6 @@ var _ = Describe("Client", func() {
|
||||
tlsConf *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -203,7 +200,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -240,7 +236,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -280,7 +275,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -325,7 +319,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -366,7 +359,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -481,7 +473,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
params *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber, /* initial version */
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -543,7 +534,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -603,7 +593,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -642,7 +631,6 @@ var _ = Describe("Client", func() {
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
|
||||
@@ -30,7 +30,9 @@ type SentPacketHandler interface {
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() *Packet
|
||||
DequeueProbePacket() (*Packet, error)
|
||||
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
|
||||
|
||||
PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen)
|
||||
PopPacketNumber() protocol.PacketNumber
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm() error
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package quic
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// The packetNumberGenerator generates the packet number for the next packet
|
||||
@@ -15,6 +16,8 @@ type packetNumberGenerator struct {
|
||||
|
||||
next protocol.PacketNumber
|
||||
nextToSkip protocol.PacketNumber
|
||||
|
||||
history []protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
|
||||
@@ -37,6 +40,10 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
|
||||
p.next++
|
||||
|
||||
if p.next == p.nextToSkip {
|
||||
if len(p.history)+1 > protocol.MaxTrackedSkippedPackets {
|
||||
p.history = p.history[1:]
|
||||
}
|
||||
p.history = append(p.history, p.next)
|
||||
p.next++
|
||||
p.generateNewSkip()
|
||||
}
|
||||
@@ -60,3 +67,12 @@ func (p *packetNumberGenerator) getRandomNumber() uint16 {
|
||||
num := uint16(b[0])<<8 + uint16(b[1])
|
||||
return num
|
||||
}
|
||||
|
||||
func (p *packetNumberGenerator) Validate(ack *wire.AckFrame) bool {
|
||||
for _, pn := range p.history {
|
||||
if ack.AcksPacket(pn) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
package quic
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -97,4 +98,45 @@ var _ = Describe("Packet Number Generator", func() {
|
||||
Expect(largest).To(BeNumerically(">", math.MaxUint16-300))
|
||||
Expect(sum / uint64(rep)).To(BeNumerically("==", uint64(math.MaxUint16/2), 1000))
|
||||
})
|
||||
|
||||
It("validates ACK frames", func() {
|
||||
var skipped []protocol.PacketNumber
|
||||
var lastPN protocol.PacketNumber
|
||||
for len(skipped) < 3 {
|
||||
if png.Peek() > lastPN+1 {
|
||||
skipped = append(skipped, lastPN+1)
|
||||
}
|
||||
lastPN = png.Pop()
|
||||
}
|
||||
invalidACK := &wire.AckFrame{
|
||||
AckRanges: []wire.AckRange{{Smallest: 1, Largest: lastPN}},
|
||||
}
|
||||
Expect(png.Validate(invalidACK)).To(BeFalse())
|
||||
validACK1 := &wire.AckFrame{
|
||||
AckRanges: []wire.AckRange{{Smallest: 1, Largest: skipped[0] - 1}},
|
||||
}
|
||||
Expect(png.Validate(validACK1)).To(BeTrue())
|
||||
validACK2 := &wire.AckFrame{
|
||||
AckRanges: []wire.AckRange{
|
||||
{Smallest: 1, Largest: skipped[0] - 1},
|
||||
{Smallest: skipped[0] + 1, Largest: skipped[1] - 1},
|
||||
{Smallest: skipped[1] + 1, Largest: skipped[2] - 1},
|
||||
{Smallest: skipped[2] + 1, Largest: skipped[2] + 100},
|
||||
},
|
||||
}
|
||||
Expect(png.Validate(validACK2)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("tracks a maximum number of protocol.MaxTrackedSkippedPackets packets", func() {
|
||||
var skipped []protocol.PacketNumber
|
||||
var lastPN protocol.PacketNumber
|
||||
for len(skipped) < protocol.MaxTrackedSkippedPackets+3 {
|
||||
if png.Peek() > lastPN+1 {
|
||||
skipped = append(skipped, lastPN+1)
|
||||
}
|
||||
lastPN = png.Pop()
|
||||
Expect(len(png.history)).To(BeNumerically("<=", protocol.MaxTrackedSkippedPackets))
|
||||
}
|
||||
Expect(len(png.history)).To(Equal(protocol.MaxTrackedSkippedPackets))
|
||||
})
|
||||
})
|
||||
@@ -30,12 +30,13 @@ const (
|
||||
)
|
||||
|
||||
type sentPacketHandler struct {
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
packetNumberGenerator *packetNumberGenerator
|
||||
|
||||
lastSentRetransmittablePacketTime time.Time
|
||||
lastSentHandshakePacketTime time.Time
|
||||
|
||||
nextPacketSendTime time.Time
|
||||
skippedPackets []protocol.PacketNumber
|
||||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestReceivedPacketWithAck protocol.PacketNumber
|
||||
@@ -89,11 +90,12 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
|
||||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
packetHistory: newSentPacketHistory(),
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
|
||||
packetHistory: newSentPacketHistory(),
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,10 +149,6 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra
|
||||
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
|
||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||
h.logger.Debugf("Skipping packet number %#x", p)
|
||||
h.skippedPackets = append(h.skippedPackets, p)
|
||||
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
||||
h.skippedPackets = h.skippedPackets[1:]
|
||||
}
|
||||
}
|
||||
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
@@ -197,7 +195,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
|
||||
|
||||
if h.skippedPacketsAcked(ackFrame) {
|
||||
if !h.packetNumberGenerator.Validate(ackFrame) {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
}
|
||||
|
||||
@@ -235,8 +233,6 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
||||
return err
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -518,8 +514,13 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
|
||||
return h.DequeuePacketForRetransmission(), nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
|
||||
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||
pn := h.packetNumberGenerator.Peek()
|
||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||
return h.packetNumberGenerator.Pop()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendMode() SendMode {
|
||||
@@ -630,23 +631,3 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||
rto <<= h.rtoCount
|
||||
return utils.MinDuration(rto, maxRTOTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
||||
for _, p := range h.skippedPackets {
|
||||
if ackFrame.AcksPacket(p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
||||
lowestUnacked := h.lowestUnacked()
|
||||
deleteIndex := 0
|
||||
for i, p := range h.skippedPackets {
|
||||
if p < lowestUnacked {
|
||||
deleteIndex = i + 1
|
||||
}
|
||||
}
|
||||
h.skippedPackets = h.skippedPackets[deleteIndex:]
|
||||
}
|
||||
|
||||
@@ -89,12 +89,6 @@ var _ = Describe("SentPacketHandler", func() {
|
||||
ExpectWithOffset(1, handler.rttStats.SmoothedRTT()).To(Equal(rtt))
|
||||
}
|
||||
|
||||
It("determines the packet number length", func() {
|
||||
handler.largestAcked = 0x1337
|
||||
Expect(handler.GetPacketNumberLen(0x1338)).To(Equal(protocol.PacketNumberLen2))
|
||||
Expect(handler.GetPacketNumberLen(0xfffffff)).To(Equal(protocol.PacketNumberLen4))
|
||||
})
|
||||
|
||||
Context("registering sent packets", func() {
|
||||
It("accepts two consecutive packets", func() {
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1}))
|
||||
@@ -102,7 +96,6 @@ var _ = Describe("SentPacketHandler", func() {
|
||||
Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(2)))
|
||||
expectInPacketHistory([]protocol.PacketNumber{1, 2})
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
|
||||
Expect(handler.skippedPackets).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("accepts packet number 0", func() {
|
||||
@@ -112,7 +105,6 @@ var _ = Describe("SentPacketHandler", func() {
|
||||
Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
expectInPacketHistory([]protocol.PacketNumber{0, 1})
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
|
||||
Expect(handler.skippedPackets).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("stores the sent time", func() {
|
||||
@@ -134,94 +126,6 @@ var _ = Describe("SentPacketHandler", func() {
|
||||
Expect(handler.lastSentRetransmittablePacketTime).To(BeZero())
|
||||
Expect(handler.bytesInFlight).To(BeZero())
|
||||
})
|
||||
|
||||
Context("skipped packet numbers", func() {
|
||||
It("works with non-consecutive packet numbers", func() {
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3}))
|
||||
Expect(handler.lastSentPacketNumber).To(Equal(protocol.PacketNumber(3)))
|
||||
expectInPacketHistory([]protocol.PacketNumber{1, 3})
|
||||
Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2}))
|
||||
})
|
||||
|
||||
It("works with non-retransmittable packets", func() {
|
||||
handler.SentPacket(nonRetransmittablePacket(&Packet{PacketNumber: 1}))
|
||||
handler.SentPacket(nonRetransmittablePacket(&Packet{PacketNumber: 3}))
|
||||
Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2}))
|
||||
})
|
||||
|
||||
It("recognizes multiple skipped packets", func() {
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 5}))
|
||||
Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 4}))
|
||||
})
|
||||
|
||||
It("recognizes multiple consecutive skipped packets", func() {
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 4}))
|
||||
Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 3}))
|
||||
})
|
||||
|
||||
It("limits the lengths of the skipped packet slice", func() {
|
||||
for i := protocol.PacketNumber(0); i < protocol.MaxTrackedSkippedPackets+5; i++ {
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2*i + 1}))
|
||||
}
|
||||
Expect(handler.skippedPackets).To(HaveLen(protocol.MaxUndecryptablePackets))
|
||||
Expect(handler.skippedPackets[0]).To(Equal(protocol.PacketNumber(10)))
|
||||
Expect(handler.skippedPackets[protocol.MaxTrackedSkippedPackets-1]).To(Equal(protocol.PacketNumber(10 + 2*(protocol.MaxTrackedSkippedPackets-1))))
|
||||
})
|
||||
|
||||
Context("garbage collection", func() {
|
||||
It("keeps all packet numbers above the LargestAcked", func() {
|
||||
handler.skippedPackets = []protocol.PacketNumber{2, 5, 8, 10}
|
||||
handler.largestAcked = 1
|
||||
handler.garbageCollectSkippedPackets()
|
||||
Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{2, 5, 8, 10}))
|
||||
})
|
||||
|
||||
It("doesn't keep packet numbers below the LargestAcked", func() {
|
||||
handler.skippedPackets = []protocol.PacketNumber{1, 5, 8, 10}
|
||||
handler.largestAcked = 5
|
||||
handler.garbageCollectSkippedPackets()
|
||||
Expect(handler.skippedPackets).To(Equal([]protocol.PacketNumber{8, 10}))
|
||||
})
|
||||
|
||||
It("deletes all packet numbers if LargestAcked is sufficiently high", func() {
|
||||
handler.skippedPackets = []protocol.PacketNumber{1, 5, 10}
|
||||
handler.largestAcked = 15
|
||||
handler.garbageCollectSkippedPackets()
|
||||
Expect(handler.skippedPackets).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ACK handling", func() {
|
||||
BeforeEach(func() {
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 10}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 12}))
|
||||
})
|
||||
|
||||
It("rejects ACKs for skipped packets", func() {
|
||||
ack := &wire.AckFrame{
|
||||
AckRanges: []wire.AckRange{{Smallest: 10, Largest: 12}},
|
||||
}
|
||||
err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).To(MatchError("InvalidAckData: Received an ACK for a skipped packet number"))
|
||||
})
|
||||
|
||||
It("accepts an ACK that correctly nacks a skipped packet", func() {
|
||||
ack := &wire.AckFrame{
|
||||
AckRanges: []wire.AckRange{
|
||||
{Smallest: 12, Largest: 12},
|
||||
{Smallest: 10, Largest: 10},
|
||||
},
|
||||
}
|
||||
err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestAcked).ToNot(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("ACK processing", func() {
|
||||
|
||||
@@ -86,18 +86,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked))
|
||||
}
|
||||
|
||||
// GetPacketNumberLen mocks base method
|
||||
func (m *MockSentPacketHandler) GetPacketNumberLen(arg0 protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
ret := m.ctrl.Call(m, "GetPacketNumberLen", arg0)
|
||||
ret0, _ := ret[0].(protocol.PacketNumberLen)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetPacketNumberLen indicates an expected call of GetPacketNumberLen
|
||||
func (mr *MockSentPacketHandlerMockRecorder) GetPacketNumberLen(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPacketNumberLen", reflect.TypeOf((*MockSentPacketHandler)(nil).GetPacketNumberLen), arg0)
|
||||
}
|
||||
|
||||
// OnAlarm mocks base method
|
||||
func (m *MockSentPacketHandler) OnAlarm() error {
|
||||
ret := m.ctrl.Call(m, "OnAlarm")
|
||||
@@ -110,6 +98,31 @@ func (mr *MockSentPacketHandlerMockRecorder) OnAlarm() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAlarm", reflect.TypeOf((*MockSentPacketHandler)(nil).OnAlarm))
|
||||
}
|
||||
|
||||
// PeekPacketNumber mocks base method
|
||||
func (m *MockSentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||
ret := m.ctrl.Call(m, "PeekPacketNumber")
|
||||
ret0, _ := ret[0].(protocol.PacketNumber)
|
||||
ret1, _ := ret[1].(protocol.PacketNumberLen)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PeekPacketNumber indicates an expected call of PeekPacketNumber
|
||||
func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber))
|
||||
}
|
||||
|
||||
// PopPacketNumber mocks base method
|
||||
func (m *MockSentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||
ret := m.ctrl.Call(m, "PopPacketNumber")
|
||||
ret0, _ := ret[0].(protocol.PacketNumber)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PopPacketNumber indicates an expected call of PopPacketNumber
|
||||
func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber))
|
||||
}
|
||||
|
||||
// ReceivedAck mocks base method
|
||||
func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.PacketNumber, arg2 protocol.EncryptionLevel, arg3 time.Time) error {
|
||||
ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2, arg3)
|
||||
|
||||
@@ -59,6 +59,11 @@ func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
|
||||
return maxSize
|
||||
}
|
||||
|
||||
type packetNumberManager interface {
|
||||
PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen)
|
||||
PopPacketNumber() protocol.PacketNumber
|
||||
}
|
||||
|
||||
type sealingManager interface {
|
||||
GetSealer() (protocol.EncryptionLevel, handshake.Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error)
|
||||
@@ -86,10 +91,9 @@ type packetPacker struct {
|
||||
|
||||
token []byte
|
||||
|
||||
packetNumberGenerator *packetNumberGenerator
|
||||
getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen
|
||||
framer frameSource
|
||||
acks ackFrameSource
|
||||
pnManager packetNumberManager
|
||||
framer frameSource
|
||||
acks ackFrameSource
|
||||
|
||||
maxPacketSize protocol.ByteCount
|
||||
hasSentPacket bool // has the packetPacker already sent a packet
|
||||
@@ -103,8 +107,7 @@ func newPacketPacker(
|
||||
srcConnID protocol.ConnectionID,
|
||||
initialStream cryptoStream,
|
||||
handshakeStream cryptoStream,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen,
|
||||
packetNumberManager packetNumberManager,
|
||||
remoteAddr net.Addr, // only used for determining the max packet size
|
||||
token []byte,
|
||||
cryptoSetup sealingManager,
|
||||
@@ -114,19 +117,18 @@ func newPacketPacker(
|
||||
version protocol.VersionNumber,
|
||||
) *packetPacker {
|
||||
return &packetPacker{
|
||||
cryptoSetup: cryptoSetup,
|
||||
token: token,
|
||||
destConnID: destConnID,
|
||||
srcConnID: srcConnID,
|
||||
initialStream: initialStream,
|
||||
handshakeStream: handshakeStream,
|
||||
perspective: perspective,
|
||||
version: version,
|
||||
framer: framer,
|
||||
acks: acks,
|
||||
getPacketNumberLen: getPacketNumberLen,
|
||||
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
|
||||
maxPacketSize: getMaxPacketSize(remoteAddr),
|
||||
cryptoSetup: cryptoSetup,
|
||||
token: token,
|
||||
destConnID: destConnID,
|
||||
srcConnID: srcConnID,
|
||||
initialStream: initialStream,
|
||||
handshakeStream: handshakeStream,
|
||||
perspective: perspective,
|
||||
version: version,
|
||||
framer: framer,
|
||||
acks: acks,
|
||||
pnManager: packetNumberManager,
|
||||
maxPacketSize: getMaxPacketSize(remoteAddr),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,12 +398,10 @@ func (p *packetPacker) composeNextPacket(
|
||||
}
|
||||
|
||||
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
|
||||
pnum := p.packetNumberGenerator.Peek()
|
||||
packetNumberLen := p.getPacketNumberLen(pnum)
|
||||
|
||||
pn, pnLen := p.pnManager.PeekPacketNumber()
|
||||
header := &wire.Header{
|
||||
PacketNumber: pnum,
|
||||
PacketNumberLen: packetNumberLen,
|
||||
PacketNumber: pn,
|
||||
PacketNumberLen: pnLen,
|
||||
Version: p.version,
|
||||
DestConnectionID: p.destConnID,
|
||||
}
|
||||
@@ -482,7 +482,7 @@ func (p *packetPacker) writeAndSealPacket(
|
||||
_ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex])
|
||||
raw = raw[0 : buffer.Len()+sealer.Overhead()]
|
||||
|
||||
num := p.packetNumberGenerator.Pop()
|
||||
num := p.pnManager.PopPacketNumber()
|
||||
if num != header.PacketNumber {
|
||||
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/lucas-clemente/quic-go/internal/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
@@ -25,6 +26,7 @@ var _ = Describe("Packet packer", func() {
|
||||
handshakeStream *MockCryptoStream
|
||||
sealingManager *MockSealingManager
|
||||
sealer *mocks.MockSealer
|
||||
pnManager *mockackhandler.MockSentPacketHandler
|
||||
token []byte
|
||||
)
|
||||
|
||||
@@ -63,6 +65,7 @@ var _ = Describe("Packet packer", func() {
|
||||
framer = NewMockFrameSource(mockCtrl)
|
||||
ackFramer = NewMockAckFrameSource(mockCtrl)
|
||||
sealingManager = NewMockSealingManager(mockCtrl)
|
||||
pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
||||
sealer = mocks.NewMockSealer(mockCtrl)
|
||||
sealer.EXPECT().Overhead().Return(7).AnyTimes()
|
||||
sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte {
|
||||
@@ -76,8 +79,7 @@ var _ = Describe("Packet packer", func() {
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
1,
|
||||
func(protocol.PacketNumber) protocol.PacketNumberLen { return protocol.PacketNumberLen2 },
|
||||
pnManager,
|
||||
&net.TCPAddr{},
|
||||
token, // token
|
||||
sealingManager,
|
||||
@@ -110,12 +112,16 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
Context("generating a packet header", func() {
|
||||
It("uses the Long Header format", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
h := packer.getHeader(protocol.EncryptionHandshake)
|
||||
Expect(h.IsLongHeader).To(BeTrue())
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x42)))
|
||||
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen2))
|
||||
Expect(h.Version).To(Equal(packer.version))
|
||||
})
|
||||
|
||||
It("sets source and destination connection ID", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
packer.srcConnID = srcConnID
|
||||
@@ -126,6 +132,7 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("changes the destination connection ID", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
srcConnID := protocol.ConnectionID{1, 1, 1, 1, 1, 1, 1, 1}
|
||||
packer.srcConnID = srcConnID
|
||||
dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
@@ -141,9 +148,11 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("uses the Short Header format for 1-RTT packets", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4)
|
||||
h := packer.getHeader(protocol.Encryption1RTT)
|
||||
Expect(h.IsLongHeader).To(BeFalse())
|
||||
Expect(h.PacketNumberLen).To(BeNumerically(">", 0))
|
||||
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
|
||||
Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -154,6 +163,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("returns nil when no packet is queued", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
// don't expect any calls to PopPacketNumber
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
framer.EXPECT().AppendControlFrames(nil, gomock.Any())
|
||||
@@ -164,6 +175,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("packs single packets", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
expectAppendControlFrames()
|
||||
@@ -182,6 +195,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("stores the encryption level a packet was sealed with", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
expectAppendControlFrames()
|
||||
@@ -195,6 +210,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("packs a single ACK", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}}
|
||||
ackFramer.EXPECT().GetAckFrame().Return(ack)
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
@@ -207,6 +224,9 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("packs a CONNECTION_CLOSE", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
// expect no framer.PopStreamFrames
|
||||
ccf := wire.ConnectionCloseFrame{
|
||||
ErrorCode: 0x1337,
|
||||
ReasonPhrase: "foobar",
|
||||
@@ -218,19 +238,9 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(p.frames[0]).To(Equal(&ccf))
|
||||
})
|
||||
|
||||
It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() {
|
||||
// expect no framer.PopStreamFrames
|
||||
ccf := &wire.ConnectionCloseFrame{
|
||||
ErrorCode: 0x1337,
|
||||
ReasonPhrase: "foobar",
|
||||
}
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
p, err := packer.PackConnectionClose(ccf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(Equal([]wire.Frame{ccf}))
|
||||
})
|
||||
|
||||
It("packs control frames", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
frames := []wire.Frame{&wire.RstStreamFrame{}, &wire.MaxDataFrame{}}
|
||||
@@ -243,23 +253,8 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(p.raw).NotTo(BeEmpty())
|
||||
})
|
||||
|
||||
It("increases the packet number", func() {
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2)
|
||||
ackFramer.EXPECT().GetAckFrame().Times(2)
|
||||
expectAppendControlFrames()
|
||||
expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")})
|
||||
expectAppendControlFrames()
|
||||
expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("raboof")})
|
||||
p1, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p1).ToNot(BeNil())
|
||||
p2, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p2).ToNot(BeNil())
|
||||
Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber))
|
||||
})
|
||||
|
||||
It("accounts for the space consumed by control frames", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
var maxSize protocol.ByteCount
|
||||
@@ -277,25 +272,6 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("only increases the packet number when there is an actual packet to send", func() {
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2)
|
||||
ackFramer.EXPECT().GetAckFrame().Times(2)
|
||||
expectAppendStreamFrames()
|
||||
expectAppendControlFrames()
|
||||
packer.packetNumberGenerator.nextToSkip = 1000
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1)))
|
||||
expectAppendControlFrames()
|
||||
expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")})
|
||||
p, err = packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(p.header.PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2)))
|
||||
})
|
||||
|
||||
Context("packing ACK packets", func() {
|
||||
It("doesn't pack a packet if there's no ACK to send", func() {
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
@@ -305,6 +281,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("packs ACK packets", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}
|
||||
ackFramer.EXPECT().GetAckFrame().Return(ack)
|
||||
@@ -317,6 +295,8 @@ var _ = Describe("Packet packer", func() {
|
||||
Context("making ACK packets retransmittable", func() {
|
||||
sendMaxNumNonRetransmittableAcks := func() {
|
||||
for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
|
||||
expectAppendControlFrames()
|
||||
@@ -330,6 +310,8 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
It("adds a PING frame when it's supposed to send a retransmittable packet", func() {
|
||||
sendMaxNumNonRetransmittableAcks()
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
|
||||
expectAppendControlFrames()
|
||||
@@ -339,6 +321,8 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.frames).To(ContainElement(&wire.PingFrame{}))
|
||||
// make sure the next packet doesn't contain another PING
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
|
||||
expectAppendControlFrames()
|
||||
@@ -352,6 +336,7 @@ var _ = Describe("Packet packer", func() {
|
||||
It("waits until there's something to send before adding a PING frame", func() {
|
||||
sendMaxNumNonRetransmittableAcks()
|
||||
// nothing to send
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
expectAppendControlFrames()
|
||||
expectAppendStreamFrames()
|
||||
@@ -362,6 +347,8 @@ var _ = Describe("Packet packer", func() {
|
||||
// now add some frame to send
|
||||
expectAppendControlFrames()
|
||||
expectAppendStreamFrames()
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}})
|
||||
p, err = packer.PackPacket()
|
||||
@@ -372,19 +359,23 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
It("doesn't send a PING if it already sent another retransmittable frame", func() {
|
||||
sendMaxNumNonRetransmittableAcks()
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
expectAppendStreamFrames()
|
||||
expectAppendControlFrames(&wire.MaxDataFrame{})
|
||||
p, err := packer.PackPacket()
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("STREAM frame handling", func() {
|
||||
It("does not split a STREAM frame with maximum size", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
expectAppendControlFrames()
|
||||
@@ -421,6 +412,8 @@ var _ = Describe("Packet packer", func() {
|
||||
Data: []byte("frame 3"),
|
||||
DataLenPresent: true,
|
||||
}
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
expectAppendControlFrames()
|
||||
@@ -438,6 +431,7 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("doesn't send unencrypted stream data on a data stream", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionInitial, sealer)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
expectAppendControlFrames()
|
||||
@@ -450,7 +444,8 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
Context("retransmissions", func() {
|
||||
It("retransmits a small packet", func() {
|
||||
packer.packetNumberGenerator.next = 10
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil)
|
||||
frames := []wire.Frame{
|
||||
&wire.MaxDataFrame{ByteOffset: 0x1234},
|
||||
@@ -468,6 +463,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("packs two packets for retransmission if the original packet contained many control frames", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)).Times(2)
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil)
|
||||
var frames []wire.Frame
|
||||
var totalLen protocol.ByteCount
|
||||
@@ -480,7 +477,6 @@ var _ = Describe("Packet packer", func() {
|
||||
frames = append(frames, f)
|
||||
totalLen += f.Length(packer.version)
|
||||
}
|
||||
packer.packetNumberGenerator.next = 10
|
||||
packets, err := packer.PackRetransmission(&ackhandler.Packet{
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
Frames: frames,
|
||||
@@ -495,6 +491,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("splits a STREAM frame that doesn't fit", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)).Times(2)
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil)
|
||||
packets, err := packer.PackRetransmission(&ackhandler.Packet{
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
@@ -521,6 +519,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("splits STREAM frames, if necessary", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).AnyTimes()
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)).AnyTimes()
|
||||
for i := 0; i < 100; i++ {
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil).MaxTimes(2)
|
||||
sf1 := &wire.StreamFrame{
|
||||
@@ -556,6 +556,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("packs two packets for retransmission if the original packet contained many STREAM frames", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)).Times(2)
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil)
|
||||
var frames []wire.Frame
|
||||
var totalLen protocol.ByteCount
|
||||
@@ -583,6 +585,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("correctly sets the DataLenPresent on STREAM frames", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil)
|
||||
frames := []wire.Frame{
|
||||
&wire.StreamFrame{StreamID: 4, Data: []byte("foobar"), DataLenPresent: true},
|
||||
@@ -609,6 +613,7 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
Context("max packet size", func() {
|
||||
It("sets the maximum packet size", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2)
|
||||
ackFramer.EXPECT().GetAckFrame().Times(2)
|
||||
var initialMaxPacketSize protocol.ByteCount
|
||||
@@ -633,6 +638,7 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("doesn't increase the max packet size", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2)
|
||||
ackFramer.EXPECT().GetAckFrame().Times(2)
|
||||
var initialMaxPacketSize protocol.ByteCount
|
||||
@@ -660,6 +666,8 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
Context("packing crypto packets", func() {
|
||||
It("sets the payload length", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
f := &wire.CryptoFrame{
|
||||
Offset: 0x1337,
|
||||
Data: []byte("foobar"),
|
||||
@@ -675,6 +683,8 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
It("packs a maximum size crypto packet", func() {
|
||||
var f *wire.CryptoFrame
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionHandshake).Return(sealer, nil)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
initialStream.EXPECT().HasData()
|
||||
@@ -696,6 +706,8 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
It("pads Initial packets to the required minimum packet size", func() {
|
||||
f := &wire.CryptoFrame{Data: []byte("foobar")}
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
initialStream.EXPECT().HasData().Return(true)
|
||||
@@ -712,6 +724,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("sets the correct payload length for an Initial packet", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil)
|
||||
ackFramer.EXPECT().GetAckFrame()
|
||||
initialStream.EXPECT().HasData().Return(true)
|
||||
@@ -728,6 +742,8 @@ var _ = Describe("Packet packer", func() {
|
||||
It("adds an ACK frame", func() {
|
||||
f := &wire.CryptoFrame{Data: []byte("foobar")}
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 42, Largest: 1337}}}
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil)
|
||||
ackFramer.EXPECT().GetAckFrame().Return(ack)
|
||||
initialStream.EXPECT().HasData().Return(true)
|
||||
@@ -747,6 +763,8 @@ var _ = Describe("Packet packer", func() {
|
||||
sf := &wire.StreamFrame{Data: []byte("foobar")}
|
||||
|
||||
It("packs a retransmission with the right encryption level", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil)
|
||||
packet := &ackhandler.Packet{
|
||||
PacketType: protocol.PacketTypeHandshake,
|
||||
@@ -763,6 +781,7 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
// this should never happen, since non forward-secure packets are limited to a size smaller than MaxPacketSize, such that it is always possible to retransmit them without splitting the StreamFrame
|
||||
It("refuses to send a packet larger than MaxPacketSize", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(gomock.Any()).Return(sealer, nil)
|
||||
packet := &ackhandler.Packet{
|
||||
EncryptionLevel: protocol.EncryptionHandshake,
|
||||
@@ -779,6 +798,8 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
|
||||
It("packs a retransmission for an Initial packet", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
|
||||
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
|
||||
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil)
|
||||
packer.perspective = protocol.PerspectiveClient
|
||||
packet := &ackhandler.Packet{
|
||||
|
||||
@@ -78,7 +78,7 @@ type server struct {
|
||||
sessionHandler packetHandlerManager
|
||||
|
||||
// set as a member, so they can be set in the tests
|
||||
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, protocol.PacketNumber, *Config, *tls.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error)
|
||||
newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error)
|
||||
|
||||
serverError error
|
||||
errorChan chan struct{}
|
||||
@@ -392,7 +392,6 @@ func (s *server) createNewSession(
|
||||
origConnID,
|
||||
destConnID,
|
||||
srcConnID,
|
||||
1,
|
||||
s.config,
|
||||
s.tlsConf,
|
||||
params,
|
||||
|
||||
@@ -262,7 +262,6 @@ var _ = Describe("Server", func() {
|
||||
origConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
_ protocol.PacketNumber,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
@@ -346,7 +345,6 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.PacketNumber,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
|
||||
@@ -138,7 +138,6 @@ var newSession = func(
|
||||
origConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
params *handshake.TransportParameters,
|
||||
@@ -184,8 +183,7 @@ var newSession = func(
|
||||
s.srcConnID,
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
initialPacketNumber,
|
||||
s.sentPacketHandler.GetPacketNumberLen,
|
||||
s.sentPacketHandler,
|
||||
s.RemoteAddr(),
|
||||
nil, // no token
|
||||
cs,
|
||||
@@ -214,7 +212,6 @@ var newClientSession = func(
|
||||
tlsConf *tls.Config,
|
||||
params *handshake.TransportParameters,
|
||||
initialVersion protocol.VersionNumber,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
@@ -259,8 +256,7 @@ var newClientSession = func(
|
||||
s.srcConnID,
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
initialPacketNumber,
|
||||
s.sentPacketHandler.GetPacketNumberLen,
|
||||
s.sentPacketHandler,
|
||||
s.RemoteAddr(),
|
||||
token,
|
||||
cs,
|
||||
|
||||
@@ -84,7 +84,6 @@ var _ = Describe("Session", func() {
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
1,
|
||||
populateServerConfig(&Config{}),
|
||||
nil, // tls.Config
|
||||
nil, // handshake.TransportParameters,
|
||||
@@ -611,7 +610,6 @@ var _ = Describe("Session", func() {
|
||||
newPacket := getPacket(234)
|
||||
sess.windowUpdateQueue.callback(&wire.MaxDataFrame{})
|
||||
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
||||
sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes()
|
||||
sph.EXPECT().DequeuePacketForRetransmission().Return(packetToRetransmit)
|
||||
sph.EXPECT().SendMode().Return(ackhandler.SendRetransmission)
|
||||
sph.EXPECT().SendMode().Return(ackhandler.SendAny)
|
||||
@@ -665,7 +663,6 @@ var _ = Describe("Session", func() {
|
||||
retransmittedPacket := getPacket(123)
|
||||
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
||||
sph.EXPECT().TimeUntilSend()
|
||||
sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes()
|
||||
sph.EXPECT().SendMode().Return(ackhandler.SendTLP)
|
||||
sph.EXPECT().ShouldSendNumPackets().Return(1)
|
||||
sph.EXPECT().DequeueProbePacket().Return(packetToRetransmit, nil)
|
||||
@@ -692,7 +689,6 @@ var _ = Describe("Session", func() {
|
||||
BeforeEach(func() {
|
||||
sph = mockackhandler.NewMockSentPacketHandler(mockCtrl)
|
||||
sph.EXPECT().GetAlarmTimeout().AnyTimes()
|
||||
sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes()
|
||||
sph.EXPECT().DequeuePacketForRetransmission().AnyTimes()
|
||||
sess.sentPacketHandler = sph
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
@@ -836,7 +832,6 @@ var _ = Describe("Session", func() {
|
||||
sph.EXPECT().TimeUntilSend().AnyTimes()
|
||||
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
|
||||
sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(1)
|
||||
sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes()
|
||||
sph.EXPECT().SentPacket(gomock.Any())
|
||||
sess.sentPacketHandler = sph
|
||||
packer.EXPECT().PackPacket().Return(getPacket(1), nil)
|
||||
@@ -1297,7 +1292,6 @@ var _ = Describe("Client Session", func() {
|
||||
nil, // tls.Config
|
||||
nil, // transport parameters
|
||||
protocol.VersionWhatever,
|
||||
1,
|
||||
utils.DefaultLogger,
|
||||
protocol.VersionWhatever,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user