forked from quic-go/quic-go
use Transport.VerifySourceAddress to control the Retry Mechanism (#4362)
* use Transport.VerifySourceAddress to control the Retry Mechanism This can be used to rate-limit handshakes originating from unverified source addresses. Rate-limiting for handshakes can be implemented using the GetConfigForClient callback on the Config. * pass the remote address to Transport.VerifySourceAddress
This commit is contained in:
@@ -11,11 +11,10 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -50,7 +49,7 @@ var _ = Describe("Handshake drop tests", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{Conn: conn}
|
||||
if doRetry {
|
||||
tr.MaxUnvalidatedHandshakes = -1
|
||||
tr.VerifySourceAddress = func(net.Addr) bool { return true }
|
||||
}
|
||||
ln, err = tr.Listen(tlsConf, conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -54,15 +54,15 @@ var _ = Describe("Handshake RTT tests", func() {
|
||||
|
||||
// 1 RTT for verifying the source address
|
||||
// 1 RTT for the TLS handshake
|
||||
It("is forward-secure after 2 RTTs", func() {
|
||||
It("is forward-secure after 2 RTTs with Retry", func() {
|
||||
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
udpConn, err := net.ListenUDP("udp", laddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
MaxUnvalidatedHandshakes: -1,
|
||||
Conn: udpConn,
|
||||
VerifySourceAddress: func(net.Addr) bool { return true },
|
||||
}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -464,147 +462,6 @@ var _ = Describe("Handshake tests", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Context("limiting handshakes", func() {
|
||||
var conn *net.UDPConn
|
||||
|
||||
BeforeEach(func() {
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err = net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() { conn.Close() })
|
||||
|
||||
It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
|
||||
const limit = 3
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
MaxUnvalidatedHandshakes: limit,
|
||||
}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
|
||||
// Block all handshakes.
|
||||
handshakes := make(chan struct{})
|
||||
var tlsConf tls.Config
|
||||
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
handshakes <- struct{}{}
|
||||
return getTLSConfig(), nil
|
||||
}
|
||||
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
const additional = 2
|
||||
results := make([]struct{ retry, closed atomic.Bool }, limit+additional)
|
||||
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
|
||||
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
|
||||
// exactly 2 to experience a Retry.
|
||||
for i := 0; i < limit+additional; i++ {
|
||||
go func(index int) {
|
||||
defer GinkgoRecover()
|
||||
quicConf := getQuicConfig(&quic.Config{
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) },
|
||||
ClosedConnection: func(error) { results[index].closed.Store(true) },
|
||||
}
|
||||
},
|
||||
})
|
||||
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.CloseWithError(0, "")
|
||||
}(i)
|
||||
}
|
||||
numRetries := func() (n int) {
|
||||
for i := 0; i < limit+additional; i++ {
|
||||
if results[i].retry.Load() {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
numClosed := func() (n int) {
|
||||
for i := 0; i < limit+2; i++ {
|
||||
if results[i].closed.Load() {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
Eventually(numRetries).Should(Equal(additional))
|
||||
// allow the handshakes to complete
|
||||
for i := 0; i < limit+additional; i++ {
|
||||
Eventually(handshakes).Should(Receive())
|
||||
}
|
||||
Eventually(numClosed).Should(Equal(limit + additional))
|
||||
Expect(numRetries()).To(Equal(additional)) // just to be on the safe side
|
||||
})
|
||||
|
||||
It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
|
||||
const limit = 3
|
||||
tr := &quic.Transport{
|
||||
Conn: conn,
|
||||
MaxHandshakes: limit,
|
||||
}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
|
||||
// Block all handshakes.
|
||||
handshakes := make(chan struct{})
|
||||
var tlsConf tls.Config
|
||||
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
handshakes <- struct{}{}
|
||||
return getTLSConfig(), nil
|
||||
}
|
||||
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer ln.Close()
|
||||
|
||||
const additional = 2
|
||||
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
|
||||
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
|
||||
// exactly 2 to experience a Retry.
|
||||
var numSuccessful, numFailed atomic.Int32
|
||||
for i := 0; i < limit+additional; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
quicConf := getQuicConfig(&quic.Config{
|
||||
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") },
|
||||
}
|
||||
},
|
||||
})
|
||||
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
|
||||
if err != nil {
|
||||
var transportErr *quic.TransportError
|
||||
if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused {
|
||||
Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err))
|
||||
}
|
||||
numFailed.Add(1)
|
||||
return
|
||||
}
|
||||
numSuccessful.Add(1)
|
||||
conn.CloseWithError(0, "")
|
||||
}()
|
||||
}
|
||||
Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional))
|
||||
// allow the handshakes to complete
|
||||
for i := 0; i < limit; i++ {
|
||||
Eventually(handshakes).Should(Receive())
|
||||
}
|
||||
Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit))
|
||||
|
||||
// make sure that the server is reachable again after these handshakes have completed
|
||||
go func() { <-handshakes }() // allow this handshake to complete immediately
|
||||
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.CloseWithError(0, "")
|
||||
})
|
||||
})
|
||||
|
||||
Context("ALPN", func() {
|
||||
It("negotiates an application protocol", func() {
|
||||
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
|
||||
@@ -718,8 +575,8 @@ var _ = Describe("Handshake tests", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
MaxUnvalidatedHandshakes: -1,
|
||||
Conn: udpConn,
|
||||
VerifySourceAddress: func(net.Addr) bool { return true },
|
||||
}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
|
||||
@@ -43,7 +43,7 @@ var _ = Describe("MITM test", func() {
|
||||
}
|
||||
addTracer(serverTransport)
|
||||
if forceAddressValidation {
|
||||
serverTransport.MaxUnvalidatedHandshakes = -1
|
||||
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
|
||||
}
|
||||
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -461,8 +461,8 @@ var _ = Describe("0-RTT", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer udpConn.Close()
|
||||
tr := &quic.Transport{
|
||||
Conn: udpConn,
|
||||
MaxUnvalidatedHandshakes: -1,
|
||||
Conn: udpConn,
|
||||
VerifySourceAddress: func(net.Addr) bool { return true },
|
||||
}
|
||||
addTracer(tr)
|
||||
defer tr.Close()
|
||||
|
||||
Reference in New Issue
Block a user