diff --git a/frames/ack_frame.go b/frames/ack_frame.go index da0b4f4e..be721624 100644 --- a/frames/ack_frame.go +++ b/frames/ack_frame.go @@ -219,7 +219,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen4: utils.LittleEndian.WriteUint32(b, uint32(f.LargestAcked)) case protocol.PacketNumberLen6: - utils.LittleEndian.WriteUint48(b, uint64(f.LargestAcked)) + utils.LittleEndian.WriteUint48(b, uint64(f.LargestAcked)&(1<<48-1)) } f.DelayTime = time.Since(f.PacketReceivedTime) @@ -257,7 +257,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen4: utils.LittleEndian.WriteUint32(b, uint32(firstAckBlockLength)) case protocol.PacketNumberLen6: - utils.LittleEndian.WriteUint48(b, uint64(firstAckBlockLength)) + utils.LittleEndian.WriteUint48(b, uint64(firstAckBlockLength)&(1<<48-1)) } for i, ackRange := range f.AckRanges { @@ -283,7 +283,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen4: utils.LittleEndian.WriteUint32(b, uint32(length)) case protocol.PacketNumberLen6: - utils.LittleEndian.WriteUint48(b, uint64(length)) + utils.LittleEndian.WriteUint48(b, uint64(length)&(1<<48-1)) } numRangesWritten++ } else { @@ -308,7 +308,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error case protocol.PacketNumberLen4: utils.LittleEndian.WriteUint32(b, uint32(lengthWritten)) case protocol.PacketNumberLen6: - utils.LittleEndian.WriteUint48(b, lengthWritten) + utils.LittleEndian.WriteUint48(b, lengthWritten&(1<<48-1)) } numRangesWritten++ diff --git a/frames/stop_waiting_frame.go b/frames/stop_waiting_frame.go index eadcfe2d..8b141b72 100644 --- a/frames/stop_waiting_frame.go +++ b/frames/stop_waiting_frame.go @@ -46,7 +46,7 @@ func (f *StopWaitingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber case protocol.PacketNumberLen4: utils.LittleEndian.WriteUint32(b, uint32(leastUnackedDelta)) case protocol.PacketNumberLen6: - utils.LittleEndian.WriteUint48(b, leastUnackedDelta) + utils.LittleEndian.WriteUint48(b, leastUnackedDelta&(1<<48-1)) default: return errPacketNumberLenNotSet } diff --git a/frames/stop_waiting_frame_test.go b/frames/stop_waiting_frame_test.go index 5a2dc2e0..e7a73b65 100644 --- a/frames/stop_waiting_frame_test.go +++ b/frames/stop_waiting_frame_test.go @@ -95,7 +95,8 @@ var _ = Describe("StopWaitingFrame", func() { PacketNumber: 13, PacketNumberLen: protocol.PacketNumberLen1, } - frame.Write(b, 0) + err := frame.Write(b, 0) + Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(2)) Expect(b.Bytes()[1]).To(Equal(uint8(3))) }) @@ -107,7 +108,8 @@ var _ = Describe("StopWaitingFrame", func() { PacketNumber: 0x1300, PacketNumberLen: protocol.PacketNumberLen2, } - frame.Write(b, 0) + err := frame.Write(b, 0) + Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(3)) Expect(b.Bytes()[1:3]).To(Equal([]byte{0xF0, 0x12})) }) @@ -119,19 +121,21 @@ var _ = Describe("StopWaitingFrame", func() { PacketNumber: 0x12345678, PacketNumberLen: protocol.PacketNumberLen4, } - frame.Write(b, 0) + err := frame.Write(b, 0) + Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(5)) Expect(b.Bytes()[1:5]).To(Equal([]byte{0x78, 0x46, 0x34, 0x12})) }) - It("writes a 6-byte LeastUnackedDelta", func() { + It("writes a 6-byte LeastUnackedDelta, for a delta that fits into 6 bytes", func() { b := &bytes.Buffer{} frame := &StopWaitingFrame{ LeastUnacked: 0x10, PacketNumber: 0x123456789ABC, PacketNumberLen: protocol.PacketNumberLen6, } - frame.Write(b, 0) + err := frame.Write(b, 0) + Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(7)) Expect(b.Bytes()[1:7]).To(Equal([]byte{0xAC, 0x9A, 0x78, 0x56, 0x34, 0x12})) }) diff --git a/internal/utils/byteorder_little_endian.go b/internal/utils/byteorder_little_endian.go index 56517f95..71ff95d5 100644 --- a/internal/utils/byteorder_little_endian.go +++ b/internal/utils/byteorder_little_endian.go @@ -2,6 +2,7 @@ package utils import ( "bytes" + "fmt" "io" ) @@ -98,6 +99,9 @@ func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) { // WriteUint56 writes 56 bit of a uint64 func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) { + if i >= (1 << 56) { + panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), @@ -106,6 +110,9 @@ func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) { // WriteUint48 writes 48 bit of a uint64 func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) { + if i >= (1 << 48) { + panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), uint8(i >> 40), @@ -114,6 +121,9 @@ func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) { // WriteUint40 writes 40 bit of a uint64 func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) { + if i >= (1 << 40) { + panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i)) + } b.Write([]byte{ uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24), uint8(i >> 32), @@ -127,6 +137,9 @@ func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) { // WriteUint24 writes 24 bit of a uint32 func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) { + if i >= (1 << 24) { + panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i)) + } b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)}) } diff --git a/internal/utils/byteorder_little_endian_test.go b/internal/utils/byteorder_little_endian_test.go index 6e8cfcef..60508aaf 100644 --- a/internal/utils/byteorder_little_endian_test.go +++ b/internal/utils/byteorder_little_endian_test.go @@ -83,10 +83,16 @@ var _ = Describe("Little Endian encoding / decoding", func() { }) It("outputs a little endian", func() { - num := uint32(0xEFAC3512) + num := uint32(0x010203) b := &bytes.Buffer{} LittleEndian.WriteUint24(b, num) - Expect(b.Bytes()).To(Equal([]byte{0x12, 0x35, 0xAC})) + Expect(b.Bytes()).To(Equal([]byte{0x03, 0x02, 0x01})) + }) + + It("panics if the value doesn't fit into 24 bits", func() { + num := uint32(0x01020304) + b := &bytes.Buffer{} + Expect(func() { LittleEndian.WriteUint24(b, num) }).Should(Panic()) }) }) @@ -113,10 +119,16 @@ var _ = Describe("Little Endian encoding / decoding", func() { }) It("outputs a little endian", func() { - num := uint64(0xDEADBEEFCAFE) + num := uint64(0x0102030405) b := &bytes.Buffer{} LittleEndian.WriteUint40(b, num) - Expect(b.Bytes()).To(Equal([]byte{0xFE, 0xCA, 0xEF, 0xBE, 0xAD})) + Expect(b.Bytes()).To(Equal([]byte{0x05, 0x04, 0x03, 0x02, 0x01})) + }) + + It("panics if the value doesn't fit into 40 bits", func() { + num := uint64(0x010203040506) + b := &bytes.Buffer{} + Expect(func() { LittleEndian.WriteUint40(b, num) }).Should(Panic()) }) }) @@ -134,12 +146,10 @@ var _ = Describe("Little Endian encoding / decoding", func() { Expect(b.Bytes()).To(Equal([]byte{0xFE, 0xCA, 0xEF, 0xBE, 0xAD, 0xDE})) }) - It("doesn't care about the two higher order bytes", func() { - num := uint64(0x1337DEADBEEFCAFE) + It("panics if the value doesn't fit into 48 bits", func() { + num := uint64(0xDEADBEEFCAFE01) b := &bytes.Buffer{} - LittleEndian.WriteUint48(b, num) - Expect(b.Len()).To(Equal(6)) - Expect(b.Bytes()).To(Equal([]byte{0xFE, 0xCA, 0xEF, 0xBE, 0xAD, 0xDE})) + Expect(func() { LittleEndian.WriteUint48(b, num) }).Should(Panic()) }) }) @@ -151,11 +161,17 @@ var _ = Describe("Little Endian encoding / decoding", func() { }) It("outputs a little endian", func() { - num := uint64(0xFFEEDDCCBBAA9988) + num := uint64(0xEEDDCCBBAA9988) b := &bytes.Buffer{} LittleEndian.WriteUint56(b, num) Expect(b.Bytes()).To(Equal([]byte{0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE})) }) + + It("panics if the value doesn't fit into 56 bits", func() { + num := uint64(0xEEDDCCBBAA998801) + b := &bytes.Buffer{} + Expect(func() { LittleEndian.WriteUint56(b, num) }).Should(Panic()) + }) }) Context("WriteUint64", func() { diff --git a/public_header.go b/public_header.go index 0c89e940..c78a395e 100644 --- a/public_header.go +++ b/public_header.go @@ -102,7 +102,7 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe case protocol.PacketNumberLen4: utils.LittleEndian.WriteUint32(b, uint32(h.PacketNumber)) case protocol.PacketNumberLen6: - utils.LittleEndian.WriteUint48(b, uint64(h.PacketNumber)) + utils.LittleEndian.WriteUint48(b, uint64(h.PacketNumber)&(1<<48-1)) default: return errPacketNumberLenNotSet }