http3: immediately close all connections on Server.Close (#4689)

* http3: immediately close all connections on Server.Close

* http3: document connection closing when using ServeQUICConn
This commit is contained in:
Marten Seemann
2024-10-09 15:24:13 -05:00
committed by GitHub
parent b2233591ad
commit e5693d0ad7
3 changed files with 46 additions and 24 deletions

View File

@@ -266,8 +266,10 @@ func (s *Server) Serve(conn net.PacketConn) error {
}
// ServeQUICConn serves a single QUIC connection.
// It is the caller's responsibility to close the connection.
// Specifically, closing the server does not close the connection.
func (s *Server) ServeQUICConn(conn quic.Connection) error {
return s.handleConn(conn)
return s.handleConn(context.Background(), conn)
}
// ServeListener serves an existing QUIC listener.
@@ -288,16 +290,19 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error {
}
func (s *Server) serveListener(ln QUICEarlyListener) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for {
conn, err := ln.Accept(context.Background())
if err == quic.ErrServerClosed {
if errors.Is(err, quic.ErrServerClosed) {
return http.ErrServerClosed
}
if err != nil {
return err
}
go func() {
if err := s.handleConn(conn); err != nil {
if err := s.handleConn(ctx, conn); err != nil {
if s.Logger != nil {
s.Logger.Debug("handling connection failed", "error", err)
}
@@ -453,7 +458,7 @@ func (s *Server) removeListener(l *QUICEarlyListener) {
s.generateAltSvcHeader()
}
func (s *Server) handleConn(conn quic.Connection) error {
func (s *Server) handleConn(serverCtx context.Context, conn quic.Connection) error {
// send a SETTINGS frame
str, err := conn.OpenUniStream()
if err != nil {
@@ -492,8 +497,12 @@ func (s *Server) handleConn(conn quic.Connection) error {
// Process all requests immediately.
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
for {
str, datagrams, err := hconn.acceptStream(context.Background())
str, datagrams, err := hconn.acceptStream(serverCtx)
if err != nil {
// close the connection if the server was closed
if errors.Is(err, context.Canceled) {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
}
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) {
return nil

View File

@@ -356,7 +356,7 @@ var _ = Describe("Server", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@@ -384,7 +384,7 @@ var _ = Describe("Server", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@@ -412,7 +412,7 @@ var _ = Describe("Server", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@@ -440,7 +440,7 @@ var _ = Describe("Server", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@@ -485,7 +485,7 @@ var _ = Describe("Server", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@@ -510,7 +510,7 @@ var _ = Describe("Server", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@@ -537,7 +537,7 @@ var _ = Describe("Server", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@@ -585,7 +585,7 @@ var _ = Describe("Server", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.EXPECT().Close().Do(func() error { close(done); return nil })
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(done).Should(BeClosed())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
@@ -607,7 +607,7 @@ var _ = Describe("Server", func() {
var buf bytes.Buffer
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(handlerCalled).Should(BeClosed())
// The buffer is expected to contain:
@@ -643,7 +643,7 @@ var _ = Describe("Server", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(done).Should(BeClosed())
})
@@ -659,7 +659,7 @@ var _ = Describe("Server", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Consistently(handlerCalled).ShouldNot(BeClosed())
})
@@ -680,7 +680,7 @@ var _ = Describe("Server", func() {
close(done)
return nil
})
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(done).Should(BeClosed())
})
@@ -703,7 +703,7 @@ var _ = Describe("Server", func() {
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
s.handleConn(conn)
s.handleConn(context.Background(), conn)
Eventually(done).Should(BeClosed())
})
})
@@ -1068,10 +1068,10 @@ var _ = Describe("Server", func() {
})
It("serves a listener", func() {
var called int32
var called atomic.Bool
ln := newMockAddrListener(":443")
quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
atomic.StoreInt32(&called, 1)
called.Store(true)
return ln, nil
}
@@ -1090,7 +1090,7 @@ var _ = Describe("Server", func() {
s.ServeListener(ln)
}()
Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
Consistently(called.Load).Should(BeFalse())
Consistently(done).ShouldNot(BeClosed())
ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
Expect(s.Close()).To(Succeed())
@@ -1098,14 +1098,14 @@ var _ = Describe("Server", func() {
})
It("serves two listeners", func() {
var called int32
var called atomic.Bool
ln1 := newMockAddrListener(":443")
ln2 := newMockAddrListener(":8443")
lns := make(chan QUICEarlyListener, 2)
lns <- ln1
lns <- ln2
quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
atomic.StoreInt32(&called, 1)
called.Store(true)
return <-lns, nil
}
@@ -1137,7 +1137,7 @@ var _ = Describe("Server", func() {
s.ServeListener(ln2)
}()
Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
Consistently(called.Load).Should(BeFalse())
Consistently(done1).ShouldNot(BeClosed())
Expect(done2).ToNot(BeClosed())
ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })

View File

@@ -1068,4 +1068,17 @@ var _ = Describe("HTTP tests", func() {
"Unannounced": {"Surprise!"},
})))
})
It("aborts requests on shutdown", func() {
mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
Expect(server.Close()).To(Succeed())
})
_, err := client.Get(fmt.Sprintf("https://localhost:%d/shutdown", port))
Expect(err).To(HaveOccurred())
var appErr *http3.Error
Expect(errors.As(err, &appErr)).To(BeTrue())
Expect(appErr.ErrorCode).To(Equal(http3.ErrCodeNoError))
})
})