Merge pull request #3794 from quic-go/new-api

introduce a Transport
This commit is contained in:
Marten Seemann
2023-05-02 16:08:04 +02:00
committed by GitHub
64 changed files with 1940 additions and 1855 deletions

212
client.go
View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/protocol"
@@ -13,20 +12,19 @@ import (
)
type client struct {
sconn sendConn
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
sendConn sendConn
use0RTT bool
packetHandlers packetHandlerManager
onClose func()
tlsConf *tls.Config
config *Config
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
connIDGenerator ConnectionIDGenerator
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool
@@ -46,32 +44,58 @@ var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// The hostname for SNI is taken from the given address.
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (Connection, error) {
return dialAddrContext(ctx, addr, tlsConf, config, false)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
if err != nil {
return nil, err
}
utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
return conn, nil
}
func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, config *Config, use0RTT bool) (quicConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return dialContext(ctx, udpConn, udpAddr, tlsConf, config, use0RTT, true)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
return dl.Dial(ctx, udpAddr, tlsConf, conf)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
// See DialEarly for details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
@@ -79,34 +103,42 @@ func dialAddrContext(ctx context.Context, addr string, tlsConf *tls.Config, conf
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// packets.
// The same PacketConn can be used for multiple calls to Dial and Listen.
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must define an application protocol (using NextProtos).
func Dial(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (Connection, error) {
return dialContext(ctx, pconn, addr, tlsConf, config, false, false)
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// The same PacketConn can be used for multiple calls to Dial and Listen,
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must define an application protocol (using NextProtos).
func DialEarly(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config) (EarlyConnection, error) {
return dialContext(ctx, pconn, addr, tlsConf, config, true, false)
}
func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsConf *tls.Config, config *Config, use0RTT bool, createdPacketConn bool) (quicConn, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(config); err != nil {
return nil, err
}
config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
c, err := newClient(pconn, addr, config, tlsConf, use0RTT, createdPacketConn)
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
return &Transport{
Conn: c,
createdConn: createdPacketConn,
isSingleUse: true,
}, nil
}
func dial(
ctx context.Context,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
packetHandlers packetHandlerManager,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
) (quicConn, error) {
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
if err != nil {
return nil, err
}
@@ -114,14 +146,10 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo
c.tracingID = nextConnTracingID()
if c.config.Tracer != nil {
c.tracer = c.config.Tracer.TracerForConnection(
context.WithValue(ctx, ConnectionTracingKey, c.tracingID),
protocol.PerspectiveClient,
c.destConnID,
)
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
}
if c.tracer != nil {
c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID)
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
}
if err := c.dial(ctx); err != nil {
return nil, err
@@ -129,23 +157,14 @@ func dialContext(ctx context.Context, pconn net.PacketConn, addr net.Addr, tlsCo
return c.conn, nil
}
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, use0RTT bool, createdPacketConn bool) (*client, error) {
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = tlsConf.Clone()
}
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
@@ -154,28 +173,30 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
return nil, err
}
c := &client{
srcConnID: srcConnID,
destConnID: destConnID,
sconn: newSendPconn(pconn, remoteAddr),
createdPacketConn: createdPacketConn,
use0RTT: use0RTT,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
connIDGenerator: connIDGenerator,
srcConnID: srcConnID,
destConnID: destConnID,
sendConn: sendConn,
use0RTT: use0RTT,
onClose: onClose,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection(
c.sconn,
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,
@@ -189,13 +210,18 @@ func (c *client) dial(ctx context.Context) error {
c.packetHandlers.Add(c.srcConnID, c.conn)
errorChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() {
err := c.conn.run() // returns as soon as the connection is closed
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
c.packetHandlers.Destroy()
err := c.conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
errorChan <- err
if c.onClose != nil {
c.onClose()
}
errorChan <- err // returns as soon as the connection is closed
}()
// only set when we're using 0-RTT
@@ -210,14 +236,12 @@ func (c *client) dial(ctx context.Context) error {
c.conn.shutdown()
return ctx.Err()
case err := <-errorChan:
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
}
return err
case recreateErr := <-recreateChan:
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
case <-earlyConnChan:
// ready to send 0-RTT data
return nil

View File

@@ -18,13 +18,16 @@ import (
. "github.com/onsi/gomega"
)
type nullMultiplexer struct{}
func (n nullMultiplexer) AddConn(indexableConn) {}
func (n nullMultiplexer) RemoveConn(indexableConn) error { return nil }
var _ = Describe("Client", func() {
var (
cl *client
packetConn *MockPacketConn
addr net.Addr
packetConn *MockSendConn
connID protocol.ConnectionID
mockMultiplexer *MockMultiplexer
origMultiplexer multiplexer
tlsConf *tls.Config
tracer *mocklogging.MockConnectionTracer
@@ -35,6 +38,7 @@ var _ = Describe("Client", func() {
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
@@ -52,26 +56,28 @@ var _ = Describe("Client", func() {
connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37})
originalClientConnConstructor = newClientConnection
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
tr := mocklogging.NewMockTracer(mockCtrl)
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}}
config = &Config{
Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer {
return tracer
},
Versions: []protocol.VersionNumber{protocol.Version1},
}
Eventually(areConnsRunning).Should(BeFalse())
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = NewMockPacketConn(mockCtrl)
packetConn = NewMockSendConn(mockCtrl)
packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
packetConn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes()
cl = &client{
srcConnID: connID,
destConnID: connID,
version: protocol.Version1,
sconn: newSendPconn(packetConn, addr),
sendConn: packetConn,
tracer: tracer,
logger: utils.DefaultLogger,
}
getMultiplexer() // make the sync.Once execute
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
mockMultiplexer = NewMockMultiplexer(mockCtrl)
// replace the clientMuxer. getMultiplexer will now return the nullMultiplexer
origMultiplexer = connMuxer
connMuxer = mockMultiplexer
connMuxer = &nullMultiplexer{}
})
AfterEach(func() {
@@ -100,50 +106,17 @@ var _ = Describe("Client", func() {
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
})
It("resolves the address", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
remoteAddrChan := make(chan string, 1)
newClientConnection = func(
sconn sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
remoteAddrChan <- sconn.RemoteAddr().String()
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
return conn
}
_, err := DialAddr(context.Background(), "localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond})
Expect(err).ToNot(HaveOccurred())
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
})
It("returns after the handshake is complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
run := make(chan struct{})
newClientConnection = func(
_ sendConn,
runner connRunner,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@@ -162,18 +135,17 @@ var _ = Describe("Client", func() {
conn.EXPECT().HandshakeComplete().Return(c)
return conn
}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
s, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(Succeed())
Eventually(run).Should(BeClosed())
})
It("returns early connections", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
readyChan := make(chan struct{})
done := make(chan struct{})
newClientConnection = func(
@@ -181,6 +153,7 @@ var _ = Describe("Client", func() {
runner connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@@ -193,29 +166,23 @@ var _ = Describe("Client", func() {
) quicConn {
Expect(enable0RTT).To(BeTrue())
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Do(func() { <-done })
conn.EXPECT().run().Do(func() { close(done) })
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().earlyConnReady().Return(readyChan)
return conn
}
go func() {
defer GinkgoRecover()
defer close(done)
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
s, err := DialEarly(context.Background(), packetConn, addr, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
}()
Consistently(done).ShouldNot(BeClosed())
close(readyChan)
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(Succeed())
Eventually(done).Should(BeClosed())
})
It("returns an error that occurs while waiting for the handshake to complete", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
testErr := errors.New("early handshake error")
newClientConnection = func(
@@ -223,6 +190,7 @@ var _ = Describe("Client", func() {
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@@ -236,150 +204,44 @@ var _ = Describe("Client", func() {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Return(testErr)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().earlyConnReady().Return(make(chan struct{}))
return conn
}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
Expect(err).To(MatchError(testErr))
})
It("closes the connection when the context is canceled", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
connRunning := make(chan struct{})
defer close(connRunning)
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run().Do(func() {
<-connRunning
})
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
newClientConnection = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
return conn
}
ctx, cancel := context.WithCancel(context.Background())
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := Dial(ctx, packetConn, addr, tlsConf, config)
Expect(err).To(MatchError(context.Canceled))
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
conn.EXPECT().shutdown()
cancel()
Eventually(dialed).Should(BeClosed())
})
It("closes the connection when it was created by DialAddr", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
var sconn sendConn
run := make(chan struct{})
connCreated := make(chan struct{})
conn := NewMockQUICConn(mockCtrl)
newClientConnection = func(
connP sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
sconn = connP
close(connCreated)
return conn
}
conn.EXPECT().run().Do(func() {
<-run
})
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := DialAddr(context.Background(), "localhost:1337", tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(connCreated).Should(BeClosed())
// check that the connection is not closed
Expect(sconn.Write([]byte("foobar"))).To(Succeed())
manager.EXPECT().Destroy()
close(run)
time.Sleep(50 * time.Millisecond)
Eventually(done).Should(BeClosed())
var closed bool
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true)
Expect(err).ToNot(HaveOccurred())
cl.packetHandlers = manager
Expect(cl).ToNot(BeNil())
Expect(cl.dial(context.Background())).To(MatchError(testErr))
Expect(closed).To(BeTrue())
})
Context("quic.Config", func() {
It("setups with the right values", func() {
srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}
tokenStore := NewLRUTokenStore(10, 4)
config := &Config{
HandshakeIdleTimeout: 1337 * time.Minute,
MaxIdleTimeout: 42 * time.Hour,
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: 4321,
ConnectionIDLength: 13,
StatelessResetKey: srk,
TokenStore: tokenStore,
EnableDatagrams: true,
}
c := populateClientConfig(config, false)
c := populateConfig(config)
Expect(c.HandshakeIdleTimeout).To(Equal(1337 * time.Minute))
Expect(c.MaxIdleTimeout).To(Equal(42 * time.Hour))
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
Expect(c.ConnectionIDLength).To(Equal(13))
Expect(c.StatelessResetKey).To(Equal(srk))
Expect(c.TokenStore).To(Equal(tokenStore))
Expect(c.EnableDatagrams).To(BeTrue())
})
It("errors when the Config contains an invalid version", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
version := protocol.VersionNumber(0x1234)
_, err := Dial(context.Background(), packetConn, nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
})
It("disables bidirectional streams", func() {
config := &Config{
MaxIncomingStreams: -1,
MaxIncomingUniStreams: 4321,
}
c := populateClientConfig(config, false)
c := populateConfig(config)
Expect(c.MaxIncomingStreams).To(BeZero())
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321))
})
@@ -389,18 +251,13 @@ var _ = Describe("Client", func() {
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: -1,
}
c := populateClientConfig(config, false)
c := populateConfig(config)
Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234))
Expect(c.MaxIncomingUniStreams).To(BeZero())
})
It("uses 0-byte connection IDs when dialing an address", func() {
c := populateClientConfig(&Config{}, true)
Expect(c.ConnectionIDLength).To(BeZero())
})
It("fills in default values if options are not set in the Config", func() {
c := populateClientConfig(&Config{}, false)
c := populateConfig(&Config{})
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(c.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
@@ -408,20 +265,17 @@ var _ = Describe("Client", func() {
})
It("creates new connections with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}}
c := make(chan struct{})
var cconn sendConn
var version protocol.VersionNumber
var conf *Config
done := make(chan struct{})
newClientConnection = func(
connP sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
configP *Config,
_ *tls.Config,
_ protocol.PacketNumber,
@@ -432,7 +286,6 @@ var _ = Describe("Client", func() {
_ utils.Logger,
versionP protocol.VersionNumber,
) quicConn {
cconn = connP
version = versionP
conf = configP
close(c)
@@ -440,28 +293,32 @@ var _ = Describe("Client", func() {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().run()
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().destroy(gomock.Any())
close(done)
return conn
}
_, err := Dial(context.Background(), packetConn, addr, tlsConf, config)
packetConn := NewMockPacketConn(mockCtrl)
packetConn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func([]byte) (int, net.Addr, error) {
<-done
return 0, nil, errors.New("closed")
})
packetConn.EXPECT().LocalAddr()
packetConn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
_, err := Dial(context.Background(), packetConn, &net.UDPAddr{}, tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
Expect(cconn.(*spconn).PacketConn).To(Equal(packetConn))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
})
It("creates a new connections after version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
var counter int
newClientConnection = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
runner connRunner,
_ protocol.ConnectionID,
connID protocol.ConnectionID,
_ ConnectionIDGenerator,
configP *Config,
_ *tls.Config,
pn protocol.PacketNumber,
@@ -477,20 +334,24 @@ var _ = Describe("Client", func() {
if counter == 0 {
Expect(pn).To(BeZero())
Expect(hasNegotiatedVersion).To(BeFalse())
conn.EXPECT().run().Return(&errCloseForRecreating{
nextPacketNumber: 109,
nextVersion: 789,
conn.EXPECT().run().DoAndReturn(func() error {
runner.Remove(connID)
return &errCloseForRecreating{
nextPacketNumber: 109,
nextVersion: 789,
}
})
} else {
Expect(pn).To(Equal(protocol.PacketNumber(109)))
Expect(hasNegotiatedVersion).To(BeTrue())
conn.EXPECT().run()
conn.EXPECT().destroy(gomock.Any())
}
counter++
return conn
}
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
@@ -498,15 +359,3 @@ var _ = Describe("Client", func() {
})
})
})
type mockConnIDGenerator struct {
ConnID protocol.ConnectionID
}
func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) {
return m.ConnID, nil
}
func (m *mockConnIDGenerator) ConnectionIDLen() int {
return m.ConnID.Len()
}

View File

@@ -2,6 +2,7 @@ package quic
import (
"errors"
"fmt"
"net"
"time"
@@ -29,13 +30,19 @@ func validateConfig(config *Config) error {
if config.MaxIncomingUniStreams > 1<<60 {
return errors.New("invalid value for Config.MaxIncomingUniStreams")
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return fmt.Errorf("invalid QUIC version: %s", v)
}
}
return nil
}
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config, protocol.DefaultConnectionIDLength)
config = populateConfig(config)
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
@@ -48,19 +55,9 @@ func populateServerConfig(config *Config) *Config {
return config
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
defaultConnIDLen := protocol.DefaultConnectionIDLength
if createdPacketConn {
defaultConnIDLen = 0
}
config = populateConfig(config, defaultConnIDLen)
return config
}
func populateConfig(config *Config, defaultConnIDLen int) *Config {
func populateConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
@@ -68,10 +65,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
conIDLen := config.ConnectionIDLength
if config.ConnectionIDLength == 0 {
conIDLen = defaultConnIDLen
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout
@@ -108,12 +101,9 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDGenerator := config.ConnectionIDGenerator
if connIDGenerator == nil {
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
}
return &Config{
GetConfigForClient: config.GetConfigForClient,
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
@@ -128,9 +118,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
ConnectionIDLength: conIDLen,
ConnectionIDGenerator: connIDGenerator,
StatelessResetKey: config.StatelessResetKey,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,

View File

@@ -1,13 +1,15 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"reflect"
"time"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -45,7 +47,7 @@ var _ = Describe("Config", func() {
}
switch fn := typ.Field(i).Name; fn {
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT":
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
@@ -85,8 +87,8 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf(true))
case "DisablePathMTUDiscovery":
f.Set(reflect.ValueOf(true))
case "Tracer":
f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl)))
case "Allow0RTT":
f.Set(reflect.ValueOf(true))
default:
Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
}
@@ -106,16 +108,25 @@ var _ = Describe("Config", func() {
Context("cloning", func() {
It("clones function fields", func() {
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool
c1 := &Config{
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer {
calledTracer = true
return nil
},
}
c2 := c1.Clone()
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
c2.AllowConnectionWindowIncrease(nil, 1234)
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
_, err := c2.GetConfigForClient(&ClientHelloInfo{})
Expect(err).To(MatchError("nope"))
c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
Expect(calledTracer).To(BeTrue())
})
It("clones non-function fields", func() {
@@ -142,18 +153,18 @@ var _ = Describe("Config", func() {
var calledAddrValidation bool
c1 := &Config{}
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
c2 := populateConfig(c1, protocol.DefaultConnectionIDLength)
c2 := populateConfig(c1)
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
})
It("copies non-function fields", func() {
c := configWithNonZeroNonFunctionFields()
Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c))
Expect(populateConfig(c)).To(Equal(c))
})
It("populates empty fields with default values", func() {
c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength)
c := populateConfig(&Config{})
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
@@ -164,22 +175,12 @@ var _ = Describe("Config", func() {
Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
Expect(c.DisableVersionNegotiationPackets).To(BeFalse())
Expect(c.DisablePathMTUDiscovery).To(BeFalse())
Expect(c.GetConfigForClient).To(BeNil())
})
It("populates empty fields with default values, for the server", func() {
c := populateServerConfig(&Config{})
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
Expect(c.RequireAddressValidation).ToNot(BeNil())
})
It("sets a default connection ID length if we didn't create the conn, for the client", func() {
c := populateClientConfig(&Config{}, false)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})
It("doesn't set a default connection ID length if we created the conn, for the client", func() {
c := populateClientConfig(&Config{}, true)
Expect(c.ConnectionIDLength).To(BeZero())
})
})
})

View File

@@ -240,6 +240,7 @@ var newConnection = func(
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetToken protocol.StatelessResetToken,
conf *Config,
tlsConf *tls.Config,
@@ -283,7 +284,7 @@ var newConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
@@ -323,10 +324,6 @@ var newConnection = func(
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
var allow0RTT func() bool
if conf.Allow0RTT != nil {
allow0RTT = func() bool { return conf.Allow0RTT(conn.RemoteAddr()) }
}
cs := handshake.NewCryptoSetupServer(
initialStream,
handshakeStream,
@@ -344,7 +341,7 @@ var newConnection = func(
},
},
tlsConf,
allow0RTT,
conf.Allow0RTT,
s.rttStats,
tracer,
logger,
@@ -363,6 +360,7 @@ var newClientConnection = func(
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
@@ -402,7 +400,7 @@ var newClientConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))

View File

@@ -113,6 +113,7 @@ var _ = Describe("Connection", func() {
clientDestConnID,
destConnID,
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
protocol.StatelessResetToken{},
populateServerConfig(&Config{DisablePathMTUDiscovery: true}),
nil, // tls.Config
@@ -2015,8 +2016,6 @@ var _ = Describe("Connection", func() {
packer.EXPECT().HandleTransportParameters(params)
packer.EXPECT().PackCoalescedPacket(false, conn.version).MaxTimes(3)
Expect(conn.earlyConnReady()).ToNot(BeClosed())
connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2)
connRunner.EXPECT().Add(gomock.Any(), conn).Times(2)
tracer.EXPECT().ReceivedTransportParameters(params)
conn.handleTransportParameters(params)
Expect(conn.earlyConnReady()).To(BeClosed())
@@ -2378,7 +2377,7 @@ var _ = Describe("Client Connection", func() {
}
BeforeEach(func() {
quicConf = populateClientConfig(&Config{}, true)
quicConf = populateConfig(&Config{})
tlsConf = nil
})
@@ -2402,6 +2401,7 @@ var _ = Describe("Client Connection", func() {
connRunner,
destConnID,
protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
&protocol.DefaultConnectionIDGenerator{},
quicConf,
tlsConf,
42, // initial packet number

View File

@@ -3,6 +3,7 @@ package main
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"flag"
@@ -57,15 +58,15 @@ func main() {
var qconf quic.Config
if *enableQlog {
qconf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
filename := fmt.Sprintf("client_%x.qlog", connID)
f, err := os.Create(filename)
if err != nil {
log.Fatal(err)
}
log.Printf("Creating qlog file %s.\n", filename)
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
})
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
}
}
roundTripper := &http3.RoundTripper{
TLSClientConfig: &tls.Config{

View File

@@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"crypto/md5"
"errors"
"flag"
@@ -162,15 +163,15 @@ func main() {
handler := setupHandler(*www)
quicConf := &quic.Config{}
if *enableQlog {
quicConf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
filename := fmt.Sprintf("server_%x.qlog", connID)
f, err := os.Create(filename)
if err != nil {
log.Fatal(err)
}
log.Printf("Creating qlog file %s.\n", filename)
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
})
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
}
}
var wg sync.WaitGroup

View File

@@ -105,7 +105,7 @@ func main() {
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
runner,
config,
nil,
false,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),

View File

@@ -390,10 +390,6 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
protocol.Version1,
)
var allow0RTT func() bool
if enable0RTTServer {
allow0RTT = func() bool { return true }
}
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server = handshake.NewCryptoSetupServer(
sInitialStream,
@@ -404,7 +400,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
serverTP,
runner,
serverConf,
allow0RTT,
enable0RTTServer,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),

View File

@@ -288,7 +288,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig
if quicConf == nil {
quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }}
quicConf = &quic.Config{Allow0RTT: true}
} else {
quicConf = s.QuicConfig.Clone()
}

View File

@@ -660,6 +660,7 @@ var _ = Describe("Stream Cancellations", func() {
getQuicConfig(&quic.Config{MaxIncomingStreams: maxIncomingStreams, MaxIdleTimeout: 10 * time.Second}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
var wg sync.WaitGroup
wg.Add(2 * 4 * maxIncomingStreams)

View File

@@ -24,6 +24,7 @@ var _ = Describe("Connection ID lengths tests", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
var drop atomic.Bool
dropped := make(chan []byte, 100)
@@ -50,6 +51,7 @@ var _ = Describe("Connection ID lengths tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
sconn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())

View File

@@ -34,9 +34,23 @@ func (c *connIDGenerator) ConnectionIDLen() int {
var _ = Describe("Connection ID lengths tests", func() {
randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
runServer := func(conf *quic.Config) *quic.Listener {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), conf)
// connIDLen is ignored when connIDGenerator is set
runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) {
if connIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen())))
} else {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen)))
}
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: conn,
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
@@ -55,16 +69,35 @@ var _ = Describe("Connection ID lengths tests", func() {
}()
}
}()
return ln
return ln, func() {
ln.Close()
tr.Close()
}
}
runClient := func(addr net.Addr, conf *quic.Config) {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
cl, err := quic.DialAddr(
// connIDLen is ignored when connIDGenerator is set
runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) {
if connIDGenerator != nil {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen())))
} else {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen)))
}
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
defer tr.Close()
cl, err := tr.Dial(
context.Background(),
fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
getTLSClientConfig(),
conf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer cl.CloseWithError(0, "")
@@ -76,32 +109,20 @@ var _ = Describe("Connection ID lengths tests", func() {
}
It("downloads a file using a 0-byte connection ID for the client", func() {
serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()})
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), getQuicConfig(nil))
ln, closeFn := runServer(randomConnIDLen(), nil)
defer closeFn()
runClient(ln.Addr(), 0, nil)
})
It("downloads a file when both client and server use a random connection ID length", func() {
serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()})
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), getQuicConfig(nil))
ln, closeFn := runServer(randomConnIDLen(), nil)
defer closeFn()
runClient(ln.Addr(), randomConnIDLen(), nil)
})
It("downloads a file when both client and server use a custom connection ID generator", func() {
serverConf := getQuicConfig(&quic.Config{
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
})
clientConf := getQuicConfig(&quic.Config{
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
})
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()})
defer closeFn()
runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()})
})
})

View File

@@ -22,12 +22,11 @@ var _ = Describe("Datagram test", func() {
const num = 100
var (
proxy *quicproxy.QuicProxy
serverConn, clientConn *net.UDPConn
dropped, total int32
)
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) {
startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err = net.ListenUDP("udp", addr)
@@ -39,8 +38,10 @@ var _ = Describe("Datagram test", func() {
)
Expect(err).ToNot(HaveOccurred())
accepted := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(accepted)
conn, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
@@ -67,7 +68,7 @@ var _ = Describe("Datagram test", func() {
}()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
// drop 10% of Short Header packets sent from the server
DropPacket: func(dir quicproxy.Direction, packet []byte) bool {
@@ -87,6 +88,11 @@ var _ = Describe("Datagram test", func() {
},
})
Expect(err).ToNot(HaveOccurred())
return proxy.LocalPort(), func() {
Eventually(accepted).Should(BeClosed())
proxy.Close()
ln.Close()
}
}
BeforeEach(func() {
@@ -96,13 +102,10 @@ var _ = Describe("Datagram test", func() {
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
It("sends datagrams", func() {
startServerAndProxy(true, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
proxyPort, close := startServerAndProxy(true, true)
defer close()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
@@ -117,6 +120,7 @@ var _ = Describe("Datagram test", func() {
for {
// Close the connection if no message is received for 100 ms.
timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() {
fmt.Println("closing conn")
conn.CloseWithError(0, "")
})
if _, err := conn.ReceiveMessage(); err != nil {
@@ -134,11 +138,12 @@ var _ = Describe("Datagram test", func() {
BeNumerically(">", expVal*9/10),
BeNumerically("<", num),
))
Eventually(conn.Context().Done).Should(BeClosed())
})
It("server can disable datagram", func() {
startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
proxyPort, close := startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
@@ -150,13 +155,13 @@ var _ = Describe("Datagram test", func() {
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
close()
conn.CloseWithError(0, "")
<-time.After(10 * time.Millisecond)
})
It("client can disable datagram", func() {
startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort()))
proxyPort, close := startServerAndProxy(false, true)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
context.Background(),
@@ -169,7 +174,8 @@ var _ = Describe("Datagram test", func() {
Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse())
Expect(conn.SendMessage([]byte{0})).To(HaveOccurred())
close()
conn.CloseWithError(0, "")
<-time.After(10 * time.Millisecond)
})
})

View File

@@ -24,6 +24,7 @@ var _ = Describe("early data", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()

View File

@@ -8,10 +8,9 @@ import (
"time"
)
var (
go120 = false
errNotSupported = errors.New("not supported")
)
const go120 = false
var errNotSupported = errors.New("not supported")
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
return errNotSupported

View File

@@ -7,7 +7,7 @@ import (
"time"
)
var go120 = true
const go120 = true
func setReadDeadline(w http.ResponseWriter, deadline time.Time) error {
rc := http.NewResponseController(w)

View File

@@ -62,13 +62,14 @@ var _ = Describe("Handshake RTT tests", func() {
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 2)
})
@@ -79,13 +80,14 @@ var _ = Describe("Handshake RTT tests", func() {
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 1)
})
@@ -97,13 +99,14 @@ var _ = Describe("Handshake RTT tests", func() {
runProxy(ln.Addr())
startTime := time.Now()
_, err = quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
expectDurationInRTTs(startTime, 2)
})
@@ -131,6 +134,7 @@ var _ = Describe("Handshake RTT tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)
@@ -166,6 +170,7 @@ var _ = Describe("Handshake RTT tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
str, err := conn.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(str)

View File

@@ -114,7 +114,7 @@ var _ = Describe("Handshake tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
nil,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
@@ -223,13 +223,14 @@ var _ = Describe("Handshake tests", func() {
var (
server *quic.Listener
pconn net.PacketConn
dialer *quic.Transport
)
dial := func() (quic.Connection, error) {
remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
Expect(err).ToNot(HaveOccurred())
return quic.Dial(context.Background(), pconn, raddr, getTLSClientConfig(), nil)
return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil))
}
BeforeEach(func() {
@@ -243,11 +244,13 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
pconn, err = net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4}
})
AfterEach(func() {
Expect(server.Close()).To(Succeed())
Expect(pconn.Close()).To(Succeed())
Expect(dialer.Close()).To(Succeed())
})
It("rejects new connection attempts if connections don't get accepted", func() {
@@ -300,7 +303,7 @@ var _ = Describe("Handshake tests", func() {
// This should free one spot in the queue.
Expect(firstConn.CloseWithError(0, ""))
Eventually(firstConn.Context().Done()).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond))
time.Sleep(scaleDuration(200 * time.Millisecond))
// dial again, and expect that this dial succeeds
_, err = dial()
@@ -366,6 +369,7 @@ var _ = Describe("Handshake tests", func() {
It("uses tokens provided in NEW_TOKEN frames", func() {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
defer server.Close()
// dial the first connection and receive the token
go func() {
@@ -432,6 +436,72 @@ var _ = Describe("Handshake tests", func() {
})
})
Context("GetConfigForClient", func() {
It("uses the quic.Config returned by GetConfigForClient", func() {
serverConfig.EnableDatagrams = false
var calledFrom net.Addr
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
conf := serverConfig.Clone()
conf.EnableDatagrams = true
calledFrom = info.RemoteAddr
return getQuicConfig(conf), nil
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
close(done)
}()
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
cs := conn.ConnectionState()
Expect(cs.SupportsDatagrams).To(BeTrue())
Eventually(done).Should(BeClosed())
Expect(ln.Close()).To(Succeed())
Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port))
})
It("rejects the connection attempt if GetConfigForClient errors", func() {
serverConfig.EnableDatagrams = false
serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
return nil, errors.New("rejected")
}
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := ln.Accept(context.Background())
Expect(err).To(HaveOccurred()) // we don't expect to accept any connection
close(done)
}()
_, err = quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).To(HaveOccurred())
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused))
})
})
It("doesn't send any packets when generating the ClientHello fails", func() {
ln, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())

View File

@@ -382,6 +382,7 @@ var _ = Describe("HTTP tests", func() {
tlsConf.NextProtos = []string{"h3"}
ln, err := quic.ListenAddr("localhost:0", tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@@ -398,57 +399,51 @@ var _ = Describe("HTTP tests", func() {
Eventually(done).Should(BeClosed())
})
It("supports read deadlines", func() {
if !go120 {
Skip("This test requires Go 1.20+")
}
if go120 {
It("supports read deadlines", func() {
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
Expect(err).ToNot(HaveOccurred())
mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := setReadDeadline(w, time.Now().Add(deadlineDelay))
body, err := io.ReadAll(r.Body)
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
Expect(body).To(ContainSubstring("aa"))
w.Write([]byte("ok"))
})
expectedEnd := time.Now().Add(deadlineDelay)
resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a'))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(r.Body)
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
Expect(body).To(ContainSubstring("aa"))
w.Write([]byte("ok"))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
Expect(err).ToNot(HaveOccurred())
Expect(time.Now().After(expectedEnd)).To(BeTrue())
Expect(string(body)).To(Equal("ok"))
})
expectedEnd := time.Now().Add(deadlineDelay)
resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a'))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
It("supports write deadlines", func() {
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
Expect(err).ToNot(HaveOccurred())
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
Expect(err).ToNot(HaveOccurred())
Expect(time.Now().After(expectedEnd)).To(BeTrue())
Expect(string(body)).To(Equal("ok"))
})
_, err = io.Copy(w, neverEnding('a'))
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
})
It("supports write deadlines", func() {
if !go120 {
Skip("This test requires Go 1.20+")
}
expectedEnd := time.Now().Add(deadlineDelay)
mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := setWriteDeadline(w, time.Now().Add(deadlineDelay))
resp, err := client.Get("https://localhost:" + port + "/write-deadline")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
_, err = io.Copy(w, neverEnding('a'))
Expect(err).To(MatchError(os.ErrDeadlineExceeded))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
Expect(err).ToNot(HaveOccurred())
Expect(time.Now().After(expectedEnd)).To(BeTrue())
Expect(string(body)).To(ContainSubstring("aa"))
})
expectedEnd := time.Now().Add(deadlineDelay)
resp, err := client.Get("https://localhost:" + port + "/write-deadline")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay))
Expect(err).ToNot(HaveOccurred())
Expect(time.Now().After(expectedEnd)).To(BeTrue())
Expect(string(body)).To(ContainSubstring("aa"))
})
}
})

View File

@@ -75,7 +75,9 @@ var _ = Describe("Key Update tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return &keyUpdateConnTracer{} })}),
getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return &keyUpdateConnTracer{}
}}),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())

View File

@@ -35,7 +35,11 @@ var _ = Describe("MITM test", func() {
Expect(err).ToNot(HaveOccurred())
serverUDPConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig)
tr := &quic.Transport{
Conn: serverUDPConn,
ConnectionIDLength: connIDLen,
}
ln, err := tr.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@@ -68,7 +72,7 @@ var _ = Describe("MITM test", func() {
}
BeforeEach(func() {
serverConfig = getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen})
serverConfig = getQuicConfig(nil)
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
clientUDPConn, err = net.ListenUDP("udp", addr)
@@ -146,12 +150,15 @@ var _ = Describe("MITM test", func() {
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
conn, err := tr.Dial(
context.Background(),
clientUDPConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
@@ -190,12 +197,15 @@ var _ = Describe("MITM test", func() {
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
conn, err := quic.Dial(
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
conn, err := tr.Dial(
context.Background(),
clientUDPConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptUniStream(context.Background())
@@ -302,20 +312,20 @@ var _ = Describe("MITM test", func() {
const rtt = 20 * time.Millisecond
runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) {
proxyPort, closeFn := startServerAndProxy(delayCb, nil)
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
_, err = quic.Dial(
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
_, err = tr.Dial(
context.Background(),
clientUDPConn,
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{
ConnectionIDLength: connIDLen,
HandshakeIdleTimeout: 2 * time.Second,
}),
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
)
return closeFn, err
return func() { tr.Close(); serverCloseFn() }, err
}
// fails immediately because client connection closes when it can't find compatible version

View File

@@ -34,10 +34,9 @@ var _ = Describe("Multiplexing", func() {
}()
}
dial := func(pconn net.PacketConn, addr net.Addr) {
conn, err := quic.Dial(
dial := func(tr *quic.Transport, addr net.Addr) {
conn, err := tr.Dial(
context.Background(),
pconn,
addr,
getTLSClientConfig(),
getQuicConfig(nil),
@@ -72,17 +71,18 @@ var _ = Describe("Multiplexing", func() {
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
dial(tr, server.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
dial(tr, server.Addr())
close(done2)
}()
timeout := 30 * time.Second
@@ -106,17 +106,18 @@ var _ = Describe("Multiplexing", func() {
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server1.Addr())
dial(tr, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn, server2.Addr())
dial(tr, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second
@@ -135,9 +136,8 @@ var _ = Describe("Multiplexing", func() {
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
server, err := quic.Listen(
conn,
tr := &quic.Transport{Conn: conn}
server, err := tr.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
@@ -146,7 +146,7 @@ var _ = Describe("Multiplexing", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
dial(tr, server.Addr())
close(done)
}()
timeout := 30 * time.Second
@@ -165,15 +165,16 @@ var _ = Describe("Multiplexing", func() {
conn1, err := net.ListenUDP("udp", addr1)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
server1, err := quic.Listen(
conn1,
server1, err := tr1.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
@@ -181,8 +182,7 @@ var _ = Describe("Multiplexing", func() {
runServer(server1)
defer server1.Close()
server2, err := quic.Listen(
conn2,
server2, err := tr2.Listen(
getTLSConfig(),
getQuicConfig(nil),
)
@@ -194,12 +194,12 @@ var _ = Describe("Multiplexing", func() {
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn2, server1.Addr())
dial(tr2, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn1, server2.Addr())
dial(tr1, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second

View File

@@ -27,12 +27,12 @@ var _ = Describe("Packetization", func() {
getTLSConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
Tracer: newTracer(serverTracer),
}),
)
Expect(err).ToNot(HaveOccurred())
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
defer server.Close()
serverAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: serverAddr,
@@ -50,10 +50,11 @@ var _ = Describe("Packetization", func() {
getTLSClientConfig(),
getQuicConfig(&quic.Config{
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
Tracer: newTracer(clientTracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
go func() {
defer GinkgoRecover()

View File

@@ -87,7 +87,7 @@ var (
logBuf *syncedBuffer
versionParam string
qlogTracer logging.Tracer
qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
enableQlog bool
version quic.VersionNumber
@@ -175,7 +175,13 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
if conf.Tracer == nil {
conf.Tracer = qlogTracer
} else if qlogTracer != nil {
conf.Tracer = logging.NewMultiplexedTracer(qlogTracer, conf.Tracer)
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID),
origTracer(ctx, p, connID),
)
}
}
}
return conf
@@ -199,8 +205,16 @@ func areHandshakesRunning() bool {
return strings.Contains(b.String(), "RunHandshake")
}
func areTransportsRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
}
var _ = AfterEach(func() {
Expect(areHandshakesRunning()).To(BeFalse())
Eventually(areTransportsRunning).Should(BeFalse())
if debugLog() {
logFile, err := os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())
@@ -224,19 +238,8 @@ func scaleDuration(d time.Duration) time.Duration {
return time.Duration(scaleFactor) * d
}
type tracer struct {
logging.NullTracer
createNewConnTracer func() logging.ConnectionTracer
}
var _ logging.Tracer = &tracer{}
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
return &tracer{createNewConnTracer: c}
}
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return t.createNewConnTracer()
func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer }
}
type packet struct {

View File

@@ -2,7 +2,6 @@ package self_test
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
@@ -25,9 +24,16 @@ var _ = Describe("Stateless Resets", func() {
It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() {
var statelessResetKey quic.StatelessResetKey
rand.Read(statelessResetKey[:])
serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey})
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
c, err := net.ListenUDP("udp", nil)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: c,
StatelessResetKey: &statelessResetKey,
ConnectionIDLength: connIDLen,
}
defer tr.Close()
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
@@ -42,7 +48,8 @@ var _ = Describe("Stateless Resets", func() {
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
<-closeServer
ln.Close()
Expect(ln.Close()).To(Succeed())
Expect(tr.Close()).To(Succeed())
}()
var drop atomic.Bool
@@ -55,14 +62,21 @@ var _ = Describe("Stateless Resets", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
conn, err := quic.DialAddr(
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
cl := &quic.Transport{
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
defer cl.Close()
conn, err := cl.Dial(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxy.LocalPort()},
getTLSClientConfig(),
getQuicConfig(&quic.Config{
ConnectionIDLength: connIDLen,
MaxIdleTimeout: 2 * time.Second,
}),
getQuicConfig(&quic.Config{MaxIdleTimeout: 2 * time.Second}),
)
Expect(err).ToNot(HaveOccurred())
str, err := conn.AcceptStream(context.Background())
@@ -77,11 +91,15 @@ var _ = Describe("Stateless Resets", func() {
close(closeServer)
time.Sleep(100 * time.Millisecond)
ln2, err := quic.ListenAddr(
fmt.Sprintf("localhost:%d", serverPort),
getTLSConfig(),
serverConfig,
)
// We need to create a new Transport here, since the old one is still sending out
// CONNECTION_CLOSE packets for (recently) closed connections).
tr2 := &quic.Transport{
Conn: c,
ConnectionIDLength: connIDLen,
StatelessResetKey: &statelessResetKey,
}
defer tr2.Close()
ln2, err := tr2.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
drop.Store(false)
@@ -100,8 +118,7 @@ var _ = Describe("Stateless Resets", func() {
_, serr = str.Read([]byte{0})
}
Expect(serr).To(HaveOccurred())
statelessResetErr := &quic.StatelessResetError{}
Expect(errors.As(serr, &statelessResetErr)).To(BeTrue())
Expect(serr).To(BeAssignableToTypeOf(&quic.StatelessResetError{}))
Expect(ln2.Close()).To(Succeed())
Eventually(acceptStopped).Should(BeClosed())
})

View File

@@ -94,6 +94,8 @@ var _ = Describe("Bidirectional streams", func() {
)
Expect(err).ToNot(HaveOccurred())
runSendingPeer(client)
client.CloseWithError(0, "")
<-conn.Context().Done()
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
@@ -149,5 +151,6 @@ var _ = Describe("Bidirectional streams", func() {
runReceivingPeer(client)
<-done1
<-done2
client.CloseWithError(0, "")
})
})

View File

@@ -201,7 +201,7 @@ var _ = Describe("Timeout tests", func() {
getTLSClientConfig(),
getQuicConfig(&quic.Config{
MaxIdleTimeout: idleTimeout,
Tracer: newTracer(func() logging.ConnectionTracer { return tr }),
Tracer: newTracer(tr),
DisablePathMTUDiscovery: true,
}),
)
@@ -473,6 +473,7 @@ var _ = Describe("Timeout tests", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {

View File

@@ -19,14 +19,6 @@ import (
. "github.com/onsi/gomega"
)
type customTracer struct{ logging.NullTracer }
func (t *customTracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return &customConnTracer{}
}
type customConnTracer struct{ logging.NullConnectionTracer }
var _ = Describe("Handshake tests", func() {
addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config {
enableQlog := mrand.Int()%3 != 0
@@ -34,22 +26,32 @@ var _ = Describe("Handshake tests", func() {
fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer)
var tracers []logging.Tracer
var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
if enableQlog {
tracers = append(tracers, qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connectionID)
fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID)
return nil
}
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connectionID)
return utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil))
}))
fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connID)
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID)
})
}
if enableCustomTracer {
tracers = append(tracers, &customTracer{})
tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return logging.NullConnectionTracer{}
})
}
c := conf.Clone()
c.Tracer = logging.NewMultiplexedTracer(tracers...)
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors))
for _, c := range tracerConstructors {
if tr := c(ctx, p, connID); tr != nil {
tracers = append(tracers, tr)
}
}
return logging.NewMultiplexedConnectionTracer(tracers...)
}
return c
}

View File

@@ -88,11 +88,14 @@ var _ = Describe("Unidirectional Streams", func() {
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
runSendingPeer(conn)
<-conn.Context().Done()
}()
client, err := quic.DialAddr(
@@ -103,6 +106,7 @@ var _ = Describe("Unidirectional Streams", func() {
)
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client)
client.CloseWithError(0, "")
})
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {

View File

@@ -54,7 +54,7 @@ var _ = Describe("0-RTT", func() {
if serverConf == nil {
serverConf = getQuicConfig(nil)
}
serverConf.Allow0RTT = func(addr net.Addr) bool { return true }
serverConf.Allow0RTT = true
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@@ -101,6 +101,7 @@ var _ = Describe("0-RTT", func() {
transfer0RTTData := func(
ln *quic.EarlyListener,
proxyPort int,
connIDLen int,
clientTLSConf *tls.Config,
clientConf *quic.Config,
testdata []byte, // data to transfer
@@ -125,13 +126,35 @@ var _ = Describe("0-RTT", func() {
if clientConf == nil {
clientConf = getQuicConfig(nil)
}
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxyPort),
clientTLSConf,
clientConf,
)
Expect(err).ToNot(HaveOccurred())
var conn quic.EarlyConnection
if connIDLen == 0 {
var err error
conn, err = quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxyPort),
clientTLSConf,
clientConf,
)
Expect(err).ToNot(HaveOccurred())
} else {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
defer tr.Close()
conn, err = tr.DialEarly(
context.Background(),
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort},
clientTLSConf,
clientConf,
)
Expect(err).ToNot(HaveOccurred())
}
defer conn.CloseWithError(0, "")
str, err := conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
@@ -199,8 +222,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(addr net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -212,8 +235,9 @@ var _ = Describe("0-RTT", func() {
transfer0RTTData(
ln,
proxy.LocalPort(),
connIDLen,
clientTLSConf,
getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}),
getQuicConfig(nil),
PRData,
)
@@ -252,8 +276,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -334,8 +358,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -373,7 +397,7 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData)
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
num0RTT := atomic.LoadUint32(&num0RTTPackets)
numDropped := atomic.LoadUint32(&num0RTTDropped)
@@ -410,8 +434,8 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return true },
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -448,7 +472,7 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, GeneratePRData(5000)) // ~5 packets
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets
mutex.Lock()
defer mutex.Unlock()
@@ -471,8 +495,8 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams + 1,
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -516,8 +540,8 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams - 1,
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -544,8 +568,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -571,8 +595,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return false }, // application rejects 0-RTT
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: false, // application rejects 0-RTT
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -592,13 +616,13 @@ var _ = Describe("0-RTT", func() {
DescribeTable("flow control limits",
func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{Allow0RTT: func(net.Addr) bool { return true }})
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
addFlowControlLimit(firstConf, 3)
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
secondConf := getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
})
addFlowControlLimit(secondConf, 100)
ln, err := quic.ListenAddrEarly(
@@ -675,7 +699,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: 1,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -750,8 +774,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Allow0RTT: true,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
@@ -768,7 +792,7 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData)
transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData)
Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial))
zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets())

View File

@@ -2,23 +2,25 @@ package tools
import (
"bufio"
"context"
"fmt"
"io"
"log"
"os"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
)
func NewQlogger(logger io.Writer) logging.Tracer {
return qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser {
func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role)
filename := fmt.Sprintf("log_%x_%s.qlog", connID.Bytes(), role)
fmt.Fprintf(logger, "Creating %s.\n", filename)
f, err := os.Create(filename)
if err != nil {
@@ -26,6 +28,6 @@ func NewQlogger(logger io.Writer) logging.Tracer {
return nil
}
bw := bufio.NewWriter(f)
return utils.NewBufferedWriteCloser(bw, f)
})
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bw, f), p, connID)
}
}

View File

@@ -85,7 +85,9 @@ var _ = Describe("Handshake tests", func() {
serverConfig := &quic.Config{}
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return serverTracer
}
server, cl := startServer(getTLSConfig(), serverConfig)
defer cl()
clientTracer := &versionNegotiationTracer{}
@@ -93,7 +95,9 @@ var _ = Describe("Handshake tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}),
maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer {
return clientTracer
}}),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion))
@@ -111,10 +115,12 @@ var _ = Describe("Handshake tests", func() {
expectedVersion := protocol.SupportedVersions[0]
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverTracer := &versionNegotiationTracer{}
serverConfig := &quic.Config{}
serverConfig.Versions = supportedVersions
serverTracer := &versionNegotiationTracer{}
serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer })
serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return serverTracer
}
server, cl := startServer(getTLSConfig(), serverConfig)
defer cl()
clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10}
@@ -123,9 +129,11 @@ var _ = Describe("Handshake tests", func() {
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
maybeAddQlogTracer(&quic.Config{
maybeAddQLOGTracer(&quic.Config{
Versions: clientVersions,
Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }),
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
return clientTracer
},
}),
)
Expect(err).ToNot(HaveOccurred())

View File

@@ -47,7 +47,7 @@ var _ = Describe("Handshake RTT tests", func() {
context.Background(),
proxy.LocalAddr().String(),
getTLSClientConfig(),
maybeAddQlogTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
maybeAddQLOGTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}),
)
Expect(err).To(HaveOccurred())
expectDurationInRTTs(startTime, 1)

View File

@@ -58,7 +58,7 @@ func TestQuicVersionNegotiation(t *testing.T) {
RunSpecs(t, "Version Negotiation Suite")
}
func maybeAddQlogTracer(c *quic.Config) *quic.Config {
func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
if c == nil {
c = &quic.Config{}
}
@@ -69,22 +69,13 @@ func maybeAddQlogTracer(c *quic.Config) *quic.Config {
if c.Tracer == nil {
c.Tracer = qlogger
} else if qlogger != nil {
c.Tracer = logging.NewMultiplexedTracer(qlogger, c.Tracer)
origTracer := c.Tracer
c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogger(ctx, p, connID),
origTracer(ctx, p, connID),
)
}
}
return c
}
type tracer struct {
logging.NullTracer
createNewConnTracer func() logging.ConnectionTracer
}
var _ logging.Tracer = &tracer{}
func newTracer(c func() logging.ConnectionTracer) logging.Tracer {
return &tracer{createNewConnTracer: c}
}
func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
return t.createNewConnTracer()
}

View File

@@ -239,21 +239,12 @@ type ConnectionIDGenerator interface {
// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// GetConfigForClient is called for incoming connections.
// If the error is not nil, the connection attempt is refused.
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
Versions []VersionNumber
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
ConnectionIDGenerator ConnectionIDGenerator
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
// If this value is zero, the timeout is set to 5 seconds.
@@ -314,9 +305,6 @@ type Config struct {
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
MaxIncomingUniStreams int64
// The StatelessResetKey is used to generate stateless reset tokens.
// If no key is configured, sending of stateless resets is disabled.
StatelessResetKey *StatelessResetKey
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
// every half of MaxIdleTimeout, whichever is smaller).
@@ -330,13 +318,15 @@ type Config struct {
// It has no effect for a client.
DisableVersionNegotiationPackets bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// When set, 0-RTT is enabled. When not set, 0-RTT is disabled.
// Only valid for the server.
// Warning: This API should not be considered stable and might change soon.
Allow0RTT func(net.Addr) bool
Allow0RTT bool
// Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool
Tracer logging.Tracer
Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer
}
type ClientHelloInfo struct {
RemoteAddr net.Addr
}
// ConnectionState records basic details about a QUIC connection

View File

@@ -116,7 +116,7 @@ type cryptoSetup struct {
clientHelloWritten bool
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
zeroRTTParametersChan chan<- *wire.TransportParameters
allow0RTT func() bool
allow0RTT bool
rttStats *utils.RTTStats
@@ -197,7 +197,7 @@ func NewCryptoSetupServer(
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
allow0RTT func() bool,
allow0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
@@ -210,14 +210,13 @@ func NewCryptoSetupServer(
tp,
runner,
tlsConf,
allow0RTT != nil,
allow0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.allow0RTT = allow0RTT
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
return cs
}
@@ -253,6 +252,7 @@ func newCryptoSetup(
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
allow0RTT: enable0RTT,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
@@ -503,7 +503,7 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
return false
}
if !h.allow0RTT() {
if !h.allow0RTT {
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
return false
}

View File

@@ -95,7 +95,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -177,7 +177,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -218,7 +218,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -253,7 +253,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -378,10 +378,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
var allow0RTT func() bool
if enable0RTT {
allow0RTT = func() bool { return true }
}
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
@@ -402,7 +398,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverTransportParameters,
sRunner,
serverConf,
allow0RTT,
enable0RTT,
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -541,7 +537,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sTransportParameters,
sRunner,
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -596,7 +592,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -655,7 +651,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),

View File

@@ -5,7 +5,6 @@
package mocklogging
import (
context "context"
net "net"
reflect "reflect"
@@ -73,17 +72,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2,
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3)
}
// TracerForConnection mocks base method.
func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
ret0, _ := ret[0].(logging.ConnectionTracer)
return ret0
}
// TracerForConnection indicates an expected call of TracerForConnection.
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/interop/http09"
"github.com/quic-go/quic-go/interop/utils"
"github.com/quic-go/quic-go/qlog"
)
var errUnsupported = errors.New("unsupported test case")
@@ -65,11 +64,7 @@ func runTestcase(testcase string) error {
flag.Parse()
urls := flag.Args()
getLogWriter, err := utils.GetQLOGWriter()
if err != nil {
return err
}
quicConf := &quic.Config{Tracer: qlog.NewTracer(getLogWriter)}
quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer}
if testcase == "http3" {
r := &http3.RoundTripper{

View File

@@ -13,7 +13,6 @@ import (
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/interop/http09"
"github.com/quic-go/quic-go/interop/utils"
"github.com/quic-go/quic-go/qlog"
)
var tlsConf *tls.Config
@@ -38,15 +37,10 @@ func main() {
testcase := os.Getenv("TESTCASE")
getLogWriter, err := utils.GetQLOGWriter()
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
// a quic.Config that doesn't do a Retry
quicConf := &quic.Config{
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
Tracer: qlog.NewTracer(getLogWriter),
Allow0RTT: testcase == "zerortt",
Tracer: utils.NewQLOGConnectionTracer,
}
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
if err != nil {
@@ -59,10 +53,7 @@ func main() {
}
switch testcase {
case "zerortt":
quicConf.Allow0RTT = func(net.Addr) bool { return true }
fallthrough
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect":
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt":
err = runHTTP09Server(quicConf)
case "chacha20":
reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)

View File

@@ -2,14 +2,17 @@ package utils
import (
"bufio"
"context"
"fmt"
"io"
"log"
"os"
"strings"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
)
// GetSSLKeyLog creates a file for the TLS key log
@@ -25,25 +28,23 @@ func GetSSLKeyLog() (io.WriteCloser, error) {
return f, nil
}
// GetQLOGWriter creates the QLOGDIR and returns the GetLogWriter callback
func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.WriteCloser, error) {
// NewQLOGConnectionTracer create a qlog file in QLOGDIR
func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
qlogDir := os.Getenv("QLOGDIR")
if len(qlogDir) == 0 {
return nil, nil
return nil
}
if _, err := os.Stat(qlogDir); os.IsNotExist(err) {
if err := os.MkdirAll(qlogDir, 0o666); err != nil {
return nil, fmt.Errorf("failed to create qlog dir %s: %s", qlogDir, err.Error())
log.Fatalf("failed to create qlog dir %s: %v", qlogDir, err)
}
}
return func(_ logging.Perspective, connID []byte) io.WriteCloser {
path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID)
f, err := os.Create(path)
if err != nil {
log.Printf("Failed to create qlog file %s: %s", path, err.Error())
return nil
}
log.Printf("Created qlog file: %s\n", path)
return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f)
}, nil
path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID)
f, err := os.Create(path)
if err != nil {
log.Printf("Failed to create qlog file %s: %s", path, err.Error())
return nil
}
log.Printf("Created qlog file: %s\n", path)
return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID)
}

View File

@@ -3,7 +3,6 @@
package logging
import (
"context"
"net"
"time"
@@ -101,12 +100,6 @@ type ShortHeader struct {
// A Tracer traces events.
type Tracer interface {
// TracerForConnection requests a new tracer for a connection.
// The ODCID is the original destination connection ID:
// The destination connection ID that the client used on the first Initial packet it sent on this connection.
// If nil is returned, tracing will be disabled for this connection.
TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer
SentPacket(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason)

View File

@@ -5,7 +5,6 @@
package logging
import (
context "context"
net "net"
reflect "reflect"
@@ -72,17 +71,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2,
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3)
}
// TracerForConnection mocks base method.
func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2)
ret0, _ := ret[0].(ConnectionTracer)
return ret0
}
// TracerForConnection indicates an expected call of TracerForConnection.
func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2)
}

View File

@@ -1,7 +1,6 @@
package logging
import (
"context"
"net"
"time"
)
@@ -23,16 +22,6 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer {
return &tracerMultiplexer{tracers}
}
func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer {
var connTracers []ConnectionTracer
for _, t := range m.tracers {
if ct := t.TracerForConnection(ctx, p, odcid); ct != nil {
connTracers = append(connTracers, ct)
}
}
return NewMultiplexedConnectionTracer(connTracers...)
}
func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) {
for _, t := range m.tracers {
t.SentPacket(remote, hdr, size, frames)

View File

@@ -1,7 +1,6 @@
package logging
import (
"context"
"errors"
"net"
"time"
@@ -37,46 +36,6 @@ var _ = Describe("Tracing", func() {
tracer = NewMultiplexedTracer(tr1, tr2)
})
It("multiplexes the TracerForConnection call", func() {
ctx := context.Background()
connID := protocol.ParseConnectionID([]byte{1, 2, 3})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
tracer.TracerForConnection(ctx, PerspectiveClient, connID)
})
It("uses multiple connection tracers", func() {
ctx := context.Background()
ctr1 := NewMockConnectionTracer(mockCtrl)
ctr2 := NewMockConnectionTracer(mockCtrl)
connID := protocol.ParseConnectionID([]byte{1, 2, 3})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr2)
tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID)
ctr1.EXPECT().LossTimerCanceled()
ctr2.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled()
})
It("handles tracers that return a nil ConnectionTracer", func() {
ctx := context.Background()
ctr1 := NewMockConnectionTracer(mockCtrl)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID)
tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID)
ctr1.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled()
})
It("returns nil when all tracers return a nil ConnectionTracer", func() {
ctx := context.Background()
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID)
Expect(tracer.TracerForConnection(ctx, PerspectiveClient, connID)).To(BeNil())
})
It("traces the PacketSent event", func() {
remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)}
hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}

View File

@@ -1,7 +1,6 @@
package logging
import (
"context"
"net"
"time"
)
@@ -12,9 +11,6 @@ type NullTracer struct{}
var _ Tracer = &NullTracer{}
func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer {
return NullConnectionTracer{}
}
func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {}
func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
}

View File

@@ -1,65 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/quic-go/quic-go (interfaces: Multiplexer)
// Package quic is a generated GoMock package.
package quic
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
logging "github.com/quic-go/quic-go/logging"
)
// MockMultiplexer is a mock of Multiplexer interface.
type MockMultiplexer struct {
ctrl *gomock.Controller
recorder *MockMultiplexerMockRecorder
}
// MockMultiplexerMockRecorder is the mock recorder for MockMultiplexer.
type MockMultiplexerMockRecorder struct {
mock *MockMultiplexer
}
// NewMockMultiplexer creates a new mock instance.
func NewMockMultiplexer(ctrl *gomock.Controller) *MockMultiplexer {
mock := &MockMultiplexer{ctrl: ctrl}
mock.recorder = &MockMultiplexerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder {
return m.recorder
}
// AddConn mocks base method.
func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 *StatelessResetKey, arg3 logging.Tracer) (packetHandlerManager, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(packetHandlerManager)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddConn indicates an expected call of AddConn.
func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3)
}
// RemoveConn mocks base method.
func (m *MockMultiplexer) RemoveConn(arg0 indexableConn) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveConn", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveConn indicates an expected call of RemoveConn.
func (mr *MockMultiplexerMockRecorder) RemoveConn(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveConn", reflect.TypeOf((*MockMultiplexer)(nil).RemoveConn), arg0)
}

View File

@@ -61,7 +61,7 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa
}
// AddWithConnID mocks base method.
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool {
func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() (packetHandler, bool)) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2)
ret0, _ := ret[0].(bool)
@@ -74,6 +74,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2)
}
// Close mocks base method.
func (m *MockPacketHandlerManager) Close(arg0 error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close", arg0)
}
// Close indicates an expected call of Close.
func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0)
}
// CloseServer mocks base method.
func (m *MockPacketHandlerManager) CloseServer() {
m.ctrl.T.Helper()
@@ -86,20 +98,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer))
}
// Destroy mocks base method.
func (m *MockPacketHandlerManager) Destroy() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Destroy")
ret0, _ := ret[0].(error)
return ret0
}
// Destroy indicates an expected call of Destroy.
func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy))
}
// Get mocks base method.
func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) {
m.ctrl.T.Helper()
@@ -115,6 +113,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0)
}
// GetByResetToken mocks base method.
func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetToken) (packetHandler, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByResetToken", arg0)
ret0, _ := ret[0].(packetHandler)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// GetByResetToken indicates an expected call of GetByResetToken.
func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0)
}
// GetStatelessResetToken mocks base method.
func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken {
m.ctrl.T.Helper()
@@ -176,15 +189,3 @@ func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 interface{}) *gomock
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0)
}
// SetServer mocks base method.
func (m *MockPacketHandlerManager) SetServer(arg0 unknownPacketHandler) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetServer", arg0)
}
// SetServer indicates an expected call of SetServer.
func (mr *MockPacketHandlerManagerMockRecorder) SetServer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).SetServer), arg0)
}

View File

@@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
type PacketHandlerManager = packetHandlerManager
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer"
type Multiplexer = multiplexer
// Need to use source mode for the batchConn, since reflect mode follows type aliases.
// See https://github.com/golang/mock/issues/244 for details.
//

View File

@@ -6,7 +6,6 @@ import (
"sync"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
var (
@@ -14,30 +13,19 @@ var (
connMuxer multiplexer
)
type indexableConn interface {
LocalAddr() net.Addr
}
type indexableConn interface{ LocalAddr() net.Addr }
type multiplexer interface {
AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error)
AddConn(conn indexableConn)
RemoveConn(indexableConn) error
}
type connManager struct {
connIDLen int
statelessResetKey *StatelessResetKey
tracer logging.Tracer
manager packetHandlerManager
}
// The connMultiplexer listens on multiple net.PacketConns and dispatches
// incoming packets to the connection handler.
type connMultiplexer struct {
mutex sync.Mutex
conns map[string] /* LocalAddr().String() */ connManager
newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests
conns map[string] /* LocalAddr().String() */ indexableConn
logger utils.Logger
}
@@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{}
func getMultiplexer() multiplexer {
connMuxerOnce.Do(func() {
connMuxer = &connMultiplexer{
conns: make(map[string]connManager),
logger: utils.DefaultLogger.WithPrefix("muxer"),
newPacketHandlerManager: newPacketHandlerMap,
conns: make(map[string]indexableConn),
logger: utils.DefaultLogger.WithPrefix("muxer"),
}
})
return connMuxer
}
func (m *connMultiplexer) AddConn(
c net.PacketConn,
connIDLen int,
statelessResetKey *StatelessResetKey,
tracer logging.Tracer,
) (packetHandlerManager, error) {
func (m *connMultiplexer) index(addr net.Addr) string {
return addr.Network() + " " + addr.String()
}
func (m *connMultiplexer) AddConn(c indexableConn) {
m.mutex.Lock()
defer m.mutex.Unlock()
addr := c.LocalAddr()
connIndex := addr.Network() + " " + addr.String()
connIndex := m.index(c.LocalAddr())
p, ok := m.conns[connIndex]
if !ok {
manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger)
if err != nil {
return nil, err
}
p = connManager{
connIDLen: connIDLen,
statelessResetKey: statelessResetKey,
manager: manager,
tracer: tracer,
}
m.conns[connIndex] = p
} else {
if p.connIDLen != connIDLen {
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
}
if statelessResetKey != nil && p.statelessResetKey != statelessResetKey {
return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn")
}
if tracer != p.tracer {
return nil, fmt.Errorf("cannot use different tracers on the same packet conn")
}
if ok {
// Panics if we're already listening on this connection.
// This is a safeguard because we're introducing a breaking API change, see
// https://github.com/quic-go/quic-go/issues/3727 for details.
// We'll remove this at a later time, when most users of the library have made the switch.
panic("connection already exists") // TODO: write a nice message
}
return p.manager, nil
m.conns[connIndex] = p
}
func (m *connMultiplexer) RemoveConn(c indexableConn) error {
m.mutex.Lock()
defer m.mutex.Unlock()
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String()
connIndex := m.index(c.LocalAddr())
if _, ok := m.conns[connIndex]; !ok {
return fmt.Errorf("cannote remove connection, connection is unknown")
}

View File

@@ -3,71 +3,24 @@ package quic
import (
"net"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type testConn struct {
counter int
net.PacketConn
}
var _ = Describe("Multiplexer", func() {
It("adds a new packet conn ", func() {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
_, err := getMultiplexer().AddConn(conn, 8, nil, nil)
Expect(err).ToNot(HaveOccurred())
It("adds new packet conns", func() {
conn1 := NewMockPacketConn(mockCtrl)
conn1.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234})
getMultiplexer().AddConn(conn1)
conn2 := NewMockPacketConn(mockCtrl)
conn2.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1235})
getMultiplexer().AddConn(conn2)
})
It("recognizes when the same connection is added twice", func() {
srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}
pconn := NewMockPacketConn(mockCtrl)
pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn := testConn{PacketConn: pconn}
tracer := mocklogging.NewMockTracer(mockCtrl)
_, err := getMultiplexer().AddConn(conn, 8, srk, tracer)
Expect(err).ToNot(HaveOccurred())
conn.counter++
_, err = getMultiplexer().AddConn(conn, 8, srk, tracer)
Expect(err).ToNot(HaveOccurred())
Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1))
})
It("errors when adding an existing conn with a different connection ID length", func() {
It("panics when the same connection is added twice", func() {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 5, nil, nil)
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 6, nil, nil)
Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs"))
})
It("errors when adding an existing conn with a different stateless rest key", func() {
srk1 := &StatelessResetKey{'f', 'o', 'o'}
srk2 := &StatelessResetKey{'b', 'a', 'r'}
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 7, srk1, nil)
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 7, srk2, nil)
Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn"))
})
It("errors when adding an existing conn with different tracers", func() {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2)
_, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
Expect(err).ToNot(HaveOccurred())
_, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl))
Expect(err).To(MatchError("cannot use different tracers on the same packet conn"))
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2)
getMultiplexer().AddConn(conn)
Expect(func() { getMultiplexer().AddConn(conn) }).To(Panic())
})
})

View File

@@ -5,28 +5,22 @@ import (
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// rawConn is a connection that allow reading of a receivedPacket.
// rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface {
ReadPacket() (*receivedPacket, error)
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error)
LocalAddr() net.Addr
SetReadDeadline(time.Time) error
io.Closer
}
@@ -36,113 +30,49 @@ type closePacket struct {
info *packetInfo
}
// The packetHandlerMap stores packetHandlers, identified by connection ID.
// It is used:
// * by the server to store connections
// * when multiplexing outgoing connections to store clients
type unknownPacketHandler interface {
handlePacket(*receivedPacket)
setCloseError(error)
}
var errListenerAlreadySet = errors.New("listener already set")
type packetHandlerMap struct {
mutex sync.Mutex
conn rawConn
connIDLen int
closeQueue chan closePacket
mutex sync.Mutex
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
listening chan struct{} // is closed when listen returns
closed bool
closeChan chan struct{}
enqueueClosePacket func(closePacket)
deleteRetiredConnsAfter time.Duration
statelessResetEnabled bool
statelessResetMutex sync.Mutex
statelessResetHasher hash.Hash
statelessResetMutex sync.Mutex
statelessResetHasher hash.Hash
tracer logging.Tracer
logger utils.Logger
}
var _ packetHandlerManager = &packetHandlerMap{}
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
conn, ok := c.(interface{ SetReadBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
}
size, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if size >= protocol.DesiredReceiveBufferSize {
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
return nil
}
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return fmt.Errorf("failed to increase receive buffer size: %w", err)
}
newSize, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredReceiveBufferSize {
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
return nil
}
// only print warnings about the UDP receive buffer size once
var receiveBufferWarningOnce sync.Once
func newPacketHandlerMap(
c net.PacketConn,
connIDLen int,
statelessResetKey *StatelessResetKey,
tracer logging.Tracer,
logger utils.Logger,
) (packetHandlerManager, error) {
if err := setReceiveBuffer(c, logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
receiveBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
})
}
}
conn, err := wrapConn(c)
if err != nil {
return nil, err
}
m := &packetHandlerMap{
conn: conn,
connIDLen: connIDLen,
listening: make(chan struct{}),
func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
h := &packetHandlerMap{
closeChan: make(chan struct{}),
handlers: make(map[protocol.ConnectionID]packetHandler),
resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
closeQueue: make(chan closePacket, 4),
statelessResetEnabled: statelessResetKey != nil,
tracer: tracer,
enqueueClosePacket: enqueueClosePacket,
logger: logger,
}
if m.statelessResetEnabled {
m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:])
if key != nil {
h.statelessResetHasher = hmac.New(sha256.New, key[:])
}
go m.listen()
go m.runCloseQueue()
if logger.Debug() {
go m.logUsage()
if h.logger.Debug() {
go h.logUsage()
}
return m, nil
return h
}
func (h *packetHandlerMap) logUsage() {
@@ -150,7 +80,7 @@ func (h *packetHandlerMap) logUsage() {
var printedZero bool
for {
select {
case <-h.listening:
case <-h.closeChan:
return
case <-ticker.C:
}
@@ -192,7 +122,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
return true
}
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool {
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
h.mutex.Lock()
defer h.mutex.Unlock()
@@ -200,7 +130,10 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
conn := fn()
conn, ok := fn()
if !ok {
return false
}
h.handlers[clientDestConnID] = conn
h.handlers[newConnID] = conn
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
@@ -233,12 +166,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
if connClosePacket != nil {
handler = newClosedLocalConn(
func(addr net.Addr, info *packetInfo) {
select {
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
default:
// Oops, we're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
},
pers,
h.logger,
@@ -265,17 +193,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
})
}
func (h *packetHandlerMap) runCloseQueue() {
for {
select {
case <-h.listening:
return
case p := <-h.closeQueue:
h.conn.WritePacket(p.payload, p.addr, p.info.OOB())
}
}
}
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
h.mutex.Lock()
h.resetTokens[token] = handler
@@ -288,19 +205,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken)
h.mutex.Unlock()
}
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
h.mutex.Lock()
h.server = s
h.mutex.Unlock()
defer h.mutex.Unlock()
handler, ok := h.resetTokens[token]
return handler, ok
}
func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
if h.server == nil {
h.mutex.Unlock()
return
}
h.server = nil
var wg sync.WaitGroup
for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
@@ -316,23 +230,16 @@ func (h *packetHandlerMap) CloseServer() {
wg.Wait()
}
// Destroy closes the underlying connection and waits until listen() has returned.
// It does not close active connections.
func (h *packetHandlerMap) Destroy() error {
if err := h.conn.Close(); err != nil {
return err
}
<-h.listening // wait until listening returns
return nil
}
func (h *packetHandlerMap) close(e error) error {
func (h *packetHandlerMap) Close(e error) {
h.mutex.Lock()
if h.closed {
h.mutex.Unlock()
return nil
return
}
close(h.closeChan)
var wg sync.WaitGroup
for _, handler := range h.handlers {
wg.Add(1)
@@ -341,89 +248,14 @@ func (h *packetHandlerMap) close(e error) error {
wg.Done()
}(handler)
}
if h.server != nil {
h.server.setCloseError(e)
}
h.closed = true
h.mutex.Unlock()
wg.Wait()
return getMultiplexer().RemoveConn(h.conn)
}
func (h *packetHandlerMap) listen() {
defer close(h.listening)
for {
p, err := h.conn.ReadPacket()
//nolint:staticcheck // SA1019 ignore this!
// TODO: This code is used to ignore wsa errors on Windows.
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
// See https://github.com/quic-go/quic-go/issues/1737 for details.
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
h.logger.Debugf("Temporary error reading from conn: %w", err)
continue
}
if err != nil {
h.close(err)
return
}
h.handlePacket(p)
}
}
func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, h.connIDLen)
if err != nil {
h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
if h.tracer != nil {
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
p.buffer.MaybeRelease()
return
}
h.mutex.Lock()
defer h.mutex.Unlock()
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if handler, ok := h.handlers[connID]; ok {
handler.handlePacket(p)
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
go h.maybeSendStatelessReset(p, connID)
return
}
if h.server == nil { // no server set
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
return
}
h.server.handlePacket(p)
}
func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
return false
}
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go sess.destroy(&StatelessResetError{Token: token})
return true
}
return false
}
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
var token protocol.StatelessResetToken
if !h.statelessResetEnabled {
if h.statelessResetHasher == nil {
// Return a random stateless reset token.
// This token will be sent in the server's transport parameters.
// By using a random token, an off-path attacker won't be able to disrupt the connection.
@@ -437,24 +269,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID)
h.statelessResetMutex.Unlock()
return token
}
func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
defer p.buffer.Release()
if !h.statelessResetEnabled {
return
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {
return
}
token := h.GetStatelessResetToken(connID)
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
h.logger.Debugf("Error sending Stateless Reset: %s", err)
}
}

View File

@@ -6,405 +6,188 @@ import (
"net"
"time"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Packet Handler Map", func() {
type packetToRead struct {
addr net.Addr
data []byte
err error
}
var (
handler *packetHandlerMap
conn *MockPacketConn
tracer *mocklogging.MockTracer
packetChan chan packetToRead
connIDLen int
statelessResetKey *StatelessResetKey
)
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
b, err := (&wire.ExtendedHeader{
Header: wire.Header{
Type: t,
DestConnectionID: connID,
Length: length,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen2,
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
return b
}
getPacket := func(connID protocol.ConnectionID) []byte {
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
}
BeforeEach(func() {
statelessResetKey = nil
connIDLen = 0
tracer = mocklogging.NewMockTracer(mockCtrl)
packetChan = make(chan packetToRead, 10)
It("adds and gets", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).To(Equal(handler))
})
JustBeforeEach(func() {
conn = NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
p, ok := <-packetChan
if !ok {
return 0, nil, errors.New("closed")
It("refused to add duplicates", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
Expect(m.Add(connID, handler)).To(BeFalse())
})
It("removes", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
m.Remove(connID)
_, ok := m.Get(connID)
Expect(ok).To(BeFalse())
Expect(m.Add(connID, handler)).To(BeTrue())
})
It("retires", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
handler := NewMockPacketHandler(mockCtrl)
Expect(m.Add(connID, handler)).To(BeTrue())
m.Retire(connID)
_, ok := m.Get(connID)
Expect(ok).To(BeTrue())
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("adds newly to-be-constructed handlers", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
var called bool
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) {
called = true
return NewMockPacketHandler(mockCtrl), true
})).To(BeTrue())
Expect(called).To(BeTrue())
Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) {
Fail("didn't expect the constructor to be executed")
return nil, false
})).To(BeFalse())
})
It("adds, gets and removes reset tokens", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
handler := NewMockPacketHandler(mockCtrl)
m.AddResetToken(token, handler)
h, ok := m.GetByResetToken(token)
Expect(ok).To(BeTrue())
Expect(h).To(Equal(h))
m.RemoveResetToken(token)
_, ok = m.GetByResetToken(token)
Expect(ok).To(BeFalse())
})
It("generates stateless reset token, if no key is set", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
token := m.GetStatelessResetToken(connID)
for i := 0; i < 1000; i++ {
to := m.GetStatelessResetToken(connID)
Expect(to).ToNot(Equal(token))
token = to
}
})
It("generates stateless reset token, if a key is set", func() {
var key StatelessResetKey
rand.Read(key[:])
m := newPacketHandlerMap(&key, nil, utils.DefaultLogger)
b := make([]byte, 8)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
token := m.GetStatelessResetToken(connID)
Expect(token).ToNot(BeZero())
Expect(m.GetStatelessResetToken(connID)).To(Equal(token))
// generate a new connection ID
rand.Read(b)
connID2 := protocol.ParseConnectionID(b)
Expect(m.GetStatelessResetToken(connID2)).ToNot(Equal(token))
})
It("replaces locally closed connections", func() {
var closePackets []closePacket
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
handler := NewMockPacketHandler(mockCtrl)
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.Add(connID, handler)).To(BeTrue())
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar"))
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).ToNot(Equal(handler))
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(&receivedPacket{remoteAddr: addr})
Expect(closePackets).To(HaveLen(1))
Expect(closePackets[0].addr).To(Equal(addr))
Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("replaces remote closed connections", func() {
var closePackets []closePacket
m := newPacketHandlerMap(nil, func(p closePacket) { closePackets = append(closePackets, p) }, utils.DefaultLogger)
dur := scaleDuration(50 * time.Millisecond)
m.deleteRetiredConnsAfter = dur
handler := NewMockPacketHandler(mockCtrl)
connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(m.Add(connID, handler)).To(BeTrue())
m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil)
h, ok := m.Get(connID)
Expect(ok).To(BeTrue())
Expect(h).ToNot(Equal(handler))
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(&receivedPacket{remoteAddr: addr})
Expect(closePackets).To(BeEmpty())
time.Sleep(dur)
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})
It("closes the server", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
if i%2 == 0 {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
} else {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
conn.EXPECT().shutdown()
}
return copy(b, p.data), p.addr, p.err
}).AnyTimes()
phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger)
Expect(err).ToNot(HaveOccurred())
handler = phm.(*packetHandlerMap)
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.CloseServer()
})
It("closes", func() {
getMultiplexer() // make the sync.Once execute
// replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer
mockMultiplexer := NewMockMultiplexer(mockCtrl)
origMultiplexer := connMuxer
connMuxer = mockMultiplexer
defer func() {
connMuxer = origMultiplexer
}()
testErr := errors.New("test error ")
conn1 := NewMockPacketHandler(mockCtrl)
conn1.EXPECT().destroy(testErr)
conn2 := NewMockPacketHandler(mockCtrl)
conn2.EXPECT().destroy(testErr)
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1)
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2)
mockMultiplexer.EXPECT().RemoveConn(gomock.Any())
handler.close(testErr)
close(packetChan)
Eventually(handler.listening).Should(BeClosed())
})
Context("other operations", func() {
AfterEach(func() {
// delete connections and the server before closing
// They might be mock implementations, and we'd have to register the expected calls before otherwise.
handler.mutex.Lock()
for connID := range handler.handlers {
delete(handler.handlers, connID)
}
handler.server = nil
handler.mutex.Unlock()
conn.EXPECT().Close().MaxTimes(1)
close(packetChan)
handler.Destroy()
Eventually(handler.listening).Should(BeClosed())
})
Context("handling packets", func() {
BeforeEach(func() {
connIDLen = 5
})
It("handles packets for different packet handlers on the same packet conn", func() {
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
packetHandler1 := NewMockPacketHandler(mockCtrl)
packetHandler2 := NewMockPacketHandler(mockCtrl)
handledPacket1 := make(chan struct{})
handledPacket2 := make(chan struct{})
packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID1))
close(handledPacket1)
})
packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID2))
close(handledPacket2)
})
handler.Add(connID1, packetHandler1)
handler.Add(connID2, packetHandler2)
packetChan <- packetToRead{data: getPacket(connID1)}
packetChan <- packetToRead{data: getPacket(connID2)}
Eventually(handledPacket1).Should(BeClosed())
Eventually(handledPacket2).Should(BeClosed())
})
It("drops unparseable packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: []byte{0, 1, 2, 3},
})
})
It("deletes removed connections immediately", func() {
handler.deleteRetiredConnsAfter = time.Hour
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID)
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("deletes retired connection entries after a wait time", func() {
handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
conn := NewMockPacketHandler(mockCtrl)
handler.Add(connID, conn)
handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond))
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
// don't EXPECT any calls to handlePacket of the MockPacketHandler
})
It("passes packets arriving late for closed connections to that connection", func() {
handler.deleteRetiredConnsAfter = time.Hour
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
packetHandler := NewMockPacketHandler(mockCtrl)
handled := make(chan struct{})
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
close(handled)
})
handler.Add(connID, packetHandler)
handler.Retire(connID)
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
Eventually(handled).Should(BeClosed())
})
It("drops packets for unknown receivers", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
handler.handlePacket(&receivedPacket{data: getPacket(connID)})
})
It("closes the packet handlers when reading from the conn fails", func() {
done := make(chan struct{})
packetHandler := NewMockPacketHandler(mockCtrl)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) {
Expect(e).To(HaveOccurred())
close(done)
})
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
packetChan <- packetToRead{err: errors.New("read failed")}
Eventually(done).Should(BeClosed())
})
It("continues listening for temporary errors", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler)
err := deadlineError{}
Expect(err.Temporary()).To(BeTrue())
packetChan <- packetToRead{err: err}
// don't EXPECT any calls to packetHandler.destroy
time.Sleep(50 * time.Millisecond)
})
It("says if a connection ID is already taken", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
})
It("says if a connection ID is already taken, for AddWithConnID", func() {
clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1})
Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue())
Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse())
})
})
Context("running a server", func() {
It("adds a server", func() {
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
cid, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(cid).To(Equal(connID))
})
handler.SetServer(server)
handler.handlePacket(&receivedPacket{data: p})
})
It("closes all server connections", func() {
handler.SetServer(NewMockUnknownPacketHandler(mockCtrl))
clientConn := NewMockPacketHandler(mockCtrl)
clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
serverConn := NewMockPacketHandler(mockCtrl)
serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
serverConn.EXPECT().shutdown()
handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn)
handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn)
handler.CloseServer()
})
It("stops handling packets with unknown connection IDs after the server is closed", func() {
connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88})
p := getPacket(connID)
server := NewMockUnknownPacketHandler(mockCtrl)
// don't EXPECT any calls to server.handlePacket
handler.SetServer(server)
handler.CloseServer()
handler.handlePacket(&receivedPacket{data: p})
})
})
Context("stateless resets", func() {
BeforeEach(func() {
connIDLen = 5
})
Context("handling", func() {
It("handles stateless resets", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
destroyed := make(chan struct{})
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
defer close(destroyed)
Expect(err).To(HaveOccurred())
var resetErr *StatelessResetError
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.Token).To(Equal(token))
})
packetChan <- packetToRead{data: packet}
Eventually(destroyed).Should(BeClosed())
})
It("handles stateless resets for 0-length connection IDs", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
destroyed := make(chan struct{})
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
var resetErr *StatelessResetError
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.Token).To(Equal(token))
close(destroyed)
})
packetChan <- packetToRead{data: packet}
Eventually(destroyed).Should(BeClosed())
})
It("removes reset tokens", func() {
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
packetHandler := NewMockPacketHandler(mockCtrl)
handler.Add(connID, packetHandler)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, NewMockPacketHandler(mockCtrl))
handler.RemoveResetToken(token)
// don't EXPECT any call to packetHandler.destroy()
packetHandler.EXPECT().handlePacket(gomock.Any())
p := append([]byte{0x40} /* short header packet */, connID.Bytes()...)
p = append(p, make([]byte, 50)...)
p = append(p, token[:]...)
handler.handlePacket(&receivedPacket{data: p})
})
It("ignores packets too small to contain a stateless reset", func() {
handler.connIDLen = 0
packetHandler := NewMockPacketHandler(mockCtrl)
token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddResetToken(token, packetHandler)
done := make(chan struct{})
// don't EXPECT any calls here, but register the closing of the done channel
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) {
close(done)
}).AnyTimes()
packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)}
Consistently(done).ShouldNot(BeClosed())
})
})
Context("generating", func() {
BeforeEach(func() {
var key StatelessResetKey
rand.Read(key[:])
statelessResetKey = &key
})
It("generates stateless reset tokens", func() {
connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})
Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2)))
})
It("sends stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
defer close(done)
Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
})
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
Eventually(done).Should(BeClosed())
})
It("doesn't send stateless resets for small packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
})
Context("if no key is configured", func() {
It("doesn't send stateless resets", func() {
addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
p := append([]byte{40}, make([]byte, 100)...)
handler.handlePacket(&receivedPacket{
buffer: getPacketBuffer(),
remoteAddr: addr,
data: p,
})
// make sure there are no Write calls on the packet conn
time.Sleep(50 * time.Millisecond)
})
})
})
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
testErr := errors.New("shutdown")
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
conn.EXPECT().destroy(testErr)
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.Close(testErr)
// check that Close can be called multiple times
m.Close(errors.New("close"))
})
})

View File

@@ -2,7 +2,6 @@ package qlog
import (
"bytes"
"context"
"fmt"
"io"
"log"
@@ -49,26 +48,6 @@ func init() {
const eventChanSize = 50
type tracer struct {
logging.NullTracer
getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser
}
var _ logging.Tracer = &tracer{}
// NewTracer creates a new qlog tracer.
func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer {
return &tracer{getLogWriter: getLogWriter}
}
func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer {
if w := t.getLogWriter(p, odcid.Bytes()); w != nil {
return NewConnectionTracer(w, p, odcid)
}
return nil
}
type connectionTracer struct {
mutex sync.Mutex

View File

@@ -2,7 +2,6 @@ package qlog
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
@@ -51,17 +50,6 @@ type entry struct {
}
var _ = Describe("Tracing", func() {
Context("tracer", func() {
It("returns nil when there's no io.WriteCloser", func() {
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil })
Expect(t.TracerForConnection(
context.Background(),
logging.PerspectiveClient,
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
)).To(BeNil())
})
})
It("stops writing when encountering an error", func() {
buf := &bytes.Buffer{}
t := NewConnectionTracer(
@@ -88,9 +76,8 @@ var _ = Describe("Tracing", func() {
BeforeEach(func() {
buf = &bytes.Buffer{}
t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) })
tracer = t.TracerForConnection(
context.Background(),
tracer = NewConnectionTracer(
nopWriteCloser(buf),
logging.PerspectiveServer,
protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
)

View File

@@ -1,8 +1,11 @@
package quic
import (
"bytes"
"io"
"log"
"runtime/pprof"
"strings"
"sync"
"testing"
@@ -29,6 +32,20 @@ var _ = BeforeSuite(func() {
log.SetOutput(io.Discard)
})
func areServersRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
}
func areTransportsRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*Transport).listen")
}
var _ = AfterEach(func() {
mockCtrl.Finish()
Eventually(areServersRunning).Should(BeFalse())
Eventually(areTransportsRunning()).Should(BeFalse())
})

View File

@@ -22,7 +22,7 @@ type sconn struct {
var _ sendConn = &sconn{}
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn {
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn {
return &sconn{
rawConn: c,
remoteAddr: remote,
@@ -51,24 +51,3 @@ func (c *sconn) LocalAddr() net.Addr {
}
return addr
}
type spconn struct {
net.PacketConn
remoteAddr net.Addr
}
var _ sendConn = &spconn{}
func newSendPconn(c net.PacketConn, remote net.Addr) sendConn {
return &spconn{PacketConn: c, remoteAddr: remote}
}
func (c *spconn) Write(p []byte) error {
_, err := c.WriteTo(p, c.remoteAddr)
return err
}
func (c *spconn) RemoteAddr() net.Addr {
return c.remoteAddr
}

View File

@@ -17,7 +17,9 @@ var _ = Describe("Connection (for sending packets)", func() {
BeforeEach(func() {
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = NewMockPacketConn(mockCtrl)
c = newSendPconn(packetConn, addr)
rawConn, err := wrapConn(packetConn)
Expect(err).ToNot(HaveOccurred())
c = newSendConn(rawConn, addr, nil)
})
It("writes", func() {

224
server.go
View File

@@ -20,7 +20,7 @@ import (
)
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
var ErrServerClosed = errors.New("quic: Server closed")
var ErrServerClosed = errors.New("quic: server closed")
// packetHandler handles packets
type packetHandler interface {
@@ -30,18 +30,13 @@ type packetHandler interface {
getPerspective() protocol.Perspective
}
type unknownPacketHandler interface {
handlePacket(*receivedPacket)
setCloseError(error)
}
type packetHandlerManager interface {
Get(protocol.ConnectionID) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool
Destroy() error
connRunner
SetServer(unknownPacketHandler)
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
Close(error)
CloseServer()
connRunner
}
type quicConn interface {
@@ -70,13 +65,12 @@ type baseServer struct {
config *Config
conn rawConn
// If the server is started with ListenAddr, we create a packet conn.
// If it is started with Listen, we take a packet conn as a parameter.
createdPacketConn bool
tokenGenerator *handshake.TokenGenerator
connHandler packetHandlerManager
connIDGenerator ConnectionIDGenerator
connHandler packetHandlerManager
onClose func()
receivedPackets chan *receivedPacket
@@ -92,6 +86,7 @@ type baseServer struct {
protocol.ConnectionID, /* client dest connection ID */
protocol.ConnectionID, /* destination connection ID */
protocol.ConnectionID, /* source connection ID */
ConnectionIDGenerator,
protocol.StatelessResetToken,
*Config,
*tls.Config,
@@ -111,11 +106,11 @@ type baseServer struct {
connQueue chan quicConn
connQueueLen int32 // to be used as an atomic
tracer logging.Tracer
logger utils.Logger
}
var _ unknownPacketHandler = &baseServer{}
// A Listener listens for incoming QUIC connections.
// It returns connections once the handshake has completed.
type Listener struct {
@@ -166,37 +161,36 @@ func (l *EarlyListener) Addr() net.Addr {
// The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
s, err := listenAddr(addr, tlsConf, config, false)
conn, err := listenUDP(addr)
if err != nil {
return nil, err
}
return &Listener{baseServer: s}, nil
return (&Transport{
Conn: conn,
createdConn: true,
isSingleUse: true,
}).Listen(tlsConf, config)
}
// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
s, err := listenAddr(addr, tlsConf, config, true)
conn, err := listenUDP(addr)
if err != nil {
return nil, err
}
return &EarlyListener{baseServer: s}, nil
return (&Transport{
Conn: conn,
createdConn: true,
isSingleUse: true,
}).ListenEarly(tlsConf, config)
}
func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
func listenUDP(addr string) (*net.UDPConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
serv, err := listen(conn, tlsConf, config, acceptEarly)
if err != nil {
return nil, err
}
serv.createdPacketConn = true
return serv, nil
return net.ListenUDP("udp", udpAddr)
}
// Listen listens for QUIC connections on a given net.PacketConn. If the
@@ -210,67 +204,51 @@ func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bo
// Furthermore, it must define an application control (using NextProtos).
// The quic.Config may be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
s, err := listen(conn, tlsConf, config, false)
if err != nil {
return nil, err
}
return &Listener{baseServer: s}, nil
tr := &Transport{Conn: conn, isSingleUse: true}
return tr.Listen(tlsConf, config)
}
// ListenEarly works like Listen, but it returns connections before the handshake completes.
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
s, err := listen(conn, tlsConf, config, true)
if err != nil {
return nil, err
}
return &EarlyListener{baseServer: s}, nil
tr := &Transport{Conn: conn, isSingleUse: true}
return tr.ListenEarly(tlsConf, config)
}
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(config); err != nil {
return nil, err
}
config = populateServerConfig(config)
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {
return nil, err
}
func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
tlsConf *tls.Config,
config *Config,
tracer logging.Tracer,
onClose func(),
acceptEarly bool,
) (*baseServer, error) {
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
if err != nil {
return nil, err
}
c, err := wrapConn(conn)
if err != nil {
return nil, err
}
s := &baseServer{
conn: c,
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: tokenGenerator,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
onClose: onClose,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
}
go s.run()
connHandler.SetServer(s)
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil
}
@@ -322,18 +300,12 @@ func (s *baseServer) Close() error {
if s.serverError == nil {
s.serverError = ErrServerClosed
}
// If the server was started with ListenAddr, we created the packet conn.
// We need to close it in order to make the go routine reading from that conn return.
createdPacketConn := s.createdPacketConn
s.closed = true
close(s.errorChan)
s.mutex.Unlock()
<-s.running
s.connHandler.CloseServer()
if createdPacketConn {
return s.connHandler.Destroy()
}
s.onClose()
return nil
}
@@ -358,8 +330,8 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
case s.receivedPackets <- p:
default:
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
}
}
}
@@ -371,8 +343,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
if wire.IsVersionNegotiationPacket(p.data) {
s.logger.Debugf("Dropping Version Negotiation packet.")
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@@ -385,16 +357,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) {
if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize {
s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size())
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
if err != nil { // should never happen
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@@ -406,8 +378,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
if wire.Is0RTTPacket(p.data) {
if !s.acceptEarlyConns {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@@ -418,16 +390,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Error parsing packet: %s", err)
return false
}
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@@ -437,8 +409,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
// There's little point in sending a Stateless Reset, since the client
// might not have received the token yet.
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
@@ -456,8 +428,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
}
return false
}
@@ -470,8 +442,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
@@ -480,8 +452,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
}
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
@@ -508,8 +480,8 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
continue
}
for _, p := range q.packets {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
p.buffer.Release()
}
@@ -544,8 +516,8 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release()
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
return errors.New("too short connection ID")
}
@@ -617,26 +589,31 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return nil
}
connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
var conn quicConn
tracingID := nextConnTracingID()
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
config := s.config
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
return nil, false
}
config = populateConfig(conf)
}
var tracer logging.ConnectionTracer
if s.config.Tracer != nil {
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = s.config.Tracer.TracerForConnection(
context.WithValue(context.Background(), ConnectionTracingKey, tracingID),
protocol.PerspectiveServer,
connID,
)
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
}
conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info),
@@ -646,8 +623,9 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID),
s.config,
config,
s.tlsConf,
s.tokenGenerator,
clientAddrIsValid,
@@ -665,10 +643,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
return conn
return conn, true
}); !added {
// TODO: don't just drop the packet
// Properly reject the connection attempt.
go func() {
defer p.buffer.Release()
if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil {
s.logger.Debugf("Error rejecting connection: %s", err)
}
}()
return nil
}
go conn.run()
@@ -712,7 +694,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
// Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
srcConnID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
@@ -741,8 +723,8 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
// append the Retry integrity tag
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
buf.Data = append(buf.Data, tag[:]...)
if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
}
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
return err
@@ -755,8 +737,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
}
// don't return the error here. Just drop the packet.
return nil
@@ -764,8 +746,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header)
hdrLen := extHdr.ParsedLen()
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
// don't return the error here. Just drop the packet.
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
}
return nil
}
@@ -818,8 +800,8 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
replyHdr.Log(s.logger)
wire.LogFrame(s.logger, ccf, true)
if s.config.Tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
}
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
return err
@@ -829,8 +811,8 @@ func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest pro
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
if s.config.Tracer != nil {
s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions)
if s.tracer != nil {
s.tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions)
}
if _, err := s.conn.WritePacket(data, remote, oob); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)

View File

@@ -1,15 +1,12 @@
package quic
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"reflect"
"runtime/pprof"
"strings"
"sync"
"sync/atomic"
"time"
@@ -24,17 +21,10 @@ import (
"github.com/quic-go/quic-go/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func areServersRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*baseServer).run")
}
var _ = Describe("Server", func() {
var (
conn *MockPacketConn
@@ -96,15 +86,19 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
conn = NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1)
wait := make(chan struct{})
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(_ []byte) (int, net.Addr, error) {
<-wait
return 0, nil, errors.New("done")
}).MaxTimes(1)
conn.EXPECT().SetReadDeadline(gomock.Any()).Do(func(time.Time) {
close(wait)
conn.EXPECT().SetReadDeadline(time.Time{})
}).MaxTimes(1)
tlsConf = testdata.GetTLSConfig()
tlsConf.NextProtos = []string{"proto1"}
})
AfterEach(func() {
Eventually(areServersRunning).Should(BeFalse())
})
It("errors when no tls.Config is given", func() {
_, err := ListenAddr("localhost:0", nil, nil)
Expect(err).To(HaveOccurred())
@@ -114,7 +108,7 @@ var _ = Describe("Server", func() {
It("errors when the Config contains an invalid version", func() {
version := protocol.VersionNumber(0x1234)
_, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
Expect(err).To(MatchError("invalid QUIC version: 0x1234"))
})
It("fills in default values if options are not set in the Config", func() {
@@ -138,7 +132,6 @@ var _ = Describe("Server", func() {
HandshakeIdleTimeout: 1337 * time.Hour,
MaxIdleTimeout: 42 * time.Minute,
KeepAlivePeriod: 5 * time.Second,
StatelessResetKey: &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'},
RequireAddressValidation: requireAddrVal,
}
ln, err := Listen(conn, tlsConf, &config)
@@ -150,7 +143,6 @@ var _ = Describe("Server", func() {
Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute))
Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal)))
Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second))
Expect(server.config.StatelessResetKey).To(Equal(&StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}))
// stop the listener
Expect(ln.Close()).To(Succeed())
})
@@ -178,6 +170,7 @@ var _ = Describe("Server", func() {
Context("server accepting connections that completed the handshake", func() {
var (
tr *Transport
serv *baseServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
@@ -185,7 +178,8 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
tracer = mocklogging.NewMockTracer(mockCtrl)
ln, err := Listen(conn, tlsConf, &Config{Tracer: tracer})
tr = &Transport{Conn: conn, Tracer: tracer}
ln, err := tr.Listen(tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
serv = ln.baseServer
phm = NewMockPacketHandlerManager(mockCtrl)
@@ -193,8 +187,7 @@ var _ = Describe("Server", func() {
})
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
serv.Close()
tr.Close()
})
Context("handling packets", func() {
@@ -274,16 +267,15 @@ var _ = Describe("Server", func() {
var newConnID protocol.ConnectionID
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
@@ -293,6 +285,7 @@ var _ = Describe("Server", func() {
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
tokenP protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -474,17 +467,16 @@ var _ = Describe("Server", func() {
var newConnID protocol.ConnectionID
gomock.InOrder(
phm.EXPECT().Get(connID),
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
newConnID = c
phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken {
newConnID = c
return token
})
fn()
return true
_, ok := fn()
return ok
}),
)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
@@ -495,6 +487,7 @@ var _ = Describe("Server", func() {
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
_ ConnectionIDGenerator,
tokenP protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -537,12 +530,11 @@ var _ = Describe("Server", func() {
It("drops packets if the receive queue is full", func() {
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).AnyTimes()
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
acceptConn := make(chan struct{})
var counter uint32 // to be used as an atomic, so we query it in Eventually
@@ -554,6 +546,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -598,7 +591,6 @@ var _ = Describe("Server", func() {
It("only creates a single connection for a duplicate Initial", func() {
var createdConn bool
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ sendConn,
runner connRunner,
@@ -607,6 +599,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -618,15 +611,19 @@ var _ = Describe("Server", func() {
_ protocol.VersionNumber,
) quicConn {
createdConn = true
return conn
return NewMockQUICConn(mockCtrl)
}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
p := getInitial(connID)
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false)
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func([]byte, net.Addr) { close(done) })
Expect(serv.handlePacketImpl(p)).To(BeTrue())
Expect(createdConn).To(BeFalse())
Eventually(done).Should(BeClosed())
})
It("rejects new connection attempts if the accept queue is full", func() {
@@ -638,6 +635,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -659,12 +657,11 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize)
@@ -709,6 +706,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -730,12 +728,11 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
@@ -753,8 +750,7 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
phm.EXPECT().CloseServer()
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
@@ -794,7 +790,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) })
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() (packetHandler, bool)) { close(done) })
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
})
@@ -968,6 +964,7 @@ var _ = Describe("Server", func() {
serv.setCloseError(testErr)
Eventually(done).Should(BeClosed())
serv.onClose() // shutdown
})
It("returns immediately, if an error occurred before", func() {
@@ -977,6 +974,7 @@ var _ = Describe("Server", func() {
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
}
serv.onClose() // shutdown
})
It("returns when the context is canceled", func() {
@@ -994,6 +992,85 @@ var _ = Describe("Server", func() {
Eventually(done).Should(BeClosed())
})
It("uses the config returned by GetConfigClient", func() {
conn := NewMockQUICConn(mockCtrl)
conf := &Config{MaxIncomingStreams: 1234}
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return conf, nil }})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
s, err := serv.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(s).To(Equal(conn))
close(done)
}()
handshakeChan := make(chan struct{})
serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234))
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().run().Do(func() {})
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
Consistently(done).ShouldNot(BeClosed())
close(handshakeChan) // complete the handshake
Eventually(done).Should(BeClosed())
})
It("rejects a connection attempt when GetConfigClient returns an error", func() {
serv.config = populateServerConfig(&Config{GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { return nil, errors.New("rejected") }})
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
_, ok := fn()
return ok
})
done := make(chan struct{})
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) {
defer close(done)
rejectHdr := parseHeader(b)
Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial))
return len(b), nil
})
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
)
Eventually(done).Should(BeClosed())
})
It("accepts new connections when the handshake completes", func() {
conn := NewMockQUICConn(mockCtrl)
@@ -1015,6 +1092,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -1032,12 +1110,11 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
@@ -1064,7 +1141,6 @@ var _ = Describe("Server", func() {
})
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
serv.Close()
})
@@ -1089,6 +1165,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -1106,10 +1183,10 @@ var _ = Describe("Server", func() {
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.baseServer.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
@@ -1131,6 +1208,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -1152,10 +1230,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
}).Times(protocol.MaxAcceptQueueSize)
for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
serv.baseServer.handlePacket(getInitialWithRandomDestConnID())
@@ -1194,6 +1272,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -1213,10 +1292,10 @@ var _ = Describe("Server", func() {
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.baseServer.handlePacket(p)
// make sure there are no Write calls on the packet conn
@@ -1234,8 +1313,7 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
phm.EXPECT().CloseServer()
conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID
conn.EXPECT().getPerspective().MaxTimes(2) // initOnce for every conn ID
Expect(serv.Close()).To(Succeed())
Eventually(done).Should(BeClosed())
})
@@ -1243,6 +1321,7 @@ var _ = Describe("Server", func() {
Context("0-RTT", func() {
var (
tr *Transport
serv *baseServer
phm *MockPacketHandlerManager
tracer *mocklogging.MockTracer
@@ -1250,7 +1329,8 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
tracer = mocklogging.NewMockTracer(mockCtrl)
ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer})
tr = &Transport{Conn: conn, Tracer: tracer}
ln, err := tr.ListenEarly(tlsConf, nil)
Expect(err).ToNot(HaveOccurred())
phm = NewMockPacketHandlerManager(mockCtrl)
serv = ln.baseServer
@@ -1259,7 +1339,7 @@ var _ = Describe("Server", func() {
AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
serv.Close()
tr.Close()
})
It("passes packets to existing connections", func() {
@@ -1317,6 +1397,7 @@ var _ = Describe("Server", func() {
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
_ *Config,
_ *tls.Config,
@@ -1341,12 +1422,11 @@ var _ = Describe("Server", func() {
return conn
}
tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any())
phm.EXPECT().Get(connID)
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool {
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
fn()
return true
_, ok := fn()
return ok
})
serv.handlePacket(initial)
Eventually(called).Should(BeClosed())

414
transport.go Normal file
View File

@@ -0,0 +1,414 @@
package quic
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type Transport struct {
// A single net.PacketConn can only be handled by one Transport.
// Bad things will happen if passed to multiple Transports.
//
// If the connection satisfies the OOBCapablePacketConn interface
// (as a net.UDPConn does), ECN and packet info support will be enabled.
// In this case, optimized syscalls might be used, skipping the
// ReadFrom and WriteTo calls to read / write packets.
Conn net.PacketConn
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If unset, a 4 byte connection ID will be used.
ConnectionIDLength int
// Use for generating new connection IDs.
// This allows the application to control of the connection IDs used,
// which allows routing / load balancing based on connection IDs.
// All Connection IDs returned by the ConnectionIDGenerator MUST
// have the same length.
ConnectionIDGenerator ConnectionIDGenerator
// The StatelessResetKey is used to generate stateless reset tokens.
// If no key is configured, sending of stateless resets is disabled.
StatelessResetKey *StatelessResetKey
// A Tracer traces events that don't belong to a single QUIC connection.
Tracer logging.Tracer
handlerMap packetHandlerManager
mutex sync.Mutex
initOnce sync.Once
initErr error
// Set in init.
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
connIDLen int
// Set in init.
// If no ConnectionIDGenerator is set, this is set to a default.
connIDGenerator ConnectionIDGenerator
server unknownPacketHandler
conn rawConn
closeQueue chan closePacket
listening chan struct{} // is closed when listen returns
closed bool
createdConn bool
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
logger utils.Logger
}
// Listen starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed.
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(conf); err != nil {
return nil, err
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(true); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
if err != nil {
return nil, err
}
t.server = s
return &Listener{baseServer: s}, nil
}
// ListenEarly starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed.
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(conf); err != nil {
return nil, err
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(true); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
if err != nil {
return nil, err
}
t.server = s
return &EarlyListener{baseServer: s}, nil
}
// Dial dials a new connection to a remote host (not using 0-RTT).
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
}
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
}
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
conn, ok := c.(interface{ SetReadBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
}
size, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if size >= protocol.DesiredReceiveBufferSize {
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
return nil
}
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return fmt.Errorf("failed to increase receive buffer size: %w", err)
}
newSize, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredReceiveBufferSize {
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
return nil
}
// only print warnings about the UDP receive buffer size once
var receiveBufferWarningOnce sync.Once
func (t *Transport) init(isServer bool) error {
t.initOnce.Do(func() {
getMultiplexer().AddConn(t.Conn)
conn, err := wrapConn(t.Conn)
if err != nil {
t.initErr = err
return
}
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
t.listening = make(chan struct{})
t.closeQueue = make(chan closePacket, 4)
if t.ConnectionIDGenerator != nil {
t.connIDGenerator = t.ConnectionIDGenerator
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
} else {
connIDLen := t.ConnectionIDLength
if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
connIDLen = protocol.DefaultConnectionIDLength
}
t.connIDLen = connIDLen
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
}
go t.listen(conn)
go t.runCloseQueue()
})
return t.initErr
}
func (t *Transport) enqueueClosePacket(p closePacket) {
select {
case t.closeQueue <- p:
default:
// Oops, we're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
}
func (t *Transport) runCloseQueue() {
for {
select {
case <-t.listening:
return
case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, p.addr, p.info.OOB())
}
}
}
// Close closes the underlying connection and waits until listen has returned.
// It is invalid to start new listeners or connections after that.
func (t *Transport) Close() error {
t.close(errors.New("closing"))
if t.createdConn {
if err := t.conn.Close(); err != nil {
return err
}
} else {
t.conn.SetReadDeadline(time.Now())
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
}
<-t.listening // wait until listening returns
return nil
}
func (t *Transport) closeServer() {
t.handlerMap.CloseServer()
t.mutex.Lock()
t.server = nil
if t.isSingleUse {
t.closed = true
}
t.mutex.Unlock()
if t.createdConn {
t.Conn.Close()
}
if t.isSingleUse {
t.conn.SetReadDeadline(time.Now())
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
<-t.listening // wait until listening returns
}
}
func (t *Transport) close(e error) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.closed {
return
}
t.handlerMap.Close(e)
if t.server != nil {
t.server.setCloseError(e)
}
t.closed = true
}
func (t *Transport) listen(conn rawConn) {
defer close(t.listening)
defer getMultiplexer().RemoveConn(t.Conn)
if err := setReceiveBuffer(t.Conn, t.logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
receiveBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
})
}
}
for {
p, err := conn.ReadPacket()
//nolint:staticcheck // SA1019 ignore this!
// TODO: This code is used to ignore wsa errors on Windows.
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
// See https://github.com/quic-go/quic-go/issues/1737 for details.
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
t.mutex.Lock()
closed := t.closed
t.mutex.Unlock()
if closed {
return
}
t.logger.Debugf("Temporary error reading from conn: %w", err)
continue
}
if err != nil {
t.close(err)
return
}
t.handlePacket(p)
}
}
func (t *Transport) handlePacket(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil {
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
if t.Tracer != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
p.buffer.MaybeRelease()
return
}
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if handler, ok := t.handlerMap.Get(connID); ok {
handler.handlePacket(p)
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
go t.maybeSendStatelessReset(p, connID)
return
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server == nil { // no server set
t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
return
}
t.server.handlePacket(p)
}
func (t *Transport) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
defer p.buffer.Release()
if t.StatelessResetKey == nil {
return
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {
return
}
token := t.handlerMap.GetStatelessResetToken(connID)
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
t.logger.Debugf("Error sending Stateless Reset: %s", err)
}
}
func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
return false
}
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go conn.destroy(&StatelessResetError{Token: token})
return true
}
return false
}

291
transport_test.go Normal file
View File

@@ -0,0 +1,291 @@
package quic
import (
"bytes"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"time"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Transport", func() {
type packetToRead struct {
addr net.Addr
data []byte
err error
}
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
b, err := (&wire.ExtendedHeader{
Header: wire.Header{
Type: t,
DestConnectionID: connID,
Length: length,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen2,
}).Append(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
return b
}
getPacket := func(connID protocol.ConnectionID) []byte {
return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2)
}
newMockPacketConn := func(packetChan <-chan packetToRead) *MockPacketConn {
conn := NewMockPacketConn(mockCtrl)
conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes()
conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) {
p, ok := <-packetChan
if !ok {
return 0, nil, errors.New("closed")
}
return copy(b, p.data), p.addr, p.err
}).AnyTimes()
// for shutdown
conn.EXPECT().SetReadDeadline(gomock.Any()).AnyTimes()
return conn
}
It("handles packets for different packet handlers on the same packet conn", func() {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
tr.init(true)
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
handled := make(chan struct{}, 2)
phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
h := NewMockPacketHandler(mockCtrl)
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
defer GinkgoRecover()
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID1))
handled <- struct{}{}
})
return h, true
})
phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
h := NewMockPacketHandler(mockCtrl)
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
defer GinkgoRecover()
connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(connID2))
handled <- struct{}{}
})
return h, true
})
packetChan <- packetToRead{data: getPacket(connID1)}
packetChan <- packetToRead{data: getPacket(connID2)}
Eventually(handled).Should(Receive())
Eventually(handled).Should(Receive())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("closes listeners", func() {
packetChan := make(chan packetToRead)
tr := &Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
ln, err := tr.Listen(&tls.Config{}, nil)
Expect(err).ToNot(HaveOccurred())
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
phm.EXPECT().CloseServer()
Expect(ln.Close()).To(Succeed())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("drops unparseable packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
packetChan := make(chan packetToRead)
tracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: 10,
Tracer: tracer,
}
tr.init(true)
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
packetChan <- packetToRead{
addr: addr,
data: []byte{0, 1, 2, 3},
}
Eventually(dropped).Should(BeClosed())
// shutdown
close(packetChan)
tr.Close()
})
It("closes when reading from the conn fails", func() {
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(true)
tr.handlerMap = phm
done := make(chan struct{})
phm.EXPECT().Close(gomock.Any()).Do(func(error) { close(done) })
packetChan <- packetToRead{err: errors.New("read failed")}
Eventually(done).Should(BeClosed())
// shutdown
close(packetChan)
tr.Close()
})
It("continues listening after temporary errors", func() {
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.init(true)
tr.handlerMap = phm
tempErr := deadlineError{}
Expect(tempErr.Temporary()).To(BeTrue())
packetChan <- packetToRead{err: tempErr}
// don't expect any calls to phm.Close
time.Sleep(50 * time.Millisecond)
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("handles short header packets resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: connID.Len(),
}
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
var token protocol.StatelessResetToken
rand.Read(token[:])
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, token[:]...)
conn := NewMockPacketHandler(mockCtrl)
gomock.InOrder(
phm.EXPECT().GetByResetToken(token),
phm.EXPECT().Get(connID).Return(conn, true),
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.data).To(Equal(b))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
}),
)
packetChan <- packetToRead{data: b}
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("handles stateless resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
tr := Transport{Conn: newMockPacketConn(packetChan)}
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
var token protocol.StatelessResetToken
rand.Read(token[:])
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, token[:]...)
conn := NewMockPacketHandler(mockCtrl)
gomock.InOrder(
phm.EXPECT().GetByResetToken(token).Return(conn, true),
conn.EXPECT().destroy(gomock.Any()).Do(func(err error) {
Expect(err).To(MatchError(&StatelessResetError{Token: token}))
}),
)
packetChan <- packetToRead{data: b}
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
It("sends stateless resets", func() {
connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5})
packetChan := make(chan packetToRead)
conn := newMockPacketConn(packetChan)
tr := Transport{
Conn: conn,
StatelessResetKey: &StatelessResetKey{1, 2, 3, 4},
ConnectionIDLength: connID.Len(),
}
tr.init(true)
defer tr.Close()
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm
var b []byte
b, err := wire.AppendShortHeader(b, connID, 1337, 2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred())
b = append(b, make([]byte, protocol.MinStatelessResetSize-len(b)+1)...)
var token protocol.StatelessResetToken
rand.Read(token[:])
written := make(chan struct{})
gomock.InOrder(
phm.EXPECT().GetByResetToken(gomock.Any()),
phm.EXPECT().Get(connID),
phm.EXPECT().GetStatelessResetToken(connID).Return(token),
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).Do(func(b []byte, _ net.Addr) {
defer close(written)
Expect(bytes.Contains(b, token[:])).To(BeTrue())
}),
)
packetChan <- packetToRead{data: b}
Eventually(written).Should(BeClosed())
// shutdown
phm.EXPECT().Close(gomock.Any())
close(packetChan)
tr.Close()
})
})