From dfa0137bd63563ffb294ad70e7824afa4c8a553f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 26 Jan 2017 14:12:50 +0900 Subject: [PATCH] close the connection when encountering an error on the header stream fixes #403 --- h2quic/server.go | 1 + h2quic/server_test.go | 26 +++++++++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/h2quic/server.go b/h2quic/server.go index fbc83278..e9e59070 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -112,6 +112,7 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) { if _, ok := err.(*qerr.QuicError); !ok { utils.Errorf("error handling h2 request: %s", err.Error()) } + session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) return } } diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 94f191e2..3146f18c 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -1,6 +1,7 @@ package h2quic import ( + "bytes" "net" "net/http" "os" @@ -12,6 +13,7 @@ import ( "golang.org/x/net/http2/hpack" "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/testdata" "github.com/lucas-clemente/quic-go/utils" @@ -20,14 +22,19 @@ import ( ) type mockSession struct { - closed bool - dataStream *mockStream + closed bool + closedWithError error + dataStream *mockStream } func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { return s.dataStream, nil } -func (s *mockSession) Close(error) error { s.closed = true; return nil } +func (s *mockSession) Close(e error) error { + s.closed = true + s.closedWithError = e + return nil +} func (s *mockSession) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42} } @@ -190,6 +197,19 @@ var _ = Describe("H2 server", func() { Eventually(func() bool { return handlerCalled }).Should(BeTrue()) }) + It("closes the connection if it encounters an error on the header stream", func() { + var handlerCalled bool + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + }) + headerStream := &mockStream{id: 3} + headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) + s.handleStream(session, headerStream) + Consistently(func() bool { return handlerCalled }).Should(BeFalse()) + Eventually(func() bool { return session.closed }).Should(BeTrue()) + Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "connection error: PROTOCOL_ERROR"))) + }) + It("ignores other streams", func() { var handlerCalled bool s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {