diff --git a/server.go b/server.go index d050fd7e..77011df0 100644 --- a/server.go +++ b/server.go @@ -13,7 +13,7 @@ import ( ) type PacketHandler interface { - HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) + HandlePacket(addr interface{}, publicHeader *PublicHeader, r *bytes.Reader) Run() } @@ -28,7 +28,7 @@ type Server struct { streamCallback StreamCallback - newSession func(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler } // NewServer makes a new server @@ -106,7 +106,13 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet session, ok := s.sessions[publicHeader.ConnectionID] if !ok { fmt.Printf("Serving new connection: %d from %v\n", publicHeader.ConnectionID, remoteAddr) - session = s.newSession(conn, publicHeader.VersionNumber, publicHeader.ConnectionID, s.scfg, s.streamCallback) + session = s.newSession( + &udpConn{conn: conn, currentAddr: remoteAddr}, + publicHeader.VersionNumber, + publicHeader.ConnectionID, + s.scfg, + s.streamCallback, + ) go session.Run() s.sessions[publicHeader.ConnectionID] = session } diff --git a/server_test.go b/server_test.go index d63a0983..1aa987ef 100644 --- a/server_test.go +++ b/server_test.go @@ -17,14 +17,14 @@ type mockSession struct { packetCount int } -func (s *mockSession) HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) { +func (s *mockSession) HandlePacket(addr interface{}, publicHeader *PublicHeader, r *bytes.Reader) { s.packetCount++ } func (s *mockSession) Run() { } -func newMockSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler { +func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler { return &mockSession{ connectionID: connectionID, } diff --git a/session.go b/session.go index 34dea457..76161d10 100644 --- a/session.go +++ b/session.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "net" "sync" "time" @@ -17,7 +16,7 @@ import ( ) type receivedPacket struct { - addr *net.UDPAddr + remoteAddr interface{} publicHeader *PublicHeader r *bytes.Reader } @@ -33,8 +32,7 @@ type StreamCallback func(*Session, utils.Stream) type Session struct { streamCallback StreamCallback - connection *net.UDPConn - currentRemoteAddr *net.UDPAddr + conn connection streams map[protocol.StreamID]*stream streamsMutex sync.RWMutex @@ -52,9 +50,9 @@ type Session struct { } // NewSession makes a new session -func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler { +func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler { session := &Session{ - connection: conn, + conn: conn, streamCallback: streamCallback, streams: make(map[protocol.StreamID]*stream), sentPacketHandler: ackhandler.NewSentPacketHandler(), @@ -79,7 +77,7 @@ func (s *Session) Run() { var err error select { case p := <-s.receivedPackets: - err = s.handlePacket(p.addr, p.publicHeader, p.r) + err = s.handlePacket(p.remoteAddr, p.publicHeader, p.r) case <-time.After(sendTimeout): err = s.sendPacket() } @@ -100,7 +98,7 @@ func (s *Session) Run() { } } -func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) error { +func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeader, r *bytes.Reader) error { // Calcualate packet number publicHeader.PacketNumber = calculatePacketNumber( publicHeader.PacketNumberLen, @@ -111,9 +109,7 @@ func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r fmt.Printf("<- Reading packet %d for connection %d\n", publicHeader.PacketNumber, publicHeader.ConnectionID) // TODO: Only do this after authenticating - if addr != s.currentRemoteAddr { - s.currentRemoteAddr = addr - } + s.conn.setCurrentRemoteAddr(remoteAddr) packet, err := s.unpacker.Unpack(publicHeader.Raw, publicHeader, r) if err != nil { @@ -149,8 +145,8 @@ func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r } // HandlePacket handles a packet -func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) { - s.receivedPackets <- receivedPacket{addr: addr, publicHeader: publicHeader, r: r} +func (s *Session) HandlePacket(remoteAddr interface{}, publicHeader *PublicHeader, r *bytes.Reader) { + s.receivedPackets <- receivedPacket{remoteAddr: remoteAddr, publicHeader: publicHeader, r: r} } // TODO: Ignore data for closed streams @@ -244,7 +240,7 @@ func (s *Session) sendPacket() error { EntropyBit: packet.entropyBit, }) fmt.Printf("-> Sending packet %d (%d bytes)\n", packet.number, len(packet.raw)) - _, err = s.connection.WriteToUDP(packet.raw, s.currentRemoteAddr) + err = s.conn.write(packet.raw) if err != nil { return err } diff --git a/session_test.go b/session_test.go index f5d1387f..abf68f3b 100644 --- a/session_test.go +++ b/session_test.go @@ -17,6 +17,11 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) +type mockConnection struct{} + +func (*mockConnection) write(p []byte) error { return nil } +func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {} + var _ = Describe("Session", func() { var ( session *Session @@ -172,7 +177,7 @@ var _ = Describe("Session", func() { signer, err := crypto.NewRSASigner(path+"cert.der", path+"key.der") Expect(err).ToNot(HaveOccurred()) scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = NewSession(nil, 0, 0, scfg, nil).(*Session) + session = NewSession(&mockConnection{}, 0, 0, scfg, nil).(*Session) }) It("shuts down without error", func() { diff --git a/udp_conn.go b/udp_conn.go new file mode 100644 index 00000000..1e2938ee --- /dev/null +++ b/udp_conn.go @@ -0,0 +1,24 @@ +package quic + +import "net" + +type connection interface { + write([]byte) error + setCurrentRemoteAddr(interface{}) +} + +type udpConn struct { + conn *net.UDPConn + currentAddr *net.UDPAddr +} + +var _ connection = &udpConn{} + +func (c *udpConn) write(p []byte) error { + _, err := c.conn.WriteToUDP(p, c.currentAddr) + return err +} + +func (c *udpConn) setCurrentRemoteAddr(addr interface{}) { + c.currentAddr = addr.(*net.UDPAddr) +}