From 2a37c531438f0d1512f39be9c768c363389be1ae Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 26 Apr 2024 20:21:04 +0200 Subject: [PATCH] http3: add support for HTTP Datagrams (RFC 9297) (#4452) * http3: add support for HTTP Datagrams (RFC 9297) * README: reference HTTP Datagrams (RFC 9297) --- README.md | 2 +- http3/README.md | 2 +- http3/client.go | 63 ++++-------- http3/client_test.go | 16 +-- http3/conn.go | 145 ++++++++++++++++++++++++---- http3/conn_test.go | 145 ++++++++++++++++++++++++---- http3/datagram.go | 100 +++++++++++++++++++ http3/datagram_test.go | 76 +++++++++++++++ http3/http_stream.go | 33 +++++-- http3/http_stream_test.go | 7 +- http3/response_writer_test.go | 2 +- http3/server.go | 15 ++- http3/server_test.go | 31 +++--- http3/state_tracking_stream.go | 91 +++++++++++++++++ http3/state_tracking_stream_test.go | 90 +++++++++++++++++ integrationtests/self/http_test.go | 93 +++++++++++++++++- 16 files changed, 790 insertions(+), 121 deletions(-) create mode 100644 http3/datagram.go create mode 100644 http3/datagram_test.go create mode 100644 http3/state_tracking_stream.go create mode 100644 http3/state_tracking_stream_test.go diff --git a/README.md b/README.md index faba82f3..1efa820f 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/) [![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/quic-go.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go) -quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)). +quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)). In addition to these base RFCs, it also implements the following RFCs: * Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) diff --git a/http3/README.md b/http3/README.md index a6128516..660e441e 100644 --- a/http3/README.md +++ b/http3/README.md @@ -2,7 +2,7 @@ [![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go/http3)](https://pkg.go.dev/github.com/quic-go/quic-go/http3) -This package implements HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)). +This package implements HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)). It aims to provide feature parity with the standard library's HTTP/1.1 and HTTP/2 implementation. ## Serving HTTP/3 diff --git a/http3/client.go b/http3/client.go index 61ded624..ee9d4ec3 100644 --- a/http3/client.go +++ b/http3/client.go @@ -80,25 +80,25 @@ func (c *SingleDestinationRoundTripper) Start() Connection { } func (c *SingleDestinationRoundTripper) init() { - c.requestWriter = newRequestWriter() c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {}) - c.hconn = newConnection(c.Connection, c.EnableDatagrams, c.UniStreamHijacker, protocol.PerspectiveClient, c.Logger) + c.requestWriter = newRequestWriter() + c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger) // send the SETTINGs frame, using 0-RTT data, if possible go func() { - if err := c.setupConn(c.Connection); err != nil { + if err := c.setupConn(c.hconn); err != nil { if c.Logger != nil { - c.Logger.Debug("setting up connection failed", "error", err) + c.Logger.Debug("Setting up connection failed", "error", err) } - c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") + c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") } }() if c.StreamHijacker != nil { go c.handleBidirectionalStreams() } - go c.hconn.HandleUnidirectionalStreams() + go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker) } -func (c *SingleDestinationRoundTripper) setupConn(conn quic.Connection) error { +func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error { // open the control stream str, err := conn.OpenUniStream() if err != nil { @@ -198,7 +198,8 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp } } - str, err := c.Connection.OpenStreamSync(req.Context()) + reqDone := make(chan struct{}) + str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes()) if err != nil { return nil, err } @@ -206,7 +207,6 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp // Request Cancellation: // This go routine keeps running even after RoundTripOpt() returns. // It is shut down when the application is done processing the body. - reqDone := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) @@ -218,7 +218,7 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp } }() - rsp, err := c.doRequest(req, str, reqDone) + rsp, err := c.doRequest(req, str) if err != nil { // if any error occurred close(reqDone) <-done @@ -230,18 +230,7 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) { c.initOnce.Do(func() { c.init() }) - str, err := c.Connection.OpenStreamSync(ctx) - if err != nil { - return nil, err - } - return newRequestStream( - newStream(str, c.hconn), - c.requestWriter, - nil, - c.decoder, - c.DisableCompression, - c.maxHeaderBytes(), - ), nil + return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes()) } // cancelingReader reads from the io.Reader. @@ -283,20 +272,12 @@ func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.Read return err } -func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str quic.Stream, reqDone chan<- struct{}) (*http.Response, error) { - hstr := newRequestStream( - newStream(str, c.hconn), - c.requestWriter, - reqDone, - c.decoder, - c.DisableCompression, - c.maxHeaderBytes(), - ) - if err := hstr.SendRequestHeader(req); err != nil { +func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) { + if err := str.SendRequestHeader(req); err != nil { return nil, err } if req.Body == nil { - hstr.Close() + str.Close() } else { // send the request body asynchronously go func() { @@ -306,27 +287,25 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str quic.St if req.ContentLength > 0 { contentLength = req.ContentLength } - if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil { + if err := c.sendRequestBody(str, req.Body, contentLength); err != nil { if c.Logger != nil { c.Logger.Debug("error writing request", "error", err) } } - hstr.Close() + str.Close() }() } - var ( - res *http.Response - err error - ) - // copy from net/http: support 1xx responses trace := httptrace.ContextClientTrace(req.Context()) num1xx := 0 // number of informational 1xx headers received const max1xxResponses = 5 // arbitrary bound on number of informational responses + var res *http.Response for { - if res, err = hstr.ReadResponse(); err != nil { + var err error + res, err = str.ReadResponse() + if err != nil { return nil, err } resCode := res.StatusCode @@ -347,7 +326,7 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str quic.St } break } - connState := c.Connection.ConnectionState().TLS + connState := c.hconn.ConnectionState().TLS res.TLS = &connState res.Request = req return res, nil diff --git a/http3/client_test.go b/http3/client_test.go index 443087a7..20ef9696 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -27,7 +27,7 @@ func encodeResponse(status int) []byte { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(newStream(rstr, nil), nil, false, nil) + rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) if status == http.StatusEarlyHints { rw.header.Add("Link", "; rel=preload; as=style") rw.header.Add("Link", "; rel=preload; as=script") @@ -361,7 +361,7 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }).MaxTimes(1) b := quicvarint.Append(nil, streamTypeControlStream) - b = (&settingsFrame{Datagram: true}).Append(b) + b = (&settingsFrame{ExtendedConnect: true}).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() @@ -375,7 +375,7 @@ var _ = Describe("Client", func() { hconn := rt.Start() Eventually(hconn.ReceivedSettings()).Should(BeClosed()) settings := hconn.Settings() - Expect(settings.EnableDatagram).To(BeTrue()) + Expect(settings.EnableExtendedConnect).To(BeTrue()) // test shutdown conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) close(done) @@ -425,7 +425,7 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }).MaxTimes(1) b := quicvarint.Append(nil, streamTypeControlStream) - b = (&settingsFrame{Datagram: true}).Append(b) + b = (&settingsFrame{}).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() @@ -493,6 +493,7 @@ var _ = Describe("Client", func() { return len(b), nil }) // SETTINGS frame str = mockquic.NewMockStream(mockCtrl) + str.EXPECT().StreamID().AnyTimes() conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { @@ -729,7 +730,6 @@ var _ = Describe("Client", func() { conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) done := make(chan struct{}) @@ -823,8 +823,9 @@ var _ = Describe("Client", func() { conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) + rstr.EXPECT().StreamID().AnyTimes() rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(newStream(rstr, nil), nil, false, nil) + rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) rw.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(rw) gz.Write([]byte("gzipped response")) @@ -849,8 +850,9 @@ var _ = Describe("Client", func() { conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) + rstr.EXPECT().StreamID().AnyTimes() rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(newStream(rstr, nil), nil, false, nil) + rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) rw.Write([]byte("not gzipped")) rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) diff --git a/http3/conn.go b/http3/conn.go index 1898d92a..55e199a8 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -1,14 +1,19 @@ package http3 import ( + "bytes" "context" + "fmt" "log/slog" "net" + "sync" "sync/atomic" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/quicvarint" + + "github.com/quic-go/qpack" ) // Connection is an HTTP/3 connection. @@ -37,8 +42,12 @@ type connection struct { perspective protocol.Perspective logger *slog.Logger - enableDatagrams bool - uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) + enableDatagrams bool + + decoder *qpack.Decoder + + streamMx sync.Mutex + streams map[protocol.StreamID]*datagrammer settings *Settings receivedSettings chan struct{} @@ -47,21 +56,80 @@ type connection struct { func newConnection( quicConn quic.Connection, enableDatagrams bool, - uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool), perspective protocol.Perspective, logger *slog.Logger, ) *connection { - return &connection{ - Connection: quicConn, - perspective: perspective, - logger: logger, - enableDatagrams: enableDatagrams, - uniStreamHijacker: uniStreamHijacker, - receivedSettings: make(chan struct{}), + c := &connection{ + Connection: quicConn, + perspective: perspective, + logger: logger, + enableDatagrams: enableDatagrams, + decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), + receivedSettings: make(chan struct{}), + streams: make(map[protocol.StreamID]*datagrammer), + } + return c +} + +func (c *connection) onStreamStateChange(id quic.StreamID, state streamState, e error) { + c.streamMx.Lock() + defer c.streamMx.Unlock() + + d, ok := c.streams[id] + if !ok { // should never happen + return + } + var isDone bool + //nolint:exhaustive // These are all the cases we care about. + switch state { + case streamStateReceiveClosed: + isDone = d.SetReceiveError(e) + case streamStateSendClosed: + isDone = d.SetSendError(e) + default: + return + } + if isDone { + delete(c.streams, id) } } -func (c *connection) HandleUnidirectionalStreams() { +func (c *connection) openRequestStream( + ctx context.Context, + requestWriter *requestWriter, + reqDone chan<- struct{}, + disableCompression bool, + maxHeaderBytes uint64, +) (*requestStream, error) { + str, err := c.Connection.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) }) + c.streamMx.Lock() + c.streams[str.StreamID()] = datagrams + c.streamMx.Unlock() + qstr := newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) }) + hstr := newStream(qstr, c, datagrams) + return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes), nil +} + +func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagrammer, error) { + str, err := c.AcceptStream(ctx) + if err != nil { + return nil, nil, err + } + datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) }) + if c.perspective == protocol.PerspectiveServer { + c.streamMx.Lock() + c.streams[str.StreamID()] = datagrams + c.streamMx.Unlock() + str = newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) }) + } + return str, datagrams, nil +} + +func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) { var ( rcvdControlStr atomic.Bool rcvdQPACKEncoderStr atomic.Bool @@ -81,7 +149,7 @@ func (c *connection) HandleUnidirectionalStreams() { streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) - if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), id, str, err) { + if hijack != nil && hijack(StreamType(streamType), id, str, err) { return } if c.logger != nil { @@ -115,8 +183,8 @@ func (c *connection) HandleUnidirectionalStreams() { } return default: - if c.uniStreamHijacker != nil { - if c.uniStreamHijacker( + if hijack != nil { + if hijack( StreamType(streamType), c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID), str, @@ -148,9 +216,7 @@ func (c *connection) HandleUnidirectionalStreams() { EnableExtendedConnect: sf.ExtendedConnect, Other: sf.Other, } - if c.receivedSettings != nil { - close(c.receivedSettings) - } + close(c.receivedSettings) if !sf.Datagram { return } @@ -159,11 +225,56 @@ func (c *connection) HandleUnidirectionalStreams() { // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). if c.enableDatagrams && !c.Connection.ConnectionState().SupportsDatagrams { c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") + return } + go func() { + if err := c.receiveDatagrams(); err != nil { + if c.logger != nil { + c.logger.Debug("receiving datagrams failed", "error", err) + } + } + }() }(str) } } +func (c *connection) sendDatagram(streamID protocol.StreamID, b []byte) error { + // TODO: this creates a lot of garbage and an additional copy + data := make([]byte, 0, len(b)+8) + data = quicvarint.Append(data, uint64(streamID/4)) + data = append(data, b...) + return c.Connection.SendDatagram(data) +} + +func (c *connection) receiveDatagrams() error { + for { + b, err := c.Connection.ReceiveDatagram(context.Background()) + if err != nil { + return err + } + // TODO: this is quite wasteful in terms of allocations + r := bytes.NewReader(b) + quarterStreamID, err := quicvarint.Read(r) + if err != nil { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") + return fmt.Errorf("could not read quarter stream id: %w", err) + } + if quarterStreamID > maxQuarterStreamID { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") + return fmt.Errorf("invalid quarter stream id: %w", err) + } + streamID := protocol.StreamID(4 * quarterStreamID) + c.streamMx.Lock() + dg, ok := c.streams[streamID] + if !ok { + c.streamMx.Unlock() + return nil + } + c.streamMx.Unlock() + dg.enqueue(b[len(b)-r.Len():]) + } +} + // ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received. func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings } diff --git a/http3/conn_test.go b/http3/conn_test.go index 45c66170..2ca9473b 100644 --- a/http3/conn_test.go +++ b/http3/conn_test.go @@ -22,10 +22,10 @@ var _ = Describe("Connection", func() { Context("control stream handling", func() { It("parses the SETTINGS frame", func() { qconn := mockquic.NewMockEarlyConnection(mockCtrl) + qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(nil, errors.New("no datagrams")) conn := newConnection( qconn, false, - nil, protocol.PerspectiveServer, nil, ) @@ -44,7 +44,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(conn.ReceivedSettings()).Should(BeClosed()) Expect(conn.Settings().EnableDatagram).To(BeTrue()) @@ -58,7 +58,6 @@ var _ = Describe("Connection", func() { conn := newConnection( qconn, false, - nil, protocol.PerspectiveServer, nil, ) @@ -82,7 +81,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(closed).Should(BeClosed()) Eventually(done).Should(BeClosed()) @@ -100,7 +99,6 @@ var _ = Describe("Connection", func() { conn := newConnection( qconn, false, - nil, protocol.PerspectiveClient, nil, ) @@ -119,7 +117,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(done).Should(BeClosed()) }) @@ -129,7 +127,6 @@ var _ = Describe("Connection", func() { conn := newConnection( qconn, false, - nil, protocol.PerspectiveClient, nil, ) @@ -154,7 +151,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(done).Should(BeClosed()) }) @@ -165,7 +162,6 @@ var _ = Describe("Connection", func() { conn := newConnection( qconn, false, - nil, protocol.PerspectiveServer, nil, ) @@ -180,7 +176,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(done).Should(BeClosed()) Eventually(reset).Should(BeClosed()) @@ -191,7 +187,6 @@ var _ = Describe("Connection", func() { conn := newConnection( qconn, false, - nil, protocol.PerspectiveServer, nil, ) @@ -211,7 +206,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(done).Should(BeClosed()) Eventually(closed).Should(BeClosed()) @@ -222,7 +217,6 @@ var _ = Describe("Connection", func() { conn := newConnection( qconn, false, - nil, protocol.PerspectiveServer, nil, ) @@ -242,7 +236,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(done).Should(BeClosed()) Eventually(closed).Should(BeClosed()) @@ -260,7 +254,6 @@ var _ = Describe("Connection", func() { conn := newConnection( qconn, false, - nil, pers.Opposite(), nil, ) @@ -278,7 +271,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(done).Should(BeClosed()) Eventually(closed).Should(BeClosed()) @@ -290,7 +283,6 @@ var _ = Describe("Connection", func() { conn := newConnection( qconn, true, - nil, protocol.PerspectiveClient, nil, ) @@ -311,10 +303,127 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() defer close(done) - conn.HandleUnidirectionalStreams() + conn.HandleUnidirectionalStreams(nil) }() Eventually(done).Should(BeClosed()) Eventually(closed).Should(BeClosed()) }) }) + + Context("datagram handling", func() { + var ( + qconn *mockquic.MockEarlyConnection + conn *connection + ) + + BeforeEach(func() { + qconn = mockquic.NewMockEarlyConnection(mockCtrl) + conn = newConnection( + qconn, + true, + protocol.PerspectiveClient, + nil, + ) + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&settingsFrame{Datagram: true}).Append(b) + r := bytes.NewReader(b) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil).MaxTimes(1) + qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1) + qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: true}).MaxTimes(1) + }) + + It("closes the connection if it can't parse the quarter stream ID", func() { + qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return([]byte{128}, nil) // return an invalid varint + done := make(chan struct{}) + qconn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeDatagramError), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error { + close(done) + return nil + }) + go func() { + defer GinkgoRecover() + conn.HandleUnidirectionalStreams(nil) + }() + Eventually(done).Should(BeClosed()) + }) + + It("closes the connection if the quarter stream ID is invalid", func() { + b := quicvarint.Append([]byte{}, maxQuarterStreamID+1) + qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(b, nil) + done := make(chan struct{}) + qconn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeDatagramError), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error { + close(done) + return nil + }) + go func() { + defer GinkgoRecover() + conn.HandleUnidirectionalStreams(nil) + }() + Eventually(done).Should(BeClosed()) + }) + + It("drops datagrams for non-existent streams", func() { + const strID = 4 + // first deliver the datagram... + b := quicvarint.Append([]byte{}, strID/4) + b = append(b, []byte("foobar")...) + delivered := make(chan struct{}) + qconn.EXPECT().ReceiveDatagram(gomock.Any()).DoAndReturn(func(context.Context) ([]byte, error) { + close(delivered) + return b, nil + }) + go func() { + defer GinkgoRecover() + conn.HandleUnidirectionalStreams(nil) + }() + Eventually(delivered).Should(BeClosed()) + + // ... then open the stream + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().Return(strID).MinTimes(1) + qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(qstr, nil) + str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000) + Expect(err).ToNot(HaveOccurred()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = str.ReceiveDatagram(ctx) + Expect(err).To(MatchError(context.Canceled)) + }) + + It("delivers datagrams for existing streams", func() { + const strID = 4 + + // first open the stream... + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().Return(strID).MinTimes(1) + qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(qstr, nil) + str, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000) + Expect(err).ToNot(HaveOccurred()) + + // ... then deliver the datagram + b := quicvarint.Append([]byte{}, strID/4) + b = append(b, []byte("foobar")...) + qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(b, nil) + qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(nil, errors.New("test done")) + go func() { + defer GinkgoRecover() + conn.HandleUnidirectionalStreams(nil) + }() + + data, err := str.ReceiveDatagram(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + }) + + It("sends datagrams", func() { + const strID = 404 + expected := quicvarint.Append([]byte{}, strID/4) + expected = append(expected, []byte("foobar")...) + testErr := errors.New("test error") + qconn.EXPECT().SendDatagram(expected).Return(testErr) + + Expect(conn.sendDatagram(strID, []byte("foobar"))).To(MatchError(testErr)) + }) + }) }) diff --git a/http3/datagram.go b/http3/datagram.go new file mode 100644 index 00000000..0a502881 --- /dev/null +++ b/http3/datagram.go @@ -0,0 +1,100 @@ +package http3 + +import ( + "context" + "sync" +) + +const maxQuarterStreamID = 1<<60 - 1 + +const streamDatagramQueueLen = 32 + +type datagrammer struct { + sendDatagram func([]byte) error + + hasData chan struct{} + queue [][]byte // TODO: use a ring buffer + + mx sync.Mutex + sendErr error + receiveErr error +} + +func newDatagrammer(sendDatagram func([]byte) error) *datagrammer { + return &datagrammer{ + sendDatagram: sendDatagram, + hasData: make(chan struct{}, 1), + } +} + +func (d *datagrammer) SetReceiveError(err error) (isDone bool) { + d.mx.Lock() + defer d.mx.Unlock() + + d.receiveErr = err + d.signalHasData() + return d.sendErr != nil +} + +func (d *datagrammer) SetSendError(err error) (isDone bool) { + d.mx.Lock() + defer d.mx.Unlock() + + d.sendErr = err + return d.receiveErr != nil +} + +func (d *datagrammer) Send(b []byte) error { + d.mx.Lock() + sendErr := d.sendErr + d.mx.Unlock() + if sendErr != nil { + return sendErr + } + + return d.sendDatagram(b) +} + +func (d *datagrammer) signalHasData() { + select { + case d.hasData <- struct{}{}: + default: + } +} + +func (d *datagrammer) enqueue(data []byte) { + d.mx.Lock() + defer d.mx.Unlock() + + if d.receiveErr != nil { + return + } + if len(d.queue) >= streamDatagramQueueLen { + return + } + d.queue = append(d.queue, data) + d.signalHasData() +} + +func (d *datagrammer) Receive(ctx context.Context) ([]byte, error) { +start: + d.mx.Lock() + if len(d.queue) >= 1 { + data := d.queue[0] + d.queue = d.queue[1:] + d.mx.Unlock() + return data, nil + } + if d.receiveErr != nil { + d.mx.Unlock() + return nil, d.receiveErr + } + d.mx.Unlock() + + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case <-d.hasData: + } + goto start +} diff --git a/http3/datagram_test.go b/http3/datagram_test.go new file mode 100644 index 00000000..85a6a823 --- /dev/null +++ b/http3/datagram_test.go @@ -0,0 +1,76 @@ +package http3 + +import ( + "context" + "errors" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Datagrams", func() { + It("receives a datagram", func() { + dg := newDatagrammer(nil) + dg.enqueue([]byte("foobar")) + data, err := dg.Receive(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + }) + + It("queues up to 32 datagrams", func() { + dg := newDatagrammer(nil) + for i := 0; i < streamDatagramQueueLen+1; i++ { + dg.enqueue([]byte{uint8(i)}) + } + for i := 0; i < streamDatagramQueueLen; i++ { + data, err := dg.Receive(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(data[0]).To(BeEquivalentTo(i)) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := dg.Receive(ctx) + Expect(err).To(MatchError(context.Canceled)) + }) + + It("blocks until a new datagram is received", func() { + dg := newDatagrammer(nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + data, err := dg.Receive(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + }() + + Consistently(done, 50*time.Millisecond).ShouldNot(BeClosed()) + dg.enqueue([]byte("foobar")) + Eventually(done).Should(BeClosed()) + }) + + It("drops datagrams when the stream's receive side is closed", func() { + dg := newDatagrammer(nil) + dg.enqueue([]byte("foo")) + testErr := errors.New("test error") + Expect(dg.SetReceiveError(testErr)).To(BeFalse()) + dg.enqueue([]byte("bar")) + data, err := dg.Receive(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foo"))) + _, err = dg.Receive(context.Background()) + Expect(err).To(MatchError(testErr)) + }) + + It("sends datagrams", func() { + var sent []byte + testErr := errors.New("test error") + dg := newDatagrammer(func(b []byte) error { + sent = b + return testErr + }) + Expect(dg.Send([]byte("foobar"))).To(MatchError(testErr)) + Expect(sent).To(Equal([]byte("foobar"))) + }) +}) diff --git a/http3/http_stream.go b/http3/http_stream.go index 32792a4f..aa860076 100644 --- a/http3/http_stream.go +++ b/http3/http_stream.go @@ -1,6 +1,7 @@ package http3 import ( + "context" "errors" "fmt" "io" @@ -15,12 +16,17 @@ import ( // A Stream is an HTTP/3 request stream. // When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. -type Stream = quic.Stream +type Stream interface { + quic.Stream + + SendDatagram([]byte) error + ReceiveDatagram(context.Context) ([]byte, error) +} // A RequestStream is an HTTP/3 request stream. // When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. type RequestStream interface { - quic.Stream + Stream // SendRequestHeader sends the HTTP request. // It is invalid to call it more than once. @@ -36,20 +42,23 @@ type RequestStream interface { type stream struct { quic.Stream - conn quic.Connection + conn *connection buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers bytesRemainingInFrame uint64 + + datagrams *datagrammer } var _ Stream = &stream{} -func newStream(str quic.Stream, conn quic.Connection) *stream { +func newStream(str quic.Stream, conn *connection, datagrams *datagrammer) *stream { return &stream{ - Stream: str, - conn: conn, - buf: make([]byte, 0, 16), + Stream: str, + conn: conn, + buf: make([]byte, 16), + datagrams: datagrams, } } @@ -238,3 +247,13 @@ func (s *requestStream) ReadResponse() (*http.Response, error) { res.Body = s.responseBody return res, nil } + +func (s *stream) SendDatagram(b []byte) error { + // TODO: reject if datagrams are not negotiated (yet) + return s.conn.sendDatagram(s.Stream.StreamID(), b) +} + +func (s *stream) ReceiveDatagram(ctx context.Context) ([]byte, error) { + // TODO: reject if datagrams are not negotiated (yet) + return s.datagrams.Receive(ctx) +} diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go index 6e989ec3..6be1c6bd 100644 --- a/http3/http_stream_test.go +++ b/http3/http_stream_test.go @@ -7,6 +7,7 @@ import ( "net/http" mockquic "github.com/quic-go/quic-go/internal/mocks/quic" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/qpack" @@ -41,7 +42,7 @@ var _ = Describe("Stream", func() { errorCbCalled = true return nil }).AnyTimes() - str = newStream(qstr, conn) + str = newStream(qstr, newConnection(conn, false, protocol.PerspectiveClient, nil), nil) }) It("reads DATA frames in a single run", func() { @@ -135,7 +136,7 @@ var _ = Describe("Stream", func() { buf := &bytes.Buffer{} qstr := mockquic.NewMockStream(mockCtrl) qstr.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - str := newStream(qstr, nil) + str := newStream(qstr, nil, nil) str.Write([]byte("foo")) str.Write([]byte("foobar")) @@ -167,7 +168,7 @@ var _ = Describe("Request Stream", func() { requestWriter := newRequestWriter() conn := mockquic.NewMockEarlyConnection(mockCtrl) str = newRequestStream( - newStream(qstr, conn), + newStream(qstr, newConnection(conn, false, protocol.PerspectiveClient, nil), nil), requestWriter, make(chan struct{}), qpack.NewDecoder(func(qpack.HeaderField) {}), diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 290eb7f3..f49e0730 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -26,7 +26,7 @@ var _ = Describe("Response Writer", func() { str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes() str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() - rw = newResponseWriter(newStream(str, nil), nil, false, nil) + rw = newResponseWriter(newStream(str, nil, nil), nil, false, nil) }) decodeHeader := func(str io.Reader) map[string][]string { diff --git a/http3/server.go b/http3/server.go index 37cf1956..7935db47 100644 --- a/http3/server.go +++ b/http3/server.go @@ -422,8 +422,6 @@ func (s *Server) removeListener(l *QUICEarlyListener) { } func (s *Server) handleConn(conn quic.Connection) error { - decoder := qpack.NewDecoder(nil) - // send a SETTINGS frame str, err := conn.OpenUniStream() if err != nil { @@ -441,16 +439,14 @@ func (s *Server) handleConn(conn quic.Connection) error { hconn := newConnection( conn, s.EnableDatagrams, - s.UniStreamHijacker, protocol.PerspectiveServer, s.Logger, ) - go hconn.HandleUnidirectionalStreams() - + go hconn.HandleUnidirectionalStreams(s.UniStreamHijacker) // Process all requests immediately. // It's the client's responsibility to decide which requests are eligible for 0-RTT. for { - str, err := conn.AcceptStream(context.Background()) + str, datagrams, err := hconn.acceptStream(context.Background()) if err != nil { var appErr *quic.ApplicationError if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) { @@ -458,7 +454,7 @@ func (s *Server) handleConn(conn quic.Connection) error { } return fmt.Errorf("accepting stream failed: %w", err) } - go s.handleRequest(hconn, str, decoder) + go s.handleRequest(hconn, str, datagrams, hconn.decoder) } } @@ -469,7 +465,7 @@ func (s *Server) maxHeaderBytes() uint64 { return uint64(s.MaxHeaderBytes) } -func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder) { +func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *datagrammer, decoder *qpack.Decoder) { var ufh unknownFrameHandlerFunc if s.StreamHijacker != nil { ufh = func(ft FrameType, e error) (processed bool, err error) { @@ -528,7 +524,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 { contentLength = req.ContentLength } - hstr := newStream(str, conn) + hstr := newStream(str, conn, datagrams) body := newRequestBody(hstr, contentLength, conn.Context(), conn.ReceivedSettings(), conn.Settings) req.Body = body @@ -553,6 +549,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack handler = http.DefaultServeMux } + // It's the client's responsibility to decide which requests are eligible for 0-RTT. var panicked bool func() { defer func() { diff --git a/http3/server_test.go b/http3/server_test.go index d8c9c358..892006db 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -139,13 +139,14 @@ var _ = Describe("Server", func() { qpackDecoder = qpack.NewDecoder(nil) str = mockquic.NewMockStream(mockCtrl) + str.EXPECT().StreamID().AnyTimes() qconn := mockquic.NewMockEarlyConnection(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} qconn.EXPECT().RemoteAddr().Return(addr).AnyTimes() qconn.EXPECT().LocalAddr().AnyTimes() qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() qconn.EXPECT().Context().Return(context.Background()).AnyTimes() - conn = newConnection(qconn, false, nil, protocol.PerspectiveServer, nil) + conn = newConnection(qconn, false, protocol.PerspectiveServer, nil) }) It("calls the HTTP handler function", func() { @@ -162,7 +163,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) var req *http.Request Eventually(requestChan).Should(Receive(&req)) Expect(req.Host).To(Equal("www.example.com")) @@ -181,7 +182,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) }) @@ -198,7 +199,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) @@ -220,7 +221,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) // status, date, content-type @@ -241,7 +242,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) Expect(responseBuf.Bytes()).To(BeEmpty()) @@ -261,7 +262,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) str.EXPECT().Close() - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) @@ -280,7 +281,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) Expect(responseBuf.Bytes()).To(HaveLen(0)) }) @@ -298,7 +299,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) Expect(responseBuf.Bytes()).To(HaveLen(0)) Expect(logBuf.String()).To(ContainSubstring("http: panic serving")) Expect(logBuf.String()).To(ContainSubstring("foobar")) @@ -334,6 +335,7 @@ var _ = Describe("Server", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().StreamID().AnyTimes() conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { @@ -357,6 +359,7 @@ var _ = Describe("Server", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().StreamID().AnyTimes() unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) @@ -383,6 +386,7 @@ var _ = Describe("Server", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().StreamID().AnyTimes() unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) @@ -400,17 +404,18 @@ var _ = Describe("Server", func() { }) It("handles errors that occur when reading the stream type", func() { + const strID = protocol.StreamID(1234 * 4) testErr := errors.New("test error") done := make(chan struct{}) unknownStr := mockquic.NewMockStream(mockCtrl) s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, err error) (bool, error) { defer close(done) Expect(ft).To(BeZero()) - Expect(str).To(Equal(unknownStr)) + Expect(str.StreamID()).To(Equal(strID)) Expect(err).To(MatchError(testErr)) return true, nil } - + unknownStr.EXPECT().StreamID().Return(strID).AnyTimes() unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) @@ -690,7 +695,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) str.EXPECT().Close() - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) Eventually(handlerCalled).Should(BeClosed()) }) @@ -713,7 +718,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) str.EXPECT().Close() - s.handleRequest(conn, str, qpackDecoder) + s.handleRequest(conn, str, nil, qpackDecoder) Eventually(handlerCalled).Should(BeClosed()) }) }) diff --git a/http3/state_tracking_stream.go b/http3/state_tracking_stream.go new file mode 100644 index 00000000..90914ebc --- /dev/null +++ b/http3/state_tracking_stream.go @@ -0,0 +1,91 @@ +package http3 + +import ( + "errors" + "sync" + + "github.com/quic-go/quic-go" +) + +type streamState uint8 + +const ( + streamStateOpen streamState = iota + streamStateReceiveClosed + streamStateSendClosed + streamStateSendAndReceiveClosed +) + +type stateTrackingStream struct { + quic.Stream + + mx sync.Mutex + state streamState + + onStateChange func(streamState, error) +} + +func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream { + return &stateTrackingStream{ + Stream: s, + state: streamStateOpen, + onStateChange: onStateChange, + } +} + +var _ quic.Stream = &stateTrackingStream{} + +func (s *stateTrackingStream) closeSend(e error) { + s.mx.Lock() + defer s.mx.Unlock() + + if s.state == streamStateReceiveClosed || s.state == streamStateSendAndReceiveClosed { + s.state = streamStateSendAndReceiveClosed + } else { + s.state = streamStateSendClosed + } + s.onStateChange(s.state, e) +} + +func (s *stateTrackingStream) closeReceive(e error) { + s.mx.Lock() + defer s.mx.Unlock() + + if s.state == streamStateSendClosed || s.state == streamStateSendAndReceiveClosed { + s.state = streamStateSendAndReceiveClosed + } else { + s.state = streamStateReceiveClosed + } + s.onStateChange(s.state, e) +} + +func (s *stateTrackingStream) Close() error { + s.closeSend(errors.New("write on closed stream")) + return s.Stream.Close() +} + +func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { + s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) + s.Stream.CancelWrite(e) +} + +func (s *stateTrackingStream) Write(b []byte) (int, error) { + n, err := s.Stream.Write(b) + if err != nil { + s.closeSend(err) + } + return n, err +} + +func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { + s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) + s.Stream.CancelRead(e) +} + +func (s *stateTrackingStream) Read(b []byte) (int, error) { + n, err := s.Stream.Read(b) + if err != nil { + s.closeReceive(err) + } + return n, err +} diff --git a/http3/state_tracking_stream_test.go b/http3/state_tracking_stream_test.go new file mode 100644 index 00000000..7526467a --- /dev/null +++ b/http3/state_tracking_stream_test.go @@ -0,0 +1,90 @@ +package http3 + +import ( + "bytes" + "errors" + "io" + + "github.com/quic-go/quic-go" + mockquic "github.com/quic-go/quic-go/internal/mocks/quic" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "go.uber.org/mock/gomock" +) + +type stateTransition struct { + state streamState + err error +} + +var _ = Describe("State Tracking Stream", func() { + var ( + qstr *mockquic.MockStream + str *stateTrackingStream + states []stateTransition + ) + + BeforeEach(func() { + states = nil + qstr = mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes() + str = newStateTrackingStream(qstr, func(state streamState, err error) { + states = append(states, stateTransition{state, err}) + }) + }) + + It("recognizes when the receive side is closed", func() { + buf := bytes.NewBuffer([]byte("foobar")) + qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + for i := 0; i < 3; i++ { + _, err := str.Read([]byte{0}) + Expect(err).ToNot(HaveOccurred()) + Expect(states).To(BeEmpty()) + } + _, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(states).To(HaveLen(1)) + Expect(states[0].state).To(Equal(streamStateReceiveClosed)) + Expect(states[0].err).To(Equal(io.EOF)) + }) + + It("recognizes read cancellations", func() { + buf := bytes.NewBuffer([]byte("foobar")) + qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + qstr.EXPECT().CancelRead(quic.StreamErrorCode(1337)) + _, err := str.Read(make([]byte, 3)) + Expect(err).ToNot(HaveOccurred()) + Expect(states).To(BeEmpty()) + str.CancelRead(1337) + Expect(states).To(HaveLen(1)) + Expect(states[0].state).To(Equal(streamStateReceiveClosed)) + Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337})) + }) + + It("recognizes when the send side is closed", func() { + testErr := errors.New("test error") + qstr.EXPECT().Write([]byte("foo")).Return(3, nil) + qstr.EXPECT().Write([]byte("bar")).Return(0, testErr) + _, err := str.Write([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(states).To(BeEmpty()) + _, err = str.Write([]byte("bar")) + Expect(err).To(MatchError(testErr)) + Expect(states).To(HaveLen(1)) + Expect(states[0].state).To(Equal(streamStateSendClosed)) + Expect(states[0].err).To(Equal(testErr)) + }) + + It("recognizes write cancellations", func() { + qstr.EXPECT().Write(gomock.Any()) + qstr.EXPECT().CancelWrite(quic.StreamErrorCode(1337)) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(states).To(BeEmpty()) + str.CancelWrite(1337) + Expect(states).To(HaveLen(1)) + Expect(states[0].state).To(Equal(streamStateSendClosed)) + Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337})) + }) +}) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index cf0cfb6f..73113d2d 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -6,6 +6,7 @@ import ( "compress/gzip" "context" "crypto/tls" + "encoding/binary" "errors" "fmt" "io" @@ -13,6 +14,7 @@ import ( "net/http" "net/http/httptrace" "net/textproto" + "net/url" "os" "strconv" "sync/atomic" @@ -91,7 +93,7 @@ var _ = Describe("HTTP tests", func() { server = &http3.Server{ Handler: mux, TLSConfig: getTLSConfig(), - QUICConfig: getQuicConfig(&quic.Config{Allow0RTT: true}), + QUICConfig: getQuicConfig(&quic.Config{Allow0RTT: true, EnableDatagrams: true}), } addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0") @@ -415,9 +417,11 @@ var _ = Describe("HTTP tests", func() { }) It("allows taking over the stream", func() { + handlerCalled := make(chan struct{}) mux.HandleFunc("/httpstreamer", func(w http.ResponseWriter, r *http.Request) { defer GinkgoRecover() - w.WriteHeader(200) + close(handlerCalled) + w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() str := r.Body.(http3.HTTPStreamer).HTTPStream() @@ -449,6 +453,9 @@ var _ = Describe("HTTP tests", func() { str, err := rt.OpenRequestStream(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str.SendRequestHeader(req)).To(Succeed()) + // make sure the request is received (and not stuck in some buffer, for example) + Eventually(handlerCalled).Should(BeClosed()) + rsp, err := str.ReadResponse() Expect(err).ToNot(HaveOccurred()) Expect(rsp.StatusCode).To(Equal(200)) @@ -715,6 +722,88 @@ var _ = Describe("HTTP tests", func() { Expect(cnt).To(Equal(0)) }) + Context("HTTP datagrams", func() { + It("sends an receives HTTP datagrams", func() { + errChan := make(chan error, 1) + const num = 5 + datagramChan := make(chan struct{}, num) + mux.HandleFunc("/datagrams", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + Expect(r.Method).To(Equal(http.MethodConnect)) + conn := w.(http3.Hijacker).Connection() + Eventually(conn.ReceivedSettings()).Should(BeClosed()) + Expect(conn.Settings().EnableDatagram).To(BeTrue()) + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + str := r.Body.(http3.HTTPStreamer).HTTPStream() + go str.Read([]byte{0}) // need to continue reading from stream to observe state transitions + + for { + if _, err := str.ReceiveDatagram(context.Background()); err != nil { + errChan <- err + return + } + datagramChan <- struct{}{} + } + }) + + tlsConf := getTLSClientConfigWithoutServerName() + tlsConf.NextProtos = []string{http3.NextProtoH3} + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", port), + tlsConf, + getQuicConfig(&quic.Config{EnableDatagrams: true}), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") + rt := &http3.SingleDestinationRoundTripper{ + Connection: conn, + EnableDatagrams: true, + } + str, err := rt.OpenRequestStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + u, err := url.Parse(fmt.Sprintf("https://localhost:%d/datagrams", port)) + Expect(err).ToNot(HaveOccurred()) + req := &http.Request{ + Method: http.MethodConnect, + Proto: "datagrams", + Host: u.Host, + URL: u, + } + Expect(str.SendRequestHeader(req)).To(Succeed()) + + rsp, err := str.ReadResponse() + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.StatusCode).To(Equal(http.StatusOK)) + + for i := 0; i < num; i++ { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(i)) + Expect(str.SendDatagram(bytes.Repeat(b, 100))).To(Succeed()) + } + var count int + loop: + for { + select { + case <-datagramChan: + count++ + if count >= num*4/5 { + break loop + } + case err := <-errChan: + Fail(fmt.Sprintf("receiving datagrams failed: %s", err)) + } + } + str.CancelWrite(42) + + var resetErr error + Eventually(errChan).Should(Receive(&resetErr)) + Expect(resetErr.(*quic.StreamError).ErrorCode).To(BeEquivalentTo(42)) + }) + }) + Context("0-RTT", func() { runCountingProxy := func(serverPort int, rtt time.Duration) (*quicproxy.QuicProxy, *atomic.Uint32) { var num0RTTPackets atomic.Uint32