forked from quic-go/quic-go
remove validation enforcing one Transport per net.PacketConn (#4851)
It is invalid to use a net.PacketConn in multiple Transports. However, the validation logic is causing pain when using wrapped net.PacketConns. It was introduce to guard against incorrect uses of the API when the Transport was introduced, but this is probably less relevant now than it was back then.
This commit is contained in:
@@ -17,20 +17,14 @@ import (
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
type nullMultiplexer struct{}
|
||||
|
||||
func (n nullMultiplexer) AddConn(indexableConn) {}
|
||||
func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
cl *client
|
||||
packetConn *MockSendConn
|
||||
connID protocol.ConnectionID
|
||||
origMultiplexer multiplexer
|
||||
tlsConf *tls.Config
|
||||
tracer *mocklogging.MockConnectionTracer
|
||||
config *Config
|
||||
cl *client
|
||||
packetConn *MockSendConn
|
||||
connID protocol.ConnectionID
|
||||
tlsConf *tls.Config
|
||||
tracer *mocklogging.MockConnectionTracer
|
||||
config *Config
|
||||
|
||||
originalClientConnConstructor func(
|
||||
ctx context.Context,
|
||||
@@ -74,14 +68,9 @@ var _ = Describe("Client", func() {
|
||||
tracer: tr,
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
getMultiplexer() // make the sync.Once execute
|
||||
// replace the clientMuxer. getMultiplexer will now return the nullMultiplexer
|
||||
origMultiplexer = connMuxer
|
||||
connMuxer = &nullMultiplexer{}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
connMuxer = origMultiplexer
|
||||
newClientConnection = originalClientConnConstructor
|
||||
})
|
||||
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
connMuxerOnce sync.Once
|
||||
connMuxer multiplexer
|
||||
)
|
||||
|
||||
type indexableConn interface{ LocalAddr() net.Addr }
|
||||
|
||||
type multiplexer interface {
|
||||
AddConn(conn indexableConn)
|
||||
RemoveConn(indexableConn) error
|
||||
}
|
||||
|
||||
// The connMultiplexer listens on multiple net.PacketConns and dispatches
|
||||
// incoming packets to the connection handler.
|
||||
type connMultiplexer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conns map[string] /* LocalAddr().String() */ indexableConn
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ multiplexer = &connMultiplexer{}
|
||||
|
||||
func getMultiplexer() multiplexer {
|
||||
connMuxerOnce.Do(func() {
|
||||
connMuxer = &connMultiplexer{
|
||||
conns: make(map[string]indexableConn),
|
||||
logger: utils.DefaultLogger.WithPrefix("muxer"),
|
||||
}
|
||||
})
|
||||
return connMuxer
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) index(addr net.Addr) string {
|
||||
return addr.Network() + " " + addr.String()
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) AddConn(c indexableConn) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
connIndex := m.index(c.LocalAddr())
|
||||
p, ok := m.conns[connIndex]
|
||||
if ok {
|
||||
// Panics if we're already listening on this connection.
|
||||
// This is a safeguard because we're introducing a breaking API change, see
|
||||
// https://github.com/quic-go/quic-go/issues/3727 for details.
|
||||
// We'll remove this at a later time, when most users of the library have made the switch.
|
||||
panic("connection already exists") // TODO: write a nice message
|
||||
}
|
||||
m.conns[connIndex] = p
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
connIndex := m.index(c.LocalAddr())
|
||||
if _, ok := m.conns[connIndex]; !ok {
|
||||
return fmt.Errorf("cannote remove connection %s, connection is unknown", connIndex)
|
||||
}
|
||||
|
||||
delete(m.conns, connIndex)
|
||||
return nil
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockIndexableConn struct{ addr net.Addr }
|
||||
|
||||
var _ indexableConn = &mockIndexableConn{}
|
||||
|
||||
func (m *mockIndexableConn) LocalAddr() net.Addr { return m.addr }
|
||||
|
||||
func TestMultiplexerAddNewPacketConns(t *testing.T) {
|
||||
conn1 := &mockIndexableConn{addr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}}
|
||||
getMultiplexer().AddConn(conn1)
|
||||
conn2 := &mockIndexableConn{addr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1235}}
|
||||
getMultiplexer().AddConn(conn2)
|
||||
|
||||
require.NoError(t, getMultiplexer().RemoveConn(conn1))
|
||||
require.NoError(t, getMultiplexer().RemoveConn(conn2))
|
||||
}
|
||||
|
||||
func TestMultiplexerPanicsOnDuplicateConn(t *testing.T) {
|
||||
conn := &mockIndexableConn{addr: &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}}
|
||||
getMultiplexer().AddConn(conn)
|
||||
require.Panics(t, func() { getMultiplexer().AddConn(conn) })
|
||||
|
||||
require.NoError(t, getMultiplexer().RemoveConn(conn))
|
||||
require.ErrorContains(t, getMultiplexer().RemoveConn(conn), "cannote remove connection")
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"runtime/pprof"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -29,9 +28,6 @@ var mockCtrl *gomock.Controller
|
||||
|
||||
var _ = BeforeEach(func() {
|
||||
mockCtrl = gomock.NewController(GinkgoT())
|
||||
|
||||
// reset the sync.Once
|
||||
connMuxerOnce = *new(sync.Once)
|
||||
})
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
|
||||
@@ -269,7 +269,6 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
|
||||
}
|
||||
|
||||
getMultiplexer().AddConn(t.Conn)
|
||||
go t.listen(conn)
|
||||
go t.runSendQueue()
|
||||
})
|
||||
@@ -366,7 +365,6 @@ var setBufferWarningOnce sync.Once
|
||||
|
||||
func (t *Transport) listen(conn rawConn) {
|
||||
defer close(t.listening)
|
||||
defer getMultiplexer().RemoveConn(t.Conn)
|
||||
|
||||
for {
|
||||
p, err := conn.ReadPacket()
|
||||
|
||||
@@ -416,9 +416,6 @@ func TestTransportFaultySyscallConn(t *testing.T) {
|
||||
_, err := tr.Listen(&tls.Config{}, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "mocked")
|
||||
|
||||
conns := getMultiplexer().(*connMultiplexer).conns
|
||||
require.Empty(t, conns)
|
||||
}
|
||||
|
||||
func TestTransportSetTLSConfigServerName(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user