diff --git a/h2quic/server.go b/h2quic/server.go index 4a653ed6d..1f0db09d9 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -5,6 +5,7 @@ import ( "errors" "io/ioutil" "net/http" + "time" "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/protocol" @@ -15,12 +16,16 @@ import ( type streamCreator interface { GetOrCreateStream(protocol.StreamID) (utils.Stream, error) + Close(error, bool) error } // Server is a HTTP2 server listening for QUIC connections type Server struct { server *quic.Server handler http.Handler + + // Private flag for demo, do not use + CloseAfterFirstRequest bool } // NewServer creates a new server instance @@ -104,6 +109,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, if responseWriter.dataStream != nil { responseWriter.dataStream.Close() } + if s.CloseAfterFirstRequest { + time.Sleep(100 * time.Millisecond) + session.Close(nil, true) + } }() return nil diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 943871b6c..0cc63dc28 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -14,12 +14,16 @@ import ( . "github.com/onsi/gomega" ) -type mockSession struct{} +type mockSession struct { + closed bool +} func (s *mockSession) GetOrCreateStream(id protocol.StreamID) (utils.Stream, error) { return &mockStream{}, nil } +func (s *mockSession) Close(error, bool) error { s.closed = true; return nil } + var _ = Describe("H2 server", func() { var ( s *Server @@ -110,4 +114,18 @@ var _ = Describe("H2 server", func() { s.handleStream(session, headerStream) Consistently(func() bool { return handlerCalled }).Should(BeFalse()) }) + + It("supports closing after first request", func() { + s.CloseAfterFirstRequest = true + s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + headerStream := &mockStream{id: 3} + headerStream.Write([]byte{ + 0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, + // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding + 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, + }) + Expect(session.closed).To(BeFalse()) + s.handleStream(session, headerStream) + Eventually(func() bool { return session.closed }).Should(BeTrue()) + }) })