forked from quic-go/quic-go
move packet number calculations to protocol package
This commit is contained in:
@@ -1,44 +0,0 @@
|
||||
package quic
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
func inferPacketNumber(packetNumberLength protocol.PacketNumberLen, lastPacketNumber protocol.PacketNumber, wirePacketNumber protocol.PacketNumber) protocol.PacketNumber {
|
||||
epochDelta := protocol.PacketNumber(1) << (uint8(packetNumberLength) * 8)
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
prevEpochBegin := epoch - epochDelta
|
||||
nextEpochBegin := epoch + epochDelta
|
||||
return closestTo(
|
||||
lastPacketNumber+1,
|
||||
epoch+wirePacketNumber,
|
||||
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
|
||||
)
|
||||
}
|
||||
|
||||
func closestTo(target, a, b protocol.PacketNumber) protocol.PacketNumber {
|
||||
if delta(target, a) < delta(target, b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func delta(a, b protocol.PacketNumber) protocol.PacketNumber {
|
||||
if a < b {
|
||||
return b - a
|
||||
}
|
||||
return a - b
|
||||
}
|
||||
|
||||
func getPacketNumberLength(packetNumber protocol.PacketNumber, highestAckedPacketNumber protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
diff := uint64(packetNumber - highestAckedPacketNumber)
|
||||
if diff < (2 << (uint8(protocol.PacketNumberLen1)*8 - 2)) {
|
||||
return protocol.PacketNumberLen1
|
||||
}
|
||||
if diff < (2 << (uint8(protocol.PacketNumberLen2)*8 - 2)) {
|
||||
return protocol.PacketNumberLen2
|
||||
}
|
||||
if diff < (2 << (uint8(protocol.PacketNumberLen4)*8 - 2)) {
|
||||
return protocol.PacketNumberLen4
|
||||
}
|
||||
// we do not check if there are less than 2^46 packets in flight, since flow control and congestion control will limit this number *a lot* sooner
|
||||
return protocol.PacketNumberLen6
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, con
|
||||
1,
|
||||
))
|
||||
|
||||
packetNumberLen := getPacketNumberLength(currentPacketNumber, p.sentPacketHandler.GetLargestObserved())
|
||||
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, p.sentPacketHandler.GetLargestObserved())
|
||||
responsePublicHeader := &PublicHeader{
|
||||
ConnectionID: p.connectionID,
|
||||
PacketNumber: currentPacketNumber,
|
||||
|
||||
44
protocol/packet_number.go
Normal file
44
protocol/packet_number.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package protocol
|
||||
|
||||
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
|
||||
func InferPacketNumber(packetNumberLength PacketNumberLen, lastPacketNumber PacketNumber, wirePacketNumber PacketNumber) PacketNumber {
|
||||
epochDelta := PacketNumber(1) << (uint8(packetNumberLength) * 8)
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
prevEpochBegin := epoch - epochDelta
|
||||
nextEpochBegin := epoch + epochDelta
|
||||
return closestTo(
|
||||
lastPacketNumber+1,
|
||||
epoch+wirePacketNumber,
|
||||
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
|
||||
)
|
||||
}
|
||||
|
||||
func closestTo(target, a, b PacketNumber) PacketNumber {
|
||||
if delta(target, a) < delta(target, b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func delta(a, b PacketNumber) PacketNumber {
|
||||
if a < b {
|
||||
return b - a
|
||||
}
|
||||
return a - b
|
||||
}
|
||||
|
||||
// GetPacketNumberLengthForPublicHeader gets the length of the packet number for the public header
|
||||
func GetPacketNumberLengthForPublicHeader(packetNumber PacketNumber, highestAckedPacketNumber PacketNumber) PacketNumberLen {
|
||||
diff := uint64(packetNumber - highestAckedPacketNumber)
|
||||
if diff < (2 << (uint8(PacketNumberLen1)*8 - 2)) {
|
||||
return PacketNumberLen1
|
||||
}
|
||||
if diff < (2 << (uint8(PacketNumberLen2)*8 - 2)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
if diff < (2 << (uint8(PacketNumberLen4)*8 - 2)) {
|
||||
return PacketNumberLen4
|
||||
}
|
||||
// we do not check if there are less than 2^46 packets in flight, since flow control and congestion control will limit this number *a lot* sooner
|
||||
return PacketNumberLen6
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package quic
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -12,13 +11,13 @@ import (
|
||||
// Tests taken and extended from chrome
|
||||
var _ = Describe("packet number calculation", func() {
|
||||
Context("infering a packet number", func() {
|
||||
check := func(length protocol.PacketNumberLen, expected, last uint64) {
|
||||
check := func(length PacketNumberLen, expected, last uint64) {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
epochMask := epoch - 1
|
||||
wirePacketNumber := expected & epochMask
|
||||
Expect(inferPacketNumber(length, protocol.PacketNumber(last), protocol.PacketNumber(wirePacketNumber))).To(Equal(protocol.PacketNumber(expected)))
|
||||
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber))).To(Equal(PacketNumber(expected)))
|
||||
}
|
||||
for _, length := range []protocol.PacketNumberLen{protocol.PacketNumberLen1, protocol.PacketNumberLen2, protocol.PacketNumberLen4, protocol.PacketNumberLen6} {
|
||||
for _, length := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4, PacketNumberLen6} {
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
epochMask := epoch - 1
|
||||
@@ -129,42 +128,42 @@ var _ = Describe("packet number calculation", func() {
|
||||
|
||||
Context("shortening a packet number", func() {
|
||||
It("sends out low packet numbers as 1 byte", func() {
|
||||
length := getPacketNumberLength(4, 2)
|
||||
Expect(length).To(Equal(protocol.PacketNumberLen1))
|
||||
length := GetPacketNumberLengthForPublicHeader(4, 2)
|
||||
Expect(length).To(Equal(PacketNumberLen1))
|
||||
})
|
||||
|
||||
It("sends out high packet numbers as 1 byte, if all ACKs are received", func() {
|
||||
length := getPacketNumberLength(0xDEADBEEF, 0xDEADBEEF-1)
|
||||
Expect(length).To(Equal(protocol.PacketNumberLen1))
|
||||
length := GetPacketNumberLengthForPublicHeader(0xDEADBEEF, 0xDEADBEEF-1)
|
||||
Expect(length).To(Equal(PacketNumberLen1))
|
||||
})
|
||||
|
||||
It("sends out higher packet numbers as 2 bytes, if a lot of ACKs are missing", func() {
|
||||
length := getPacketNumberLength(200, 2)
|
||||
Expect(length).To(Equal(protocol.PacketNumberLen2))
|
||||
length := GetPacketNumberLengthForPublicHeader(200, 2)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
})
|
||||
|
||||
Context("self-consistency", func() {
|
||||
It("works for small packet numbers", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := protocol.PacketNumber(i)
|
||||
highestAcked := protocol.PacketNumber(1)
|
||||
length := getPacketNumberLength(packetNumber, highestAcked)
|
||||
packetNumber := PacketNumber(i)
|
||||
highestAcked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := inferPacketNumber(length, highestAcked, protocol.PacketNumber(wirePacketNumber))
|
||||
inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("works for small packet numbers and increasing ACKed packets", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := protocol.PacketNumber(i)
|
||||
highestAcked := protocol.PacketNumber(i / 2)
|
||||
length := getPacketNumberLength(packetNumber, highestAcked)
|
||||
packetNumber := PacketNumber(i)
|
||||
highestAcked := PacketNumber(i / 2)
|
||||
length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := inferPacketNumber(length, highestAcked, protocol.PacketNumber(wirePacketNumber))
|
||||
inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
@@ -172,20 +171,20 @@ var _ = Describe("packet number calculation", func() {
|
||||
It("also works for larger packet numbers", func() {
|
||||
increment := uint64(1)
|
||||
for i := uint64(1); i < (2 << 46); i += increment {
|
||||
packetNumber := protocol.PacketNumber(i)
|
||||
highestAcked := protocol.PacketNumber(1)
|
||||
length := getPacketNumberLength(packetNumber, highestAcked)
|
||||
packetNumber := PacketNumber(i)
|
||||
highestAcked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := inferPacketNumber(length, highestAcked, protocol.PacketNumber(wirePacketNumber))
|
||||
inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
|
||||
switch length {
|
||||
case protocol.PacketNumberLen2:
|
||||
case PacketNumberLen2:
|
||||
increment = 100
|
||||
case protocol.PacketNumberLen4:
|
||||
case PacketNumberLen4:
|
||||
increment = 50000
|
||||
case protocol.PacketNumberLen6:
|
||||
case PacketNumberLen6:
|
||||
increment = 100000000
|
||||
}
|
||||
}
|
||||
@@ -193,12 +192,12 @@ var _ = Describe("packet number calculation", func() {
|
||||
|
||||
It("works for packet numbers larger than 2^48", func() {
|
||||
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 45) {
|
||||
packetNumber := protocol.PacketNumber(i)
|
||||
highestAcked := protocol.PacketNumber(i - 1000)
|
||||
length := getPacketNumberLength(packetNumber, highestAcked)
|
||||
packetNumber := PacketNumber(i)
|
||||
highestAcked := PacketNumber(i - 1000)
|
||||
length := GetPacketNumberLengthForPublicHeader(packetNumber, highestAcked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := inferPacketNumber(length, highestAcked, protocol.PacketNumber(wirePacketNumber))
|
||||
inferedPacketNumber := InferPacketNumber(length, highestAcked, PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
@@ -160,7 +160,7 @@ func (s *Session) Run() {
|
||||
|
||||
func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeader, r *bytes.Reader) error {
|
||||
// Calcualate packet number
|
||||
publicHeader.PacketNumber = inferPacketNumber(
|
||||
publicHeader.PacketNumber = protocol.InferPacketNumber(
|
||||
publicHeader.PacketNumberLen,
|
||||
s.lastRcvdPacketNumber,
|
||||
publicHeader.PacketNumber,
|
||||
|
||||
Reference in New Issue
Block a user