forked from quic-go/quic-go
Merge pull request #1698 from lucas-clemente/coalesced-packets
implement parsing of coalesced packets
This commit is contained in:
@@ -6,22 +6,52 @@ import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var bufferPool sync.Pool
|
||||
type packetBuffer struct {
|
||||
Slice []byte
|
||||
|
||||
func getPacketBuffer() *[]byte {
|
||||
return bufferPool.Get().(*[]byte)
|
||||
// refCount counts how many packets the Slice is used in.
|
||||
// It doesn't support concurrent use.
|
||||
// It is > 1 when used for coalesced packet.
|
||||
refCount int
|
||||
}
|
||||
|
||||
func putPacketBuffer(buf *[]byte) {
|
||||
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
|
||||
// Split increases the refCount.
|
||||
// It must be called when a packet buffer is used for more than one packet,
|
||||
// e.g. when splitting coalesced packets.
|
||||
func (b *packetBuffer) Split() {
|
||||
b.refCount++
|
||||
}
|
||||
|
||||
// Release decreases the refCount.
|
||||
// It should be called when processing the packet is finished.
|
||||
// When the refCount reaches 0, the packet buffer is put back into the pool.
|
||||
func (b *packetBuffer) Release() {
|
||||
if cap(b.Slice) != int(protocol.MaxReceivePacketSize) {
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
bufferPool.Put(buf)
|
||||
b.refCount--
|
||||
if b.refCount < 0 {
|
||||
panic("negative packetBuffer refCount")
|
||||
}
|
||||
// only put the packetBuffer back if it's not used any more
|
||||
if b.refCount == 0 {
|
||||
bufferPool.Put(b)
|
||||
}
|
||||
}
|
||||
|
||||
var bufferPool sync.Pool
|
||||
|
||||
func getPacketBuffer() *packetBuffer {
|
||||
buf := bufferPool.Get().(*packetBuffer)
|
||||
buf.refCount = 1
|
||||
buf.Slice = buf.Slice[:protocol.MaxReceivePacketSize]
|
||||
return buf
|
||||
}
|
||||
|
||||
func init() {
|
||||
bufferPool.New = func() interface{} {
|
||||
b := make([]byte, 0, protocol.MaxReceivePacketSize)
|
||||
return &b
|
||||
return &packetBuffer{
|
||||
Slice: make([]byte, 0, protocol.MaxReceivePacketSize),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,13 +9,35 @@ import (
|
||||
|
||||
var _ = Describe("Buffer Pool", func() {
|
||||
It("returns buffers of cap", func() {
|
||||
buf := *getPacketBuffer()
|
||||
Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize)))
|
||||
buf := getPacketBuffer()
|
||||
Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize)))
|
||||
})
|
||||
|
||||
It("releases buffers", func() {
|
||||
buf := getPacketBuffer()
|
||||
buf.Release()
|
||||
})
|
||||
|
||||
It("panics if wrong-sized buffers are passed", func() {
|
||||
Expect(func() {
|
||||
putPacketBuffer(&[]byte{0})
|
||||
}).To(Panic())
|
||||
buf := getPacketBuffer()
|
||||
buf.Slice = make([]byte, 10)
|
||||
Expect(func() { buf.Release() }).To(Panic())
|
||||
})
|
||||
|
||||
It("panics if it is released twice", func() {
|
||||
buf := getPacketBuffer()
|
||||
buf.Release()
|
||||
Expect(func() { buf.Release() }).To(Panic())
|
||||
})
|
||||
|
||||
It("waits until all parts have been released", func() {
|
||||
buf := getPacketBuffer()
|
||||
buf.Split()
|
||||
buf.Split()
|
||||
// now we have 3 parts
|
||||
buf.Release()
|
||||
buf.Release()
|
||||
buf.Release()
|
||||
Expect(func() { buf.Release() }).To(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -144,8 +144,8 @@ func (h *packetHandlerMap) close(e error) error {
|
||||
|
||||
func (h *packetHandlerMap) listen() {
|
||||
for {
|
||||
data := *getPacketBuffer()
|
||||
data = data[:protocol.MaxReceivePacketSize]
|
||||
buffer := getPacketBuffer()
|
||||
data := buffer.Slice
|
||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||
n, addr, err := h.conn.ReadFrom(data)
|
||||
@@ -153,55 +153,110 @@ func (h *packetHandlerMap) listen() {
|
||||
h.close(err)
|
||||
return
|
||||
}
|
||||
data = data[:n]
|
||||
|
||||
if err := h.handlePacket(addr, data); err != nil {
|
||||
h.logger.Debugf("error handling packet from %s: %s", addr, err)
|
||||
}
|
||||
h.handlePacket(addr, buffer, data[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
||||
r := bytes.NewReader(data)
|
||||
hdr, err := wire.ParseHeader(r, h.connIDLen)
|
||||
// drop the packet if we can't parse the header
|
||||
func (h *packetHandlerMap) handlePacket(
|
||||
addr net.Addr,
|
||||
buffer *packetBuffer,
|
||||
data []byte,
|
||||
) {
|
||||
packets, err := h.parsePacket(addr, buffer, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing header: %s", err)
|
||||
h.logger.Debugf("error parsing packets from %s: %s", addr, err)
|
||||
// This is just the error from parsing the last packet.
|
||||
// We still need to process the packets that were successfully parsed before.
|
||||
}
|
||||
|
||||
p := &receivedPacket{
|
||||
remoteAddr: addr,
|
||||
hdr: hdr,
|
||||
data: data,
|
||||
rcvTime: time.Now(),
|
||||
if len(packets) == 0 {
|
||||
buffer.Release()
|
||||
return
|
||||
}
|
||||
h.handleParsedPackets(packets)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) parsePacket(
|
||||
addr net.Addr,
|
||||
buffer *packetBuffer,
|
||||
data []byte,
|
||||
) ([]*receivedPacket, error) {
|
||||
rcvTime := time.Now()
|
||||
packets := make([]*receivedPacket, 0, 1)
|
||||
|
||||
var counter int
|
||||
var lastConnID protocol.ConnectionID
|
||||
for len(data) > 0 {
|
||||
if counter > 0 && h.logger.Debug() {
|
||||
h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes", counter, len(packets[counter-1].data))
|
||||
}
|
||||
|
||||
hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen)
|
||||
// drop the packet if we can't parse the header
|
||||
if err != nil {
|
||||
return packets, fmt.Errorf("error parsing header: %s", err)
|
||||
}
|
||||
if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) {
|
||||
return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
|
||||
}
|
||||
lastConnID = hdr.DestConnectionID
|
||||
|
||||
var rest []byte
|
||||
if hdr.IsLongHeader {
|
||||
if protocol.ByteCount(len(data)) < hdr.Length {
|
||||
return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
|
||||
}
|
||||
packetLen := int(hdr.ParsedLen() + hdr.Length)
|
||||
rest = data[packetLen:]
|
||||
data = data[:packetLen]
|
||||
}
|
||||
|
||||
if counter > 0 {
|
||||
buffer.Split()
|
||||
}
|
||||
counter++
|
||||
packets = append(packets, &receivedPacket{
|
||||
remoteAddr: addr,
|
||||
hdr: hdr,
|
||||
rcvTime: rcvTime,
|
||||
data: data,
|
||||
buffer: buffer,
|
||||
})
|
||||
data = rest
|
||||
}
|
||||
return packets, nil
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
|
||||
// coalesced packets all have the same destination connection ID
|
||||
handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)]
|
||||
|
||||
if handlerFound { // existing session
|
||||
handlerEntry.handler.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
// No session found.
|
||||
// This might be a stateless reset.
|
||||
if !hdr.IsLongHeader {
|
||||
if len(data) >= protocol.MinStatelessResetSize {
|
||||
var token [16]byte
|
||||
copy(token[:], data[len(data)-16:])
|
||||
if sess, ok := h.resetTokens[token]; ok {
|
||||
sess.destroy(errors.New("received a stateless reset"))
|
||||
return nil
|
||||
}
|
||||
for _, p := range packets {
|
||||
if handlerFound { // existing session
|
||||
handlerEntry.handler.handlePacket(p)
|
||||
continue
|
||||
}
|
||||
// TODO(#943): send a stateless reset
|
||||
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
|
||||
// No session found.
|
||||
// This might be a stateless reset.
|
||||
if !p.hdr.IsLongHeader {
|
||||
if len(p.data) >= protocol.MinStatelessResetSize {
|
||||
var token [16]byte
|
||||
copy(token[:], p.data[len(p.data)-16:])
|
||||
if sess, ok := h.resetTokens[token]; ok {
|
||||
sess.destroy(errors.New("received a stateless reset"))
|
||||
continue
|
||||
}
|
||||
}
|
||||
// TODO(#943): send a stateless reset
|
||||
h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
|
||||
break // a short header packet is always the last in a coalesced packet
|
||||
|
||||
}
|
||||
if h.server != nil { // no server set
|
||||
h.server.handlePacket(p)
|
||||
}
|
||||
h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
|
||||
}
|
||||
if h.server == nil { // no server set
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
|
||||
}
|
||||
h.server.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package quic
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
@@ -19,21 +20,25 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
conn *mockPacketConn
|
||||
)
|
||||
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: connID,
|
||||
Length: 1,
|
||||
Length: length,
|
||||
Version: protocol.VersionTLS,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}).Write(buf, protocol.VersionWhatever)).To(Succeed())
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
return getPacketWithLength(connID, 2)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
conn = newMockPacketConn()
|
||||
handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap)
|
||||
@@ -81,7 +86,7 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
})
|
||||
|
||||
It("drops unparseable packets", func() {
|
||||
err := handler.handlePacket(nil, []byte{0, 1, 2, 3})
|
||||
_, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error parsing header:"))
|
||||
})
|
||||
@@ -91,7 +96,8 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
handler.Remove(connID)
|
||||
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
})
|
||||
|
||||
It("deletes retired session entries after a wait time", func() {
|
||||
@@ -100,7 +106,8 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
handler.Retire(connID)
|
||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
||||
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
})
|
||||
|
||||
It("passes packets arriving late for closed sessions to that session", func() {
|
||||
@@ -110,14 +117,12 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any())
|
||||
handler.Add(connID, packetHandler)
|
||||
handler.Retire(connID)
|
||||
err := handler.handlePacket(nil, getPacket(connID))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
})
|
||||
|
||||
It("drops packets for unknown receivers", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
err := handler.handlePacket(nil, getPacket(connID))
|
||||
Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
})
|
||||
|
||||
It("closes the packet handlers when reading from the conn fails", func() {
|
||||
@@ -131,6 +136,75 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
conn.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("coalesced packets", func() {
|
||||
It("errors on packets that are smaller than the length in the packet header", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...)
|
||||
_, err := handler.parsePacket(nil, nil, data)
|
||||
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
|
||||
})
|
||||
|
||||
It("cuts packets to the right length", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...)
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen())))
|
||||
})
|
||||
handler.Add(connID, packetHandler)
|
||||
handler.handlePacket(nil, nil, data)
|
||||
})
|
||||
|
||||
It("handles coalesced packets", func() {
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
handledPackets := make(chan *receivedPacket, 3)
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
handledPackets <- p
|
||||
}).Times(3)
|
||||
handler.Add(connID, packetHandler)
|
||||
|
||||
buffer := getPacketBuffer()
|
||||
packet := buffer.Slice[:0]
|
||||
packet = append(packet, append(getPacketWithLength(connID, 10), make([]byte, 10-2 /* packet number len */)...)...)
|
||||
packet = append(packet, append(getPacketWithLength(connID, 20), make([]byte, 20-2 /* packet number len */)...)...)
|
||||
packet = append(packet, append(getPacketWithLength(connID, 30), make([]byte, 30-2 /* packet number len */)...)...)
|
||||
conn.dataToRead <- packet
|
||||
|
||||
now := time.Now()
|
||||
for i := 1; i <= 3; i++ {
|
||||
var p *receivedPacket
|
||||
Eventually(handledPackets).Should(Receive(&p))
|
||||
Expect(p.hdr.DestConnectionID).To(Equal(connID))
|
||||
Expect(p.hdr.Length).To(BeEquivalentTo(10 * i))
|
||||
Expect(p.data).To(HaveLen(int(p.hdr.ParsedLen() + p.hdr.Length)))
|
||||
Expect(p.rcvTime).To(BeTemporally("~", now, scaleDuration(20*time.Millisecond)))
|
||||
Expect(p.buffer.refCount).To(Equal(3))
|
||||
}
|
||||
|
||||
// makes the listen go routine return
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).AnyTimes()
|
||||
close(conn.dataToRead)
|
||||
})
|
||||
|
||||
It("ignores coalesced packet parts if the connection IDs don't match", func() {
|
||||
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
|
||||
buffer := getPacketBuffer()
|
||||
packet := buffer.Slice[:0]
|
||||
// var packet []byte
|
||||
packet = append(packet, getPacket(connID1)...)
|
||||
packet = append(packet, getPacket(connID2)...)
|
||||
|
||||
packets, err := handler.parsePacket(&net.UDPAddr{}, buffer, packet)
|
||||
Expect(err).To(MatchError("coalesced packet has different destination connection ID: 0x0807060504030201, expected 0x0102030405060708"))
|
||||
Expect(packets).To(HaveLen(1))
|
||||
Expect(packets[0].hdr.DestConnectionID).To(Equal(connID1))
|
||||
Expect(packets[0].buffer.refCount).To(Equal(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("stateless reset handling", func() {
|
||||
@@ -164,6 +238,24 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("detects a stateless that is coalesced with another packet", func() {
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
|
||||
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddWithResetToken(connID, packetHandler, token)
|
||||
fakeConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
packet := getPacket(fakeConnID)
|
||||
reset := append([]byte{0x40} /* short header packet */, fakeConnID...)
|
||||
reset = append(reset, make([]byte, 50)...) // add some "random" data
|
||||
reset = append(reset, token[:]...)
|
||||
destroyed := make(chan struct{})
|
||||
packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) {
|
||||
close(destroyed)
|
||||
})
|
||||
conn.dataToRead <- append(packet, reset...)
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("deletes reset tokens when the session is retired", func() {
|
||||
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
|
||||
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}
|
||||
@@ -171,10 +263,12 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token)
|
||||
handler.Retire(connID)
|
||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
||||
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42"))
|
||||
handler.handlePacket(nil, nil, getPacket(connID))
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...)
|
||||
packet = append(packet, token[:]...)
|
||||
Expect(handler.handlePacket(nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99"))
|
||||
handler.handlePacket(nil, nil, packet)
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
Expect(handler.resetTokens).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -188,7 +282,7 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
Expect(p.hdr.DestConnectionID).To(Equal(connID))
|
||||
})
|
||||
handler.SetServer(server)
|
||||
Expect(handler.handlePacket(nil, p)).To(Succeed())
|
||||
handler.handlePacket(nil, nil, p)
|
||||
})
|
||||
|
||||
It("closes all server sessions", func() {
|
||||
@@ -207,9 +301,10 @@ var _ = Describe("Packet Handler Map", func() {
|
||||
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
|
||||
p := getPacket(connID)
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
// don't EXPECT any calls to server.handlePacket
|
||||
handler.SetServer(server)
|
||||
handler.CloseServer()
|
||||
Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788"))
|
||||
handler.handlePacket(nil, nil, p)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -25,10 +25,25 @@ type packer interface {
|
||||
}
|
||||
|
||||
type packedPacket struct {
|
||||
header *wire.ExtendedHeader
|
||||
raw []byte
|
||||
frames []wire.Frame
|
||||
encryptionLevel protocol.EncryptionLevel
|
||||
header *wire.ExtendedHeader
|
||||
raw []byte
|
||||
frames []wire.Frame
|
||||
|
||||
buffer *packetBuffer
|
||||
}
|
||||
|
||||
func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
|
||||
if !p.header.IsLongHeader {
|
||||
return protocol.Encryption1RTT
|
||||
}
|
||||
switch p.header.Type {
|
||||
case protocol.PacketTypeInitial:
|
||||
return protocol.EncryptionInitial
|
||||
case protocol.PacketTypeHandshake:
|
||||
return protocol.EncryptionHandshake
|
||||
default:
|
||||
return protocol.EncryptionUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
|
||||
@@ -37,7 +52,7 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
|
||||
PacketType: p.header.Type,
|
||||
Frames: p.frames,
|
||||
Length: protocol.ByteCount(len(p.raw)),
|
||||
EncryptionLevel: p.encryptionLevel,
|
||||
EncryptionLevel: p.EncryptionLevel(),
|
||||
SendTime: time.Now(),
|
||||
}
|
||||
}
|
||||
@@ -136,13 +151,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
|
||||
frames := []wire.Frame{ccf}
|
||||
encLevel, sealer := p.cryptoSetup.GetSealer()
|
||||
header := p.getHeader(encLevel)
|
||||
raw, err := p.writeAndSealPacket(header, frames, sealer)
|
||||
return &packedPacket{
|
||||
header: header,
|
||||
raw: raw,
|
||||
frames: frames,
|
||||
encryptionLevel: encLevel,
|
||||
}, err
|
||||
return p.writeAndSealPacket(header, frames, sealer)
|
||||
}
|
||||
|
||||
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
|
||||
@@ -154,13 +163,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
|
||||
encLevel, sealer := p.cryptoSetup.GetSealer()
|
||||
header := p.getHeader(encLevel)
|
||||
frames := []wire.Frame{ack}
|
||||
raw, err := p.writeAndSealPacket(header, frames, sealer)
|
||||
return &packedPacket{
|
||||
header: header,
|
||||
raw: raw,
|
||||
frames: frames,
|
||||
encryptionLevel: encLevel,
|
||||
}, err
|
||||
return p.writeAndSealPacket(header, frames, sealer)
|
||||
}
|
||||
|
||||
// PackRetransmission packs a retransmission
|
||||
@@ -227,16 +230,11 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
|
||||
if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
|
||||
sf.DataLenPresent = false
|
||||
}
|
||||
raw, err := p.writeAndSealPacket(header, frames, sealer)
|
||||
p, err := p.writeAndSealPacket(header, frames, sealer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packets = append(packets, &packedPacket{
|
||||
header: header,
|
||||
raw: raw,
|
||||
frames: frames,
|
||||
encryptionLevel: encLevel,
|
||||
})
|
||||
packets = append(packets, p)
|
||||
}
|
||||
return packets, nil
|
||||
}
|
||||
@@ -281,16 +279,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
|
||||
p.numNonRetransmittableAcks = 0
|
||||
}
|
||||
|
||||
raw, err := p.writeAndSealPacket(header, frames, sealer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &packedPacket{
|
||||
header: header,
|
||||
raw: raw,
|
||||
frames: frames,
|
||||
encryptionLevel: encLevel,
|
||||
}, nil
|
||||
return p.writeAndSealPacket(header, frames, sealer)
|
||||
}
|
||||
|
||||
func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
|
||||
@@ -320,16 +309,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
|
||||
}
|
||||
cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
|
||||
frames = append(frames, cf)
|
||||
raw, err := p.writeAndSealPacket(hdr, frames, sealer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &packedPacket{
|
||||
header: hdr,
|
||||
raw: raw,
|
||||
frames: frames,
|
||||
encryptionLevel: encLevel,
|
||||
}, nil
|
||||
return p.writeAndSealPacket(hdr, frames, sealer)
|
||||
}
|
||||
|
||||
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) {
|
||||
@@ -395,9 +375,9 @@ func (p *packetPacker) writeAndSealPacket(
|
||||
header *wire.ExtendedHeader,
|
||||
frames []wire.Frame,
|
||||
sealer handshake.Sealer,
|
||||
) ([]byte, error) {
|
||||
raw := *getPacketBuffer()
|
||||
buffer := bytes.NewBuffer(raw[:0])
|
||||
) (*packedPacket, error) {
|
||||
packetBuffer := getPacketBuffer()
|
||||
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
|
||||
|
||||
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
|
||||
|
||||
@@ -458,7 +438,7 @@ func (p *packetPacker) writeAndSealPacket(
|
||||
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
|
||||
}
|
||||
|
||||
raw = raw[0:buffer.Len()]
|
||||
raw := buffer.Bytes()
|
||||
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
|
||||
raw = raw[0 : buffer.Len()+sealer.Overhead()]
|
||||
|
||||
@@ -473,7 +453,12 @@ func (p *packetPacker) writeAndSealPacket(
|
||||
if num != header.PacketNumber {
|
||||
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
|
||||
}
|
||||
return raw, nil
|
||||
return &packedPacket{
|
||||
header: header,
|
||||
raw: raw,
|
||||
frames: frames,
|
||||
buffer: packetBuffer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
|
||||
|
||||
@@ -253,7 +253,7 @@ var _ = Describe("Packet packer", func() {
|
||||
})
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT))
|
||||
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
|
||||
})
|
||||
|
||||
It("packs a single ACK", func() {
|
||||
@@ -494,7 +494,7 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packets).To(HaveLen(1))
|
||||
p := packets[0]
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT))
|
||||
Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT))
|
||||
Expect(p.frames).To(Equal(frames))
|
||||
})
|
||||
|
||||
@@ -846,7 +846,7 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(p).To(HaveLen(1))
|
||||
Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(p[0].frames).To(Equal([]wire.Frame{f}))
|
||||
Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionInitial))
|
||||
Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
|
||||
})
|
||||
|
||||
It("packs a retransmission for an Initial packet", func() {
|
||||
@@ -864,7 +864,7 @@ var _ = Describe("Packet packer", func() {
|
||||
Expect(packets).To(HaveLen(1))
|
||||
p := packets[0]
|
||||
Expect(p.frames).To(Equal([]wire.Frame{sf}))
|
||||
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionInitial))
|
||||
Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
|
||||
Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
Expect(p.header.Token).To(Equal(token))
|
||||
Expect(p.raw).To(HaveLen(protocol.MinInitialPacketSize))
|
||||
|
||||
@@ -39,14 +39,6 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber)
|
||||
func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
if hdr.IsLongHeader {
|
||||
if protocol.ByteCount(r.Len()) < hdr.Length {
|
||||
return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
|
||||
}
|
||||
data = data[:int(hdr.ParsedLen()+hdr.Length)]
|
||||
// TODO(#1312): implement parsing of compound packets
|
||||
}
|
||||
|
||||
var encLevel protocol.EncryptionLevel
|
||||
switch hdr.Type {
|
||||
case protocol.PacketTypeInitial:
|
||||
@@ -93,11 +85,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket,
|
||||
extHdr.PacketNumber,
|
||||
)
|
||||
|
||||
buf := *getPacketBuffer()
|
||||
buf = buf[:0]
|
||||
defer putPacketBuffer(&buf)
|
||||
|
||||
decrypted, err := opener.Open(buf, data, pn, extHdr.Raw)
|
||||
decrypted, err := opener.Open(data[:0], data, pn, extHdr.Raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -75,49 +75,6 @@ var _ = Describe("Packet Unpacker", func() {
|
||||
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
|
||||
})
|
||||
|
||||
It("errors on packets that are smaller than the length in the packet header", func() {
|
||||
extHdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
Length: 1000,
|
||||
DestConnectionID: connID,
|
||||
Version: version,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
hdr, hdrRaw := getHeader(extHdr)
|
||||
data := append(hdrRaw, make([]byte, 500-2 /* for packet number length */)...)
|
||||
_, err := unpacker.Unpack(hdr, data)
|
||||
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
|
||||
})
|
||||
|
||||
It("cuts packets to the right length", func() {
|
||||
pnLen := protocol.PacketNumberLen2
|
||||
extHdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
DestConnectionID: connID,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
Length: 456,
|
||||
Version: protocol.VersionTLS,
|
||||
},
|
||||
PacketNumberLen: pnLen,
|
||||
}
|
||||
payloadLen := 456 - int(pnLen)
|
||||
hdr, hdrRaw := getHeader(extHdr)
|
||||
data := append(hdrRaw, make([]byte, payloadLen)...)
|
||||
opener := mocks.NewMockOpener(mockCtrl)
|
||||
cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil)
|
||||
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) {
|
||||
Expect(payload).To(HaveLen(payloadLen))
|
||||
return []byte{0}, nil
|
||||
})
|
||||
_, err := unpacker.Unpack(hdr, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns the error when getting the sealer fails", func() {
|
||||
extHdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{DestConnectionID: connID},
|
||||
|
||||
@@ -318,21 +318,27 @@ func (s *server) handlePacket(p *receivedPacket) {
|
||||
}
|
||||
if hdr.Type == protocol.PacketTypeInitial {
|
||||
go s.handleInitial(p)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(#943): send Stateless Reset
|
||||
p.buffer.Release()
|
||||
}
|
||||
|
||||
func (s *server) handleInitial(p *receivedPacket) {
|
||||
// TODO: add a check that DestConnID == SrcConnID
|
||||
s.logger.Debugf("<- Received Initial packet.")
|
||||
sess, connID, err := s.handleInitialImpl(p)
|
||||
if err != nil {
|
||||
p.buffer.Release()
|
||||
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
||||
return
|
||||
}
|
||||
if sess == nil { // a retry was done
|
||||
p.buffer.Release()
|
||||
return
|
||||
}
|
||||
// Don't put the packet buffer back if a new session was created.
|
||||
// The session will handle the packet and take of that.
|
||||
serverSession := newServerSession(sess, s.config, s.logger)
|
||||
s.sessionHandler.Add(connID, serverSession)
|
||||
}
|
||||
@@ -455,6 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
||||
}
|
||||
|
||||
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
|
||||
defer p.buffer.Release()
|
||||
hdr := p.hdr
|
||||
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
|
||||
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
||||
|
||||
@@ -122,19 +122,19 @@ var _ = Describe("Server", func() {
|
||||
}
|
||||
|
||||
It("drops Initial packets with a too short connection ID", func() {
|
||||
serv.handlePacket(&receivedPacket{
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
})
|
||||
}))
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
})
|
||||
|
||||
It("drops too small Initial", func() {
|
||||
serv.handlePacket(&receivedPacket{
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
@@ -142,12 +142,12 @@ var _ = Describe("Server", func() {
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
|
||||
})
|
||||
}))
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
})
|
||||
|
||||
It("drops packets with a too short connection ID", func() {
|
||||
serv.handlePacket(&receivedPacket{
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
@@ -156,19 +156,19 @@ var _ = Describe("Server", func() {
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
})
|
||||
}))
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
})
|
||||
|
||||
It("drops non-Initial packets", func() {
|
||||
serv.logger.SetLogLevel(utils.LogLevelDebug)
|
||||
serv.handlePacket(&receivedPacket{
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: []byte("invalid"),
|
||||
})
|
||||
}))
|
||||
})
|
||||
|
||||
It("decodes the cookie from the Token field", func() {
|
||||
@@ -185,7 +185,7 @@ var _ = Describe("Server", func() {
|
||||
}
|
||||
token, err := serv.cookieGenerator.NewToken(raddr, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv.handlePacket(&receivedPacket{
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: raddr,
|
||||
hdr: &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
@@ -193,7 +193,7 @@ var _ = Describe("Server", func() {
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
})
|
||||
}))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
@@ -209,7 +209,7 @@ var _ = Describe("Server", func() {
|
||||
close(done)
|
||||
return false
|
||||
}
|
||||
serv.handlePacket(&receivedPacket{
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: raddr,
|
||||
hdr: &wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
@@ -217,14 +217,14 @@ var _ = Describe("Server", func() {
|
||||
Version: serv.config.Versions[0],
|
||||
},
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
})
|
||||
}))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
||||
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
|
||||
serv.handlePacket(&receivedPacket{
|
||||
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
@@ -233,7 +233,7 @@ var _ = Describe("Server", func() {
|
||||
DestConnectionID: destConnID,
|
||||
Version: 0x42,
|
||||
},
|
||||
})
|
||||
}))
|
||||
var write mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&write))
|
||||
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
||||
@@ -253,11 +253,11 @@ var _ = Describe("Server", func() {
|
||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
serv.handleInitial(&receivedPacket{
|
||||
serv.handleInitial(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
||||
hdr: hdr,
|
||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||
})
|
||||
}))
|
||||
var write mockPacketConnWrite
|
||||
Eventually(conn.dataWritten).Should(Receive(&write))
|
||||
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
||||
@@ -308,7 +308,7 @@ var _ = Describe("Server", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
serv.handlePacket(p)
|
||||
serv.handlePacket(insertPacketBuffer(p))
|
||||
// the Handshake packet is written by the session
|
||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||
close(done)
|
||||
|
||||
21
session.go
21
session.go
@@ -53,8 +53,10 @@ type cryptoStreamHandler interface {
|
||||
type receivedPacket struct {
|
||||
remoteAddr net.Addr
|
||||
hdr *wire.Header
|
||||
data []byte
|
||||
rcvTime time.Time
|
||||
data []byte
|
||||
|
||||
buffer *packetBuffer
|
||||
}
|
||||
|
||||
type closeError struct {
|
||||
@@ -368,9 +370,6 @@ runLoop:
|
||||
if wasProcessed := s.handlePacketImpl(p); !wasProcessed {
|
||||
continue
|
||||
}
|
||||
// This is a bit unclean, but works properly, since the packet always
|
||||
// begins with the public header and we never copy it.
|
||||
// TODO: putPacketBuffer(&p.extHdr.Raw)
|
||||
case <-s.handshakeCompleteChan:
|
||||
s.handleHandshakeComplete()
|
||||
}
|
||||
@@ -475,6 +474,15 @@ func (s *session) handleHandshakeComplete() {
|
||||
}
|
||||
|
||||
func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ {
|
||||
var wasQueued bool
|
||||
|
||||
defer func() {
|
||||
// Put back the packet buffer if the packet wasn't queued for later decryption.
|
||||
if !wasQueued {
|
||||
p.buffer.Release()
|
||||
}
|
||||
}()
|
||||
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
// After this, all packets with a different source connection have to be ignored.
|
||||
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||
@@ -490,6 +498,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc
|
||||
// if the decryption failed, this might be a packet sent by an attacker
|
||||
if err != nil {
|
||||
if err == handshake.ErrOpenerNotYetAvailable {
|
||||
wasQueued = true
|
||||
s.tryQueueingUndecryptablePacket(p)
|
||||
return false
|
||||
}
|
||||
@@ -953,7 +962,7 @@ func (s *session) sendPacket() (bool, error) {
|
||||
}
|
||||
|
||||
func (s *session) sendPackedPacket(packet *packedPacket) error {
|
||||
defer putPacketBuffer(&packet.raw)
|
||||
defer packet.buffer.Release()
|
||||
s.logPacket(packet)
|
||||
return s.conn.Write(packet.raw)
|
||||
}
|
||||
@@ -976,7 +985,7 @@ func (s *session) logPacket(packet *packedPacket) {
|
||||
// We don't need to allocate the slices for calling the format functions
|
||||
return
|
||||
}
|
||||
s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.encryptionLevel)
|
||||
s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.EncryptionLevel())
|
||||
packet.header.Log(s.logger)
|
||||
for _, frame := range packet.frames {
|
||||
wire.LogFrame(s.logger, frame, true)
|
||||
|
||||
@@ -61,6 +61,11 @@ func areSessionsRunning() bool {
|
||||
return strings.Contains(b.String(), "quic-go.(*session).run")
|
||||
}
|
||||
|
||||
func insertPacketBuffer(p *receivedPacket) *receivedPacket {
|
||||
p.buffer = getPacketBuffer()
|
||||
return p
|
||||
}
|
||||
|
||||
var _ = Describe("Session", func() {
|
||||
var (
|
||||
sess *session
|
||||
@@ -496,11 +501,11 @@ var _ = Describe("Session", func() {
|
||||
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false)
|
||||
sess.receivedPacketHandler = rph
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
rcvTime: rcvTime,
|
||||
hdr: &hdr.Header,
|
||||
data: getData(hdr),
|
||||
})).To(BeTrue())
|
||||
}))).To(BeTrue())
|
||||
})
|
||||
|
||||
It("closes when handling a packet fails", func() {
|
||||
@@ -518,7 +523,10 @@ var _ = Describe("Session", func() {
|
||||
close(done)
|
||||
}()
|
||||
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
||||
sess.handlePacket(&receivedPacket{hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1})})
|
||||
sess.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{},
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
}))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
@@ -528,18 +536,18 @@ var _ = Describe("Session", func() {
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: hdr}, nil).Times(2)
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue())
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue())
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores 0-RTT packets", func() {
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketType0RTT,
|
||||
DestConnectionID: sess.srcConnID,
|
||||
},
|
||||
})).To(BeFalse())
|
||||
}))).To(BeFalse())
|
||||
})
|
||||
|
||||
It("ignores packets with a different source connection ID", func() {
|
||||
@@ -552,12 +560,12 @@ var _ = Describe("Session", func() {
|
||||
// Send one packet, which might change the connection ID.
|
||||
// only EXPECT one call to the unpacker
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: &wire.ExtendedHeader{Header: *hdr}}, nil)
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
hdr: hdr,
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
})).To(BeTrue())
|
||||
}))).To(BeTrue())
|
||||
// The next packet has to be ignored, since the source connection ID doesn't match.
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
DestConnectionID: sess.destConnID,
|
||||
@@ -565,7 +573,7 @@ var _ = Describe("Session", func() {
|
||||
Length: 1,
|
||||
},
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
})).To(BeFalse())
|
||||
}))).To(BeFalse())
|
||||
})
|
||||
|
||||
Context("updating the remote address", func() {
|
||||
@@ -574,12 +582,11 @@ var _ = Describe("Session", func() {
|
||||
origAddr := sess.conn.(*mockConnection).remoteAddr
|
||||
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
|
||||
Expect(origAddr).ToNot(Equal(remoteIP))
|
||||
p := receivedPacket{
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
remoteAddr: remoteIP,
|
||||
hdr: &wire.Header{},
|
||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||
}
|
||||
Expect(sess.handlePacketImpl(&p)).To(BeTrue())
|
||||
}))).To(BeTrue())
|
||||
Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr))
|
||||
})
|
||||
})
|
||||
@@ -587,10 +594,12 @@ var _ = Describe("Session", func() {
|
||||
|
||||
Context("sending packets", func() {
|
||||
getPacket := func(pn protocol.PacketNumber) *packedPacket {
|
||||
data := *getPacketBuffer()
|
||||
buffer := getPacketBuffer()
|
||||
data := buffer.Slice[:0]
|
||||
data = append(data, []byte("foobar")...)
|
||||
return &packedPacket{
|
||||
raw: data,
|
||||
buffer: buffer,
|
||||
header: &wire.ExtendedHeader{PacketNumber: pn},
|
||||
}
|
||||
}
|
||||
@@ -963,7 +972,7 @@ var _ = Describe("Session", func() {
|
||||
defer close(done)
|
||||
return &packedPacket{
|
||||
header: &wire.ExtendedHeader{},
|
||||
raw: *getPacketBuffer(),
|
||||
buffer: getPacketBuffer(),
|
||||
}, nil
|
||||
}),
|
||||
packer.EXPECT().PackPacket().AnyTimes(),
|
||||
@@ -1352,7 +1361,7 @@ var _ = Describe("Client Session", func() {
|
||||
}()
|
||||
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
|
||||
packer.EXPECT().ChangeDestConnectionID(newConnID)
|
||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
||||
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||
hdr: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
@@ -1361,7 +1370,7 @@ var _ = Describe("Client Session", func() {
|
||||
Length: 1,
|
||||
},
|
||||
data: []byte{0},
|
||||
})).To(BeTrue())
|
||||
}))).To(BeTrue())
|
||||
// make sure the go routine returns
|
||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
||||
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
||||
|
||||
Reference in New Issue
Block a user