diff --git a/stream_framer.go b/stream_framer.go index e03bb43e..29cdf867 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -1,6 +1,8 @@ package quic import ( + "runtime" + "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" @@ -96,17 +98,24 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] } data := s.getDataForWriting(maxLen) - if data == nil { - if s.shouldSendFin() { - frame.FinBit = true - s.sentFin() - res = append(res, frame) - currentLen += frameHeaderBytes + frame.DataLen() - frame = &frames.StreamFrame{DataLenPresent: true} - } + + // Here, stream.Write() may return in parallel. Afterwards, the user may + // call stream.Close(). We want to pack the FIN into the same frame, + // so we speculatively allow the other goroutines to run. + // In tests, this increased the percentage of FINs packed into the same + // frame from ~20% to ~97%. + runtime.Gosched() + + shouldSendFin := s.shouldSendFin() + if data == nil && !shouldSendFin { return true, nil } + if shouldSendFin { + frame.FinBit = true + s.sentFin() + } + frame.Data = data f.flowControlManager.AddBytesSent(s.streamID, protocol.ByteCount(len(data))) diff --git a/stream_framer_test.go b/stream_framer_test.go index 3eff626a..c6957e1e 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "time" "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/protocol" @@ -253,6 +254,38 @@ var _ = Describe("Stream Framer", func() { Expect(fs[0].FinBit).To(BeTrue()) Expect(fs[0].Data).To(BeEmpty()) }) + + It("bundles FINs with data", func() { + // Since this is non-deterministic (see the comment in maybePopNormalFrames), + // we give it a few tries and assert that FINs were packed at least + // some times. + stream1.onData = func() {} + stream1.doneWritingOrErrCond.L = &stream1.mutex + const n = 1000 + nFins := 0 + for i := 0; i < n; { + go func() { + defer GinkgoRecover() + _, err := stream1.Write([]byte("foobar")) + Expect(err).NotTo(HaveOccurred()) + stream1.Close() + }() + time.Sleep(time.Microsecond) + fs := framer.PopStreamFrames(1000) + if len(fs) != 1 { + continue + } + Expect(fs[0].StreamID).To(Equal(stream1.streamID)) + if fs[0].FinBit { + nFins++ + } + stream1.closed = 0 + stream1.finSent = false + stream1.dataForWriting = nil + i++ + } + Expect(nFins).To(BeNumerically(">", n/2)) + }, 5) }) })