diff --git a/h2quic/server.go b/h2quic/server.go index 7b15a19f..4a653ed6 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -46,17 +46,21 @@ func (s *Server) ListenAndServe(addr string, handler http.Handler) error { return s.server.ListenAndServe(addr) } -func (s *Server) handleStreamCb(session *quic.Session, headerStream utils.Stream) { - s.handleStream(session, headerStream) +func (s *Server) handleStreamCb(session *quic.Session, stream utils.Stream) { + s.handleStream(session, stream) } -func (s *Server) handleStream(session streamCreator, headerStream utils.Stream) { +func (s *Server) handleStream(session streamCreator, stream utils.Stream) { + if stream.StreamID() != 3 { + return + } + hpackDecoder := hpack.NewDecoder(4096, nil) - h2framer := http2.NewFramer(nil, headerStream) + h2framer := http2.NewFramer(nil, stream) go func() { for { - if err := s.handleRequest(session, headerStream, hpackDecoder, h2framer); err != nil { + if err := s.handleRequest(session, stream, hpackDecoder, h2framer); err != nil { utils.Errorf("error handling h2 request: %s", err.Error()) return } diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 92519fe9..943871b6 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -20,7 +20,7 @@ func (s *mockSession) GetOrCreateStream(id protocol.StreamID) (utils.Stream, err return &mockStream{}, nil } -var _ = Describe("Response Writer", func() { +var _ = Describe("H2 server", func() { var ( s *Server session *mockSession @@ -85,7 +85,7 @@ var _ = Describe("Response Writer", func() { Expect(r.Host).To(Equal("www.example.com")) handlerCalled = true }) - headerStream := &mockStream{} + headerStream := &mockStream{id: 3} headerStream.Write([]byte{ 0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding @@ -94,4 +94,20 @@ var _ = Describe("Response Writer", func() { s.handleStream(session, headerStream) Eventually(func() bool { return handlerCalled }).Should(BeTrue()) }) + + It("ignores other streams", func() { + var handlerCalled bool + s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Expect(r.Host).To(Equal("www.example.com")) + handlerCalled = true + }) + headerStream := &mockStream{id: 5} + headerStream.Write([]byte{ + 0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, + // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding + 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, + }) + s.handleStream(session, headerStream) + Consistently(func() bool { return handlerCalled }).Should(BeFalse()) + }) })