diff --git a/client_test.go b/client_test.go index f21b6e92..69c2b29c 100644 --- a/client_test.go +++ b/client_test.go @@ -162,7 +162,7 @@ var _ = Describe("Client", func() { var listenErr error go func() { defer GinkgoRecover() - _, err = serverConn.Write(bytes.Repeat([]byte{'f'}, 100)) + _, err = serverConn.Write(bytes.Repeat([]byte{0xff}, 100)) Expect(err).ToNot(HaveOccurred()) }() diff --git a/public_header.go b/public_header.go index 9613ea25..3c7585e0 100644 --- a/public_header.go +++ b/public_header.go @@ -128,7 +128,8 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub // return nil, errors.New("diversification nonces should only be sent by servers") // } - if publicFlagByte&0x08 == 0 { + header.TruncateConnectionID = publicFlagByte&0x08 == 0 + if header.TruncateConnectionID && packetSentBy == protocol.PerspectiveClient { return nil, errReceivedTruncatedConnectionID } @@ -146,14 +147,16 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub } // Connection ID - connID, err := utils.ReadUint64(b) - if err != nil { - return nil, err - } - - header.ConnectionID = protocol.ConnectionID(connID) - if header.ConnectionID == 0 { - return nil, errInvalidConnectionID + if !header.TruncateConnectionID { + var connID uint64 + connID, err = utils.ReadUint64(b) + if err != nil { + return nil, err + } + header.ConnectionID = protocol.ConnectionID(connID) + if header.ConnectionID == 0 { + return nil, errInvalidConnectionID + } } if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 { @@ -181,9 +184,9 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub } header.VersionNumber = protocol.VersionTagToNumber(versionTag) } else { // parse the version negotiaton packet - if b.Len()%4 != 0 { - return nil, qerr.InvalidVersionNegotiationPacket - } + if b.Len()%4 != 0 { + return nil, qerr.InvalidVersionNegotiationPacket + } header.SupportedVersions = make([]protocol.VersionNumber, 0) for { var versionTag uint32 diff --git a/public_header_test.go b/public_header_test.go index fffb9f28..cf28a2a9 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -25,12 +25,21 @@ var _ = Describe("Public Header", func() { Expect(b.Len()).To(BeZero()) }) - It("does not accept 0-byte connection ID", func() { + It("does not accept truncated connection ID as a server", func() { b := bytes.NewReader([]byte{0x00, 0x01}) _, err := ParsePublicHeader(b, protocol.PerspectiveClient) Expect(err).To(MatchError(errReceivedTruncatedConnectionID)) }) + It("accepts a truncated connection ID as a client", func() { + b := bytes.NewReader([]byte{0x00, 0x01}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.TruncateConnectionID).To(BeTrue()) + Expect(hdr.ConnectionID).To(BeZero()) + Expect(b.Len()).To(BeZero()) + }) + It("rejects 0 as a connection ID", func() { b := bytes.NewReader([]byte{0x09, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x51, 0x30, 0x33, 0x30, 0x01}) _, err := ParsePublicHeader(b, protocol.PerspectiveClient)