From eaa879f32f141d455519c8666736b2ede15307e6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 11 Oct 2024 23:47:17 -0500 Subject: [PATCH] http3: send GOAWAY when server is shutting down (#4691) * send goaway when server is shutting down * http3: track next stream ID instead of last stream ID for GOAWAYs * refactor the graceful shutdown integration tests * http3: improve GOAWAY frame parsing tests * http3: simplify server graceful shutdown logic * http3: simplify parsing of GOAWAY frames * http3: clean up initialization of server contexts * http3: fix race condition in graceful shutdown logic --------- Co-authored-by: WeidiDeng --- http3/frames.go | 27 ++++- http3/frames_test.go | 43 +++++++ http3/server.go | 155 +++++++++++++++++++++----- http3/server_test.go | 30 ++--- integrationtests/self/hotswap_test.go | 1 - integrationtests/self/http_test.go | 82 +++++++++++++- 6 files changed, 295 insertions(+), 43 deletions(-) diff --git a/http3/frames.go b/http3/frames.go index 66cba68ca..b54afb313 100644 --- a/http3/frames.go +++ b/http3/frames.go @@ -66,7 +66,8 @@ func (p *frameParser) ParseNext() (frame, error) { return parseSettingsFrame(p.r, l) case 0x3: // CANCEL_PUSH case 0x5: // PUSH_PROMISE - case 0x7: // GOAWAY + case 0x7: + return parseGoAwayFrame(qr, l) case 0xd: // MAX_PUSH_ID case 0x2, 0x6, 0x8, 0x9: p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") @@ -194,3 +195,27 @@ func (f *settingsFrame) Append(b []byte) []byte { } return b } + +type goAwayFrame struct { + StreamID quic.StreamID +} + +func parseGoAwayFrame(r io.ByteReader, l uint64) (*goAwayFrame, error) { + frame := &goAwayFrame{} + cbr := countingByteReader{ByteReader: r} + id, err := quicvarint.Read(&cbr) + if err != nil { + return nil, err + } + if cbr.Read != int(l) { + return nil, errors.New("GOAWAY frame: inconsistent length") + } + frame.StreamID = quic.StreamID(id) + return frame, nil +} + +func (f *goAwayFrame) Append(b []byte) []byte { + b = quicvarint.Append(b, 0x7) + b = quicvarint.Append(b, uint64(quicvarint.Len(uint64(f.StreamID)))) + return quicvarint.Append(b, uint64(f.StreamID)) +} diff --git a/http3/frames_test.go b/http3/frames_test.go index fc1984963..418876ca7 100644 --- a/http3/frames_test.go +++ b/http3/frames_test.go @@ -252,6 +252,49 @@ var _ = Describe("Frames", func() { }) }) + Context("GOAWAY frames", func() { + It("parses", func() { + data := quicvarint.Append(nil, 0x7) // type byte + data = quicvarint.Append(data, uint64(quicvarint.Len(100))) + data = quicvarint.Append(data, 100) + fp := frameParser{r: bytes.NewReader(data)} + frame, err := fp.ParseNext() + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&goAwayFrame{})) + Expect(frame.(*goAwayFrame).StreamID).To(Equal(quic.StreamID(100))) + }) + + It("errors on inconsistent lengths", func() { + data := quicvarint.Append(nil, 0x7) // type byte + data = quicvarint.Append(data, uint64(quicvarint.Len(100))+1) + data = quicvarint.Append(data, 100) + fp := frameParser{r: bytes.NewReader(data)} + _, err := fp.ParseNext() + Expect(err).To(MatchError("GOAWAY frame: inconsistent length")) + }) + + It("writes", func() { + data := (&goAwayFrame{StreamID: 200}).Append(nil) + fp := frameParser{r: bytes.NewReader(data)} + frame, err := fp.ParseNext() + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(BeAssignableToTypeOf(&goAwayFrame{})) + Expect(frame.(*goAwayFrame).StreamID).To(Equal(quic.StreamID(200))) + }) + + It("errors on EOF", func() { + data := (&goAwayFrame{StreamID: 1337}).Append(nil) + fp := frameParser{r: bytes.NewReader(data)} + _, err := fp.ParseNext() + Expect(err).ToNot(HaveOccurred()) + for i := range data { + fp := frameParser{r: bytes.NewReader(data[:i])} + _, err := fp.ParseNext() + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + Context("hijacking", func() { It("reads a frame without hijacking the stream", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 1337)) diff --git a/http3/server.go b/http3/server.go index dd4970bff..6f8f5e47a 100644 --- a/http3/server.go +++ b/http3/server.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -45,6 +46,8 @@ const ( streamTypeQPACKDecoderStream = 3 ) +const goawayTimeout = 5 * time.Second + // A QUICEarlyListener listens for incoming QUIC connections. type QUICEarlyListener interface { Accept(context.Context) (quic.EarlyConnection, error) @@ -213,7 +216,13 @@ type Server struct { mutex sync.RWMutex listeners map[*QUICEarlyListener]listenerInfo - closed bool + closed bool + closeCtx context.Context // canceled when the server is closed + closeCancel context.CancelFunc // cancels the closeCtx + graceCtx context.Context // canceled when the server is closed or gracefully closed + graceCancel context.CancelFunc // cancels the graceCtx + connCount atomic.Int64 + connHandlingDone chan struct{} altSvcHeader string } @@ -265,11 +274,32 @@ func (s *Server) Serve(conn net.PacketConn) error { return s.serveListener(ln) } +// init initializes the contexts used for shutting down the server. +// It must be called with the mutex held. +func (s *Server) init() { + if s.closeCtx == nil { + s.closeCtx, s.closeCancel = context.WithCancel(context.Background()) + s.graceCtx, s.graceCancel = context.WithCancel(s.closeCtx) + } + s.connHandlingDone = make(chan struct{}, 1) +} + +func (s *Server) decreaseConnCount() { + if s.connCount.Add(-1) == 0 { + close(s.connHandlingDone) + } +} + // ServeQUICConn serves a single QUIC connection. -// It is the caller's responsibility to close the connection. -// Specifically, closing the server does not close the connection. func (s *Server) ServeQUICConn(conn quic.Connection) error { - return s.handleConn(context.Background(), conn) + s.mutex.Lock() + s.init() + s.mutex.Unlock() + + s.connCount.Add(1) + defer s.decreaseConnCount() + + return s.handleConn(conn) } // ServeListener serves an existing QUIC listener. @@ -290,19 +320,19 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error { } func (s *Server) serveListener(ln QUICEarlyListener) error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - for { - conn, err := ln.Accept(context.Background()) - if errors.Is(err, quic.ErrServerClosed) { + conn, err := ln.Accept(s.graceCtx) + // server closed + if errors.Is(err, quic.ErrServerClosed) || s.graceCtx.Err() != nil { return http.ErrServerClosed } if err != nil { return err } + s.connCount.Add(1) go func() { - if err := s.handleConn(ctx, conn); err != nil { + defer s.decreaseConnCount() + if err := s.handleConn(conn); err != nil { if s.Logger != nil { s.Logger.Debug("handling connection failed", "error", err) } @@ -435,6 +465,7 @@ func (s *Server) addListener(l *QUICEarlyListener) error { if s.listeners == nil { s.listeners = make(map[*QUICEarlyListener]listenerInfo) } + s.init() laddr := (*l).Addr() if port, err := extractPort(laddr.String()); err == nil { @@ -458,9 +489,12 @@ func (s *Server) removeListener(l *QUICEarlyListener) { s.generateAltSvcHeader() } -func (s *Server) handleConn(serverCtx context.Context, conn quic.Connection) error { - // send a SETTINGS frame - str, err := conn.OpenUniStream() +// handleConn handles the HTTP/3 exchange on a QUIC connection. +// It blocks until all HTTP handlers for all streams have returned. +func (s *Server) handleConn(conn quic.Connection) error { + // open the control stream and send a SETTINGS frame, it's also used to send a GOAWAY frame later + // when the server is gracefully closed + ctrlStr, err := conn.OpenUniStream() if err != nil { return fmt.Errorf("opening the control stream failed: %w", err) } @@ -471,7 +505,7 @@ func (s *Server) handleConn(serverCtx context.Context, conn quic.Connection) err ExtendedConnect: true, Other: s.AdditionalSettings, }).Append(b) - str.Write(b) + ctrlStr.Write(b) ctx := conn.Context() ctx = context.WithValue(ctx, ServerContextKey, s) @@ -494,23 +528,58 @@ func (s *Server) handleConn(serverCtx context.Context, conn quic.Connection) err ) go hconn.HandleUnidirectionalStreams(s.UniStreamHijacker) + var nextStreamID quic.StreamID + var wg sync.WaitGroup + var handleErr error // Process all requests immediately. // It's the client's responsibility to decide which requests are eligible for 0-RTT. for { - str, datagrams, err := hconn.acceptStream(serverCtx) + str, datagrams, err := hconn.acceptStream(s.graceCtx) if err != nil { - // close the connection if the server was closed - if errors.Is(err, context.Canceled) { + // server (not gracefully) closed, close the connection immediately + if s.closeCtx.Err() != nil { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") + handleErr = http.ErrServerClosed + break } + + // gracefully closed, send GOAWAY frame and wait for requests to complete or grace period to end + // new requests will be rejected and shouldn't be sent + if s.graceCtx.Err() != nil { + b = (&goAwayFrame{StreamID: nextStreamID}).Append(b[:0]) + // set a deadline to send the GOAWAY frame + ctrlStr.SetWriteDeadline(time.Now().Add(goawayTimeout)) + ctrlStr.Write(b) + + select { + case <-hconn.Context().Done(): + // we expect the client to eventually close the connection after receiving the GOAWAY + case <-s.closeCtx.Done(): + // close the connection after graceful period + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") + } + handleErr = http.ErrServerClosed + break + } + var appErr *quic.ApplicationError - if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) { - return nil + if !errors.As(err, &appErr) || appErr.ErrorCode != quic.ApplicationErrorCode(ErrCodeNoError) { + handleErr = fmt.Errorf("accepting stream failed: %w", err) } - return fmt.Errorf("accepting stream failed: %w", err) + break } - go s.handleRequest(hconn, str, datagrams, hconn.decoder) + + nextStreamID = str.StreamID() + 4 + wg.Add(1) + go func() { + // handleRequest will return once the request has been handled, + // or the underlying connection is closed + defer wg.Done() + s.handleRequest(hconn, str, datagrams, hconn.decoder) + }() } + wg.Wait() + return handleErr } func (s *Server) maxHeaderBytes() uint64 { @@ -652,11 +721,17 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. +// It is the caller's responsibility to close any connection passed to ServeQUICConn. func (s *Server) Close() error { s.mutex.Lock() defer s.mutex.Unlock() s.closed = true + // server is never used + if s.closeCtx == nil { + return nil + } + s.closeCancel() var err error for ln := range s.listeners { @@ -664,14 +739,44 @@ func (s *Server) Close() error { err = cerr } } + if s.connCount.Load() == 0 { + return err + } + // wait for all connections to be closed + <-s.connHandlingDone return err } -// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete. +// CloseGracefully shuts down the server gracefully. +// The server sends a GOAWAY frame first, then or for all running requests to complete. // CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. -func (s *Server) CloseGracefully(timeout time.Duration) error { - // TODO: implement - return nil +func (s *Server) CloseGracefully(ctx context.Context) error { + s.mutex.Lock() + s.closed = true + // server is never used + if s.closeCtx == nil { + s.mutex.Unlock() + return nil + } + s.graceCancel() + s.mutex.Unlock() + + if s.connCount.Load() == 0 { + return nil + } + select { + case <-s.connHandlingDone: // all connections were closed + // When receiving a GOAWAY frame, HTTP/3 clients are expected to close the connection + // once all requests were successfully handled... + return s.Close() + case <-ctx.Done(): + // ... however, clients handling long-lived requests (and misbehaving clients), + // might not do so before the context is cancelled. + // In this case, we close the server, which closes all existing connections + // (expect those passed to ServeQUICConn). + _ = s.Close() + return ctx.Err() + } } // ErrNoAltSvcPort is the error returned by SetQUICHeaders when no port was found diff --git a/http3/server_test.go b/http3/server_test.go index 848fe700c..49206d15e 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -75,6 +75,8 @@ var _ = Describe("Server", func() { return context.WithValue(ctx, testConnContextKey("test"), c) }, } + s.closeCtx, s.closeCancel = context.WithCancel(context.Background()) + s.graceCtx, s.graceCancel = context.WithCancel(s.closeCtx) origQuicListenAddr = quicListenAddr }) @@ -356,7 +358,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -384,7 +386,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -412,7 +414,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -440,7 +442,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -485,7 +487,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -510,7 +512,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -537,7 +539,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -585,7 +587,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) str.EXPECT().Close().Do(func() error { close(done); return nil }) - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(done).Should(BeClosed()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) @@ -607,7 +609,7 @@ var _ = Describe("Server", func() { var buf bytes.Buffer str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(handlerCalled).Should(BeClosed()) // The buffer is expected to contain: @@ -643,7 +645,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) @@ -659,7 +661,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(context.Background(), conn) + s.handleConn(conn) Consistently(handlerCalled).ShouldNot(BeClosed()) }) @@ -680,7 +682,7 @@ var _ = Describe("Server", func() { close(done) return nil }) - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) @@ -703,7 +705,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(context.Background(), conn) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) }) @@ -1198,7 +1200,7 @@ var _ = Describe("Server", func() { }) It("closes gracefully", func() { - Expect(s.CloseGracefully(0)).To(Succeed()) + Expect(s.CloseGracefully(context.Background())).To(Succeed()) }) It("errors when listening fails", func() { diff --git a/integrationtests/self/hotswap_test.go b/integrationtests/self/hotswap_test.go index 4091cda6a..ef33a9327 100644 --- a/integrationtests/self/hotswap_test.go +++ b/integrationtests/self/hotswap_test.go @@ -46,7 +46,6 @@ type fakeClosingListener struct { } func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) { - Expect(ctx).To(Equal(context.Background())) return ln.listenerWrapper.Accept(ln.ctx) } diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 69309a9c0..65d1c847b 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -1071,8 +1071,11 @@ var _ = Describe("HTTP tests", func() { It("aborts requests on shutdown", func() { mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - Expect(server.Close()).To(Succeed()) + go func() { + defer GinkgoRecover() + Expect(server.Close()).To(Succeed()) + }() + time.Sleep(scaleDuration(50 * time.Millisecond)) // make sure the server started shutting down }) _, err := client.Get(fmt.Sprintf("https://localhost:%d/shutdown", port)) @@ -1081,4 +1084,79 @@ var _ = Describe("HTTP tests", func() { Expect(errors.As(err, &appErr)).To(BeTrue()) Expect(appErr.ErrorCode).To(Equal(http3.ErrCodeNoError)) }) + + It("allows existing requests to complete on graceful shutdown", func() { + delay := scaleDuration(100 * time.Millisecond) + done := make(chan struct{}) + + mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) { + go func() { + defer GinkgoRecover() + defer close(done) + Expect(server.CloseGracefully(context.Background())).To(Succeed()) + fmt.Println("close gracefully done") + }() + time.Sleep(delay) + w.Write([]byte("shutdown")) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 3*delay) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/shutdown", port), nil) + Expect(err).ToNot(HaveOccurred()) + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + body, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(body).To(Equal([]byte("shutdown"))) + // manually close the client, since we don't support + client.Transport.(*http3.RoundTripper).Close() + + // make sure that CloseGracefully returned + Eventually(done).Should(BeClosed()) + }) + + It("aborts long-lived requests on graceful shutdown", func() { + delay := scaleDuration(100 * time.Millisecond) + shutdownDone := make(chan struct{}) + requestChan := make(chan time.Duration, 1) + + mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + go func() { + defer GinkgoRecover() + ctx, cancel := context.WithTimeout(context.Background(), delay) + defer cancel() + defer close(shutdownDone) + Expect(server.CloseGracefully(ctx)).To(MatchError(context.DeadlineExceeded)) + }() + for t := range time.NewTicker(delay / 10).C { + if _, err := w.Write([]byte(t.String())); err != nil { + requestChan <- time.Since(start) + return + } + } + }) + + start := time.Now() + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/shutdown", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + _, err = io.Copy(io.Discard, resp.Body) + Expect(err).To(HaveOccurred()) + var h3Err *http3.Error + Expect(errors.As(err, &h3Err)).To(BeTrue()) + Expect(h3Err.ErrorCode).To(Equal(http3.ErrCodeNoError)) + took := time.Since(start) + Expect(took).To(BeNumerically("~", delay, delay/2)) + var requestDuration time.Duration + Eventually(requestChan).Should(Receive(&requestDuration)) + Expect(requestDuration).To(BeNumerically("~", delay, delay/2)) + + // make sure that CloseGracefully returned + Eventually(shutdownDone).Should(BeClosed()) + }) })