add more methods to the protocol.ConnectionID

This will make it easier to change the underlying type of the connection
ID (potentially an array is faster than a byte slice).
This commit is contained in:
Marten Seemann
2018-04-19 08:47:43 +09:00
parent 74ccd194f7
commit d47124d14d
8 changed files with 124 additions and 47 deletions

View File

@@ -3,6 +3,8 @@ package protocol
import (
"bytes"
"crypto/rand"
"fmt"
"io"
)
// A ConnectionID in QUIC
@@ -17,7 +19,38 @@ func GenerateConnectionID() (ConnectionID, error) {
return ConnectionID(b), nil
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {
if len == 0 {
return nil, nil
}
c := make(ConnectionID, len)
_, err := io.ReadFull(r, c)
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return c, err
}
// Equal says if two connection IDs are equal
func (c ConnectionID) Equal(other ConnectionID) bool {
return bytes.Equal(c, other)
}
// Len returns the length of the connection ID in bytes
func (c ConnectionID) Len() int {
return len(c)
}
// Bytes returns the byte representation
func (c ConnectionID) Bytes() []byte {
return []byte(c)
}
func (c ConnectionID) String() string {
if c.Len() == 0 {
return "(empty)"
}
return fmt.Sprintf("%#x", c.Bytes())
}

View File

@@ -1,6 +1,9 @@
package protocol
import (
"bytes"
"io"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
@@ -23,4 +26,59 @@ var _ = Describe("Connection ID generation", func() {
Expect(c1.Equal(c2)).To(BeFalse())
Expect(c2.Equal(c1)).To(BeFalse())
})
It("reads the connection ID", func() {
buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
c, err := ReadConnectionID(buf, 9)
Expect(err).ToNot(HaveOccurred())
Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}))
})
It("returns io.EOF if there's not enough data to read", func() {
buf := bytes.NewBuffer([]byte{1, 2, 3, 4})
_, err := ReadConnectionID(buf, 5)
Expect(err).To(MatchError(io.EOF))
})
It("returns nil for a 0 length connection ID", func() {
buf := bytes.NewBuffer([]byte{1, 2, 3, 4})
c, err := ReadConnectionID(buf, 0)
Expect(err).ToNot(HaveOccurred())
Expect(c).To(BeNil())
})
It("returns the length", func() {
c := ConnectionID{1, 2, 3, 4, 5, 6, 7}
Expect(c.Len()).To(Equal(7))
})
It("has 0 length for the default value", func() {
var c ConnectionID
Expect(c.Len()).To(BeZero())
})
It("returns the bytes", func() {
c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7}))
})
It("returns a nil byte slice for the default value", func() {
var c ConnectionID
Expect(c.Bytes()).To(BeNil())
})
It("has a string representation", func() {
c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42})
Expect(c.String()).To(Equal("0xdeadbeef42"))
})
It("has a long string representation", func() {
c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}
Expect(c.String()).To(Equal("0x13370000decafbad"))
})
It("has a string representation for the default value", func() {
var c ConnectionID
Expect(c.String()).To(Equal("(empty)"))
})
})

View File

@@ -35,24 +35,13 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte
return nil, err
}
dcil, scil := decodeConnIDLen(connIDLenByte)
var destConnID, srcConnID protocol.ConnectionID
if dcil > 0 {
destConnID = make(protocol.ConnectionID, dcil)
if _, err := io.ReadFull(b, destConnID); err != nil {
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return nil, err
}
destConnID, err := protocol.ReadConnectionID(b, dcil)
if err != nil {
return nil, err
}
if scil > 0 {
srcConnID = make(protocol.ConnectionID, scil)
if _, err := io.ReadFull(b, srcConnID); err != nil {
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return nil, err
}
srcConnID, err := protocol.ReadConnectionID(b, scil)
if err != nil {
return nil, err
}
h := &Header{
@@ -145,8 +134,8 @@ func (h *Header) writeLongHeader(b *bytes.Buffer) error {
if !h.DestConnectionID.Equal(h.SrcConnectionID) {
return errors.New("Header: can't write a header with different source and destination connection ID")
}
if len(h.SrcConnectionID) != 8 {
return fmt.Errorf("Header: source connection ID must be 8 bytes, is %d", len(h.SrcConnectionID))
if h.SrcConnectionID.Len() != 8 {
return fmt.Errorf("Header: source connection ID must be 8 bytes, is %d", h.SrcConnectionID.Len())
}
b.WriteByte(byte(0x80 | h.Type))
utils.BigEndian.WriteUint32(b, uint32(h.Version))
@@ -155,8 +144,8 @@ func (h *Header) writeLongHeader(b *bytes.Buffer) error {
return err
}
b.WriteByte(connIDLen)
b.Write(h.DestConnectionID)
b.Write(h.SrcConnectionID)
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
return nil
}
@@ -175,7 +164,7 @@ func (h *Header) writeShortHeader(b *bytes.Buffer) error {
}
b.WriteByte(typeByte)
b.Write(h.DestConnectionID)
b.Write(h.DestConnectionID.Bytes())
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
@@ -190,10 +179,10 @@ func (h *Header) writeShortHeader(b *bytes.Buffer) error {
// getHeaderLength gets the length of the Header in bytes.
func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
if h.IsLongHeader {
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(len(h.DestConnectionID)+len(h.SrcConnectionID)) + 4 /* packet number */, nil
return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + 4 /* packet number */, nil
}
length := protocol.ByteCount(1 /* type byte */ + len(h.DestConnectionID))
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
@@ -203,9 +192,9 @@ func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
func (h *Header) logHeader(logger utils.Logger) {
if h.IsLongHeader {
logger.Debugf(" Long Header{Type: %s, DestConnectionID: %#x, SrcConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.Version)
logger.Debugf(" Long Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.Version)
} else {
logger.Debugf(" Short Header{DestConnectionID: %#x, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
logger.Debugf(" Short Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
}
@@ -222,13 +211,14 @@ func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
}
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
if len(id) == 0 {
len := id.Len()
if len == 0 {
return 0, nil
}
if len(id) < 4 || len(id) > 18 {
if len < 4 || len > 18 {
return 0, errors.New("invalid connection ID length")
}
return byte(len(id) - 3), nil
return byte(len - 3), nil
}
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {

View File

@@ -241,13 +241,9 @@ func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool {
}
func (h *Header) logPublicHeader(logger utils.Logger) {
connID := "(omitted)"
if !h.OmitConnectionID {
connID = fmt.Sprintf("%#x", h.DestConnectionID)
}
ver := "(unset)"
if h.Version != 0 {
ver = h.Version.String()
}
logger.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
logger.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
}

View File

@@ -554,7 +554,7 @@ var _ = Describe("Public Header", func() {
PacketNumberLen: 6,
Version: protocol.Version39,
}).logPublicHeader(logger)
Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: (omitted)"))
Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: (empty)"))
})
It("logs a Public Header without a version", func() {