forked from quic-go/quic-go
accept truncated packet numbers
This commit is contained in:
29
packet_number.go
Normal file
29
packet_number.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package quic
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
func calculatePacketNumber(packetNumberLength uint8, lastPacketNumber protocol.PacketNumber, wirePacketNumber protocol.PacketNumber) protocol.PacketNumber {
|
||||
epochDelta := protocol.PacketNumber(1) << (packetNumberLength * 8)
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
prevEpochBegin := epoch - epochDelta
|
||||
nextEpochBegin := epoch + epochDelta
|
||||
return closestTo(
|
||||
lastPacketNumber+1,
|
||||
epoch+wirePacketNumber,
|
||||
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
|
||||
)
|
||||
}
|
||||
|
||||
func closestTo(target, a, b protocol.PacketNumber) protocol.PacketNumber {
|
||||
if delta(target, a) < delta(target, b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func delta(a, b protocol.PacketNumber) protocol.PacketNumber {
|
||||
if a < b {
|
||||
return b - a
|
||||
}
|
||||
return a - b
|
||||
}
|
||||
127
packet_number_test.go
Normal file
127
packet_number_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Tests taken and extended from chrome
|
||||
var _ = Describe("packet number calculation", func() {
|
||||
check := func(length uint8, expected, last uint64) {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
epochMask := epoch - 1
|
||||
wirePacketNumber := expected & epochMask
|
||||
Expect(calculatePacketNumber(length, protocol.PacketNumber(last), protocol.PacketNumber(wirePacketNumber))).To(Equal(protocol.PacketNumber(expected)))
|
||||
}
|
||||
for _, length := range []uint8{1, 2, 4, 6} {
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
epochMask := epoch - 1
|
||||
|
||||
It("works near epoch start", func() {
|
||||
// A few quick manual sanity check
|
||||
check(length, 1, 0)
|
||||
check(length, epoch+1, epochMask)
|
||||
check(length, epoch, epochMask)
|
||||
|
||||
// Cases where the last number was close to the start of the range.
|
||||
for last := uint64(0); last < 10; last++ {
|
||||
// Small numbers should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not wrap either (because we're near 0 already).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near epoch end", func() {
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := epoch - i
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch+j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Next check where we're in a non-zero epoch to verify we handle
|
||||
// reverse wrapping, too.
|
||||
It("works near previous epoch", func() {
|
||||
prevEpoch := 1 * epoch
|
||||
curEpoch := 2 * epoch
|
||||
// Cases where the last number was close to the start of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := curEpoch + i
|
||||
// Small number should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, curEpoch+j, last)
|
||||
}
|
||||
|
||||
// But large numbers should reverse wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, prevEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near next epoch", func() {
|
||||
curEpoch := 2 * epoch
|
||||
nextEpoch := 3 * epoch
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := nextEpoch - 1 - i
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, nextEpoch+j, last)
|
||||
}
|
||||
|
||||
// but large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, curEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near next max", func() {
|
||||
maxNumber := uint64(math.MaxUint64)
|
||||
maxEpoch := maxNumber & ^epochMask
|
||||
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
// Subtract 1, because the expected next packet number is 1 more than the
|
||||
// last packet number.
|
||||
last := maxNumber - i - 1
|
||||
|
||||
// Small numbers should not wrap, because they have nowhere to go.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, maxEpoch+j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not wrap either.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, maxEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
12
session.go
12
session.go
@@ -42,6 +42,9 @@ type Session struct {
|
||||
packer *packetPacker
|
||||
|
||||
receivedPackets chan receivedPacket
|
||||
|
||||
// Used to calculate the next packet number from the truncated wire representation
|
||||
lastRcvdPacketNumber protocol.PacketNumber
|
||||
}
|
||||
|
||||
// NewSession makes a new session
|
||||
@@ -90,6 +93,15 @@ func (s *Session) Run() {
|
||||
}
|
||||
|
||||
func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) error {
|
||||
// Calcualate packet number
|
||||
publicHeader.PacketNumber = calculatePacketNumber(
|
||||
publicHeader.PacketNumberLen,
|
||||
s.lastRcvdPacketNumber,
|
||||
publicHeader.PacketNumber,
|
||||
)
|
||||
s.lastRcvdPacketNumber = publicHeader.PacketNumber
|
||||
fmt.Printf("<- Reading packet %d for connection %d\n", publicHeader.PacketNumber, publicHeader.ConnectionID)
|
||||
|
||||
// TODO: Only do this after authenticating
|
||||
if addr != s.currentRemoteAddr {
|
||||
s.currentRemoteAddr = addr
|
||||
|
||||
Reference in New Issue
Block a user