diff --git a/packet_number.go b/packet_number.go new file mode 100644 index 00000000..b8eac122 --- /dev/null +++ b/packet_number.go @@ -0,0 +1,29 @@ +package quic + +import "github.com/lucas-clemente/quic-go/protocol" + +func calculatePacketNumber(packetNumberLength uint8, lastPacketNumber protocol.PacketNumber, wirePacketNumber protocol.PacketNumber) protocol.PacketNumber { + epochDelta := protocol.PacketNumber(1) << (packetNumberLength * 8) + epoch := lastPacketNumber & ^(epochDelta - 1) + prevEpochBegin := epoch - epochDelta + nextEpochBegin := epoch + epochDelta + return closestTo( + lastPacketNumber+1, + epoch+wirePacketNumber, + closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber), + ) +} + +func closestTo(target, a, b protocol.PacketNumber) protocol.PacketNumber { + if delta(target, a) < delta(target, b) { + return a + } + return b +} + +func delta(a, b protocol.PacketNumber) protocol.PacketNumber { + if a < b { + return b - a + } + return a - b +} diff --git a/packet_number_test.go b/packet_number_test.go new file mode 100644 index 00000000..2d706cd8 --- /dev/null +++ b/packet_number_test.go @@ -0,0 +1,127 @@ +package quic + +import ( + "fmt" + "math" + + "github.com/lucas-clemente/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// Tests taken and extended from chrome +var _ = Describe("packet number calculation", func() { + check := func(length uint8, expected, last uint64) { + epoch := uint64(1) << (length * 8) + epochMask := epoch - 1 + wirePacketNumber := expected & epochMask + Expect(calculatePacketNumber(length, protocol.PacketNumber(last), protocol.PacketNumber(wirePacketNumber))).To(Equal(protocol.PacketNumber(expected))) + } + for _, length := range []uint8{1, 2, 4, 6} { + Context(fmt.Sprintf("with %d bytes", length), func() { + epoch := uint64(1) << (length * 8) + epochMask := epoch - 1 + + It("works near epoch start", func() { + // A few quick manual sanity check + check(length, 1, 0) + check(length, epoch+1, epochMask) + check(length, epoch, epochMask) + + // Cases where the last number was close to the start of the range. + for last := uint64(0); last < 10; last++ { + // Small numbers should not wrap (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, j, last) + } + + // Large numbers should not wrap either (because we're near 0 already). + for j := uint64(0); j < 10; j++ { + check(length, epoch-1-j, last) + } + } + }) + + It("works near epoch end", func() { + // Cases where the last number was close to the end of the range + for i := uint64(0); i < 10; i++ { + last := epoch - i + + // Small numbers should wrap. + for j := uint64(0); j < 10; j++ { + check(length, epoch+j, last) + } + + // Large numbers should not (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, epoch-1-j, last) + } + } + }) + + // Next check where we're in a non-zero epoch to verify we handle + // reverse wrapping, too. + It("works near previous epoch", func() { + prevEpoch := 1 * epoch + curEpoch := 2 * epoch + // Cases where the last number was close to the start of the range + for i := uint64(0); i < 10; i++ { + last := curEpoch + i + // Small number should not wrap (even if they're out of order). + for j := uint64(0); j < 10; j++ { + check(length, curEpoch+j, last) + } + + // But large numbers should reverse wrap. + for j := uint64(0); j < 10; j++ { + num := epoch - 1 - j + check(length, prevEpoch+num, last) + } + } + }) + + It("works near next epoch", func() { + curEpoch := 2 * epoch + nextEpoch := 3 * epoch + // Cases where the last number was close to the end of the range + for i := uint64(0); i < 10; i++ { + last := nextEpoch - 1 - i + + // Small numbers should wrap. + for j := uint64(0); j < 10; j++ { + check(length, nextEpoch+j, last) + } + + // but large numbers should not (even if they're out of order). + for j := uint64(0); j < 10; j++ { + num := epoch - 1 - j + check(length, curEpoch+num, last) + } + } + }) + + It("works near next max", func() { + maxNumber := uint64(math.MaxUint64) + maxEpoch := maxNumber & ^epochMask + + // Cases where the last number was close to the end of the range + for i := uint64(0); i < 10; i++ { + // Subtract 1, because the expected next packet number is 1 more than the + // last packet number. + last := maxNumber - i - 1 + + // Small numbers should not wrap, because they have nowhere to go. + for j := uint64(0); j < 10; j++ { + check(length, maxEpoch+j, last) + } + + // Large numbers should not wrap either. + for j := uint64(0); j < 10; j++ { + num := epoch - 1 - j + check(length, maxEpoch+num, last) + } + } + }) + }) + } +}) diff --git a/session.go b/session.go index d3b1719b..d2317313 100644 --- a/session.go +++ b/session.go @@ -42,6 +42,9 @@ type Session struct { packer *packetPacker receivedPackets chan receivedPacket + + // Used to calculate the next packet number from the truncated wire representation + lastRcvdPacketNumber protocol.PacketNumber } // NewSession makes a new session @@ -90,6 +93,15 @@ func (s *Session) Run() { } func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) error { + // Calcualate packet number + publicHeader.PacketNumber = calculatePacketNumber( + publicHeader.PacketNumberLen, + s.lastRcvdPacketNumber, + publicHeader.PacketNumber, + ) + s.lastRcvdPacketNumber = publicHeader.PacketNumber + fmt.Printf("<- Reading packet %d for connection %d\n", publicHeader.PacketNumber, publicHeader.ConnectionID) + // TODO: Only do this after authenticating if addr != s.currentRemoteAddr { s.currentRemoteAddr = addr