fix race condition in ReceivedPacketHistory

This commit is contained in:
Marten Seemann
2016-06-27 17:47:27 +07:00
parent 497c57d54a
commit d14f85d4ec

View File

@@ -1,6 +1,8 @@
package ackhandlernew package ackhandlernew
import ( import (
"sync"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/utils"
@@ -8,6 +10,8 @@ import (
type receivedPacketHistory struct { type receivedPacketHistory struct {
ranges *utils.PacketIntervalList ranges *utils.PacketIntervalList
mutex sync.RWMutex
} }
// newReceivedPacketHistory creates a new received packet history // newReceivedPacketHistory creates a new received packet history
@@ -19,6 +23,9 @@ func newReceivedPacketHistory() *receivedPacketHistory {
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges // ReceivedPacket registers a packet with PacketNumber p and updates the ranges
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) { func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.ranges.Len() == 0 { if h.ranges.Len() == 0 {
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
return return
@@ -62,6 +69,9 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) {
} }
func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) { func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) {
h.mutex.Lock()
defer h.mutex.Unlock()
nextEl := h.ranges.Front() nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl { for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next() nextEl = el.Next()
@@ -79,6 +89,9 @@ func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber)
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame // GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange { func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.ranges.Len() == 0 { if h.ranges.Len() == 0 {
return nil return nil
} }