remove udp references from session to simplify testing

This commit is contained in:
Lucas Clemente
2016-04-26 18:00:41 +02:00
parent 8339f210cb
commit a5a06a25c2
5 changed files with 51 additions and 20 deletions

View File

@@ -13,7 +13,7 @@ import (
)
type PacketHandler interface {
HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader)
HandlePacket(addr interface{}, publicHeader *PublicHeader, r *bytes.Reader)
Run()
}
@@ -28,7 +28,7 @@ type Server struct {
streamCallback StreamCallback
newSession func(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler
}
// NewServer makes a new server
@@ -106,7 +106,13 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
session, ok := s.sessions[publicHeader.ConnectionID]
if !ok {
fmt.Printf("Serving new connection: %d from %v\n", publicHeader.ConnectionID, remoteAddr)
session = s.newSession(conn, publicHeader.VersionNumber, publicHeader.ConnectionID, s.scfg, s.streamCallback)
session = s.newSession(
&udpConn{conn: conn, currentAddr: remoteAddr},
publicHeader.VersionNumber,
publicHeader.ConnectionID,
s.scfg,
s.streamCallback,
)
go session.Run()
s.sessions[publicHeader.ConnectionID] = session
}

View File

@@ -17,14 +17,14 @@ type mockSession struct {
packetCount int
}
func (s *mockSession) HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) {
func (s *mockSession) HandlePacket(addr interface{}, publicHeader *PublicHeader, r *bytes.Reader) {
s.packetCount++
}
func (s *mockSession) Run() {
}
func newMockSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
return &mockSession{
connectionID: connectionID,
}

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"errors"
"fmt"
"net"
"sync"
"time"
@@ -17,7 +16,7 @@ import (
)
type receivedPacket struct {
addr *net.UDPAddr
remoteAddr interface{}
publicHeader *PublicHeader
r *bytes.Reader
}
@@ -33,8 +32,7 @@ type StreamCallback func(*Session, utils.Stream)
type Session struct {
streamCallback StreamCallback
connection *net.UDPConn
currentRemoteAddr *net.UDPAddr
conn connection
streams map[protocol.StreamID]*stream
streamsMutex sync.RWMutex
@@ -52,9 +50,9 @@ type Session struct {
}
// NewSession makes a new session
func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
session := &Session{
connection: conn,
conn: conn,
streamCallback: streamCallback,
streams: make(map[protocol.StreamID]*stream),
sentPacketHandler: ackhandler.NewSentPacketHandler(),
@@ -79,7 +77,7 @@ func (s *Session) Run() {
var err error
select {
case p := <-s.receivedPackets:
err = s.handlePacket(p.addr, p.publicHeader, p.r)
err = s.handlePacket(p.remoteAddr, p.publicHeader, p.r)
case <-time.After(sendTimeout):
err = s.sendPacket()
}
@@ -100,7 +98,7 @@ func (s *Session) Run() {
}
}
func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) error {
func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeader, r *bytes.Reader) error {
// Calcualate packet number
publicHeader.PacketNumber = calculatePacketNumber(
publicHeader.PacketNumberLen,
@@ -111,9 +109,7 @@ func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r
fmt.Printf("<- Reading packet %d for connection %d\n", publicHeader.PacketNumber, publicHeader.ConnectionID)
// TODO: Only do this after authenticating
if addr != s.currentRemoteAddr {
s.currentRemoteAddr = addr
}
s.conn.setCurrentRemoteAddr(remoteAddr)
packet, err := s.unpacker.Unpack(publicHeader.Raw, publicHeader, r)
if err != nil {
@@ -149,8 +145,8 @@ func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r
}
// HandlePacket handles a packet
func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) {
s.receivedPackets <- receivedPacket{addr: addr, publicHeader: publicHeader, r: r}
func (s *Session) HandlePacket(remoteAddr interface{}, publicHeader *PublicHeader, r *bytes.Reader) {
s.receivedPackets <- receivedPacket{remoteAddr: remoteAddr, publicHeader: publicHeader, r: r}
}
// TODO: Ignore data for closed streams
@@ -244,7 +240,7 @@ func (s *Session) sendPacket() error {
EntropyBit: packet.entropyBit,
})
fmt.Printf("-> Sending packet %d (%d bytes)\n", packet.number, len(packet.raw))
_, err = s.connection.WriteToUDP(packet.raw, s.currentRemoteAddr)
err = s.conn.write(packet.raw)
if err != nil {
return err
}

View File

@@ -17,6 +17,11 @@ import (
"github.com/lucas-clemente/quic-go/utils"
)
type mockConnection struct{}
func (*mockConnection) write(p []byte) error { return nil }
func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {}
var _ = Describe("Session", func() {
var (
session *Session
@@ -172,7 +177,7 @@ var _ = Describe("Session", func() {
signer, err := crypto.NewRSASigner(path+"cert.der", path+"key.der")
Expect(err).ToNot(HaveOccurred())
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
session = NewSession(nil, 0, 0, scfg, nil).(*Session)
session = NewSession(&mockConnection{}, 0, 0, scfg, nil).(*Session)
})
It("shuts down without error", func() {

24
udp_conn.go Normal file
View File

@@ -0,0 +1,24 @@
package quic
import "net"
type connection interface {
write([]byte) error
setCurrentRemoteAddr(interface{})
}
type udpConn struct {
conn *net.UDPConn
currentAddr *net.UDPAddr
}
var _ connection = &udpConn{}
func (c *udpConn) write(p []byte) error {
_, err := c.conn.WriteToUDP(p, c.currentAddr)
return err
}
func (c *udpConn) setCurrentRemoteAddr(addr interface{}) {
c.currentAddr = addr.(*net.UDPAddr)
}