proxy: add function to simulate NAT rebinding (#4922)

This commit is contained in:
Marten Seemann
2025-01-26 05:03:08 +01:00
committed by GitHub
parent 79bae396b4
commit 3e87ea3f50
3 changed files with 149 additions and 35 deletions

View File

@@ -1,7 +1,10 @@
package quicproxy
import (
"errors"
"fmt"
"net"
"os"
"sort"
"sync"
"time"
@@ -13,6 +16,9 @@ import (
// Connection is a UDP connection
type connection struct {
ClientAddr *net.UDPAddr // Address of the client
ServerAddr *net.UDPAddr // Address of the server
mx sync.Mutex
ServerConn *net.UDPConn // UDP connection to server
incomingPackets chan packetEntry
@@ -25,6 +31,22 @@ func (c *connection) queuePacket(t time.Time, b []byte) {
c.incomingPackets <- packetEntry{Time: t, Raw: b}
}
func (c *connection) SwitchConn(conn *net.UDPConn) {
c.mx.Lock()
defer c.mx.Unlock()
old := c.ServerConn
old.SetReadDeadline(time.Now())
c.ServerConn = conn
}
func (c *connection) GetServerConn() *net.UDPConn {
c.mx.Lock()
defer c.mx.Unlock()
return c.ServerConn
}
// Direction is the direction a packet is sent.
type Direction int
@@ -118,8 +140,7 @@ type DelayCallback func(dir Direction, packet []byte) time.Duration
// Proxy is a QUIC proxy that can drop and delay packets.
type Proxy struct {
// Conn is the UDP socket that the proxy listens on for incoming packets
// from clients.
// Conn is the UDP socket that the proxy listens on for incoming packets from clients.
Conn *net.UDPConn
// ServerAddr is the address of the server that the proxy forwards packets to.
@@ -139,7 +160,6 @@ type Proxy struct {
clientDict map[string]*connection
}
// NewQuicProxy creates a new UDP proxy
func (p *Proxy) Start() error {
p.clientDict = make(map[string]*connection)
p.closeChan = make(chan struct{})
@@ -157,6 +177,25 @@ func (p *Proxy) Start() error {
return nil
}
// SwitchConn switches the connection for a client,
// identified the address that the client is sending from.
func (p *Proxy) SwitchConn(clientAddr *net.UDPAddr, conn *net.UDPConn) error {
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return err
}
if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
return err
}
p.mutex.Lock()
defer p.mutex.Unlock()
c, ok := p.clientDict[clientAddr.String()]
if !ok {
return fmt.Errorf("client %s not found", clientAddr)
}
c.SwitchConn(conn)
return nil
}
// Close stops the UDP Proxy
func (p *Proxy) Close() error {
p.mutex.Lock()
@@ -164,7 +203,7 @@ func (p *Proxy) Close() error {
close(p.closeChan)
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
if err := c.GetServerConn().Close(); err != nil {
return err
}
c.Incoming.Close()
@@ -177,7 +216,7 @@ func (p *Proxy) Close() error {
func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() }
func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
conn, err := net.DialUDP("udp", nil, p.ServerAddr)
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
if err != nil {
return nil, err
}
@@ -189,10 +228,11 @@ func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
}
return &connection{
ClientAddr: cliAddr,
ServerConn: conn,
ServerAddr: p.ServerAddr,
incomingPackets: make(chan packetEntry, 10),
Incoming: newQueue(),
Outgoing: newQueue(),
ServerConn: conn,
}, nil
}
@@ -204,11 +244,10 @@ func (p *Proxy) runProxy() error {
if err != nil {
return err
}
raw := buffer[0:n]
raw := buffer[:n]
saddr := cliaddr.String()
p.mutex.Lock()
conn, ok := p.clientDict[saddr]
conn, ok := p.clientDict[cliaddr.String()]
if !ok {
conn, err = p.newConnection(cliaddr)
@@ -216,7 +255,7 @@ func (p *Proxy) runProxy() error {
p.mutex.Unlock()
return err
}
p.clientDict[saddr] = conn
p.clientDict[cliaddr.String()] = conn
go p.runIncomingConnection(conn)
go p.runOutgoingConnection(conn)
}
@@ -235,15 +274,15 @@ func (p *Proxy) runProxy() error {
}
if delay == 0 {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerAddr)
}
if _, err := conn.ServerConn.Write(raw); err != nil {
if _, err := conn.GetServerConn().WriteTo(raw, conn.ServerAddr); err != nil {
return err
}
} else {
now := time.Now()
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerAddr, delay)
}
conn.queuePacket(now.Add(delay), raw)
}
@@ -256,8 +295,13 @@ func (p *Proxy) runOutgoingConnection(conn *connection) error {
go func() {
for {
buffer := make([]byte, protocol.MaxPacketBufferSize)
n, err := conn.ServerConn.Read(buffer)
n, err := conn.GetServerConn().Read(buffer)
if err != nil {
// when the connection is switched out, we set a deadline on the old connection,
// in order to return it immediately
if errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
return
}
raw := buffer[0:n]
@@ -315,7 +359,7 @@ func (p *Proxy) runIncomingConnection(conn *connection) error {
conn.Incoming.Add(e)
case <-conn.Incoming.Timer():
conn.Incoming.SetTimerRead()
if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
if _, err := conn.GetServerConn().WriteTo(conn.Incoming.Get(), conn.ServerAddr); err != nil {
return err
}
}