forked from quic-go/quic-go
use an array instead of a byte slice for Connection IDs
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length")
|
||||
|
||||
// An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999.
|
||||
// Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1
|
||||
// restricts the length to 20 bytes.
|
||||
@@ -27,18 +29,32 @@ func (c ArbitraryLenConnectionID) String() string {
|
||||
return fmt.Sprintf("%x", c.Bytes())
|
||||
}
|
||||
|
||||
// A ConnectionID in QUIC
|
||||
type ConnectionID []byte
|
||||
|
||||
const maxConnectionIDLen = 20
|
||||
|
||||
// A ConnectionID in QUIC
|
||||
type ConnectionID struct {
|
||||
b [20]byte
|
||||
l uint8
|
||||
}
|
||||
|
||||
// GenerateConnectionID generates a connection ID using cryptographic random
|
||||
func GenerateConnectionID(len int) (ConnectionID, error) {
|
||||
b := make([]byte, len)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return nil, err
|
||||
func GenerateConnectionID(l int) (ConnectionID, error) {
|
||||
var c ConnectionID
|
||||
c.l = uint8(l)
|
||||
_, err := rand.Read(c.b[:l])
|
||||
return c, err
|
||||
}
|
||||
|
||||
// ParseConnectionID interprets b as a Connection ID.
|
||||
// It panics if b is longer than 20 bytes.
|
||||
func ParseConnectionID(b []byte) ConnectionID {
|
||||
if len(b) > maxConnectionIDLen {
|
||||
panic("invalid conn id length")
|
||||
}
|
||||
return ConnectionID(b), nil
|
||||
var c ConnectionID
|
||||
c.l = uint8(len(b))
|
||||
copy(c.b[:c.l], b)
|
||||
return c
|
||||
}
|
||||
|
||||
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
|
||||
@@ -46,39 +62,43 @@ func GenerateConnectionID(len int) (ConnectionID, error) {
|
||||
func GenerateConnectionIDForInitial() (ConnectionID, error) {
|
||||
r := make([]byte, 1)
|
||||
if _, err := rand.Read(r); err != nil {
|
||||
return nil, err
|
||||
return ConnectionID{}, err
|
||||
}
|
||||
len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
|
||||
return GenerateConnectionID(len)
|
||||
l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
|
||||
return GenerateConnectionID(l)
|
||||
}
|
||||
|
||||
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
|
||||
// It returns io.EOF if there are not enough bytes to read.
|
||||
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {
|
||||
if len == 0 {
|
||||
return nil, nil
|
||||
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {
|
||||
var c ConnectionID
|
||||
if l == 0 {
|
||||
return c, nil
|
||||
}
|
||||
c := make(ConnectionID, len)
|
||||
_, err := io.ReadFull(r, c)
|
||||
if l > maxConnectionIDLen {
|
||||
return c, ErrInvalidConnectionIDLen
|
||||
}
|
||||
c.l = uint8(l)
|
||||
_, err := io.ReadFull(r, c.b[:l])
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
return c, io.EOF
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// Equal says if two connection IDs are equal
|
||||
func (c ConnectionID) Equal(other ConnectionID) bool {
|
||||
return bytes.Equal(c, other)
|
||||
return c == other
|
||||
}
|
||||
|
||||
// Len returns the length of the connection ID in bytes
|
||||
func (c ConnectionID) Len() int {
|
||||
return len(c)
|
||||
return int(c.l)
|
||||
}
|
||||
|
||||
// Bytes returns the byte representation
|
||||
func (c ConnectionID) Bytes() []byte {
|
||||
return []byte(c)
|
||||
return c.b[:c.l]
|
||||
}
|
||||
|
||||
func (c ConnectionID) String() string {
|
||||
@@ -92,7 +112,7 @@ type DefaultConnectionIDGenerator struct {
|
||||
ConnLen int
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) {
|
||||
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
|
||||
return GenerateConnectionID(d.ConnLen)
|
||||
}
|
||||
|
||||
|
||||
@@ -44,10 +44,13 @@ var _ = Describe("Connection ID generation", func() {
|
||||
})
|
||||
|
||||
It("says if connection IDs are equal", func() {
|
||||
c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
c1 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
c2 := ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
||||
c3 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
Expect(c1.Equal(c1)).To(BeTrue())
|
||||
Expect(c1.Equal(c3)).To(BeTrue())
|
||||
Expect(c2.Equal(c2)).To(BeTrue())
|
||||
Expect(c2.Equal(c3)).To(BeFalse())
|
||||
Expect(c1.Equal(c2)).To(BeFalse())
|
||||
Expect(c2.Equal(c1)).To(BeFalse())
|
||||
})
|
||||
@@ -65,15 +68,21 @@ var _ = Describe("Connection ID generation", func() {
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
})
|
||||
|
||||
It("returns nil for a 0 length connection ID", func() {
|
||||
It("returns a 0 length connection ID", func() {
|
||||
buf := bytes.NewBuffer([]byte{1, 2, 3, 4})
|
||||
c, err := ReadConnectionID(buf, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(c).To(BeNil())
|
||||
Expect(c.Len()).To(BeZero())
|
||||
})
|
||||
|
||||
It("errors when trying to read a too long connection ID", func() {
|
||||
buf := bytes.NewBuffer(make([]byte, 21))
|
||||
_, err := ReadConnectionID(buf, 21)
|
||||
Expect(err).To(MatchError(ErrInvalidConnectionIDLen))
|
||||
})
|
||||
|
||||
It("returns the length", func() {
|
||||
c := ConnectionID{1, 2, 3, 4, 5, 6, 7}
|
||||
c := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
|
||||
Expect(c.Len()).To(Equal(7))
|
||||
})
|
||||
|
||||
@@ -83,22 +92,22 @@ var _ = Describe("Connection ID generation", func() {
|
||||
})
|
||||
|
||||
It("returns the bytes", func() {
|
||||
c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
|
||||
c := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
|
||||
Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7}))
|
||||
})
|
||||
|
||||
It("returns a nil byte slice for the default value", func() {
|
||||
var c ConnectionID
|
||||
Expect(c.Bytes()).To(BeNil())
|
||||
Expect(c.Bytes()).To(HaveLen(0))
|
||||
})
|
||||
|
||||
It("has a string representation", func() {
|
||||
c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
|
||||
c := ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
|
||||
Expect(c.String()).To(Equal("deadbeef42"))
|
||||
})
|
||||
|
||||
It("has a long string representation", func() {
|
||||
c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}
|
||||
c := ParseConnectionID([]byte{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad})
|
||||
Expect(c.String()).To(Equal("13370000decafbad"))
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user