forked from quic-go/quic-go
http3: implement client-side GOAWAY handling (#5143)
When receiving a GOAWAY frame, the client: * immediately closes the connection if there are no active requests * refuses to open streams with stream IDs larger than the stream ID in the GOAWAY frame * closes the connection once the stream count drops to zero
This commit is contained in:
@@ -306,6 +306,9 @@ func testClientExtendedConnect(t *testing.T, enabled bool) {
|
||||
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||
<-allowSettings
|
||||
if r.Len() == 0 {
|
||||
<-done
|
||||
}
|
||||
return r.Read(b)
|
||||
}).AnyTimes()
|
||||
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
|
||||
|
||||
148
http3/conn.go
148
http3/conn.go
@@ -2,6 +2,7 @@ package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -19,6 +20,8 @@ import (
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
var errGoAway = errors.New("connection in graceful shutdown")
|
||||
|
||||
// Connection is an HTTP/3 connection.
|
||||
// It has all methods from the quic.Connection expect for AcceptStream, AcceptUniStream,
|
||||
// SendDatagram and ReceiveDatagram.
|
||||
@@ -50,8 +53,10 @@ type connection struct {
|
||||
|
||||
decoder *qpack.Decoder
|
||||
|
||||
streamMx sync.Mutex
|
||||
streams map[protocol.StreamID]*datagrammer
|
||||
streamMx sync.Mutex
|
||||
streams map[protocol.StreamID]*datagrammer
|
||||
lastStreamID protocol.StreamID
|
||||
maxStreamID protocol.StreamID
|
||||
|
||||
settings *Settings
|
||||
receivedSettings chan struct{}
|
||||
@@ -78,6 +83,8 @@ func newConnection(
|
||||
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
|
||||
receivedSettings: make(chan struct{}),
|
||||
streams: make(map[protocol.StreamID]*datagrammer),
|
||||
maxStreamID: protocol.InvalidStreamID,
|
||||
lastStreamID: protocol.InvalidStreamID,
|
||||
}
|
||||
if idleTimeout > 0 {
|
||||
c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer)
|
||||
@@ -97,6 +104,13 @@ func (c *connection) clearStream(id quic.StreamID) {
|
||||
if c.idleTimeout > 0 && len(c.streams) == 0 {
|
||||
c.idleTimer.Reset(c.idleTimeout)
|
||||
}
|
||||
// The server is performing a graceful shutdown.
|
||||
// If no more streams are remaining, close the connection.
|
||||
if c.maxStreamID != protocol.InvalidStreamID {
|
||||
if len(c.streams) == 0 {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connection) openRequestStream(
|
||||
@@ -106,6 +120,14 @@ func (c *connection) openRequestStream(
|
||||
disableCompression bool,
|
||||
maxHeaderBytes uint64,
|
||||
) (*requestStream, error) {
|
||||
c.streamMx.Lock()
|
||||
maxStreamID := c.maxStreamID
|
||||
lastStreamID := c.lastStreamID
|
||||
c.streamMx.Unlock()
|
||||
if maxStreamID != protocol.InvalidStreamID && lastStreamID >= maxStreamID {
|
||||
return nil, errGoAway
|
||||
}
|
||||
|
||||
str, err := c.OpenStreamSync(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -113,6 +135,7 @@ func (c *connection) openRequestStream(
|
||||
datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
|
||||
c.streamMx.Lock()
|
||||
c.streams[str.StreamID()] = datagrams
|
||||
c.lastStreamID = str.StreamID()
|
||||
c.streamMx.Unlock()
|
||||
qstr := newStateTrackingStream(str, c, datagrams)
|
||||
rsp := &http.Response{}
|
||||
@@ -244,44 +267,97 @@ func (c *connection) handleUnidirectionalStreams(hijack func(StreamType, quic.Co
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
|
||||
return
|
||||
}
|
||||
fp := &frameParser{conn: c.Connection, r: str}
|
||||
f, err := fp.ParseNext()
|
||||
if err != nil {
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
|
||||
return
|
||||
}
|
||||
sf, ok := f.(*settingsFrame)
|
||||
if !ok {
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
|
||||
return
|
||||
}
|
||||
c.settings = &Settings{
|
||||
EnableDatagrams: sf.Datagram,
|
||||
EnableExtendedConnect: sf.ExtendedConnect,
|
||||
Other: sf.Other,
|
||||
}
|
||||
close(c.receivedSettings)
|
||||
if !sf.Datagram {
|
||||
return
|
||||
}
|
||||
// If datagram support was enabled on our side as well as on the server side,
|
||||
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
|
||||
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
|
||||
if c.enableDatagrams && !c.ConnectionState().SupportsDatagrams {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
if err := c.receiveDatagrams(); err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("receiving datagrams failed", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
c.handleControlStream(str)
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connection) handleControlStream(str quic.ReceiveStream) {
|
||||
fp := &frameParser{conn: c.Connection, r: str}
|
||||
f, err := fp.ParseNext()
|
||||
if err != nil {
|
||||
var serr *quic.StreamError
|
||||
if err == io.EOF || errors.As(err, &serr) {
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeClosedCriticalStream), "")
|
||||
return
|
||||
}
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
|
||||
return
|
||||
}
|
||||
sf, ok := f.(*settingsFrame)
|
||||
if !ok {
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
|
||||
return
|
||||
}
|
||||
c.settings = &Settings{
|
||||
EnableDatagrams: sf.Datagram,
|
||||
EnableExtendedConnect: sf.ExtendedConnect,
|
||||
Other: sf.Other,
|
||||
}
|
||||
close(c.receivedSettings)
|
||||
if sf.Datagram {
|
||||
// If datagram support was enabled on our side as well as on the server side,
|
||||
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
|
||||
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
|
||||
if c.enableDatagrams && !c.ConnectionState().SupportsDatagrams {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
if err := c.receiveDatagrams(); err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("receiving datagrams failed", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// we don't support server push, hence we don't expect any GOAWAY frames from the client
|
||||
if c.perspective == protocol.PerspectiveServer {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
f, err := fp.ParseNext()
|
||||
if err != nil {
|
||||
var serr *quic.StreamError
|
||||
if err == io.EOF || errors.As(err, &serr) {
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeClosedCriticalStream), "")
|
||||
return
|
||||
}
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
|
||||
return
|
||||
}
|
||||
// GOAWAY is the only frame allowed at this point:
|
||||
// * unexpected frames are ignored by the frame parser
|
||||
// * we don't support any extension that might add support for more frames
|
||||
goaway, ok := f.(*goAwayFrame)
|
||||
if !ok {
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
|
||||
return
|
||||
}
|
||||
if goaway.StreamID%4 != 0 { // client-initiated, bidirectional streams
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
|
||||
return
|
||||
}
|
||||
c.streamMx.Lock()
|
||||
if c.maxStreamID != protocol.InvalidStreamID && goaway.StreamID > c.maxStreamID {
|
||||
c.streamMx.Unlock()
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
|
||||
return
|
||||
}
|
||||
c.maxStreamID = goaway.StreamID
|
||||
hasActiveStreams := len(c.streams) > 0
|
||||
c.streamMx.Unlock()
|
||||
|
||||
// immediately close the connection if there are currently no active requests
|
||||
if !hasActiveStreams {
|
||||
c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connection) sendDatagram(streamID protocol.StreamID, b []byte) error {
|
||||
// TODO: this creates a lot of garbage and an additional copy
|
||||
data := make([]byte, 0, len(b)+8)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -149,35 +150,96 @@ func TestConnResetUnknownUniStream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConnControlStreamFailures(t *testing.T) {
|
||||
t.Run("missing settings", func(t *testing.T) {
|
||||
testConnControlStreamFailures(t, (&dataFrame{}).Append(nil), ErrCodeMissingSettings)
|
||||
t.Run("missing SETTINGS", func(t *testing.T) {
|
||||
testConnControlStreamFailures(t, (&dataFrame{}).Append(nil), nil, ErrCodeMissingSettings)
|
||||
})
|
||||
t.Run("frame error", func(t *testing.T) {
|
||||
testConnControlStreamFailures(t,
|
||||
// 1337 is invalid value for the Extended CONNECT setting
|
||||
(&settingsFrame{Other: map[uint64]uint64{settingExtendedConnect: 1337}}).Append(nil),
|
||||
nil,
|
||||
ErrCodeFrameError,
|
||||
)
|
||||
})
|
||||
t.Run("control stream closed before SETTINGS", func(t *testing.T) {
|
||||
testConnControlStreamFailures(t, nil, io.EOF, ErrCodeClosedCriticalStream)
|
||||
})
|
||||
t.Run("control stream reset before SETTINGS", func(t *testing.T) {
|
||||
testConnControlStreamFailures(t,
|
||||
nil,
|
||||
&quic.StreamError{Remote: true, ErrorCode: 42},
|
||||
ErrCodeClosedCriticalStream,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func testConnControlStreamFailures(t *testing.T, data []byte, expectedErr ErrCode) {
|
||||
func TestConnGoAwayFailures(t *testing.T) {
|
||||
t.Run("invalid frame", func(t *testing.T) {
|
||||
b := (&settingsFrame{}).Append(nil)
|
||||
// 1337 is invalid value for the Extended CONNECT setting
|
||||
b = (&settingsFrame{Other: map[uint64]uint64{settingExtendedConnect: 1337}}).Append(b)
|
||||
testConnControlStreamFailures(t, b, nil, ErrCodeFrameError)
|
||||
})
|
||||
t.Run("not a GOAWAY", func(t *testing.T) {
|
||||
b := (&settingsFrame{}).Append(nil)
|
||||
// GOAWAY is the only allowed frame type after SETTINGS
|
||||
b = (&headersFrame{}).Append(b)
|
||||
testConnControlStreamFailures(t, b, nil, ErrCodeFrameUnexpected)
|
||||
})
|
||||
t.Run("stream closed before GOAWAY", func(t *testing.T) {
|
||||
testConnControlStreamFailures(t, (&settingsFrame{}).Append(nil), io.EOF, ErrCodeClosedCriticalStream)
|
||||
})
|
||||
t.Run("stream reset before GOAWAY", func(t *testing.T) {
|
||||
testConnControlStreamFailures(t,
|
||||
(&settingsFrame{}).Append(nil),
|
||||
&quic.StreamError{Remote: true, ErrorCode: 42},
|
||||
ErrCodeClosedCriticalStream,
|
||||
)
|
||||
})
|
||||
t.Run("invalid stream ID", func(t *testing.T) {
|
||||
data := (&settingsFrame{}).Append(nil)
|
||||
data = (&goAwayFrame{StreamID: 1}).Append(data)
|
||||
testConnControlStreamFailures(t, data, nil, ErrCodeIDError)
|
||||
})
|
||||
t.Run("increased stream ID", func(t *testing.T) {
|
||||
data := (&settingsFrame{}).Append(nil)
|
||||
data = (&goAwayFrame{StreamID: 4}).Append(data)
|
||||
data = (&goAwayFrame{StreamID: 8}).Append(data)
|
||||
testConnControlStreamFailures(t, data, nil, ErrCodeIDError)
|
||||
})
|
||||
}
|
||||
|
||||
func testConnControlStreamFailures(t *testing.T, data []byte, readErr error, expectedErr ErrCode) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
conn := newConnection(
|
||||
context.Background(),
|
||||
qconn,
|
||||
false,
|
||||
protocol.PerspectiveServer,
|
||||
protocol.PerspectiveClient,
|
||||
nil,
|
||||
0,
|
||||
)
|
||||
b := quicvarint.Append(nil, streamTypeControlStream)
|
||||
b = append(b, data...)
|
||||
r := bytes.NewReader(b)
|
||||
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(bytes.NewReader(b).Read).AnyTimes()
|
||||
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||
if r.Len() == 0 {
|
||||
return 0, readErr
|
||||
}
|
||||
return r.Read(b)
|
||||
}).AnyTimes()
|
||||
qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
|
||||
qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
|
||||
closed := make(chan struct{})
|
||||
|
||||
str := mockquic.NewMockStream(mockCtrl)
|
||||
str.EXPECT().StreamID().Return(4).AnyTimes()
|
||||
str.EXPECT().Context().Return(context.Background()).AnyTimes()
|
||||
qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(str, nil)
|
||||
conn.openRequestStream(context.Background(), nil, nil, true, 1000)
|
||||
|
||||
qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(expectedErr), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
|
||||
close(closed)
|
||||
return nil
|
||||
@@ -199,6 +261,88 @@ func testConnControlStreamFailures(t *testing.T, data []byte, expectedErr ErrCod
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnGoAway(t *testing.T) {
|
||||
t.Run("no active streams", func(t *testing.T) {
|
||||
testConnGoAway(t, false)
|
||||
})
|
||||
t.Run("active stream", func(t *testing.T) {
|
||||
testConnGoAway(t, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testConnGoAway(t *testing.T, withStream bool) {
|
||||
mockCtrl := gomock.NewController(t)
|
||||
qconn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
conn := newConnection(
|
||||
context.Background(),
|
||||
qconn,
|
||||
false,
|
||||
protocol.PerspectiveClient,
|
||||
nil,
|
||||
0,
|
||||
)
|
||||
b := quicvarint.Append(nil, streamTypeControlStream)
|
||||
b = (&settingsFrame{}).Append(b)
|
||||
b = (&goAwayFrame{StreamID: 4}).Append(b)
|
||||
|
||||
var mockStr *mockquic.MockStream
|
||||
var str quic.Stream
|
||||
if withStream {
|
||||
mockStr = mockquic.NewMockStream(mockCtrl)
|
||||
mockStr.EXPECT().StreamID().Return(4).AnyTimes()
|
||||
mockStr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
||||
qconn.EXPECT().OpenStreamSync(gomock.Any()).Return(mockStr, nil)
|
||||
s, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
|
||||
require.NoError(t, err)
|
||||
str = s
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
r := bytes.NewReader(b)
|
||||
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||
if r.Len() == 0 {
|
||||
<-done
|
||||
return 0, errors.New("test done")
|
||||
}
|
||||
return r.Read(b)
|
||||
}).AnyTimes()
|
||||
qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil)
|
||||
qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done"))
|
||||
closed := make(chan struct{})
|
||||
qconn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
|
||||
close(closed)
|
||||
return nil
|
||||
})
|
||||
// duplicate calls to CloseWithError are a no-op
|
||||
qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
go conn.handleUnidirectionalStreams(nil)
|
||||
|
||||
// the connection should be closed after the stream is closed
|
||||
if withStream {
|
||||
select {
|
||||
case <-closed:
|
||||
t.Fatal("connection closed")
|
||||
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
||||
}
|
||||
|
||||
_, err := conn.openRequestStream(context.Background(), nil, nil, true, 1000)
|
||||
require.ErrorIs(t, err, errGoAway)
|
||||
|
||||
mockStr.EXPECT().Close()
|
||||
str.Close()
|
||||
mockStr.EXPECT().CancelRead(gomock.Any())
|
||||
str.CancelRead(1337)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-closed:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnRejectPushStream(t *testing.T) {
|
||||
t.Run("client", func(t *testing.T) {
|
||||
testConnRejectPushStream(t, protocol.PerspectiveClient, ErrCodeStreamCreationError)
|
||||
@@ -299,8 +443,16 @@ func TestConnSendAndReceiveDatagram(t *testing.T) {
|
||||
)
|
||||
b := quicvarint.Append(nil, streamTypeControlStream)
|
||||
b = (&settingsFrame{Datagram: true}).Append(b)
|
||||
r := bytes.NewReader(b)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(bytes.NewReader(b).Read).AnyTimes()
|
||||
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||
if r.Len() == 0 {
|
||||
<-done
|
||||
}
|
||||
return r.Read(b)
|
||||
}).AnyTimes()
|
||||
qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil).MaxTimes(1)
|
||||
qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1)
|
||||
qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: true}).MaxTimes(1)
|
||||
@@ -351,6 +503,8 @@ func TestConnSendAndReceiveDatagram(t *testing.T) {
|
||||
expected = append(expected, []byte("foobaz")...)
|
||||
qconn.EXPECT().SendDatagram(expected).Return(assert.AnError)
|
||||
require.ErrorIs(t, conn.sendDatagram(strID2, []byte("foobaz")), assert.AnError)
|
||||
|
||||
qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
}
|
||||
|
||||
func TestConnDatagramFailures(t *testing.T) {
|
||||
@@ -376,18 +530,24 @@ func testConnDatagramFailures(t *testing.T, datagram []byte) {
|
||||
b := quicvarint.Append(nil, streamTypeControlStream)
|
||||
b = (&settingsFrame{Datagram: true}).Append(b)
|
||||
r := bytes.NewReader(b)
|
||||
done := make(chan struct{})
|
||||
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
||||
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
|
||||
if r.Len() == 0 {
|
||||
<-done
|
||||
}
|
||||
return r.Read(b)
|
||||
}).AnyTimes()
|
||||
qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil).MaxTimes(1)
|
||||
qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("test done")).MaxTimes(1)
|
||||
qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: true}).MaxTimes(1)
|
||||
|
||||
qconn.EXPECT().ReceiveDatagram(gomock.Any()).Return(datagram, nil)
|
||||
done := make(chan struct{})
|
||||
qconn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeDatagramError), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error {
|
||||
close(done)
|
||||
return nil
|
||||
})
|
||||
qconn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).AnyTimes() // further calls to CloseWithError are a no-op
|
||||
go func() { conn.handleUnidirectionalStreams(nil) }()
|
||||
select {
|
||||
case <-done:
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -39,33 +40,93 @@ func TestHTTPShutdown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGracefulShutdownShortRequest(t *testing.T) {
|
||||
delay := scaleDuration(25 * time.Millisecond)
|
||||
|
||||
var server *http3.Server
|
||||
mux := http.NewServeMux()
|
||||
port := startHTTPServer(t, mux, func(s *http3.Server) { server = s })
|
||||
errChan := make(chan error, 1)
|
||||
proceed := make(chan struct{})
|
||||
mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) {
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
errChan <- server.Shutdown(context.Background())
|
||||
}()
|
||||
time.Sleep(delay)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.(http.Flusher).Flush()
|
||||
<-proceed
|
||||
w.Write([]byte("shutdown"))
|
||||
})
|
||||
|
||||
client := newHTTP3Client(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*delay)
|
||||
connChan := make(chan quic.EarlyConnection, 1)
|
||||
tr := &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfigWithoutServerName(),
|
||||
Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf)
|
||||
connChan <- conn
|
||||
return conn, err
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { tr.Close() })
|
||||
|
||||
client := &http.Client{Transport: tr}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/shutdown", port), nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("shutdown"), body)
|
||||
client.Transport.(*http3.Transport).Close() // manually close the client
|
||||
|
||||
var conn quic.EarlyConnection
|
||||
select {
|
||||
case conn = <-connChan:
|
||||
default:
|
||||
t.Fatal("expected a connection")
|
||||
}
|
||||
|
||||
type result struct {
|
||||
body []byte
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan result, 1)
|
||||
go func() {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resultChan <- result{body: body, err: err}
|
||||
}()
|
||||
select {
|
||||
case <-resultChan:
|
||||
t.Fatal("request body shouldn't have been read yet")
|
||||
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
||||
}
|
||||
select {
|
||||
case <-conn.Context().Done():
|
||||
t.Fatal("connection shouldn't have been closed")
|
||||
default:
|
||||
}
|
||||
|
||||
// allow the request to proceed
|
||||
close(proceed)
|
||||
select {
|
||||
case res := <-resultChan:
|
||||
require.NoError(t, res.err)
|
||||
require.Equal(t, []byte("shutdown"), res.body)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// now that the stream count dropped to 0, the client should close the connection
|
||||
select {
|
||||
case <-conn.Context().Done():
|
||||
var appErr *quic.ApplicationError
|
||||
require.ErrorAs(t, context.Cause(conn.Context()), &appErr)
|
||||
assert.False(t, appErr.Remote)
|
||||
assert.Equal(t, quic.ApplicationErrorCode(http3.ErrCodeNoError), appErr.ErrorCode)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -75,6 +136,63 @@ func TestGracefulShutdownShortRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulShutdownIdleConnection(t *testing.T) {
|
||||
var server *http3.Server
|
||||
port := startHTTPServer(t, http.NewServeMux(), func(s *http3.Server) { server = s })
|
||||
|
||||
connChan := make(chan quic.EarlyConnection, 1)
|
||||
tr := &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfigWithoutServerName(),
|
||||
Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf)
|
||||
connChan <- conn
|
||||
return conn, err
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { tr.Close() })
|
||||
|
||||
client := &http.Client{Transport: tr}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/", port), nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
var conn quic.EarlyConnection
|
||||
select {
|
||||
case conn = <-connChan:
|
||||
default:
|
||||
t.Fatal("expected a connection")
|
||||
}
|
||||
// the connection should still be alive (and idle)
|
||||
select {
|
||||
case <-conn.Context().Done():
|
||||
t.Fatal("connection shouldn't have been closed")
|
||||
default:
|
||||
}
|
||||
|
||||
shutdownChan := make(chan error, 1)
|
||||
go func() { shutdownChan <- server.Shutdown(context.Background()) }()
|
||||
|
||||
// since the connection is idle, the client should close it immediately
|
||||
select {
|
||||
case <-conn.Context().Done():
|
||||
var appErr *quic.ApplicationError
|
||||
require.ErrorAs(t, context.Cause(conn.Context()), &appErr)
|
||||
assert.False(t, appErr.Remote)
|
||||
assert.Equal(t, quic.ApplicationErrorCode(http3.ErrCodeNoError), appErr.ErrorCode)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulShutdownLongLivedRequest(t *testing.T) {
|
||||
delay := scaleDuration(25 * time.Millisecond)
|
||||
errChan := make(chan error, 1)
|
||||
@@ -88,6 +206,8 @@ func TestGracefulShutdownLongLivedRequest(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.(http.Flusher).Flush()
|
||||
|
||||
// The request simulated here takes longer than the server's graceful shutdown period.
|
||||
// We expect it to be terminated once the server shuts down.
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), delay)
|
||||
defer cancel()
|
||||
|
||||
@@ -1099,3 +1099,115 @@ func testHTTPRequestRetryAfterIdleTimeout(t *testing.T, onlyCachedConn bool) {
|
||||
require.Equal(t, 2, headersCount)
|
||||
require.Empty(t, conns) // make sure we dialed 2 connections
|
||||
}
|
||||
|
||||
func TestHTTPRequestAfterGracefulShutdown(t *testing.T) {
|
||||
t.Run("Request.GetBody set", func(t *testing.T) {
|
||||
testHTTPRequestAfterGracefulShutdown(t, true)
|
||||
})
|
||||
t.Run("Request.GetBody not set", func(t *testing.T) {
|
||||
testHTTPRequestAfterGracefulShutdown(t, false)
|
||||
})
|
||||
}
|
||||
|
||||
func testHTTPRequestAfterGracefulShutdown(t *testing.T, setGetBody bool) {
|
||||
t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")
|
||||
|
||||
ln, err := quic.ListenEarly(
|
||||
newUDPConnLocalhost(t),
|
||||
http3.ConfigureTLSConfig(getTLSConfig()),
|
||||
getQuicConfig(nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
var inShutdown atomic.Bool
|
||||
proxy := quicproxy.Proxy{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ServerAddr: ln.Addr().(*net.UDPAddr),
|
||||
DelayPacket: func(_ quicproxy.Direction, _, _ net.Addr, data []byte) time.Duration {
|
||||
if inShutdown.Load() {
|
||||
return scaleDuration(10 * time.Millisecond)
|
||||
}
|
||||
return scaleDuration(2 * time.Millisecond)
|
||||
},
|
||||
}
|
||||
require.NoError(t, proxy.Start())
|
||||
defer proxy.Close()
|
||||
|
||||
mux2 := http.NewServeMux()
|
||||
mux2.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
|
||||
data, _ := io.ReadAll(r.Body)
|
||||
w.Write(data)
|
||||
})
|
||||
server2 := &http3.Server{Handler: mux2}
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
server1 := &http3.Server{Handler: http.NewServeMux()}
|
||||
|
||||
go server1.ServeListener(ln)
|
||||
|
||||
tlsConf := getTLSClientConfigWithoutServerName()
|
||||
tlsConf.NextProtos = []string{http3.NextProtoH3}
|
||||
var dialCount int
|
||||
tr := &http3.Transport{
|
||||
TLSClientConfig: tlsConf,
|
||||
Dial: func(ctx context.Context, a string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dialCount++
|
||||
return quic.DialEarly(ctx, newUDPConnLocalhost(t), addr, tlsConf, conf)
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { tr.Close() })
|
||||
cl := &http.Client{Transport: tr}
|
||||
|
||||
// first request to establish the connection
|
||||
resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/", ln.Addr().(*net.UDPAddr).Port))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
|
||||
// If the body is a strings.Reader, http.NewRequest automatically sets the GetBody callback.
|
||||
// This can be prevented by using a different kind of reader, e.g. the io.LimitReader.
|
||||
var headersCount int
|
||||
req, err := http.NewRequestWithContext(
|
||||
httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
|
||||
WroteHeaders: func() { headersCount++ },
|
||||
}),
|
||||
http.MethodGet,
|
||||
fmt.Sprintf("https://localhost:%d/echo", ln.Addr().(*net.UDPAddr).Port),
|
||||
io.LimitReader(strings.NewReader("foobar"), 1000),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
if setGetBody {
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(strings.NewReader("foobaz")), nil
|
||||
}
|
||||
} else {
|
||||
require.Nil(t, req.GetBody)
|
||||
}
|
||||
|
||||
inShutdown.Store(true)
|
||||
go server1.Shutdown(context.Background())
|
||||
go server2.ServeListener(ln)
|
||||
defer server2.Close()
|
||||
|
||||
// so that graceful shutdown can actually start
|
||||
time.Sleep(scaleDuration(10 * time.Millisecond))
|
||||
|
||||
resp, err = cl.Do(req)
|
||||
if !setGetBody {
|
||||
require.ErrorContains(t, err, "after Request.Body was written; define Request.GetBody to avoid this error")
|
||||
require.Equal(t, 1, dialCount)
|
||||
require.Equal(t, 1, headersCount)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foobaz", string(body))
|
||||
require.Equal(t, 2, dialCount)
|
||||
require.Equal(t, 2, headersCount)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user