diff --git a/stream.go b/stream.go index 8e17ad2ef..57cc60fe8 100644 --- a/stream.go +++ b/stream.go @@ -163,6 +163,9 @@ func (s *stream) Write(p []byte) (int, error) { s.windowUpdateCond.L.Lock() remainingBytesInWindow := int64(s.flowControlWindow) - int64(s.writeOffset) for ; remainingBytesInWindow == 0; remainingBytesInWindow = int64(s.flowControlWindow) - int64(s.writeOffset) { + if s.remoteErr != nil { + return 0, s.remoteErr + } s.windowUpdateCond.Wait() } s.windowUpdateCond.L.Unlock() @@ -206,6 +209,7 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { func (s *stream) RegisterError(err error) { s.remoteErr = err s.streamFrames <- nil + s.windowUpdateCond.Broadcast() } func (s *stream) finishedReading() bool { diff --git a/stream_test.go b/stream_test.go index 914a859bc..2ac081159 100644 --- a/stream_test.go +++ b/stream_test.go @@ -305,6 +305,23 @@ var _ = Describe("Stream", func() { Expect(n).To(Equal(2)) Expect(err).ToNot(HaveOccurred()) }) + + It("immediately returns on remote errors", func() { + var b bool + str.flowControlWindow = 1 + + testErr := errors.New("test error") + + go func() { + time.Sleep(time.Millisecond) + b = true + str.RegisterError(testErr) + }() + + _, err := str.Write([]byte{0xDE, 0xCA, 0xFB, 0xAD}) + Expect(b).To(BeTrue()) + Expect(err).To(Equal(testErr)) + }) }) })