From be4bdbfad92faba8518890b6f63fc3eae8335fc1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 30 Dec 2018 12:36:01 +0700 Subject: [PATCH 1/7] remove the encryption level from the packed packet struct This was only needed in gQUIC. In IETF QUIC, We can always derive the encryption level from the header type when needed. --- packet_packer.go | 58 ++++++++++++++++++++++++------------------- packet_packer_test.go | 8 +++--- session.go | 2 +- 3 files changed, 38 insertions(+), 30 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 25e3cdf18..48292eeda 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -25,10 +25,23 @@ type packer interface { } type packedPacket struct { - header *wire.ExtendedHeader - raw []byte - frames []wire.Frame - encryptionLevel protocol.EncryptionLevel + header *wire.ExtendedHeader + raw []byte + frames []wire.Frame +} + +func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel { + if !p.header.IsLongHeader { + return protocol.Encryption1RTT + } + switch p.header.Type { + case protocol.PacketTypeInitial: + return protocol.EncryptionInitial + case protocol.PacketTypeHandshake: + return protocol.EncryptionHandshake + default: + return protocol.EncryptionUnspecified + } } func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { @@ -37,7 +50,7 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { PacketType: p.header.Type, Frames: p.frames, Length: protocol.ByteCount(len(p.raw)), - EncryptionLevel: p.encryptionLevel, + EncryptionLevel: p.EncryptionLevel(), SendTime: time.Now(), } } @@ -138,10 +151,9 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac header := p.getHeader(encLevel) raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: header, + raw: raw, + frames: frames, }, err } @@ -156,10 +168,9 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { frames := []wire.Frame{ack} raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: header, + raw: raw, + frames: frames, }, err } @@ -232,10 +243,9 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP return nil, err } packets = append(packets, &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: header, + raw: raw, + frames: frames, }) } return packets, nil @@ -286,10 +296,9 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { return nil, err } return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: header, + raw: raw, + frames: frames, }, nil } @@ -325,10 +334,9 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { return nil, err } return &packedPacket{ - header: hdr, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: hdr, + raw: raw, + frames: frames, }, nil } diff --git a/packet_packer_test.go b/packet_packer_test.go index f262f703a..1971700ed 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -253,7 +253,7 @@ var _ = Describe("Packet packer", func() { }) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) }) It("packs a single ACK", func() { @@ -494,7 +494,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(packets).To(HaveLen(1)) p := packets[0] - Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) Expect(p.frames).To(Equal(frames)) }) @@ -846,7 +846,7 @@ var _ = Describe("Packet packer", func() { Expect(p).To(HaveLen(1)) Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p[0].frames).To(Equal([]wire.Frame{f})) - Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) }) It("packs a retransmission for an Initial packet", func() { @@ -864,7 +864,7 @@ var _ = Describe("Packet packer", func() { Expect(packets).To(HaveLen(1)) p := packets[0] Expect(p.frames).To(Equal([]wire.Frame{sf})) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p.header.Token).To(Equal(token)) Expect(p.raw).To(HaveLen(protocol.MinInitialPacketSize)) diff --git a/session.go b/session.go index 2cdd6d8c0..b16083b7a 100644 --- a/session.go +++ b/session.go @@ -976,7 +976,7 @@ func (s *session) logPacket(packet *packedPacket) { // We don't need to allocate the slices for calling the format functions return } - s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.encryptionLevel) + s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.EncryptionLevel()) packet.header.Log(s.logger) for _, frame := range packet.frames { wire.LogFrame(s.logger, frame, true) From e3d2c81729673c2a96483ea36f661336969cdc25 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 30 Dec 2018 12:41:30 +0700 Subject: [PATCH 2/7] simplify packing of packets --- packet_packer.go | 50 ++++++++++++------------------------------------ 1 file changed, 12 insertions(+), 38 deletions(-) diff --git a/packet_packer.go b/packet_packer.go index 48292eeda..269c77db5 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -149,12 +149,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac frames := []wire.Frame{ccf} encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - raw, err := p.writeAndSealPacket(header, frames, sealer) - return &packedPacket{ - header: header, - raw: raw, - frames: frames, - }, err + return p.writeAndSealPacket(header, frames, sealer) } func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { @@ -166,12 +161,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) frames := []wire.Frame{ack} - raw, err := p.writeAndSealPacket(header, frames, sealer) - return &packedPacket{ - header: header, - raw: raw, - frames: frames, - }, err + return p.writeAndSealPacket(header, frames, sealer) } // PackRetransmission packs a retransmission @@ -238,15 +228,11 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok { sf.DataLenPresent = false } - raw, err := p.writeAndSealPacket(header, frames, sealer) + p, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err } - packets = append(packets, &packedPacket{ - header: header, - raw: raw, - frames: frames, - }) + packets = append(packets, p) } return packets, nil } @@ -291,15 +277,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { p.numNonRetransmittableAcks = 0 } - raw, err := p.writeAndSealPacket(header, frames, sealer) - if err != nil { - return nil, err - } - return &packedPacket{ - header: header, - raw: raw, - frames: frames, - }, nil + return p.writeAndSealPacket(header, frames, sealer) } func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { @@ -329,15 +307,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { } cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length) frames = append(frames, cf) - raw, err := p.writeAndSealPacket(hdr, frames, sealer) - if err != nil { - return nil, err - } - return &packedPacket{ - header: hdr, - raw: raw, - frames: frames, - }, nil + return p.writeAndSealPacket(hdr, frames, sealer) } func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) { @@ -403,7 +373,7 @@ func (p *packetPacker) writeAndSealPacket( header *wire.ExtendedHeader, frames []wire.Frame, sealer handshake.Sealer, -) ([]byte, error) { +) (*packedPacket, error) { raw := *getPacketBuffer() buffer := bytes.NewBuffer(raw[:0]) @@ -481,7 +451,11 @@ func (p *packetPacker) writeAndSealPacket( if num != header.PacketNumber { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - return raw, nil + return &packedPacket{ + header: header, + raw: raw, + frames: frames, + }, nil } func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) { From 413844d0bc024961bccd30e426693826ef05efaa Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 30 Dec 2018 12:16:53 +0700 Subject: [PATCH 3/7] use in-place decryption when opening packets --- packet_unpacker.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/packet_unpacker.go b/packet_unpacker.go index 5cde3fa7e..b9eb8ea01 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -93,11 +93,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, extHdr.PacketNumber, ) - buf := *getPacketBuffer() - buf = buf[:0] - defer putPacketBuffer(&buf) - - decrypted, err := opener.Open(buf, data, pn, extHdr.Raw) + decrypted, err := opener.Open(data[:0], data, pn, extHdr.Raw) if err != nil { return nil, err } From 767dbdd5450165e14c7546e4a4effbb5eb222594 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 26 Dec 2018 20:57:30 +0630 Subject: [PATCH 4/7] put back packet buffers after processing a packet This introduces a reference counter in the packet buffer, which will be used to process coalesced packets. --- buffer_pool.go | 36 +++++++++++++++++++++++------- buffer_pool_test.go | 21 +++++++++++++----- packet_handler_map.go | 15 ++++++++----- packet_handler_map_test.go | 18 +++++++-------- packet_packer.go | 9 +++++--- server.go | 9 +++++++- server_test.go | 34 ++++++++++++++-------------- session.go | 19 +++++++++++----- session_test.go | 45 +++++++++++++++++++++++--------------- 9 files changed, 135 insertions(+), 71 deletions(-) diff --git a/buffer_pool.go b/buffer_pool.go index 6b233696f..e56468203 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -6,22 +6,42 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" ) -var bufferPool sync.Pool +type packetBuffer struct { + Slice []byte -func getPacketBuffer() *[]byte { - return bufferPool.Get().(*[]byte) + // refCount counts how many packets the Slice is used in. + // It doesn't support concurrent use. + // It is > 1 when used for coalesced packet. + refCount int } -func putPacketBuffer(buf *[]byte) { - if cap(*buf) != int(protocol.MaxReceivePacketSize) { +var bufferPool sync.Pool + +func getPacketBuffer() *packetBuffer { + buf := bufferPool.Get().(*packetBuffer) + buf.refCount = 1 + buf.Slice = buf.Slice[:protocol.MaxReceivePacketSize] + return buf +} + +func putPacketBuffer(buf *packetBuffer) { + if cap(buf.Slice) != int(protocol.MaxReceivePacketSize) { panic("putPacketBuffer called with packet of wrong size!") } - bufferPool.Put(buf) + buf.refCount-- + if buf.refCount < 0 { + panic("negative packetBuffer refCount") + } + // only put the packetBuffer back if it's not used any more + if buf.refCount == 0 { + bufferPool.Put(buf) + } } func init() { bufferPool.New = func() interface{} { - b := make([]byte, 0, protocol.MaxReceivePacketSize) - return &b + return &packetBuffer{ + Slice: make([]byte, 0, protocol.MaxReceivePacketSize), + } } } diff --git a/buffer_pool_test.go b/buffer_pool_test.go index 7413ceaf6..8b00b4d55 100644 --- a/buffer_pool_test.go +++ b/buffer_pool_test.go @@ -9,13 +9,24 @@ import ( var _ = Describe("Buffer Pool", func() { It("returns buffers of cap", func() { - buf := *getPacketBuffer() - Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize))) + buf := getPacketBuffer() + Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize))) + }) + + It("puts buffers back", func() { + buf := getPacketBuffer() + putPacketBuffer(buf) }) It("panics if wrong-sized buffers are passed", func() { - Expect(func() { - putPacketBuffer(&[]byte{0}) - }).To(Panic()) + buf := getPacketBuffer() + buf.Slice = make([]byte, 10) + Expect(func() { putPacketBuffer(buf) }).To(Panic()) + }) + + It("panics if it is put pack twice", func() { + buf := getPacketBuffer() + putPacketBuffer(buf) + Expect(func() { putPacketBuffer(buf) }).To(Panic()) }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 7e411df04..0f77dd9de 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -144,8 +144,8 @@ func (h *packetHandlerMap) close(e error) error { func (h *packetHandlerMap) listen() { for { - data := *getPacketBuffer() - data = data[:protocol.MaxReceivePacketSize] + buffer := getPacketBuffer() + data := buffer.Slice // The packet size should not exceed protocol.MaxReceivePacketSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable n, addr, err := h.conn.ReadFrom(data) @@ -155,13 +155,17 @@ func (h *packetHandlerMap) listen() { } data = data[:n] - if err := h.handlePacket(addr, data); err != nil { + if err := h.handlePacket(addr, buffer, data); err != nil { h.logger.Debugf("error handling packet from %s: %s", addr, err) } } } -func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { +func (h *packetHandlerMap) handlePacket( + addr net.Addr, + buffer *packetBuffer, + data []byte, +) error { r := bytes.NewReader(data) hdr, err := wire.ParseHeader(r, h.connIDLen) // drop the packet if we can't parse the header @@ -172,8 +176,9 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { p := &receivedPacket{ remoteAddr: addr, hdr: hdr, - data: data, rcvTime: time.Now(), + data: data, + buffer: buffer, } h.mutex.RLock() diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index a955259b4..31bc7513f 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -81,7 +81,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("drops unparseable packets", func() { - err := handler.handlePacket(nil, []byte{0, 1, 2, 3}) + err := handler.handlePacket(nil, nil, []byte{0, 1, 2, 3}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error parsing header:")) }) @@ -91,7 +91,7 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) - Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) }) It("deletes retired session entries after a wait time", func() { @@ -100,7 +100,7 @@ var _ = Describe("Packet Handler Map", func() { handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) }) It("passes packets arriving late for closed sessions to that session", func() { @@ -110,13 +110,13 @@ var _ = Describe("Packet Handler Map", func() { packetHandler.EXPECT().handlePacket(gomock.Any()) handler.Add(connID, packetHandler) handler.Retire(connID) - err := handler.handlePacket(nil, getPacket(connID)) + err := handler.handlePacket(nil, nil, getPacket(connID)) Expect(err).ToNot(HaveOccurred()) }) It("drops packets for unknown receivers", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - err := handler.handlePacket(nil, getPacket(connID)) + err := handler.handlePacket(nil, nil, getPacket(connID)) Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) }) @@ -171,10 +171,10 @@ var _ = Describe("Packet Handler Map", func() { handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42")) + Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42")) packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) packet = append(packet, token[:]...) - Expect(handler.handlePacket(nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99")) + Expect(handler.handlePacket(nil, nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99")) Expect(handler.resetTokens).To(BeEmpty()) }) }) @@ -188,7 +188,7 @@ var _ = Describe("Packet Handler Map", func() { Expect(p.hdr.DestConnectionID).To(Equal(connID)) }) handler.SetServer(server) - Expect(handler.handlePacket(nil, p)).To(Succeed()) + Expect(handler.handlePacket(nil, nil, p)).To(Succeed()) }) It("closes all server sessions", func() { @@ -209,7 +209,7 @@ var _ = Describe("Packet Handler Map", func() { server := NewMockUnknownPacketHandler(mockCtrl) handler.SetServer(server) handler.CloseServer() - Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788")) + Expect(handler.handlePacket(nil, nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788")) }) }) }) diff --git a/packet_packer.go b/packet_packer.go index 269c77db5..aa55262b8 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -28,6 +28,8 @@ type packedPacket struct { header *wire.ExtendedHeader raw []byte frames []wire.Frame + + buffer *packetBuffer } func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel { @@ -374,8 +376,8 @@ func (p *packetPacker) writeAndSealPacket( frames []wire.Frame, sealer handshake.Sealer, ) (*packedPacket, error) { - raw := *getPacketBuffer() - buffer := bytes.NewBuffer(raw[:0]) + packetBuffer := getPacketBuffer() + buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial @@ -436,7 +438,7 @@ func (p *packetPacker) writeAndSealPacket( return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } - raw = raw[0:buffer.Len()] + raw := buffer.Bytes() _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset]) raw = raw[0 : buffer.Len()+sealer.Overhead()] @@ -455,6 +457,7 @@ func (p *packetPacker) writeAndSealPacket( header: header, raw: raw, frames: frames, + buffer: packetBuffer, }, nil } diff --git a/server.go b/server.go index 14b43cff4..c665283db 100644 --- a/server.go +++ b/server.go @@ -318,21 +318,27 @@ func (s *server) handlePacket(p *receivedPacket) { } if hdr.Type == protocol.PacketTypeInitial { go s.handleInitial(p) + return } + + putPacketBuffer(p.buffer) // TODO(#943): send Stateless Reset } func (s *server) handleInitial(p *receivedPacket) { - // TODO: add a check that DestConnID == SrcConnID s.logger.Debugf("<- Received Initial packet.") sess, connID, err := s.handleInitialImpl(p) if err != nil { + putPacketBuffer(p.buffer) s.logger.Errorf("Error occurred handling initial packet: %s", err) return } if sess == nil { // a retry was done + putPacketBuffer(p.buffer) return } + // Don't put the packet buffer back if a new session was created. + // The session will handle the packet and take of that. serverSession := newServerSession(sess, s.config, s.logger) s.sessionHandler.Add(connID, serverSession) } @@ -455,6 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { } func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { + defer putPacketBuffer(p.buffer) hdr := p.hdr s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) diff --git a/server_test.go b/server_test.go index 479423ab2..ddbac18f6 100644 --- a/server_test.go +++ b/server_test.go @@ -122,19 +122,19 @@ var _ = Describe("Server", func() { } It("drops Initial packets with a too short connection ID", func() { - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, Version: serv.config.Versions[0], }, - }) + })) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops too small Initial", func() { - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -142,12 +142,12 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100), - }) + })) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops packets with a too short connection ID", func() { - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -156,19 +156,19 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - }) + })) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops non-Initial packets", func() { serv.logger.SetLogLevel(utils.LogLevelDebug) - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ Type: protocol.PacketTypeHandshake, Version: serv.config.Versions[0], }, data: []byte("invalid"), - }) + })) }) It("decodes the cookie from the Token field", func() { @@ -185,7 +185,7 @@ var _ = Describe("Server", func() { } token, err := serv.cookieGenerator.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ remoteAddr: raddr, hdr: &wire.Header{ Type: protocol.PacketTypeInitial, @@ -193,7 +193,7 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - }) + })) Eventually(done).Should(BeClosed()) }) @@ -209,7 +209,7 @@ var _ = Describe("Server", func() { close(done) return false } - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ remoteAddr: raddr, hdr: &wire.Header{ Type: protocol.PacketTypeInitial, @@ -217,14 +217,14 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - }) + })) Eventually(done).Should(BeClosed()) }) It("sends a Version Negotiation Packet for unsupported versions", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, hdr: &wire.Header{ IsLongHeader: true, @@ -233,7 +233,7 @@ var _ = Describe("Server", func() { DestConnectionID: destConnID, Version: 0x42, }, - }) + })) var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) @@ -253,11 +253,11 @@ var _ = Describe("Server", func() { DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - serv.handleInitial(&receivedPacket{ + serv.handleInitial(insertPacketBuffer(&receivedPacket{ remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, hdr: hdr, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - }) + })) var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) @@ -308,7 +308,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - serv.handlePacket(p) + serv.handlePacket(insertPacketBuffer(p)) // the Handshake packet is written by the session Consistently(conn.dataWritten).ShouldNot(Receive()) close(done) diff --git a/session.go b/session.go index b16083b7a..511a2b24e 100644 --- a/session.go +++ b/session.go @@ -53,8 +53,10 @@ type cryptoStreamHandler interface { type receivedPacket struct { remoteAddr net.Addr hdr *wire.Header - data []byte rcvTime time.Time + data []byte + + buffer *packetBuffer } type closeError struct { @@ -368,9 +370,6 @@ runLoop: if wasProcessed := s.handlePacketImpl(p); !wasProcessed { continue } - // This is a bit unclean, but works properly, since the packet always - // begins with the public header and we never copy it. - // TODO: putPacketBuffer(&p.extHdr.Raw) case <-s.handshakeCompleteChan: s.handleHandshakeComplete() } @@ -475,6 +474,15 @@ func (s *session) handleHandshakeComplete() { } func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ { + var wasQueued bool + + defer func() { + // Put back the packet buffer if the packet wasn't queued for later decryption. + if !wasQueued { + putPacketBuffer(p.buffer) + } + }() + // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) { @@ -490,6 +498,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc // if the decryption failed, this might be a packet sent by an attacker if err != nil { if err == handshake.ErrOpenerNotYetAvailable { + wasQueued = true s.tryQueueingUndecryptablePacket(p) return false } @@ -953,7 +962,7 @@ func (s *session) sendPacket() (bool, error) { } func (s *session) sendPackedPacket(packet *packedPacket) error { - defer putPacketBuffer(&packet.raw) + defer putPacketBuffer(packet.buffer) s.logPacket(packet) return s.conn.Write(packet.raw) } diff --git a/session_test.go b/session_test.go index 02fb35792..3b0eb59d6 100644 --- a/session_test.go +++ b/session_test.go @@ -61,6 +61,11 @@ func areSessionsRunning() bool { return strings.Contains(b.String(), "quic-go.(*session).run") } +func insertPacketBuffer(p *receivedPacket) *receivedPacket { + p.buffer = getPacketBuffer() + return p +} + var _ = Describe("Session", func() { var ( sess *session @@ -496,11 +501,11 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false) sess.receivedPacketHandler = rph - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ rcvTime: rcvTime, hdr: &hdr.Header, data: getData(hdr), - })).To(BeTrue()) + }))).To(BeTrue()) }) It("closes when handling a packet fails", func() { @@ -518,7 +523,10 @@ var _ = Describe("Session", func() { close(done) }() sessionRunner.EXPECT().retireConnectionID(gomock.Any()) - sess.handlePacket(&receivedPacket{hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1})}) + sess.handlePacket(insertPacketBuffer(&receivedPacket{ + hdr: &wire.Header{}, + data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), + })) Eventually(done).Should(BeClosed()) }) @@ -528,18 +536,18 @@ var _ = Describe("Session", func() { PacketNumberLen: protocol.PacketNumberLen1, } unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: hdr}, nil).Times(2) - Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue()) - Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue()) + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) }) It("ignores 0-RTT packets", func() { - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketType0RTT, DestConnectionID: sess.srcConnID, }, - })).To(BeFalse()) + }))).To(BeFalse()) }) It("ignores packets with a different source connection ID", func() { @@ -552,12 +560,12 @@ var _ = Describe("Session", func() { // Send one packet, which might change the connection ID. // only EXPECT one call to the unpacker unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: &wire.ExtendedHeader{Header: *hdr}}, nil) - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: hdr, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - })).To(BeTrue()) + }))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, DestConnectionID: sess.destConnID, @@ -565,7 +573,7 @@ var _ = Describe("Session", func() { Length: 1, }, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - })).To(BeFalse()) + }))).To(BeFalse()) }) Context("updating the remote address", func() { @@ -574,12 +582,11 @@ var _ = Describe("Session", func() { origAddr := sess.conn.(*mockConnection).remoteAddr remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(origAddr).ToNot(Equal(remoteIP)) - p := receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ remoteAddr: remoteIP, hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - } - Expect(sess.handlePacketImpl(&p)).To(BeTrue()) + }))).To(BeTrue()) Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr)) }) }) @@ -587,10 +594,12 @@ var _ = Describe("Session", func() { Context("sending packets", func() { getPacket := func(pn protocol.PacketNumber) *packedPacket { - data := *getPacketBuffer() + buffer := getPacketBuffer() + data := buffer.Slice[:0] data = append(data, []byte("foobar")...) return &packedPacket{ raw: data, + buffer: buffer, header: &wire.ExtendedHeader{PacketNumber: pn}, } } @@ -963,7 +972,7 @@ var _ = Describe("Session", func() { defer close(done) return &packedPacket{ header: &wire.ExtendedHeader{}, - raw: *getPacketBuffer(), + buffer: getPacketBuffer(), }, nil }), packer.EXPECT().PackPacket().AnyTimes(), @@ -1352,7 +1361,7 @@ var _ = Describe("Client Session", func() { }() newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} packer.EXPECT().ChangeDestConnectionID(newConnID) - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -1361,7 +1370,7 @@ var _ = Describe("Client Session", func() { Length: 1, }, data: []byte{0}, - })).To(BeTrue()) + }))).To(BeTrue()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().retireConnectionID(gomock.Any()) From 6ce7a204fc8f5f9f71049c19a0e134cfd0091860 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 26 Dec 2018 22:01:27 +0700 Subject: [PATCH 5/7] move cutting of packets with a length to the packet handler map --- packet_handler_map.go | 8 +++++++ packet_handler_map_test.go | 28 ++++++++++++++++++++++--- packet_unpacker.go | 8 ------- packet_unpacker_test.go | 43 -------------------------------------- 4 files changed, 33 insertions(+), 54 deletions(-) diff --git a/packet_handler_map.go b/packet_handler_map.go index 0f77dd9de..114d2a18e 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -173,6 +173,14 @@ func (h *packetHandlerMap) handlePacket( return fmt.Errorf("error parsing header: %s", err) } + if hdr.IsLongHeader { + if protocol.ByteCount(r.Len()) < hdr.Length { + return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + data = data[:int(hdr.ParsedLen()+hdr.Length)] + // TODO(#1312): implement parsing of compound packets + } + p := &receivedPacket{ remoteAddr: addr, hdr: hdr, diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 31bc7513f..8a0141d3d 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -19,21 +19,25 @@ var _ = Describe("Packet Handler Map", func() { conn *mockPacketConn ) - getPacket := func(connID protocol.ConnectionID) []byte { + getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte { buf := &bytes.Buffer{} Expect((&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, DestConnectionID: connID, - Length: 1, + Length: length, Version: protocol.VersionTLS, }, - PacketNumberLen: protocol.PacketNumberLen1, + PacketNumberLen: protocol.PacketNumberLen2, }).Write(buf, protocol.VersionWhatever)).To(Succeed()) return buf.Bytes() } + getPacket := func(connID protocol.ConnectionID) []byte { + return getPacketWithLength(connID, 1) + } + BeforeEach(func() { conn = newMockPacketConn() handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap) @@ -131,6 +135,24 @@ var _ = Describe("Packet Handler Map", func() { conn.Close() Eventually(done).Should(BeClosed()) }) + + It("errors on packets that are smaller than the length in the packet header", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...) + err := handler.handlePacket(nil, nil, data) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + + It("cuts packets to the right length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) + }) + handler.Add(connID, packetHandler) + Expect(handler.handlePacket(nil, nil, data)).To(Succeed()) + }) }) Context("stateless reset handling", func() { diff --git a/packet_unpacker.go b/packet_unpacker.go index b9eb8ea01..cda222324 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -39,14 +39,6 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { r := bytes.NewReader(data) - if hdr.IsLongHeader { - if protocol.ByteCount(r.Len()) < hdr.Length { - return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) - } - data = data[:int(hdr.ParsedLen()+hdr.Length)] - // TODO(#1312): implement parsing of compound packets - } - var encLevel protocol.EncryptionLevel switch hdr.Type { case protocol.PacketTypeInitial: diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 8aa4e446b..c9a085d08 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -75,49 +75,6 @@ var _ = Describe("Packet Unpacker", func() { Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) }) - It("errors on packets that are smaller than the length in the packet header", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Length: 1000, - DestConnectionID: connID, - Version: version, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - hdr, hdrRaw := getHeader(extHdr) - data := append(hdrRaw, make([]byte, 500-2 /* for packet number length */)...) - _, err := unpacker.Unpack(hdr, data) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - - It("cuts packets to the right length", func() { - pnLen := protocol.PacketNumberLen2 - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: connID, - Type: protocol.PacketTypeHandshake, - Length: 456, - Version: protocol.VersionTLS, - }, - PacketNumberLen: pnLen, - } - payloadLen := 456 - int(pnLen) - hdr, hdrRaw := getHeader(extHdr) - data := append(hdrRaw, make([]byte, payloadLen)...) - opener := mocks.NewMockOpener(mockCtrl) - cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) { - Expect(payload).To(HaveLen(payloadLen)) - return []byte{0}, nil - }) - _, err := unpacker.Unpack(hdr, data) - Expect(err).ToNot(HaveOccurred()) - }) - It("returns the error when getting the sealer fails", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: connID}, From d3ea7c0c4c4a1f8d1e217c991998a401b16024df Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 29 Dec 2018 17:48:28 +0700 Subject: [PATCH 6/7] implement handling of coalesced packets --- buffer_pool.go | 7 ++ buffer_pool_test.go | 11 ++++ packet_handler_map.go | 132 ++++++++++++++++++++++++------------- packet_handler_map_test.go | 127 +++++++++++++++++++++++++++-------- 4 files changed, 205 insertions(+), 72 deletions(-) diff --git a/buffer_pool.go b/buffer_pool.go index e56468203..c890d32ba 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -15,6 +15,13 @@ type packetBuffer struct { refCount int } +// Split increases the refCount. +// It must be called when a packet buffer is used for more than one packet, +// e.g. when splitting coalesced packets. +func (b *packetBuffer) Split() { + b.refCount++ +} + var bufferPool sync.Pool func getPacketBuffer() *packetBuffer { diff --git a/buffer_pool_test.go b/buffer_pool_test.go index 8b00b4d55..ef6b48523 100644 --- a/buffer_pool_test.go +++ b/buffer_pool_test.go @@ -29,4 +29,15 @@ var _ = Describe("Buffer Pool", func() { putPacketBuffer(buf) Expect(func() { putPacketBuffer(buf) }).To(Panic()) }) + + It("waits until all parts have been put back", func() { + buf := getPacketBuffer() + buf.Split() + buf.Split() + // now we have 3 parts + putPacketBuffer(buf) + putPacketBuffer(buf) + putPacketBuffer(buf) + Expect(func() { putPacketBuffer(buf) }).To(Panic()) + }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 114d2a18e..5c2a1948e 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -153,11 +153,7 @@ func (h *packetHandlerMap) listen() { h.close(err) return } - data = data[:n] - - if err := h.handlePacket(addr, buffer, data); err != nil { - h.logger.Debugf("error handling packet from %s: %s", addr, err) - } + h.handlePacket(addr, buffer, data[:n]) } } @@ -165,56 +161,102 @@ func (h *packetHandlerMap) handlePacket( addr net.Addr, buffer *packetBuffer, data []byte, -) error { - r := bytes.NewReader(data) - hdr, err := wire.ParseHeader(r, h.connIDLen) - // drop the packet if we can't parse the header +) { + packets, err := h.parsePacket(addr, buffer, data) if err != nil { - return fmt.Errorf("error parsing header: %s", err) + h.logger.Debugf("error parsing packets from %s: %s", addr, err) + // This is just the error from parsing the last packet. + // We still need to process the packets that were successfully parsed before. } + if len(packets) == 0 { + putPacketBuffer(buffer) + return + } + h.handleParsedPackets(packets) +} - if hdr.IsLongHeader { - if protocol.ByteCount(r.Len()) < hdr.Length { - return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) +func (h *packetHandlerMap) parsePacket( + addr net.Addr, + buffer *packetBuffer, + data []byte, +) ([]*receivedPacket, error) { + rcvTime := time.Now() + packets := make([]*receivedPacket, 0, 1) + + var counter int + var lastConnID protocol.ConnectionID + for len(data) > 0 { + if counter > 0 && h.logger.Debug() { + h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes", counter, len(packets[counter-1].data)) } - data = data[:int(hdr.ParsedLen()+hdr.Length)] - // TODO(#1312): implement parsing of compound packets - } - p := &receivedPacket{ - remoteAddr: addr, - hdr: hdr, - rcvTime: time.Now(), - data: data, - buffer: buffer, - } + hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen) + // drop the packet if we can't parse the header + if err != nil { + return packets, fmt.Errorf("error parsing header: %s", err) + } + if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) + } + lastConnID = hdr.DestConnectionID + var rest []byte + if hdr.IsLongHeader { + if protocol.ByteCount(len(data)) < hdr.Length { + return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + packetLen := int(hdr.ParsedLen() + hdr.Length) + rest = data[packetLen:] + data = data[:packetLen] + } + + if counter > 0 { + buffer.Split() + } + counter++ + packets = append(packets, &receivedPacket{ + remoteAddr: addr, + hdr: hdr, + rcvTime: rcvTime, + data: data, + buffer: buffer, + }) + data = rest + } + return packets, nil +} + +func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) { h.mutex.RLock() defer h.mutex.RUnlock() - handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)] + // coalesced packets all have the same destination connection ID + handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)] - if handlerFound { // existing session - handlerEntry.handler.handlePacket(p) - return nil - } - // No session found. - // This might be a stateless reset. - if !hdr.IsLongHeader { - if len(data) >= protocol.MinStatelessResetSize { - var token [16]byte - copy(token[:], data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) - return nil - } + for _, p := range packets { + if handlerFound { // existing session + handlerEntry.handler.handlePacket(p) + continue } - // TODO(#943): send a stateless reset - return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID) + // No session found. + // This might be a stateless reset. + if !p.hdr.IsLongHeader { + if len(p.data) >= protocol.MinStatelessResetSize { + var token [16]byte + copy(token[:], p.data[len(p.data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + sess.destroy(errors.New("received a stateless reset")) + continue + } + } + // TODO(#943): send a stateless reset + h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID) + break // a short header packet is always the last in a coalesced packet + + } + if h.server != nil { // no server set + h.server.handlePacket(p) + } + h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID) } - if h.server == nil { // no server set - return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID) - } - h.server.handlePacket(p) - return nil } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 8a0141d3d..6a4e9ecee 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "errors" + "net" "time" "github.com/golang/mock/gomock" @@ -35,7 +36,7 @@ var _ = Describe("Packet Handler Map", func() { } getPacket := func(connID protocol.ConnectionID) []byte { - return getPacketWithLength(connID, 1) + return getPacketWithLength(connID, 2) } BeforeEach(func() { @@ -85,7 +86,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("drops unparseable packets", func() { - err := handler.handlePacket(nil, nil, []byte{0, 1, 2, 3}) + _, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error parsing header:")) }) @@ -95,7 +96,8 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) - Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler }) It("deletes retired session entries after a wait time", func() { @@ -104,7 +106,8 @@ var _ = Describe("Packet Handler Map", func() { handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler }) It("passes packets arriving late for closed sessions to that session", func() { @@ -114,14 +117,12 @@ var _ = Describe("Packet Handler Map", func() { packetHandler.EXPECT().handlePacket(gomock.Any()) handler.Add(connID, packetHandler) handler.Retire(connID) - err := handler.handlePacket(nil, nil, getPacket(connID)) - Expect(err).ToNot(HaveOccurred()) + handler.handlePacket(nil, nil, getPacket(connID)) }) It("drops packets for unknown receivers", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - err := handler.handlePacket(nil, nil, getPacket(connID)) - Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) }) It("closes the packet handlers when reading from the conn fails", func() { @@ -136,22 +137,73 @@ var _ = Describe("Packet Handler Map", func() { Eventually(done).Should(BeClosed()) }) - It("errors on packets that are smaller than the length in the packet header", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...) - err := handler.handlePacket(nil, nil, data) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - - It("cuts packets to the right length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) + Context("coalesced packets", func() { + It("errors on packets that are smaller than the length in the packet header", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...) + _, err := handler.parsePacket(nil, nil, data) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + + It("cuts packets to the right length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) + }) + handler.Add(connID, packetHandler) + handler.handlePacket(nil, nil, data) + }) + + It("handles coalesced packets", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + packetHandler := NewMockPacketHandler(mockCtrl) + handledPackets := make(chan *receivedPacket, 3) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + handledPackets <- p + }).Times(3) + handler.Add(connID, packetHandler) + + buffer := getPacketBuffer() + packet := buffer.Slice[:0] + packet = append(packet, append(getPacketWithLength(connID, 10), make([]byte, 10-2 /* packet number len */)...)...) + packet = append(packet, append(getPacketWithLength(connID, 20), make([]byte, 20-2 /* packet number len */)...)...) + packet = append(packet, append(getPacketWithLength(connID, 30), make([]byte, 30-2 /* packet number len */)...)...) + conn.dataToRead <- packet + + now := time.Now() + for i := 1; i <= 3; i++ { + var p *receivedPacket + Eventually(handledPackets).Should(Receive(&p)) + Expect(p.hdr.DestConnectionID).To(Equal(connID)) + Expect(p.hdr.Length).To(BeEquivalentTo(10 * i)) + Expect(p.data).To(HaveLen(int(p.hdr.ParsedLen() + p.hdr.Length))) + Expect(p.rcvTime).To(BeTemporally("~", now, scaleDuration(20*time.Millisecond))) + Expect(p.buffer.refCount).To(Equal(3)) + } + + // makes the listen go routine return + packetHandler.EXPECT().destroy(gomock.Any()).AnyTimes() + close(conn.dataToRead) + }) + + It("ignores coalesced packet parts if the connection IDs don't match", func() { + connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + + buffer := getPacketBuffer() + packet := buffer.Slice[:0] + // var packet []byte + packet = append(packet, getPacket(connID1)...) + packet = append(packet, getPacket(connID2)...) + + packets, err := handler.parsePacket(&net.UDPAddr{}, buffer, packet) + Expect(err).To(MatchError("coalesced packet has different destination connection ID: 0x0807060504030201, expected 0x0102030405060708")) + Expect(packets).To(HaveLen(1)) + Expect(packets[0].hdr.DestConnectionID).To(Equal(connID1)) + Expect(packets[0].buffer.refCount).To(Equal(1)) }) - handler.Add(connID, packetHandler) - Expect(handler.handlePacket(nil, nil, data)).To(Succeed()) }) }) @@ -186,6 +238,24 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) + It("detects a stateless that is coalesced with another packet", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddWithResetToken(connID, packetHandler, token) + fakeConnID := protocol.ConnectionID{1, 2, 3, 4, 5} + packet := getPacket(fakeConnID) + reset := append([]byte{0x40} /* short header packet */, fakeConnID...) + reset = append(reset, make([]byte, 50)...) // add some "random" data + reset = append(reset, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- append(packet, reset...) + Eventually(destroyed).Should(BeClosed()) + }) + It("deletes reset tokens when the session is retired", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} @@ -193,10 +263,12 @@ var _ = Describe("Packet Handler Map", func() { handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) packet = append(packet, token[:]...) - Expect(handler.handlePacket(nil, nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99")) + handler.handlePacket(nil, nil, packet) + // don't EXPECT any calls to handlePacket of the MockPacketHandler Expect(handler.resetTokens).To(BeEmpty()) }) }) @@ -210,7 +282,7 @@ var _ = Describe("Packet Handler Map", func() { Expect(p.hdr.DestConnectionID).To(Equal(connID)) }) handler.SetServer(server) - Expect(handler.handlePacket(nil, nil, p)).To(Succeed()) + handler.handlePacket(nil, nil, p) }) It("closes all server sessions", func() { @@ -229,9 +301,10 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket handler.SetServer(server) handler.CloseServer() - Expect(handler.handlePacket(nil, nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788")) + handler.handlePacket(nil, nil, p) }) }) }) From b32fb438afe2a3497bb5d39fcc2b3aa2dc216f3d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 30 Dec 2018 15:28:52 +0700 Subject: [PATCH 7/7] make putting back packet buffers a member function of packetBuffer --- buffer_pool.go | 31 +++++++++++++++++-------------- buffer_pool_test.go | 22 +++++++++++----------- packet_handler_map.go | 2 +- server.go | 8 ++++---- session.go | 4 ++-- 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/buffer_pool.go b/buffer_pool.go index c890d32ba..204eff2bc 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -22,6 +22,23 @@ func (b *packetBuffer) Split() { b.refCount++ } +// Release decreases the refCount. +// It should be called when processing the packet is finished. +// When the refCount reaches 0, the packet buffer is put back into the pool. +func (b *packetBuffer) Release() { + if cap(b.Slice) != int(protocol.MaxReceivePacketSize) { + panic("putPacketBuffer called with packet of wrong size!") + } + b.refCount-- + if b.refCount < 0 { + panic("negative packetBuffer refCount") + } + // only put the packetBuffer back if it's not used any more + if b.refCount == 0 { + bufferPool.Put(b) + } +} + var bufferPool sync.Pool func getPacketBuffer() *packetBuffer { @@ -31,20 +48,6 @@ func getPacketBuffer() *packetBuffer { return buf } -func putPacketBuffer(buf *packetBuffer) { - if cap(buf.Slice) != int(protocol.MaxReceivePacketSize) { - panic("putPacketBuffer called with packet of wrong size!") - } - buf.refCount-- - if buf.refCount < 0 { - panic("negative packetBuffer refCount") - } - // only put the packetBuffer back if it's not used any more - if buf.refCount == 0 { - bufferPool.Put(buf) - } -} - func init() { bufferPool.New = func() interface{} { return &packetBuffer{ diff --git a/buffer_pool_test.go b/buffer_pool_test.go index ef6b48523..c49141d90 100644 --- a/buffer_pool_test.go +++ b/buffer_pool_test.go @@ -13,31 +13,31 @@ var _ = Describe("Buffer Pool", func() { Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize))) }) - It("puts buffers back", func() { + It("releases buffers", func() { buf := getPacketBuffer() - putPacketBuffer(buf) + buf.Release() }) It("panics if wrong-sized buffers are passed", func() { buf := getPacketBuffer() buf.Slice = make([]byte, 10) - Expect(func() { putPacketBuffer(buf) }).To(Panic()) + Expect(func() { buf.Release() }).To(Panic()) }) - It("panics if it is put pack twice", func() { + It("panics if it is released twice", func() { buf := getPacketBuffer() - putPacketBuffer(buf) - Expect(func() { putPacketBuffer(buf) }).To(Panic()) + buf.Release() + Expect(func() { buf.Release() }).To(Panic()) }) - It("waits until all parts have been put back", func() { + It("waits until all parts have been released", func() { buf := getPacketBuffer() buf.Split() buf.Split() // now we have 3 parts - putPacketBuffer(buf) - putPacketBuffer(buf) - putPacketBuffer(buf) - Expect(func() { putPacketBuffer(buf) }).To(Panic()) + buf.Release() + buf.Release() + buf.Release() + Expect(func() { buf.Release() }).To(Panic()) }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 5c2a1948e..ba810b44f 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -169,7 +169,7 @@ func (h *packetHandlerMap) handlePacket( // We still need to process the packets that were successfully parsed before. } if len(packets) == 0 { - putPacketBuffer(buffer) + buffer.Release() return } h.handleParsedPackets(packets) diff --git a/server.go b/server.go index c665283db..3fb1ca7a3 100644 --- a/server.go +++ b/server.go @@ -321,20 +321,20 @@ func (s *server) handlePacket(p *receivedPacket) { return } - putPacketBuffer(p.buffer) // TODO(#943): send Stateless Reset + p.buffer.Release() } func (s *server) handleInitial(p *receivedPacket) { s.logger.Debugf("<- Received Initial packet.") sess, connID, err := s.handleInitialImpl(p) if err != nil { - putPacketBuffer(p.buffer) + p.buffer.Release() s.logger.Errorf("Error occurred handling initial packet: %s", err) return } if sess == nil { // a retry was done - putPacketBuffer(p.buffer) + p.buffer.Release() return } // Don't put the packet buffer back if a new session was created. @@ -461,7 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { } func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { - defer putPacketBuffer(p.buffer) + defer p.buffer.Release() hdr := p.hdr s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) diff --git a/session.go b/session.go index 511a2b24e..63aa3d9d7 100644 --- a/session.go +++ b/session.go @@ -479,7 +479,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc defer func() { // Put back the packet buffer if the packet wasn't queued for later decryption. if !wasQueued { - putPacketBuffer(p.buffer) + p.buffer.Release() } }() @@ -962,7 +962,7 @@ func (s *session) sendPacket() (bool, error) { } func (s *session) sendPackedPacket(packet *packedPacket) error { - defer putPacketBuffer(packet.buffer) + defer packet.buffer.Release() s.logPacket(packet) return s.conn.Write(packet.raw) }