forked from quic-go/quic-go
http3: immediately close all connections on Server.Close (#4689)
* http3: immediately close all connections on Server.Close * http3: document connection closing when using ServeQUICConn
This commit is contained in:
@@ -266,8 +266,10 @@ func (s *Server) Serve(conn net.PacketConn) error {
|
||||
}
|
||||
|
||||
// ServeQUICConn serves a single QUIC connection.
|
||||
// It is the caller's responsibility to close the connection.
|
||||
// Specifically, closing the server does not close the connection.
|
||||
func (s *Server) ServeQUICConn(conn quic.Connection) error {
|
||||
return s.handleConn(conn)
|
||||
return s.handleConn(context.Background(), conn)
|
||||
}
|
||||
|
||||
// ServeListener serves an existing QUIC listener.
|
||||
@@ -288,16 +290,19 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error {
|
||||
}
|
||||
|
||||
func (s *Server) serveListener(ln QUICEarlyListener) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err == quic.ErrServerClosed {
|
||||
if errors.Is(err, quic.ErrServerClosed) {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
if err := s.handleConn(conn); err != nil {
|
||||
if err := s.handleConn(ctx, conn); err != nil {
|
||||
if s.Logger != nil {
|
||||
s.Logger.Debug("handling connection failed", "error", err)
|
||||
}
|
||||
@@ -453,7 +458,7 @@ func (s *Server) removeListener(l *QUICEarlyListener) {
|
||||
s.generateAltSvcHeader()
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(conn quic.Connection) error {
|
||||
func (s *Server) handleConn(serverCtx context.Context, conn quic.Connection) error {
|
||||
// send a SETTINGS frame
|
||||
str, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
@@ -492,8 +497,12 @@ func (s *Server) handleConn(conn quic.Connection) error {
|
||||
// Process all requests immediately.
|
||||
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
|
||||
for {
|
||||
str, datagrams, err := hconn.acceptStream(context.Background())
|
||||
str, datagrams, err := hconn.acceptStream(serverCtx)
|
||||
if err != nil {
|
||||
// close the connection if the server was closed
|
||||
if errors.Is(err, context.Canceled) {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
|
||||
}
|
||||
var appErr *quic.ApplicationError
|
||||
if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) {
|
||||
return nil
|
||||
|
||||
@@ -356,7 +356,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
@@ -384,7 +384,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
@@ -412,7 +412,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
@@ -440,7 +440,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(done).Should(BeClosed())
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
@@ -485,7 +485,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
@@ -510,7 +510,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(done).Should(BeClosed())
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
@@ -537,7 +537,7 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
@@ -585,7 +585,7 @@ var _ = Describe("Server", func() {
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError))
|
||||
str.EXPECT().Close().Do(func() error { close(done); return nil })
|
||||
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(done).Should(BeClosed())
|
||||
hfs := decodeHeader(responseBuf)
|
||||
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||
@@ -607,7 +607,7 @@ var _ = Describe("Server", func() {
|
||||
var buf bytes.Buffer
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
|
||||
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(handlerCalled).Should(BeClosed())
|
||||
|
||||
// The buffer is expected to contain:
|
||||
@@ -643,7 +643,7 @@ var _ = Describe("Server", func() {
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
|
||||
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
@@ -659,7 +659,7 @@ var _ = Describe("Server", func() {
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) })
|
||||
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Consistently(handlerCalled).ShouldNot(BeClosed())
|
||||
})
|
||||
|
||||
@@ -680,7 +680,7 @@ var _ = Describe("Server", func() {
|
||||
close(done)
|
||||
return nil
|
||||
})
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
@@ -703,7 +703,7 @@ var _ = Describe("Server", func() {
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) })
|
||||
|
||||
s.handleConn(conn)
|
||||
s.handleConn(context.Background(), conn)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
@@ -1068,10 +1068,10 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
|
||||
It("serves a listener", func() {
|
||||
var called int32
|
||||
var called atomic.Bool
|
||||
ln := newMockAddrListener(":443")
|
||||
quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
|
||||
atomic.StoreInt32(&called, 1)
|
||||
called.Store(true)
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
@@ -1090,7 +1090,7 @@ var _ = Describe("Server", func() {
|
||||
s.ServeListener(ln)
|
||||
}()
|
||||
|
||||
Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
|
||||
Consistently(called.Load).Should(BeFalse())
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil })
|
||||
Expect(s.Close()).To(Succeed())
|
||||
@@ -1098,14 +1098,14 @@ var _ = Describe("Server", func() {
|
||||
})
|
||||
|
||||
It("serves two listeners", func() {
|
||||
var called int32
|
||||
var called atomic.Bool
|
||||
ln1 := newMockAddrListener(":443")
|
||||
ln2 := newMockAddrListener(":8443")
|
||||
lns := make(chan QUICEarlyListener, 2)
|
||||
lns <- ln1
|
||||
lns <- ln2
|
||||
quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
|
||||
atomic.StoreInt32(&called, 1)
|
||||
called.Store(true)
|
||||
return <-lns, nil
|
||||
}
|
||||
|
||||
@@ -1137,7 +1137,7 @@ var _ = Describe("Server", func() {
|
||||
s.ServeListener(ln2)
|
||||
}()
|
||||
|
||||
Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0)))
|
||||
Consistently(called.Load).Should(BeFalse())
|
||||
Consistently(done1).ShouldNot(BeClosed())
|
||||
Expect(done2).ToNot(BeClosed())
|
||||
ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil })
|
||||
|
||||
@@ -1068,4 +1068,17 @@ var _ = Describe("HTTP tests", func() {
|
||||
"Unannounced": {"Surprise!"},
|
||||
})))
|
||||
})
|
||||
|
||||
It("aborts requests on shutdown", func() {
|
||||
mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
Expect(server.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
_, err := client.Get(fmt.Sprintf("https://localhost:%d/shutdown", port))
|
||||
Expect(err).To(HaveOccurred())
|
||||
var appErr *http3.Error
|
||||
Expect(errors.As(err, &appErr)).To(BeTrue())
|
||||
Expect(appErr.ErrorCode).To(Equal(http3.ErrCodeNoError))
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user