diff --git a/http3/client.go b/http3/client.go index ee9d4ec36..f5370549b 100644 --- a/http3/client.go +++ b/http3/client.go @@ -121,12 +121,16 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { } return } - go func(str quic.Stream) { - _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { + fp := &frameParser{ + r: str, + conn: c.hconn, + unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) { id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) return c.StreamHijacker(ft, id, str, e) - }) - if err == errHijacked { + }, + } + go func() { + if _, err := fp.ParseNext(); err == errHijacked { return } if err != nil { @@ -135,7 +139,7 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { } } c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") - }(str) + }() } } diff --git a/http3/client_test.go b/http3/client_test.go index 20ef96962..e3863afe9 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -465,7 +465,8 @@ var _ = Describe("Client", func() { fields := make(map[string]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str, nil) + fp := frameParser{r: str} + frame, err := fp.ParseNext() ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) diff --git a/http3/conn.go b/http3/conn.go index a153e408e..58ce46527 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -200,7 +200,8 @@ func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.Co c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") return } - f, err := parseNextFrame(str, nil) + fp := &frameParser{conn: c.Connection, r: str} + f, err := fp.ParseNext() if err != nil { c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") return diff --git a/http3/frames.go b/http3/frames.go index 515302ef5..66cba68ca 100644 --- a/http3/frames.go +++ b/http3/frames.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/quicvarint" ) @@ -18,13 +19,19 @@ type frame interface{} var errHijacked = errors.New("hijacked") -func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) { - qr := quicvarint.NewReader(r) +type frameParser struct { + r io.Reader + conn quic.Connection + unknownFrameHandler unknownFrameHandlerFunc +} + +func (p *frameParser) ParseNext() (frame, error) { + qr := quicvarint.NewReader(p.r) for { t, err := quicvarint.Read(qr) if err != nil { - if unknownFrameHandler != nil { - hijacked, err := unknownFrameHandler(0, err) + if p.unknownFrameHandler != nil { + hijacked, err := p.unknownFrameHandler(0, err) if err != nil { return nil, err } @@ -35,8 +42,8 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f return nil, err } // Call the unknownFrameHandler for frames not defined in the HTTP/3 spec - if t > 0xd && unknownFrameHandler != nil { - hijacked, err := unknownFrameHandler(FrameType(t), nil) + if t > 0xd && p.unknownFrameHandler != nil { + hijacked, err := p.unknownFrameHandler(FrameType(t), nil) if err != nil { return nil, err } @@ -56,11 +63,14 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f case 0x1: return &headersFrame{Length: l}, nil case 0x4: - return parseSettingsFrame(r, l) + return parseSettingsFrame(p.r, l) case 0x3: // CANCEL_PUSH case 0x5: // PUSH_PROMISE case 0x7: // GOAWAY case 0xd: // MAX_PUSH_ID + case 0x2, 0x6, 0x8, 0x9: + p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") + return nil, fmt.Errorf("http3: reserved frame type: %d", t) } // skip over unknown frames if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil { diff --git a/http3/frames_test.go b/http3/frames_test.go index fbb1d752e..fc1984963 100644 --- a/http3/frames_test.go +++ b/http3/frames_test.go @@ -6,10 +6,13 @@ import ( "fmt" "io" + "github.com/quic-go/quic-go" + mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/quic-go/quic-go/quicvarint" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" ) type errReader struct{ err error } @@ -18,30 +21,49 @@ func (e errReader) Read([]byte) (int, error) { return 0, e.err } var _ = Describe("Frames", func() { It("skips unknown frames", func() { - b := quicvarint.Append(nil, 0xdeadbeef) // type byte - b = quicvarint.Append(b, 0x42) - b = append(b, make([]byte, 0x42)...) - b = (&dataFrame{Length: 0x1234}).Append(b) - r := bytes.NewReader(b) - frame, err := parseNextFrame(r, nil) + data := quicvarint.Append(nil, 0xdeadbeef) // type byte + data = quicvarint.Append(data, 0x42) + data = append(data, make([]byte, 0x42)...) + data = (&dataFrame{Length: 0x1234}).Append(data) + fp := frameParser{r: bytes.NewReader(data)} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234))) }) + It("closes the connection when encountering a reserved frame type", func() { + conn := mockquic.NewMockEarlyConnection(mockCtrl) + for _, ft := range []uint64{0x2, 0x6, 0x8, 0x9} { + data := quicvarint.Append(nil, ft) + data = quicvarint.Append(data, 6) + data = append(data, []byte("foobar")...) + + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) + fp := frameParser{ + r: bytes.NewReader(data), + conn: conn, + } + _, err := fp.ParseNext() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("http3: reserved frame type")) + } + }) + Context("DATA frames", func() { It("parses", func() { data := quicvarint.Append(nil, 0) // type byte data = quicvarint.Append(data, 0x1337) - frame, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337))) }) It("writes", func() { - b := (&dataFrame{Length: 0xdeadbeef}).Append(nil) - frame, err := parseNextFrame(bytes.NewReader(b), nil) + fp := frameParser{r: bytes.NewReader((&dataFrame{Length: 0xdeadbeef}).Append(nil))} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) @@ -53,15 +75,17 @@ var _ = Describe("Frames", func() { It("parses", func() { data := quicvarint.Append(nil, 1) // type byte data = quicvarint.Append(data, 0x1337) - frame, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337))) }) It("writes", func() { - b := (&headersFrame{Length: 0xdeadbeef}).Append(nil) - frame, err := parseNextFrame(bytes.NewReader(b), nil) + data := (&headersFrame{Length: 0xdeadbeef}).Append(nil) + fp := frameParser{r: bytes.NewReader(data)} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) @@ -78,7 +102,8 @@ var _ = Describe("Frames", func() { data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) - frame, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{})) sf := frame.(*settingsFrame) @@ -94,7 +119,8 @@ var _ = Describe("Frames", func() { data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + _, err := fp.ParseNext() Expect(err).To(MatchError("duplicate setting: 13")) }) @@ -104,7 +130,8 @@ var _ = Describe("Frames", func() { 99: 999, 13: 37, }} - frame, err := parseNextFrame(bytes.NewReader(sf.Append(nil)), nil) + fp := frameParser{r: bytes.NewReader(sf.Append(nil))} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(sf)) }) @@ -115,14 +142,15 @@ var _ = Describe("Frames", func() { 0xdeadbeef: 0xdecafbad, }} data := sf.Append(nil) - - _, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + _, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) for i := range data { b := make([]byte, i) copy(b, data[:i]) - _, err := parseNextFrame(bytes.NewReader(b), nil) + fp := frameParser{r: bytes.NewReader(b)} + _, err := fp.ParseNext() Expect(err).To(MatchError(io.EOF)) } }) @@ -134,7 +162,8 @@ var _ = Describe("Frames", func() { data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) - f, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + f, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeAssignableToTypeOf(&settingsFrame{})) sf := f.(*settingsFrame) @@ -149,7 +178,8 @@ var _ = Describe("Frames", func() { data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + _, err := fp.ParseNext() Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingDatagram))) }) @@ -159,13 +189,15 @@ var _ = Describe("Frames", func() { data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + _, err := fp.ParseNext() Expect(err).To(MatchError("invalid value for SETTINGS_H3_DATAGRAM: 1337")) }) It("writes the SETTINGS_H3_DATAGRAM setting", func() { sf := &settingsFrame{Datagram: true} - frame, err := parseNextFrame(bytes.NewReader(sf.Append(nil)), nil) + fp := frameParser{r: bytes.NewReader(sf.Append(nil))} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(sf)) }) @@ -178,7 +210,8 @@ var _ = Describe("Frames", func() { data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) - f, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + f, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeAssignableToTypeOf(&settingsFrame{})) sf := f.(*settingsFrame) @@ -193,7 +226,8 @@ var _ = Describe("Frames", func() { data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + _, err := fp.ParseNext() Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingExtendedConnect))) }) @@ -203,13 +237,15 @@ var _ = Describe("Frames", func() { data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data), nil) + fp := frameParser{r: bytes.NewReader(data)} + _, err := fp.ParseNext() Expect(err).To(MatchError("invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: 1337")) }) It("writes the SETTINGS_ENABLE_CONNECT_PROTOCOL setting", func() { sf := &settingsFrame{ExtendedConnect: true} - frame, err := parseNextFrame(bytes.NewReader(sf.Append(nil)), nil) + fp := frameParser{r: bytes.NewReader(sf.Append(nil))} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(sf)) }) @@ -223,16 +259,20 @@ var _ = Describe("Frames", func() { buf.Write(customFrameContents) var called bool - _, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - Expect(ft).To(BeEquivalentTo(1337)) - called = true - b := make([]byte, 3) - _, err = io.ReadFull(buf, b) - Expect(err).ToNot(HaveOccurred()) - Expect(string(b)).To(Equal("foo")) - return true, nil - }) + fp := frameParser{ + r: buf, + unknownFrameHandler: func(ft FrameType, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + Expect(ft).To(BeEquivalentTo(1337)) + called = true + b := make([]byte, 3) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal("foo")) + return true, nil + }, + } + _, err := fp.ParseNext() Expect(err).To(MatchError(errHijacked)) Expect(called).To(BeTrue()) }) @@ -240,12 +280,16 @@ var _ = Describe("Frames", func() { It("passes on errors that occur when reading the frame type", func() { testErr := errors.New("test error") var called bool - _, err := parseNextFrame(errReader{err: testErr}, func(ft FrameType, e error) (hijacked bool, err error) { - Expect(e).To(MatchError(testErr)) - Expect(ft).To(BeZero()) - called = true - return true, nil - }) + fp := frameParser{ + r: errReader{err: testErr}, + unknownFrameHandler: func(ft FrameType, e error) (hijacked bool, err error) { + Expect(e).To(MatchError(testErr)) + Expect(ft).To(BeZero()) + called = true + return true, nil + }, + } + _, err := fp.ParseNext() Expect(err).To(MatchError(errHijacked)) Expect(called).To(BeTrue()) }) @@ -259,12 +303,16 @@ var _ = Describe("Frames", func() { b = append(b, []byte("foobar")...) var called bool - frame, err := parseNextFrame(bytes.NewReader(b), func(ft FrameType, e error) (hijacked bool, err error) { - Expect(e).ToNot(HaveOccurred()) - Expect(ft).To(BeEquivalentTo(1337)) - called = true - return false, nil - }) + fp := frameParser{ + r: bytes.NewReader(b), + unknownFrameHandler: func(ft FrameType, e error) (hijacked bool, err error) { + Expect(e).ToNot(HaveOccurred()) + Expect(ft).To(BeEquivalentTo(1337)) + called = true + return false, nil + }, + } + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(&dataFrame{Length: 6})) Expect(called).To(BeTrue()) diff --git a/http3/http_stream.go b/http3/http_stream.go index aa8600761..32ba93494 100644 --- a/http3/http_stream.go +++ b/http3/http_stream.go @@ -63,10 +63,14 @@ func newStream(str quic.Stream, conn *connection, datagrams *datagrammer) *strea } func (s *stream) Read(b []byte) (int, error) { + fp := &frameParser{ + r: s.Stream, + conn: s.conn, + } if s.bytesRemainingInFrame == 0 { parseLoop: for { - frame, err := parseNextFrame(s.Stream, nil) + frame, err := fp.ParseNext() if err != nil { return 0, err } @@ -177,7 +181,11 @@ func (s *requestStream) SendRequestHeader(req *http.Request) error { } func (s *requestStream) ReadResponse() (*http.Response, error) { - frame, err := parseNextFrame(s.Stream, nil) + fp := &frameParser{ + r: s.Stream, + conn: s.conn, + } + frame, err := fp.ParseNext() if err != nil { s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go index 6be1c6bd7..a97dcb7b4 100644 --- a/http3/http_stream_test.go +++ b/http3/http_stream_test.go @@ -140,7 +140,8 @@ var _ = Describe("Stream", func() { str.Write([]byte("foo")) str.Write([]byte("foobar")) - f, err := parseNextFrame(buf, nil) + fp := frameParser{r: buf} + f, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(f).To(Equal(&dataFrame{Length: 3})) b := make([]byte, 3) @@ -148,7 +149,8 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) Expect(b).To(Equal([]byte("foo"))) - f, err = parseNextFrame(buf, nil) + fp = frameParser{r: buf} + f, err = fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(f).To(Equal(&dataFrame{Length: 6})) b = make([]byte, 6) diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index 9459db8dd..4f2af74be 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -22,7 +22,8 @@ var _ = Describe("Request Writer", func() { ) decode := func(str io.Reader) map[string]string { - frame, err := parseNextFrame(str, nil) + fp := frameParser{r: str} + frame, err := fp.ParseNext() ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index f68e5acee..dbac17615 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -34,7 +34,8 @@ var _ = Describe("Response Writer", func() { fields := make(map[string][]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str, nil) + fp := frameParser{r: str} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) @@ -50,7 +51,8 @@ var _ = Describe("Response Writer", func() { } getData := func(str io.Reader) []byte { - frame, err := parseNextFrame(str, nil) + fp := frameParser{r: str} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) df := frame.(*dataFrame) diff --git a/http3/server.go b/http3/server.go index 7a16c6a85..4042da334 100644 --- a/http3/server.go +++ b/http3/server.go @@ -477,7 +477,8 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat ) } } - frame, err := parseNextFrame(str, ufh) + fp := &frameParser{conn: conn, r: str, unknownFrameHandler: ufh} + frame, err := fp.ParseNext() if err != nil { if !errors.Is(err, errHijacked) { str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) diff --git a/http3/server_test.go b/http3/server_test.go index d77671447..5597a2815 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -96,7 +96,8 @@ var _ = Describe("Server", func() { fields := make(map[string][]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str, nil) + fp := frameParser{r: str} + frame, err := fp.ParseNext() ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) @@ -619,7 +620,8 @@ var _ = Describe("Server", func() { // The buffer is expected to contain: // 1. The response header (in a HEADERS frame) // 2. the "foobar" (unframed) - frame, err := parseNextFrame(&buf, nil) + fp := frameParser{r: &buf} + frame, err := fp.ParseNext() Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) df := frame.(*headersFrame)