diff --git a/frames/ack_frame.go b/frames/ack_frame.go index d88f00c2..22fa55a2 100644 --- a/frames/ack_frame.go +++ b/frames/ack_frame.go @@ -25,7 +25,7 @@ type AckFrame struct { } // Write writes an ACK frame. -func (f *AckFrame) Write(b *bytes.Buffer) error { +func (f *AckFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { typeByte := uint8(0x40 | 0x0C) if f.HasNACK() { diff --git a/frames/ack_frame_test.go b/frames/ack_frame_test.go index b519b1b6..ebbf38d1 100644 --- a/frames/ack_frame_test.go +++ b/frames/ack_frame_test.go @@ -192,7 +192,7 @@ var _ = Describe("AckFrame", func() { Entropy: 2, LargestObserved: 1, } - err := frame.Write(b) + err := frame.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x4c, 0x02, 0x01, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0})) }) @@ -208,7 +208,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 4, NackRanges: []NackRange{nackRange}, } - err := frame.Write(b) + err := frame.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-8:] Expect(missingPacketBytes[0]).To(Equal(uint8(1))) // numRanges @@ -233,7 +233,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 7, NackRanges: []NackRange{nackRange1, nackRange2}, } - err := frame.Write(b) + err := frame.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) missingPacketBytes := b.Bytes()[b.Len()-(1+2*7):] Expect(missingPacketBytes[0]).To(Equal(uint8(2))) // numRanges @@ -253,7 +253,7 @@ var _ = Describe("AckFrame", func() { Entropy: 2, LargestObserved: 1, } - f.Write(b) + f.Write(b, 1, 6) Expect(f.MaxLength()).To(Equal(b.Len())) }) @@ -269,7 +269,7 @@ var _ = Describe("AckFrame", func() { }, }, } - err := f.Write(b) + err := f.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) Expect(f.MaxLength()).To(Equal(b.Len())) }) @@ -282,7 +282,7 @@ var _ = Describe("AckFrame", func() { Entropy: 0xDE, LargestObserved: 6789, } - err := frameOrig.Write(b) + err := frameOrig.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) frame, err := ParseAckFrame(bytes.NewReader(b.Bytes())) Expect(err).ToNot(HaveOccurred()) @@ -301,7 +301,7 @@ var _ = Describe("AckFrame", func() { LargestObserved: 15, NackRanges: nackRanges, } - err := frameOrig.Write(b) + err := frameOrig.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(b.Bytes()) frame, err := ParseAckFrame(r) diff --git a/frames/connection_close_frame.go b/frames/connection_close_frame.go index faa1b305..86e14020 100644 --- a/frames/connection_close_frame.go +++ b/frames/connection_close_frame.go @@ -52,7 +52,7 @@ func (f *ConnectionCloseFrame) MaxLength() int { } // Write writes an CONNECTION_CLOSE frame. -func (f *ConnectionCloseFrame) Write(b *bytes.Buffer) error { +func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { b.WriteByte(0x02) utils.WriteUint32(b, uint32(f.ErrorCode)) diff --git a/frames/connection_close_frame_test.go b/frames/connection_close_frame_test.go index 8ec7bc40..e653275e 100644 --- a/frames/connection_close_frame_test.go +++ b/frames/connection_close_frame_test.go @@ -35,7 +35,7 @@ var _ = Describe("ConnectionCloseFrame", func() { frame := &ConnectionCloseFrame{ ErrorCode: 0xDEADBEEF, } - err := frame.Write(b) + err := frame.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(1 + 2 + 4)) Expect(b.Bytes()).To(Equal([]byte{0x02, 0xEF, 0xBE, 0xAD, 0xDE, 0x00, 0x00})) @@ -47,7 +47,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ErrorCode: 0xDEADBEEF, ReasonPhrase: "foobar", } - err := frame.Write(b) + err := frame.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(1 + 2 + 4 + len(frame.ReasonPhrase))) Expect(b.Bytes()[:5]).To(Equal([]byte{0x02, 0xEF, 0xBE, 0xAD, 0xDE})) @@ -67,7 +67,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ErrorCode: 0xDEADBEEF, ReasonPhrase: reasonPhrase, } - err := frame.Write(b) + err := frame.Write(b, 1, 6) Expect(err).To(HaveOccurred()) }) @@ -77,7 +77,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ErrorCode: 0xDEADBEEF, ReasonPhrase: "foobar", } - f.Write(b) + f.Write(b, 1, 6) Expect(f.MaxLength()).To(Equal(b.Len())) }) }) @@ -88,7 +88,7 @@ var _ = Describe("ConnectionCloseFrame", func() { ErrorCode: 0xDEADBEEF, ReasonPhrase: "Lorem ipsum dolor sit amet.", } - err := frame.Write(b) + err := frame.Write(b, 1, 6) Expect(err).ToNot(HaveOccurred()) readframe, err := ParseConnectionCloseFrame(bytes.NewReader(b.Bytes())) Expect(err).ToNot(HaveOccurred()) diff --git a/frames/frame.go b/frames/frame.go index 43a17bc8..a69c9421 100644 --- a/frames/frame.go +++ b/frames/frame.go @@ -1,9 +1,13 @@ package frames -import "bytes" +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/protocol" +) // A Frame in QUIC type Frame interface { - Write(b *bytes.Buffer) error + Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error MaxLength() int } diff --git a/frames/rst_stream_frame.go b/frames/rst_stream_frame.go index 634bd780..a1aa35e0 100644 --- a/frames/rst_stream_frame.go +++ b/frames/rst_stream_frame.go @@ -15,7 +15,7 @@ type RstStreamFrame struct { } //Write writes a RST_STREAM frame -func (f *RstStreamFrame) Write(b *bytes.Buffer) error { +func (f *RstStreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { panic("RstStreamFrame: Write not yet implemented") } diff --git a/frames/stop_waiting_frame.go b/frames/stop_waiting_frame.go index 3b163deb..7810c20c 100644 --- a/frames/stop_waiting_frame.go +++ b/frames/stop_waiting_frame.go @@ -3,6 +3,7 @@ package frames import ( "bytes" + "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/utils" ) @@ -12,7 +13,7 @@ type StopWaitingFrame struct { LeastUnackedDelta uint64 } -func (f *StopWaitingFrame) Write(b *bytes.Buffer) error { +func (f *StopWaitingFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { panic("StopWaitingFrame: Write not yet implemented") } diff --git a/frames/stream_frame.go b/frames/stream_frame.go index 37d93e68..b0e529d0 100644 --- a/frames/stream_frame.go +++ b/frames/stream_frame.go @@ -69,7 +69,7 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { } // WriteStreamFrame writes a stream frame. -func (f *StreamFrame) Write(b *bytes.Buffer) error { +func (f *StreamFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { typeByte := uint8(0x80) if f.FinBit { typeByte ^= 0x40 diff --git a/frames/stream_frame_test.go b/frames/stream_frame_test.go index 5698ad17..efd8df35 100644 --- a/frames/stream_frame_test.go +++ b/frames/stream_frame_test.go @@ -37,7 +37,7 @@ var _ = Describe("StreamFrame", func() { (&StreamFrame{ StreamID: 1, Data: []byte("foobar"), - }).Write(b) + }).Write(b, 1, 6) Expect(b.Bytes()).To(Equal([]byte{0xa3, 0x1, 0, 0, 0, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'})) }) @@ -47,7 +47,7 @@ var _ = Describe("StreamFrame", func() { StreamID: 1, Offset: 16, Data: []byte("foobar"), - }).Write(b) + }).Write(b, 1, 6) Expect(b.Bytes()).To(Equal([]byte{0xbf, 0x1, 0, 0, 0, 0x10, 0, 0, 0, 0, 0, 0, 0, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'})) }) @@ -58,7 +58,7 @@ var _ = Describe("StreamFrame", func() { Data: []byte("f"), Offset: 1, } - f.Write(b) + f.Write(b, 1, 6) Expect(f.MaxLength()).To(Equal(b.Len())) }) }) diff --git a/frames/window_update_frame.go b/frames/window_update_frame.go index f0ff4345..df32feff 100644 --- a/frames/window_update_frame.go +++ b/frames/window_update_frame.go @@ -14,7 +14,7 @@ type WindowUpdateFrame struct { } //Write writes a RST_STREAM frame -func (f *WindowUpdateFrame) Write(b *bytes.Buffer) error { +func (f *WindowUpdateFrame) Write(b *bytes.Buffer, packetNumber protocol.PacketNumber, packetNumberLen uint8) error { panic("WindowUpdateFrame: Write not yet implemented") } diff --git a/packet_packer.go b/packet_packer.go index 53c0c020..fba9dd48 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -42,7 +42,12 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { return nil, nil } - payload, err := p.composeNextPayload() + currentPacketNumber := protocol.PacketNumber(atomic.AddUint64( + (*uint64)(&p.lastPacketNumber), + 1, + )) + + payload, err := p.composeNextPayload(currentPacketNumber) if err != nil { return nil, err } @@ -55,10 +60,6 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { payload[0] = 1 } - currentPacketNumber := protocol.PacketNumber(atomic.AddUint64( - (*uint64)(&p.lastPacketNumber), - 1, - )) var raw bytes.Buffer responsePublicHeader := PublicHeader{ ConnectionID: p.connectionID, @@ -83,7 +84,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { }, nil } -func (p *packetPacker) composeNextPayload() ([]byte, error) { +func (p *packetPacker) composeNextPayload(currentPacketNumber protocol.PacketNumber) ([]byte, error) { var payload bytes.Buffer payload.WriteByte(0) // The entropy bit is set in sendPayload @@ -112,7 +113,7 @@ func (p *packetPacker) composeNextPayload() ([]byte, error) { p.queuedFrames = p.queuedFrames[1:] } - if err := frame.Write(&payload); err != nil { + if err := frame.Write(&payload, currentPacketNumber, 6); err != nil { return nil, err } } diff --git a/packet_packer_test.go b/packet_packer_test.go index 548a5d6a..d540fce1 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -33,7 +33,7 @@ var _ = Describe("Packet packer", func() { Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) b := &bytes.Buffer{} - f.Write(b) + f.Write(b, 1, 6) Expect(p.payload).To(Equal(b.Bytes())) Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) }) @@ -47,8 +47,8 @@ var _ = Describe("Packet packer", func() { Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) b := &bytes.Buffer{} - f1.Write(b) - f2.Write(b) + f1.Write(b, 2, 6) + f2.Write(b, 2, 6) Expect(p.payload).To(Equal(b.Bytes())) Expect(p.raw).To(ContainSubstring(string(b.Bytes()))) }) @@ -56,7 +56,7 @@ var _ = Describe("Packet packer", func() { It("packs many normal frames into 2 packets", func() { f := &frames.AckFrame{LargestObserved: 1} b := &bytes.Buffer{} - f.Write(b) + f.Write(b, 3, 6) for i := 0; i <= (protocol.MaxFrameSize-1)/b.Len()+1; i++ { packer.AddFrame(f) } @@ -78,7 +78,7 @@ var _ = Describe("Packet packer", func() { Offset: 1, } b := &bytes.Buffer{} - f.Write(b) + f.Write(b, 4, 6) packer.AddFrame(f) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index bf59c2ba..c8c11498 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -49,7 +49,7 @@ var _ = Describe("Packet unpacker", func() { StreamID: 1, Data: []byte("foobar"), } - err := f.Write(buf) + err := f.Write(buf, 3, 6) Expect(err).ToNot(HaveOccurred()) setReader(buf.Bytes()) packet, err := unpacker.Unpack(hdrBin, hdr, r) @@ -62,7 +62,7 @@ var _ = Describe("Packet unpacker", func() { LargestObserved: 1, DelayTime: 1, } - err := f.Write(buf) + err := f.Write(buf, 3, 6) Expect(err).ToNot(HaveOccurred()) setReader(buf.Bytes()) packet, err := unpacker.Unpack(hdrBin, hdr, r) @@ -75,7 +75,7 @@ var _ = Describe("Packet unpacker", func() { LargestObserved: 1, DelayTime: 1, } - err := f.Write(buf) + err := f.Write(buf, 3, 6) Expect(err).ToNot(HaveOccurred()) setReader(buf.Bytes()) packet, err := unpacker.Unpack(hdrBin, hdr, r) @@ -111,7 +111,7 @@ var _ = Describe("Packet unpacker", func() { It("unpacks CONNECTION_CLOSE frames", func() { f := &frames.ConnectionCloseFrame{ReasonPhrase: "foo"} - err := f.Write(buf) + err := f.Write(buf, 6, 6) Expect(err).ToNot(HaveOccurred()) setReader(buf.Bytes()) packet, err := unpacker.Unpack(hdrBin, hdr, r)