Files
quic-go/integrationtests/udp_proxy.go
2016-06-08 11:22:24 +07:00

151 lines
3.1 KiB
Go

package integrationtests
import (
"net"
"strconv"
"sync"
)
// Connection is a UDP connection
type connection struct {
ClientAddr *net.UDPAddr // Address of the client
ServerConn *net.UDPConn // UDP connection to server
incomingPacketCounter PacketNumber
outgoingPacketCounter PacketNumber
}
type PacketNumber uint64
type dropCallback func(PacketNumber) bool
// UDPProxy is a UDP proxy
type UDPProxy struct {
serverAddr *net.UDPAddr
mutex sync.Mutex
proxyConn *net.UDPConn
dropIncomingPacket dropCallback
dropOutgoingPacket dropCallback
// Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection
}
// NewUDPProxy creates a new UDP proxy
func NewUDPProxy(proxyPort int, serverAddress string, serverPort int, dropIncomingPacket, dropOutgoingPacket dropCallback) (*UDPProxy, error) {
dontDrop := func(p PacketNumber) bool {
return false
}
if dropIncomingPacket == nil {
dropIncomingPacket = dontDrop
}
if dropOutgoingPacket == nil {
dropOutgoingPacket = dontDrop
}
p := UDPProxy{
clientDict: make(map[string]*connection),
dropIncomingPacket: dropIncomingPacket,
dropOutgoingPacket: dropOutgoingPacket,
}
saddr, err := net.ResolveUDPAddr("udp", ":"+strconv.Itoa(proxyPort))
if err != nil {
return nil, err
}
pudp, err := net.ListenUDP("udp", saddr)
if err != nil {
return nil, err
}
p.proxyConn = pudp
srvaddr, err := net.ResolveUDPAddr("udp", serverAddress+":"+strconv.Itoa(serverPort))
if err != nil {
return nil, err
}
p.serverAddr = srvaddr
go p.runProxy()
return &p, nil
}
// Stop stops the UDP Proxy
func (p *UDPProxy) Stop() {
p.proxyConn.Close()
}
func (p *UDPProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
var conn connection
conn.ClientAddr = cliAddr
srvudp, err := net.DialUDP("udp", nil, p.serverAddr)
if err != nil {
return nil, err
}
conn.ServerConn = srvudp
return &conn, nil
}
// runProxy handles inputs to Proxy port
func (p *UDPProxy) runProxy() error {
buffer := make([]byte, 1500)
for {
n, cliaddr, err := p.proxyConn.ReadFromUDP(buffer[0:])
if err != nil {
return err
}
saddr := cliaddr.String()
p.mutex.Lock()
conn, ok := p.clientDict[saddr]
if !ok {
conn, err = p.newConnection(cliaddr)
if err != nil {
p.mutex.Unlock()
return err
}
p.clientDict[saddr] = conn
p.mutex.Unlock()
go p.runConnection(conn)
} else {
p.mutex.Unlock()
}
conn.incomingPacketCounter++
if !p.dropIncomingPacket(conn.incomingPacketCounter) {
// Relay to server
_, err = conn.ServerConn.Write(buffer[0:n])
if err != nil {
return err
}
}
}
}
// runConnection handles packets from server to a single client
func (p *UDPProxy) runConnection(conn *connection) error {
buffer := make([]byte, 1500)
for {
n, err := conn.ServerConn.Read(buffer[0:])
if err != nil {
return err
}
conn.outgoingPacketCounter++
if !p.dropOutgoingPacket(conn.outgoingPacketCounter) {
// Relay it to client
_, err = p.proxyConn.WriteToUDP(buffer[0:n], conn.ClientAddr)
if err != nil {
return err
}
}
}
}