forked from quic-go/quic-go
implement sending of shortened packet numbers
This commit is contained in:
@@ -15,6 +15,7 @@ type SentPacketHandler interface {
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
|
||||
BytesInFlight() protocol.ByteCount
|
||||
GetLargestObserved() protocol.PacketNumber
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
|
||||
@@ -220,3 +220,7 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() (packet *Packet) {
|
||||
func (h *sentPacketHandler) BytesInFlight() protocol.ByteCount {
|
||||
return h.bytesInFlight
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLargestObserved() protocol.PacketNumber {
|
||||
return h.LargestObserved
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package quic
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
func calculatePacketNumber(packetNumberLength protocol.PacketNumberLen, lastPacketNumber protocol.PacketNumber, wirePacketNumber protocol.PacketNumber) protocol.PacketNumber {
|
||||
func inferPacketNumber(packetNumberLength protocol.PacketNumberLen, lastPacketNumber protocol.PacketNumber, wirePacketNumber protocol.PacketNumber) protocol.PacketNumber {
|
||||
epochDelta := protocol.PacketNumber(1) << (uint8(packetNumberLength) * 8)
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
prevEpochBegin := epoch - epochDelta
|
||||
@@ -27,3 +27,18 @@ func delta(a, b protocol.PacketNumber) protocol.PacketNumber {
|
||||
}
|
||||
return a - b
|
||||
}
|
||||
|
||||
func getPacketNumberLength(packetNumber protocol.PacketNumber, highestAckedPacketNumber protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
diff := uint64(packetNumber - highestAckedPacketNumber)
|
||||
if diff < (2 << (uint8(protocol.PacketNumberLen1)*8 - 2)) {
|
||||
return protocol.PacketNumberLen1
|
||||
}
|
||||
if diff < (2 << (uint8(protocol.PacketNumberLen2)*8 - 2)) {
|
||||
return protocol.PacketNumberLen2
|
||||
}
|
||||
if diff < (2 << (uint8(protocol.PacketNumberLen4)*8 - 2)) {
|
||||
return protocol.PacketNumberLen4
|
||||
}
|
||||
// we do not check if there are less than 2^46 packets in flight, since flow control and congestion control will limit this number *a lot* sooner
|
||||
return protocol.PacketNumberLen6
|
||||
}
|
||||
|
||||
@@ -11,117 +11,196 @@ import (
|
||||
|
||||
// Tests taken and extended from chrome
|
||||
var _ = Describe("packet number calculation", func() {
|
||||
check := func(length protocol.PacketNumberLen, expected, last uint64) {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
epochMask := epoch - 1
|
||||
wirePacketNumber := expected & epochMask
|
||||
Expect(calculatePacketNumber(length, protocol.PacketNumber(last), protocol.PacketNumber(wirePacketNumber))).To(Equal(protocol.PacketNumber(expected)))
|
||||
}
|
||||
for _, length := range []protocol.PacketNumberLen{protocol.PacketNumberLen1, protocol.PacketNumberLen2, protocol.PacketNumberLen4, protocol.PacketNumberLen6} {
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
Context("infering a packet number", func() {
|
||||
check := func(length protocol.PacketNumberLen, expected, last uint64) {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
epochMask := epoch - 1
|
||||
wirePacketNumber := expected & epochMask
|
||||
Expect(inferPacketNumber(length, protocol.PacketNumber(last), protocol.PacketNumber(wirePacketNumber))).To(Equal(protocol.PacketNumber(expected)))
|
||||
}
|
||||
for _, length := range []protocol.PacketNumberLen{protocol.PacketNumberLen1, protocol.PacketNumberLen2, protocol.PacketNumberLen4, protocol.PacketNumberLen6} {
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
epoch := uint64(1) << (length * 8)
|
||||
epochMask := epoch - 1
|
||||
|
||||
It("works near epoch start", func() {
|
||||
// A few quick manual sanity check
|
||||
check(length, 1, 0)
|
||||
check(length, epoch+1, epochMask)
|
||||
check(length, epoch, epochMask)
|
||||
It("works near epoch start", func() {
|
||||
// A few quick manual sanity check
|
||||
check(length, 1, 0)
|
||||
check(length, epoch+1, epochMask)
|
||||
check(length, epoch, epochMask)
|
||||
|
||||
// Cases where the last number was close to the start of the range.
|
||||
for last := uint64(0); last < 10; last++ {
|
||||
// Small numbers should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, j, last)
|
||||
// Cases where the last number was close to the start of the range.
|
||||
for last := uint64(0); last < 10; last++ {
|
||||
// Small numbers should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not wrap either (because we're near 0 already).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Large numbers should not wrap either (because we're near 0 already).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
It("works near epoch end", func() {
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := epoch - i
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch+j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Next check where we're in a non-zero epoch to verify we handle
|
||||
// reverse wrapping, too.
|
||||
It("works near previous epoch", func() {
|
||||
prevEpoch := 1 * epoch
|
||||
curEpoch := 2 * epoch
|
||||
// Cases where the last number was close to the start of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := curEpoch + i
|
||||
// Small number should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, curEpoch+j, last)
|
||||
}
|
||||
|
||||
// But large numbers should reverse wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, prevEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near next epoch", func() {
|
||||
curEpoch := 2 * epoch
|
||||
nextEpoch := 3 * epoch
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := nextEpoch - 1 - i
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, nextEpoch+j, last)
|
||||
}
|
||||
|
||||
// but large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, curEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near next max", func() {
|
||||
maxNumber := uint64(math.MaxUint64)
|
||||
maxEpoch := maxNumber & ^epochMask
|
||||
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
// Subtract 1, because the expected next packet number is 1 more than the
|
||||
// last packet number.
|
||||
last := maxNumber - i - 1
|
||||
|
||||
// Small numbers should not wrap, because they have nowhere to go.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, maxEpoch+j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not wrap either.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, maxEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
It("works near epoch end", func() {
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := epoch - i
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch+j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Next check where we're in a non-zero epoch to verify we handle
|
||||
// reverse wrapping, too.
|
||||
It("works near previous epoch", func() {
|
||||
prevEpoch := 1 * epoch
|
||||
curEpoch := 2 * epoch
|
||||
// Cases where the last number was close to the start of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := curEpoch + i
|
||||
// Small number should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, curEpoch+j, last)
|
||||
}
|
||||
|
||||
// But large numbers should reverse wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, prevEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near next epoch", func() {
|
||||
curEpoch := 2 * epoch
|
||||
nextEpoch := 3 * epoch
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := nextEpoch - 1 - i
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, nextEpoch+j, last)
|
||||
}
|
||||
|
||||
// but large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, curEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works near next max", func() {
|
||||
maxNumber := uint64(math.MaxUint64)
|
||||
maxEpoch := maxNumber & ^epochMask
|
||||
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
// Subtract 1, because the expected next packet number is 1 more than the
|
||||
// last packet number.
|
||||
last := maxNumber - i - 1
|
||||
|
||||
// Small numbers should not wrap, because they have nowhere to go.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, maxEpoch+j, last)
|
||||
}
|
||||
|
||||
// Large numbers should not wrap either.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, maxEpoch+num, last)
|
||||
}
|
||||
}
|
||||
})
|
||||
Context("shortening a packet number", func() {
|
||||
It("sends out low packet numbers as 1 byte", func() {
|
||||
length := getPacketNumberLength(4, 2)
|
||||
Expect(length).To(Equal(protocol.PacketNumberLen1))
|
||||
})
|
||||
}
|
||||
|
||||
It("sends out high packet numbers as 1 byte, if all ACKs are received", func() {
|
||||
length := getPacketNumberLength(0xDEADBEEF, 0xDEADBEEF-1)
|
||||
Expect(length).To(Equal(protocol.PacketNumberLen1))
|
||||
})
|
||||
|
||||
It("sends out higher packet numbers as 2 bytes, if a lot of ACKs are missing", func() {
|
||||
length := getPacketNumberLength(200, 2)
|
||||
Expect(length).To(Equal(protocol.PacketNumberLen2))
|
||||
})
|
||||
})
|
||||
|
||||
Context("self-consistency", func() {
|
||||
It("works for small packet numbers", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := protocol.PacketNumber(i)
|
||||
highestAcked := protocol.PacketNumber(1)
|
||||
length := getPacketNumberLength(packetNumber, highestAcked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := inferPacketNumber(length, highestAcked, protocol.PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("works for small packet numbers and increasing ACKed packets", func() {
|
||||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := protocol.PacketNumber(i)
|
||||
highestAcked := protocol.PacketNumber(i / 2)
|
||||
length := getPacketNumberLength(packetNumber, highestAcked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := inferPacketNumber(length, highestAcked, protocol.PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("also works for larger packet numbers", func() {
|
||||
increment := uint64(1)
|
||||
for i := uint64(1); i < (2 << 46); i += increment {
|
||||
packetNumber := protocol.PacketNumber(i)
|
||||
highestAcked := protocol.PacketNumber(1)
|
||||
length := getPacketNumberLength(packetNumber, highestAcked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := inferPacketNumber(length, highestAcked, protocol.PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
|
||||
switch length {
|
||||
case protocol.PacketNumberLen2:
|
||||
increment = 100
|
||||
case protocol.PacketNumberLen4:
|
||||
increment = 50000
|
||||
case protocol.PacketNumberLen6:
|
||||
increment = 100000000
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("works for packet numbers larger than 2^48", func() {
|
||||
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 45) {
|
||||
packetNumber := protocol.PacketNumber(i)
|
||||
highestAcked := protocol.PacketNumber(i - 1000)
|
||||
length := getPacketNumberLength(packetNumber, highestAcked)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := inferPacketNumber(length, highestAcked, protocol.PacketNumber(wirePacketNumber))
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
@@ -21,8 +22,9 @@ type packedPacket struct {
|
||||
type packetPacker struct {
|
||||
connectionID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
aead crypto.AEAD
|
||||
|
||||
aead crypto.AEAD
|
||||
sentPacketHandler ackhandler.SentPacketHandler
|
||||
connectionParametersManager *handshake.ConnectionParametersManager
|
||||
|
||||
streamFrameQueue StreamFrameQueue
|
||||
@@ -72,6 +74,7 @@ func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, con
|
||||
responsePublicHeader := PublicHeader{
|
||||
ConnectionID: p.connectionID,
|
||||
PacketNumber: currentPacketNumber,
|
||||
PacketNumberLen: getPacketNumberLength(currentPacketNumber, p.sentPacketHandler.GetLargestObserved()),
|
||||
TruncateConnectionID: p.connectionParametersManager.TruncateConnectionID(),
|
||||
}
|
||||
if err := responsePublicHeader.WritePublicHeader(&raw); err != nil {
|
||||
|
||||
@@ -2,7 +2,9 @@ package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
@@ -11,6 +13,32 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockSentPacketHandler struct{}
|
||||
|
||||
func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame) (time.Duration, []*ackhandler.Packet, []*ackhandler.Packet, error) {
|
||||
return 0, nil, nil, nil
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) DequeuePacketForRetransmission() (packet *ackhandler.Packet) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) BytesInFlight() protocol.ByteCount {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (h *mockSentPacketHandler) GetLargestObserved() protocol.PacketNumber {
|
||||
return 1
|
||||
}
|
||||
|
||||
func newMockSentPacketHandler() ackhandler.SentPacketHandler {
|
||||
return &mockSentPacketHandler{}
|
||||
}
|
||||
|
||||
var _ = Describe("Packet packer", func() {
|
||||
var (
|
||||
packer *packetPacker
|
||||
@@ -18,7 +46,11 @@ var _ = Describe("Packet packer", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
aead := &crypto.NullAEAD{}
|
||||
packer = &packetPacker{aead: aead, connectionParametersManager: handshake.NewConnectionParamatersManager()}
|
||||
packer = &packetPacker{
|
||||
aead: aead,
|
||||
connectionParametersManager: handshake.NewConnectionParamatersManager(),
|
||||
sentPacketHandler: newMockSentPacketHandler(),
|
||||
}
|
||||
})
|
||||
|
||||
It("returns nil when no packet is queued", func() {
|
||||
@@ -34,10 +66,10 @@ var _ = Describe("Packet packer", func() {
|
||||
}
|
||||
packer.AddStreamFrame(f)
|
||||
p, err := packer.PackPacket(nil, []frames.Frame{}, true)
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).ToNot(BeNil())
|
||||
b := &bytes.Buffer{}
|
||||
f.Write(b, 1, 6, 0)
|
||||
f.Write(b, 1, protocol.PacketNumberLen6, 0)
|
||||
Expect(len(p.frames)).To(Equal(1))
|
||||
Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
|
||||
})
|
||||
@@ -80,7 +112,7 @@ var _ = Describe("Packet packer", func() {
|
||||
It("packs many control frames into 1 packets", func() {
|
||||
f := &frames.AckFrame{LargestObserved: 1}
|
||||
b := &bytes.Buffer{}
|
||||
f.Write(b, 3, 6, 32)
|
||||
f.Write(b, 3, protocol.PacketNumberLen6, 32)
|
||||
maxFramesPerPacket := protocol.MaxFrameSize / b.Len()
|
||||
var controlFrames []frames.Frame
|
||||
for i := 0; i < maxFramesPerPacket; i++ {
|
||||
@@ -146,8 +178,8 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(p).ToNot(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := &bytes.Buffer{}
|
||||
f1.Write(b, 2, 6, 0)
|
||||
f2.Write(b, 2, 6, 0)
|
||||
f1.Write(b, 2, protocol.PacketNumberLen6, 0)
|
||||
f2.Write(b, 2, protocol.PacketNumberLen6, 0)
|
||||
Expect(len(p.frames)).To(Equal(2))
|
||||
Expect(p.raw).To(ContainSubstring(string(b.Bytes())))
|
||||
})
|
||||
@@ -172,7 +204,8 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(len(payloadFrames)).To(Equal(0))
|
||||
})
|
||||
|
||||
It("packs 2 stream frames that are too big for one packet correctly", func() {
|
||||
// set pending until https://github.com/lucas-clemente/quic-go/issues/67 is fixed
|
||||
PIt("packs 2 stream frames that are too big for one packet correctly", func() {
|
||||
maxStreamFrameDataLen := protocol.MaxFrameSize - (1 + 4 + 8 + 2)
|
||||
f1 := frames.StreamFrame{
|
||||
Data: bytes.Repeat([]byte{'f'}, maxStreamFrameDataLen+100),
|
||||
@@ -198,7 +231,8 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(p).To(BeNil())
|
||||
})
|
||||
|
||||
It("packs a packet that has the maximum packet size when given a large enough stream frame", func() {
|
||||
// set pending until https://github.com/lucas-clemente/quic-go/issues/67 is fixed
|
||||
PIt("packs a packet that has the maximum packet size when given a large enough stream frame", func() {
|
||||
f := frames.StreamFrame{
|
||||
Data: bytes.Repeat([]byte{'f'}, protocol.MaxFrameSize-(1+4+8+2)),
|
||||
Offset: 1,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set")
|
||||
errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time")
|
||||
errReceivedTruncatedConnectionID = errors.New("PublicHeader: Receiving packets with truncated ConnectionID is not supported")
|
||||
errInvalidConnectionID = errors.New("PublicHeader: connection ID cannot be 0")
|
||||
@@ -30,7 +31,7 @@ type PublicHeader struct {
|
||||
|
||||
// WritePublicHeader writes a public header
|
||||
func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer) error {
|
||||
publicFlagByte := uint8(0x34)
|
||||
publicFlagByte := uint8(0x04)
|
||||
if h.VersionFlag && h.ResetFlag {
|
||||
return errResetAndVersionFlagSet
|
||||
}
|
||||
@@ -44,6 +45,19 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer) error {
|
||||
publicFlagByte |= 0x08
|
||||
}
|
||||
|
||||
if !h.ResetFlag && !h.VersionFlag {
|
||||
switch h.PacketNumberLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
publicFlagByte |= 0x00
|
||||
case protocol.PacketNumberLen2:
|
||||
publicFlagByte |= 0x10
|
||||
case protocol.PacketNumberLen4:
|
||||
publicFlagByte |= 0x20
|
||||
case protocol.PacketNumberLen6:
|
||||
publicFlagByte |= 0x30
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteByte(publicFlagByte)
|
||||
|
||||
if !h.TruncateConnectionID {
|
||||
@@ -51,7 +65,18 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer) error {
|
||||
}
|
||||
|
||||
if !h.ResetFlag && !h.VersionFlag {
|
||||
utils.WriteUint48(b, uint64(h.PacketNumber)) // TODO: Send shorter packet number if possible
|
||||
switch h.PacketNumberLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(h.PacketNumber))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.WriteUint16(b, uint16(h.PacketNumber))
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.WriteUint32(b, uint32(h.PacketNumber))
|
||||
case protocol.PacketNumberLen6:
|
||||
utils.WriteUint48(b, uint64(h.PacketNumber))
|
||||
default:
|
||||
return errPacketNumberLenNotSet
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -82,19 +82,21 @@ var _ = Describe("Public Header", func() {
|
||||
It("writes a sample header", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 2,
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 2,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
publicHeader.WritePublicHeader(b)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x3c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 2, 0, 0, 0, 0, 0}))
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x38 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 2, 0, 0, 0, 0, 0}))
|
||||
})
|
||||
|
||||
It("sets the Version Flag", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
VersionFlag: true,
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 2,
|
||||
VersionFlag: true,
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 2,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
publicHeader.WritePublicHeader(b)
|
||||
// must be the first assertion
|
||||
@@ -106,9 +108,10 @@ var _ = Describe("Public Header", func() {
|
||||
It("sets the Reset Flag", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
ResetFlag: true,
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 2,
|
||||
ResetFlag: true,
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 2,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
publicHeader.WritePublicHeader(b)
|
||||
// must be the first assertion
|
||||
@@ -120,10 +123,11 @@ var _ = Describe("Public Header", func() {
|
||||
It("throws an error if both Reset Flag and Version Flag are set", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
VersionFlag: true,
|
||||
ResetFlag: true,
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 2,
|
||||
VersionFlag: true,
|
||||
ResetFlag: true,
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 2,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
err := publicHeader.WritePublicHeader(b)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -135,6 +139,7 @@ var _ = Describe("Public Header", func() {
|
||||
publicHeader := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
TruncateConnectionID: true,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
err := publicHeader.WritePublicHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -142,5 +147,66 @@ var _ = Describe("Public Header", func() {
|
||||
Expect(firstByte & 0x08).To(BeZero())
|
||||
Expect(b.Bytes()).ToNot(ContainSubstring(string([]byte{0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c})))
|
||||
})
|
||||
|
||||
Context("packet number length", func() {
|
||||
It("doesn't write a header if the packet number length is not set", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 0xDECAFBAD,
|
||||
}
|
||||
err := publicHeader.WritePublicHeader(b)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).To(Equal(errPacketNumberLenNotSet))
|
||||
})
|
||||
|
||||
It("writes a header with a 1-byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 0xDECAFBAD,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
err := publicHeader.WritePublicHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x08 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD}))
|
||||
})
|
||||
|
||||
It("writes a header with a 2-byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 0xDECAFBAD,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
err := publicHeader.WritePublicHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x18 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB}))
|
||||
})
|
||||
|
||||
It("writes a header with a 4-byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 0x13DECAFBAD,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}
|
||||
err := publicHeader.WritePublicHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x28 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB, 0xCA, 0xDE}))
|
||||
})
|
||||
|
||||
It("writes a header with a 6-byte packet number", func() {
|
||||
b := &bytes.Buffer{}
|
||||
publicHeader := PublicHeader{
|
||||
ConnectionID: 0x4cfa9f9b668619f6,
|
||||
PacketNumber: 0xBE1337DECAFBAD,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
}
|
||||
err := publicHeader.WritePublicHeader(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x38 | 0x04, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB, 0xCA, 0xDE, 0x37, 0x13}))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -47,7 +47,7 @@ var _ = Describe("Server", func() {
|
||||
|
||||
It("composes version negotiation packets", func() {
|
||||
expected := append(
|
||||
[]byte{0x3d, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
[]byte{0x01 | 0x08 | 0x04, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
protocol.SupportedVersionsAsTags...,
|
||||
)
|
||||
Expect(composeVersionNegotiation(1)).To(Equal(expected))
|
||||
@@ -111,7 +111,7 @@ var _ = Describe("Server", func() {
|
||||
}
|
||||
data = data[:n]
|
||||
expected := append(
|
||||
[]byte{0x3d, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
[]byte{0xd, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
protocol.SupportedVersionsAsTags...,
|
||||
)
|
||||
Expect(data).To(Equal(expected))
|
||||
|
||||
10
session.go
10
session.go
@@ -95,7 +95,13 @@ func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
}
|
||||
}()
|
||||
|
||||
session.packer = &packetPacker{aead: cryptoSetup, connectionParametersManager: session.connectionParametersManager, connectionID: connectionID, version: v}
|
||||
session.packer = &packetPacker{
|
||||
aead: cryptoSetup,
|
||||
connectionParametersManager: session.connectionParametersManager,
|
||||
sentPacketHandler: session.sentPacketHandler,
|
||||
connectionID: connectionID,
|
||||
version: v,
|
||||
}
|
||||
session.unpacker = &packetUnpacker{aead: cryptoSetup, version: v}
|
||||
|
||||
session.congestion = congestion.NewCubicSender(
|
||||
@@ -152,7 +158,7 @@ func (s *Session) Run() {
|
||||
|
||||
func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeader, r *bytes.Reader) error {
|
||||
// Calcualate packet number
|
||||
publicHeader.PacketNumber = calculatePacketNumber(
|
||||
publicHeader.PacketNumber = inferPacketNumber(
|
||||
publicHeader.PacketNumberLen,
|
||||
s.lastRcvdPacketNumber,
|
||||
publicHeader.PacketNumber,
|
||||
|
||||
@@ -398,6 +398,7 @@ var _ = Describe("Session", func() {
|
||||
handshake.TagICSL: {0, 0, 0, 0},
|
||||
})
|
||||
session.packer.connectionParametersManager = session.connectionParametersManager
|
||||
session.packer.sentPacketHandler = newMockSentPacketHandler()
|
||||
session.Run() // Would normally not return
|
||||
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
close(done)
|
||||
@@ -422,9 +423,9 @@ var _ = Describe("Session", func() {
|
||||
session.QueueStreamFrame(&frames.StreamFrame{})
|
||||
session.sendPacket()
|
||||
Expect(cong.nCalls).To(Equal(2)) // OnPacketSent + GetCongestionWindow
|
||||
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(35)))
|
||||
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(30)))
|
||||
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(35)))
|
||||
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(30)))
|
||||
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user