diff --git a/buffer_pool.go b/buffer_pool.go index 6b233696f..204eff2bc 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -6,22 +6,52 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" ) -var bufferPool sync.Pool +type packetBuffer struct { + Slice []byte -func getPacketBuffer() *[]byte { - return bufferPool.Get().(*[]byte) + // refCount counts how many packets the Slice is used in. + // It doesn't support concurrent use. + // It is > 1 when used for coalesced packet. + refCount int } -func putPacketBuffer(buf *[]byte) { - if cap(*buf) != int(protocol.MaxReceivePacketSize) { +// Split increases the refCount. +// It must be called when a packet buffer is used for more than one packet, +// e.g. when splitting coalesced packets. +func (b *packetBuffer) Split() { + b.refCount++ +} + +// Release decreases the refCount. +// It should be called when processing the packet is finished. +// When the refCount reaches 0, the packet buffer is put back into the pool. +func (b *packetBuffer) Release() { + if cap(b.Slice) != int(protocol.MaxReceivePacketSize) { panic("putPacketBuffer called with packet of wrong size!") } - bufferPool.Put(buf) + b.refCount-- + if b.refCount < 0 { + panic("negative packetBuffer refCount") + } + // only put the packetBuffer back if it's not used any more + if b.refCount == 0 { + bufferPool.Put(b) + } +} + +var bufferPool sync.Pool + +func getPacketBuffer() *packetBuffer { + buf := bufferPool.Get().(*packetBuffer) + buf.refCount = 1 + buf.Slice = buf.Slice[:protocol.MaxReceivePacketSize] + return buf } func init() { bufferPool.New = func() interface{} { - b := make([]byte, 0, protocol.MaxReceivePacketSize) - return &b + return &packetBuffer{ + Slice: make([]byte, 0, protocol.MaxReceivePacketSize), + } } } diff --git a/buffer_pool_test.go b/buffer_pool_test.go index 7413ceaf6..c49141d90 100644 --- a/buffer_pool_test.go +++ b/buffer_pool_test.go @@ -9,13 +9,35 @@ import ( var _ = Describe("Buffer Pool", func() { It("returns buffers of cap", func() { - buf := *getPacketBuffer() - Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize))) + buf := getPacketBuffer() + Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize))) + }) + + It("releases buffers", func() { + buf := getPacketBuffer() + buf.Release() }) It("panics if wrong-sized buffers are passed", func() { - Expect(func() { - putPacketBuffer(&[]byte{0}) - }).To(Panic()) + buf := getPacketBuffer() + buf.Slice = make([]byte, 10) + Expect(func() { buf.Release() }).To(Panic()) + }) + + It("panics if it is released twice", func() { + buf := getPacketBuffer() + buf.Release() + Expect(func() { buf.Release() }).To(Panic()) + }) + + It("waits until all parts have been released", func() { + buf := getPacketBuffer() + buf.Split() + buf.Split() + // now we have 3 parts + buf.Release() + buf.Release() + buf.Release() + Expect(func() { buf.Release() }).To(Panic()) }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 7e411df04..ba810b44f 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -144,8 +144,8 @@ func (h *packetHandlerMap) close(e error) error { func (h *packetHandlerMap) listen() { for { - data := *getPacketBuffer() - data = data[:protocol.MaxReceivePacketSize] + buffer := getPacketBuffer() + data := buffer.Slice // The packet size should not exceed protocol.MaxReceivePacketSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable n, addr, err := h.conn.ReadFrom(data) @@ -153,55 +153,110 @@ func (h *packetHandlerMap) listen() { h.close(err) return } - data = data[:n] - - if err := h.handlePacket(addr, data); err != nil { - h.logger.Debugf("error handling packet from %s: %s", addr, err) - } + h.handlePacket(addr, buffer, data[:n]) } } -func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { - r := bytes.NewReader(data) - hdr, err := wire.ParseHeader(r, h.connIDLen) - // drop the packet if we can't parse the header +func (h *packetHandlerMap) handlePacket( + addr net.Addr, + buffer *packetBuffer, + data []byte, +) { + packets, err := h.parsePacket(addr, buffer, data) if err != nil { - return fmt.Errorf("error parsing header: %s", err) + h.logger.Debugf("error parsing packets from %s: %s", addr, err) + // This is just the error from parsing the last packet. + // We still need to process the packets that were successfully parsed before. } - - p := &receivedPacket{ - remoteAddr: addr, - hdr: hdr, - data: data, - rcvTime: time.Now(), + if len(packets) == 0 { + buffer.Release() + return } + h.handleParsedPackets(packets) +} +func (h *packetHandlerMap) parsePacket( + addr net.Addr, + buffer *packetBuffer, + data []byte, +) ([]*receivedPacket, error) { + rcvTime := time.Now() + packets := make([]*receivedPacket, 0, 1) + + var counter int + var lastConnID protocol.ConnectionID + for len(data) > 0 { + if counter > 0 && h.logger.Debug() { + h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes", counter, len(packets[counter-1].data)) + } + + hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen) + // drop the packet if we can't parse the header + if err != nil { + return packets, fmt.Errorf("error parsing header: %s", err) + } + if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) + } + lastConnID = hdr.DestConnectionID + + var rest []byte + if hdr.IsLongHeader { + if protocol.ByteCount(len(data)) < hdr.Length { + return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + packetLen := int(hdr.ParsedLen() + hdr.Length) + rest = data[packetLen:] + data = data[:packetLen] + } + + if counter > 0 { + buffer.Split() + } + counter++ + packets = append(packets, &receivedPacket{ + remoteAddr: addr, + hdr: hdr, + rcvTime: rcvTime, + data: data, + buffer: buffer, + }) + data = rest + } + return packets, nil +} + +func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) { h.mutex.RLock() defer h.mutex.RUnlock() - handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)] + // coalesced packets all have the same destination connection ID + handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)] - if handlerFound { // existing session - handlerEntry.handler.handlePacket(p) - return nil - } - // No session found. - // This might be a stateless reset. - if !hdr.IsLongHeader { - if len(data) >= protocol.MinStatelessResetSize { - var token [16]byte - copy(token[:], data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) - return nil - } + for _, p := range packets { + if handlerFound { // existing session + handlerEntry.handler.handlePacket(p) + continue } - // TODO(#943): send a stateless reset - return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID) + // No session found. + // This might be a stateless reset. + if !p.hdr.IsLongHeader { + if len(p.data) >= protocol.MinStatelessResetSize { + var token [16]byte + copy(token[:], p.data[len(p.data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + sess.destroy(errors.New("received a stateless reset")) + continue + } + } + // TODO(#943): send a stateless reset + h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID) + break // a short header packet is always the last in a coalesced packet + + } + if h.server != nil { // no server set + h.server.handlePacket(p) + } + h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID) } - if h.server == nil { // no server set - return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID) - } - h.server.handlePacket(p) - return nil } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index a955259b4..6a4e9ecee 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "errors" + "net" "time" "github.com/golang/mock/gomock" @@ -19,21 +20,25 @@ var _ = Describe("Packet Handler Map", func() { conn *mockPacketConn ) - getPacket := func(connID protocol.ConnectionID) []byte { + getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte { buf := &bytes.Buffer{} Expect((&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, DestConnectionID: connID, - Length: 1, + Length: length, Version: protocol.VersionTLS, }, - PacketNumberLen: protocol.PacketNumberLen1, + PacketNumberLen: protocol.PacketNumberLen2, }).Write(buf, protocol.VersionWhatever)).To(Succeed()) return buf.Bytes() } + getPacket := func(connID protocol.ConnectionID) []byte { + return getPacketWithLength(connID, 2) + } + BeforeEach(func() { conn = newMockPacketConn() handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap) @@ -81,7 +86,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("drops unparseable packets", func() { - err := handler.handlePacket(nil, []byte{0, 1, 2, 3}) + _, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error parsing header:")) }) @@ -91,7 +96,8 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) - Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler }) It("deletes retired session entries after a wait time", func() { @@ -100,7 +106,8 @@ var _ = Describe("Packet Handler Map", func() { handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler }) It("passes packets arriving late for closed sessions to that session", func() { @@ -110,14 +117,12 @@ var _ = Describe("Packet Handler Map", func() { packetHandler.EXPECT().handlePacket(gomock.Any()) handler.Add(connID, packetHandler) handler.Retire(connID) - err := handler.handlePacket(nil, getPacket(connID)) - Expect(err).ToNot(HaveOccurred()) + handler.handlePacket(nil, nil, getPacket(connID)) }) It("drops packets for unknown receivers", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - err := handler.handlePacket(nil, getPacket(connID)) - Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) }) It("closes the packet handlers when reading from the conn fails", func() { @@ -131,6 +136,75 @@ var _ = Describe("Packet Handler Map", func() { conn.Close() Eventually(done).Should(BeClosed()) }) + + Context("coalesced packets", func() { + It("errors on packets that are smaller than the length in the packet header", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...) + _, err := handler.parsePacket(nil, nil, data) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + + It("cuts packets to the right length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) + }) + handler.Add(connID, packetHandler) + handler.handlePacket(nil, nil, data) + }) + + It("handles coalesced packets", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + packetHandler := NewMockPacketHandler(mockCtrl) + handledPackets := make(chan *receivedPacket, 3) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + handledPackets <- p + }).Times(3) + handler.Add(connID, packetHandler) + + buffer := getPacketBuffer() + packet := buffer.Slice[:0] + packet = append(packet, append(getPacketWithLength(connID, 10), make([]byte, 10-2 /* packet number len */)...)...) + packet = append(packet, append(getPacketWithLength(connID, 20), make([]byte, 20-2 /* packet number len */)...)...) + packet = append(packet, append(getPacketWithLength(connID, 30), make([]byte, 30-2 /* packet number len */)...)...) + conn.dataToRead <- packet + + now := time.Now() + for i := 1; i <= 3; i++ { + var p *receivedPacket + Eventually(handledPackets).Should(Receive(&p)) + Expect(p.hdr.DestConnectionID).To(Equal(connID)) + Expect(p.hdr.Length).To(BeEquivalentTo(10 * i)) + Expect(p.data).To(HaveLen(int(p.hdr.ParsedLen() + p.hdr.Length))) + Expect(p.rcvTime).To(BeTemporally("~", now, scaleDuration(20*time.Millisecond))) + Expect(p.buffer.refCount).To(Equal(3)) + } + + // makes the listen go routine return + packetHandler.EXPECT().destroy(gomock.Any()).AnyTimes() + close(conn.dataToRead) + }) + + It("ignores coalesced packet parts if the connection IDs don't match", func() { + connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + + buffer := getPacketBuffer() + packet := buffer.Slice[:0] + // var packet []byte + packet = append(packet, getPacket(connID1)...) + packet = append(packet, getPacket(connID2)...) + + packets, err := handler.parsePacket(&net.UDPAddr{}, buffer, packet) + Expect(err).To(MatchError("coalesced packet has different destination connection ID: 0x0807060504030201, expected 0x0102030405060708")) + Expect(packets).To(HaveLen(1)) + Expect(packets[0].hdr.DestConnectionID).To(Equal(connID1)) + Expect(packets[0].buffer.refCount).To(Equal(1)) + }) + }) }) Context("stateless reset handling", func() { @@ -164,6 +238,24 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) + It("detects a stateless that is coalesced with another packet", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddWithResetToken(connID, packetHandler, token) + fakeConnID := protocol.ConnectionID{1, 2, 3, 4, 5} + packet := getPacket(fakeConnID) + reset := append([]byte{0x40} /* short header packet */, fakeConnID...) + reset = append(reset, make([]byte, 50)...) // add some "random" data + reset = append(reset, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- append(packet, reset...) + Eventually(destroyed).Should(BeClosed()) + }) + It("deletes reset tokens when the session is retired", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} @@ -171,10 +263,12 @@ var _ = Describe("Packet Handler Map", func() { handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) packet = append(packet, token[:]...) - Expect(handler.handlePacket(nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99")) + handler.handlePacket(nil, nil, packet) + // don't EXPECT any calls to handlePacket of the MockPacketHandler Expect(handler.resetTokens).To(BeEmpty()) }) }) @@ -188,7 +282,7 @@ var _ = Describe("Packet Handler Map", func() { Expect(p.hdr.DestConnectionID).To(Equal(connID)) }) handler.SetServer(server) - Expect(handler.handlePacket(nil, p)).To(Succeed()) + handler.handlePacket(nil, nil, p) }) It("closes all server sessions", func() { @@ -207,9 +301,10 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket handler.SetServer(server) handler.CloseServer() - Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788")) + handler.handlePacket(nil, nil, p) }) }) }) diff --git a/packet_packer.go b/packet_packer.go index 25e3cdf18..aa55262b8 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -25,10 +25,25 @@ type packer interface { } type packedPacket struct { - header *wire.ExtendedHeader - raw []byte - frames []wire.Frame - encryptionLevel protocol.EncryptionLevel + header *wire.ExtendedHeader + raw []byte + frames []wire.Frame + + buffer *packetBuffer +} + +func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel { + if !p.header.IsLongHeader { + return protocol.Encryption1RTT + } + switch p.header.Type { + case protocol.PacketTypeInitial: + return protocol.EncryptionInitial + case protocol.PacketTypeHandshake: + return protocol.EncryptionHandshake + default: + return protocol.EncryptionUnspecified + } } func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { @@ -37,7 +52,7 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { PacketType: p.header.Type, Frames: p.frames, Length: protocol.ByteCount(len(p.raw)), - EncryptionLevel: p.encryptionLevel, + EncryptionLevel: p.EncryptionLevel(), SendTime: time.Now(), } } @@ -136,13 +151,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac frames := []wire.Frame{ccf} encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - raw, err := p.writeAndSealPacket(header, frames, sealer) - return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, - }, err + return p.writeAndSealPacket(header, frames, sealer) } func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { @@ -154,13 +163,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) frames := []wire.Frame{ack} - raw, err := p.writeAndSealPacket(header, frames, sealer) - return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, - }, err + return p.writeAndSealPacket(header, frames, sealer) } // PackRetransmission packs a retransmission @@ -227,16 +230,11 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok { sf.DataLenPresent = false } - raw, err := p.writeAndSealPacket(header, frames, sealer) + p, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err } - packets = append(packets, &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, - }) + packets = append(packets, p) } return packets, nil } @@ -281,16 +279,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { p.numNonRetransmittableAcks = 0 } - raw, err := p.writeAndSealPacket(header, frames, sealer) - if err != nil { - return nil, err - } - return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, - }, nil + return p.writeAndSealPacket(header, frames, sealer) } func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { @@ -320,16 +309,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { } cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length) frames = append(frames, cf) - raw, err := p.writeAndSealPacket(hdr, frames, sealer) - if err != nil { - return nil, err - } - return &packedPacket{ - header: hdr, - raw: raw, - frames: frames, - encryptionLevel: encLevel, - }, nil + return p.writeAndSealPacket(hdr, frames, sealer) } func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) { @@ -395,9 +375,9 @@ func (p *packetPacker) writeAndSealPacket( header *wire.ExtendedHeader, frames []wire.Frame, sealer handshake.Sealer, -) ([]byte, error) { - raw := *getPacketBuffer() - buffer := bytes.NewBuffer(raw[:0]) +) (*packedPacket, error) { + packetBuffer := getPacketBuffer() + buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial @@ -458,7 +438,7 @@ func (p *packetPacker) writeAndSealPacket( return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } - raw = raw[0:buffer.Len()] + raw := buffer.Bytes() _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset]) raw = raw[0 : buffer.Len()+sealer.Overhead()] @@ -473,7 +453,12 @@ func (p *packetPacker) writeAndSealPacket( if num != header.PacketNumber { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - return raw, nil + return &packedPacket{ + header: header, + raw: raw, + frames: frames, + buffer: packetBuffer, + }, nil } func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) { diff --git a/packet_packer_test.go b/packet_packer_test.go index f262f703a..1971700ed 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -253,7 +253,7 @@ var _ = Describe("Packet packer", func() { }) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) }) It("packs a single ACK", func() { @@ -494,7 +494,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(packets).To(HaveLen(1)) p := packets[0] - Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) Expect(p.frames).To(Equal(frames)) }) @@ -846,7 +846,7 @@ var _ = Describe("Packet packer", func() { Expect(p).To(HaveLen(1)) Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p[0].frames).To(Equal([]wire.Frame{f})) - Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) }) It("packs a retransmission for an Initial packet", func() { @@ -864,7 +864,7 @@ var _ = Describe("Packet packer", func() { Expect(packets).To(HaveLen(1)) p := packets[0] Expect(p.frames).To(Equal([]wire.Frame{sf})) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p.header.Token).To(Equal(token)) Expect(p.raw).To(HaveLen(protocol.MinInitialPacketSize)) diff --git a/packet_unpacker.go b/packet_unpacker.go index 5cde3fa7e..cda222324 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -39,14 +39,6 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { r := bytes.NewReader(data) - if hdr.IsLongHeader { - if protocol.ByteCount(r.Len()) < hdr.Length { - return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) - } - data = data[:int(hdr.ParsedLen()+hdr.Length)] - // TODO(#1312): implement parsing of compound packets - } - var encLevel protocol.EncryptionLevel switch hdr.Type { case protocol.PacketTypeInitial: @@ -93,11 +85,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, extHdr.PacketNumber, ) - buf := *getPacketBuffer() - buf = buf[:0] - defer putPacketBuffer(&buf) - - decrypted, err := opener.Open(buf, data, pn, extHdr.Raw) + decrypted, err := opener.Open(data[:0], data, pn, extHdr.Raw) if err != nil { return nil, err } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 8aa4e446b..c9a085d08 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -75,49 +75,6 @@ var _ = Describe("Packet Unpacker", func() { Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) }) - It("errors on packets that are smaller than the length in the packet header", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Length: 1000, - DestConnectionID: connID, - Version: version, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - hdr, hdrRaw := getHeader(extHdr) - data := append(hdrRaw, make([]byte, 500-2 /* for packet number length */)...) - _, err := unpacker.Unpack(hdr, data) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - - It("cuts packets to the right length", func() { - pnLen := protocol.PacketNumberLen2 - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: connID, - Type: protocol.PacketTypeHandshake, - Length: 456, - Version: protocol.VersionTLS, - }, - PacketNumberLen: pnLen, - } - payloadLen := 456 - int(pnLen) - hdr, hdrRaw := getHeader(extHdr) - data := append(hdrRaw, make([]byte, payloadLen)...) - opener := mocks.NewMockOpener(mockCtrl) - cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) { - Expect(payload).To(HaveLen(payloadLen)) - return []byte{0}, nil - }) - _, err := unpacker.Unpack(hdr, data) - Expect(err).ToNot(HaveOccurred()) - }) - It("returns the error when getting the sealer fails", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: connID}, diff --git a/server.go b/server.go index 14b43cff4..3fb1ca7a3 100644 --- a/server.go +++ b/server.go @@ -318,21 +318,27 @@ func (s *server) handlePacket(p *receivedPacket) { } if hdr.Type == protocol.PacketTypeInitial { go s.handleInitial(p) + return } + // TODO(#943): send Stateless Reset + p.buffer.Release() } func (s *server) handleInitial(p *receivedPacket) { - // TODO: add a check that DestConnID == SrcConnID s.logger.Debugf("<- Received Initial packet.") sess, connID, err := s.handleInitialImpl(p) if err != nil { + p.buffer.Release() s.logger.Errorf("Error occurred handling initial packet: %s", err) return } if sess == nil { // a retry was done + p.buffer.Release() return } + // Don't put the packet buffer back if a new session was created. + // The session will handle the packet and take of that. serverSession := newServerSession(sess, s.config, s.logger) s.sessionHandler.Add(connID, serverSession) } @@ -455,6 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { } func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { + defer p.buffer.Release() hdr := p.hdr s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) diff --git a/server_test.go b/server_test.go index 479423ab2..ddbac18f6 100644 --- a/server_test.go +++ b/server_test.go @@ -122,19 +122,19 @@ var _ = Describe("Server", func() { } It("drops Initial packets with a too short connection ID", func() { - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, Version: serv.config.Versions[0], }, - }) + })) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops too small Initial", func() { - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -142,12 +142,12 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100), - }) + })) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops packets with a too short connection ID", func() { - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -156,19 +156,19 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - }) + })) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops non-Initial packets", func() { serv.logger.SetLogLevel(utils.LogLevelDebug) - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ Type: protocol.PacketTypeHandshake, Version: serv.config.Versions[0], }, data: []byte("invalid"), - }) + })) }) It("decodes the cookie from the Token field", func() { @@ -185,7 +185,7 @@ var _ = Describe("Server", func() { } token, err := serv.cookieGenerator.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ remoteAddr: raddr, hdr: &wire.Header{ Type: protocol.PacketTypeInitial, @@ -193,7 +193,7 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - }) + })) Eventually(done).Should(BeClosed()) }) @@ -209,7 +209,7 @@ var _ = Describe("Server", func() { close(done) return false } - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ remoteAddr: raddr, hdr: &wire.Header{ Type: protocol.PacketTypeInitial, @@ -217,14 +217,14 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - }) + })) Eventually(done).Should(BeClosed()) }) It("sends a Version Negotiation Packet for unsupported versions", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - serv.handlePacket(&receivedPacket{ + serv.handlePacket(insertPacketBuffer(&receivedPacket{ remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, hdr: &wire.Header{ IsLongHeader: true, @@ -233,7 +233,7 @@ var _ = Describe("Server", func() { DestConnectionID: destConnID, Version: 0x42, }, - }) + })) var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) @@ -253,11 +253,11 @@ var _ = Describe("Server", func() { DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - serv.handleInitial(&receivedPacket{ + serv.handleInitial(insertPacketBuffer(&receivedPacket{ remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, hdr: hdr, data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - }) + })) var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) @@ -308,7 +308,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - serv.handlePacket(p) + serv.handlePacket(insertPacketBuffer(p)) // the Handshake packet is written by the session Consistently(conn.dataWritten).ShouldNot(Receive()) close(done) diff --git a/session.go b/session.go index 2cdd6d8c0..63aa3d9d7 100644 --- a/session.go +++ b/session.go @@ -53,8 +53,10 @@ type cryptoStreamHandler interface { type receivedPacket struct { remoteAddr net.Addr hdr *wire.Header - data []byte rcvTime time.Time + data []byte + + buffer *packetBuffer } type closeError struct { @@ -368,9 +370,6 @@ runLoop: if wasProcessed := s.handlePacketImpl(p); !wasProcessed { continue } - // This is a bit unclean, but works properly, since the packet always - // begins with the public header and we never copy it. - // TODO: putPacketBuffer(&p.extHdr.Raw) case <-s.handshakeCompleteChan: s.handleHandshakeComplete() } @@ -475,6 +474,15 @@ func (s *session) handleHandshakeComplete() { } func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ { + var wasQueued bool + + defer func() { + // Put back the packet buffer if the packet wasn't queued for later decryption. + if !wasQueued { + p.buffer.Release() + } + }() + // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) { @@ -490,6 +498,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc // if the decryption failed, this might be a packet sent by an attacker if err != nil { if err == handshake.ErrOpenerNotYetAvailable { + wasQueued = true s.tryQueueingUndecryptablePacket(p) return false } @@ -953,7 +962,7 @@ func (s *session) sendPacket() (bool, error) { } func (s *session) sendPackedPacket(packet *packedPacket) error { - defer putPacketBuffer(&packet.raw) + defer packet.buffer.Release() s.logPacket(packet) return s.conn.Write(packet.raw) } @@ -976,7 +985,7 @@ func (s *session) logPacket(packet *packedPacket) { // We don't need to allocate the slices for calling the format functions return } - s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.encryptionLevel) + s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.EncryptionLevel()) packet.header.Log(s.logger) for _, frame := range packet.frames { wire.LogFrame(s.logger, frame, true) diff --git a/session_test.go b/session_test.go index 02fb35792..3b0eb59d6 100644 --- a/session_test.go +++ b/session_test.go @@ -61,6 +61,11 @@ func areSessionsRunning() bool { return strings.Contains(b.String(), "quic-go.(*session).run") } +func insertPacketBuffer(p *receivedPacket) *receivedPacket { + p.buffer = getPacketBuffer() + return p +} + var _ = Describe("Session", func() { var ( sess *session @@ -496,11 +501,11 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false) sess.receivedPacketHandler = rph - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ rcvTime: rcvTime, hdr: &hdr.Header, data: getData(hdr), - })).To(BeTrue()) + }))).To(BeTrue()) }) It("closes when handling a packet fails", func() { @@ -518,7 +523,10 @@ var _ = Describe("Session", func() { close(done) }() sessionRunner.EXPECT().retireConnectionID(gomock.Any()) - sess.handlePacket(&receivedPacket{hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1})}) + sess.handlePacket(insertPacketBuffer(&receivedPacket{ + hdr: &wire.Header{}, + data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), + })) Eventually(done).Should(BeClosed()) }) @@ -528,18 +536,18 @@ var _ = Describe("Session", func() { PacketNumberLen: protocol.PacketNumberLen1, } unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: hdr}, nil).Times(2) - Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue()) - Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue()) + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) }) It("ignores 0-RTT packets", func() { - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketType0RTT, DestConnectionID: sess.srcConnID, }, - })).To(BeFalse()) + }))).To(BeFalse()) }) It("ignores packets with a different source connection ID", func() { @@ -552,12 +560,12 @@ var _ = Describe("Session", func() { // Send one packet, which might change the connection ID. // only EXPECT one call to the unpacker unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: &wire.ExtendedHeader{Header: *hdr}}, nil) - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: hdr, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - })).To(BeTrue()) + }))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, DestConnectionID: sess.destConnID, @@ -565,7 +573,7 @@ var _ = Describe("Session", func() { Length: 1, }, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - })).To(BeFalse()) + }))).To(BeFalse()) }) Context("updating the remote address", func() { @@ -574,12 +582,11 @@ var _ = Describe("Session", func() { origAddr := sess.conn.(*mockConnection).remoteAddr remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(origAddr).ToNot(Equal(remoteIP)) - p := receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ remoteAddr: remoteIP, hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - } - Expect(sess.handlePacketImpl(&p)).To(BeTrue()) + }))).To(BeTrue()) Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr)) }) }) @@ -587,10 +594,12 @@ var _ = Describe("Session", func() { Context("sending packets", func() { getPacket := func(pn protocol.PacketNumber) *packedPacket { - data := *getPacketBuffer() + buffer := getPacketBuffer() + data := buffer.Slice[:0] data = append(data, []byte("foobar")...) return &packedPacket{ raw: data, + buffer: buffer, header: &wire.ExtendedHeader{PacketNumber: pn}, } } @@ -963,7 +972,7 @@ var _ = Describe("Session", func() { defer close(done) return &packedPacket{ header: &wire.ExtendedHeader{}, - raw: *getPacketBuffer(), + buffer: getPacketBuffer(), }, nil }), packer.EXPECT().PackPacket().AnyTimes(), @@ -1352,7 +1361,7 @@ var _ = Describe("Client Session", func() { }() newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} packer.EXPECT().ChangeDestConnectionID(newConnID) - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ hdr: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -1361,7 +1370,7 @@ var _ = Describe("Client Session", func() { Length: 1, }, data: []byte{0}, - })).To(BeTrue()) + }))).To(BeTrue()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().retireConnectionID(gomock.Any())