diff --git a/integrationtests/udp_proxy.go b/integrationtests/udp_proxy.go index 26f7a0a1..18947b9b 100644 --- a/integrationtests/udp_proxy.go +++ b/integrationtests/udp_proxy.go @@ -10,23 +10,44 @@ import ( type connection struct { ClientAddr *net.UDPAddr // Address of the client ServerConn *net.UDPConn // UDP connection to server + + incomingPacketCounter packetNumber + outgoingPacketCounter packetNumber } +type packetNumber uint64 +type dropCallback func(packetNumber) bool + // UDPProxy is a UDP proxy type UDPProxy struct { serverAddr *net.UDPAddr mutex sync.Mutex - proxyConn *net.UDPConn + proxyConn *net.UDPConn + dropIncomingPacket dropCallback + dropOutgoingPacket dropCallback // Mapping from client addresses (as host:port) to connection clientDict map[string]*connection } // NewUDPProxy creates a new UDP proxy -func NewUDPProxy(proxyPort int, serverAddress string, serverPort int) (*UDPProxy, error) { +func NewUDPProxy(proxyPort int, serverAddress string, serverPort int, dropIncomingPacket, dropOutgoingPacket dropCallback) (*UDPProxy, error) { + dontDrop := func(p packetNumber) bool { + return false + } + + if dropIncomingPacket == nil { + dropIncomingPacket = dontDrop + } + if dropOutgoingPacket == nil { + dropOutgoingPacket = dontDrop + } + p := UDPProxy{ - clientDict: make(map[string]*connection), + clientDict: make(map[string]*connection), + dropIncomingPacket: dropIncomingPacket, + dropOutgoingPacket: dropOutgoingPacket, } saddr, err := net.ResolveUDPAddr("udp", ":"+strconv.Itoa(proxyPort)) @@ -94,10 +115,14 @@ func (p *UDPProxy) runProxy() error { p.mutex.Unlock() } - // Relay to server - _, err = conn.ServerConn.Write(buffer[0:n]) - if err != nil { - return err + conn.incomingPacketCounter++ + + if !p.dropIncomingPacket(conn.incomingPacketCounter) { + // Relay to server + _, err = conn.ServerConn.Write(buffer[0:n]) + if err != nil { + return err + } } } } @@ -112,10 +137,14 @@ func (p *UDPProxy) runConnection(conn *connection) error { return err } - // Relay it to client - _, err = p.proxyConn.WriteToUDP(buffer[0:n], conn.ClientAddr) - if err != nil { - return err + conn.outgoingPacketCounter++ + + if !p.dropOutgoingPacket(conn.outgoingPacketCounter) { + // Relay it to client + _, err = p.proxyConn.WriteToUDP(buffer[0:n], conn.ClientAddr) + if err != nil { + return err + } } } } diff --git a/integrationtests/udp_proxy_test.go b/integrationtests/udp_proxy_test.go index 746d225f..999f2051 100644 --- a/integrationtests/udp_proxy_test.go +++ b/integrationtests/udp_proxy_test.go @@ -2,6 +2,7 @@ package integrationtests import ( "net" + "strconv" "time" . "github.com/onsi/ginkgo" @@ -11,9 +12,16 @@ import ( type packetData []byte var _ = Describe("Integrationtests", func() { + var serverPort int + + BeforeEach(func() { + serverPort = 7331 + }) + It("sets up the UDPProxy", func() { - proxy, err := NewUDPProxy(13370, "localhost", 7331) + proxy, err := NewUDPProxy(13370, "localhost", serverPort, nil, nil) Expect(err).ToNot(HaveOccurred()) + Expect(proxy.clientDict).To(HaveLen(0)) // check that port 13370 is in use addr, err := net.ResolveUDPAddr("udp", ":13370") @@ -26,7 +34,7 @@ var _ = Describe("Integrationtests", func() { }) It("stops the UDPProxy", func() { - proxy, err := NewUDPProxy(13370, "localhost", 7331) + proxy, err := NewUDPProxy(13370, "localhost", serverPort, nil, nil) Expect(err).ToNot(HaveOccurred()) proxy.Stop() @@ -51,14 +59,11 @@ var _ = Describe("Integrationtests", func() { serverNumPacketsSent = 0 // setup a UDP server on port 7331 - serverAddr, err := net.ResolveUDPAddr("udp", ":7331") + serverAddr, err := net.ResolveUDPAddr("udp", ":"+strconv.Itoa(serverPort)) Expect(err).ToNot(HaveOccurred()) serverConn, err = net.ListenUDP("udp", serverAddr) Expect(err).ToNot(HaveOccurred()) - // setup the proxy - proxy, err = NewUDPProxy(10001, "localhost", 7331) - Expect(err).ToNot(HaveOccurred()) proxyAddr, err := net.ResolveUDPAddr("udp", ":10001") Expect(err).ToNot(HaveOccurred()) @@ -93,49 +98,135 @@ var _ = Describe("Integrationtests", func() { time.Sleep(time.Millisecond) }) - It("relays packets from the client to the server", func() { - _, err := clientConn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - time.Sleep(time.Millisecond) - _, err = clientConn.Write([]byte("decafbad")) - Expect(err).ToNot(HaveOccurred()) - time.Sleep(time.Millisecond) - Expect(serverReceivedPackets).To(HaveLen(2)) - Expect(serverReceivedPackets[0]).To(Equal(packetData("foobar"))) - Expect(serverReceivedPackets[1]).To(Equal(packetData("decafbad"))) + Context("no packet drop", func() { + BeforeEach(func() { + // setup the proxy + var err error + proxy, err = NewUDPProxy(10001, "localhost", serverPort, nil, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("relays packets from the client to the server", func() { + _, err := clientConn.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(time.Millisecond) + Expect(proxy.clientDict).To(HaveLen(1)) + var key string + var conn *connection + for key, conn = range proxy.clientDict { + Expect(conn.incomingPacketCounter).To(Equal(packetNumber(1))) + } + _, err = clientConn.Write([]byte("decafbad")) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(time.Millisecond) + Expect(proxy.clientDict).To(HaveLen(1)) + Expect(proxy.clientDict[key].incomingPacketCounter).To(Equal(packetNumber(2))) + Expect(serverReceivedPackets).To(HaveLen(2)) + Expect(serverReceivedPackets[0]).To(Equal(packetData("foobar"))) + Expect(serverReceivedPackets[1]).To(Equal(packetData("decafbad"))) + }) + + It("relays packets from the server to the client", func() { + _, err := clientConn.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(time.Millisecond) + Expect(proxy.clientDict).To(HaveLen(1)) + var key string + var conn *connection + for key, conn = range proxy.clientDict { + Expect(conn.outgoingPacketCounter).To(Equal(packetNumber(1))) + } + _, err = clientConn.Write([]byte("decafbad")) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(time.Millisecond) + Expect(proxy.clientDict).To(HaveLen(1)) + Expect(proxy.clientDict[key].outgoingPacketCounter).To(Equal(packetNumber(2))) + + var clientReceivedPackets []packetData + + // receive the packets echoed by the server on client side + go func() { + defer GinkgoRecover() + + for { + buf := make([]byte, 1500) + n, _, err2 := clientConn.ReadFromUDP(buf) + if err2 != nil { + return + } + data := buf[0:n] + clientReceivedPackets = append(clientReceivedPackets, packetData(data)) + } + }() + + time.Sleep(time.Millisecond) + Expect(serverReceivedPackets).To(HaveLen(2)) + Expect(serverNumPacketsSent).To(Equal(2)) + Expect(clientReceivedPackets).To(HaveLen(2)) + Expect(clientReceivedPackets[0]).To(Equal(packetData("foobar"))) + Expect(clientReceivedPackets[1]).To(Equal(packetData("decafbad"))) + }) }) - It("relays packets from the server to the client", func() { - _, err := clientConn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - time.Sleep(time.Millisecond) - _, err = clientConn.Write([]byte("decafbad")) - Expect(err).ToNot(HaveOccurred()) - time.Sleep(time.Millisecond) - - var clientReceivedPackets []packetData - - // receive the packets echoed by the server on client side - go func() { - defer GinkgoRecover() - - for { - buf := make([]byte, 1500) - n, _, err2 := clientConn.ReadFromUDP(buf) - if err2 != nil { - return - } - data := buf[0:n] - clientReceivedPackets = append(clientReceivedPackets, packetData(data)) + Context("Drop Callbacks", func() { + It("drops incoming packets", func() { + dropper := func(p packetNumber) bool { + return p%2 == 0 } - }() - time.Sleep(time.Millisecond) - Expect(serverReceivedPackets).To(HaveLen(2)) - Expect(serverNumPacketsSent).To(Equal(2)) - Expect(clientReceivedPackets).To(HaveLen(2)) - Expect(clientReceivedPackets[0]).To(Equal(packetData("foobar"))) - Expect(clientReceivedPackets[1]).To(Equal(packetData("decafbad"))) + var err error + proxy, err = NewUDPProxy(10001, "localhost", serverPort, dropper, nil) + Expect(err).ToNot(HaveOccurred()) + + for i := 1; i <= 6; i++ { + _, err := clientConn.Write([]byte("foobar" + strconv.Itoa(i))) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(time.Millisecond) + } + Expect(serverReceivedPackets).To(HaveLen(3)) + Expect(serverReceivedPackets[0]).To(Equal(packetData("foobar1"))) + Expect(serverReceivedPackets[1]).To(Equal(packetData("foobar3"))) + Expect(serverReceivedPackets[2]).To(Equal(packetData("foobar5"))) + }) + + It("drops outgoing packets", func() { + dropper := func(p packetNumber) bool { + return p%2 == 0 + } + + var err error + proxy, err = NewUDPProxy(10001, "localhost", serverPort, nil, dropper) + Expect(err).ToNot(HaveOccurred()) + + for i := 1; i <= 6; i++ { + _, err := clientConn.Write([]byte("foobar" + strconv.Itoa(i))) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(time.Millisecond) + } + + var clientReceivedPackets []packetData + + // receive the packets echoed by the server on client side + go func() { + defer GinkgoRecover() + + for { + buf := make([]byte, 1500) + n, _, err2 := clientConn.ReadFromUDP(buf) + if err2 != nil { + return + } + data := buf[0:n] + clientReceivedPackets = append(clientReceivedPackets, packetData(data)) + } + }() + + time.Sleep(time.Millisecond) + Expect(clientReceivedPackets).To(HaveLen(3)) + Expect(clientReceivedPackets[0]).To(Equal(packetData("foobar1"))) + Expect(clientReceivedPackets[1]).To(Equal(packetData("foobar3"))) + Expect(clientReceivedPackets[2]).To(Equal(packetData("foobar5"))) + }) }) }) })