forked from quic-go/quic-go
rewrite the proxy to avoid packet reordering
This commit is contained in:
@@ -2,6 +2,7 @@ package quicproxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +14,15 @@ import (
|
||||
type connection struct {
|
||||
ClientAddr *net.UDPAddr // Address of the client
|
||||
ServerConn *net.UDPConn // UDP connection to server
|
||||
|
||||
incomingPackets chan packetEntry
|
||||
|
||||
Incoming *queue
|
||||
Outgoing *queue
|
||||
}
|
||||
|
||||
func (c *connection) queuePacket(t time.Time, b []byte) {
|
||||
c.incomingPackets <- packetEntry{Time: t, Raw: b}
|
||||
}
|
||||
|
||||
// Direction is the direction a packet is sent.
|
||||
@@ -27,12 +37,63 @@ const (
|
||||
DirectionBoth
|
||||
)
|
||||
|
||||
type packetEntry struct {
|
||||
Time time.Time
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
type packetEntries []packetEntry
|
||||
|
||||
func (e packetEntries) Len() int { return len(e) }
|
||||
func (e packetEntries) Less(i, j int) bool { return e[i].Time.Before(e[j].Time) }
|
||||
func (e packetEntries) Swap(i, j int) { e[i], e[j] = e[j], e[i] }
|
||||
|
||||
type queue struct {
|
||||
sync.Mutex
|
||||
|
||||
timer *utils.Timer
|
||||
Packets packetEntries
|
||||
}
|
||||
|
||||
func newQueue() *queue {
|
||||
return &queue{timer: utils.NewTimer()}
|
||||
}
|
||||
|
||||
func (q *queue) Add(e packetEntry) {
|
||||
q.Lock()
|
||||
q.Packets = append(q.Packets, e)
|
||||
if len(q.Packets) > 1 {
|
||||
lastIndex := len(q.Packets) - 1
|
||||
if q.Packets[lastIndex].Time.Before(q.Packets[lastIndex-1].Time) {
|
||||
sort.Stable(q.Packets)
|
||||
}
|
||||
}
|
||||
q.timer.Reset(q.Packets[0].Time)
|
||||
q.Unlock()
|
||||
}
|
||||
|
||||
func (q *queue) Get() []byte {
|
||||
q.Lock()
|
||||
raw := q.Packets[0].Raw
|
||||
q.Packets = q.Packets[1:]
|
||||
if len(q.Packets) > 0 {
|
||||
q.timer.Reset(q.Packets[0].Time)
|
||||
}
|
||||
q.Unlock()
|
||||
return raw
|
||||
}
|
||||
|
||||
func (q *queue) Timer() <-chan time.Time { return q.timer.Chan() }
|
||||
func (q *queue) SetTimerRead() { q.timer.SetRead() }
|
||||
|
||||
func (q *queue) Close() { q.timer.Stop() }
|
||||
|
||||
func (d Direction) String() string {
|
||||
switch d {
|
||||
case DirectionIncoming:
|
||||
return "incoming"
|
||||
return "Incoming"
|
||||
case DirectionOutgoing:
|
||||
return "outgoing"
|
||||
return "Outgoing"
|
||||
case DirectionBoth:
|
||||
return "both"
|
||||
default:
|
||||
@@ -81,15 +142,14 @@ type Opts struct {
|
||||
type QuicProxy struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
closeChan chan struct{}
|
||||
|
||||
conn *net.UDPConn
|
||||
serverAddr *net.UDPAddr
|
||||
|
||||
dropPacket DropCallback
|
||||
delayPacket DelayCallback
|
||||
|
||||
timerID uint64
|
||||
timers map[uint64]*time.Timer
|
||||
|
||||
// Mapping from client addresses (as host:port) to connection
|
||||
clientDict map[string]*connection
|
||||
|
||||
@@ -127,10 +187,10 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
|
||||
p := QuicProxy{
|
||||
clientDict: make(map[string]*connection),
|
||||
conn: conn,
|
||||
closeChan: make(chan struct{}),
|
||||
serverAddr: raddr,
|
||||
dropPacket: packetDropper,
|
||||
delayPacket: packetDelayer,
|
||||
timers: make(map[uint64]*time.Timer),
|
||||
logger: utils.DefaultLogger.WithPrefix("proxy"),
|
||||
}
|
||||
|
||||
@@ -143,13 +203,13 @@ func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
|
||||
func (p *QuicProxy) Close() error {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
close(p.closeChan)
|
||||
for _, c := range p.clientDict {
|
||||
if err := c.ServerConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, t := range p.timers {
|
||||
t.Stop()
|
||||
c.Incoming.Close()
|
||||
c.Outgoing.Close()
|
||||
}
|
||||
return p.conn.Close()
|
||||
}
|
||||
@@ -170,8 +230,11 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
||||
return nil, err
|
||||
}
|
||||
return &connection{
|
||||
ClientAddr: cliAddr,
|
||||
ServerConn: srvudp,
|
||||
ClientAddr: cliAddr,
|
||||
ServerConn: srvudp,
|
||||
incomingPackets: make(chan packetEntry, 10),
|
||||
Incoming: newQueue(),
|
||||
Outgoing: newQueue(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -196,7 +259,8 @@ func (p *QuicProxy) runProxy() error {
|
||||
return err
|
||||
}
|
||||
p.clientDict[saddr] = conn
|
||||
go p.runConnection(conn)
|
||||
go p.runIncomingConnection(conn)
|
||||
go p.runOutgoingConnection(conn)
|
||||
}
|
||||
p.mutex.Unlock()
|
||||
|
||||
@@ -207,75 +271,87 @@ func (p *QuicProxy) runProxy() error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet to the server
|
||||
delay := p.delayPacket(DirectionIncoming, raw)
|
||||
if delay != 0 {
|
||||
if delay == 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", n, conn.ServerConn.RemoteAddr(), delay)
|
||||
}
|
||||
|
||||
p.mutex.Lock()
|
||||
p.timerID++
|
||||
id := p.timerID
|
||||
timer := time.AfterFunc(delay, func() {
|
||||
_, _ = conn.ServerConn.Write(raw) // TODO: handle error
|
||||
|
||||
p.mutex.Lock()
|
||||
delete(p.timers, id)
|
||||
p.mutex.Unlock()
|
||||
})
|
||||
p.timers[id] = timer
|
||||
p.mutex.Unlock()
|
||||
} else {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", n, conn.ServerConn.RemoteAddr())
|
||||
p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
|
||||
}
|
||||
if _, err := conn.ServerConn.Write(raw); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
now := time.Now()
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
|
||||
}
|
||||
conn.queuePacket(now.Add(delay), raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runConnection handles packets from server to a single client
|
||||
func (p *QuicProxy) runConnection(conn *connection) error {
|
||||
func (p *QuicProxy) runOutgoingConnection(conn *connection) error {
|
||||
outgoingPackets := make(chan packetEntry, 10)
|
||||
go func() {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, err := conn.ServerConn.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
||||
if p.dropPacket(DirectionOutgoing, raw) {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("dropping outgoing packet(%d bytes)", n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
delay := p.delayPacket(DirectionOutgoing, raw)
|
||||
if delay == 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr)
|
||||
}
|
||||
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
now := time.Now()
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", len(raw), conn.ClientAddr, delay)
|
||||
}
|
||||
outgoingPackets <- packetEntry{Time: now.Add(delay), Raw: raw}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, err := conn.ServerConn.Read(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
||||
if p.dropPacket(DirectionOutgoing, raw) {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("dropping outgoing packet(%d bytes)", n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
delay := p.delayPacket(DirectionOutgoing, raw)
|
||||
if delay != 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", n, conn.ClientAddr, delay)
|
||||
}
|
||||
|
||||
p.mutex.Lock()
|
||||
p.timerID++
|
||||
id := p.timerID
|
||||
timer := time.AfterFunc(delay, func() {
|
||||
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) // TODO: handle error
|
||||
p.mutex.Lock()
|
||||
delete(p.timers, id)
|
||||
p.mutex.Unlock()
|
||||
})
|
||||
p.timers[id] = timer
|
||||
p.mutex.Unlock()
|
||||
} else {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", n, conn.ClientAddr)
|
||||
}
|
||||
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
|
||||
select {
|
||||
case <-p.closeChan:
|
||||
return nil
|
||||
case e := <-outgoingPackets:
|
||||
conn.Outgoing.Add(e)
|
||||
case <-conn.Outgoing.Timer():
|
||||
conn.Outgoing.SetTimerRead()
|
||||
if _, err := p.conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *QuicProxy) runIncomingConnection(conn *connection) error {
|
||||
for {
|
||||
select {
|
||||
case <-p.closeChan:
|
||||
return nil
|
||||
case e := <-conn.incomingPackets:
|
||||
// Send the packet to the server
|
||||
conn.Incoming.Add(e)
|
||||
case <-conn.Incoming.Timer():
|
||||
conn.Incoming.SetTimerRead()
|
||||
if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user