diff --git a/frames/ack_frame.go b/frames/ack_frame.go index 274ab177..4e718311 100644 --- a/frames/ack_frame.go +++ b/frames/ack_frame.go @@ -21,7 +21,7 @@ type AckFrame struct { } // Write writes an ACK frame. -func (f *AckFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { +func (f *AckFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error { typeByte := uint8(0x40 | 0x0C) if f.HasNACK() { diff --git a/frames/ack_frame_test.go b/frames/ack_frame_test.go index ad8963bb..53dc640f 100644 --- a/frames/ack_frame_test.go +++ b/frames/ack_frame_test.go @@ -238,7 +238,7 @@ var _ = Describe("AckFrame", func() { Entropy: 2, LargestObserved: 1, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x4c, 0x02, 0x01, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0})) }) @@ -249,7 +249,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 4, NackRanges: []NackRange{NackRange{FirstPacketNumber: 2, LastPacketNumber: 2}}, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-8:] Expect(missingPacketBytes[0]).To(Equal(uint8(1))) // numRanges @@ -267,7 +267,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 7, NackRanges: []NackRange{nackRange1, nackRange2}, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-(1+2*7):] Expect(missingPacketBytes[0]).To(Equal(uint8(2))) // numRanges @@ -288,7 +288,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 258, NackRanges: []NackRange{NackRange{FirstPacketNumber: 2, LastPacketNumber: 257}}, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-(1+7):] Expect(missingPacketBytes[0]).To(Equal(uint8(1))) // numRanges @@ -302,7 +302,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 302, NackRanges: []NackRange{NackRange{FirstPacketNumber: 2, LastPacketNumber: 301}}, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-(1+2*7):] Expect(missingPacketBytes[0]).To(Equal(uint8(2))) // numRanges @@ -318,7 +318,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 259, NackRanges: []NackRange{NackRange{FirstPacketNumber: 2, LastPacketNumber: 258}}, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-(1+2*7):] Expect(missingPacketBytes[0]).To(Equal(uint8(2))) // numRanges @@ -334,7 +334,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 603, NackRanges: []NackRange{NackRange{FirstPacketNumber: 2, LastPacketNumber: 601}}, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-(1+3*7):] Expect(missingPacketBytes[0]).To(Equal(uint8(3))) // numRanges @@ -354,7 +354,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 655, NackRanges: []NackRange{nackRange2, nackRange1}, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-(1+4*7):] Expect(missingPacketBytes[0]).To(Equal(uint8(4))) // numRanges @@ -375,7 +375,7 @@ var _ = Describe("AckFrame", func() { Entropy: 2, LargestObserved: 1, } - f.Write(b, 1, 6) + f.Write(b, 1, protocol.PacketNumberLen6) Expect(f.MinLength()).To(Equal(b.Len())) }) @@ -385,7 +385,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 4, NackRanges: []NackRange{NackRange{FirstPacketNumber: 2, LastPacketNumber: 2}}, } - err := f.Write(b, 1, 6) + err := f.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) Expect(f.MinLength()).To(Equal(b.Len())) }) @@ -403,7 +403,7 @@ var _ = Describe("AckFrame", func() { Entropy: 0xDE, LargestObserved: 6789, } - err := frameOrig.Write(b, 1, 6) + err := frameOrig.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) frame, err := ParseAckFrame(bytes.NewReader(b.Bytes())) Expect(err).ToNot(HaveOccurred()) @@ -421,7 +421,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 15, NackRanges: nackRanges, } - err := frameOrig.Write(b, 1, 6) + err := frameOrig.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(b.Bytes()) frame, err := ParseAckFrame(r) @@ -441,7 +441,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 1600, NackRanges: nackRanges, } - err := frameOrig.Write(b, 1, 6) + err := frameOrig.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(b.Bytes()) frame, err := ParseAckFrame(r) diff --git a/frames/blocked_frame.go b/frames/blocked_frame.go index a0ac1dd4..14bdb1ee 100644 --- a/frames/blocked_frame.go +++ b/frames/blocked_frame.go @@ -13,7 +13,7 @@ type BlockedFrame struct { } //Write writes a RST_STREAM frame -func (f *BlockedFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { +func (f *BlockedFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error { b.WriteByte(0x05) if f.StreamID == 0 { diff --git a/frames/blocked_frame_test.go b/frames/blocked_frame_test.go index bac38376..8d19971e 100644 --- a/frames/blocked_frame_test.go +++ b/frames/blocked_frame_test.go @@ -22,7 +22,7 @@ var _ = Describe("BlockedFrame", func() { It("writes a sample frame", func() { b := &bytes.Buffer{} frame := BlockedFrame{StreamID: 0x1337} - frame.Write(b, 10, 6) + frame.Write(b, 10, protocol.PacketNumberLen6) Expect(b.Bytes()).To(Equal([]byte{0x05, 0x37, 0x13, 0x0, 0x0})) }) diff --git a/frames/connection_close_frame.go b/frames/connection_close_frame.go index 101f37bf..0590034f 100644 --- a/frames/connection_close_frame.go +++ b/frames/connection_close_frame.go @@ -52,7 +52,7 @@ func (f *ConnectionCloseFrame) MinLength() int { } // Write writes an CONNECTION_CLOSE frame. -func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { +func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error { b.WriteByte(0x02) utils.WriteUint32(b, uint32(f.ErrorCode)) diff --git a/frames/connection_close_frame_test.go b/frames/connection_close_frame_test.go index 8a318532..87a4dddf 100644 --- a/frames/connection_close_frame_test.go +++ b/frames/connection_close_frame_test.go @@ -35,7 +35,7 @@ var _ = Describe("ConnectionCloseFrame", func() { frame := &ConnectionCloseFrame{ ErrorCode: 0xDEADBEEF, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(1 + 2 + 4)) Expect(b.Bytes()).To(Equal([]byte{0x02, 0xEF, 0xBE, 0xAD, 0xDE, 0x00, 0x00})) @@ -47,7 +47,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ErrorCode: 0xDEADBEEF, ReasonPhrase: "foobar", } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(1 + 2 + 4 + len(frame.ReasonPhrase))) Expect(b.Bytes()[:5]).To(Equal([]byte{0x02, 0xEF, 0xBE, 0xAD, 0xDE})) @@ -67,7 +67,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ErrorCode: 0xDEADBEEF, ReasonPhrase: reasonPhrase, } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).To(HaveOccurred()) }) @@ -77,7 +77,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ErrorCode: 0xDEADBEEF, ReasonPhrase: "foobar", } - f.Write(b, 1, 6) + f.Write(b, 1, protocol.PacketNumberLen6) Expect(f.MinLength()).To(Equal(b.Len())) }) }) @@ -88,7 +88,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ErrorCode: 0xDEADBEEF, ReasonPhrase: "Lorem ipsum dolor sit amet.", } - err := frame.Write(b, 1, 6) + err := frame.Write(b, 1, protocol.PacketNumberLen6) Expect(err).ToNot(HaveOccurred()) readframe, err := ParseConnectionCloseFrame(bytes.NewReader(b.Bytes())) Expect(err).ToNot(HaveOccurred()) diff --git a/frames/frame.go b/frames/frame.go index 1484d69d..c5aebd5f 100644 --- a/frames/frame.go +++ b/frames/frame.go @@ -8,6 +8,6 @@ import ( // A Frame in QUIC type Frame interface { - Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error + Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error MinLength() int } diff --git a/frames/ping_frame.go b/frames/ping_frame.go index 5cfd63f7..6069f7c4 100644 --- a/frames/ping_frame.go +++ b/frames/ping_frame.go @@ -21,7 +21,7 @@ func ParsePingFrame(r *bytes.Reader) (*PingFrame, error) { return frame, nil } -func (f *PingFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { +func (f *PingFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error { typeByte := uint8(0x07) b.WriteByte(typeByte) return nil diff --git a/frames/ping_frame_test.go b/frames/ping_frame_test.go index 5cec3134..9ded2341 100644 --- a/frames/ping_frame_test.go +++ b/frames/ping_frame_test.go @@ -3,6 +3,7 @@ package frames import ( "bytes" + "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -21,7 +22,7 @@ var _ = Describe("PingFrame", func() { It("writes a sample frame", func() { b := &bytes.Buffer{} frame := PingFrame{} - frame.Write(b, 10, 6) + frame.Write(b, 10, protocol.PacketNumberLen6) Expect(b.Bytes()).To(Equal([]byte{0x07})) }) diff --git a/frames/rst_stream_frame.go b/frames/rst_stream_frame.go index 63917bd5..e93d7a51 100644 --- a/frames/rst_stream_frame.go +++ b/frames/rst_stream_frame.go @@ -15,7 +15,7 @@ type RstStreamFrame struct { } //Write writes a RST_STREAM frame -func (f *RstStreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { +func (f *RstStreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error { panic("RstStreamFrame: Write not yet implemented") } diff --git a/frames/stop_waiting_frame.go b/frames/stop_waiting_frame.go index e1a41ba0..24e945a2 100644 --- a/frames/stop_waiting_frame.go +++ b/frames/stop_waiting_frame.go @@ -14,7 +14,7 @@ type StopWaitingFrame struct { LeastUnacked protocol.PacketNumber } -func (f *StopWaitingFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { +func (f *StopWaitingFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error { // packetNumber is the packet number of the packet that this StopWaitingFrame will be sent with typeByte := uint8(0x06) b.WriteByte(typeByte) @@ -36,7 +36,7 @@ func (f *StopWaitingFrame) MinLength() int { } // ParseStopWaitingFrame parses a StopWaiting frame -func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, packetNumberLen uint8) (*StopWaitingFrame, error) { +func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) (*StopWaitingFrame, error) { frame := &StopWaitingFrame{} // read the TypeByte @@ -50,7 +50,7 @@ func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, return nil, err } - leastUnackedDelta, err := utils.ReadUintN(r, packetNumberLen) + leastUnackedDelta, err := utils.ReadUintN(r, uint8(packetNumberLen)) if err != nil { return nil, err } diff --git a/frames/stop_waiting_frame_test.go b/frames/stop_waiting_frame_test.go index 38e11cb6..a4ce9592 100644 --- a/frames/stop_waiting_frame_test.go +++ b/frames/stop_waiting_frame_test.go @@ -34,7 +34,7 @@ var _ = Describe("StopWaitingFrame", func() { LeastUnacked: 10, Entropy: 0xE, } - frame.Write(b, packetNumber, 6) + frame.Write(b, packetNumber, protocol.PacketNumberLen6) Expect(b.Bytes()[0]).To(Equal(uint8(0x06))) // todo: check more }) @@ -48,7 +48,7 @@ var _ = Describe("StopWaitingFrame", func() { Entropy: 0xE, } b := &bytes.Buffer{} - frame.Write(b, packetNumber, 6) + frame.Write(b, packetNumber, protocol.PacketNumberLen6) readframe, err := ParseStopWaitingFrame(bytes.NewReader(b.Bytes()), packetNumber, 6) Expect(err).ToNot(HaveOccurred()) Expect(readframe.Entropy).To(Equal(frame.Entropy)) diff --git a/frames/stream_frame.go b/frames/stream_frame.go index c1dd43a3..a88efd2c 100644 --- a/frames/stream_frame.go +++ b/frames/stream_frame.go @@ -70,7 +70,7 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { } // WriteStreamFrame writes a stream frame. -func (f *StreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { +func (f *StreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error { typeByte := uint8(0x80) if f.FinBit { typeByte ^= 0x40 diff --git a/frames/stream_frame_test.go b/frames/stream_frame_test.go index 61860233..ee8f79b3 100644 --- a/frames/stream_frame_test.go +++ b/frames/stream_frame_test.go @@ -37,7 +37,7 @@ var _ = Describe("StreamFrame", func() { (&StreamFrame{ StreamID: 1, Data: []byte("foobar"), - }).Write(b, 1, 6) + }).Write(b, 1, protocol.PacketNumberLen6) Expect(b.Bytes()).To(Equal([]byte{0xa3, 0x1, 0, 0, 0, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'})) }) @@ -47,7 +47,7 @@ var _ = Describe("StreamFrame", func() { StreamID: 1, Offset: 16, Data: []byte("foobar"), - }).Write(b, 1, 6) + }).Write(b, 1, protocol.PacketNumberLen6) Expect(b.Bytes()).To(Equal([]byte{0xbf, 0x1, 0, 0, 0, 0x10, 0, 0, 0, 0, 0, 0, 0, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'})) }) @@ -58,7 +58,7 @@ var _ = Describe("StreamFrame", func() { Data: []byte("f"), Offset: 1, } - f.Write(b, 1, 6) + f.Write(b, 1, protocol.PacketNumberLen6) Expect(f.MinLength()).To(Equal(b.Len())) }) }) diff --git a/frames/window_update_frame.go b/frames/window_update_frame.go index 886c5860..e9604931 100644 --- a/frames/window_update_frame.go +++ b/frames/window_update_frame.go @@ -14,7 +14,7 @@ type WindowUpdateFrame struct { } //Write writes a RST_STREAM frame -func (f *WindowUpdateFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { +func (f *WindowUpdateFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen) error { panic("WindowUpdateFrame: Write not yet implemented") } diff --git a/packet_number.go b/packet_number.go index b8eac122..67b82eb9 100644 --- a/packet_number.go +++ b/packet_number.go @@ -2,8 +2,8 @@ package quic import "github.com/lucas-clemente/quic-go/protocol" -func calculatePacketNumber(packetNumberLength uint8, lastPacketNumber protocol.PacketNumber, wirePacketNumber protocol.PacketNumber) protocol.PacketNumber { - epochDelta := protocol.PacketNumber(1) << (packetNumberLength * 8) +func calculatePacketNumber(packetNumberLength protocol.PacketNumberLen, lastPacketNumber protocol.PacketNumber, wirePacketNumber protocol.PacketNumber) protocol.PacketNumber { + epochDelta := protocol.PacketNumber(1) << (uint8(packetNumberLength) * 8) epoch := lastPacketNumber & ^(epochDelta - 1) prevEpochBegin := epoch - epochDelta nextEpochBegin := epoch + epochDelta diff --git a/packet_number_test.go b/packet_number_test.go index 2d706cd8..f1066b86 100644 --- a/packet_number_test.go +++ b/packet_number_test.go @@ -11,13 +11,13 @@ import ( // Tests taken and extended from chrome var _ = Describe("packet number calculation", func() { - check := func(length uint8, expected, last uint64) { + check := func(length protocol.PacketNumberLen, expected, last uint64) { epoch := uint64(1) << (length * 8) epochMask := epoch - 1 wirePacketNumber := expected & epochMask Expect(calculatePacketNumber(length, protocol.PacketNumber(last), protocol.PacketNumber(wirePacketNumber))).To(Equal(protocol.PacketNumber(expected))) } - for _, length := range []uint8{1, 2, 4, 6} { + for _, length := range []protocol.PacketNumberLen{protocol.PacketNumberLen1, protocol.PacketNumberLen2, protocol.PacketNumberLen4, protocol.PacketNumberLen6} { Context(fmt.Sprintf("with %d bytes", length), func() { epoch := uint64(1) << (length * 8) epochMask := epoch - 1 diff --git a/protocol/protocol.go b/protocol/protocol.go index 32c7b2a3..40428eb3 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -3,6 +3,20 @@ package protocol // A PacketNumber in QUIC type PacketNumber uint64 +// PacketNumberLen is the length of the packet number in bytes +type PacketNumberLen uint8 + +const ( + // PacketNumberLen1 is a packet number length of 1 byte + PacketNumberLen1 PacketNumberLen = 1 + // PacketNumberLen2 is a packet number length of 2 bytes + PacketNumberLen2 PacketNumberLen = 2 + // PacketNumberLen4 is a packet number length of 4 bytes + PacketNumberLen4 PacketNumberLen = 4 + // PacketNumberLen6 is a packet number length of 6 bytes + PacketNumberLen6 PacketNumberLen = 6 +) + // A ConnectionID in QUIC type ConnectionID uint64 diff --git a/public_header.go b/public_header.go index 8989278e..1da588e0 100644 --- a/public_header.go +++ b/public_header.go @@ -24,7 +24,7 @@ type PublicHeader struct { TruncateConnectionID bool VersionNumber protocol.VersionNumber QuicVersion uint32 - PacketNumberLen uint8 + PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber } @@ -77,13 +77,13 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { switch publicFlagByte & 0x30 { case 0x30: - header.PacketNumberLen = 6 + header.PacketNumberLen = protocol.PacketNumberLen6 case 0x20: - header.PacketNumberLen = 4 + header.PacketNumberLen = protocol.PacketNumberLen4 case 0x10: - header.PacketNumberLen = 2 + header.PacketNumberLen = protocol.PacketNumberLen2 case 0x00: - header.PacketNumberLen = 1 + header.PacketNumberLen = protocol.PacketNumberLen1 } // Connection ID @@ -107,7 +107,7 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { } // Packet number - packetNumber, err := utils.ReadUintN(b, header.PacketNumberLen) + packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen)) if err != nil { return nil, err }