add Allow0RTT opt in the quic.Config to control 0-RTT on the server side (#3635)

This commit is contained in:
Marten Seemann
2023-01-04 16:18:11 -08:00
committed by GitHub
parent 421893b1c4
commit b52d34008f
12 changed files with 98 additions and 51 deletions

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"math"
"net"
"sync"
"time"
@@ -115,6 +116,7 @@ type cryptoSetup struct {
clientHelloWritten bool
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
zeroRTTParametersChan chan<- *wire.TransportParameters
allow0RTT func() bool
rttStats *utils.RTTStats
@@ -195,7 +197,7 @@ func NewCryptoSetupServer(
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
allow0RTT func() bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
@@ -208,13 +210,14 @@ func NewCryptoSetupServer(
tp,
runner,
tlsConf,
enable0RTT,
allow0RTT != nil,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.allow0RTT = allow0RTT
cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
return cs
}
@@ -267,7 +270,7 @@ func newCryptoSetup(
}
var maxEarlyData uint32
if enable0RTT {
maxEarlyData = 0xffffffff
maxEarlyData = math.MaxUint32
}
cs.extraConf = &qtls.ExtraConfig{
GetExtensions: extHandler.GetExtensions,
@@ -490,13 +493,17 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
return false
}
valid := h.ourParams.ValidFor0RTT(t.Parameters)
if valid {
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
} else {
if !valid {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
return false
}
return valid
if !h.allow0RTT() {
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
return false
}
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
return true
}
// rejected0RTT is called for the client when the server rejects 0-RTT.

View File

@@ -95,7 +95,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -177,7 +177,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -218,7 +218,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -253,7 +253,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -378,6 +378,10 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.VersionTLS,
)
var allow0RTT func() bool
if enable0RTT {
allow0RTT = func() bool { return true }
}
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
@@ -398,7 +402,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverTransportParameters,
sRunner,
serverConf,
enable0RTT,
allow0RTT,
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -536,7 +540,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sTransportParameters,
sRunner,
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -591,7 +595,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@@ -650,7 +654,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),