Merge pull request #1698 from lucas-clemente/coalesced-packets

implement parsing of coalesced packets
This commit is contained in:
Marten Seemann
2019-01-01 10:03:18 +07:00
committed by GitHub
12 changed files with 377 additions and 220 deletions

View File

@@ -6,22 +6,52 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var bufferPool sync.Pool
type packetBuffer struct {
Slice []byte
func getPacketBuffer() *[]byte {
return bufferPool.Get().(*[]byte)
// refCount counts how many packets the Slice is used in.
// It doesn't support concurrent use.
// It is > 1 when used for coalesced packet.
refCount int
}
func putPacketBuffer(buf *[]byte) {
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
// Split increases the refCount.
// It must be called when a packet buffer is used for more than one packet,
// e.g. when splitting coalesced packets.
func (b *packetBuffer) Split() {
b.refCount++
}
// Release decreases the refCount.
// It should be called when processing the packet is finished.
// When the refCount reaches 0, the packet buffer is put back into the pool.
func (b *packetBuffer) Release() {
if cap(b.Slice) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!")
}
bufferPool.Put(buf)
b.refCount--
if b.refCount < 0 {
panic("negative packetBuffer refCount")
}
// only put the packetBuffer back if it's not used any more
if b.refCount == 0 {
bufferPool.Put(b)
}
}
var bufferPool sync.Pool
func getPacketBuffer() *packetBuffer {
buf := bufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Slice = buf.Slice[:protocol.MaxReceivePacketSize]
return buf
}
func init() {
bufferPool.New = func() interface{} {
b := make([]byte, 0, protocol.MaxReceivePacketSize)
return &b
return &packetBuffer{
Slice: make([]byte, 0, protocol.MaxReceivePacketSize),
}
}
}

View File

@@ -9,13 +9,35 @@ import (
var _ = Describe("Buffer Pool", func() {
It("returns buffers of cap", func() {
buf := *getPacketBuffer()
Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize)))
buf := getPacketBuffer()
Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize)))
})
It("releases buffers", func() {
buf := getPacketBuffer()
buf.Release()
})
It("panics if wrong-sized buffers are passed", func() {
Expect(func() {
putPacketBuffer(&[]byte{0})
}).To(Panic())
buf := getPacketBuffer()
buf.Slice = make([]byte, 10)
Expect(func() { buf.Release() }).To(Panic())
})
It("panics if it is released twice", func() {
buf := getPacketBuffer()
buf.Release()
Expect(func() { buf.Release() }).To(Panic())
})
It("waits until all parts have been released", func() {
buf := getPacketBuffer()
buf.Split()
buf.Split()
// now we have 3 parts
buf.Release()
buf.Release()
buf.Release()
Expect(func() { buf.Release() }).To(Panic())
})
})

View File

@@ -144,8 +144,8 @@ func (h *packetHandlerMap) close(e error) error {
func (h *packetHandlerMap) listen() {
for {
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
buffer := getPacketBuffer()
data := buffer.Slice
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err := h.conn.ReadFrom(data)
@@ -153,55 +153,110 @@ func (h *packetHandlerMap) listen() {
h.close(err)
return
}
data = data[:n]
if err := h.handlePacket(addr, data); err != nil {
h.logger.Debugf("error handling packet from %s: %s", addr, err)
}
h.handlePacket(addr, buffer, data[:n])
}
}
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
r := bytes.NewReader(data)
hdr, err := wire.ParseHeader(r, h.connIDLen)
// drop the packet if we can't parse the header
func (h *packetHandlerMap) handlePacket(
addr net.Addr,
buffer *packetBuffer,
data []byte,
) {
packets, err := h.parsePacket(addr, buffer, data)
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
h.logger.Debugf("error parsing packets from %s: %s", addr, err)
// This is just the error from parsing the last packet.
// We still need to process the packets that were successfully parsed before.
}
p := &receivedPacket{
remoteAddr: addr,
hdr: hdr,
data: data,
rcvTime: time.Now(),
if len(packets) == 0 {
buffer.Release()
return
}
h.handleParsedPackets(packets)
}
func (h *packetHandlerMap) parsePacket(
addr net.Addr,
buffer *packetBuffer,
data []byte,
) ([]*receivedPacket, error) {
rcvTime := time.Now()
packets := make([]*receivedPacket, 0, 1)
var counter int
var lastConnID protocol.ConnectionID
for len(data) > 0 {
if counter > 0 && h.logger.Debug() {
h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes", counter, len(packets[counter-1].data))
}
hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen)
// drop the packet if we can't parse the header
if err != nil {
return packets, fmt.Errorf("error parsing header: %s", err)
}
if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) {
return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
}
lastConnID = hdr.DestConnectionID
var rest []byte
if hdr.IsLongHeader {
if protocol.ByteCount(len(data)) < hdr.Length {
return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
rest = data[packetLen:]
data = data[:packetLen]
}
if counter > 0 {
buffer.Split()
}
counter++
packets = append(packets, &receivedPacket{
remoteAddr: addr,
hdr: hdr,
rcvTime: rcvTime,
data: data,
buffer: buffer,
})
data = rest
}
return packets, nil
}
func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) {
h.mutex.RLock()
defer h.mutex.RUnlock()
handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
// coalesced packets all have the same destination connection ID
handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)]
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
return nil
}
// No session found.
// This might be a stateless reset.
if !hdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
return nil
}
for _, p := range packets {
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
continue
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
// No session found.
// This might be a stateless reset.
if !p.hdr.IsLongHeader {
if len(p.data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], p.data[len(p.data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
continue
}
}
// TODO(#943): send a stateless reset
h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
break // a short header packet is always the last in a coalesced packet
}
if h.server != nil { // no server set
h.server.handlePacket(p)
}
h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
}
if h.server == nil { // no server set
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
h.server.handlePacket(p)
return nil
}

View File

@@ -3,6 +3,7 @@ package quic
import (
"bytes"
"errors"
"net"
"time"
"github.com/golang/mock/gomock"
@@ -19,21 +20,25 @@ var _ = Describe("Packet Handler Map", func() {
conn *mockPacketConn
)
getPacket := func(connID protocol.ConnectionID) []byte {
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte {
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Length: 1,
Length: length,
Version: protocol.VersionTLS,
},
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumberLen: protocol.PacketNumberLen2,
}).Write(buf, protocol.VersionWhatever)).To(Succeed())
return buf.Bytes()
}
getPacket := func(connID protocol.ConnectionID) []byte {
return getPacketWithLength(connID, 2)
}
BeforeEach(func() {
conn = newMockPacketConn()
handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap)
@@ -81,7 +86,7 @@ var _ = Describe("Packet Handler Map", func() {
})
It("drops unparseable packets", func() {
err := handler.handlePacket(nil, []byte{0, 1, 2, 3})
_, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error parsing header:"))
})
@@ -91,7 +96,8 @@ var _ = Describe("Packet Handler Map", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID)
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
handler.handlePacket(nil, nil, getPacket(connID))
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("deletes retired session entries after a wait time", func() {
@@ -100,7 +106,8 @@ var _ = Describe("Packet Handler Map", func() {
handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond))
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
handler.handlePacket(nil, nil, getPacket(connID))
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("passes packets arriving late for closed sessions to that session", func() {
@@ -110,14 +117,12 @@ var _ = Describe("Packet Handler Map", func() {
packetHandler.EXPECT().handlePacket(gomock.Any())
handler.Add(connID, packetHandler)
handler.Retire(connID)
err := handler.handlePacket(nil, getPacket(connID))
Expect(err).ToNot(HaveOccurred())
handler.handlePacket(nil, nil, getPacket(connID))
})
It("drops packets for unknown receivers", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
err := handler.handlePacket(nil, getPacket(connID))
Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
handler.handlePacket(nil, nil, getPacket(connID))
})
It("closes the packet handlers when reading from the conn fails", func() {
@@ -131,6 +136,75 @@ var _ = Describe("Packet Handler Map", func() {
conn.Close()
Eventually(done).Should(BeClosed())
})
Context("coalesced packets", func() {
It("errors on packets that are smaller than the length in the packet header", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...)
_, err := handler.parsePacket(nil, nil, data)
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
})
It("cuts packets to the right length", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...)
packetHandler := NewMockPacketHandler(mockCtrl)
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen())))
})
handler.Add(connID, packetHandler)
handler.handlePacket(nil, nil, data)
})
It("handles coalesced packets", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
packetHandler := NewMockPacketHandler(mockCtrl)
handledPackets := make(chan *receivedPacket, 3)
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
handledPackets <- p
}).Times(3)
handler.Add(connID, packetHandler)
buffer := getPacketBuffer()
packet := buffer.Slice[:0]
packet = append(packet, append(getPacketWithLength(connID, 10), make([]byte, 10-2 /* packet number len */)...)...)
packet = append(packet, append(getPacketWithLength(connID, 20), make([]byte, 20-2 /* packet number len */)...)...)
packet = append(packet, append(getPacketWithLength(connID, 30), make([]byte, 30-2 /* packet number len */)...)...)
conn.dataToRead <- packet
now := time.Now()
for i := 1; i <= 3; i++ {
var p *receivedPacket
Eventually(handledPackets).Should(Receive(&p))
Expect(p.hdr.DestConnectionID).To(Equal(connID))
Expect(p.hdr.Length).To(BeEquivalentTo(10 * i))
Expect(p.data).To(HaveLen(int(p.hdr.ParsedLen() + p.hdr.Length)))
Expect(p.rcvTime).To(BeTemporally("~", now, scaleDuration(20*time.Millisecond)))
Expect(p.buffer.refCount).To(Equal(3))
}
// makes the listen go routine return
packetHandler.EXPECT().destroy(gomock.Any()).AnyTimes()
close(conn.dataToRead)
})
It("ignores coalesced packet parts if the connection IDs don't match", func() {
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
buffer := getPacketBuffer()
packet := buffer.Slice[:0]
// var packet []byte
packet = append(packet, getPacket(connID1)...)
packet = append(packet, getPacket(connID2)...)
packets, err := handler.parsePacket(&net.UDPAddr{}, buffer, packet)
Expect(err).To(MatchError("coalesced packet has different destination connection ID: 0x0807060504030201, expected 0x0102030405060708"))
Expect(packets).To(HaveLen(1))
Expect(packets[0].hdr.DestConnectionID).To(Equal(connID1))
Expect(packets[0].buffer.refCount).To(Equal(1))
})
})
})
Context("stateless reset handling", func() {
@@ -164,6 +238,24 @@ var _ = Describe("Packet Handler Map", func() {
Eventually(destroyed).Should(BeClosed())
})
It("detects a stateless that is coalesced with another packet", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddWithResetToken(connID, packetHandler, token)
fakeConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
packet := getPacket(fakeConnID)
reset := append([]byte{0x40} /* short header packet */, fakeConnID...)
reset = append(reset, make([]byte, 50)...) // add some "random" data
reset = append(reset, token[:]...)
destroyed := make(chan struct{})
packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) {
close(destroyed)
})
conn.dataToRead <- append(packet, reset...)
Eventually(destroyed).Should(BeClosed())
})
It("deletes reset tokens when the session is retired", func() {
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}
@@ -171,10 +263,12 @@ var _ = Describe("Packet Handler Map", func() {
handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token)
handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond))
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42"))
handler.handlePacket(nil, nil, getPacket(connID))
// don't EXPECT any calls to handlePacket of the MockPacketHandler
packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
Expect(handler.handlePacket(nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99"))
handler.handlePacket(nil, nil, packet)
// don't EXPECT any calls to handlePacket of the MockPacketHandler
Expect(handler.resetTokens).To(BeEmpty())
})
})
@@ -188,7 +282,7 @@ var _ = Describe("Packet Handler Map", func() {
Expect(p.hdr.DestConnectionID).To(Equal(connID))
})
handler.SetServer(server)
Expect(handler.handlePacket(nil, p)).To(Succeed())
handler.handlePacket(nil, nil, p)
})
It("closes all server sessions", func() {
@@ -207,9 +301,10 @@ var _ = Describe("Packet Handler Map", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
handler.CloseServer()
Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788"))
handler.handlePacket(nil, nil, p)
})
})
})

View File

@@ -25,10 +25,25 @@ type packer interface {
}
type packedPacket struct {
header *wire.ExtendedHeader
raw []byte
frames []wire.Frame
encryptionLevel protocol.EncryptionLevel
header *wire.ExtendedHeader
raw []byte
frames []wire.Frame
buffer *packetBuffer
}
func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
if !p.header.IsLongHeader {
return protocol.Encryption1RTT
}
switch p.header.Type {
case protocol.PacketTypeInitial:
return protocol.EncryptionInitial
case protocol.PacketTypeHandshake:
return protocol.EncryptionHandshake
default:
return protocol.EncryptionUnspecified
}
}
func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
@@ -37,7 +52,7 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
PacketType: p.header.Type,
Frames: p.frames,
Length: protocol.ByteCount(len(p.raw)),
EncryptionLevel: p.encryptionLevel,
EncryptionLevel: p.EncryptionLevel(),
SendTime: time.Now(),
}
}
@@ -136,13 +151,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
frames := []wire.Frame{ccf}
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
raw, err := p.writeAndSealPacket(header, frames, sealer)
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
return p.writeAndSealPacket(header, frames, sealer)
}
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
@@ -154,13 +163,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
frames := []wire.Frame{ack}
raw, err := p.writeAndSealPacket(header, frames, sealer)
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
return p.writeAndSealPacket(header, frames, sealer)
}
// PackRetransmission packs a retransmission
@@ -227,16 +230,11 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
sf.DataLenPresent = false
}
raw, err := p.writeAndSealPacket(header, frames, sealer)
p, err := p.writeAndSealPacket(header, frames, sealer)
if err != nil {
return nil, err
}
packets = append(packets, &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
})
packets = append(packets, p)
}
return packets, nil
}
@@ -281,16 +279,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
p.numNonRetransmittableAcks = 0
}
raw, err := p.writeAndSealPacket(header, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
return p.writeAndSealPacket(header, frames, sealer)
}
func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
@@ -320,16 +309,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
}
cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
frames = append(frames, cf)
raw, err := p.writeAndSealPacket(hdr, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
header: hdr,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
return p.writeAndSealPacket(hdr, frames, sealer)
}
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) {
@@ -395,9 +375,9 @@ func (p *packetPacker) writeAndSealPacket(
header *wire.ExtendedHeader,
frames []wire.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := *getPacketBuffer()
buffer := bytes.NewBuffer(raw[:0])
) (*packedPacket, error) {
packetBuffer := getPacketBuffer()
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
@@ -458,7 +438,7 @@ func (p *packetPacker) writeAndSealPacket(
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
}
raw = raw[0:buffer.Len()]
raw := buffer.Bytes()
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
raw = raw[0 : buffer.Len()+sealer.Overhead()]
@@ -473,7 +453,12 @@ func (p *packetPacker) writeAndSealPacket(
if num != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
return raw, nil
return &packedPacket{
header: header,
raw: raw,
frames: frames,
buffer: packetBuffer,
}, nil
}
func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {

View File

@@ -253,7 +253,7 @@ var _ = Describe("Packet packer", func() {
})
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT))
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
})
It("packs a single ACK", func() {
@@ -494,7 +494,7 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
Expect(packets).To(HaveLen(1))
p := packets[0]
Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT))
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
Expect(p.frames).To(Equal(frames))
})
@@ -846,7 +846,7 @@ var _ = Describe("Packet packer", func() {
Expect(p).To(HaveLen(1))
Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial))
Expect(p[0].frames).To(Equal([]wire.Frame{f}))
Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
})
It("packs a retransmission for an Initial packet", func() {
@@ -864,7 +864,7 @@ var _ = Describe("Packet packer", func() {
Expect(packets).To(HaveLen(1))
p := packets[0]
Expect(p.frames).To(Equal([]wire.Frame{sf}))
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial))
Expect(p.header.Token).To(Equal(token))
Expect(p.raw).To(HaveLen(protocol.MinInitialPacketSize))

View File

@@ -39,14 +39,6 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
r := bytes.NewReader(data)
if hdr.IsLongHeader {
if protocol.ByteCount(r.Len()) < hdr.Length {
return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
data = data[:int(hdr.ParsedLen()+hdr.Length)]
// TODO(#1312): implement parsing of compound packets
}
var encLevel protocol.EncryptionLevel
switch hdr.Type {
case protocol.PacketTypeInitial:
@@ -93,11 +85,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
extHdr.PacketNumber,
)
buf := *getPacketBuffer()
buf = buf[:0]
defer putPacketBuffer(&buf)
decrypted, err := opener.Open(buf, data, pn, extHdr.Raw)
decrypted, err := opener.Open(data[:0], data, pn, extHdr.Raw)
if err != nil {
return nil, err
}

View File

@@ -75,49 +75,6 @@ var _ = Describe("Packet Unpacker", func() {
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
})
It("errors on packets that are smaller than the length in the packet header", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Length: 1000,
DestConnectionID: connID,
Version: version,
},
PacketNumberLen: protocol.PacketNumberLen2,
}
hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, make([]byte, 500-2 /* for packet number length */)...)
_, err := unpacker.Unpack(hdr, data)
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
})
It("cuts packets to the right length", func() {
pnLen := protocol.PacketNumberLen2
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
DestConnectionID: connID,
Type: protocol.PacketTypeHandshake,
Length: 456,
Version: protocol.VersionTLS,
},
PacketNumberLen: pnLen,
}
payloadLen := 456 - int(pnLen)
hdr, hdrRaw := getHeader(extHdr)
data := append(hdrRaw, make([]byte, payloadLen)...)
opener := mocks.NewMockOpener(mockCtrl)
cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) {
Expect(payload).To(HaveLen(payloadLen))
return []byte{0}, nil
})
_, err := unpacker.Unpack(hdr, data)
Expect(err).ToNot(HaveOccurred())
})
It("returns the error when getting the sealer fails", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},

View File

@@ -318,21 +318,27 @@ func (s *server) handlePacket(p *receivedPacket) {
}
if hdr.Type == protocol.PacketTypeInitial {
go s.handleInitial(p)
return
}
// TODO(#943): send Stateless Reset
p.buffer.Release()
}
func (s *server) handleInitial(p *receivedPacket) {
// TODO: add a check that DestConnID == SrcConnID
s.logger.Debugf("<- Received Initial packet.")
sess, connID, err := s.handleInitialImpl(p)
if err != nil {
p.buffer.Release()
s.logger.Errorf("Error occurred handling initial packet: %s", err)
return
}
if sess == nil { // a retry was done
p.buffer.Release()
return
}
// Don't put the packet buffer back if a new session was created.
// The session will handle the packet and take of that.
serverSession := newServerSession(sess, s.config, s.logger)
s.sessionHandler.Add(connID, serverSession)
}
@@ -455,6 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
}
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
defer p.buffer.Release()
hdr := p.hdr
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)

View File

@@ -122,19 +122,19 @@ var _ = Describe("Server", func() {
}
It("drops Initial packets with a too short connection ID", func() {
serv.handlePacket(&receivedPacket{
serv.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
Version: serv.config.Versions[0],
},
})
}))
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("drops too small Initial", func() {
serv.handlePacket(&receivedPacket{
serv.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
@@ -142,12 +142,12 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0],
},
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
})
}))
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("drops packets with a too short connection ID", func() {
serv.handlePacket(&receivedPacket{
serv.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
@@ -156,19 +156,19 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0],
},
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
})
}))
Consistently(conn.dataWritten).ShouldNot(Receive())
})
It("drops non-Initial packets", func() {
serv.logger.SetLogLevel(utils.LogLevelDebug)
serv.handlePacket(&receivedPacket{
serv.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{
Type: protocol.PacketTypeHandshake,
Version: serv.config.Versions[0],
},
data: []byte("invalid"),
})
}))
})
It("decodes the cookie from the Token field", func() {
@@ -185,7 +185,7 @@ var _ = Describe("Server", func() {
}
token, err := serv.cookieGenerator.NewToken(raddr, nil)
Expect(err).ToNot(HaveOccurred())
serv.handlePacket(&receivedPacket{
serv.handlePacket(insertPacketBuffer(&receivedPacket{
remoteAddr: raddr,
hdr: &wire.Header{
Type: protocol.PacketTypeInitial,
@@ -193,7 +193,7 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0],
},
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
})
}))
Eventually(done).Should(BeClosed())
})
@@ -209,7 +209,7 @@ var _ = Describe("Server", func() {
close(done)
return false
}
serv.handlePacket(&receivedPacket{
serv.handlePacket(insertPacketBuffer(&receivedPacket{
remoteAddr: raddr,
hdr: &wire.Header{
Type: protocol.PacketTypeInitial,
@@ -217,14 +217,14 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0],
},
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
})
}))
Eventually(done).Should(BeClosed())
})
It("sends a Version Negotiation Packet for unsupported versions", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
serv.handlePacket(&receivedPacket{
serv.handlePacket(insertPacketBuffer(&receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
hdr: &wire.Header{
IsLongHeader: true,
@@ -233,7 +233,7 @@ var _ = Describe("Server", func() {
DestConnectionID: destConnID,
Version: 0x42,
},
})
}))
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
@@ -253,11 +253,11 @@ var _ = Describe("Server", func() {
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS,
}
serv.handleInitial(&receivedPacket{
serv.handleInitial(insertPacketBuffer(&receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
hdr: hdr,
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
})
}))
var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
@@ -308,7 +308,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
serv.handlePacket(p)
serv.handlePacket(insertPacketBuffer(p))
// the Handshake packet is written by the session
Consistently(conn.dataWritten).ShouldNot(Receive())
close(done)

View File

@@ -53,8 +53,10 @@ type cryptoStreamHandler interface {
type receivedPacket struct {
remoteAddr net.Addr
hdr *wire.Header
data []byte
rcvTime time.Time
data []byte
buffer *packetBuffer
}
type closeError struct {
@@ -368,9 +370,6 @@ runLoop:
if wasProcessed := s.handlePacketImpl(p); !wasProcessed {
continue
}
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
// TODO: putPacketBuffer(&p.extHdr.Raw)
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
}
@@ -475,6 +474,15 @@ func (s *session) handleHandshakeComplete() {
}
func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
p.buffer.Release()
}
}()
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
@@ -490,6 +498,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc
// if the decryption failed, this might be a packet sent by an attacker
if err != nil {
if err == handshake.ErrOpenerNotYetAvailable {
wasQueued = true
s.tryQueueingUndecryptablePacket(p)
return false
}
@@ -953,7 +962,7 @@ func (s *session) sendPacket() (bool, error) {
}
func (s *session) sendPackedPacket(packet *packedPacket) error {
defer putPacketBuffer(&packet.raw)
defer packet.buffer.Release()
s.logPacket(packet)
return s.conn.Write(packet.raw)
}
@@ -976,7 +985,7 @@ func (s *session) logPacket(packet *packedPacket) {
// We don't need to allocate the slices for calling the format functions
return
}
s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.encryptionLevel)
s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.EncryptionLevel())
packet.header.Log(s.logger)
for _, frame := range packet.frames {
wire.LogFrame(s.logger, frame, true)

View File

@@ -61,6 +61,11 @@ func areSessionsRunning() bool {
return strings.Contains(b.String(), "quic-go.(*session).run")
}
func insertPacketBuffer(p *receivedPacket) *receivedPacket {
p.buffer = getPacketBuffer()
return p
}
var _ = Describe("Session", func() {
var (
sess *session
@@ -496,11 +501,11 @@ var _ = Describe("Session", func() {
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false)
sess.receivedPacketHandler = rph
Expect(sess.handlePacketImpl(&receivedPacket{
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
rcvTime: rcvTime,
hdr: &hdr.Header,
data: getData(hdr),
})).To(BeTrue())
}))).To(BeTrue())
})
It("closes when handling a packet fails", func() {
@@ -518,7 +523,10 @@ var _ = Describe("Session", func() {
close(done)
}()
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
sess.handlePacket(&receivedPacket{hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1})})
sess.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{},
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
}))
Eventually(done).Should(BeClosed())
})
@@ -528,18 +536,18 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1,
}
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: hdr}, nil).Times(2)
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue())
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue())
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
})
It("ignores 0-RTT packets", func() {
Expect(sess.handlePacketImpl(&receivedPacket{
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketType0RTT,
DestConnectionID: sess.srcConnID,
},
})).To(BeFalse())
}))).To(BeFalse())
})
It("ignores packets with a different source connection ID", func() {
@@ -552,12 +560,12 @@ var _ = Describe("Session", func() {
// Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: &wire.ExtendedHeader{Header: *hdr}}, nil)
Expect(sess.handlePacketImpl(&receivedPacket{
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
hdr: hdr,
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
})).To(BeTrue())
}))).To(BeTrue())
// The next packet has to be ignored, since the source connection ID doesn't match.
Expect(sess.handlePacketImpl(&receivedPacket{
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{
IsLongHeader: true,
DestConnectionID: sess.destConnID,
@@ -565,7 +573,7 @@ var _ = Describe("Session", func() {
Length: 1,
},
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
})).To(BeFalse())
}))).To(BeFalse())
})
Context("updating the remote address", func() {
@@ -574,12 +582,11 @@ var _ = Describe("Session", func() {
origAddr := sess.conn.(*mockConnection).remoteAddr
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
Expect(origAddr).ToNot(Equal(remoteIP))
p := receivedPacket{
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
remoteAddr: remoteIP,
hdr: &wire.Header{},
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
}
Expect(sess.handlePacketImpl(&p)).To(BeTrue())
}))).To(BeTrue())
Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr))
})
})
@@ -587,10 +594,12 @@ var _ = Describe("Session", func() {
Context("sending packets", func() {
getPacket := func(pn protocol.PacketNumber) *packedPacket {
data := *getPacketBuffer()
buffer := getPacketBuffer()
data := buffer.Slice[:0]
data = append(data, []byte("foobar")...)
return &packedPacket{
raw: data,
buffer: buffer,
header: &wire.ExtendedHeader{PacketNumber: pn},
}
}
@@ -963,7 +972,7 @@ var _ = Describe("Session", func() {
defer close(done)
return &packedPacket{
header: &wire.ExtendedHeader{},
raw: *getPacketBuffer(),
buffer: getPacketBuffer(),
}, nil
}),
packer.EXPECT().PackPacket().AnyTimes(),
@@ -1352,7 +1361,7 @@ var _ = Describe("Client Session", func() {
}()
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
packer.EXPECT().ChangeDestConnectionID(newConnID)
Expect(sess.handlePacketImpl(&receivedPacket{
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
@@ -1361,7 +1370,7 @@ var _ = Describe("Client Session", func() {
Length: 1,
},
data: []byte{0},
})).To(BeTrue())
}))).To(BeTrue())
// make sure the go routine returns
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
sessionRunner.EXPECT().retireConnectionID(gomock.Any())