diff --git a/send_stream.go b/send_stream.go index df1b753b5..c3191ec64 100644 --- a/send_stream.go +++ b/send_stream.go @@ -166,6 +166,10 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Fr } func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { + if s.canceledWrite || s.closeForShutdownErr != nil { + return nil, false + } + if len(s.retransmissionQueue) > 0 { f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes) if f != nil || hasMoreRetransmissions { @@ -195,10 +199,6 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun } func (s *sendStream) popNewStreamFrame(f *wire.StreamFrame, maxBytes protocol.ByteCount) bool { - if s.canceledWrite || s.closeForShutdownErr != nil { - return false - } - maxDataLen := f.MaxDataLen(maxBytes, s.version) if maxDataLen == 0 { // a STREAM frame must have at least one byte of data return s.dataForWriting != nil @@ -329,7 +329,6 @@ func (s *sendStream) Close() error { func (s *sendStream) CancelWrite(errorCode protocol.ApplicationErrorCode) { s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) - } // must be called after locking the mutex diff --git a/send_stream_test.go b/send_stream_test.go index 3bb4421e2..57b92e9db 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -755,6 +755,22 @@ var _ = Describe("Send Stream", func() { Expect(newFrame).ToNot(BeNil()) Expect(newFrame.Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) }) + + It("doesn't get a retransmission after a stream was canceled", func() { + str.numOutstandingFrames = 1 + f := &wire.StreamFrame{ + Data: []byte("foobar"), + Offset: 0x42, + DataLenPresent: false, + } + mockSender.EXPECT().onHasStreamData(streamID) + str.queueRetransmission(f) + mockSender.EXPECT().queueControlFrame(gomock.Any()) + str.CancelWrite(0) + frame, hasMoreData := str.popStreamFrame(protocol.MaxByteCount) + Expect(hasMoreData).To(BeFalse()) + Expect(frame).To(BeNil()) + }) }) Context("determining when a stream is completed", func() {