From 595c396dc6588507a103d151631912e265d6e022 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 4 Jul 2018 12:11:16 +0700 Subject: [PATCH] refactor packet handling in the multiplexer --- client_multiplexer.go | 69 ++++++++++++++++++++------------------ client_multiplexer_test.go | 47 +++++++------------------- 2 files changed, 49 insertions(+), 67 deletions(-) diff --git a/client_multiplexer.go b/client_multiplexer.go index b53c315f..4b7c7ecf 100644 --- a/client_multiplexer.go +++ b/client_multiplexer.go @@ -63,7 +63,7 @@ func (m *clientMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHand m.conns[c] = p // If we didn't know this packet conn before, listen for incoming packets // and dispatch them to the right sessions. - go m.listen(c, p) + go m.listen(c, &p) } if p.connIDLen != connIDLen { return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) @@ -80,7 +80,7 @@ func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.Connect return nil } -func (m *clientMultiplexer) listen(c net.PacketConn, p connManager) { +func (m *clientMultiplexer) listen(c net.PacketConn, p *connManager) { for { data := *getPacketBuffer() data = data[:protocol.MaxReceivePacketSize] @@ -94,37 +94,42 @@ func (m *clientMultiplexer) listen(c net.PacketConn, p connManager) { return } data = data[:n] - rcvTime := time.Now() - r := bytes.NewReader(data) - iHdr, err := wire.ParseInvariantHeader(r, p.connIDLen) - // drop the packet if we can't parse the header - if err != nil { - m.logger.Debugf("error parsing invariant header from %s: %s", addr, err) - continue + if err := m.handlePacket(addr, data, p); err != nil { + m.logger.Debugf("error handling packet from %s: %s", addr, err) } - client, ok := p.manager.Get(iHdr.DestConnectionID) - if !ok { - m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) - continue - } - if client == nil { - // Late packet for closed session - continue - } - hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, client.GetVersion()) - if err != nil { - m.logger.Debugf("error parsing header from %s: %s", addr, err) - continue - } - hdr.Raw = data[:len(data)-r.Len()] - packetData := data[len(data)-r.Len():] - - client.handlePacket(&receivedPacket{ - remoteAddr: addr, - header: hdr, - data: packetData, - rcvTime: rcvTime, - }) } } + +func (m *clientMultiplexer) handlePacket(addr net.Addr, data []byte, p *connManager) error { + rcvTime := time.Now() + + r := bytes.NewReader(data) + iHdr, err := wire.ParseInvariantHeader(r, p.connIDLen) + // drop the packet if we can't parse the header + if err != nil { + return fmt.Errorf("error parsing invariant header: %s", err) + } + client, ok := p.manager.Get(iHdr.DestConnectionID) + if !ok { + return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) + } + if client == nil { + // Late packet for closed session + return nil + } + hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, client.GetVersion()) + if err != nil { + return fmt.Errorf("error parsing header: %s", err) + } + hdr.Raw = data[:len(data)-r.Len()] + packetData := data[len(data)-r.Len():] + + client.handlePacket(&receivedPacket{ + remoteAddr: addr, + header: hdr, + data: packetData, + rcvTime: rcvTime, + }) + return nil +} diff --git a/client_multiplexer_test.go b/client_multiplexer_test.go index b5f1d405..56038841 100644 --- a/client_multiplexer_test.go +++ b/client_multiplexer_test.go @@ -3,7 +3,6 @@ package quic import ( "bytes" "errors" - "time" "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -92,47 +91,25 @@ var _ = Describe("Client Multiplexer", func() { }) It("drops unparseable packets", func() { - conn := newMockPacketConn() - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} - conn.dataToRead <- []byte("invalid header") - packetHandler := NewMockQuicSession(mockCtrl) - getClientMultiplexer().AddConn(conn, 7) - Expect(getClientMultiplexer().AddHandler(conn, connID, packetHandler)).To(Succeed()) - time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet - packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() - close(conn.dataToRead) + err := getClientMultiplexer().(*clientMultiplexer).handlePacket(nil, []byte("invalid"), &connManager{connIDLen: 8}) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("error parsing invariant header:")) }) It("ignores packets arriving late for closed sessions", func() { - manager := NewMockPacketHandlerManager(mockCtrl) - origNewPacketHandlerManager := getClientMultiplexer().(*clientMultiplexer).newPacketHandlerManager - defer func() { - getClientMultiplexer().(*clientMultiplexer).newPacketHandlerManager = origNewPacketHandlerManager - }() - getClientMultiplexer().(*clientMultiplexer).newPacketHandlerManager = func() packetHandlerManager { return manager } - - conn := newMockPacketConn() connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - done := make(chan struct{}) - manager.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(done) }).Return(nil, true) - getClientMultiplexer().AddConn(conn, 8) - conn.dataToRead <- getPacket(connID) - Eventually(done).Should(BeClosed()) - // makes the listen go routine return - manager.EXPECT().Close(gomock.Any()).AnyTimes() - close(conn.dataToRead) + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Get(connID).Return(nil, true) + err := getClientMultiplexer().(*clientMultiplexer).handlePacket(nil, getPacket(connID), &connManager{manager: manager, connIDLen: 8}) + Expect(err).ToNot(HaveOccurred()) }) It("drops packets for unknown receivers", func() { - conn := newMockPacketConn() - conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) - packetHandler := NewMockQuicSession(mockCtrl) - getClientMultiplexer().AddConn(conn, 8) - Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)).To(Succeed()) - time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet - // makes the listen go routine return - packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() - close(conn.dataToRead) + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Get(connID).Return(nil, false) + err := getClientMultiplexer().(*clientMultiplexer).handlePacket(nil, getPacket(connID), &connManager{manager: manager, connIDLen: 8}) + Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) }) It("closes the packet handlers when reading from the conn fails", func() {