forked from quic-go/quic-go
proxy: add function to simulate NAT rebinding (#4922)
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
package quicproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,6 +16,9 @@ import (
|
||||
// Connection is a UDP connection
|
||||
type connection struct {
|
||||
ClientAddr *net.UDPAddr // Address of the client
|
||||
ServerAddr *net.UDPAddr // Address of the server
|
||||
|
||||
mx sync.Mutex
|
||||
ServerConn *net.UDPConn // UDP connection to server
|
||||
|
||||
incomingPackets chan packetEntry
|
||||
@@ -25,6 +31,22 @@ func (c *connection) queuePacket(t time.Time, b []byte) {
|
||||
c.incomingPackets <- packetEntry{Time: t, Raw: b}
|
||||
}
|
||||
|
||||
func (c *connection) SwitchConn(conn *net.UDPConn) {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
old := c.ServerConn
|
||||
old.SetReadDeadline(time.Now())
|
||||
c.ServerConn = conn
|
||||
}
|
||||
|
||||
func (c *connection) GetServerConn() *net.UDPConn {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
return c.ServerConn
|
||||
}
|
||||
|
||||
// Direction is the direction a packet is sent.
|
||||
type Direction int
|
||||
|
||||
@@ -118,8 +140,7 @@ type DelayCallback func(dir Direction, packet []byte) time.Duration
|
||||
|
||||
// Proxy is a QUIC proxy that can drop and delay packets.
|
||||
type Proxy struct {
|
||||
// Conn is the UDP socket that the proxy listens on for incoming packets
|
||||
// from clients.
|
||||
// Conn is the UDP socket that the proxy listens on for incoming packets from clients.
|
||||
Conn *net.UDPConn
|
||||
|
||||
// ServerAddr is the address of the server that the proxy forwards packets to.
|
||||
@@ -139,7 +160,6 @@ type Proxy struct {
|
||||
clientDict map[string]*connection
|
||||
}
|
||||
|
||||
// NewQuicProxy creates a new UDP proxy
|
||||
func (p *Proxy) Start() error {
|
||||
p.clientDict = make(map[string]*connection)
|
||||
p.closeChan = make(chan struct{})
|
||||
@@ -157,6 +177,25 @@ func (p *Proxy) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SwitchConn switches the connection for a client,
|
||||
// identified the address that the client is sending from.
|
||||
func (p *Proxy) SwitchConn(clientAddr *net.UDPAddr, conn *net.UDPConn) error {
|
||||
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
|
||||
return err
|
||||
}
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
c, ok := p.clientDict[clientAddr.String()]
|
||||
if !ok {
|
||||
return fmt.Errorf("client %s not found", clientAddr)
|
||||
}
|
||||
c.SwitchConn(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close stops the UDP Proxy
|
||||
func (p *Proxy) Close() error {
|
||||
p.mutex.Lock()
|
||||
@@ -164,7 +203,7 @@ func (p *Proxy) Close() error {
|
||||
|
||||
close(p.closeChan)
|
||||
for _, c := range p.clientDict {
|
||||
if err := c.ServerConn.Close(); err != nil {
|
||||
if err := c.GetServerConn().Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Incoming.Close()
|
||||
@@ -177,7 +216,7 @@ func (p *Proxy) Close() error {
|
||||
func (p *Proxy) LocalAddr() net.Addr { return p.Conn.LocalAddr() }
|
||||
|
||||
func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
||||
conn, err := net.DialUDP("udp", nil, p.ServerAddr)
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,10 +228,11 @@ func (p *Proxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
||||
}
|
||||
return &connection{
|
||||
ClientAddr: cliAddr,
|
||||
ServerConn: conn,
|
||||
ServerAddr: p.ServerAddr,
|
||||
incomingPackets: make(chan packetEntry, 10),
|
||||
Incoming: newQueue(),
|
||||
Outgoing: newQueue(),
|
||||
ServerConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -204,11 +244,10 @@ func (p *Proxy) runProxy() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
raw := buffer[:n]
|
||||
|
||||
saddr := cliaddr.String()
|
||||
p.mutex.Lock()
|
||||
conn, ok := p.clientDict[saddr]
|
||||
conn, ok := p.clientDict[cliaddr.String()]
|
||||
|
||||
if !ok {
|
||||
conn, err = p.newConnection(cliaddr)
|
||||
@@ -216,7 +255,7 @@ func (p *Proxy) runProxy() error {
|
||||
p.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
p.clientDict[saddr] = conn
|
||||
p.clientDict[cliaddr.String()] = conn
|
||||
go p.runIncomingConnection(conn)
|
||||
go p.runOutgoingConnection(conn)
|
||||
}
|
||||
@@ -235,15 +274,15 @@ func (p *Proxy) runProxy() error {
|
||||
}
|
||||
if delay == 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
|
||||
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerAddr)
|
||||
}
|
||||
if _, err := conn.ServerConn.Write(raw); err != nil {
|
||||
if _, err := conn.GetServerConn().WriteTo(raw, conn.ServerAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
now := time.Now()
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
|
||||
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerAddr, delay)
|
||||
}
|
||||
conn.queuePacket(now.Add(delay), raw)
|
||||
}
|
||||
@@ -256,8 +295,13 @@ func (p *Proxy) runOutgoingConnection(conn *connection) error {
|
||||
go func() {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxPacketBufferSize)
|
||||
n, err := conn.ServerConn.Read(buffer)
|
||||
n, err := conn.GetServerConn().Read(buffer)
|
||||
if err != nil {
|
||||
// when the connection is switched out, we set a deadline on the old connection,
|
||||
// in order to return it immediately
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
@@ -315,7 +359,7 @@ func (p *Proxy) runIncomingConnection(conn *connection) error {
|
||||
conn.Incoming.Add(e)
|
||||
case <-conn.Incoming.Timer():
|
||||
conn.Incoming.SetTimerRead()
|
||||
if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
|
||||
if _, err := conn.GetServerConn().WriteTo(conn.Incoming.Get(), conn.ServerAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user