From b6c9c9b6a44f068012226c31eb613d5d7373a7ce Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Fri, 15 Apr 2016 18:46:39 +0200 Subject: [PATCH] introduce protocol/version.go and remove big endian functions --- example/main.go | 11 ++--------- protocol/protocol_suite_test.go | 13 +++++++++++++ protocol/version.go | 15 +++++++++++++++ protocol/version_test.go | 18 ++++++++++++++++++ public_header.go | 14 ++++++++------ public_header_test.go | 3 +-- utils/utils.go | 27 --------------------------- utils/utils_test.go | 30 ------------------------------ 8 files changed, 57 insertions(+), 74 deletions(-) create mode 100644 protocol/protocol_suite_test.go create mode 100644 protocol/version.go create mode 100644 protocol/version_test.go diff --git a/example/main.go b/example/main.go index 2060971a..60979688 100644 --- a/example/main.go +++ b/example/main.go @@ -15,14 +15,7 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) -const ( - // QuicVersionNumber32 is the QUIC protocol version - QuicVersionNumber32 = 32 -) - func main() { - QuicVersion32, _ := utils.ReadUint32BigEndian(bytes.NewReader([]byte{'Q', '0', 48 + (QuicVersionNumber32/10)%10, 48 + QuicVersionNumber32%10})) - path := os.Getenv("GOPATH") + "/src/github.com/lucas-clemente/quic-go/example/" keyData, err := crypto.LoadKeyData(path+"cert.der", path+"key.der") if err != nil { @@ -62,7 +55,7 @@ func main() { fmt.Printf("Got packet # %d\n", publicHeader.PacketNumber) // Send Version Negotiation Packet if the client is speaking a different protocol version - if publicHeader.VersionFlag && publicHeader.QuicVersion != QuicVersion32 { + if publicHeader.VersionFlag && publicHeader.VersionNumber != 32 { fmt.Println("Sending VersionNegotiationPacket") fullReply := &bytes.Buffer{} responsePublicHeader := quic.PublicHeader{ConnectionID: publicHeader.ConnectionID, PacketNumber: 1, VersionFlag: true} @@ -70,7 +63,7 @@ func main() { if err != nil { panic(err) } - utils.WriteUint32BigEndian(fullReply, QuicVersion32) + utils.WriteUint32(fullReply, protocol.VersionNumberToTag(protocol.VersionNumber(32))) _, err = conn.WriteToUDP(fullReply.Bytes(), remoteAddr) if err != nil { panic(err) diff --git a/protocol/protocol_suite_test.go b/protocol/protocol_suite_test.go new file mode 100644 index 00000000..8e38c4c8 --- /dev/null +++ b/protocol/protocol_suite_test.go @@ -0,0 +1,13 @@ +package protocol_test + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "testing" +) + +func TestProtocol(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Protocol Suite") +} diff --git a/protocol/version.go b/protocol/version.go new file mode 100644 index 00000000..9b73f612 --- /dev/null +++ b/protocol/version.go @@ -0,0 +1,15 @@ +package protocol + +// VersionNumber is a version number as int +type VersionNumber int + +// VersionNumberToTag maps version numbers ('32') to tags ('Q032') +func VersionNumberToTag(vn VersionNumber) uint32 { + v := uint32(vn) + return 'Q' + ((v/100%10)+'0')<<8 + ((v/10%10)+'0')<<16 + ((v%10)+'0')<<24 +} + +// VersionTagToNumber is built from VersionNumberToTag in init() +func VersionTagToNumber(v uint32) VersionNumber { + return VersionNumber(((v>>8)&0xff-'0')*100 + ((v>>16)&0xff-'0')*10 + ((v>>24)&0xff - '0')) +} diff --git a/protocol/version_test.go b/protocol/version_test.go new file mode 100644 index 00000000..b51a43c8 --- /dev/null +++ b/protocol/version_test.go @@ -0,0 +1,18 @@ +package protocol_test + +import ( + "github.com/lucas-clemente/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Version", func() { + It("converts tags to numbers", func() { + Expect(protocol.VersionTagToNumber('Q' + '1'<<8 + '2'<<16 + '3'<<24)).To(Equal(protocol.VersionNumber(123))) + }) + + It("converts number to tag", func() { + Expect(protocol.VersionNumberToTag(protocol.VersionNumber(123))).To(Equal(uint32('Q' + '1'<<8 + '2'<<16 + '3'<<24))) + }) +}) diff --git a/public_header.go b/public_header.go index 4333c14a..52863398 100644 --- a/public_header.go +++ b/public_header.go @@ -11,11 +11,11 @@ import ( // The PublicHeader of a QUIC packet type PublicHeader struct { - VersionFlag bool - ResetFlag bool - ConnectionID protocol.ConnectionID - QuicVersion uint32 - PacketNumber protocol.PacketNumber + VersionFlag bool + ResetFlag bool + ConnectionID protocol.ConnectionID + VersionNumber protocol.VersionNumber + PacketNumber protocol.PacketNumber // packetNumberLen uint8 } @@ -83,10 +83,12 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { // Version (optional) if header.VersionFlag { - header.QuicVersion, err = utils.ReadUint32BigEndian(b) + var versionTag uint32 + versionTag, err = utils.ReadUint32(b) if err != nil { return nil, err } + header.VersionNumber = protocol.VersionTagToNumber(versionTag) } // Packet number diff --git a/public_header_test.go b/public_header_test.go index 8defb9eb..15fde1de 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -2,7 +2,6 @@ package quic import ( "bytes" - "encoding/binary" "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" @@ -18,7 +17,7 @@ var _ = Describe("Public Header", func() { Expect(publicHeader.VersionFlag).To(BeTrue()) Expect(publicHeader.ResetFlag).To(BeFalse()) Expect(publicHeader.ConnectionID).To(Equal(protocol.ConnectionID(0x4cfa9f9b668619f6))) - Expect(publicHeader.QuicVersion).To(Equal(binary.BigEndian.Uint32([]byte("Q030")))) + Expect(publicHeader.VersionNumber).To(Equal(protocol.VersionNumber(30))) Expect(publicHeader.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(b.Len()).To(BeZero()) }) diff --git a/utils/utils.go b/utils/utils.go index 46b05cd1..5438dbfe 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -38,25 +38,6 @@ func ReadUint32(b io.ByteReader) (uint32, error) { return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil } -// ReadUint32BigEndian reads a uint32 Big Endian -func ReadUint32BigEndian(b io.ByteReader) (uint32, error) { - var b1, b2, b3, b4 uint8 - var err error - if b1, err = b.ReadByte(); err != nil { - return 0, err - } - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - if b3, err = b.ReadByte(); err != nil { - return 0, err - } - if b4, err = b.ReadByte(); err != nil { - return 0, err - } - return uint32(b4) + uint32(b3)<<8 + uint32(b2)<<16 + uint32(b1)<<24, nil -} - // ReadUint16 reads a uint16 func ReadUint16(b io.ByteReader) (uint16, error) { var b1, b2 uint8 @@ -90,14 +71,6 @@ func WriteUint32(b *bytes.Buffer, i uint32) { b.WriteByte(uint8((i >> 24) & 0xff)) } -// WriteUint32BigEndian writes a uint32 -func WriteUint32BigEndian(b *bytes.Buffer, i uint32) { - b.WriteByte(uint8((i >> 24) & 0xff)) - b.WriteByte(uint8((i >> 16) & 0xff)) - b.WriteByte(uint8((i >> 8) & 0xff)) - b.WriteByte(uint8(i & 0xff)) -} - // WriteUint16 writes a uint16 func WriteUint16(b *bytes.Buffer, i uint16) { b.WriteByte(uint8(i & 0xff)) diff --git a/utils/utils_test.go b/utils/utils_test.go index d45550f4..a70b52e3 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -53,21 +53,6 @@ var _ = Describe("Utils", func() { }) }) - Context("ReadUint32BigEndian", func() { - It("reads a big endian", func() { - b := []byte{0x12, 0x35, 0xAB, 0xFF} - val, err := ReadUint32BigEndian(bytes.NewReader(b)) - Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(uint32(0x1235ABFF))) - }) - - It("throws an error if less than 4 bytes are passed", func() { - b := []byte{0x13, 0x34, 0xEA} - _, err := ReadUint32(bytes.NewReader(b)) - Expect(err).To(HaveOccurred()) - }) - }) - Context("WriteUint16", func() { It("outputs 2 bytes", func() { b := &bytes.Buffer{} @@ -97,19 +82,4 @@ var _ = Describe("Utils", func() { Expect(b.Bytes()).To(Equal([]byte{0x12, 0x35, 0xAC, 0xEF})) }) }) - - Context("WriteUint32BigEndian", func() { - It("outputs 4 bytes", func() { - b := &bytes.Buffer{} - WriteUint32BigEndian(b, uint32(1)) - Expect(b.Len()).To(Equal(4)) - }) - - It("outputs a big endian", func() { - num := uint32(0xEFAC3512) - b := &bytes.Buffer{} - WriteUint32BigEndian(b, num) - Expect(b.Bytes()).To(Equal([]byte{0xEF, 0xAC, 0x35, 0x12})) - }) - }) })