diff --git a/ackhandler/ack_handler.go b/ackhandler/ack_handler.go new file mode 100644 index 00000000..b05637cc --- /dev/null +++ b/ackhandler/ack_handler.go @@ -0,0 +1,52 @@ +package ackhandler + +import "github.com/lucas-clemente/quic-go/protocol" + +// The AckHandler handles ACKs +type AckHandler struct { + LargestObserved protocol.PacketNumber + Observed map[protocol.PacketNumber]bool +} + +// NewAckHandler creates a new AckHandler +func NewAckHandler() (*AckHandler, error) { + ackHandler := &AckHandler{ + Observed: make(map[protocol.PacketNumber]bool), + } + return ackHandler, nil +} + +// HandlePacket handles a packet +func (h *AckHandler) HandlePacket(packetNumber protocol.PacketNumber) { + if packetNumber > h.LargestObserved { + h.LargestObserved = packetNumber + } + h.Observed[packetNumber] = true +} + +// GetNackRanges gets all the NACK ranges +func (h *AckHandler) GetNackRanges() []*NackRange { + // ToDo: improve performance + var ranges []*NackRange + inRange := false + // ToDo: fix types + for i := 0; i < int(h.LargestObserved); i++ { + packetNumber := protocol.PacketNumber(i) + _, ok := h.Observed[packetNumber] + if !ok { + if !inRange { + r := &NackRange{ + FirstPacketNumber: packetNumber, + Length: 1, + } + ranges = append(ranges, r) + inRange = true + } else { + ranges[len(ranges)-1].Length++ + } + } else { + inRange = false + } + } + return ranges +} diff --git a/ackhandler/ack_handler_test.go b/ackhandler/ack_handler_test.go new file mode 100644 index 00000000..145dbd55 --- /dev/null +++ b/ackhandler/ack_handler_test.go @@ -0,0 +1,84 @@ +package ackhandler + +import ( + "github.com/lucas-clemente/quic-go/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("AckHandler", func() { + It("Returns no NACK ranges for continously received packets", func() { + ackHandler, _ := NewAckHandler() + for i := 0; i < 100; i++ { + ackHandler.HandlePacket(protocol.PacketNumber(i)) + } + Expect(ackHandler.LargestObserved).To(Equal(protocol.PacketNumber(99))) + Expect(len(ackHandler.GetNackRanges())).To(Equal(0)) + }) + + It("handles a single lost package", func() { + ackHandler, _ := NewAckHandler() + for i := 0; i < 10; i++ { + if i == 5 { + continue + } + ackHandler.HandlePacket(protocol.PacketNumber(i)) + } + Expect(ackHandler.LargestObserved).To(Equal(protocol.PacketNumber(9))) + nackRanges := ackHandler.GetNackRanges() + Expect(len(nackRanges)).To(Equal(1)) + Expect(nackRanges[0].FirstPacketNumber).To(Equal(protocol.PacketNumber(5))) + Expect(nackRanges[0].Length).To(Equal(uint8(1))) + }) + + It("handles two consecutive lost packages", func() { + ackHandler, _ := NewAckHandler() + for i := 0; i < 10; i++ { + if i == 5 || i == 6 { + continue + } + ackHandler.HandlePacket(protocol.PacketNumber(i)) + } + Expect(ackHandler.LargestObserved).To(Equal(protocol.PacketNumber(9))) + nackRanges := ackHandler.GetNackRanges() + Expect(len(nackRanges)).To(Equal(1)) + Expect(nackRanges[0].FirstPacketNumber).To(Equal(protocol.PacketNumber(5))) + Expect(nackRanges[0].Length).To(Equal(uint8(2))) + }) + + It("handles two non-consecutively lost packages", func() { + ackHandler, _ := NewAckHandler() + for i := 0; i < 10; i++ { + if i == 3 || i == 7 { + continue + } + ackHandler.HandlePacket(protocol.PacketNumber(i)) + } + Expect(ackHandler.LargestObserved).To(Equal(protocol.PacketNumber(9))) + nackRanges := ackHandler.GetNackRanges() + Expect(len(nackRanges)).To(Equal(2)) + Expect(nackRanges[0].FirstPacketNumber).To(Equal(protocol.PacketNumber(3))) + Expect(nackRanges[0].Length).To(Equal(uint8(1))) + Expect(nackRanges[1].FirstPacketNumber).To(Equal(protocol.PacketNumber(7))) + Expect(nackRanges[1].Length).To(Equal(uint8(1))) + }) + + It("handles two sequences of lost packages", func() { + ackHandler, _ := NewAckHandler() + for i := 0; i < 10; i++ { + if i == 2 || i == 3 || i == 4 || i == 7 || i == 8 { + continue + } + ackHandler.HandlePacket(protocol.PacketNumber(i)) + } + Expect(ackHandler.LargestObserved).To(Equal(protocol.PacketNumber(9))) + nackRanges := ackHandler.GetNackRanges() + Expect(len(nackRanges)).To(Equal(2)) + Expect(nackRanges[0].FirstPacketNumber).To(Equal(protocol.PacketNumber(2))) + Expect(nackRanges[0].Length).To(Equal(uint8(3))) + Expect(nackRanges[1].FirstPacketNumber).To(Equal(protocol.PacketNumber(7))) + Expect(nackRanges[1].Length).To(Equal(uint8(2))) + }) + +}) diff --git a/ackhandler/ackhandler_suite_test.go b/ackhandler/ackhandler_suite_test.go new file mode 100644 index 00000000..53108c19 --- /dev/null +++ b/ackhandler/ackhandler_suite_test.go @@ -0,0 +1,13 @@ +package ackhandler + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "testing" +) + +func TestCrypto(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "AckHandler Suite") +}