From 81be522bf3074904cea5a6b984dc2bfe57aef161 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 3 Sep 2019 19:40:02 +0700 Subject: [PATCH] identify connections by their local addr when adding to the multiplexer --- conn_test.go | 1 + multiplexer.go | 14 ++++++++------ multiplexer_test.go | 19 +++++++++++++++++++ quic_suite_test.go | 5 +++++ 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/conn_test.go b/conn_test.go index d5dedfdae..8fa330dbe 100644 --- a/conn_test.go +++ b/conn_test.go @@ -25,6 +25,7 @@ type mockPacketConn struct { func newMockPacketConn() *mockPacketConn { return &mockPacketConn{ + addr: &net.UDPAddr{IP: net.IPv6zero, Port: 0x42}, dataToRead: make(chan []byte, 1000), dataWritten: make(chan mockPacketConnWrite, 1000), } diff --git a/multiplexer.go b/multiplexer.go index eeffca53b..1ba8892d1 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -30,7 +30,7 @@ type connManager struct { type connMultiplexer struct { mutex sync.Mutex - conns map[net.PacketConn]connManager + conns map[string] /* LocalAddr().String() */ connManager newPacketHandlerManager func(net.PacketConn, int, []byte, utils.Logger) packetHandlerManager // so it can be replaced in the tests logger utils.Logger @@ -41,7 +41,7 @@ var _ multiplexer = &connMultiplexer{} func getMultiplexer() multiplexer { connMuxerOnce.Do(func() { connMuxer = &connMultiplexer{ - conns: make(map[net.PacketConn]connManager), + conns: make(map[string]connManager), logger: utils.DefaultLogger.WithPrefix("muxer"), newPacketHandlerManager: newPacketHandlerMap, } @@ -57,7 +57,8 @@ func (m *connMultiplexer) AddConn( m.mutex.Lock() defer m.mutex.Unlock() - p, ok := m.conns[c] + laddr := c.LocalAddr().String() + p, ok := m.conns[laddr] if !ok { manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, m.logger) p = connManager{ @@ -65,7 +66,7 @@ func (m *connMultiplexer) AddConn( statelessResetKey: statelessResetKey, manager: manager, } - m.conns[c] = p + m.conns[laddr] = p } if p.connIDLen != connIDLen { return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) @@ -80,10 +81,11 @@ func (m *connMultiplexer) RemoveConn(c net.PacketConn) error { m.mutex.Lock() defer m.mutex.Unlock() - if _, ok := m.conns[c]; !ok { + laddr := c.LocalAddr().String() + if _, ok := m.conns[laddr]; !ok { return fmt.Errorf("cannote remove connection, connection is unknown") } - delete(m.conns, c) + delete(m.conns, laddr) return nil } diff --git a/multiplexer_test.go b/multiplexer_test.go index 1b40cf114..f92d2aa9f 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -1,10 +1,17 @@ package quic import ( + "net" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) +type testConn struct { + counter int + net.PacketConn +} + var _ = Describe("Client Multiplexer", func() { It("adds a new packet conn ", func() { conn := newMockPacketConn() @@ -12,6 +19,18 @@ var _ = Describe("Client Multiplexer", func() { Expect(err).ToNot(HaveOccurred()) }) + It("recognizes when the same connection is added twice", func() { + pconn := newMockPacketConn() + pconn.addr = &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} + conn := testConn{PacketConn: pconn} + _, err := getMultiplexer().AddConn(conn, 8, nil) + Expect(err).ToNot(HaveOccurred()) + conn.counter++ + _, err = getMultiplexer().AddConn(conn, 8, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1)) + }) + It("errors when adding an existing conn with a different connection ID length", func() { conn := newMockPacketConn() _, err := getMultiplexer().AddConn(conn, 5, nil) diff --git a/quic_suite_test.go b/quic_suite_test.go index 7059345fb..1798fff7d 100644 --- a/quic_suite_test.go +++ b/quic_suite_test.go @@ -1,6 +1,8 @@ package quic import ( + "sync" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -17,6 +19,9 @@ var mockCtrl *gomock.Controller var _ = BeforeEach(func() { mockCtrl = gomock.NewController(GinkgoT()) + + // reset the sync.Once + connMuxerOnce = *new(sync.Once) }) var _ = AfterEach(func() {