diff --git a/session.go b/session.go index 7677e1d5..144f5eff 100644 --- a/session.go +++ b/session.go @@ -109,9 +109,7 @@ type session struct { sessionCreationTime time.Time lastNetworkActivityTime time.Time - timer *time.Timer - currentDeadline time.Time - timerRead bool + timer *utils.Timer } var _ Session = &session{} @@ -169,7 +167,7 @@ func (s *session) setup( s.sendingScheduled = make(chan struct{}, 1) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) - s.timer = time.NewTimer(0) + s.timer = utils.NewTimer() now := time.Now() s.lastNetworkActivityTime = now s.sessionCreationTime = now @@ -254,8 +252,8 @@ runLoop: select { case closeErr = <-s.closeChan: break runLoop - case <-s.timer.C: - s.timerRead = true + case <-s.timer.Chan(): + s.timer.SetRead() // We do all the interesting stuff after the switch statement, so // nothing to see here. case <-s.sendingScheduled: @@ -323,36 +321,23 @@ runLoop: } func (s *session) maybeResetTimer() { - nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) + deadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) if !s.nextAckScheduledTime.IsZero() { - nextDeadline = utils.MinTime(nextDeadline, s.nextAckScheduledTime) + deadline = utils.MinTime(deadline, s.nextAckScheduledTime) } if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() { - nextDeadline = utils.MinTime(nextDeadline, lossTime) + deadline = utils.MinTime(deadline, lossTime) } if !s.handshakeComplete { handshakeDeadline := s.sessionCreationTime.Add(s.config.HandshakeTimeout) - nextDeadline = utils.MinTime(nextDeadline, handshakeDeadline) + deadline = utils.MinTime(deadline, handshakeDeadline) } if !s.receivedTooManyUndecrytablePacketsTime.IsZero() { - nextDeadline = utils.MinTime(nextDeadline, s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout)) + deadline = utils.MinTime(deadline, s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout)) } - if nextDeadline.Equal(s.currentDeadline) { - // No need to reset the timer - return - } - - // We need to drain the timer if the value from its channel was not read yet. - // See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU - if !s.timer.Stop() && !s.timerRead { - <-s.timer.C - } - s.timer.Reset(nextDeadline.Sub(time.Now())) - - s.timerRead = false - s.currentDeadline = nextDeadline + s.timer.Reset(deadline) } func (s *session) idleTimeout() time.Duration { diff --git a/utils/timer.go b/utils/timer.go new file mode 100644 index 00000000..695ad3e7 --- /dev/null +++ b/utils/timer.go @@ -0,0 +1,43 @@ +package utils + +import "time" + +// A Timer wrapper that behaves correctly when resetting +type Timer struct { + t *time.Timer + read bool + deadline time.Time +} + +// NewTimer creates a new timer that is not set +func NewTimer() *Timer { + return &Timer{t: time.NewTimer(0)} +} + +// Chan returns the channel of the wrapped timer +func (t *Timer) Chan() <-chan time.Time { + return t.t.C +} + +// Reset the timer, no matter whether the value was read or not +func (t *Timer) Reset(deadline time.Time) { + if deadline.Equal(t.deadline) { + // No need to reset the timer + return + } + + // We need to drain the timer if the value from its channel was not read yet. + // See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU + if !t.t.Stop() && !t.read { + <-t.t.C + } + t.t.Reset(deadline.Sub(time.Now())) + + t.read = false + t.deadline = deadline +} + +// SetRead should be called after the value from the chan was read +func (t *Timer) SetRead() { + t.read = true +} diff --git a/utils/timer_test.go b/utils/timer_test.go new file mode 100644 index 00000000..8aa0731c --- /dev/null +++ b/utils/timer_test.go @@ -0,0 +1,45 @@ +package utils + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Timer", func() { + const d = 10 * time.Millisecond + + It("works", func() { + t := NewTimer() + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + }) + + It("works multiple times with reading", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + t.SetRead() + } + }) + + It("works multiple times without reading", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(d)) + time.Sleep(d * 2) + } + Eventually(t.Chan()).Should(Receive()) + }) + + It("works when resetting without expiration", func() { + t := NewTimer() + for i := 0; i < 10; i++ { + t.Reset(time.Now().Add(time.Hour)) + } + t.Reset(time.Now().Add(d)) + Eventually(t.Chan()).Should(Receive()) + }) +})