diff --git a/send_stream.go b/send_stream.go index 624e1ff47..5ffd6efc6 100644 --- a/send_stream.go +++ b/send_stream.go @@ -41,8 +41,10 @@ type sendStream struct { finSent bool // set when a STREAM_FRAME with FIN bit has b dataForWriting []byte - writeChan chan struct{} - writeDeadline time.Time + + writeChan chan struct{} + deadline time.Time + deadlineTimer *time.Timer // initialized by SetReadDeadline() flowController flowcontrol.StreamFlowController @@ -86,7 +88,7 @@ func (s *sendStream) Write(p []byte) (int, error) { if s.closeForShutdownErr != nil { return 0, s.closeForShutdownErr } - if !s.writeDeadline.IsZero() && !time.Now().Before(s.writeDeadline) { + if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { return 0, errDeadline } if len(p) == 0 { @@ -101,8 +103,7 @@ func (s *sendStream) Write(p []byte) (int, error) { var err error for { bytesWritten = len(p) - len(s.dataForWriting) - deadline := s.writeDeadline - if !deadline.IsZero() && !time.Now().Before(deadline) { + if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { s.dataForWriting = nil err = errDeadline break @@ -112,12 +113,12 @@ func (s *sendStream) Write(p []byte) (int, error) { } s.mutex.Unlock() - if deadline.IsZero() { + if s.deadline.IsZero() { <-s.writeChan } else { select { case <-s.writeChan: - case <-time.After(time.Until(deadline)): + case <-s.deadlineTimer.C: } } s.mutex.Lock() @@ -298,12 +299,22 @@ func (s *sendStream) Context() context.Context { func (s *sendStream) SetWriteDeadline(t time.Time) error { s.mutex.Lock() - oldDeadline := s.writeDeadline - s.writeDeadline = t - s.mutex.Unlock() - if t.Before(oldDeadline) { + defer s.mutex.Unlock() + s.deadline = t + if s.deadline.IsZero() { // skip if there's no deadline to set s.signalWrite() + return nil } + // Lazily initialize the deadline timer. + if s.deadlineTimer == nil { + s.deadlineTimer = time.NewTimer(time.Until(t)) + return nil + } + // reset the timer to the new deadline + if !s.deadlineTimer.Stop() { + <-s.deadlineTimer.C + } + s.deadlineTimer.Reset(time.Until(t)) return nil } diff --git a/send_stream_test.go b/send_stream_test.go index 76aabda11..f6d304374 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -255,6 +255,21 @@ var _ = Describe("Send Stream", func() { Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) }) + It("unblocks when the deadline is changed to the past", func() { + mockSender.EXPECT().onHasStreamData(streamID) + str.SetWriteDeadline(time.Now().Add(time.Hour)) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + str.SetWriteDeadline(time.Now().Add(-time.Hour)) + Eventually(done).Should(BeClosed()) + }) + It("returns the number of bytes written, when the deadline expires", func() { mockSender.EXPECT().onHasStreamData(streamID) mockFC.EXPECT().SendWindowSize().Return(protocol.ByteCount(10000)).AnyTimes()