forked from quic-go/quic-go
212
client.go
212
client.go
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
@@ -13,20 +12,19 @@ import (
|
||||
)
|
||||
|
||||
type client struct {
|
||||
sconn sendConn
|
||||
// If the client is created with DialAddr, we create a packet conn.
|
||||
// If it is started with Dial, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
sendConn sendConn
|
||||
|
||||
use0RTT bool
|
||||
|
||||
packetHandlers packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
initialPacketNumber protocol.PacketNumber
|
||||
hasNegotiatedVersion bool
|
||||
@@ -46,32 +44,58 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) {
|
||||
return dialAddrContext(ctx, addr, tlsConf, config, false)
|
||||
}
|
||||
|
||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
|
||||
conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialContext(ctx, udpConn, udpAddr, tlsConf, config, use0RTT, true)
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dl.Dial(ctx, udpAddr, tlsConf, conf)
|
||||
}
|
||||
|
||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
|
||||
// See DialEarly for details.
|
||||
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
|
||||
@@ -79,34 +103,42 @@ func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, conf
|
||||
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
|
||||
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
|
||||
// packets.
|
||||
// The same PacketConn can be used for multiple calls to Dial and Listen.
|
||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||
// The tls.Config must define an application protocol (using NextProtos).
|
||||
func Dial(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) {
|
||||
return dialContext(ctx, pconn, addr, tlsConf, config, false, false)
|
||||
}
|
||||
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
|
||||
// The same PacketConn can be used for multiple calls to Dial and Listen,
|
||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||
// The tls.Config must define an application protocol (using NextProtos).
|
||||
func DialEarly(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
|
||||
return dialContext(ctx, pconn, addr, tlsConf, config, true, false)
|
||||
}
|
||||
|
||||
func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
||||
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, addr, config, tlsConf, use0RTT, createdPacketConn)
|
||||
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
return &Transport{
|
||||
Conn: c,
|
||||
createdConn: createdPacketConn,
|
||||
isSingleUse: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func dial(
|
||||
ctx context.Context,
|
||||
conn sendConn,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
packetHandlers packetHandlerManager,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
) (quicConn, error) {
|
||||
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -114,14 +146,10 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo
|
||||
|
||||
c.tracingID = nextConnTracingID()
|
||||
if c.config.Tracer != nil {
|
||||
c.tracer = c.config.Tracer.TracerForConnection(
|
||||
context.WithValue(ctx, ConnectionTracingKey, c.tracingID),
|
||||
protocol.PerspectiveClient,
|
||||
c.destConnID,
|
||||
)
|
||||
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
|
||||
}
|
||||
if c.tracer != nil {
|
||||
c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
}
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
@@ -129,23 +157,14 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, use0RTT bool, createdPacketConn bool) (*client, error) {
|
||||
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
tlsConf = tlsConf.Clone()
|
||||
}
|
||||
|
||||
// check that all versions are actually supported
|
||||
if config != nil {
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
srcConnID, err := connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -154,28 +173,30 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
|
||||
return nil, err
|
||||
}
|
||||
c := &client{
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sconn: newSendPconn(pconn, remoteAddr),
|
||||
createdPacketConn: createdPacketConn,
|
||||
use0RTT: use0RTT,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
connIDGenerator: connIDGenerator,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sendConn: sendConn,
|
||||
use0RTT: use0RTT,
|
||||
onClose: onClose,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
c.conn = newClientConnection(
|
||||
c.sconn,
|
||||
c.sendConn,
|
||||
c.packetHandlers,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.connIDGenerator,
|
||||
c.config,
|
||||
c.tlsConf,
|
||||
c.initialPacketNumber,
|
||||
@@ -189,13 +210,18 @@ func (c *client) dial(ctx context.Context) error {
|
||||
c.packetHandlers.Add(c.srcConnID, c.conn)
|
||||
|
||||
errorChan := make(chan error, 1)
|
||||
recreateChan := make(chan errCloseForRecreating)
|
||||
go func() {
|
||||
err := c.conn.run() // returns as soon as the connection is closed
|
||||
|
||||
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
|
||||
c.packetHandlers.Destroy()
|
||||
err := c.conn.run()
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
recreateChan <- *recreateErr
|
||||
return
|
||||
}
|
||||
errorChan <- err
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
errorChan <- err // returns as soon as the connection is closed
|
||||
}()
|
||||
|
||||
// only set when we're using 0-RTT
|
||||
@@ -210,14 +236,12 @@ func (c *client) dial(ctx context.Context) error {
|
||||
c.conn.shutdown()
|
||||
return ctx.Err()
|
||||
case err := <-errorChan:
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||
c.version = recreateErr.nextVersion
|
||||
c.hasNegotiatedVersion = true
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
case recreateErr := <-recreateChan:
|
||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||
c.version = recreateErr.nextVersion
|
||||
c.hasNegotiatedVersion = true
|
||||
return c.dial(ctx)
|
||||
case <-earlyConnChan:
|
||||
// ready to send 0-RTT data
|
||||
return nil
|
||||
|
||||
287
client_test.go
287
client_test.go
@@ -18,13 +18,16 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type nullMultiplexer struct{}
|
||||
|
||||
func (n nullMultiplexer) AddConn(indexableConn) {}
|
||||
func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
cl *client
|
||||
packetConn *MockPacketConn
|
||||
addr net.Addr
|
||||
packetConn *MockSendConn
|
||||
connID protocol.ConnectionID
|
||||
mockMultiplexer *MockMultiplexer
|
||||
origMultiplexer multiplexer
|
||||
tlsConf *tls.Config
|
||||
tracer *mocklogging.MockConnectionTracer
|
||||
@@ -35,6 +38,7 @@ var _ = Describe("Client", func() {
|
||||
runner connRunner,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
@@ -52,26 +56,28 @@ var _ = Describe("Client", func() {
|
||||
connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37})
|
||||
originalClientConnConstructor = newClientConnection
|
||||
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
|
||||
tr := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
|
||||
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
config = &Config{
|
||||
Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer {
|
||||
return tracer
|
||||
},
|
||||
Versions: []protocol.VersionNumber{protocol.Version1},
|
||||
}
|
||||
Eventually(areConnsRunning).Should(BeFalse())
|
||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
packetConn = NewMockPacketConn(mockCtrl)
|
||||
packetConn = NewMockSendConn(mockCtrl)
|
||||
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
cl = &client{
|
||||
srcConnID: connID,
|
||||
destConnID: connID,
|
||||
version: protocol.Version1,
|
||||
sconn: newSendPconn(packetConn, addr),
|
||||
sendConn: packetConn,
|
||||
tracer: tracer,
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
getMultiplexer() // make the sync.Once execute
|
||||
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
|
||||
mockMultiplexer = NewMockMultiplexer(mockCtrl)
|
||||
// replace the clientMuxer. getMultiplexer will now return the nullMultiplexer
|
||||
origMultiplexer = connMuxer
|
||||
connMuxer = mockMultiplexer
|
||||
connMuxer = &nullMultiplexer{}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
@@ -100,50 +106,17 @@ var _ = Describe("Client", func() {
|
||||
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
|
||||
})
|
||||
|
||||
It("resolves the address", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
manager.EXPECT().Destroy()
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
remoteAddrChan := make(chan string, 1)
|
||||
newClientConnection = func(
|
||||
sconn sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
remoteAddrChan <- sconn.RemoteAddr().String()
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
return conn
|
||||
}
|
||||
_, err := DialAddr(context.Background(), "localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
|
||||
})
|
||||
|
||||
It("returns after the handshake is complete", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
run := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
@@ -162,18 +135,17 @@ var _ = Describe("Client", func() {
|
||||
conn.EXPECT().HandshakeComplete().Return(c)
|
||||
return conn
|
||||
}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
s, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(Succeed())
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns early connections", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
readyChan := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
@@ -181,6 +153,7 @@ var _ = Describe("Client", func() {
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
@@ -193,29 +166,23 @@ var _ = Describe("Client", func() {
|
||||
) quicConn {
|
||||
Expect(enable0RTT).To(BeTrue())
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Do(func() { <-done })
|
||||
conn.EXPECT().run().Do(func() { close(done) })
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().earlyConnReady().Return(readyChan)
|
||||
return conn
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
s, err := DialEarly(context.Background(), packetConn, addr, tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
}()
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
close(readyChan)
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
testErr := errors.New("early handshake error")
|
||||
newClientConnection = func(
|
||||
@@ -223,6 +190,7 @@ var _ = Describe("Client", func() {
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
@@ -236,150 +204,44 @@ var _ = Describe("Client", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Return(testErr)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().earlyConnReady().Return(make(chan struct{}))
|
||||
return conn
|
||||
}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("closes the connection when the context is canceled", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
connRunning := make(chan struct{})
|
||||
defer close(connRunning)
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run().Do(func() {
|
||||
<-connRunning
|
||||
})
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
newClientConnection = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
return conn
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := Dial(ctx, packetConn, addr, tlsConf, config)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
close(dialed)
|
||||
}()
|
||||
Consistently(dialed).ShouldNot(BeClosed())
|
||||
conn.EXPECT().shutdown()
|
||||
cancel()
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the connection when it was created by DialAddr", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
|
||||
var sconn sendConn
|
||||
run := make(chan struct{})
|
||||
connCreated := make(chan struct{})
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
newClientConnection = func(
|
||||
connP sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
sconn = connP
|
||||
close(connCreated)
|
||||
return conn
|
||||
}
|
||||
conn.EXPECT().run().Do(func() {
|
||||
<-run
|
||||
})
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := DialAddr(context.Background(), "localhost:1337", tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(connCreated).Should(BeClosed())
|
||||
|
||||
// check that the connection is not closed
|
||||
Expect(sconn.Write([]byte("foobar"))).To(Succeed())
|
||||
|
||||
manager.EXPECT().Destroy()
|
||||
close(run)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
Eventually(done).Should(BeClosed())
|
||||
var closed bool
|
||||
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.packetHandlers = manager
|
||||
Expect(cl).ToNot(BeNil())
|
||||
Expect(cl.dial(context.Background())).To(MatchError(testErr))
|
||||
Expect(closed).To(BeTrue())
|
||||
})
|
||||
|
||||
Context("quic.Config", func() {
|
||||
It("setups with the right values", func() {
|
||||
srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}
|
||||
tokenStore := NewLRUTokenStore(10, 4)
|
||||
config := &Config{
|
||||
HandshakeIdleTimeout: 1337 * time.Minute,
|
||||
MaxIdleTimeout: 42 * time.Hour,
|
||||
MaxIncomingStreams: 1234,
|
||||
MaxIncomingUniStreams: 4321,
|
||||
ConnectionIDLength: 13,
|
||||
StatelessResetKey: srk,
|
||||
TokenStore: tokenStore,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
c := populateClientConfig(config, false)
|
||||
c := populateConfig(config)
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
|
||||
Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour))
|
||||
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
|
||||
Expect(c.ConnectionIDLength).To(Equal(13))
|
||||
Expect(c.StatelessResetKey).To(Equal(srk))
|
||||
Expect(c.TokenStore).To(Equal(tokenStore))
|
||||
Expect(c.EnableDatagrams).To(BeTrue())
|
||||
})
|
||||
|
||||
It("errors when the Config contains an invalid version", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
version := protocol.VersionNumber(0x1234)
|
||||
_, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
|
||||
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
|
||||
})
|
||||
|
||||
It("disables bidirectional streams", func() {
|
||||
config := &Config{
|
||||
MaxIncomingStreams: -1,
|
||||
MaxIncomingUniStreams: 4321,
|
||||
}
|
||||
c := populateClientConfig(config, false)
|
||||
c := populateConfig(config)
|
||||
Expect(c.MaxIncomingStreams).To(BeZero())
|
||||
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
|
||||
})
|
||||
@@ -389,18 +251,13 @@ var _ = Describe("Client", func() {
|
||||
MaxIncomingStreams: 1234,
|
||||
MaxIncomingUniStreams: -1,
|
||||
}
|
||||
c := populateClientConfig(config, false)
|
||||
c := populateConfig(config)
|
||||
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
Expect(c.MaxIncomingUniStreams).To(BeZero())
|
||||
})
|
||||
|
||||
It("uses 0-byte connection IDs when dialing an address", func() {
|
||||
c := populateClientConfig(&Config{}, true)
|
||||
Expect(c.ConnectionIDLength).To(BeZero())
|
||||
})
|
||||
|
||||
It("fills in default values if options are not set in the Config", func() {
|
||||
c := populateClientConfig(&Config{}, false)
|
||||
c := populateConfig(&Config{})
|
||||
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||
Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
||||
@@ -408,20 +265,17 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
|
||||
It("creates new connections with the right parameters", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
c := make(chan struct{})
|
||||
var cconn sendConn
|
||||
var version protocol.VersionNumber
|
||||
var conf *Config
|
||||
done := make(chan struct{})
|
||||
newClientConnection = func(
|
||||
connP sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
@@ -432,7 +286,6 @@ var _ = Describe("Client", func() {
|
||||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
) quicConn {
|
||||
cconn = connP
|
||||
version = versionP
|
||||
conf = configP
|
||||
close(c)
|
||||
@@ -440,28 +293,32 @@ var _ = Describe("Client", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
close(done)
|
||||
return conn
|
||||
}
|
||||
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
|
||||
packetConn := NewMockPacketConn(mockCtrl)
|
||||
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
|
||||
<-done
|
||||
return 0, nil, errors.New("closed")
|
||||
})
|
||||
packetConn.EXPECT().LocalAddr()
|
||||
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
||||
_, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn))
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
||||
It("creates a new connections after version negotiation", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
|
||||
manager.EXPECT().Destroy()
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
var counter int
|
||||
newClientConnection = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
runner connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
connID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
@@ -477,20 +334,24 @@ var _ = Describe("Client", func() {
|
||||
if counter == 0 {
|
||||
Expect(pn).To(BeZero())
|
||||
Expect(hasNegotiatedVersion).To(BeFalse())
|
||||
conn.EXPECT().run().Return(&errCloseForRecreating{
|
||||
nextPacketNumber: 109,
|
||||
nextVersion: 789,
|
||||
conn.EXPECT().run().DoAndReturn(func() error {
|
||||
runner.Remove(connID)
|
||||
return &errCloseForRecreating{
|
||||
nextPacketNumber: 109,
|
||||
nextVersion: 789,
|
||||
}
|
||||
})
|
||||
} else {
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(109)))
|
||||
Expect(hasNegotiatedVersion).To(BeTrue())
|
||||
conn.EXPECT().run()
|
||||
conn.EXPECT().destroy(gomock.Any())
|
||||
}
|
||||
counter++
|
||||
return conn
|
||||
}
|
||||
|
||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
|
||||
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}}
|
||||
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -498,15 +359,3 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
type mockConnIDGenerator struct {
|
||||
ConnID protocol.ConnectionID
|
||||
}
|
||||
|
||||
func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) {
|
||||
return m.ConnID, nil
|
||||
}
|
||||
|
||||
func (m *mockConnIDGenerator) ConnectionIDLen() int {
|
||||
return m.ConnID.Len()
|
||||
}
|
||||
|
||||
35
config.go
35
config.go
@@ -2,6 +2,7 @@ package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@@ -29,13 +30,19 @@ func validateConfig(config *Config) error {
|
||||
if config.MaxIncomingUniStreams > 1<<60 {
|
||||
return errors.New("invalid value for Config.MaxIncomingUniStreams")
|
||||
}
|
||||
// check that all QUIC versions are actually supported
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return fmt.Errorf("invalid QUIC version: %s", v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
config = populateConfig(config, protocol.DefaultConnectionIDLength)
|
||||
config = populateConfig(config)
|
||||
if config.MaxTokenAge == 0 {
|
||||
config.MaxTokenAge = protocol.TokenValidity
|
||||
}
|
||||
@@ -48,19 +55,9 @@ func populateServerConfig(config *Config) *Config {
|
||||
return config
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// populateConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
defaultConnIDLen := protocol.DefaultConnectionIDLength
|
||||
if createdPacketConn {
|
||||
defaultConnIDLen = 0
|
||||
}
|
||||
|
||||
config = populateConfig(config, defaultConnIDLen)
|
||||
return config
|
||||
}
|
||||
|
||||
func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
||||
func populateConfig(config *Config) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
@@ -68,10 +65,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
||||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
}
|
||||
conIDLen := config.ConnectionIDLength
|
||||
if config.ConnectionIDLength == 0 {
|
||||
conIDLen = defaultConnIDLen
|
||||
}
|
||||
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
|
||||
if config.HandshakeIdleTimeout != 0 {
|
||||
handshakeIdleTimeout = config.HandshakeIdleTimeout
|
||||
@@ -108,12 +101,9 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
||||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDGenerator := config.ConnectionIDGenerator
|
||||
if connIDGenerator == nil {
|
||||
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
GetConfigForClient: config.GetConfigForClient,
|
||||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
@@ -128,9 +118,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
||||
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
ConnectionIDLength: conIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
StatelessResetKey: config.StatelessResetKey,
|
||||
TokenStore: config.TokenStore,
|
||||
EnableDatagrams: config.EnableDatagrams,
|
||||
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -45,7 +47,7 @@ var _ = Describe("Config", func() {
|
||||
}
|
||||
|
||||
switch fn := typ.Field(i).Name; fn {
|
||||
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT":
|
||||
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
|
||||
// Can't compare functions.
|
||||
case "Versions":
|
||||
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
|
||||
@@ -85,8 +87,8 @@ var _ = Describe("Config", func() {
|
||||
f.Set(reflect.ValueOf(true))
|
||||
case "DisablePathMTUDiscovery":
|
||||
f.Set(reflect.ValueOf(true))
|
||||
case "Tracer":
|
||||
f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl)))
|
||||
case "Allow0RTT":
|
||||
f.Set(reflect.ValueOf(true))
|
||||
default:
|
||||
Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
|
||||
}
|
||||
@@ -106,16 +108,25 @@ var _ = Describe("Config", func() {
|
||||
|
||||
Context("cloning", func() {
|
||||
It("clones function fields", func() {
|
||||
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
|
||||
var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool
|
||||
c1 := &Config{
|
||||
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
|
||||
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
|
||||
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
|
||||
Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer {
|
||||
calledTracer = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
c2 := c1.Clone()
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
Expect(calledAddrValidation).To(BeTrue())
|
||||
c2.AllowConnectionWindowIncrease(nil, 1234)
|
||||
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
|
||||
_, err := c2.GetConfigForClient(&ClientHelloInfo{})
|
||||
Expect(err).To(MatchError("nope"))
|
||||
c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
|
||||
Expect(calledTracer).To(BeTrue())
|
||||
})
|
||||
|
||||
It("clones non-function fields", func() {
|
||||
@@ -142,18 +153,18 @@ var _ = Describe("Config", func() {
|
||||
var calledAddrValidation bool
|
||||
c1 := &Config{}
|
||||
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
|
||||
c2 := populateConfig(c1, protocol.DefaultConnectionIDLength)
|
||||
c2 := populateConfig(c1)
|
||||
c2.RequireAddressValidation(&net.UDPAddr{})
|
||||
Expect(calledAddrValidation).To(BeTrue())
|
||||
})
|
||||
|
||||
It("copies non-function fields", func() {
|
||||
c := configWithNonZeroNonFunctionFields()
|
||||
Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c))
|
||||
Expect(populateConfig(c)).To(Equal(c))
|
||||
})
|
||||
|
||||
It("populates empty fields with default values", func() {
|
||||
c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength)
|
||||
c := populateConfig(&Config{})
|
||||
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
|
||||
Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
|
||||
@@ -164,22 +175,12 @@ var _ = Describe("Config", func() {
|
||||
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
|
||||
Expect(c.DisableVersionNegotiationPackets).To(BeFalse())
|
||||
Expect(c.DisablePathMTUDiscovery).To(BeFalse())
|
||||
Expect(c.GetConfigForClient).To(BeNil())
|
||||
})
|
||||
|
||||
It("populates empty fields with default values, for the server", func() {
|
||||
c := populateServerConfig(&Config{})
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
Expect(c.RequireAddressValidation).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
|
||||
c := populateClientConfig(&Config{}, false)
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
})
|
||||
|
||||
It("doesn't set a default connection ID length if we created the conn, for the client", func() {
|
||||
c := populateClientConfig(&Config{}, true)
|
||||
Expect(c.ConnectionIDLength).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -240,6 +240,7 @@ var newConnection = func(
|
||||
clientDestConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
statelessResetToken protocol.StatelessResetToken,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
@@ -283,7 +284,7 @@ var newConnection = func(
|
||||
runner.Retire,
|
||||
runner.ReplaceWithClosed,
|
||||
s.queueControlFrame,
|
||||
s.config.ConnectionIDGenerator,
|
||||
connIDGenerator,
|
||||
)
|
||||
s.preSetup()
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
||||
@@ -323,10 +324,6 @@ var newConnection = func(
|
||||
if s.tracer != nil {
|
||||
s.tracer.SentTransportParameters(params)
|
||||
}
|
||||
var allow0RTT func() bool
|
||||
if conf.Allow0RTT != nil {
|
||||
allow0RTT = func() bool { return conf.Allow0RTT(conn.RemoteAddr()) }
|
||||
}
|
||||
cs := handshake.NewCryptoSetupServer(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
@@ -344,7 +341,7 @@ var newConnection = func(
|
||||
},
|
||||
},
|
||||
tlsConf,
|
||||
allow0RTT,
|
||||
conf.Allow0RTT,
|
||||
s.rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
@@ -363,6 +360,7 @@ var newClientConnection = func(
|
||||
runner connRunner,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
@@ -402,7 +400,7 @@ var newClientConnection = func(
|
||||
runner.Retire,
|
||||
runner.ReplaceWithClosed,
|
||||
s.queueControlFrame,
|
||||
s.config.ConnectionIDGenerator,
|
||||
connIDGenerator,
|
||||
)
|
||||
s.preSetup()
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
||||
|
||||
@@ -113,6 +113,7 @@ var _ = Describe("Connection", func() {
|
||||
clientDestConnID,
|
||||
destConnID,
|
||||
srcConnID,
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
protocol.StatelessResetToken{},
|
||||
populateServerConfig(&Config{DisablePathMTUDiscovery: true}),
|
||||
nil, // tls.Config
|
||||
@@ -2015,8 +2016,6 @@ var _ = Describe("Connection", func() {
|
||||
packer.EXPECT().HandleTransportParameters(params)
|
||||
packer.EXPECT().PackCoalescedPacket(false, conn.version).MaxTimes(3)
|
||||
Expect(conn.earlyConnReady()).ToNot(BeClosed())
|
||||
connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2)
|
||||
connRunner.EXPECT().Add(gomock.Any(), conn).Times(2)
|
||||
tracer.EXPECT().ReceivedTransportParameters(params)
|
||||
conn.handleTransportParameters(params)
|
||||
Expect(conn.earlyConnReady()).To(BeClosed())
|
||||
@@ -2378,7 +2377,7 @@ var _ = Describe("Client Connection", func() {
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
quicConf = populateClientConfig(&Config{}, true)
|
||||
quicConf = populateConfig(&Config{})
|
||||
tlsConf = nil
|
||||
})
|
||||
|
||||
@@ -2402,6 +2401,7 @@ var _ = Describe("Client Connection", func() {
|
||||
connRunner,
|
||||
destConnID,
|
||||
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
|
||||
&protocol.DefaultConnectionIDGenerator{},
|
||||
quicConf,
|
||||
tlsConf,
|
||||
42, // initial packet number
|
||||
|
||||
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"flag"
|
||||
@@ -57,15 +58,15 @@ func main() {
|
||||
|
||||
var qconf quic.Config
|
||||
if *enableQlog {
|
||||
qconf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
|
||||
qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
filename := fmt.Sprintf("client_%x.qlog", connID)
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("Creating qlog file %s.\n", filename)
|
||||
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
|
||||
})
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
|
||||
}
|
||||
}
|
||||
roundTripper := &http3.RoundTripper{
|
||||
TLSClientConfig: &tls.Config{
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"flag"
|
||||
@@ -162,15 +163,15 @@ func main() {
|
||||
handler := setupHandler(*www)
|
||||
quicConf := &quic.Config{}
|
||||
if *enableQlog {
|
||||
quicConf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
|
||||
quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
filename := fmt.Sprintf("server_%x.qlog", connID)
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("Creating qlog file %s.\n", filename)
|
||||
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
|
||||
})
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
@@ -105,7 +105,7 @@ func main() {
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
|
||||
runner,
|
||||
config,
|
||||
nil,
|
||||
false,
|
||||
utils.NewRTTStats(),
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
|
||||
@@ -390,10 +390,6 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
|
||||
protocol.Version1,
|
||||
)
|
||||
|
||||
var allow0RTT func() bool
|
||||
if enable0RTTServer {
|
||||
allow0RTT = func() bool { return true }
|
||||
}
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
server = handshake.NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
@@ -404,7 +400,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
|
||||
serverTP,
|
||||
runner,
|
||||
serverConf,
|
||||
allow0RTT,
|
||||
enable0RTTServer,
|
||||
utils.NewRTTStats(),
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
|
||||
@@ -288,7 +288,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
|
||||
baseConf := ConfigureTLSConfig(tlsConf)
|
||||
quicConf := s.QuicConfig
|
||||
if quicConf == nil {
|
||||
quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }}
|
||||
quicConf = &quic.Config{Allow0RTT: true}
|
||||
} else {
|
||||
quicConf = s.QuicConfig.Clone()
|
||||
}
|
||||
|
||||
@@ -660,6 +660,7 @@ var _ = Describe("Stream Cancellations", func() {
|
||||
getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2 * 4 * maxIncomingStreams)
|
||||
|
||||
@@ -24,6 +24,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
var drop atomic.Bool
|
||||
dropped := make(chan []byte, 100)
|
||||
@@ -50,6 +51,7 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
sconn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -34,9 +34,23 @@ func (c *connIDGenerator) ConnectionIDLen() int {
|
||||
var _ = Describe("Connection ID lengths tests", func() {
|
||||
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
|
||||
|
||||
runServer := func(conf *quic.Config) *quic.Listener {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf)
|
||||
// connIDLen is ignored when connIDGenerator is set
|
||||
runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) {
|
||||
if connIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen())))
|
||||
} else {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen)))
|
||||
}
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
}
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -55,16 +69,35 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||
}()
|
||||
}
|
||||
}()
|
||||
return ln
|
||||
return ln, func() {
|
||||
ln.Close()
|
||||
tr.Close()
|
||||
}
|
||||
}
|
||||
|
||||
runClient := func(addr net.Addr, conf *quic.Config) {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
|
||||
cl, err := quic.DialAddr(
|
||||
// connIDLen is ignored when connIDGenerator is set
|
||||
runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) {
|
||||
if connIDGenerator != nil {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen())))
|
||||
} else {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen)))
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
}
|
||||
defer tr.Close()
|
||||
cl, err := tr.Dial(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
|
||||
getTLSClientConfig(),
|
||||
conf,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cl.CloseWithError(0, "")
|
||||
@@ -76,32 +109,20 @@ var _ = Describe("Connection ID lengths tests", func() {
|
||||
}
|
||||
|
||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
||||
serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()})
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
|
||||
runClient(ln.Addr(), getQuicConfig(nil))
|
||||
ln, closeFn := runServer(randomConnIDLen(), nil)
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), 0, nil)
|
||||
})
|
||||
|
||||
It("downloads a file when both client and server use a random connection ID length", func() {
|
||||
serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()})
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
|
||||
runClient(ln.Addr(), getQuicConfig(nil))
|
||||
ln, closeFn := runServer(randomConnIDLen(), nil)
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), randomConnIDLen(), nil)
|
||||
})
|
||||
|
||||
It("downloads a file when both client and server use a custom connection ID generator", func() {
|
||||
serverConf := getQuicConfig(&quic.Config{
|
||||
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
|
||||
})
|
||||
clientConf := getQuicConfig(&quic.Config{
|
||||
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
|
||||
})
|
||||
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
|
||||
runClient(ln.Addr(), clientConf)
|
||||
ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()})
|
||||
defer closeFn()
|
||||
runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -22,12 +22,11 @@ var _ = Describe("Datagram test", func() {
|
||||
const num = 100
|
||||
|
||||
var (
|
||||
proxy *quicproxy.QuicProxy
|
||||
serverConn, clientConn *net.UDPConn
|
||||
dropped, total int32
|
||||
)
|
||||
|
||||
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) {
|
||||
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConn, err = net.ListenUDP("udp", addr)
|
||||
@@ -39,8 +38,10 @@ var _ = Describe("Datagram test", func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(accepted)
|
||||
conn, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -67,7 +68,7 @@ var _ = Describe("Datagram test", func() {
|
||||
}()
|
||||
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
// drop 10% of Short Header packets sent from the server
|
||||
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
|
||||
@@ -87,6 +88,11 @@ var _ = Describe("Datagram test", func() {
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return proxy.LocalPort(), func() {
|
||||
Eventually(accepted).Should(BeClosed())
|
||||
proxy.Close()
|
||||
ln.Close()
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
@@ -96,13 +102,10 @@ var _ = Describe("Datagram test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("sends datagrams", func() {
|
||||
startServerAndProxy(true, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
||||
proxyPort, close := startServerAndProxy(true, true)
|
||||
defer close()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
@@ -117,6 +120,7 @@ var _ = Describe("Datagram test", func() {
|
||||
for {
|
||||
// Close the connection if no message is received for 100 ms.
|
||||
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() {
|
||||
fmt.Println("closing conn")
|
||||
conn.CloseWithError(0, "")
|
||||
})
|
||||
if _, err := conn.ReceiveMessage(); err != nil {
|
||||
@@ -134,11 +138,12 @@ var _ = Describe("Datagram test", func() {
|
||||
BeNumerically(">", expVal*9/10),
|
||||
BeNumerically("<", num),
|
||||
))
|
||||
Eventually(conn.Context().Done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("server can disable datagram", func() {
|
||||
startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
||||
proxyPort, close := startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
@@ -150,13 +155,13 @@ var _ = Describe("Datagram test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||
|
||||
close()
|
||||
conn.CloseWithError(0, "")
|
||||
<-time.After(10 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("client can disable datagram", func() {
|
||||
startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
|
||||
proxyPort, close := startServerAndProxy(false, true)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
context.Background(),
|
||||
@@ -169,7 +174,8 @@ var _ = Describe("Datagram test", func() {
|
||||
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
|
||||
|
||||
Expect(conn.SendMessage([]byte{0})).To(HaveOccurred())
|
||||
|
||||
close()
|
||||
conn.CloseWithError(0, "")
|
||||
<-time.After(10 * time.Millisecond)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -24,6 +24,7 @@ var _ = Describe("early data", func() {
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
||||
@@ -8,10 +8,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
go120 = false
|
||||
errNotSupported = errors.New("not supported")
|
||||
)
|
||||
const go120 = false
|
||||
|
||||
var errNotSupported = errors.New("not supported")
|
||||
|
||||
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
|
||||
return errNotSupported
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var go120 = true
|
||||
const go120 = true
|
||||
|
||||
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
|
||||
rc := http.NewResponseController(w)
|
||||
|
||||
@@ -62,13 +62,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
_, err = quic.DialAddr(
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 2)
|
||||
})
|
||||
|
||||
@@ -79,13 +80,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
_, err = quic.DialAddr(
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 1)
|
||||
})
|
||||
|
||||
@@ -97,13 +99,14 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||
|
||||
runProxy(ln.Addr())
|
||||
startTime := time.Now()
|
||||
_, err = quic.DialAddr(
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
expectDurationInRTTs(startTime, 2)
|
||||
})
|
||||
|
||||
@@ -131,6 +134,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
@@ -166,6 +170,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(str)
|
||||
|
||||
@@ -114,7 +114,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
nil,
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
@@ -223,13 +223,14 @@ var _ = Describe("Handshake tests", func() {
|
||||
var (
|
||||
server *quic.Listener
|
||||
pconn net.PacketConn
|
||||
dialer *quic.Transport
|
||||
)
|
||||
|
||||
dial := func() (quic.Connection, error) {
|
||||
remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return quic.Dial(context.Background(), pconn, raddr, getTLSClientConfig(), nil)
|
||||
return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil))
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
@@ -243,11 +244,13 @@ var _ = Describe("Handshake tests", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
pconn, err = net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(server.Close()).To(Succeed())
|
||||
Expect(pconn.Close()).To(Succeed())
|
||||
Expect(dialer.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if connections don't get accepted", func() {
|
||||
@@ -300,7 +303,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
// This should free one spot in the queue.
|
||||
Expect(firstConn.CloseWithError(0, ""))
|
||||
Eventually(firstConn.Context().Done()).Should(BeClosed())
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond))
|
||||
time.Sleep(scaleDuration(200 * time.Millisecond))
|
||||
|
||||
// dial again, and expect that this dial succeeds
|
||||
_, err = dial()
|
||||
@@ -366,6 +369,7 @@ var _ = Describe("Handshake tests", func() {
|
||||
It("uses tokens provided in NEW_TOKEN frames", func() {
|
||||
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer server.Close()
|
||||
|
||||
// dial the first connection and receive the token
|
||||
go func() {
|
||||
@@ -432,6 +436,72 @@ var _ = Describe("Handshake tests", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetConfigForClient", func() {
|
||||
It("uses the quic.Config returned by GetConfigForClient", func() {
|
||||
serverConfig.EnableDatagrams = false
|
||||
var calledFrom net.Addr
|
||||
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
||||
conf := serverConfig.Clone()
|
||||
conf.EnableDatagrams = true
|
||||
calledFrom = info.RemoteAddr
|
||||
return getQuicConfig(conf), nil
|
||||
}
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := ln.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
cs := conn.ConnectionState()
|
||||
Expect(cs.SupportsDatagrams).To(BeTrue())
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port))
|
||||
})
|
||||
|
||||
It("rejects the connection attempt if GetConfigForClient errors", func() {
|
||||
serverConfig.EnableDatagrams = false
|
||||
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
|
||||
return nil, errors.New("rejected")
|
||||
}
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := ln.Accept(context.Background())
|
||||
Expect(err).To(HaveOccurred()) // we don't expect to accept any connection
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = quic.DialAddr(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{EnableDatagrams: true}),
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
var transportErr *quic.TransportError
|
||||
Expect(errors.As(err, &transportErr)).To(BeTrue())
|
||||
Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused))
|
||||
})
|
||||
})
|
||||
|
||||
It("doesn't send any packets when generating the ClientHello fails", func() {
|
||||
ln, err := net.ListenUDP("udp", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -382,6 +382,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
tlsConf.NextProtos = []string{"h3"}
|
||||
ln, err := quic.ListenAddr("localhost:0", tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -398,57 +399,51 @@ var _ = Describe("HTTP tests", func() {
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("supports read deadlines", func() {
|
||||
if !go120 {
|
||||
Skip("This test requires Go 1.20+")
|
||||
}
|
||||
if go120 {
|
||||
It("supports read deadlines", func() {
|
||||
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
|
||||
body, err := io.ReadAll(r.Body)
|
||||
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
|
||||
Expect(body).To(ContainSubstring("aa"))
|
||||
|
||||
w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
expectedEnd := time.Now().Add(deadlineDelay)
|
||||
resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a'))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
|
||||
Expect(body).To(ContainSubstring("aa"))
|
||||
|
||||
w.Write([]byte("ok"))
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||
Expect(string(body)).To(Equal("ok"))
|
||||
})
|
||||
|
||||
expectedEnd := time.Now().Add(deadlineDelay)
|
||||
resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a'))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
It("supports write deadlines", func() {
|
||||
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||
Expect(string(body)).To(Equal("ok"))
|
||||
})
|
||||
_, err = io.Copy(w, neverEnding('a'))
|
||||
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
|
||||
})
|
||||
|
||||
It("supports write deadlines", func() {
|
||||
if !go120 {
|
||||
Skip("This test requires Go 1.20+")
|
||||
}
|
||||
expectedEnd := time.Now().Add(deadlineDelay)
|
||||
|
||||
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
|
||||
resp, err := client.Get("https://localhost:" + port + "/write-deadline")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
_, err = io.Copy(w, neverEnding('a'))
|
||||
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||
Expect(string(body)).To(ContainSubstring("aa"))
|
||||
})
|
||||
|
||||
expectedEnd := time.Now().Add(deadlineDelay)
|
||||
|
||||
resp, err := client.Get("https://localhost:" + port + "/write-deadline")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(time.Now().After(expectedEnd)).To(BeTrue())
|
||||
Expect(string(body)).To(ContainSubstring("aa"))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
@@ -75,7 +75,9 @@ var _ = Describe("Key Update tests", func() {
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return &keyUpdateConnTracer{} })}),
|
||||
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return &keyUpdateConnTracer{}
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
|
||||
@@ -35,7 +35,11 @@ var _ = Describe("MITM test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverUDPConn, err = net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig)
|
||||
tr := &quic.Transport{
|
||||
Conn: serverUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
ln, err := tr.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
@@ -68,7 +72,7 @@ var _ = Describe("MITM test", func() {
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig = getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen})
|
||||
serverConfig = getQuicConfig(nil)
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientUDPConn, err = net.ListenUDP("udp", addr)
|
||||
@@ -146,12 +150,15 @@ var _ = Describe("MITM test", func() {
|
||||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
clientUDPConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
@@ -190,12 +197,15 @@ var _ = Describe("MITM test", func() {
|
||||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := quic.Dial(
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
clientUDPConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
@@ -302,20 +312,20 @@ var _ = Describe("MITM test", func() {
|
||||
const rtt = 20 * time.Millisecond
|
||||
|
||||
runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) {
|
||||
proxyPort, closeFn := startServerAndProxy(delayCb, nil)
|
||||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = quic.Dial(
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
_, err = tr.Dial(
|
||||
context.Background(),
|
||||
clientUDPConn,
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
ConnectionIDLength: connIDLen,
|
||||
HandshakeIdleTimeout: 2 * time.Second,
|
||||
}),
|
||||
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
return closeFn, err
|
||||
return func() { tr.Close(); serverCloseFn() }, err
|
||||
}
|
||||
|
||||
// fails immediately because client connection closes when it can't find compatible version
|
||||
|
||||
@@ -34,10 +34,9 @@ var _ = Describe("Multiplexing", func() {
|
||||
}()
|
||||
}
|
||||
|
||||
dial := func(pconn net.PacketConn, addr net.Addr) {
|
||||
conn, err := quic.Dial(
|
||||
dial := func(tr *quic.Transport, addr net.Addr) {
|
||||
conn, err := tr.Dial(
|
||||
context.Background(),
|
||||
pconn,
|
||||
addr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(nil),
|
||||
@@ -72,17 +71,18 @@ var _ = Describe("Multiplexing", func() {
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
dial(tr, server.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
dial(tr, server.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
@@ -106,17 +106,18 @@ var _ = Describe("Multiplexing", func() {
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server1.Addr())
|
||||
dial(tr, server1.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server2.Addr())
|
||||
dial(tr, server2.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
@@ -135,9 +136,8 @@ var _ = Describe("Multiplexing", func() {
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
|
||||
server, err := quic.Listen(
|
||||
conn,
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
server, err := tr.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
@@ -146,7 +146,7 @@ var _ = Describe("Multiplexing", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
dial(tr, server.Addr())
|
||||
close(done)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
@@ -165,15 +165,16 @@ var _ = Describe("Multiplexing", func() {
|
||||
conn1, err := net.ListenUDP("udp", addr1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.Close()
|
||||
tr1 := &quic.Transport{Conn: conn1}
|
||||
|
||||
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn2, err := net.ListenUDP("udp", addr2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn2.Close()
|
||||
tr2 := &quic.Transport{Conn: conn2}
|
||||
|
||||
server1, err := quic.Listen(
|
||||
conn1,
|
||||
server1, err := tr1.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
@@ -181,8 +182,7 @@ var _ = Describe("Multiplexing", func() {
|
||||
runServer(server1)
|
||||
defer server1.Close()
|
||||
|
||||
server2, err := quic.Listen(
|
||||
conn2,
|
||||
server2, err := tr2.Listen(
|
||||
getTLSConfig(),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
@@ -194,12 +194,12 @@ var _ = Describe("Multiplexing", func() {
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn2, server1.Addr())
|
||||
dial(tr2, server1.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn1, server2.Addr())
|
||||
dial(tr1, server2.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
|
||||
@@ -27,12 +27,12 @@ var _ = Describe("Packetization", func() {
|
||||
getTLSConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
|
||||
Tracer: newTracer(serverTracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
defer server.Close()
|
||||
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
|
||||
RemoteAddr: serverAddr,
|
||||
@@ -50,10 +50,11 @@ var _ = Describe("Packetization", func() {
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
DisablePathMTUDiscovery: true,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
|
||||
Tracer: newTracer(clientTracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
||||
@@ -87,7 +87,7 @@ var (
|
||||
logBuf *syncedBuffer
|
||||
versionParam string
|
||||
|
||||
qlogTracer logging.Tracer
|
||||
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
|
||||
enableQlog bool
|
||||
|
||||
version quic.VersionNumber
|
||||
@@ -175,7 +175,13 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
|
||||
if conf.Tracer == nil {
|
||||
conf.Tracer = qlogTracer
|
||||
} else if qlogTracer != nil {
|
||||
conf.Tracer = logging.NewMultiplexedTracer(qlogTracer, conf.Tracer)
|
||||
origTracer := conf.Tracer
|
||||
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
qlogTracer(ctx, p, connID),
|
||||
origTracer(ctx, p, connID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return conf
|
||||
@@ -199,8 +205,16 @@ func areHandshakesRunning() bool {
|
||||
return strings.Contains(b.String(), "RunHandshake")
|
||||
}
|
||||
|
||||
func areTransportsRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
||||
}
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
Expect(areHandshakesRunning()).To(BeFalse())
|
||||
Eventually(areTransportsRunning).Should(BeFalse())
|
||||
|
||||
if debugLog() {
|
||||
logFile, err := os.Create(logFileName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -224,19 +238,8 @@ func scaleDuration(d time.Duration) time.Duration {
|
||||
return time.Duration(scaleFactor) * d
|
||||
}
|
||||
|
||||
type tracer struct {
|
||||
logging.NullTracer
|
||||
createNewConnTracer func() logging.ConnectionTracer
|
||||
}
|
||||
|
||||
var _ logging.Tracer = &tracer{}
|
||||
|
||||
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
|
||||
return &tracer{createNewConnTracer: c}
|
||||
}
|
||||
|
||||
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
|
||||
return t.createNewConnTracer()
|
||||
func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer }
|
||||
}
|
||||
|
||||
type packet struct {
|
||||
|
||||
@@ -2,7 +2,6 @@ package self_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -25,9 +24,16 @@ var _ = Describe("Stateless Resets", func() {
|
||||
It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() {
|
||||
var statelessResetKey quic.StatelessResetKey
|
||||
rand.Read(statelessResetKey[:])
|
||||
serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey})
|
||||
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
c, err := net.ListenUDP("udp", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: c,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
|
||||
@@ -42,7 +48,8 @@ var _ = Describe("Stateless Resets", func() {
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
<-closeServer
|
||||
ln.Close()
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
Expect(tr.Close()).To(Succeed())
|
||||
}()
|
||||
|
||||
var drop atomic.Bool
|
||||
@@ -55,14 +62,21 @@ var _ = Describe("Stateless Resets", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := quic.DialAddr(
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
cl := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer cl.Close()
|
||||
conn, err := cl.Dial(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
ConnectionIDLength: connIDLen,
|
||||
MaxIdleTimeout: 2 * time.Second,
|
||||
}),
|
||||
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
@@ -77,11 +91,15 @@ var _ = Describe("Stateless Resets", func() {
|
||||
close(closeServer)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ln2, err := quic.ListenAddr(
|
||||
fmt.Sprintf("localhost:%d", serverPort),
|
||||
getTLSConfig(),
|
||||
serverConfig,
|
||||
)
|
||||
// We need to create a new Transport here, since the old one is still sending out
|
||||
// CONNECTION_CLOSE packets for (recently) closed connections).
|
||||
tr2 := &quic.Transport{
|
||||
Conn: c,
|
||||
ConnectionIDLength: connIDLen,
|
||||
StatelessResetKey: &statelessResetKey,
|
||||
}
|
||||
defer tr2.Close()
|
||||
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
drop.Store(false)
|
||||
|
||||
@@ -100,8 +118,7 @@ var _ = Describe("Stateless Resets", func() {
|
||||
_, serr = str.Read([]byte{0})
|
||||
}
|
||||
Expect(serr).To(HaveOccurred())
|
||||
statelessResetErr := &quic.StatelessResetError{}
|
||||
Expect(errors.As(serr, &statelessResetErr)).To(BeTrue())
|
||||
Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{}))
|
||||
Expect(ln2.Close()).To(Succeed())
|
||||
Eventually(acceptStopped).Should(BeClosed())
|
||||
})
|
||||
|
||||
@@ -94,6 +94,8 @@ var _ = Describe("Bidirectional streams", func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(client)
|
||||
client.CloseWithError(0, "")
|
||||
<-conn.Context().Done()
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
@@ -149,5 +151,6 @@ var _ = Describe("Bidirectional streams", func() {
|
||||
runReceivingPeer(client)
|
||||
<-done1
|
||||
<-done2
|
||||
client.CloseWithError(0, "")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -201,7 +201,7 @@ var _ = Describe("Timeout tests", func() {
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tr }),
|
||||
Tracer: newTracer(tr),
|
||||
DisablePathMTUDiscovery: true,
|
||||
}),
|
||||
)
|
||||
@@ -473,6 +473,7 @@ var _ = Describe("Timeout tests", func() {
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
|
||||
@@ -19,14 +19,6 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type customTracer struct{ logging.NullTracer }
|
||||
|
||||
func (t *customTracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
|
||||
return &customConnTracer{}
|
||||
}
|
||||
|
||||
type customConnTracer struct{ logging.NullConnectionTracer }
|
||||
|
||||
var _ = Describe("Handshake tests", func() {
|
||||
addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config {
|
||||
enableQlog := mrand.Int()%3 != 0
|
||||
@@ -34,22 +26,32 @@ var _ = Describe("Handshake tests", func() {
|
||||
|
||||
fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer)
|
||||
|
||||
var tracers []logging.Tracer
|
||||
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
|
||||
if enableQlog {
|
||||
tracers = append(tracers, qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
|
||||
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connectionID)
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID)
|
||||
return nil
|
||||
}
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connectionID)
|
||||
return utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil))
|
||||
}))
|
||||
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connID)
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID)
|
||||
})
|
||||
}
|
||||
if enableCustomTracer {
|
||||
tracers = append(tracers, &customTracer{})
|
||||
tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return logging.NullConnectionTracer{}
|
||||
})
|
||||
}
|
||||
c := conf.Clone()
|
||||
c.Tracer = logging.NewMultiplexedTracer(tracers...)
|
||||
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors))
|
||||
for _, c := range tracerConstructors {
|
||||
if tr := c(ctx, p, connID); tr != nil {
|
||||
tracers = append(tracers, tr)
|
||||
}
|
||||
}
|
||||
return logging.NewMultiplexedConnectionTracer(tracers...)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
|
||||
@@ -88,11 +88,14 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(done)
|
||||
conn, err := server.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(conn)
|
||||
<-conn.Context().Done()
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(
|
||||
@@ -103,6 +106,7 @@ var _ = Describe("Unidirectional Streams", func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(client)
|
||||
client.CloseWithError(0, "")
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
|
||||
|
||||
@@ -54,7 +54,7 @@ var _ = Describe("0-RTT", func() {
|
||||
if serverConf == nil {
|
||||
serverConf = getQuicConfig(nil)
|
||||
}
|
||||
serverConf.Allow0RTT = func(addr net.Addr) bool { return true }
|
||||
serverConf.Allow0RTT = true
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
@@ -101,6 +101,7 @@ var _ = Describe("0-RTT", func() {
|
||||
transfer0RTTData := func(
|
||||
ln *quic.EarlyListener,
|
||||
proxyPort int,
|
||||
connIDLen int,
|
||||
clientTLSConf *tls.Config,
|
||||
clientConf *quic.Config,
|
||||
testdata []byte, // data to transfer
|
||||
@@ -125,13 +126,35 @@ var _ = Describe("0-RTT", func() {
|
||||
if clientConf == nil {
|
||||
clientConf = getQuicConfig(nil)
|
||||
}
|
||||
conn, err := quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxyPort),
|
||||
clientTLSConf,
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var conn quic.EarlyConnection
|
||||
if connIDLen == 0 {
|
||||
var err error
|
||||
conn, err = quic.DialAddrEarly(
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", proxyPort),
|
||||
clientTLSConf,
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
} else {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
defer tr.Close()
|
||||
conn, err = tr.DialEarly(
|
||||
context.Background(),
|
||||
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort},
|
||||
clientTLSConf,
|
||||
clientConf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
defer conn.CloseWithError(0, "")
|
||||
str, err := conn.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -199,8 +222,8 @@ var _ = Describe("0-RTT", func() {
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: func(addr net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -212,8 +235,9 @@ var _ = Describe("0-RTT", func() {
|
||||
transfer0RTTData(
|
||||
ln,
|
||||
proxy.LocalPort(),
|
||||
connIDLen,
|
||||
clientTLSConf,
|
||||
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
|
||||
getQuicConfig(nil),
|
||||
PRData,
|
||||
)
|
||||
|
||||
@@ -252,8 +276,8 @@ var _ = Describe("0-RTT", func() {
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: func(net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -334,8 +358,8 @@ var _ = Describe("0-RTT", func() {
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: func(net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -373,7 +397,7 @@ var _ = Describe("0-RTT", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData)
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
|
||||
|
||||
num0RTT := atomic.LoadUint32(&num0RTTPackets)
|
||||
numDropped := atomic.LoadUint32(&num0RTTDropped)
|
||||
@@ -410,8 +434,8 @@ var _ = Describe("0-RTT", func() {
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
RequireAddressValidation: func(net.Addr) bool { return true },
|
||||
Allow0RTT: func(net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -448,7 +472,7 @@ var _ = Describe("0-RTT", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, GeneratePRData(5000)) // ~5 packets
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
@@ -471,8 +495,8 @@ var _ = Describe("0-RTT", func() {
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: maxStreams + 1,
|
||||
Allow0RTT: func(net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -516,8 +540,8 @@ var _ = Describe("0-RTT", func() {
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingStreams: maxStreams - 1,
|
||||
Allow0RTT: func(net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -544,8 +568,8 @@ var _ = Describe("0-RTT", func() {
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: func(net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -571,8 +595,8 @@ var _ = Describe("0-RTT", func() {
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: func(net.Addr) bool { return false }, // application rejects 0-RTT
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: false, // application rejects 0-RTT
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -592,13 +616,13 @@ var _ = Describe("0-RTT", func() {
|
||||
DescribeTable("flow control limits",
|
||||
func(addFlowControlLimit func(*quic.Config, uint64)) {
|
||||
tracer := newPacketTracer()
|
||||
firstConf := getQuicConfig(&quic.Config{Allow0RTT: func(net.Addr) bool { return true }})
|
||||
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
|
||||
addFlowControlLimit(firstConf, 3)
|
||||
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
|
||||
|
||||
secondConf := getQuicConfig(&quic.Config{
|
||||
Allow0RTT: func(net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
})
|
||||
addFlowControlLimit(secondConf, 100)
|
||||
ln, err := quic.ListenAddrEarly(
|
||||
@@ -675,7 +699,7 @@ var _ = Describe("0-RTT", func() {
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
MaxIncomingUniStreams: 1,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -750,8 +774,8 @@ var _ = Describe("0-RTT", func() {
|
||||
"localhost:0",
|
||||
tlsConf,
|
||||
getQuicConfig(&quic.Config{
|
||||
Allow0RTT: func(net.Addr) bool { return true },
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
|
||||
Allow0RTT: true,
|
||||
Tracer: newTracer(tracer),
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -768,7 +792,7 @@ var _ = Describe("0-RTT", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData)
|
||||
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
|
||||
|
||||
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())
|
||||
|
||||
@@ -2,23 +2,25 @@ package tools
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
func NewQlogger(logger io.Writer) logging.Tracer {
|
||||
return qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
|
||||
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
role := "server"
|
||||
if p == logging.PerspectiveClient {
|
||||
role = "client"
|
||||
}
|
||||
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
|
||||
filename := fmt.Sprintf("log_%x_%s.qlog", connID.Bytes(), role)
|
||||
fmt.Fprintf(logger, "Creating %s.\n", filename)
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
@@ -26,6 +28,6 @@ func NewQlogger(logger io.Writer) logging.Tracer {
|
||||
return nil
|
||||
}
|
||||
bw := bufio.NewWriter(f)
|
||||
return utils.NewBufferedWriteCloser(bw, f)
|
||||
})
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bw, f), p, connID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,7 +85,9 @@ var _ = Describe("Handshake tests", func() {
|
||||
serverConfig := &quic.Config{}
|
||||
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return serverTracer
|
||||
}
|
||||
server, cl := startServer(getTLSConfig(), serverConfig)
|
||||
defer cl()
|
||||
clientTracer := &versionNegotiationTracer{}
|
||||
@@ -93,7 +95,9 @@ var _ = Describe("Handshake tests", func() {
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}),
|
||||
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer {
|
||||
return clientTracer
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
|
||||
@@ -111,10 +115,12 @@ var _ = Describe("Handshake tests", func() {
|
||||
expectedVersion := protocol.SupportedVersions[0]
|
||||
// the server doesn't support the highest supported version, which is the first one the client will try
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig := &quic.Config{}
|
||||
serverConfig.Versions = supportedVersions
|
||||
serverTracer := &versionNegotiationTracer{}
|
||||
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
|
||||
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return serverTracer
|
||||
}
|
||||
server, cl := startServer(getTLSConfig(), serverConfig)
|
||||
defer cl()
|
||||
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
|
||||
@@ -123,9 +129,11 @@ var _ = Describe("Handshake tests", func() {
|
||||
context.Background(),
|
||||
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQlogTracer(&quic.Config{
|
||||
maybeAddQLOGTracer(&quic.Config{
|
||||
Versions: clientVersions,
|
||||
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
|
||||
return clientTracer
|
||||
},
|
||||
}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -47,7 +47,7 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||
context.Background(),
|
||||
proxy.LocalAddr().String(),
|
||||
getTLSClientConfig(),
|
||||
maybeAddQlogTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
|
||||
maybeAddQLOGTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
expectDurationInRTTs(startTime, 1)
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestQuicVersionNegotiation(t *testing.T) {
|
||||
RunSpecs(t, "Version Negotiation Suite")
|
||||
}
|
||||
|
||||
func maybeAddQlogTracer(c *quic.Config) *quic.Config {
|
||||
func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
|
||||
if c == nil {
|
||||
c = &quic.Config{}
|
||||
}
|
||||
@@ -69,22 +69,13 @@ func maybeAddQlogTracer(c *quic.Config) *quic.Config {
|
||||
if c.Tracer == nil {
|
||||
c.Tracer = qlogger
|
||||
} else if qlogger != nil {
|
||||
c.Tracer = logging.NewMultiplexedTracer(qlogger, c.Tracer)
|
||||
origTracer := c.Tracer
|
||||
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
return logging.NewMultiplexedConnectionTracer(
|
||||
qlogger(ctx, p, connID),
|
||||
origTracer(ctx, p, connID),
|
||||
)
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type tracer struct {
|
||||
logging.NullTracer
|
||||
createNewConnTracer func() logging.ConnectionTracer
|
||||
}
|
||||
|
||||
var _ logging.Tracer = &tracer{}
|
||||
|
||||
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
|
||||
return &tracer{createNewConnTracer: c}
|
||||
}
|
||||
|
||||
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
|
||||
return t.createNewConnTracer()
|
||||
}
|
||||
|
||||
28
interface.go
28
interface.go
@@ -239,21 +239,12 @@ type ConnectionIDGenerator interface {
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
type Config struct {
|
||||
// GetConfigForClient is called for incoming connections.
|
||||
// If the error is not nil, the connection attempt is refused.
|
||||
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
Versions []VersionNumber
|
||||
// The length of the connection ID in bytes.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If not set, the interpretation depends on where the Config is used:
|
||||
// If used for dialing an address, a 0 byte connection ID will be used.
|
||||
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
|
||||
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
|
||||
ConnectionIDLength int
|
||||
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
|
||||
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
|
||||
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
|
||||
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
|
||||
ConnectionIDGenerator ConnectionIDGenerator
|
||||
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
||||
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
||||
// If this value is zero, the timeout is set to 5 seconds.
|
||||
@@ -314,9 +305,6 @@ type Config struct {
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
MaxIncomingUniStreams int64
|
||||
// The StatelessResetKey is used to generate stateless reset tokens.
|
||||
// If no key is configured, sending of stateless resets is disabled.
|
||||
StatelessResetKey *StatelessResetKey
|
||||
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
|
||||
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
|
||||
// every half of MaxIdleTimeout, whichever is smaller).
|
||||
@@ -330,13 +318,15 @@ type Config struct {
|
||||
// It has no effect for a client.
|
||||
DisableVersionNegotiationPackets bool
|
||||
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
|
||||
// When set, 0-RTT is enabled. When not set, 0-RTT is disabled.
|
||||
// Only valid for the server.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Allow0RTT func(net.Addr) bool
|
||||
Allow0RTT bool
|
||||
// Enable QUIC datagram support (RFC 9221).
|
||||
EnableDatagrams bool
|
||||
Tracer logging.Tracer
|
||||
Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer
|
||||
}
|
||||
|
||||
type ClientHelloInfo struct {
|
||||
RemoteAddr net.Addr
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about a QUIC connection
|
||||
|
||||
@@ -116,7 +116,7 @@ type cryptoSetup struct {
|
||||
clientHelloWritten bool
|
||||
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
|
||||
zeroRTTParametersChan chan<- *wire.TransportParameters
|
||||
allow0RTT func() bool
|
||||
allow0RTT bool
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
@@ -197,7 +197,7 @@ func NewCryptoSetupServer(
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
allow0RTT func() bool,
|
||||
allow0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
@@ -210,14 +210,13 @@ func NewCryptoSetupServer(
|
||||
tp,
|
||||
runner,
|
||||
tlsConf,
|
||||
allow0RTT != nil,
|
||||
allow0RTT,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
protocol.PerspectiveServer,
|
||||
version,
|
||||
)
|
||||
cs.allow0RTT = allow0RTT
|
||||
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
|
||||
return cs
|
||||
}
|
||||
@@ -253,6 +252,7 @@ func newCryptoSetup(
|
||||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
runner: runner,
|
||||
allow0RTT: enable0RTT,
|
||||
ourParams: tp,
|
||||
paramsChan: extHandler.TransportParameters(),
|
||||
rttStats: rttStats,
|
||||
@@ -503,7 +503,7 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
|
||||
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
|
||||
return false
|
||||
}
|
||||
if !h.allow0RTT() {
|
||||
if !h.allow0RTT {
|
||||
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
runner,
|
||||
testdata.GetTLSConfig(),
|
||||
nil,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
@@ -177,7 +177,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
runner,
|
||||
testdata.GetTLSConfig(),
|
||||
nil,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
@@ -218,7 +218,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
runner,
|
||||
serverConf,
|
||||
nil,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
@@ -253,7 +253,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
NewMockHandshakeRunner(mockCtrl),
|
||||
serverConf,
|
||||
nil,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
@@ -378,10 +378,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
protocol.Version1,
|
||||
)
|
||||
|
||||
var allow0RTT func() bool
|
||||
if enable0RTT {
|
||||
allow0RTT = func() bool { return true }
|
||||
}
|
||||
var sHandshakeComplete bool
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sErrChan := make(chan error, 1)
|
||||
@@ -402,7 +398,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
serverTransportParameters,
|
||||
sRunner,
|
||||
serverConf,
|
||||
allow0RTT,
|
||||
enable0RTT,
|
||||
serverRTTStats,
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
@@ -541,7 +537,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
sTransportParameters,
|
||||
sRunner,
|
||||
serverConf,
|
||||
nil,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
@@ -596,7 +592,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
|
||||
sRunner,
|
||||
serverConf,
|
||||
nil,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
@@ -655,7 +651,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
|
||||
sRunner,
|
||||
serverConf,
|
||||
nil,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
package mocklogging
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
@@ -73,17 +72,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2,
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// TracerForConnection mocks base method.
|
||||
func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(logging.ConnectionTracer)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TracerForConnection indicates an expected call of TracerForConnection.
|
||||
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/interop/http09"
|
||||
"github.com/quic-go/quic-go/interop/utils"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
var errUnsupported = errors.New("unsupported test case")
|
||||
@@ -65,11 +64,7 @@ func runTestcase(testcase string) error {
|
||||
flag.Parse()
|
||||
urls := flag.Args()
|
||||
|
||||
getLogWriter, err := utils.GetQLOGWriter()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
quicConf := &quic.Config{Tracer: qlog.NewTracer(getLogWriter)}
|
||||
quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer}
|
||||
|
||||
if testcase == "http3" {
|
||||
r := &http3.RoundTripper{
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/interop/http09"
|
||||
"github.com/quic-go/quic-go/interop/utils"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
var tlsConf *tls.Config
|
||||
@@ -38,15 +37,10 @@ func main() {
|
||||
|
||||
testcase := os.Getenv("TESTCASE")
|
||||
|
||||
getLogWriter, err := utils.GetQLOGWriter()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
// a quic.Config that doesn't do a Retry
|
||||
quicConf := &quic.Config{
|
||||
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
|
||||
Tracer: qlog.NewTracer(getLogWriter),
|
||||
Allow0RTT: testcase == "zerortt",
|
||||
Tracer: utils.NewQLOGConnectionTracer,
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
|
||||
if err != nil {
|
||||
@@ -59,10 +53,7 @@ func main() {
|
||||
}
|
||||
|
||||
switch testcase {
|
||||
case "zerortt":
|
||||
quicConf.Allow0RTT = func(net.Addr) bool { return true }
|
||||
fallthrough
|
||||
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect":
|
||||
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt":
|
||||
err = runHTTP09Server(quicConf)
|
||||
case "chacha20":
|
||||
reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
|
||||
|
||||
@@ -2,14 +2,17 @@ package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
"github.com/quic-go/quic-go/qlog"
|
||||
)
|
||||
|
||||
// GetSSLKeyLog creates a file for the TLS key log
|
||||
@@ -25,25 +28,23 @@ func GetSSLKeyLog() (io.WriteCloser, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// GetQLOGWriter creates the QLOGDIR and returns the GetLogWriter callback
|
||||
func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.WriteCloser, error) {
|
||||
// NewQLOGConnectionTracer create a qlog file in QLOGDIR
|
||||
func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
|
||||
qlogDir := os.Getenv("QLOGDIR")
|
||||
if len(qlogDir) == 0 {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
if _, err := os.Stat(qlogDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(qlogDir, 0o666); err != nil {
|
||||
return nil, fmt.Errorf("failed to create qlog dir %s: %s", qlogDir, err.Error())
|
||||
log.Fatalf("failed to create qlog dir %s: %v", qlogDir, err)
|
||||
}
|
||||
}
|
||||
return func(_ logging.Perspective, connID []byte) io.WriteCloser {
|
||||
path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID)
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create qlog file %s: %s", path, err.Error())
|
||||
return nil
|
||||
}
|
||||
log.Printf("Created qlog file: %s\n", path)
|
||||
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
|
||||
}, nil
|
||||
path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID)
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create qlog file %s: %s", path, err.Error())
|
||||
return nil
|
||||
}
|
||||
log.Printf("Created qlog file: %s\n", path)
|
||||
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@@ -101,12 +100,6 @@ type ShortHeader struct {
|
||||
|
||||
// A Tracer traces events.
|
||||
type Tracer interface {
|
||||
// TracerForConnection requests a new tracer for a connection.
|
||||
// The ODCID is the original destination connection ID:
|
||||
// The destination connection ID that the client used on the first Initial packet it sent on this connection.
|
||||
// If nil is returned, tracing will be disabled for this connection.
|
||||
TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer
|
||||
|
||||
SentPacket(net.Addr, *Header, ByteCount, []Frame)
|
||||
SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
|
||||
DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
@@ -72,17 +71,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2,
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// TracerForConnection mocks base method.
|
||||
func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(ConnectionTracer)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TracerForConnection indicates an expected call of TracerForConnection.
|
||||
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@@ -23,16 +22,6 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer {
|
||||
return &tracerMultiplexer{tracers}
|
||||
}
|
||||
|
||||
func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer {
|
||||
var connTracers []ConnectionTracer
|
||||
for _, t := range m.tracers {
|
||||
if ct := t.TracerForConnection(ctx, p, odcid); ct != nil {
|
||||
connTracers = append(connTracers, ct)
|
||||
}
|
||||
}
|
||||
return NewMultiplexedConnectionTracer(connTracers...)
|
||||
}
|
||||
|
||||
func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) {
|
||||
for _, t := range m.tracers {
|
||||
t.SentPacket(remote, hdr, size, frames)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
@@ -37,46 +36,6 @@ var _ = Describe("Tracing", func() {
|
||||
tracer = NewMultiplexedTracer(tr1, tr2)
|
||||
})
|
||||
|
||||
It("multiplexes the TracerForConnection call", func() {
|
||||
ctx := context.Background()
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3})
|
||||
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
tracer.TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
})
|
||||
|
||||
It("uses multiple connection tracers", func() {
|
||||
ctx := context.Background()
|
||||
ctr1 := NewMockConnectionTracer(mockCtrl)
|
||||
ctr2 := NewMockConnectionTracer(mockCtrl)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3})
|
||||
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1)
|
||||
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr2)
|
||||
tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID)
|
||||
ctr1.EXPECT().LossTimerCanceled()
|
||||
ctr2.EXPECT().LossTimerCanceled()
|
||||
tr.LossTimerCanceled()
|
||||
})
|
||||
|
||||
It("handles tracers that return a nil ConnectionTracer", func() {
|
||||
ctx := context.Background()
|
||||
ctr1 := NewMockConnectionTracer(mockCtrl)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1)
|
||||
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID)
|
||||
tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID)
|
||||
ctr1.EXPECT().LossTimerCanceled()
|
||||
tr.LossTimerCanceled()
|
||||
})
|
||||
|
||||
It("returns nil when all tracers return a nil ConnectionTracer", func() {
|
||||
ctx := context.Background()
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
|
||||
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
|
||||
Expect(tracer.TracerForConnection(ctx, PerspectiveClient, connID)).To(BeNil())
|
||||
})
|
||||
|
||||
It("traces the PacketSent event", func() {
|
||||
remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)}
|
||||
hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@@ -12,9 +11,6 @@ type NullTracer struct{}
|
||||
|
||||
var _ Tracer = &NullTracer{}
|
||||
|
||||
func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer {
|
||||
return NullConnectionTracer{}
|
||||
}
|
||||
func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {}
|
||||
func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
|
||||
}
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/quic-go/quic-go (interfaces: Multiplexer)
|
||||
|
||||
// Package quic is a generated GoMock package.
|
||||
package quic
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
logging "github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// MockMultiplexer is a mock of Multiplexer interface.
|
||||
type MockMultiplexer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMultiplexerMockRecorder
|
||||
}
|
||||
|
||||
// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer.
|
||||
type MockMultiplexerMockRecorder struct {
|
||||
mock *MockMultiplexer
|
||||
}
|
||||
|
||||
// NewMockMultiplexer creates a new mock instance.
|
||||
func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer {
|
||||
mock := &MockMultiplexer{ctrl: ctrl}
|
||||
mock.recorder = &MockMultiplexerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddConn mocks base method.
|
||||
func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 *StatelessResetKey, arg3 logging.Tracer) (packetHandlerManager, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(packetHandlerManager)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AddConn indicates an expected call of AddConn.
|
||||
func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// RemoveConn mocks base method.
|
||||
func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveConn", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemoveConn indicates an expected call of RemoveConn.
|
||||
func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0)
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa
|
||||
}
|
||||
|
||||
// AddWithConnID mocks base method.
|
||||
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool {
|
||||
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() (packetHandler, bool)) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(bool)
|
||||
@@ -74,6 +74,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 i
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockPacketHandlerManager) Close(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Close", arg0)
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0)
|
||||
}
|
||||
|
||||
// CloseServer mocks base method.
|
||||
func (m *MockPacketHandlerManager) CloseServer() {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -86,20 +98,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
|
||||
}
|
||||
|
||||
// Destroy mocks base method.
|
||||
func (m *MockPacketHandlerManager) Destroy() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Destroy")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Destroy indicates an expected call of Destroy.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy))
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -115,6 +113,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// GetByResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetByResetToken", arg0)
|
||||
ret0, _ := ret[0].(packetHandler)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetByResetToken indicates an expected call of GetByResetToken.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0)
|
||||
}
|
||||
|
||||
// GetStatelessResetToken mocks base method.
|
||||
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -176,15 +189,3 @@ func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0)
|
||||
}
|
||||
|
||||
// SetServer mocks base method.
|
||||
func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetServer", arg0)
|
||||
}
|
||||
|
||||
// SetServer indicates an expected call of SetServer.
|
||||
func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0)
|
||||
}
|
||||
|
||||
@@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler
|
||||
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
|
||||
type PacketHandlerManager = packetHandlerManager
|
||||
|
||||
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer"
|
||||
type Multiplexer = multiplexer
|
||||
|
||||
// Need to use source mode for the batchConn, since reflect mode follows type aliases.
|
||||
// See https://github.com/golang/mock/issues/244 for details.
|
||||
//
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -14,30 +13,19 @@ var (
|
||||
connMuxer multiplexer
|
||||
)
|
||||
|
||||
type indexableConn interface {
|
||||
LocalAddr() net.Addr
|
||||
}
|
||||
type indexableConn interface{ LocalAddr() net.Addr }
|
||||
|
||||
type multiplexer interface {
|
||||
AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error)
|
||||
AddConn(conn indexableConn)
|
||||
RemoveConn(indexableConn) error
|
||||
}
|
||||
|
||||
type connManager struct {
|
||||
connIDLen int
|
||||
statelessResetKey *StatelessResetKey
|
||||
tracer logging.Tracer
|
||||
manager packetHandlerManager
|
||||
}
|
||||
|
||||
// The connMultiplexer listens on multiple net.PacketConns and dispatches
|
||||
// incoming packets to the connection handler.
|
||||
type connMultiplexer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conns map[string] /* LocalAddr().String() */ connManager
|
||||
newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests
|
||||
|
||||
conns map[string] /* LocalAddr().String() */ indexableConn
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
@@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{}
|
||||
func getMultiplexer() multiplexer {
|
||||
connMuxerOnce.Do(func() {
|
||||
connMuxer = &connMultiplexer{
|
||||
conns: make(map[string]connManager),
|
||||
logger: utils.DefaultLogger.WithPrefix("muxer"),
|
||||
newPacketHandlerManager: newPacketHandlerMap,
|
||||
conns: make(map[string]indexableConn),
|
||||
logger: utils.DefaultLogger.WithPrefix("muxer"),
|
||||
}
|
||||
})
|
||||
return connMuxer
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) AddConn(
|
||||
c net.PacketConn,
|
||||
connIDLen int,
|
||||
statelessResetKey *StatelessResetKey,
|
||||
tracer logging.Tracer,
|
||||
) (packetHandlerManager, error) {
|
||||
func (m *connMultiplexer) index(addr net.Addr) string {
|
||||
return addr.Network() + " " + addr.String()
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) AddConn(c indexableConn) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
addr := c.LocalAddr()
|
||||
connIndex := addr.Network() + " " + addr.String()
|
||||
connIndex := m.index(c.LocalAddr())
|
||||
p, ok := m.conns[connIndex]
|
||||
if !ok {
|
||||
manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p = connManager{
|
||||
connIDLen: connIDLen,
|
||||
statelessResetKey: statelessResetKey,
|
||||
manager: manager,
|
||||
tracer: tracer,
|
||||
}
|
||||
m.conns[connIndex] = p
|
||||
} else {
|
||||
if p.connIDLen != connIDLen {
|
||||
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
|
||||
}
|
||||
if statelessResetKey != nil && p.statelessResetKey != statelessResetKey {
|
||||
return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn")
|
||||
}
|
||||
if tracer != p.tracer {
|
||||
return nil, fmt.Errorf("cannot use different tracers on the same packet conn")
|
||||
}
|
||||
if ok {
|
||||
// Panics if we're already listening on this connection.
|
||||
// This is a safeguard because we're introducing a breaking API change, see
|
||||
// https://github.com/quic-go/quic-go/issues/3727 for details.
|
||||
// We'll remove this at a later time, when most users of the library have made the switch.
|
||||
panic("connection already exists") // TODO: write a nice message
|
||||
}
|
||||
return p.manager, nil
|
||||
m.conns[connIndex] = p
|
||||
}
|
||||
|
||||
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
|
||||
connIndex := m.index(c.LocalAddr())
|
||||
if _, ok := m.conns[connIndex]; !ok {
|
||||
return fmt.Errorf("cannote remove connection, connection is unknown")
|
||||
}
|
||||
|
||||
@@ -3,71 +3,24 @@ package quic
|
||||
import (
|
||||
"net"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type testConn struct {
|
||||
counter int
|
||||
net.PacketConn
|
||||
}
|
||||
|
||||
var _ = Describe("Multiplexer", func() {
|
||||
It("adds a new packet conn ", func() {
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
|
||||
_, err := getMultiplexer().AddConn(conn, 8, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
It("adds new packet conns", func() {
|
||||
conn1 := NewMockPacketConn(mockCtrl)
|
||||
conn1.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
|
||||
getMultiplexer().AddConn(conn1)
|
||||
conn2 := NewMockPacketConn(mockCtrl)
|
||||
conn2.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1235})
|
||||
getMultiplexer().AddConn(conn2)
|
||||
})
|
||||
|
||||
It("recognizes when the same connection is added twice", func() {
|
||||
srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}
|
||||
pconn := NewMockPacketConn(mockCtrl)
|
||||
pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
|
||||
pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn := testConn{PacketConn: pconn}
|
||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
_, err := getMultiplexer().AddConn(conn, 8, srk, tracer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.counter++
|
||||
_, err = getMultiplexer().AddConn(conn, 8, srk, tracer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("errors when adding an existing conn with a different connection ID length", func() {
|
||||
It("panics when the same connection is added twice", func() {
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
||||
_, err := getMultiplexer().AddConn(conn, 5, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = getMultiplexer().AddConn(conn, 6, nil, nil)
|
||||
Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs"))
|
||||
})
|
||||
|
||||
It("errors when adding an existing conn with a different stateless rest key", func() {
|
||||
srk1 := &StatelessResetKey{'f', 'o', 'o'}
|
||||
srk2 := &StatelessResetKey{'b', 'a', 'r'}
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
||||
_, err := getMultiplexer().AddConn(conn, 7, srk1, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = getMultiplexer().AddConn(conn, 7, srk2, nil)
|
||||
Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn"))
|
||||
})
|
||||
|
||||
It("errors when adding an existing conn with different tracers", func() {
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
|
||||
_, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
|
||||
Expect(err).To(MatchError("cannot use different tracers on the same packet conn"))
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
|
||||
getMultiplexer().AddConn(conn)
|
||||
Expect(func() { getMultiplexer().AddConn(conn) }).To(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,28 +5,22 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// rawConn is a connection that allow reading of a receivedPacket.
|
||||
// rawConn is a connection that allow reading of a receivedPackeh.
|
||||
type rawConn interface {
|
||||
ReadPacket() (*receivedPacket, error)
|
||||
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error)
|
||||
LocalAddr() net.Addr
|
||||
SetReadDeadline(time.Time) error
|
||||
io.Closer
|
||||
}
|
||||
|
||||
@@ -36,113 +30,49 @@ type closePacket struct {
|
||||
info *packetInfo
|
||||
}
|
||||
|
||||
// The packetHandlerMap stores packetHandlers, identified by connection ID.
|
||||
// It is used:
|
||||
// * by the server to store connections
|
||||
// * when multiplexing outgoing connections to store clients
|
||||
type unknownPacketHandler interface {
|
||||
handlePacket(*receivedPacket)
|
||||
setCloseError(error)
|
||||
}
|
||||
|
||||
var errListenerAlreadySet = errors.New("listener already set")
|
||||
|
||||
type packetHandlerMap struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conn rawConn
|
||||
connIDLen int
|
||||
|
||||
closeQueue chan closePacket
|
||||
|
||||
mutex sync.Mutex
|
||||
handlers map[protocol.ConnectionID]packetHandler
|
||||
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
|
||||
server unknownPacketHandler
|
||||
|
||||
listening chan struct{} // is closed when listen returns
|
||||
closed bool
|
||||
closeChan chan struct{}
|
||||
|
||||
enqueueClosePacket func(closePacket)
|
||||
|
||||
deleteRetiredConnsAfter time.Duration
|
||||
|
||||
statelessResetEnabled bool
|
||||
statelessResetMutex sync.Mutex
|
||||
statelessResetHasher hash.Hash
|
||||
statelessResetMutex sync.Mutex
|
||||
statelessResetHasher hash.Hash
|
||||
|
||||
tracer logging.Tracer
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ packetHandlerManager = &packetHandlerMap{}
|
||||
|
||||
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
||||
conn, ok := c.(interface{ SetReadBuffer(int) error })
|
||||
if !ok {
|
||||
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
|
||||
}
|
||||
size, err := inspectReadBuffer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||
}
|
||||
if size >= protocol.DesiredReceiveBufferSize {
|
||||
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
|
||||
return nil
|
||||
}
|
||||
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
||||
return fmt.Errorf("failed to increase receive buffer size: %w", err)
|
||||
}
|
||||
newSize, err := inspectReadBuffer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||
}
|
||||
if newSize == size {
|
||||
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||
}
|
||||
if newSize < protocol.DesiredReceiveBufferSize {
|
||||
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||
}
|
||||
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
|
||||
return nil
|
||||
}
|
||||
|
||||
// only print warnings about the UDP receive buffer size once
|
||||
var receiveBufferWarningOnce sync.Once
|
||||
|
||||
func newPacketHandlerMap(
|
||||
c net.PacketConn,
|
||||
connIDLen int,
|
||||
statelessResetKey *StatelessResetKey,
|
||||
tracer logging.Tracer,
|
||||
logger utils.Logger,
|
||||
) (packetHandlerManager, error) {
|
||||
if err := setReceiveBuffer(c, logger); err != nil {
|
||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
receiveBufferWarningOnce.Do(func() {
|
||||
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
|
||||
return
|
||||
}
|
||||
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
|
||||
})
|
||||
}
|
||||
}
|
||||
conn, err := wrapConn(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := &packetHandlerMap{
|
||||
conn: conn,
|
||||
connIDLen: connIDLen,
|
||||
listening: make(chan struct{}),
|
||||
func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
|
||||
h := &packetHandlerMap{
|
||||
closeChan: make(chan struct{}),
|
||||
handlers: make(map[protocol.ConnectionID]packetHandler),
|
||||
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
|
||||
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
|
||||
closeQueue: make(chan closePacket, 4),
|
||||
statelessResetEnabled: statelessResetKey != nil,
|
||||
tracer: tracer,
|
||||
enqueueClosePacket: enqueueClosePacket,
|
||||
logger: logger,
|
||||
}
|
||||
if m.statelessResetEnabled {
|
||||
m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:])
|
||||
if key != nil {
|
||||
h.statelessResetHasher = hmac.New(sha256.New, key[:])
|
||||
}
|
||||
go m.listen()
|
||||
go m.runCloseQueue()
|
||||
|
||||
if logger.Debug() {
|
||||
go m.logUsage()
|
||||
if h.logger.Debug() {
|
||||
go h.logUsage()
|
||||
}
|
||||
return m, nil
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) logUsage() {
|
||||
@@ -150,7 +80,7 @@ func (h *packetHandlerMap) logUsage() {
|
||||
var printedZero bool
|
||||
for {
|
||||
select {
|
||||
case <-h.listening:
|
||||
case <-h.closeChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
@@ -192,7 +122,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
@@ -200,7 +130,10 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
|
||||
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
|
||||
return false
|
||||
}
|
||||
conn := fn()
|
||||
conn, ok := fn()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
h.handlers[clientDestConnID] = conn
|
||||
h.handlers[newConnID] = conn
|
||||
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
|
||||
@@ -233,12 +166,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
|
||||
if connClosePacket != nil {
|
||||
handler = newClosedLocalConn(
|
||||
func(addr net.Addr, info *packetInfo) {
|
||||
select {
|
||||
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
|
||||
default:
|
||||
// Oops, we're backlogged.
|
||||
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
||||
}
|
||||
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
|
||||
},
|
||||
pers,
|
||||
h.logger,
|
||||
@@ -265,17 +193,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
|
||||
})
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) runCloseQueue() {
|
||||
for {
|
||||
select {
|
||||
case <-h.listening:
|
||||
return
|
||||
case p := <-h.closeQueue:
|
||||
h.conn.WritePacket(p.payload, p.addr, p.info.OOB())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
|
||||
h.mutex.Lock()
|
||||
h.resetTokens[token] = handler
|
||||
@@ -288,19 +205,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken)
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
|
||||
func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
|
||||
h.mutex.Lock()
|
||||
h.server = s
|
||||
h.mutex.Unlock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
handler, ok := h.resetTokens[token]
|
||||
return handler, ok
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) CloseServer() {
|
||||
h.mutex.Lock()
|
||||
if h.server == nil {
|
||||
h.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
h.server = nil
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range h.handlers {
|
||||
if handler.getPerspective() == protocol.PerspectiveServer {
|
||||
@@ -316,23 +230,16 @@ func (h *packetHandlerMap) CloseServer() {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Destroy closes the underlying connection and waits until listen() has returned.
|
||||
// It does not close active connections.
|
||||
func (h *packetHandlerMap) Destroy() error {
|
||||
if err := h.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
<-h.listening // wait until listening returns
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) close(e error) error {
|
||||
func (h *packetHandlerMap) Close(e error) {
|
||||
h.mutex.Lock()
|
||||
|
||||
if h.closed {
|
||||
h.mutex.Unlock()
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
close(h.closeChan)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range h.handlers {
|
||||
wg.Add(1)
|
||||
@@ -341,89 +248,14 @@ func (h *packetHandlerMap) close(e error) error {
|
||||
wg.Done()
|
||||
}(handler)
|
||||
}
|
||||
|
||||
if h.server != nil {
|
||||
h.server.setCloseError(e)
|
||||
}
|
||||
h.closed = true
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
return getMultiplexer().RemoveConn(h.conn)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) listen() {
|
||||
defer close(h.listening)
|
||||
for {
|
||||
p, err := h.conn.ReadPacket()
|
||||
//nolint:staticcheck // SA1019 ignore this!
|
||||
// TODO: This code is used to ignore wsa errors on Windows.
|
||||
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
|
||||
// See https://github.com/quic-go/quic-go/issues/1737 for details.
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
h.logger.Debugf("Temporary error reading from conn: %w", err)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
h.close(err)
|
||||
return
|
||||
}
|
||||
h.handlePacket(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, h.connIDLen)
|
||||
if err != nil {
|
||||
h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||
if h.tracer != nil {
|
||||
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
p.buffer.MaybeRelease()
|
||||
return
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||
return
|
||||
}
|
||||
if handler, ok := h.handlers[connID]; ok {
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
if !wire.IsLongHeaderPacket(p.data[0]) {
|
||||
go h.maybeSendStatelessReset(p, connID)
|
||||
return
|
||||
}
|
||||
if h.server == nil { // no server set
|
||||
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
||||
return
|
||||
}
|
||||
h.server.handlePacket(p)
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
|
||||
// stateless resets are always short header packets
|
||||
if wire.IsLongHeaderPacket(data[0]) {
|
||||
return false
|
||||
}
|
||||
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
|
||||
return false
|
||||
}
|
||||
|
||||
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
|
||||
if sess, ok := h.resetTokens[token]; ok {
|
||||
h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
|
||||
go sess.destroy(&StatelessResetError{Token: token})
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
var token protocol.StatelessResetToken
|
||||
if !h.statelessResetEnabled {
|
||||
if h.statelessResetHasher == nil {
|
||||
// Return a random stateless reset token.
|
||||
// This token will be sent in the server's transport parameters.
|
||||
// By using a random token, an off-path attacker won't be able to disrupt the connection.
|
||||
@@ -437,24 +269,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID)
|
||||
h.statelessResetMutex.Unlock()
|
||||
return token
|
||||
}
|
||||
|
||||
func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
|
||||
defer p.buffer.Release()
|
||||
if !h.statelessResetEnabled {
|
||||
return
|
||||
}
|
||||
// Don't send a stateless reset in response to very small packets.
|
||||
// This includes packets that could be stateless resets.
|
||||
if len(p.data) <= protocol.MinStatelessResetSize {
|
||||
return
|
||||
}
|
||||
token := h.GetStatelessResetToken(connID)
|
||||
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||
rand.Read(data)
|
||||
data[0] = (data[0] & 0x7f) | 0x40
|
||||
data = append(data, token[:]...)
|
||||
if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
|
||||
h.logger.Debugf("Error sending Stateless Reset: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,405 +6,188 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Packet Handler Map", func() {
|
||||
type packetToRead struct {
|
||||
addr net.Addr
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
var (
|
||||
handler *packetHandlerMap
|
||||
conn *MockPacketConn
|
||||
tracer *mocklogging.MockTracer
|
||||
packetChan chan packetToRead
|
||||
|
||||
connIDLen int
|
||||
statelessResetKey *StatelessResetKey
|
||||
)
|
||||
|
||||
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
|
||||
b, err := (&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
Type: t,
|
||||
DestConnectionID: connID,
|
||||
Length: length,
|
||||
Version: protocol.Version1,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return b
|
||||
}
|
||||
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
statelessResetKey = nil
|
||||
connIDLen = 0
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
packetChan = make(chan packetToRead, 10)
|
||||
It("adds and gets", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
h, ok := m.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(h).To(Equal(handler))
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
conn = NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
|
||||
p, ok := <-packetChan
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
It("refused to add duplicates", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
Expect(m.Add(connID, handler)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("removes", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
m.Remove(connID)
|
||||
_, ok := m.Get(connID)
|
||||
Expect(ok).To(BeFalse())
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("retires", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
dur := scaleDuration(50 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
m.Retire(connID)
|
||||
_, ok := m.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
time.Sleep(dur)
|
||||
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("adds newly to-be-constructed handlers", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
var called bool
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) {
|
||||
called = true
|
||||
return NewMockPacketHandler(mockCtrl), true
|
||||
})).To(BeTrue())
|
||||
Expect(called).To(BeTrue())
|
||||
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) {
|
||||
Fail("didn't expect the constructor to be executed")
|
||||
return nil, false
|
||||
})).To(BeFalse())
|
||||
})
|
||||
|
||||
It("adds, gets and removes reset tokens", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
m.AddResetToken(token, handler)
|
||||
h, ok := m.GetByResetToken(token)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(h).To(Equal(h))
|
||||
m.RemoveResetToken(token)
|
||||
_, ok = m.GetByResetToken(token)
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
|
||||
It("generates stateless reset token, if no key is set", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
token := m.GetStatelessResetToken(connID)
|
||||
for i := 0; i < 1000; i++ {
|
||||
to := m.GetStatelessResetToken(connID)
|
||||
Expect(to).ToNot(Equal(token))
|
||||
token = to
|
||||
}
|
||||
})
|
||||
|
||||
It("generates stateless reset token, if a key is set", func() {
|
||||
var key StatelessResetKey
|
||||
rand.Read(key[:])
|
||||
m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
connID := protocol.ParseConnectionID(b)
|
||||
token := m.GetStatelessResetToken(connID)
|
||||
Expect(token).ToNot(BeZero())
|
||||
Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
|
||||
// generate a new connection ID
|
||||
rand.Read(b)
|
||||
connID2 := protocol.ParseConnectionID(b)
|
||||
Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
|
||||
})
|
||||
|
||||
It("replaces locally closed connections", func() {
|
||||
var closePackets []closePacket
|
||||
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
|
||||
dur := scaleDuration(50 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar"))
|
||||
h, ok := m.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(h).ToNot(Equal(handler))
|
||||
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||
h.handlePacket(&receivedPacket{remoteAddr: addr})
|
||||
Expect(closePackets).To(HaveLen(1))
|
||||
Expect(closePackets[0].addr).To(Equal(addr))
|
||||
Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
|
||||
|
||||
time.Sleep(dur)
|
||||
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("replaces remote closed connections", func() {
|
||||
var closePackets []closePacket
|
||||
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
|
||||
dur := scaleDuration(50 * time.Millisecond)
|
||||
m.deleteRetiredConnsAfter = dur
|
||||
|
||||
handler := NewMockPacketHandler(mockCtrl)
|
||||
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(m.Add(connID, handler)).To(BeTrue())
|
||||
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil)
|
||||
h, ok := m.Get(connID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(h).ToNot(Equal(handler))
|
||||
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||
h.handlePacket(&receivedPacket{remoteAddr: addr})
|
||||
Expect(closePackets).To(BeEmpty())
|
||||
|
||||
time.Sleep(dur)
|
||||
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("closes the server", func() {
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
if i%2 == 0 {
|
||||
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
||||
} else {
|
||||
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
||||
conn.EXPECT().shutdown()
|
||||
}
|
||||
return copy(b, p.data), p.addr, p.err
|
||||
}).AnyTimes()
|
||||
phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler = phm.(*packetHandlerMap)
|
||||
b := make([]byte, 12)
|
||||
rand.Read(b)
|
||||
m.Add(protocol.ParseConnectionID(b), conn)
|
||||
}
|
||||
m.CloseServer()
|
||||
})
|
||||
|
||||
It("closes", func() {
|
||||
getMultiplexer() // make the sync.Once execute
|
||||
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
|
||||
mockMultiplexer := NewMockMultiplexer(mockCtrl)
|
||||
origMultiplexer := connMuxer
|
||||
connMuxer = mockMultiplexer
|
||||
|
||||
defer func() {
|
||||
connMuxer = origMultiplexer
|
||||
}()
|
||||
|
||||
testErr := errors.New("test error ")
|
||||
conn1 := NewMockPacketHandler(mockCtrl)
|
||||
conn1.EXPECT().destroy(testErr)
|
||||
conn2 := NewMockPacketHandler(mockCtrl)
|
||||
conn2.EXPECT().destroy(testErr)
|
||||
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1)
|
||||
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2)
|
||||
mockMultiplexer.EXPECT().RemoveConn(gomock.Any())
|
||||
handler.close(testErr)
|
||||
close(packetChan)
|
||||
Eventually(handler.listening).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("other operations", func() {
|
||||
AfterEach(func() {
|
||||
// delete connections and the server before closing
|
||||
// They might be mock implementations, and we'd have to register the expected calls before otherwise.
|
||||
handler.mutex.Lock()
|
||||
for connID := range handler.handlers {
|
||||
delete(handler.handlers, connID)
|
||||
}
|
||||
handler.server = nil
|
||||
handler.mutex.Unlock()
|
||||
conn.EXPECT().Close().MaxTimes(1)
|
||||
close(packetChan)
|
||||
handler.Destroy()
|
||||
Eventually(handler.listening).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
BeforeEach(func() {
|
||||
connIDLen = 5
|
||||
})
|
||||
|
||||
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
||||
packetHandler1 := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler2 := NewMockPacketHandler(mockCtrl)
|
||||
handledPacket1 := make(chan struct{})
|
||||
handledPacket2 := make(chan struct{})
|
||||
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID1))
|
||||
close(handledPacket1)
|
||||
})
|
||||
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID2))
|
||||
close(handledPacket2)
|
||||
})
|
||||
handler.Add(connID1, packetHandler1)
|
||||
handler.Add(connID2, packetHandler2)
|
||||
packetChan <- packetToRead{data: getPacket(connID1)}
|
||||
packetChan <- packetToRead{data: getPacket(connID2)}
|
||||
|
||||
Eventually(handledPacket1).Should(BeClosed())
|
||||
Eventually(handledPacket2).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("drops unparseable packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: []byte{0, 1, 2, 3},
|
||||
})
|
||||
})
|
||||
|
||||
It("deletes removed connections immediately", func() {
|
||||
handler.deleteRetiredConnsAfter = time.Hour
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||
handler.Remove(connID)
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
})
|
||||
|
||||
It("deletes retired connection entries after a wait time", func() {
|
||||
handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
handler.Add(connID, conn)
|
||||
handler.Retire(connID)
|
||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
// don't EXPECT any calls to handlePacket of the MockPacketHandler
|
||||
})
|
||||
|
||||
It("passes packets arriving late for closed connections to that connection", func() {
|
||||
handler.deleteRetiredConnsAfter = time.Hour
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
handled := make(chan struct{})
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
close(handled)
|
||||
})
|
||||
handler.Add(connID, packetHandler)
|
||||
handler.Retire(connID)
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
Eventually(handled).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("drops packets for unknown receivers", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
|
||||
})
|
||||
|
||||
It("closes the packet handlers when reading from the conn fails", func() {
|
||||
done := make(chan struct{})
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
|
||||
Expect(e).To(HaveOccurred())
|
||||
close(done)
|
||||
})
|
||||
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
|
||||
packetChan <- packetToRead{err: errors.New("read failed")}
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("continues listening for temporary errors", func() {
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
|
||||
err := deadlineError{}
|
||||
Expect(err.Temporary()).To(BeTrue())
|
||||
packetChan <- packetToRead{err: err}
|
||||
// don't EXPECT any calls to packetHandler.destroy
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
It("says if a connection ID is already taken", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
|
||||
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
|
||||
})
|
||||
|
||||
It("says if a connection ID is already taken, for AddWithConnID", func() {
|
||||
clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
|
||||
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
|
||||
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("running a server", func() {
|
||||
It("adds a server", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
||||
p := getPacket(connID)
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
cid, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cid).To(Equal(connID))
|
||||
})
|
||||
handler.SetServer(server)
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
|
||||
It("closes all server connections", func() {
|
||||
handler.SetServer(NewMockUnknownPacketHandler(mockCtrl))
|
||||
clientConn := NewMockPacketHandler(mockCtrl)
|
||||
clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
|
||||
serverConn := NewMockPacketHandler(mockCtrl)
|
||||
serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
|
||||
serverConn.EXPECT().shutdown()
|
||||
|
||||
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn)
|
||||
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn)
|
||||
handler.CloseServer()
|
||||
})
|
||||
|
||||
It("stops handling packets with unknown connection IDs after the server is closed", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
|
||||
p := getPacket(connID)
|
||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||
// don't EXPECT any calls to server.handlePacket
|
||||
handler.SetServer(server)
|
||||
handler.CloseServer()
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
})
|
||||
|
||||
Context("stateless resets", func() {
|
||||
BeforeEach(func() {
|
||||
connIDLen = 5
|
||||
})
|
||||
|
||||
Context("handling", func() {
|
||||
It("handles stateless resets", func() {
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddResetToken(token, packetHandler)
|
||||
destroyed := make(chan struct{})
|
||||
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
|
||||
packet = append(packet, token[:]...)
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
defer GinkgoRecover()
|
||||
defer close(destroyed)
|
||||
Expect(err).To(HaveOccurred())
|
||||
var resetErr *StatelessResetError
|
||||
Expect(errors.As(err, &resetErr)).To(BeTrue())
|
||||
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
|
||||
Expect(resetErr.Token).To(Equal(token))
|
||||
})
|
||||
packetChan <- packetToRead{data: packet}
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("handles stateless resets for 0-length connection IDs", func() {
|
||||
handler.connIDLen = 0
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddResetToken(token, packetHandler)
|
||||
destroyed := make(chan struct{})
|
||||
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
|
||||
packet = append(packet, token[:]...)
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(err).To(HaveOccurred())
|
||||
var resetErr *StatelessResetError
|
||||
Expect(errors.As(err, &resetErr)).To(BeTrue())
|
||||
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
|
||||
Expect(resetErr.Token).To(Equal(token))
|
||||
close(destroyed)
|
||||
})
|
||||
packetChan <- packetToRead{data: packet}
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("removes reset tokens", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
handler.Add(connID, packetHandler)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
|
||||
handler.RemoveResetToken(token)
|
||||
// don't EXPECT any call to packetHandler.destroy()
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any())
|
||||
p := append([]byte{0x40} /* short header packet */, connID.Bytes()...)
|
||||
p = append(p, make([]byte, 50)...)
|
||||
p = append(p, token[:]...)
|
||||
|
||||
handler.handlePacket(&receivedPacket{data: p})
|
||||
})
|
||||
|
||||
It("ignores packets too small to contain a stateless reset", func() {
|
||||
handler.connIDLen = 0
|
||||
packetHandler := NewMockPacketHandler(mockCtrl)
|
||||
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
handler.AddResetToken(token, packetHandler)
|
||||
done := make(chan struct{})
|
||||
// don't EXPECT any calls here, but register the closing of the done channel
|
||||
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) {
|
||||
close(done)
|
||||
}).AnyTimes()
|
||||
packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)}
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("generating", func() {
|
||||
BeforeEach(func() {
|
||||
var key StatelessResetKey
|
||||
rand.Read(key[:])
|
||||
statelessResetKey = &key
|
||||
})
|
||||
|
||||
It("generates stateless reset tokens", func() {
|
||||
connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
|
||||
connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})
|
||||
Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2)))
|
||||
})
|
||||
|
||||
It("sends stateless resets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, 100)...)
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
|
||||
defer close(done)
|
||||
Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet
|
||||
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
|
||||
})
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't send stateless resets for small packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
// make sure there are no Write calls on the packet conn
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
})
|
||||
|
||||
Context("if no key is configured", func() {
|
||||
It("doesn't send stateless resets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
p := append([]byte{40}, make([]byte, 100)...)
|
||||
handler.handlePacket(&receivedPacket{
|
||||
buffer: getPacketBuffer(),
|
||||
remoteAddr: addr,
|
||||
data: p,
|
||||
})
|
||||
// make sure there are no Write calls on the packet conn
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
})
|
||||
})
|
||||
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
|
||||
testErr := errors.New("shutdown")
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
conn.EXPECT().destroy(testErr)
|
||||
b := make([]byte, 12)
|
||||
rand.Read(b)
|
||||
m.Add(protocol.ParseConnectionID(b), conn)
|
||||
}
|
||||
m.Close(testErr)
|
||||
// check that Close can be called multiple times
|
||||
m.Close(errors.New("close"))
|
||||
})
|
||||
})
|
||||
|
||||
21
qlog/qlog.go
21
qlog/qlog.go
@@ -2,7 +2,6 @@ package qlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -49,26 +48,6 @@ func init() {
|
||||
|
||||
const eventChanSize = 50
|
||||
|
||||
type tracer struct {
|
||||
logging.NullTracer
|
||||
|
||||
getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser
|
||||
}
|
||||
|
||||
var _ logging.Tracer = &tracer{}
|
||||
|
||||
// NewTracer creates a new qlog tracer.
|
||||
func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer {
|
||||
return &tracer{getLogWriter: getLogWriter}
|
||||
}
|
||||
|
||||
func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer {
|
||||
if w := t.getLogWriter(p, odcid.Bytes()); w != nil {
|
||||
return NewConnectionTracer(w, p, odcid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type connectionTracer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package qlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -51,17 +50,6 @@ type entry struct {
|
||||
}
|
||||
|
||||
var _ = Describe("Tracing", func() {
|
||||
Context("tracer", func() {
|
||||
It("returns nil when there's no io.WriteCloser", func() {
|
||||
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil })
|
||||
Expect(t.TracerForConnection(
|
||||
context.Background(),
|
||||
logging.PerspectiveClient,
|
||||
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
It("stops writing when encountering an error", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
t := NewConnectionTracer(
|
||||
@@ -88,9 +76,8 @@ var _ = Describe("Tracing", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) })
|
||||
tracer = t.TracerForConnection(
|
||||
context.Background(),
|
||||
tracer = NewConnectionTracer(
|
||||
nopWriteCloser(buf),
|
||||
logging.PerspectiveServer,
|
||||
protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
|
||||
)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -29,6 +32,20 @@ var _ = BeforeSuite(func() {
|
||||
log.SetOutput(io.Discard)
|
||||
})
|
||||
|
||||
func areServersRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
|
||||
}
|
||||
|
||||
func areTransportsRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
|
||||
}
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
mockCtrl.Finish()
|
||||
Eventually(areServersRunning).Should(BeFalse())
|
||||
Eventually(areTransportsRunning()).Should(BeFalse())
|
||||
})
|
||||
|
||||
23
send_conn.go
23
send_conn.go
@@ -22,7 +22,7 @@ type sconn struct {
|
||||
|
||||
var _ sendConn = &sconn{}
|
||||
|
||||
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn {
|
||||
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn {
|
||||
return &sconn{
|
||||
rawConn: c,
|
||||
remoteAddr: remote,
|
||||
@@ -51,24 +51,3 @@ func (c *sconn) LocalAddr() net.Addr {
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
type spconn struct {
|
||||
net.PacketConn
|
||||
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
var _ sendConn = &spconn{}
|
||||
|
||||
func newSendPconn(c net.PacketConn, remote net.Addr) sendConn {
|
||||
return &spconn{PacketConn: c, remoteAddr: remote}
|
||||
}
|
||||
|
||||
func (c *spconn) Write(p []byte) error {
|
||||
_, err := c.WriteTo(p, c.remoteAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *spconn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
@@ -17,7 +17,9 @@ var _ = Describe("Connection (for sending packets)", func() {
|
||||
BeforeEach(func() {
|
||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
packetConn = NewMockPacketConn(mockCtrl)
|
||||
c = newSendPconn(packetConn, addr)
|
||||
rawConn, err := wrapConn(packetConn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c = newSendConn(rawConn, addr, nil)
|
||||
})
|
||||
|
||||
It("writes", func() {
|
||||
|
||||
224
server.go
224
server.go
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
|
||||
var ErrServerClosed = errors.New("quic: Server closed")
|
||||
var ErrServerClosed = errors.New("quic: server closed")
|
||||
|
||||
// packetHandler handles packets
|
||||
type packetHandler interface {
|
||||
@@ -30,18 +30,13 @@ type packetHandler interface {
|
||||
getPerspective() protocol.Perspective
|
||||
}
|
||||
|
||||
type unknownPacketHandler interface {
|
||||
handlePacket(*receivedPacket)
|
||||
setCloseError(error)
|
||||
}
|
||||
|
||||
type packetHandlerManager interface {
|
||||
Get(protocol.ConnectionID) (packetHandler, bool)
|
||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
|
||||
Destroy() error
|
||||
connRunner
|
||||
SetServer(unknownPacketHandler)
|
||||
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
|
||||
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
|
||||
Close(error)
|
||||
CloseServer()
|
||||
connRunner
|
||||
}
|
||||
|
||||
type quicConn interface {
|
||||
@@ -70,13 +65,12 @@ type baseServer struct {
|
||||
config *Config
|
||||
|
||||
conn rawConn
|
||||
// If the server is started with ListenAddr, we create a packet conn.
|
||||
// If it is started with Listen, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
|
||||
tokenGenerator *handshake.TokenGenerator
|
||||
|
||||
connHandler packetHandlerManager
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
connHandler packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
receivedPackets chan *receivedPacket
|
||||
|
||||
@@ -92,6 +86,7 @@ type baseServer struct {
|
||||
protocol.ConnectionID, /* client dest connection ID */
|
||||
protocol.ConnectionID, /* destination connection ID */
|
||||
protocol.ConnectionID, /* source connection ID */
|
||||
ConnectionIDGenerator,
|
||||
protocol.StatelessResetToken,
|
||||
*Config,
|
||||
*tls.Config,
|
||||
@@ -111,11 +106,11 @@ type baseServer struct {
|
||||
connQueue chan quicConn
|
||||
connQueueLen int32 // to be used as an atomic
|
||||
|
||||
tracer logging.Tracer
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ unknownPacketHandler = &baseServer{}
|
||||
|
||||
// A Listener listens for incoming QUIC connections.
|
||||
// It returns connections once the handshake has completed.
|
||||
type Listener struct {
|
||||
@@ -166,37 +161,36 @@ func (l *EarlyListener) Addr() net.Addr {
|
||||
// The tls.Config must not be nil and must contain a certificate configuration.
|
||||
// The quic.Config may be nil, in that case the default values will be used.
|
||||
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
|
||||
s, err := listenAddr(addr, tlsConf, config, false)
|
||||
conn, err := listenUDP(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{baseServer: s}, nil
|
||||
return (&Transport{
|
||||
Conn: conn,
|
||||
createdConn: true,
|
||||
isSingleUse: true,
|
||||
}).Listen(tlsConf, config)
|
||||
}
|
||||
|
||||
// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
|
||||
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
|
||||
s, err := listenAddr(addr, tlsConf, config, true)
|
||||
conn, err := listenUDP(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EarlyListener{baseServer: s}, nil
|
||||
return (&Transport{
|
||||
Conn: conn,
|
||||
createdConn: true,
|
||||
isSingleUse: true,
|
||||
}).ListenEarly(tlsConf, config)
|
||||
}
|
||||
|
||||
func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
|
||||
func listenUDP(addr string) (*net.UDPConn, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serv, err := listen(conn, tlsConf, config, acceptEarly)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serv.createdPacketConn = true
|
||||
return serv, nil
|
||||
return net.ListenUDP("udp", udpAddr)
|
||||
}
|
||||
|
||||
// Listen listens for QUIC connections on a given net.PacketConn. If the
|
||||
@@ -210,67 +204,51 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo
|
||||
// Furthermore, it must define an application control (using NextProtos).
|
||||
// The quic.Config may be nil, in that case the default values will be used.
|
||||
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
|
||||
s, err := listen(conn, tlsConf, config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{baseServer: s}, nil
|
||||
tr := &Transport{Conn: conn, isSingleUse: true}
|
||||
return tr.Listen(tlsConf, config)
|
||||
}
|
||||
|
||||
// ListenEarly works like Listen, but it returns connections before the handshake completes.
|
||||
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
|
||||
s, err := listen(conn, tlsConf, config, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EarlyListener{baseServer: s}, nil
|
||||
tr := &Transport{Conn: conn, isSingleUse: true}
|
||||
return tr.ListenEarly(tlsConf, config)
|
||||
}
|
||||
|
||||
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config = populateServerConfig(config)
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
|
||||
}
|
||||
}
|
||||
|
||||
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func newServer(
|
||||
conn rawConn,
|
||||
connHandler packetHandlerManager,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
tracer logging.Tracer,
|
||||
onClose func(),
|
||||
acceptEarly bool,
|
||||
) (*baseServer, error) {
|
||||
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := wrapConn(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &baseServer{
|
||||
conn: c,
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
tokenGenerator: tokenGenerator,
|
||||
connIDGenerator: connIDGenerator,
|
||||
connHandler: connHandler,
|
||||
connQueue: make(chan quicConn),
|
||||
errorChan: make(chan struct{}),
|
||||
running: make(chan struct{}),
|
||||
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
|
||||
newConn: newConnection,
|
||||
tracer: tracer,
|
||||
logger: utils.DefaultLogger.WithPrefix("server"),
|
||||
acceptEarlyConns: acceptEarly,
|
||||
onClose: onClose,
|
||||
}
|
||||
if acceptEarly {
|
||||
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
|
||||
}
|
||||
go s.run()
|
||||
connHandler.SetServer(s)
|
||||
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
||||
return s, nil
|
||||
}
|
||||
@@ -322,18 +300,12 @@ func (s *baseServer) Close() error {
|
||||
if s.serverError == nil {
|
||||
s.serverError = ErrServerClosed
|
||||
}
|
||||
// If the server was started with ListenAddr, we created the packet conn.
|
||||
// We need to close it in order to make the go routine reading from that conn return.
|
||||
createdPacketConn := s.createdPacketConn
|
||||
s.closed = true
|
||||
close(s.errorChan)
|
||||
s.mutex.Unlock()
|
||||
|
||||
<-s.running
|
||||
s.connHandler.CloseServer()
|
||||
if createdPacketConn {
|
||||
return s.connHandler.Destroy()
|
||||
}
|
||||
s.onClose()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -358,8 +330,8 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
|
||||
case s.receivedPackets <- p:
|
||||
default:
|
||||
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -371,8 +343,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
||||
|
||||
if wire.IsVersionNegotiationPacket(p.data) {
|
||||
s.logger.Debugf("Dropping Version Negotiation packet.")
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -385,16 +357,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
||||
if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) {
|
||||
if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize {
|
||||
s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size())
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
|
||||
if err != nil { // should never happen
|
||||
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -406,8 +378,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
||||
|
||||
if wire.Is0RTTPacket(p.data) {
|
||||
if !s.acceptEarlyConns {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -418,16 +390,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
||||
// The header will then be parsed again.
|
||||
hdr, _, _, err := wire.ParsePacket(p.data)
|
||||
if err != nil {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
s.logger.Debugf("Error parsing packet: %s", err)
|
||||
return false
|
||||
}
|
||||
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
|
||||
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -437,8 +409,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
||||
// There's little point in sending a Stateless Reset, since the client
|
||||
// might not have received the token yet.
|
||||
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -456,8 +428,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
|
||||
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
if err != nil {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -470,8 +442,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
||||
|
||||
if q, ok := s.zeroRTTQueues[connID]; ok {
|
||||
if len(q.packets) >= protocol.Max0RTTQueueLen {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -480,8 +452,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
|
||||
}
|
||||
|
||||
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -508,8 +480,8 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
|
||||
continue
|
||||
}
|
||||
for _, p := range q.packets {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
p.buffer.Release()
|
||||
}
|
||||
@@ -544,8 +516,8 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
|
||||
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error {
|
||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
p.buffer.Release()
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
|
||||
}
|
||||
return errors.New("too short connection ID")
|
||||
}
|
||||
@@ -617,26 +589,31 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
connID, err := s.connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debugf("Changing connection ID to %s.", connID)
|
||||
var conn quicConn
|
||||
tracingID := nextConnTracingID()
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
|
||||
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
|
||||
config := s.config
|
||||
if s.config.GetConfigForClient != nil {
|
||||
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
|
||||
if err != nil {
|
||||
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
|
||||
return nil, false
|
||||
}
|
||||
config = populateConfig(conf)
|
||||
}
|
||||
var tracer logging.ConnectionTracer
|
||||
if s.config.Tracer != nil {
|
||||
if config.Tracer != nil {
|
||||
// Use the same connection ID that is passed to the client's GetLogWriter callback.
|
||||
connID := hdr.DestConnectionID
|
||||
if origDestConnID.Len() > 0 {
|
||||
connID = origDestConnID
|
||||
}
|
||||
tracer = s.config.Tracer.TracerForConnection(
|
||||
context.WithValue(context.Background(), ConnectionTracingKey, tracingID),
|
||||
protocol.PerspectiveServer,
|
||||
connID,
|
||||
)
|
||||
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
|
||||
}
|
||||
conn = s.newConn(
|
||||
newSendConn(s.conn, p.remoteAddr, p.info),
|
||||
@@ -646,8 +623,9 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
||||
hdr.DestConnectionID,
|
||||
hdr.SrcConnectionID,
|
||||
connID,
|
||||
s.connIDGenerator,
|
||||
s.connHandler.GetStatelessResetToken(connID),
|
||||
s.config,
|
||||
config,
|
||||
s.tlsConf,
|
||||
s.tokenGenerator,
|
||||
clientAddrIsValid,
|
||||
@@ -665,10 +643,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
|
||||
delete(s.zeroRTTQueues, hdr.DestConnectionID)
|
||||
}
|
||||
|
||||
return conn
|
||||
return conn, true
|
||||
}); !added {
|
||||
// TODO: don't just drop the packet
|
||||
// Properly reject the connection attempt.
|
||||
go func() {
|
||||
defer p.buffer.Release()
|
||||
if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil {
|
||||
s.logger.Debugf("Error rejecting connection: %s", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
go conn.run()
|
||||
@@ -712,7 +694,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
|
||||
// Log the Initial packet now.
|
||||
// If no Retry is sent, the packet will be logged by the connection.
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
||||
srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
srcConnID, err := s.connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -741,8 +723,8 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
|
||||
// append the Retry integrity tag
|
||||
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
|
||||
buf.Data = append(buf.Data, tag[:]...)
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
|
||||
if s.tracer != nil {
|
||||
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
|
||||
}
|
||||
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
|
||||
return err
|
||||
@@ -755,8 +737,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
|
||||
data := p.data[:hdr.ParsedLen()+hdr.Length]
|
||||
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
|
||||
if err != nil {
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
// don't return the error here. Just drop the packet.
|
||||
return nil
|
||||
@@ -764,8 +746,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
|
||||
hdrLen := extHdr.ParsedLen()
|
||||
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
|
||||
// don't return the error here. Just drop the packet.
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -818,8 +800,8 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
|
||||
|
||||
replyHdr.Log(s.logger)
|
||||
wire.LogFrame(s.logger, ccf, true)
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
|
||||
if s.tracer != nil {
|
||||
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
|
||||
}
|
||||
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
|
||||
return err
|
||||
@@ -829,8 +811,8 @@ func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest pro
|
||||
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
|
||||
|
||||
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions)
|
||||
if s.tracer != nil {
|
||||
s.tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions)
|
||||
}
|
||||
if _, err := s.conn.WritePacket(data, remote, oob); err != nil {
|
||||
s.logger.Debugf("Error sending Version Negotiation: %s", err)
|
||||
|
||||
216
server_test.go
216
server_test.go
@@ -1,15 +1,12 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -24,17 +21,10 @@ import (
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func areServersRunning() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
|
||||
}
|
||||
|
||||
var _ = Describe("Server", func() {
|
||||
var (
|
||||
conn *MockPacketConn
|
||||
@@ -96,15 +86,19 @@ var _ = Describe("Server", func() {
|
||||
BeforeEach(func() {
|
||||
conn = NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1)
|
||||
wait := make(chan struct{})
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
|
||||
<-wait
|
||||
return 0, nil, errors.New("done")
|
||||
}).MaxTimes(1)
|
||||
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) {
|
||||
close(wait)
|
||||
conn.EXPECT().SetReadDeadline(time.Time{})
|
||||
}).MaxTimes(1)
|
||||
tlsConf = testdata.GetTLSConfig()
|
||||
tlsConf.NextProtos = []string{"proto1"}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Eventually(areServersRunning).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("errors when no tls.Config is given", func() {
|
||||
_, err := ListenAddr("localhost:0", nil, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -114,7 +108,7 @@ var _ = Describe("Server", func() {
|
||||
It("errors when the Config contains an invalid version", func() {
|
||||
version := protocol.VersionNumber(0x1234)
|
||||
_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
|
||||
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
|
||||
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
|
||||
})
|
||||
|
||||
It("fills in default values if options are not set in the Config", func() {
|
||||
@@ -138,7 +132,6 @@ var _ = Describe("Server", func() {
|
||||
HandshakeIdleTimeout: 1337 * time.Hour,
|
||||
MaxIdleTimeout: 42 * time.Minute,
|
||||
KeepAlivePeriod: 5 * time.Second,
|
||||
StatelessResetKey: &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'},
|
||||
RequireAddressValidation: requireAddrVal,
|
||||
}
|
||||
ln, err := Listen(conn, tlsConf, &config)
|
||||
@@ -150,7 +143,6 @@ var _ = Describe("Server", func() {
|
||||
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
|
||||
Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
|
||||
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
|
||||
Expect(server.config.StatelessResetKey).To(Equal(&StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}))
|
||||
// stop the listener
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
})
|
||||
@@ -178,6 +170,7 @@ var _ = Describe("Server", func() {
|
||||
|
||||
Context("server accepting connections that completed the handshake", func() {
|
||||
var (
|
||||
tr *Transport
|
||||
serv *baseServer
|
||||
phm *MockPacketHandlerManager
|
||||
tracer *mocklogging.MockTracer
|
||||
@@ -185,7 +178,8 @@ var _ = Describe("Server", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer})
|
||||
tr = &Transport{Conn: conn, Tracer: tracer}
|
||||
ln, err := tr.Listen(tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serv = ln.baseServer
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
@@ -193,8 +187,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
serv.Close()
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
@@ -274,16 +267,15 @@ var _ = Describe("Server", func() {
|
||||
var newConnID protocol.ConnectionID
|
||||
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
@@ -293,6 +285,7 @@ var _ = Describe("Server", func() {
|
||||
clientDestConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
tokenP protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -474,17 +467,16 @@ var _ = Describe("Server", func() {
|
||||
var newConnID protocol.ConnectionID
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().Get(connID),
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
newConnID = c
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
|
||||
newConnID = c
|
||||
return token
|
||||
})
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}),
|
||||
)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
|
||||
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
@@ -495,6 +487,7 @@ var _ = Describe("Server", func() {
|
||||
clientDestConnID protocol.ConnectionID,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
tokenP protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -537,12 +530,11 @@ var _ = Describe("Server", func() {
|
||||
|
||||
It("drops packets if the receive queue is full", func() {
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).AnyTimes()
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
|
||||
|
||||
acceptConn := make(chan struct{})
|
||||
var counter uint32 // to be used as an atomic, so we query it in Eventually
|
||||
@@ -554,6 +546,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -598,7 +591,6 @@ var _ = Describe("Server", func() {
|
||||
|
||||
It("only creates a single connection for a duplicate Initial", func() {
|
||||
var createdConn bool
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
runner connRunner,
|
||||
@@ -607,6 +599,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -618,15 +611,19 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
createdConn = true
|
||||
return conn
|
||||
return NewMockQUICConn(mockCtrl)
|
||||
}
|
||||
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
p := getInitial(connID)
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
done := make(chan struct{})
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
|
||||
Expect(serv.handlePacketImpl(p)).To(BeTrue())
|
||||
Expect(createdConn).To(BeFalse())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects new connection attempts if the accept queue is full", func() {
|
||||
@@ -638,6 +635,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -659,12 +657,11 @@ var _ = Describe("Server", func() {
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).Times(protocol.MaxAcceptQueueSize)
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(protocol.MaxAcceptQueueSize)
|
||||
@@ -709,6 +706,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -730,12 +728,11 @@ var _ = Describe("Server", func() {
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
|
||||
|
||||
serv.handlePacket(p)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
@@ -753,8 +750,7 @@ var _ = Describe("Server", func() {
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
phm.EXPECT().CloseServer()
|
||||
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
||||
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
@@ -794,7 +790,7 @@ var _ = Describe("Server", func() {
|
||||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
|
||||
serv.handlePacket(packet)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
@@ -968,6 +964,7 @@ var _ = Describe("Server", func() {
|
||||
|
||||
serv.setCloseError(testErr)
|
||||
Eventually(done).Should(BeClosed())
|
||||
serv.onClose() // shutdown
|
||||
})
|
||||
|
||||
It("returns immediately, if an error occurred before", func() {
|
||||
@@ -977,6 +974,7 @@ var _ = Describe("Server", func() {
|
||||
_, err := serv.Accept(context.Background())
|
||||
Expect(err).To(MatchError(testErr))
|
||||
}
|
||||
serv.onClose() // shutdown
|
||||
})
|
||||
|
||||
It("returns when the context is canceled", func() {
|
||||
@@ -994,6 +992,85 @@ var _ = Describe("Server", func() {
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("uses the config returned by GetConfigClient", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
||||
conf := &Config{MaxIncomingStreams: 1234}
|
||||
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
s, err := serv.Accept(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).To(Equal(conn))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
handshakeChan := make(chan struct{})
|
||||
serv.newConn = func(
|
||||
_ sendConn,
|
||||
_ connRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ *protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
conf *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TokenGenerator,
|
||||
_ bool,
|
||||
_ logging.ConnectionTracer,
|
||||
_ uint64,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicConn {
|
||||
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
|
||||
conn.EXPECT().handlePacket(gomock.Any())
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
conn.EXPECT().run().Do(func() {})
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
||||
)
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
close(handshakeChan) // complete the handshake
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects a connection attempt when GetConfigClient returns an error", func() {
|
||||
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
done := make(chan struct{})
|
||||
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
|
||||
defer close(done)
|
||||
rejectHdr := parseHeader(b)
|
||||
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
|
||||
return len(b), nil
|
||||
})
|
||||
serv.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
|
||||
)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("accepts new connections when the handshake completes", func() {
|
||||
conn := NewMockQUICConn(mockCtrl)
|
||||
|
||||
@@ -1015,6 +1092,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -1032,12 +1110,11 @@ var _ = Describe("Server", func() {
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
|
||||
serv.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
|
||||
@@ -1064,7 +1141,6 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
serv.Close()
|
||||
})
|
||||
|
||||
@@ -1089,6 +1165,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -1106,10 +1183,10 @@ var _ = Describe("Server", func() {
|
||||
return conn
|
||||
}
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.baseServer.handleInitialImpl(
|
||||
&receivedPacket{buffer: getPacketBuffer()},
|
||||
@@ -1131,6 +1208,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -1152,10 +1230,10 @@ var _ = Describe("Server", func() {
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any()).AnyTimes()
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
}).Times(protocol.MaxAcceptQueueSize)
|
||||
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
|
||||
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
|
||||
@@ -1194,6 +1272,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -1213,10 +1292,10 @@ var _ = Describe("Server", func() {
|
||||
}
|
||||
|
||||
phm.EXPECT().Get(gomock.Any())
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.baseServer.handlePacket(p)
|
||||
// make sure there are no Write calls on the packet conn
|
||||
@@ -1234,8 +1313,7 @@ var _ = Describe("Server", func() {
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
|
||||
// make the go routine return
|
||||
phm.EXPECT().CloseServer()
|
||||
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
|
||||
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
@@ -1243,6 +1321,7 @@ var _ = Describe("Server", func() {
|
||||
|
||||
Context("0-RTT", func() {
|
||||
var (
|
||||
tr *Transport
|
||||
serv *baseServer
|
||||
phm *MockPacketHandlerManager
|
||||
tracer *mocklogging.MockTracer
|
||||
@@ -1250,7 +1329,8 @@ var _ = Describe("Server", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
tracer = mocklogging.NewMockTracer(mockCtrl)
|
||||
ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer})
|
||||
tr = &Transport{Conn: conn, Tracer: tracer}
|
||||
ln, err := tr.ListenEarly(tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
phm = NewMockPacketHandlerManager(mockCtrl)
|
||||
serv = ln.baseServer
|
||||
@@ -1259,7 +1339,7 @@ var _ = Describe("Server", func() {
|
||||
|
||||
AfterEach(func() {
|
||||
phm.EXPECT().CloseServer().MaxTimes(1)
|
||||
serv.Close()
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("passes packets to existing connections", func() {
|
||||
@@ -1317,6 +1397,7 @@ var _ = Describe("Server", func() {
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ ConnectionIDGenerator,
|
||||
_ protocol.StatelessResetToken,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
@@ -1341,12 +1422,11 @@ var _ = Describe("Server", func() {
|
||||
return conn
|
||||
}
|
||||
|
||||
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
|
||||
phm.EXPECT().Get(connID)
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
|
||||
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
|
||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||
fn()
|
||||
return true
|
||||
_, ok := fn()
|
||||
return ok
|
||||
})
|
||||
serv.handlePacket(initial)
|
||||
Eventually(called).Should(BeClosed())
|
||||
|
||||
414
transport.go
Normal file
414
transport.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
type Transport struct {
|
||||
// A single net.PacketConn can only be handled by one Transport.
|
||||
// Bad things will happen if passed to multiple Transports.
|
||||
//
|
||||
// If the connection satisfies the OOBCapablePacketConn interface
|
||||
// (as a net.UDPConn does), ECN and packet info support will be enabled.
|
||||
// In this case, optimized syscalls might be used, skipping the
|
||||
// ReadFrom and WriteTo calls to read / write packets.
|
||||
Conn net.PacketConn
|
||||
|
||||
// The length of the connection ID in bytes.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If unset, a 4 byte connection ID will be used.
|
||||
ConnectionIDLength int
|
||||
|
||||
// Use for generating new connection IDs.
|
||||
// This allows the application to control of the connection IDs used,
|
||||
// which allows routing / load balancing based on connection IDs.
|
||||
// All Connection IDs returned by the ConnectionIDGenerator MUST
|
||||
// have the same length.
|
||||
ConnectionIDGenerator ConnectionIDGenerator
|
||||
|
||||
// The StatelessResetKey is used to generate stateless reset tokens.
|
||||
// If no key is configured, sending of stateless resets is disabled.
|
||||
StatelessResetKey *StatelessResetKey
|
||||
|
||||
// A Tracer traces events that don't belong to a single QUIC connection.
|
||||
Tracer logging.Tracer
|
||||
|
||||
handlerMap packetHandlerManager
|
||||
|
||||
mutex sync.Mutex
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
|
||||
// Set in init.
|
||||
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
|
||||
connIDLen int
|
||||
// Set in init.
|
||||
// If no ConnectionIDGenerator is set, this is set to a default.
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
|
||||
server unknownPacketHandler
|
||||
|
||||
conn rawConn
|
||||
|
||||
closeQueue chan closePacket
|
||||
|
||||
listening chan struct{} // is closed when listen returns
|
||||
closed bool
|
||||
createdConn bool
|
||||
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// Listen starts listening for incoming QUIC connections.
|
||||
// There can only be a single listener on any net.PacketConn.
|
||||
// Listen may only be called again after the current Listener was closed.
|
||||
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
if t.server != nil {
|
||||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.server = s
|
||||
return &Listener{baseServer: s}, nil
|
||||
}
|
||||
|
||||
// ListenEarly starts listening for incoming QUIC connections.
|
||||
// There can only be a single listener on any net.PacketConn.
|
||||
// Listen may only be called again after the current Listener was closed.
|
||||
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
if t.server != nil {
|
||||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.server = s
|
||||
return &EarlyListener{baseServer: s}, nil
|
||||
}
|
||||
|
||||
// Dial dials a new connection to a remote host (not using 0-RTT).
|
||||
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
|
||||
}
|
||||
|
||||
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
|
||||
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
|
||||
}
|
||||
|
||||
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
|
||||
conn, ok := c.(interface{ SetReadBuffer(int) error })
|
||||
if !ok {
|
||||
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
|
||||
}
|
||||
size, err := inspectReadBuffer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||
}
|
||||
if size >= protocol.DesiredReceiveBufferSize {
|
||||
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
|
||||
return nil
|
||||
}
|
||||
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
|
||||
return fmt.Errorf("failed to increase receive buffer size: %w", err)
|
||||
}
|
||||
newSize, err := inspectReadBuffer(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine receive buffer size: %w", err)
|
||||
}
|
||||
if newSize == size {
|
||||
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||
}
|
||||
if newSize < protocol.DesiredReceiveBufferSize {
|
||||
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
|
||||
}
|
||||
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
|
||||
return nil
|
||||
}
|
||||
|
||||
// only print warnings about the UDP receive buffer size once
|
||||
var receiveBufferWarningOnce sync.Once
|
||||
|
||||
func (t *Transport) init(isServer bool) error {
|
||||
t.initOnce.Do(func() {
|
||||
getMultiplexer().AddConn(t.Conn)
|
||||
|
||||
conn, err := wrapConn(t.Conn)
|
||||
if err != nil {
|
||||
t.initErr = err
|
||||
return
|
||||
}
|
||||
|
||||
t.logger = utils.DefaultLogger // TODO: make this configurable
|
||||
t.conn = conn
|
||||
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
|
||||
t.listening = make(chan struct{})
|
||||
|
||||
t.closeQueue = make(chan closePacket, 4)
|
||||
|
||||
if t.ConnectionIDGenerator != nil {
|
||||
t.connIDGenerator = t.ConnectionIDGenerator
|
||||
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
|
||||
} else {
|
||||
connIDLen := t.ConnectionIDLength
|
||||
if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
t.connIDLen = connIDLen
|
||||
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
|
||||
}
|
||||
|
||||
go t.listen(conn)
|
||||
go t.runCloseQueue()
|
||||
})
|
||||
return t.initErr
|
||||
}
|
||||
|
||||
func (t *Transport) enqueueClosePacket(p closePacket) {
|
||||
select {
|
||||
case t.closeQueue <- p:
|
||||
default:
|
||||
// Oops, we're backlogged.
|
||||
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) runCloseQueue() {
|
||||
for {
|
||||
select {
|
||||
case <-t.listening:
|
||||
return
|
||||
case p := <-t.closeQueue:
|
||||
t.conn.WritePacket(p.payload, p.addr, p.info.OOB())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying connection and waits until listen has returned.
|
||||
// It is invalid to start new listeners or connections after that.
|
||||
func (t *Transport) Close() error {
|
||||
t.close(errors.New("closing"))
|
||||
if t.createdConn {
|
||||
if err := t.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
t.conn.SetReadDeadline(time.Now())
|
||||
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
|
||||
}
|
||||
<-t.listening // wait until listening returns
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Transport) closeServer() {
|
||||
t.handlerMap.CloseServer()
|
||||
t.mutex.Lock()
|
||||
t.server = nil
|
||||
if t.isSingleUse {
|
||||
t.closed = true
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
if t.createdConn {
|
||||
t.Conn.Close()
|
||||
}
|
||||
if t.isSingleUse {
|
||||
t.conn.SetReadDeadline(time.Now())
|
||||
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
|
||||
<-t.listening // wait until listening returns
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) close(e error) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
|
||||
t.handlerMap.Close(e)
|
||||
if t.server != nil {
|
||||
t.server.setCloseError(e)
|
||||
}
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
func (t *Transport) listen(conn rawConn) {
|
||||
defer close(t.listening)
|
||||
defer getMultiplexer().RemoveConn(t.Conn)
|
||||
|
||||
if err := setReceiveBuffer(t.Conn, t.logger); err != nil {
|
||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
receiveBufferWarningOnce.Do(func() {
|
||||
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
|
||||
return
|
||||
}
|
||||
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
p, err := conn.ReadPacket()
|
||||
//nolint:staticcheck // SA1019 ignore this!
|
||||
// TODO: This code is used to ignore wsa errors on Windows.
|
||||
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
|
||||
// See https://github.com/quic-go/quic-go/issues/1737 for details.
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
t.mutex.Lock()
|
||||
closed := t.closed
|
||||
t.mutex.Unlock()
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Temporary error reading from conn: %w", err)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.close(err)
|
||||
return
|
||||
}
|
||||
t.handlePacket(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) handlePacket(p *receivedPacket) {
|
||||
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
|
||||
if err != nil {
|
||||
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||
if t.Tracer != nil {
|
||||
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
p.buffer.MaybeRelease()
|
||||
return
|
||||
}
|
||||
|
||||
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
|
||||
return
|
||||
}
|
||||
if handler, ok := t.handlerMap.Get(connID); ok {
|
||||
handler.handlePacket(p)
|
||||
return
|
||||
}
|
||||
if !wire.IsLongHeaderPacket(p.data[0]) {
|
||||
go t.maybeSendStatelessReset(p, connID)
|
||||
return
|
||||
}
|
||||
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
if t.server == nil { // no server set
|
||||
t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
||||
return
|
||||
}
|
||||
t.server.handlePacket(p)
|
||||
}
|
||||
|
||||
func (t *Transport) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
|
||||
defer p.buffer.Release()
|
||||
if t.StatelessResetKey == nil {
|
||||
return
|
||||
}
|
||||
// Don't send a stateless reset in response to very small packets.
|
||||
// This includes packets that could be stateless resets.
|
||||
if len(p.data) <= protocol.MinStatelessResetSize {
|
||||
return
|
||||
}
|
||||
token := t.handlerMap.GetStatelessResetToken(connID)
|
||||
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||||
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
|
||||
rand.Read(data)
|
||||
data[0] = (data[0] & 0x7f) | 0x40
|
||||
data = append(data, token[:]...)
|
||||
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
|
||||
t.logger.Debugf("Error sending Stateless Reset: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
|
||||
// stateless resets are always short header packets
|
||||
if wire.IsLongHeaderPacket(data[0]) {
|
||||
return false
|
||||
}
|
||||
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
|
||||
return false
|
||||
}
|
||||
|
||||
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
|
||||
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
|
||||
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
|
||||
go conn.destroy(&StatelessResetError{Token: token})
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
291
transport_test.go
Normal file
291
transport_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Transport", func() {
|
||||
type packetToRead struct {
|
||||
addr net.Addr
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
|
||||
b, err := (&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
Type: t,
|
||||
DestConnectionID: connID,
|
||||
Length: length,
|
||||
Version: protocol.Version1,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return b
|
||||
}
|
||||
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
|
||||
}
|
||||
|
||||
newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
|
||||
conn := NewMockPacketConn(mockCtrl)
|
||||
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
|
||||
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
|
||||
p, ok := <-packetChan
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
}
|
||||
return copy(b, p.data), p.addr, p.err
|
||||
}).AnyTimes()
|
||||
// for shutdown
|
||||
conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
|
||||
return conn
|
||||
}
|
||||
|
||||
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(true)
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
||||
|
||||
handled := make(chan struct{}, 2)
|
||||
phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
|
||||
h := NewMockPacketHandler(mockCtrl)
|
||||
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
defer GinkgoRecover()
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID1))
|
||||
handled <- struct{}{}
|
||||
})
|
||||
return h, true
|
||||
})
|
||||
phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
|
||||
h := NewMockPacketHandler(mockCtrl)
|
||||
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
defer GinkgoRecover()
|
||||
connID, err := wire.ParseConnectionID(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(connID2))
|
||||
handled <- struct{}{}
|
||||
})
|
||||
return h, true
|
||||
})
|
||||
|
||||
packetChan <- packetToRead{data: getPacket(connID1)}
|
||||
packetChan <- packetToRead{data: getPacket(connID2)}
|
||||
|
||||
Eventually(handled).Should(Receive())
|
||||
Eventually(handled).Should(Receive())
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("closes listeners", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := &Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
ln, err := tr.Listen(&tls.Config{}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
||||
phm.EXPECT().CloseServer()
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("drops unparseable packets", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
|
||||
packetChan := make(chan packetToRead)
|
||||
tracer := mocklogging.NewMockTracer(mockCtrl)
|
||||
tr := &Transport{
|
||||
Conn: newMockPacketConn(packetChan),
|
||||
ConnectionIDLength: 10,
|
||||
Tracer: tracer,
|
||||
}
|
||||
tr.init(true)
|
||||
dropped := make(chan struct{})
|
||||
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
|
||||
packetChan <- packetToRead{
|
||||
addr: addr,
|
||||
data: []byte{0, 1, 2, 3},
|
||||
}
|
||||
Eventually(dropped).Should(BeClosed())
|
||||
|
||||
// shutdown
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("closes when reading from the conn fails", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(true)
|
||||
tr.handlerMap = phm
|
||||
|
||||
done := make(chan struct{})
|
||||
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
|
||||
packetChan <- packetToRead{err: errors.New("read failed")}
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
// shutdown
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("continues listening after temporary errors", func() {
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.init(true)
|
||||
tr.handlerMap = phm
|
||||
|
||||
tempErr := deadlineError{}
|
||||
Expect(tempErr.Temporary()).To(BeTrue())
|
||||
packetChan <- packetToRead{err: tempErr}
|
||||
// don't expect any calls to phm.Close
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("handles short header packets resets", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{
|
||||
Conn: newMockPacketConn(packetChan),
|
||||
ConnectionIDLength: connID.Len(),
|
||||
}
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
rand.Read(token[:])
|
||||
|
||||
var b []byte
|
||||
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b = append(b, token[:]...)
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().GetByResetToken(token),
|
||||
phm.EXPECT().Get(connID).Return(conn, true),
|
||||
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
|
||||
Expect(p.data).To(Equal(b))
|
||||
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
|
||||
}),
|
||||
)
|
||||
packetChan <- packetToRead{data: b}
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("handles stateless resets", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
tr := Transport{Conn: newMockPacketConn(packetChan)}
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
rand.Read(token[:])
|
||||
|
||||
var b []byte
|
||||
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b = append(b, token[:]...)
|
||||
conn := NewMockPacketHandler(mockCtrl)
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().GetByResetToken(token).Return(conn, true),
|
||||
conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
Expect(err).To(MatchError(&StatelessResetError{Token: token}))
|
||||
}),
|
||||
)
|
||||
packetChan <- packetToRead{data: b}
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
|
||||
It("sends stateless resets", func() {
|
||||
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
|
||||
packetChan := make(chan packetToRead)
|
||||
conn := newMockPacketConn(packetChan)
|
||||
tr := Transport{
|
||||
Conn: conn,
|
||||
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
|
||||
ConnectionIDLength: connID.Len(),
|
||||
}
|
||||
tr.init(true)
|
||||
defer tr.Close()
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
tr.handlerMap = phm
|
||||
|
||||
var b []byte
|
||||
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
|
||||
|
||||
var token protocol.StatelessResetToken
|
||||
rand.Read(token[:])
|
||||
written := make(chan struct{})
|
||||
gomock.InOrder(
|
||||
phm.EXPECT().GetByResetToken(gomock.Any()),
|
||||
phm.EXPECT().Get(connID),
|
||||
phm.EXPECT().GetStatelessResetToken(connID).Return(token),
|
||||
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) {
|
||||
defer close(written)
|
||||
Expect(bytes.Contains(b, token[:])).To(BeTrue())
|
||||
}),
|
||||
)
|
||||
packetChan <- packetToRead{data: b}
|
||||
Eventually(written).Should(BeClosed())
|
||||
|
||||
// shutdown
|
||||
phm.EXPECT().Close(gomock.Any())
|
||||
close(packetChan)
|
||||
tr.Close()
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user