forked from quic-go/quic-go
remove udp references from session to simplify testing
This commit is contained in:
12
server.go
12
server.go
@@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type PacketHandler interface {
|
||||
HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader)
|
||||
HandlePacket(addr interface{}, publicHeader *PublicHeader, r *bytes.Reader)
|
||||
Run()
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ type Server struct {
|
||||
|
||||
streamCallback StreamCallback
|
||||
|
||||
newSession func(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler
|
||||
}
|
||||
|
||||
// NewServer makes a new server
|
||||
@@ -106,7 +106,13 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
|
||||
session, ok := s.sessions[publicHeader.ConnectionID]
|
||||
if !ok {
|
||||
fmt.Printf("Serving new connection: %d from %v\n", publicHeader.ConnectionID, remoteAddr)
|
||||
session = s.newSession(conn, publicHeader.VersionNumber, publicHeader.ConnectionID, s.scfg, s.streamCallback)
|
||||
session = s.newSession(
|
||||
&udpConn{conn: conn, currentAddr: remoteAddr},
|
||||
publicHeader.VersionNumber,
|
||||
publicHeader.ConnectionID,
|
||||
s.scfg,
|
||||
s.streamCallback,
|
||||
)
|
||||
go session.Run()
|
||||
s.sessions[publicHeader.ConnectionID] = session
|
||||
}
|
||||
|
||||
@@ -17,14 +17,14 @@ type mockSession struct {
|
||||
packetCount int
|
||||
}
|
||||
|
||||
func (s *mockSession) HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) {
|
||||
func (s *mockSession) HandlePacket(addr interface{}, publicHeader *PublicHeader, r *bytes.Reader) {
|
||||
s.packetCount++
|
||||
}
|
||||
|
||||
func (s *mockSession) Run() {
|
||||
}
|
||||
|
||||
func newMockSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
|
||||
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
|
||||
return &mockSession{
|
||||
connectionID: connectionID,
|
||||
}
|
||||
|
||||
24
session.go
24
session.go
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
type receivedPacket struct {
|
||||
addr *net.UDPAddr
|
||||
remoteAddr interface{}
|
||||
publicHeader *PublicHeader
|
||||
r *bytes.Reader
|
||||
}
|
||||
@@ -33,8 +32,7 @@ type StreamCallback func(*Session, utils.Stream)
|
||||
type Session struct {
|
||||
streamCallback StreamCallback
|
||||
|
||||
connection *net.UDPConn
|
||||
currentRemoteAddr *net.UDPAddr
|
||||
conn connection
|
||||
|
||||
streams map[protocol.StreamID]*stream
|
||||
streamsMutex sync.RWMutex
|
||||
@@ -52,9 +50,9 @@ type Session struct {
|
||||
}
|
||||
|
||||
// NewSession makes a new session
|
||||
func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
|
||||
func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler {
|
||||
session := &Session{
|
||||
connection: conn,
|
||||
conn: conn,
|
||||
streamCallback: streamCallback,
|
||||
streams: make(map[protocol.StreamID]*stream),
|
||||
sentPacketHandler: ackhandler.NewSentPacketHandler(),
|
||||
@@ -79,7 +77,7 @@ func (s *Session) Run() {
|
||||
var err error
|
||||
select {
|
||||
case p := <-s.receivedPackets:
|
||||
err = s.handlePacket(p.addr, p.publicHeader, p.r)
|
||||
err = s.handlePacket(p.remoteAddr, p.publicHeader, p.r)
|
||||
case <-time.After(sendTimeout):
|
||||
err = s.sendPacket()
|
||||
}
|
||||
@@ -100,7 +98,7 @@ func (s *Session) Run() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) error {
|
||||
func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeader, r *bytes.Reader) error {
|
||||
// Calcualate packet number
|
||||
publicHeader.PacketNumber = calculatePacketNumber(
|
||||
publicHeader.PacketNumberLen,
|
||||
@@ -111,9 +109,7 @@ func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r
|
||||
fmt.Printf("<- Reading packet %d for connection %d\n", publicHeader.PacketNumber, publicHeader.ConnectionID)
|
||||
|
||||
// TODO: Only do this after authenticating
|
||||
if addr != s.currentRemoteAddr {
|
||||
s.currentRemoteAddr = addr
|
||||
}
|
||||
s.conn.setCurrentRemoteAddr(remoteAddr)
|
||||
|
||||
packet, err := s.unpacker.Unpack(publicHeader.Raw, publicHeader, r)
|
||||
if err != nil {
|
||||
@@ -149,8 +145,8 @@ func (s *Session) handlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r
|
||||
}
|
||||
|
||||
// HandlePacket handles a packet
|
||||
func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeader *PublicHeader, r *bytes.Reader) {
|
||||
s.receivedPackets <- receivedPacket{addr: addr, publicHeader: publicHeader, r: r}
|
||||
func (s *Session) HandlePacket(remoteAddr interface{}, publicHeader *PublicHeader, r *bytes.Reader) {
|
||||
s.receivedPackets <- receivedPacket{remoteAddr: remoteAddr, publicHeader: publicHeader, r: r}
|
||||
}
|
||||
|
||||
// TODO: Ignore data for closed streams
|
||||
@@ -244,7 +240,7 @@ func (s *Session) sendPacket() error {
|
||||
EntropyBit: packet.entropyBit,
|
||||
})
|
||||
fmt.Printf("-> Sending packet %d (%d bytes)\n", packet.number, len(packet.raw))
|
||||
_, err = s.connection.WriteToUDP(packet.raw, s.currentRemoteAddr)
|
||||
err = s.conn.write(packet.raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
type mockConnection struct{}
|
||||
|
||||
func (*mockConnection) write(p []byte) error { return nil }
|
||||
func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {}
|
||||
|
||||
var _ = Describe("Session", func() {
|
||||
var (
|
||||
session *Session
|
||||
@@ -172,7 +177,7 @@ var _ = Describe("Session", func() {
|
||||
signer, err := crypto.NewRSASigner(path+"cert.der", path+"key.der")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = NewSession(nil, 0, 0, scfg, nil).(*Session)
|
||||
session = NewSession(&mockConnection{}, 0, 0, scfg, nil).(*Session)
|
||||
})
|
||||
|
||||
It("shuts down without error", func() {
|
||||
|
||||
24
udp_conn.go
Normal file
24
udp_conn.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package quic
|
||||
|
||||
import "net"
|
||||
|
||||
type connection interface {
|
||||
write([]byte) error
|
||||
setCurrentRemoteAddr(interface{})
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
conn *net.UDPConn
|
||||
currentAddr *net.UDPAddr
|
||||
}
|
||||
|
||||
var _ connection = &udpConn{}
|
||||
|
||||
func (c *udpConn) write(p []byte) error {
|
||||
_, err := c.conn.WriteToUDP(p, c.currentAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *udpConn) setCurrentRemoteAddr(addr interface{}) {
|
||||
c.currentAddr = addr.(*net.UDPAddr)
|
||||
}
|
||||
Reference in New Issue
Block a user