diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index 90851817..16b0d817 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -14,9 +14,9 @@ type mockStream struct { bytes.Buffer } -func (mockStream) Close() error { return nil } - -func (s mockStream) StreamID() protocol.StreamID { return s.id } +func (mockStream) Close() error { return nil } +func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") } +func (s mockStream) StreamID() protocol.StreamID { return s.id } var _ = Describe("Response Writer", func() { var ( diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index d5e4dcec..6ae0ebd8 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -88,11 +88,9 @@ func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) } -func (s *mockStream) Close() error { - panic("not implemented") -} - -func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") } +func (s *mockStream) Close() error { panic("not implemented") } +func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") } +func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") } var _ = Describe("Crypto setup", func() { var ( diff --git a/stream.go b/stream.go index f017d9ae..b8787032 100644 --- a/stream.go +++ b/stream.go @@ -254,6 +254,11 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { return nil } +// CloseRemote makes the stream receive a "virtual" FIN stream frame at a given offset +func (s *stream) CloseRemote(offset protocol.ByteCount) { + s.AddStreamFrame(&frames.StreamFrame{FinBit: true, Offset: offset}) +} + func (s *stream) maybeTriggerWindowUpdate() { // check for stream level window updates doUpdate, byteOffset := s.flowController.MaybeTriggerWindowUpdate() diff --git a/stream_test.go b/stream_test.go index 39ff6894..934a5d48 100644 --- a/stream_test.go +++ b/stream_test.go @@ -655,5 +655,15 @@ var _ = Describe("Stream", func() { Expect(err).To(MatchError(testErr)) }) }) + + Context("when CloseRemote is called", func() { + It("closes", func() { + str.CloseRemote(0) + b := make([]byte, 8) + n, err := str.Read(b) + Expect(n).To(BeZero()) + Expect(err).To(MatchError(io.EOF)) + }) + }) }) }) diff --git a/utils/utils.go b/utils/utils.go index 9c4aa5f2..e4716986 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -21,6 +21,7 @@ type Stream interface { io.Writer io.Closer StreamID() protocol.StreamID + CloseRemote(offset protocol.ByteCount) } // ReadUintN reads N bytes