Merge pull request #2536 from lucas-clemente/amplification-protection

implement the 3x amplification limit
This commit is contained in:
Marten Seemann
2020-05-27 10:06:32 +07:00
committed by GitHub
15 changed files with 531 additions and 135 deletions

View File

@@ -2,6 +2,7 @@ package self_test
import (
"context"
"crypto/tls"
"fmt"
mrand "math/rand"
"net"
@@ -32,7 +33,7 @@ var _ = Describe("Handshake drop tests", func() {
const timeout = 10 * time.Minute
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, version protocol.VersionNumber) {
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) {
conf := getQuicConfigForServer(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeTimeout: timeout,
@@ -41,8 +42,14 @@ var _ = Describe("Handshake drop tests", func() {
if !doRetry {
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
}
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
} else {
tlsConf = getTLSConfig()
}
var err error
ln, err = quic.ListenAddr("localhost:0", getTLSConfig(), conf)
ln, err = quic.ListenAddr("localhost:0", tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
@@ -184,46 +191,52 @@ var _ = Describe("Handshake drop tests", func() {
}
Context(desc, func() {
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a
for _, lcc := range []bool{false, true} {
longCertChain := lcc
Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 1 && d.Is(direction)
}, doRetry, version)
app.run(version)
})
Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() {
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 2 && d.Is(direction)
}, doRetry, version)
app.run(version)
})
Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 1 && d.Is(direction)
}, doRetry, longCertChain, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
return d.Is(direction) && stochasticDropper(3)
}, doRetry, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() {
var incoming, outgoing int32
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
var p int32
switch d {
case quicproxy.DirectionIncoming:
p = atomic.AddInt32(&incoming, 1)
case quicproxy.DirectionOutgoing:
p = atomic.AddInt32(&outgoing, 1)
}
return p == 2 && d.Is(direction)
}, doRetry, longCertChain, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() {
startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool {
return d.Is(direction) && stochasticDropper(3)
}, doRetry, longCertChain, version)
app.run(version)
})
})
}
})
}
})

View File

@@ -51,14 +51,12 @@ var _ = Describe("Handshake tests", func() {
server quic.Listener
serverConfig *quic.Config
acceptStopped chan struct{}
tlsServerConf *tls.Config
)
BeforeEach(func() {
server = nil
acceptStopped = make(chan struct{})
serverConfig = getQuicConfigForServer(nil)
tlsServerConf = getTLSConfig()
})
AfterEach(func() {
@@ -68,10 +66,10 @@ var _ = Describe("Handshake tests", func() {
}
})
runServer := func() quic.Listener {
runServer := func(tlsConf *tls.Config) {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig)
Expect(err).ToNot(HaveOccurred())
go func() {
@@ -83,7 +81,6 @@ var _ = Describe("Handshake tests", func() {
}
}
}()
return server
}
if !israce.Enabled {
@@ -103,7 +100,7 @@ var _ = Describe("Handshake tests", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9}
server := runServer()
runServer(getTLSConfig())
defer server.Close()
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
@@ -119,7 +116,7 @@ var _ = Describe("Handshake tests", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = supportedVersions
server := runServer()
runServer(getTLSConfig())
defer server.Close()
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
@@ -145,9 +142,11 @@ var _ = Describe("Handshake tests", func() {
suiteID := id
It(fmt.Sprintf("using %s", name), func() {
tlsServerConf.CipherSuites = []uint16{suiteID}
ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
tlsConf := getTLSConfig()
tlsConf.CipherSuites = []uint16{suiteID}
ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
go func() {
defer GinkgoRecover()
@@ -177,7 +176,7 @@ var _ = Describe("Handshake tests", func() {
}
})
Context("Certifiate validation", func() {
Context("Certificate validation", func() {
for _, v := range protocol.SupportedVersions {
version := v
@@ -189,11 +188,8 @@ var _ = Describe("Handshake tests", func() {
clientConfig = getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}})
})
JustBeforeEach(func() {
runServer()
})
It("accepts the certificate", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
@@ -202,7 +198,18 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
})
It("works with a long certificate chain", func() {
runServer(getTLSConfigWithLongCertChain())
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}),
)
Expect(err).ToNot(HaveOccurred())
})
It("errors if the server name doesn't match", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
@@ -212,7 +219,10 @@ var _ = Describe("Handshake tests", func() {
})
It("fails the handshake if the client fails to provide the requested client cert", func() {
tlsServerConf.ClientAuth = tls.RequireAndVerifyClientCert
tlsConf := getTLSConfig()
tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
runServer(tlsConf)
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
@@ -234,6 +244,7 @@ var _ = Describe("Handshake tests", func() {
})
It("uses the ServerName in the tls.Config", func() {
runServer(getTLSConfig())
tlsConf := getTLSClientConfig()
tlsConf.ServerName = "localhost"
_, err := quic.DialAddr(
@@ -350,7 +361,7 @@ var _ = Describe("Handshake tests", func() {
Context("ALPN", func() {
It("negotiates an application protocol", func() {
ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig)
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
@@ -379,7 +390,7 @@ var _ = Describe("Handshake tests", func() {
})
It("errors if application protocol negotiation fails", func() {
server := runServer()
runServer(getTLSConfig())
tlsConf := getTLSClientConfig()
tlsConf.NextProtos = []string{"foobar"}
@@ -391,7 +402,6 @@ var _ = Describe("Handshake tests", func() {
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))
Expect(err.Error()).To(ContainSubstring("no application protocol"))
Expect(server.Close()).To(Succeed())
})
})

View File

@@ -3,19 +3,23 @@ package self_test
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"flag"
"fmt"
"io"
"log"
"math/rand"
"math/big"
mrand "math/rand"
"os"
"sync"
"testing"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
@@ -24,19 +28,6 @@ import (
const alpn = "quic-go integration tests"
func getTLSConfig() *tls.Config {
conf := testdata.GetTLSConfig()
conf.NextProtos = []string{alpn}
return conf
}
func getTLSClientConfig() *tls.Config {
return &tls.Config{
RootCAs: testdata.GetRootCA(),
NextProtos: []string{alpn},
}
}
const (
dataLen = 500 * 1024 // 500 KB
dataLenLong = 50 * 1024 * 1024 // 50 MB
@@ -93,6 +84,10 @@ var (
logBufOnce sync.Once
logBuf *syncedBuffer
enableQlog bool
tlsConfig *tls.Config
tlsConfigLongChain *tls.Config
tlsClientConfig *tls.Config
)
// read the logfile command line flag
@@ -100,6 +95,151 @@ var (
func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
ca, caPrivateKey, err := generateCA()
if err != nil {
panic(err)
}
leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey)
if err != nil {
panic(err)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{tls.Certificate{
Certificate: [][]byte{leafCert.Raw},
PrivateKey: leafPrivateKey,
}},
NextProtos: []string{alpn},
}
tlsConfLongChain, err := generateTLSConfigWithLongCertChain(ca, caPrivateKey)
if err != nil {
panic(err)
}
tlsConfigLongChain = tlsConfLongChain
root := x509.NewCertPool()
root.AddCert(ca)
tlsClientConfig = &tls.Config{
RootCAs: root,
NextProtos: []string{alpn},
}
}
func generateCA() (*x509.Certificate, *rsa.PrivateKey, error) {
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
return ca, caPrivateKey, nil
}
func generateLeafCert(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) {
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1),
DNSNames: []string{"localhost"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey)
if err != nil {
return nil, nil, err
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, err
}
return cert, privKey, nil
}
// getTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain.
// The Root CA used is the same as for the config returned from getTLSConfig().
func generateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*tls.Config, error) {
const chainLen = 7
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
lastCA := ca
lastCAPrivKey := caPrivateKey
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
certs := make([]*x509.Certificate, chainLen)
for i := 0; i < chainLen; i++ {
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey)
if err != nil {
return nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, err
}
certs[i] = ca
lastCA = ca
lastCAPrivKey = privKey
}
leafCert, leafPrivateKey, err := generateLeafCert(lastCA, lastCAPrivKey)
if err != nil {
return nil, err
}
rawCerts := make([][]byte, chainLen+1)
for i, cert := range certs {
rawCerts[chainLen-i] = cert.Raw
}
rawCerts[0] = leafCert.Raw
return &tls.Config{
Certificates: []tls.Certificate{tls.Certificate{
Certificate: rawCerts,
PrivateKey: leafPrivateKey,
}},
NextProtos: []string{alpn},
}, nil
}
func getTLSConfig() *tls.Config {
return tlsConfig.Clone()
}
func getTLSConfigWithLongCertChain() *tls.Config {
return tlsConfigLongChain.Clone()
}
func getTLSClientConfig() *tls.Config {
return tlsClientConfig.Clone()
}
func getQuicConfigForClient(conf *quic.Config) *quic.Config {
@@ -163,5 +303,5 @@ func TestSelf(t *testing.T) {
}
var _ = BeforeSuite(func() {
rand.Seed(GinkgoRandomSeed())
mrand.Seed(GinkgoRandomSeed())
})

View File

@@ -25,12 +25,14 @@ type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet)
ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) error
ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel)
ResetForRetry() error
SetHandshakeComplete()
// The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode
AmplificationWindow() protocol.ByteCount
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
@@ -56,6 +58,7 @@ type SentPacketHandler interface {
type sentPacketTracker interface {
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
ReceivedPacket(protocol.EncryptionLevel)
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets

View File

@@ -47,3 +47,15 @@ func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked()
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked))
}
// ReceivedPacket mocks base method
func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedPacket", arg0)
}
// ReceivedPacket indicates an expected call of ReceivedPacket
func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0)
}

View File

@@ -64,6 +64,7 @@ func (h *receivedPacketHandler) ReceivedPacket(
rcvTime time.Time,
shouldInstigateAck bool,
) error {
h.sentPackets.ReceivedPacket(encLevel)
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck)

View File

@@ -3,6 +3,8 @@ package ackhandler
import (
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
@@ -29,6 +31,9 @@ var _ = Describe("Received Packet Handler", func() {
It("generates ACKs for different packet number spaces", func() {
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sendTime := time.Now().Add(-time.Second)
sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial).Times(2)
sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake).Times(2)
sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT).Times(2)
Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(5, protocol.Encryption1RTT, sendTime, true)).To(Succeed())
@@ -54,6 +59,8 @@ var _ = Describe("Received Packet Handler", func() {
It("uses the same packet number space for 0-RTT and 1-RTT packets", func() {
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT)
sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT)
sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed())
@@ -64,6 +71,7 @@ var _ = Describe("Received Packet Handler", func() {
})
It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3)
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sendTime := time.Now()
Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed())
@@ -72,6 +80,7 @@ var _ = Describe("Received Packet Handler", func() {
})
It("allows reordered 0-RTT packets", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3)
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sendTime := time.Now()
Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed())
@@ -80,6 +89,7 @@ var _ = Describe("Received Packet Handler", func() {
})
It("drops Initial packets", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2)
sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed())
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
@@ -90,6 +100,7 @@ var _ = Describe("Received Packet Handler", func() {
})
It("drops Handshake packets", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2)
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes()
sendTime := time.Now().Add(-time.Second)
Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed())
@@ -105,6 +116,7 @@ var _ = Describe("Received Packet Handler", func() {
})
It("drops old ACK ranges", func() {
sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes()
sendTime := time.Now()
sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2)
Expect(handler.ReceivedPacket(1, protocol.Encryption1RTT, sendTime, true)).To(Succeed())

View File

@@ -20,6 +20,8 @@ const (
timeThreshold = 9.0 / 8
// Maximum reordering in packets before packet threshold loss detection considers a packet lost.
packetThreshold = 3
// Before validating the client's address, the server won't send more than 3x bytes than it received.
amplificationFactor = 3
)
type packetNumberSpace struct {
@@ -49,8 +51,16 @@ type sentPacketHandler struct {
handshakePackets *packetNumberSpace
appDataPackets *packetNumberSpace
// Do we know that the peer completed address validation yet?
// Always true for the server.
peerCompletedAddressValidation bool
handshakeComplete bool
bytesReceived protocol.ByteCount
bytesSent protocol.ByteCount
// Have we validated the peer's address yet?
// Always true for the client.
peerAddressValidated bool
handshakeComplete bool
// lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
@@ -99,6 +109,7 @@ func newSentPacketHandler(
return &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient,
initialPackets: newPacketNumberSpace(initialPacketNumber),
handshakePackets: newPacketNumberSpace(0),
appDataPackets: newPacketNumberSpace(0),
@@ -168,6 +179,16 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
h.ptoMode = SendNone
}
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
h.bytesReceived += n
}
func (h *sentPacketHandler) ReceivedPacket(encLevel protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionHandshake {
h.peerAddressValidated = true
}
}
func (h *sentPacketHandler) packetsInFlight() int {
packetsInFlight := h.appDataPackets.history.Len()
if h.handshakePackets != nil {
@@ -180,6 +201,7 @@ func (h *sentPacketHandler) packetsInFlight() int {
}
func (h *sentPacketHandler) SentPacket(packet *Packet) {
h.bytesSent += packet.Length
// For the client, drop the Initial packet number space when the first Handshake packet is sent.
if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil {
h.dropPackets(protocol.EncryptionInitial)
@@ -638,6 +660,10 @@ func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets += h.handshakePackets.history.Len()
}
if h.AmplificationWindow() == 0 {
h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent)
return SendNone
}
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
@@ -683,6 +709,16 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) AmplificationWindow() protocol.ByteCount {
if h.peerAddressValidated {
return protocol.MaxByteCount
}
if h.bytesSent >= amplificationFactor*h.bytesReceived {
return 0
}
return amplificationFactor*h.bytesReceived - h.bytesSent
}
func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool {
pnSpace := h.getPacketNumberSpace(encLevel)
p := pnSpace.history.FirstOutstanding()

View File

@@ -495,13 +495,51 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed())
})
It("passes the bytes in flight to CanSend", func() {
handler.bytesInFlight = 42
cong.EXPECT().CanSend(protocol.ByteCount(42))
It("passes the bytes in flight to the congestion controller", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(42), gomock.Any(), protocol.ByteCount(42), true)
cong.EXPECT().TimeUntilSend(gomock.Any())
handler.SentPacket(&Packet{
Length: 42,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
cong.EXPECT().CanSend(protocol.ByteCount(42)).Return(true)
handler.SendMode()
})
It("returns SendNone if limited by the 3x limit", func() {
handler.ReceivedBytes(100)
cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(300), gomock.Any(), protocol.ByteCount(300), true)
cong.EXPECT().TimeUntilSend(gomock.Any())
handler.SentPacket(&Packet{
Length: 300,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
cong.EXPECT().CanSend(protocol.ByteCount(300)).Return(true).AnyTimes()
Expect(handler.AmplificationWindow()).To(BeZero())
Expect(handler.SendMode()).To(Equal(SendNone))
})
It("limits the window to 3x the bytes received, to avoid amplification attacks", func() {
handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address
cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(50), gomock.Any(), protocol.ByteCount(50), true)
cong.EXPECT().TimeUntilSend(gomock.Any())
handler.SentPacket(&Packet{
Length: 50,
EncryptionLevel: protocol.EncryptionInitial,
Frames: []Frame{{Frame: &wire.PingFrame{}}},
SendTime: time.Now(),
})
handler.ReceivedBytes(100)
Expect(handler.AmplificationWindow()).To(Equal(protocol.ByteCount(3*100 - 50)))
})
It("allows sending of ACKs when congestion limited", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
cong.EXPECT().CanSend(gomock.Any()).Return(true)
Expect(handler.SendMode()).To(Equal(SendAny))
cong.EXPECT().CanSend(gomock.Any()).Return(false)
@@ -509,6 +547,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("allows sending of ACKs when we're keeping track of MaxOutstandingSentPackets packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes()
cong.EXPECT().TimeUntilSend(gomock.Any()).AnyTimes()
cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
@@ -521,6 +560,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("allows PTOs, even when congestion limited", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
// note that we don't EXPECT a call to GetCongestionWindow
// that means retransmissions are sent without considering the congestion window
handler.numProbesToSend = 1
@@ -561,6 +601,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("doesn't set an alarm if there are no outstanding packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}}
@@ -569,6 +610,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("does nothing on OnAlarm if there are no outstanding packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
Expect(handler.SendMode()).To(Equal(SendAny))
})
@@ -602,6 +644,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("reset the PTO count when receiving an ACK", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
now := time.Now()
handler.SetHandshakeComplete()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)}))
@@ -615,6 +658,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("resets the PTO mode and PTO count when a packet number space is dropped", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
now := time.Now()
handler.SentPacket(ackElicitingPacket(&Packet{
PacketNumber: 1,
@@ -638,6 +682,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("allows two 1-RTT PTOs", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete()
var lostPackets []protocol.PacketNumber
handler.SentPacket(ackElicitingPacket(&Packet{
@@ -657,6 +702,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("only counts ack-eliciting packets as probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
@@ -672,7 +718,8 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData))
})
It("gets two probe packets if RTO expires", func() {
It("gets two probe packets if PTO expires", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1}))
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2}))
@@ -698,6 +745,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("gets two probe packets if PTO expires, for Handshake packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SentPacket(initialPacket(&Packet{PacketNumber: 1}))
handler.SentPacket(initialPacket(&Packet{PacketNumber: 2}))
@@ -714,6 +762,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("doesn't send 1-RTT probe packets before the handshake completes", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1}))
updateRTT(time.Hour)
Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP
@@ -726,6 +775,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SetHandshakeComplete()
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)}))
handler.rttStats.UpdateRTT(time.Second, 0, time.Now())
@@ -737,6 +787,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("handles ACKs for the original packet", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)}))
handler.rttStats.UpdateRTT(time.Second, 0, time.Now())
Expect(handler.OnLossDetectionTimeout()).To(Succeed())
@@ -993,6 +1044,7 @@ var _ = Describe("SentPacketHandler", func() {
})
It("cancels the PTO when dropping a packet number space", func() {
handler.ReceivedPacket(protocol.EncryptionHandshake)
now := time.Now()
handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)}))
handler.SentPacket(handshakePacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)}))
@@ -1028,12 +1080,15 @@ var _ = Describe("SentPacketHandler", func() {
})
})
Context("resetting for retry", func() {
Context("for the client", func() {
BeforeEach(func() {
perspective = protocol.PerspectiveClient
})
It("queues outstanding packets for retransmission, cancels alarms and resets PTO count", func() {
It("considers the server's address validated right away", func() {
})
It("queues outstanding packets for retransmission, cancels alarms and resets PTO count when receiving a Retry", func() {
handler.SentPacket(initialPacket(&Packet{PacketNumber: 42}))
Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero())
Expect(handler.bytesInFlight).ToNot(BeZero())
@@ -1047,7 +1102,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.ptoCount).To(BeZero())
})
It("queues outstanding frames for retransmission and cancels alarms", func() {
It("queues outstanding frames for retransmission and cancels alarms when receiving a Retry", func() {
var lostInitial, lost0RTT bool
handler.SentPacket(&Packet{
PacketNumber: 13,

View File

@@ -38,6 +38,20 @@ func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder {
return m.recorder
}
// AmplificationWindow mocks base method
func (m *MockSentPacketHandler) AmplificationWindow() protocol.ByteCount {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AmplificationWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
// AmplificationWindow indicates an expected call of AmplificationWindow
func (mr *MockSentPacketHandlerMockRecorder) AmplificationWindow() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AmplificationWindow", reflect.TypeOf((*MockSentPacketHandler)(nil).AmplificationWindow))
}
// DropPackets mocks base method
func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) {
m.ctrl.T.Helper()
@@ -149,6 +163,18 @@ func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2)
}
// ReceivedBytes mocks base method
func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReceivedBytes", arg0)
}
// ReceivedBytes indicates an expected call of ReceivedBytes
func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0)
}
// ResetForRetry mocks base method
func (m *MockSentPacketHandler) ResetForRetry() error {
m.ctrl.T.Helper()

View File

@@ -79,18 +79,18 @@ func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock
}
// PackCoalescedPacket mocks base method
func (m *MockPacker) PackCoalescedPacket() (*coalescedPacket, error) {
func (m *MockPacker) PackCoalescedPacket(arg0 protocol.ByteCount) (*coalescedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PackCoalescedPacket")
ret := m.ctrl.Call(m, "PackCoalescedPacket", arg0)
ret0, _ := ret[0].(*coalescedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PackCoalescedPacket indicates an expected call of PackCoalescedPacket
func (mr *MockPackerMockRecorder) PackCoalescedPacket() *gomock.Call {
func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), arg0)
}
// PackConnectionClose mocks base method

View File

@@ -17,7 +17,7 @@ import (
)
type packer interface {
PackCoalescedPacket() (*coalescedPacket, error)
PackCoalescedPacket(protocol.ByteCount) (*coalescedPacket, error)
PackPacket() (*packedPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error)
MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error)
@@ -323,14 +323,14 @@ func (p *packetPacker) padPacket(buffer *packetBuffer) {
// PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed.
func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
func (p *packetPacker) PackCoalescedPacket(maxPacketSize protocol.ByteCount) (*coalescedPacket, error) {
buffer := getPacketBuffer()
packet, err := p.packCoalescedPacket(buffer)
packet, err := p.packCoalescedPacket(buffer, maxPacketSize)
if err != nil {
return nil, err
}
if len(packet.packets) == 0 { // nothing to send
if packet == nil || len(packet.packets) == 0 { // nothing to send
buffer.Release()
return nil, nil
}
@@ -342,37 +342,45 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
return packet, nil
}
func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPacket, error) {
func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer, maxPacketSize protocol.ByteCount) (*coalescedPacket, error) {
maxPacketSize = utils.MinByteCount(maxPacketSize, p.maxPacketSize)
if p.perspective == protocol.PerspectiveClient {
maxPacketSize = protocol.MinInitialPacketSize
}
if maxPacketSize < protocol.MinCoalescedPacketSize {
return nil, nil
}
packet := &coalescedPacket{
buffer: buffer,
packets: make([]*packetContents, 0, 3),
}
// Try packing an Initial packet.
contents, err := p.maybeAppendCryptoPacket(buffer, protocol.EncryptionInitial)
contents, err := p.maybeAppendCryptoPacket(buffer, maxPacketSize, protocol.EncryptionInitial)
if err != nil && err != handshake.ErrKeysDropped {
return nil, err
}
if contents != nil {
packet.packets = append(packet.packets, contents)
}
if buffer.Len() >= p.maxPacketSize-protocol.MinCoalescedPacketSize {
if buffer.Len() >= maxPacketSize-protocol.MinCoalescedPacketSize {
return packet, nil
}
// Add a Handshake packet.
contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionHandshake)
contents, err = p.maybeAppendCryptoPacket(buffer, maxPacketSize, protocol.EncryptionHandshake)
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if contents != nil {
packet.packets = append(packet.packets, contents)
}
if buffer.Len() >= p.maxPacketSize-protocol.MinCoalescedPacketSize {
if buffer.Len() >= maxPacketSize-protocol.MinCoalescedPacketSize {
return packet, nil
}
// Add a 0-RTT / 1-RTT packet.
contents, err = p.maybeAppendAppDataPacket(buffer)
contents, err = p.maybeAppendAppDataPacket(buffer, maxPacketSize)
if err == handshake.ErrKeysNotYetAvailable {
return packet, nil
}
@@ -389,7 +397,7 @@ func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPack
// It should be called after the handshake is confirmed.
func (p *packetPacker) PackPacket() (*packedPacket, error) {
buffer := getPacketBuffer()
contents, err := p.maybeAppendAppDataPacket(buffer)
contents, err := p.maybeAppendAppDataPacket(buffer, p.maxPacketSize)
if err != nil || contents == nil {
buffer.Release()
return nil, err
@@ -400,16 +408,12 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
}, nil
}
func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, encLevel protocol.EncryptionLevel) (*packetContents, error) {
func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*packetContents, error) {
var sealer sealer
var s cryptoStream
var hasRetransmission bool
maxPacketSize := p.maxPacketSize
switch encLevel {
case protocol.EncryptionInitial:
if p.perspective == protocol.PerspectiveClient {
maxPacketSize = protocol.MinInitialPacketSize
}
s = p.initialStream
hasRetransmission = p.retransmissionQueue.HasInitialData()
var err error
@@ -471,7 +475,7 @@ func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, encLevel pr
return p.appendPacket(buffer, hdr, payload, encLevel, sealer)
}
func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer) (*packetContents, error) {
func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer, maxPacketSize protocol.ByteCount) (*packetContents, error) {
var sealer sealer
var header *wire.ExtendedHeader
var encLevel protocol.EncryptionLevel
@@ -494,7 +498,7 @@ func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer) (*packetCo
}
headerLen := header.GetLength(p.version)
maxSize := p.maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) - headerLen
maxSize := maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) - headerLen
payload := p.composeNextPacket(maxSize, encLevel != protocol.Encryption0RTT && buffer.Len() == 0)
// check if we have anything to send
@@ -557,11 +561,11 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
buffer := getPacketBuffer()
switch encLevel {
case protocol.EncryptionInitial:
contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionInitial)
contents, err = p.maybeAppendCryptoPacket(buffer, p.maxPacketSize, protocol.EncryptionInitial)
case protocol.EncryptionHandshake:
contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionHandshake)
contents, err = p.maybeAppendCryptoPacket(buffer, p.maxPacketSize, protocol.EncryptionHandshake)
case protocol.Encryption1RTT:
contents, err = p.maybeAppendAppDataPacket(buffer)
contents, err = p.maybeAppendAppDataPacket(buffer, p.maxPacketSize)
default:
panic("unknown encryption level")
}

View File

@@ -176,7 +176,7 @@ var _ = Describe("Packet packer", func() {
expectAppendControlFrames()
f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}}
expectAppendStreamFrames(ackhandler.Frame{Frame: f})
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.packets).To(HaveLen(1))
@@ -266,7 +266,7 @@ var _ = Describe("Packet packer", func() {
framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) {
return frames, 0
})
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(1))
@@ -535,7 +535,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil)
sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable)
packer.retransmissionQueue.AddHandshake(&wire.PingFrame{})
packet, err := packer.PackCoalescedPacket()
packet, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(packet).ToNot(BeNil())
Expect(packet.packets).To(HaveLen(1))
@@ -784,7 +784,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil)
sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable)
sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable)
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
checkLength(p.buffer.Data)
})
@@ -805,7 +805,7 @@ var _ = Describe("Packet packer", func() {
Expect(f.Length(packer.version)).To(Equal(size))
return f
})
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(1))
Expect(p.packets[0].frames).To(HaveLen(1))
@@ -832,7 +832,7 @@ var _ = Describe("Packet packer", func() {
handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame {
return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")}
})
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(2))
Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
@@ -868,7 +868,7 @@ var _ = Describe("Packet packer", func() {
})
expectAppendControlFrames()
expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}})
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.buffer.Data).To(HaveLen(protocol.MinInitialPacketSize))
Expect(p.packets).To(HaveLen(2))
@@ -903,7 +903,7 @@ var _ = Describe("Packet packer", func() {
})
expectAppendControlFrames()
expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}})
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(2))
Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake))
@@ -935,7 +935,7 @@ var _ = Describe("Packet packer", func() {
Expect(f.Length(packer.version)).To(Equal(s))
return f
})
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(1))
Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
@@ -943,6 +943,62 @@ var _ = Describe("Packet packer", func() {
checkLength(p.buffer.Data)
})
It("doesn't pack a coalesced packet if there's not enough space", func() {
p, err := packer.PackCoalescedPacket(protocol.MinCoalescedPacketSize - 1)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
It("packs a small packet", func() {
const size = protocol.MinCoalescedPacketSize + 10
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24))
sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil)
// don't EXPECT any calls to GetHandshakeSealer and Get1RTTSealer
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
initialStream.EXPECT().HasData().Return(true).Times(2)
initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(s protocol.ByteCount) *wire.CryptoFrame {
f := &wire.CryptoFrame{Offset: 0x1337}
f.Data = bytes.Repeat([]byte{'f'}, int(s-f.Length(packer.version)-1))
Expect(f.Length(packer.version)).To(Equal(s))
return f
})
p, err := packer.PackCoalescedPacket(size)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(len(p.buffer.Data)).To(Equal(size))
})
It("packs a small packet, that includes a 1-RTT packet", func() {
const size = 2 * protocol.MinCoalescedPacketSize
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24))
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x24))
sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped)
sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil)
oneRTTSealer := getSealer()
sealingManager.EXPECT().Get1RTTSealer().Return(oneRTTSealer, nil)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake)
handshakeStream.EXPECT().HasData().Return(true).Times(2)
handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(&wire.CryptoFrame{
Offset: 0x1337,
Data: []byte("foobar"),
})
expectAppendControlFrames()
var appDataSize protocol.ByteCount
framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, maxSize protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) {
appDataSize = maxSize
f := &wire.StreamFrame{Data: []byte("foobar")}
return append(frames, ackhandler.Frame{Frame: f}), f.Length(packer.version)
})
p, err := packer.PackCoalescedPacket(size)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.packets).To(HaveLen(2))
Expect(appDataSize).To(Equal(size - p.packets[0].length - p.packets[1].header.GetLength(packer.version) - protocol.ByteCount(oneRTTSealer.Overhead())))
})
It("adds retransmissions", func() {
f := &wire.CryptoFrame{Data: []byte("Initial")}
retransmissionQueue.AddInitial(f)
@@ -954,7 +1010,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable)
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
initialStream.EXPECT().HasData()
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(1))
Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial))
@@ -972,7 +1028,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable)
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42))
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(1))
Expect(p.packets[0].ack).To(Equal(ack))
@@ -984,7 +1040,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable)
initialStream.EXPECT().HasData()
ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial)
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
@@ -1000,7 +1056,7 @@ var _ = Describe("Packet packer", func() {
sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable)
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42))
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(1))
Expect(p.packets[0].ack).To(Equal(ack))
@@ -1020,7 +1076,7 @@ var _ = Describe("Packet packer", func() {
initialStream.EXPECT().HasData().Return(true).Times(2)
initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f)
packer.perspective = protocol.PerspectiveClient
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize))
Expect(p.packets).To(HaveLen(1))
@@ -1044,7 +1100,7 @@ var _ = Describe("Packet packer", func() {
initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f)
packer.version = protocol.VersionTLS
packer.perspective = protocol.PerspectiveClient
p, err := packer.PackCoalescedPacket()
p, err := packer.PackCoalescedPacket(protocol.MaxByteCount)
Expect(err).ToNot(HaveOccurred())
Expect(p.packets).To(HaveLen(1))
Expect(p.packets[0].ack).To(Equal(ack))

View File

@@ -700,6 +700,7 @@ func (s *session) handlePacketImpl(rp *receivedPacket) bool {
var processed bool
data := rp.data
p := rp
s.sentPacketHandler.ReceivedBytes(protocol.ByteCount(len(data)))
for len(data) > 0 {
if counter > 0 {
p = p.Clone()
@@ -1427,7 +1428,7 @@ func (s *session) sendPacket() (bool, error) {
if !s.handshakeConfirmed {
now := time.Now()
packet, err := s.packer.PackCoalescedPacket()
packet, err := s.packer.PackCoalescedPacket(s.sentPacketHandler.AmplificationWindow())
if err != nil || packet == nil {
return false, err
}

View File

@@ -1033,6 +1033,13 @@ var _ = Describe("Session", func() {
It("sends packets", func() {
sess.handshakeConfirmed = true
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ShouldSendNumPackets().Return(1000)
sph.EXPECT().SentPacket(gomock.Any())
sess.sentPacketHandler = sph
runSession()
p := getPacket(1)
packer.EXPECT().PackPacket().Return(p, nil)
@@ -1069,6 +1076,13 @@ var _ = Describe("Session", func() {
It("adds a BLOCKED frame when it is connection-level flow control blocked", func() {
sess.handshakeConfirmed = true
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ShouldSendNumPackets().Return(1000)
sph.EXPECT().SentPacket(gomock.Any())
sess.sentPacketHandler = sph
fc := mocks.NewMockConnectionFlowController(mockCtrl)
fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337))
fc.EXPECT().IsNewlyBlocked()
@@ -1366,10 +1380,12 @@ var _ = Describe("Session", func() {
It("sends coalesced packets before the handshake is confirmed", func() {
sess.handshakeConfirmed = false
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
const window protocol.ByteCount = 321
sph.EXPECT().AmplificationWindow().Return(window).AnyTimes()
sess.sentPacketHandler = sph
buffer := getPacketBuffer()
buffer.Data = append(buffer.Data, []byte("foobar")...)
packer.EXPECT().PackCoalescedPacket().Return(&coalescedPacket{
packer.EXPECT().PackCoalescedPacket(window).Return(&coalescedPacket{
buffer: buffer,
packets: []*packetContents{
{
@@ -1394,7 +1410,7 @@ var _ = Describe("Session", func() {
},
},
}, nil)
packer.EXPECT().PackCoalescedPacket().AnyTimes()
packer.EXPECT().PackCoalescedPacket(window).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
@@ -1445,7 +1461,7 @@ var _ = Describe("Session", func() {
})
It("cancels the HandshakeComplete context and informs the SentPacketHandler when the handshake completes", func() {
packer.EXPECT().PackCoalescedPacket().AnyTimes()
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes()
finishHandshake := make(chan struct{})
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
@@ -1482,7 +1498,7 @@ var _ = Describe("Session", func() {
It("sends a session ticket when the handshake completes", func() {
const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2
packer.EXPECT().PackCoalescedPacket().AnyTimes()
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes()
finishHandshake := make(chan struct{})
sessionRunner.EXPECT().Retire(clientDestConnID)
go func() {
@@ -1525,7 +1541,7 @@ var _ = Describe("Session", func() {
})
It("doesn't cancel the HandshakeComplete context when the handshake fails", func() {
packer.EXPECT().PackCoalescedPacket().AnyTimes()
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes()
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@@ -1545,9 +1561,17 @@ var _ = Describe("Session", func() {
})
It("sends a HANDSHAKE_DONE frame when the handshake completes", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount)
sph.EXPECT().SetHandshakeComplete()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(10)
sess.sentPacketHandler = sph
done := make(chan struct{})
sessionRunner.EXPECT().Retire(clientDestConnID)
packer.EXPECT().PackCoalescedPacket().DoAndReturn(func() (*packedPacket, error) {
packer.EXPECT().PackCoalescedPacket(gomock.Any()).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) {
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
Expect(frames).ToNot(BeEmpty())
Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{}))
@@ -1559,7 +1583,7 @@ var _ = Describe("Session", func() {
buffer: getPacketBuffer(),
}, nil
})
packer.EXPECT().PackCoalescedPacket().AnyTimes()
packer.EXPECT().PackCoalescedPacket(gomock.Any()).AnyTimes()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake()
@@ -1630,7 +1654,7 @@ var _ = Describe("Session", func() {
}
streamManager.EXPECT().UpdateLimits(params)
packer.EXPECT().HandleTransportParameters(params)
packer.EXPECT().PackCoalescedPacket().MaxTimes(3)
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).MaxTimes(3)
Expect(sess.earlySessionReady()).ToNot(BeClosed())
sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2)
sessionRunner.EXPECT().Add(gomock.Any(), sess).Times(2)
@@ -1659,6 +1683,7 @@ var _ = Describe("Session", func() {
BeforeEach(func() {
sess.config.MaxIdleTimeout = 30 * time.Second
sess.config.KeepAlive = true
sess.receivedPacketHandler.ReceivedPacket(0, protocol.EncryptionHandshake, time.Now(), true)
})
AfterEach(func() {
@@ -1677,7 +1702,7 @@ var _ = Describe("Session", func() {
setRemoteIdleTimeout(5 * time.Second)
sess.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2)
sent := make(chan struct{})
packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) {
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).Do(func(protocol.ByteCount) (*packedPacket, error) {
close(sent)
return nil, nil
})
@@ -1690,7 +1715,7 @@ var _ = Describe("Session", func() {
setRemoteIdleTimeout(time.Hour)
sess.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond)
sent := make(chan struct{})
packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) {
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).Do(func(protocol.ByteCount) (*packedPacket, error) {
close(sent)
return nil, nil
})
@@ -1794,7 +1819,7 @@ var _ = Describe("Session", func() {
})
It("closes the session due to the idle timeout after handshake", func() {
packer.EXPECT().PackCoalescedPacket().AnyTimes()
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes()
gomock.InOrder(
sessionRunner.EXPECT().Retire(clientDestConnID),
sessionRunner.EXPECT().Remove(gomock.Any()),
@@ -2098,6 +2123,7 @@ var _ = Describe("Client Session", func() {
It("handles Retry packets", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sph.EXPECT().ReceivedBytes(gomock.Any())
sph.EXPECT().ResetForRetry()
cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
packer.EXPECT().SetToken([]byte("foobar"))
@@ -2180,7 +2206,7 @@ var _ = Describe("Client Session", func() {
},
}
packer.EXPECT().HandleTransportParameters(gomock.Any())
packer.EXPECT().PackCoalescedPacket().MaxTimes(1)
packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).MaxTimes(1)
qlogger.EXPECT().ReceivedTransportParameters(params)
sess.processTransportParameters(params)
// make sure the connection ID is not retired
@@ -2333,6 +2359,7 @@ var _ = Describe("Client Session", func() {
It("ignores Initial packets which use original source id, after accepting a Retry", func() {
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2)
sph.EXPECT().ResetForRetry()
newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID)