From e5693d0ad7ddc8a3971808379773fd4940260cc0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 9 Oct 2024 15:24:13 -0500 Subject: [PATCH] http3: immediately close all connections on Server.Close (#4689) * http3: immediately close all connections on Server.Close * http3: document connection closing when using ServeQUICConn --- http3/server.go | 19 +++++++++++---- http3/server_test.go | 38 +++++++++++++++--------------- integrationtests/self/http_test.go | 13 ++++++++++ 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/http3/server.go b/http3/server.go index 9f285b6e7..dd4970bff 100644 --- a/http3/server.go +++ b/http3/server.go @@ -266,8 +266,10 @@ func (s *Server) Serve(conn net.PacketConn) error { } // ServeQUICConn serves a single QUIC connection. +// It is the caller's responsibility to close the connection. +// Specifically, closing the server does not close the connection. func (s *Server) ServeQUICConn(conn quic.Connection) error { - return s.handleConn(conn) + return s.handleConn(context.Background(), conn) } // ServeListener serves an existing QUIC listener. @@ -288,16 +290,19 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error { } func (s *Server) serveListener(ln QUICEarlyListener) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for { conn, err := ln.Accept(context.Background()) - if err == quic.ErrServerClosed { + if errors.Is(err, quic.ErrServerClosed) { return http.ErrServerClosed } if err != nil { return err } go func() { - if err := s.handleConn(conn); err != nil { + if err := s.handleConn(ctx, conn); err != nil { if s.Logger != nil { s.Logger.Debug("handling connection failed", "error", err) } @@ -453,7 +458,7 @@ func (s *Server) removeListener(l *QUICEarlyListener) { s.generateAltSvcHeader() } -func (s *Server) handleConn(conn quic.Connection) error { +func (s *Server) handleConn(serverCtx context.Context, conn quic.Connection) error { // send a SETTINGS frame str, err := conn.OpenUniStream() if err != nil { @@ -492,8 +497,12 @@ func (s *Server) handleConn(conn quic.Connection) error { // Process all requests immediately. // It's the client's responsibility to decide which requests are eligible for 0-RTT. for { - str, datagrams, err := hconn.acceptStream(context.Background()) + str, datagrams, err := hconn.acceptStream(serverCtx) if err != nil { + // close the connection if the server was closed + if errors.Is(err, context.Canceled) { + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") + } var appErr *quic.ApplicationError if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) { return nil diff --git a/http3/server_test.go b/http3/server_test.go index f7a4c9ce3..848fe700c 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -356,7 +356,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -384,7 +384,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -412,7 +412,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -440,7 +440,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -485,7 +485,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -510,7 +510,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -537,7 +537,7 @@ var _ = Describe("Server", func() { }) ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) conn.EXPECT().Context().Return(ctx).AnyTimes() - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) @@ -585,7 +585,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) str.EXPECT().Close().Do(func() error { close(done); return nil }) - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(done).Should(BeClosed()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) @@ -607,7 +607,7 @@ var _ = Describe("Server", func() { var buf bytes.Buffer str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(handlerCalled).Should(BeClosed()) // The buffer is expected to contain: @@ -643,7 +643,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(done).Should(BeClosed()) }) @@ -659,7 +659,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(conn) + s.handleConn(context.Background(), conn) Consistently(handlerCalled).ShouldNot(BeClosed()) }) @@ -680,7 +680,7 @@ var _ = Describe("Server", func() { close(done) return nil }) - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(done).Should(BeClosed()) }) @@ -703,7 +703,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(conn) + s.handleConn(context.Background(), conn) Eventually(done).Should(BeClosed()) }) }) @@ -1068,10 +1068,10 @@ var _ = Describe("Server", func() { }) It("serves a listener", func() { - var called int32 + var called atomic.Bool ln := newMockAddrListener(":443") quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { - atomic.StoreInt32(&called, 1) + called.Store(true) return ln, nil } @@ -1090,7 +1090,7 @@ var _ = Describe("Server", func() { s.ServeListener(ln) }() - Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(called.Load).Should(BeFalse()) Consistently(done).ShouldNot(BeClosed()) ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) Expect(s.Close()).To(Succeed()) @@ -1098,14 +1098,14 @@ var _ = Describe("Server", func() { }) It("serves two listeners", func() { - var called int32 + var called atomic.Bool ln1 := newMockAddrListener(":443") ln2 := newMockAddrListener(":8443") lns := make(chan QUICEarlyListener, 2) lns <- ln1 lns <- ln2 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { - atomic.StoreInt32(&called, 1) + called.Store(true) return <-lns, nil } @@ -1137,7 +1137,7 @@ var _ = Describe("Server", func() { s.ServeListener(ln2) }() - Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(called.Load).Should(BeFalse()) Consistently(done1).ShouldNot(BeClosed()) Expect(done2).ToNot(BeClosed()) ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 305fa6039..69309a9c0 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -1068,4 +1068,17 @@ var _ = Describe("HTTP tests", func() { "Unannounced": {"Surprise!"}, }))) }) + + It("aborts requests on shutdown", func() { + mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + Expect(server.Close()).To(Succeed()) + }) + + _, err := client.Get(fmt.Sprintf("https://localhost:%d/shutdown", port)) + Expect(err).To(HaveOccurred()) + var appErr *http3.Error + Expect(errors.As(err, &appErr)).To(BeTrue()) + Expect(appErr.ErrorCode).To(Equal(http3.ErrCodeNoError)) + }) })