use an array instead of a byte slice for Connection IDs

This commit is contained in:
Marten Seemann
2022-08-28 16:05:56 +03:00
parent 9e0f9e62ff
commit 1aced95d41
47 changed files with 530 additions and 487 deletions

View File

@@ -1,12 +1,14 @@
package protocol
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
)
var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length")
// An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999.
// Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1
// restricts the length to 20 bytes.
@@ -27,18 +29,32 @@ func (c ArbitraryLenConnectionID) String() string {
return fmt.Sprintf("%x", c.Bytes())
}
// A ConnectionID in QUIC
type ConnectionID []byte
const maxConnectionIDLen = 20
// A ConnectionID in QUIC
type ConnectionID struct {
b [20]byte
l uint8
}
// GenerateConnectionID generates a connection ID using cryptographic random
func GenerateConnectionID(len int) (ConnectionID, error) {
b := make([]byte, len)
if _, err := rand.Read(b); err != nil {
return nil, err
func GenerateConnectionID(l int) (ConnectionID, error) {
var c ConnectionID
c.l = uint8(l)
_, err := rand.Read(c.b[:l])
return c, err
}
// ParseConnectionID interprets b as a Connection ID.
// It panics if b is longer than 20 bytes.
func ParseConnectionID(b []byte) ConnectionID {
if len(b) > maxConnectionIDLen {
panic("invalid conn id length")
}
return ConnectionID(b), nil
var c ConnectionID
c.l = uint8(len(b))
copy(c.b[:c.l], b)
return c
}
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
@@ -46,39 +62,43 @@ func GenerateConnectionID(len int) (ConnectionID, error) {
func GenerateConnectionIDForInitial() (ConnectionID, error) {
r := make([]byte, 1)
if _, err := rand.Read(r); err != nil {
return nil, err
return ConnectionID{}, err
}
len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
return GenerateConnectionID(len)
l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
return GenerateConnectionID(l)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {
if len == 0 {
return nil, nil
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {
var c ConnectionID
if l == 0 {
return c, nil
}
c := make(ConnectionID, len)
_, err := io.ReadFull(r, c)
if l > maxConnectionIDLen {
return c, ErrInvalidConnectionIDLen
}
c.l = uint8(l)
_, err := io.ReadFull(r, c.b[:l])
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
return c, io.EOF
}
return c, err
}
// Equal says if two connection IDs are equal
func (c ConnectionID) Equal(other ConnectionID) bool {
return bytes.Equal(c, other)
return c == other
}
// Len returns the length of the connection ID in bytes
func (c ConnectionID) Len() int {
return len(c)
return int(c.l)
}
// Bytes returns the byte representation
func (c ConnectionID) Bytes() []byte {
return []byte(c)
return c.b[:c.l]
}
func (c ConnectionID) String() string {
@@ -92,7 +112,7 @@ type DefaultConnectionIDGenerator struct {
ConnLen int
}
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) {
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
return GenerateConnectionID(d.ConnLen)
}

View File

@@ -44,10 +44,13 @@ var _ = Describe("Connection ID generation", func() {
})
It("says if connection IDs are equal", func() {
c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
c1 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
c2 := ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
c3 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
Expect(c1.Equal(c1)).To(BeTrue())
Expect(c1.Equal(c3)).To(BeTrue())
Expect(c2.Equal(c2)).To(BeTrue())
Expect(c2.Equal(c3)).To(BeFalse())
Expect(c1.Equal(c2)).To(BeFalse())
Expect(c2.Equal(c1)).To(BeFalse())
})
@@ -65,15 +68,21 @@ var _ = Describe("Connection ID generation", func() {
Expect(err).To(MatchError(io.EOF))
})
It("returns nil for a 0 length connection ID", func() {
It("returns a 0 length connection ID", func() {
buf := bytes.NewBuffer([]byte{1, 2, 3, 4})
c, err := ReadConnectionID(buf, 0)
Expect(err).ToNot(HaveOccurred())
Expect(c).To(BeNil())
Expect(c.Len()).To(BeZero())
})
It("errors when trying to read a too long connection ID", func() {
buf := bytes.NewBuffer(make([]byte, 21))
_, err := ReadConnectionID(buf, 21)
Expect(err).To(MatchError(ErrInvalidConnectionIDLen))
})
It("returns the length", func() {
c := ConnectionID{1, 2, 3, 4, 5, 6, 7}
c := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
Expect(c.Len()).To(Equal(7))
})
@@ -83,22 +92,22 @@ var _ = Describe("Connection ID generation", func() {
})
It("returns the bytes", func() {
c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
c := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7}))
})
It("returns a nil byte slice for the default value", func() {
var c ConnectionID
Expect(c.Bytes()).To(BeNil())
Expect(c.Bytes()).To(HaveLen(0))
})
It("has a string representation", func() {
c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
c := ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
Expect(c.String()).To(Equal("deadbeef42"))
})
It("has a long string representation", func() {
c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}
c := ParseConnectionID([]byte{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad})
Expect(c.String()).To(Equal("13370000decafbad"))
})