diff --git a/http3/client.go b/http3/client.go index 23ac088f..e60acffe 100644 --- a/http3/client.go +++ b/http3/client.go @@ -88,6 +88,7 @@ func (c *SingleDestinationRoundTripper) init() { c.EnableDatagrams, protocol.PerspectiveClient, c.Logger, + 0, ) // send the SETTINGs frame, using 0-RTT data, if possible go func() { diff --git a/http3/conn.go b/http3/conn.go index df7fb282..0fd9412f 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -7,6 +7,7 @@ import ( "net" "sync" "sync/atomic" + "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" @@ -51,6 +52,9 @@ type connection struct { settings *Settings receivedSettings chan struct{} + + idleTimeout time.Duration + idleTimer *time.Timer } func newConnection( @@ -59,17 +63,27 @@ func newConnection( enableDatagrams bool, perspective protocol.Perspective, logger *slog.Logger, + idleTimeout time.Duration, ) *connection { - return &connection{ + c := &connection{ ctx: ctx, Connection: quicConn, perspective: perspective, logger: logger, + idleTimeout: idleTimeout, enableDatagrams: enableDatagrams, decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), receivedSettings: make(chan struct{}), streams: make(map[protocol.StreamID]*datagrammer), } + if idleTimeout > 0 { + c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer) + } + return c +} + +func (c *connection) onIdleTimer() { + c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "idle timeout") } func (c *connection) clearStream(id quic.StreamID) { @@ -77,6 +91,9 @@ func (c *connection) clearStream(id quic.StreamID) { defer c.streamMx.Unlock() delete(c.streams, id) + if c.idleTimeout > 0 && len(c.streams) == 0 { + c.idleTimer.Reset(c.idleTimeout) + } } func (c *connection) openRequestStream( @@ -109,12 +126,24 @@ func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagramme strID := str.StreamID() c.streamMx.Lock() c.streams[strID] = datagrams + if c.idleTimeout > 0 { + if len(c.streams) == 1 { + c.idleTimer.Stop() + } + } c.streamMx.Unlock() str = newStateTrackingStream(str, c, datagrams) } return str, datagrams, nil } +func (c *connection) CloseWithError(code quic.ApplicationErrorCode, msg string) error { + if c.idleTimer != nil { + c.idleTimer.Stop() + } + return c.Connection.CloseWithError(code, msg) +} + func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) { var ( rcvdControlStr atomic.Bool diff --git a/http3/conn_test.go b/http3/conn_test.go index 4be8b7ae..1b1ff257 100644 --- a/http3/conn_test.go +++ b/http3/conn_test.go @@ -29,6 +29,7 @@ var _ = Describe("Connection", func() { false, protocol.PerspectiveServer, nil, + 0, ) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{ @@ -62,6 +63,7 @@ var _ = Describe("Connection", func() { false, protocol.PerspectiveServer, nil, + 0, ) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) @@ -104,6 +106,7 @@ var _ = Describe("Connection", func() { false, protocol.PerspectiveClient, nil, + 0, ) buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) str := mockquic.NewMockStream(mockCtrl) @@ -133,6 +136,7 @@ var _ = Describe("Connection", func() { false, protocol.PerspectiveClient, nil, + 0, ) buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) str1 := mockquic.NewMockStream(mockCtrl) @@ -169,6 +173,7 @@ var _ = Describe("Connection", func() { false, protocol.PerspectiveServer, nil, + 0, ) buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) str := mockquic.NewMockStream(mockCtrl) @@ -195,6 +200,7 @@ var _ = Describe("Connection", func() { false, protocol.PerspectiveServer, nil, + 0, ) b := quicvarint.Append(nil, streamTypeControlStream) b = (&dataFrame{}).Append(b) @@ -226,6 +232,7 @@ var _ = Describe("Connection", func() { false, protocol.PerspectiveServer, nil, + 0, ) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) @@ -264,6 +271,7 @@ var _ = Describe("Connection", func() { false, pers.Opposite(), nil, + 0, ) buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) controlStr := mockquic.NewMockStream(mockCtrl) @@ -294,6 +302,7 @@ var _ = Describe("Connection", func() { true, protocol.PerspectiveClient, nil, + 0, ) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{Datagram: true}).Append(b) @@ -333,6 +342,7 @@ var _ = Describe("Connection", func() { true, protocol.PerspectiveClient, nil, + 0, ) b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{Datagram: true}).Append(b) diff --git a/http3/http_stream_test.go b/http3/http_stream_test.go index a4c889e2..59f68279 100644 --- a/http3/http_stream_test.go +++ b/http3/http_stream_test.go @@ -43,7 +43,7 @@ var _ = Describe("Stream", func() { errorCbCalled = true return nil }).AnyTimes() - str = newStream(qstr, newConnection(context.Background(), conn, false, protocol.PerspectiveClient, nil), nil) + str = newStream(qstr, newConnection(context.Background(), conn, false, protocol.PerspectiveClient, nil, 0), nil) }) It("reads DATA frames in a single run", func() { @@ -171,7 +171,7 @@ var _ = Describe("Request Stream", func() { requestWriter := newRequestWriter() conn := mockquic.NewMockEarlyConnection(mockCtrl) str = newRequestStream( - newStream(qstr, newConnection(context.Background(), conn, false, protocol.PerspectiveClient, nil), nil), + newStream(qstr, newConnection(context.Background(), conn, false, protocol.PerspectiveClient, nil, 0), nil), requestWriter, make(chan struct{}), qpack.NewDecoder(func(qpack.HeaderField) {}), diff --git a/http3/server.go b/http3/server.go index cf52c174..9e7cd644 100644 --- a/http3/server.go +++ b/http3/server.go @@ -198,6 +198,12 @@ type Server struct { // In that case, the stream type will not be set. UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) + // IdleTimeout specifies how long until idle clients connection should be + // closed. Idle refers only to the HTTP/3 layer, activity at the QUIC layer + // like PING frames are not considered. + // If zero or negative, there is no timeout. + IdleTimeout time.Duration + // ConnContext optionally specifies a function that modifies the context used for a new connection c. // The provided ctx has a ServerContextKey value. ConnContext func(ctx context.Context, c quic.Connection) context.Context @@ -479,8 +485,10 @@ func (s *Server) handleConn(conn quic.Connection) error { s.EnableDatagrams, protocol.PerspectiveServer, s.Logger, + s.IdleTimeout, ) go hconn.HandleUnidirectionalStreams(s.UniStreamHijacker) + // Process all requests immediately. // It's the client's responsibility to decide which requests are eligible for 0-RTT. for { diff --git a/http3/server_test.go b/http3/server_test.go index 7e2138d5..f7a4c9ce 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -148,7 +148,7 @@ var _ = Describe("Server", func() { qconn.EXPECT().LocalAddr().AnyTimes() qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() qconn.EXPECT().Context().Return(context.Background()).AnyTimes() - conn = newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil) + conn = newConnection(context.Background(), qconn, false, protocol.PerspectiveServer, nil, 0) }) It("calls the HTTP handler function", func() { diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index e2cf5fb2..cd97896b 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -128,6 +128,18 @@ var _ = Describe("HTTP tests", func() { client = &http.Client{Transport: rt} }) + It("closes the connection after idle timeout", func() { + server.IdleTimeout = 100 * time.Millisecond + _, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", port)) + Expect(err).ToNot(HaveOccurred()) + + time.Sleep(150 * time.Millisecond) + + _, err = client.Get(fmt.Sprintf("https://localhost:%d/hello", port)) + Expect(err).ToNot(MatchError("idle timeout")) + server.IdleTimeout = 0 + }) + It("downloads a hello", func() { resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", port)) Expect(err).ToNot(HaveOccurred())