wire: use quicvarint.Parse to when parsing transport parameters (#4482)

* wire: add a benchmark for parsing of transport parameters

* wire: use quicvarint.Parse to when parsing transport parameters
This commit is contained in:
Marten Seemann
2024-05-05 19:26:51 +08:00
committed by GitHub
parent bb6f066aa5
commit 1514095afb
5 changed files with 156 additions and 88 deletions

View File

@@ -1,7 +1,6 @@
package transportparameters
import (
"bytes"
"errors"
"fmt"
@@ -55,12 +54,12 @@ func fuzzTransportParameters(data []byte, sentByServer bool) int {
func fuzzTransportParametersForSessionTicket(data []byte) int {
tp := &wire.TransportParameters{}
if err := tp.UnmarshalFromSessionTicket(bytes.NewReader(data)); err != nil {
if err := tp.UnmarshalFromSessionTicket(data); err != nil {
return 0
}
b := tp.MarshalForSessionTicket(nil)
tp2 := &wire.TransportParameters{}
if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(b)); err != nil {
if err := tp2.UnmarshalFromSessionTicket(b); err != nil {
panic(err)
}
return 1

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"context"
"crypto/tls"
"errors"
@@ -338,25 +337,26 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (a
return false
}
func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
func decodeDataFromSessionState(b []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
ver, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
if ver != clientSessionStateRevision {
return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rttEncoded, err := quicvarint.Read(r)
rttEncoded, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
rtt := time.Duration(rttEncoded) * time.Microsecond
if !earlyData {
return rtt, nil, nil
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return 0, nil, err
}
return rtt, &tp, nil

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"time"
@@ -28,25 +27,26 @@ func (t *sessionTicket) Marshal() []byte {
}
func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
r := bytes.NewReader(b)
rev, err := quicvarint.Read(r)
rev, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read session ticket revision")
}
b = b[l:]
if rev != sessionTicketRevision {
return fmt.Errorf("unknown session ticket revision: %d", rev)
}
rtt, err := quicvarint.Read(r)
rtt, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read RTT")
}
b = b[l:]
if using0RTT {
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
} else if r.Len() > 0 {
} else if len(b) > 0 {
return fmt.Errorf("the session ticket has more bytes than expected")
}
t.RTT = time.Duration(rtt) * time.Microsecond

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"math"
"net/netip"
"testing"
"time"
"golang.org/x/exp/rand"
@@ -17,20 +18,20 @@ import (
. "github.com/onsi/gomega"
)
func getRandomValueUpTo(max int64) uint64 {
maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4}
m := maxVals[int(rand.Int31n(4))]
if m > max {
m = max
}
return uint64(rand.Int63n(m))
}
func getRandomValue() uint64 {
return getRandomValueUpTo(math.MaxInt64)
}
var _ = Describe("Transport Parameters", func() {
getRandomValueUpTo := func(max int64) uint64 {
maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4}
m := maxVals[int(rand.Int31n(4))]
if m > max {
m = max
}
return uint64(rand.Int63n(m))
}
getRandomValue := func() uint64 {
return getRandomValueUpTo(math.MaxInt64)
}
BeforeEach(func() {
rand.Seed(uint64(GinkgoRandomSeed()))
})
@@ -504,7 +505,7 @@ var _ = Describe("Transport Parameters", func() {
Expect(params.ValidFor0RTT(params)).To(BeTrue())
b := params.MarshalForSessionTicket(nil)
var tp TransportParameters
Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(Succeed())
Expect(tp.UnmarshalFromSessionTicket(b)).To(Succeed())
Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal))
Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote))
Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni))
@@ -517,7 +518,7 @@ var _ = Describe("Transport Parameters", func() {
It("rejects the parameters if it can't parse them", func() {
var p TransportParameters
Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed())
Expect(p.UnmarshalFromSessionTicket([]byte("foobar"))).ToNot(Succeed())
})
It("rejects the parameters if the version changed", func() {
@@ -525,7 +526,7 @@ var _ = Describe("Transport Parameters", func() {
data := p.MarshalForSessionTicket(nil)
b := quicvarint.Append(nil, transportParameterMarshalingVersion+1)
b = append(b, data[quicvarint.Len(transportParameterMarshalingVersion):]...)
Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
Expect(p.UnmarshalFromSessionTicket(b)).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
})
Context("rejects the parameters if they changed", func() {
@@ -722,3 +723,66 @@ var _ = Describe("Transport Parameters", func() {
})
})
})
func BenchmarkTransportParameters(b *testing.B) {
b.Run("without preferred address", func(b *testing.B) { benchmarkTransportParameters(b, false) })
b.Run("with preferred address", func(b *testing.B) { benchmarkTransportParameters(b, true) })
}
func benchmarkTransportParameters(b *testing.B, withPreferredAddress bool) {
var token protocol.StatelessResetToken
rand.Read(token[:])
rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})
params := &TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()),
InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()),
InitialMaxData: protocol.ByteCount(getRandomValue()),
MaxIdleTimeout: 0xcafe * time.Second,
MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))),
MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))),
DisableActiveMigration: true,
StatelessResetToken: &token,
OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
RetrySourceConnectionID: &rcid,
AckDelayExponent: 13,
MaxAckDelay: 42 * time.Millisecond,
ActiveConnectionIDLimit: 2 + getRandomValueUpTo(math.MaxInt64-2),
MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()),
}
var token2 protocol.StatelessResetToken
rand.Read(token2[:])
if withPreferredAddress {
var ip4 [4]byte
var ip6 [16]byte
rand.Read(ip4[:])
rand.Read(ip6[:])
params.PreferredAddress = &PreferredAddress{
IPv4: netip.AddrPortFrom(netip.AddrFrom4(ip4), 1234),
IPv6: netip.AddrPortFrom(netip.AddrFrom16(ip6), 4321),
ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
StatelessResetToken: token2,
}
}
data := params.Marshal(protocol.PerspectiveServer)
b.ResetTimer()
b.ReportAllocs()
var p TransportParameters
for i := 0; i < b.N; i++ {
if err := p.Unmarshal(data, protocol.PerspectiveServer); err != nil {
b.Fatal(err)
}
// check a few fields
if p.DisableActiveMigration != params.DisableActiveMigration ||
p.InitialMaxStreamDataBidiLocal != params.InitialMaxStreamDataBidiLocal ||
*p.StatelessResetToken != *params.StatelessResetToken ||
p.AckDelayExponent != params.AckDelayExponent {
b.Fatalf("params mismatch: %v vs %v", p, params)
}
if withPreferredAddress && *p.PreferredAddress != *params.PreferredAddress {
b.Fatalf("preferred address mismatch: %v vs %v", p.PreferredAddress, params.PreferredAddress)
}
}
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
@@ -13,7 +12,6 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -89,7 +87,7 @@ type TransportParameters struct {
// Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil {
if err := p.unmarshal(data, sentBy, false); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
@@ -98,7 +96,7 @@ func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective
return nil
}
func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error {
func (p *TransportParameters) unmarshal(b []byte, sentBy protocol.Perspective, fromSessionTicket bool) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID
@@ -112,18 +110,20 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
p.MaxAckDelay = protocol.DefaultMaxAckDelay
p.MaxDatagramFrameSize = protocol.InvalidByteCount
for r.Len() > 0 {
paramIDInt, err := quicvarint.Read(r)
for len(b) > 0 {
paramIDInt, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
paramID := transportParameterID(paramIDInt)
paramLen, err := quicvarint.Read(r)
b = b[l:]
paramLen, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if uint64(r.Len()) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
b = b[l:]
if uint64(len(b)) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(b), paramLen)
}
parameterIDs = append(parameterIDs, paramID)
switch paramID {
@@ -141,16 +141,18 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
maxAckDelayParameterID,
maxDatagramFrameSizeParameterID,
ackDelayExponentParameterID:
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
if err := p.readNumericTransportParameter(b, paramID, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case preferredAddressParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a preferred_address")
}
if err := p.readPreferredAddress(r, int(paramLen)); err != nil {
if err := p.readPreferredAddress(b, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case disableActiveMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen)
@@ -164,25 +166,41 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
var token protocol.StatelessResetToken
r.Read(token[:])
if len(b) < len(token) {
return io.EOF
}
copy(token[:], b)
b = b[len(token):]
p.StatelessResetToken = &token
case originalDestinationConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent an original_destination_connection_id")
}
p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.OriginalDestinationConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readOriginalDestinationConnectionID = true
case initialSourceConnectionIDParameterID:
p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.InitialSourceConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readInitialSourceConnectionID = true
case retrySourceConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a retry_source_connection_id")
}
connID, _ := protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
connID := protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
p.RetrySourceConnectionID = &connID
default:
r.Seek(int64(paramLen), io.SeekCurrent)
b = b[paramLen:]
}
}
@@ -212,60 +230,47 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
return nil
}
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
remainingLen := r.Len()
func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) error {
remainingLen := len(b)
pa := &PreferredAddress{}
if len(b) < 4+2+16+2+1 {
return io.EOF
}
var ipv4 [4]byte
if _, err := io.ReadFull(r, ipv4[:]); err != nil {
return err
}
port, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
copy(ipv4[:], b[:4])
port4 := binary.BigEndian.Uint16(b[4:])
b = b[4+2:]
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4)
var ipv6 [16]byte
if _, err := io.ReadFull(r, ipv6[:]); err != nil {
return err
}
port, err = utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
connIDLen, err := r.ReadByte()
if err != nil {
return err
}
copy(ipv6[:], b[:16])
port6 := binary.BigEndian.Uint16(b[16:])
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6)
b = b[16+2:]
connIDLen := int(b[0])
b = b[1:]
if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d", connIDLen)
}
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
if err != nil {
return err
if len(b) < connIDLen+len(pa.StatelessResetToken) {
return io.EOF
}
pa.ConnectionID = connID
if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil {
return err
}
if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen {
pa.ConnectionID = protocol.ParseConnectionID(b[:connIDLen])
b = b[connIDLen:]
copy(pa.StatelessResetToken[:], b)
b = b[len(pa.StatelessResetToken):]
if bytesRead := remainingLen - len(b); bytesRead != expectedLen {
return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead)
}
p.PreferredAddress = pa
return nil
}
func (p *TransportParameters) readNumericTransportParameter(
r *bytes.Reader,
paramID transportParameterID,
expectedLen int,
) error {
remainingLen := r.Len()
val, err := quicvarint.Read(r)
func (p *TransportParameters) readNumericTransportParameter(b []byte, paramID transportParameterID, expectedLen int) error {
val, l, err := quicvarint.Parse(b)
if err != nil {
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
}
if remainingLen-r.Len() != expectedLen {
if l != expectedLen {
return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID)
}
//nolint:exhaustive // This only covers the numeric transport parameters.
@@ -457,15 +462,15 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte {
}
// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error {
version, err := quicvarint.Read(r)
func (p *TransportParameters) UnmarshalFromSessionTicket(b []byte) error {
version, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if version != transportParameterMarshalingVersion {
return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
}
return p.unmarshal(r, protocol.PerspectiveServer, true)
return p.unmarshal(b[l:], protocol.PerspectiveServer, true)
}
// ValidFor0RTT checks if the transport parameters match those saved in the session ticket.