rewrite the proxy to avoid packet reordering

This commit is contained in:
Marten Seemann
2020-06-19 22:57:07 +07:00
parent c956ca4447
commit 0baf16ea4e
3 changed files with 248 additions and 96 deletions

View File

@@ -2,6 +2,7 @@ package quicproxy
import (
"net"
"sort"
"sync"
"time"
@@ -13,6 +14,15 @@ import (
type connection struct {
ClientAddr *net.UDPAddr // Address of the client
ServerConn *net.UDPConn // UDP connection to server
incomingPackets chan packetEntry
Incoming *queue
Outgoing *queue
}
func (c *connection) queuePacket(t time.Time, b []byte) {
c.incomingPackets <- packetEntry{Time: t, Raw: b}
}
// Direction is the direction a packet is sent.
@@ -27,12 +37,63 @@ const (
DirectionBoth
)
type packetEntry struct {
Time time.Time
Raw []byte
}
type packetEntries []packetEntry
func (e packetEntries) Len() int { return len(e) }
func (e packetEntries) Less(i, j int) bool { return e[i].Time.Before(e[j].Time) }
func (e packetEntries) Swap(i, j int) { e[i], e[j] = e[j], e[i] }
type queue struct {
sync.Mutex
timer *utils.Timer
Packets packetEntries
}
func newQueue() *queue {
return &queue{timer: utils.NewTimer()}
}
func (q *queue) Add(e packetEntry) {
q.Lock()
q.Packets = append(q.Packets, e)
if len(q.Packets) > 1 {
lastIndex := len(q.Packets) - 1
if q.Packets[lastIndex].Time.Before(q.Packets[lastIndex-1].Time) {
sort.Stable(q.Packets)
}
}
q.timer.Reset(q.Packets[0].Time)
q.Unlock()
}
func (q *queue) Get() []byte {
q.Lock()
raw := q.Packets[0].Raw
q.Packets = q.Packets[1:]
if len(q.Packets) > 0 {
q.timer.Reset(q.Packets[0].Time)
}
q.Unlock()
return raw
}
func (q *queue) Timer() <-chan time.Time { return q.timer.Chan() }
func (q *queue) SetTimerRead() { q.timer.SetRead() }
func (q *queue) Close() { q.timer.Stop() }
func (d Direction) String() string {
switch d {
case DirectionIncoming:
return "incoming"
return "Incoming"
case DirectionOutgoing:
return "outgoing"
return "Outgoing"
case DirectionBoth:
return "both"
default:
@@ -81,15 +142,14 @@ type Opts struct {
type QuicProxy struct {
mutex sync.Mutex
closeChan chan struct{}
conn *net.UDPConn
serverAddr *net.UDPAddr
dropPacket DropCallback
delayPacket DelayCallback
timerID uint64
timers map[uint64]*time.Timer
// Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection
@@ -127,10 +187,10 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
p := QuicProxy{
clientDict: make(map[string]*connection),
conn: conn,
closeChan: make(chan struct{}),
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
timers: make(map[uint64]*time.Timer),
logger: utils.DefaultLogger.WithPrefix("proxy"),
}
@@ -143,13 +203,13 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
func (p *QuicProxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
close(p.closeChan)
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
return err
}
}
for _, t := range p.timers {
t.Stop()
c.Incoming.Close()
c.Outgoing.Close()
}
return p.conn.Close()
}
@@ -170,8 +230,11 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
return nil, err
}
return &connection{
ClientAddr: cliAddr,
ServerConn: srvudp,
ClientAddr: cliAddr,
ServerConn: srvudp,
incomingPackets: make(chan packetEntry, 10),
Incoming: newQueue(),
Outgoing: newQueue(),
}, nil
}
@@ -196,7 +259,8 @@ func (p *QuicProxy) runProxy() error {
return err
}
p.clientDict[saddr] = conn
go p.runConnection(conn)
go p.runIncomingConnection(conn)
go p.runOutgoingConnection(conn)
}
p.mutex.Unlock()
@@ -207,75 +271,87 @@ func (p *QuicProxy) runProxy() error {
continue
}
// Send the packet to the server
delay := p.delayPacket(DirectionIncoming, raw)
if delay != 0 {
if delay == 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", n, conn.ServerConn.RemoteAddr(), delay)
}
p.mutex.Lock()
p.timerID++
id := p.timerID
timer := time.AfterFunc(delay, func() {
_, _ = conn.ServerConn.Write(raw) // TODO: handle error
p.mutex.Lock()
delete(p.timers, id)
p.mutex.Unlock()
})
p.timers[id] = timer
p.mutex.Unlock()
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", n, conn.ServerConn.RemoteAddr())
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
}
if _, err := conn.ServerConn.Write(raw); err != nil {
return err
}
} else {
now := time.Now()
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
}
conn.queuePacket(now.Add(delay), raw)
}
}
}
// runConnection handles packets from server to a single client
func (p *QuicProxy) runConnection(conn *connection) error {
func (p *QuicProxy) runOutgoingConnection(conn *connection) error {
outgoingPackets := make(chan packetEntry, 10)
go func() {
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return
}
raw := buffer[0:n]
if p.dropPacket(DirectionOutgoing, raw) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet(%d bytes)", n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, raw)
if delay == 0 {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return
}
} else {
now := time.Now()
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", len(raw), conn.ClientAddr, delay)
}
outgoingPackets <- packetEntry{Time: now.Add(delay), Raw: raw}
}
}
}()
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
if p.dropPacket(DirectionOutgoing, raw) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet(%d bytes)", n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, raw)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", n, conn.ClientAddr, delay)
}
p.mutex.Lock()
p.timerID++
id := p.timerID
timer := time.AfterFunc(delay, func() {
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) // TODO: handle error
p.mutex.Lock()
delete(p.timers, id)
p.mutex.Unlock()
})
p.timers[id] = timer
p.mutex.Unlock()
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", n, conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
select {
case <-p.closeChan:
return nil
case e := <-outgoingPackets:
conn.Outgoing.Add(e)
case <-conn.Outgoing.Timer():
conn.Outgoing.SetTimerRead()
if _, err := p.conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil {
return err
}
}
}
}
func (p *QuicProxy) runIncomingConnection(conn *connection) error {
for {
select {
case <-p.closeChan:
return nil
case e := <-conn.incomingPackets:
// Send the packet to the server
conn.Incoming.Add(e)
case <-conn.Incoming.Timer():
conn.Incoming.SetTimerRead()
if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
return err
}
}