Files
quic-go/http3/server_test.go

827 lines
24 KiB
Go

package http3
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"runtime"
"testing"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/testdata"
"github.com/quic-go/quic-go/quicvarint"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConfigureTLSConfig(t *testing.T) {
t.Run("basic config", func(t *testing.T) {
conf := ConfigureTLSConfig(&tls.Config{})
require.Equal(t, conf.NextProtos, []string{NextProtoH3})
})
t.Run("ALPN set", func(t *testing.T) {
conf := ConfigureTLSConfig(&tls.Config{NextProtos: []string{"foo", "bar"}})
require.Equal(t, []string{NextProtoH3}, conf.NextProtos)
})
// for configs that define GetConfigForClient, the ALPN is set to h3
t.Run("GetConfigForClient", func(t *testing.T) {
staticConf := &tls.Config{NextProtos: []string{"foo", "bar"}}
conf := ConfigureTLSConfig(&tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return staticConf, nil
},
})
innerConf, err := conf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "example.com"})
require.NoError(t, err)
require.NotNil(t, innerConf)
require.Equal(t, []string{NextProtoH3}, innerConf.NextProtos)
// make sure the original config was not modified
require.Equal(t, []string{"foo", "bar"}, staticConf.NextProtos)
})
// GetConfigForClient might return a nil tls.Config
t.Run("GetConfigForClient returns nil", func(t *testing.T) {
conf := ConfigureTLSConfig(&tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
})
innerConf, err := conf.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "example.com"})
require.NoError(t, err)
require.Nil(t, innerConf)
})
}
func TestServerSettings(t *testing.T) {
t.Run("enable datagrams", func(t *testing.T) {
testServerSettings(t, true, nil)
})
t.Run("additional settings", func(t *testing.T) {
testServerSettings(t, false, map[uint64]uint64{13: 37})
})
}
func testServerSettings(t *testing.T, enableDatagrams bool, other map[uint64]uint64) {
s := Server{
EnableDatagrams: enableDatagrams,
AdditionalSettings: other,
}
s.init()
testDone := make(chan struct{})
defer close(testDone)
clientConn, serverConn := newConnPair(t)
go s.handleConn(serverConn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
settingsStr, err := clientConn.AcceptUniStream(ctx)
require.NoError(t, err)
settingsStr.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1024)
n, err := settingsStr.Read(b)
require.NoError(t, err)
b = b[:n]
typ, l, err := quicvarint.Parse(b)
require.NoError(t, err)
require.EqualValues(t, streamTypeControlStream, typ)
fp := (&frameParser{r: bytes.NewReader(b[l:])})
f, err := fp.ParseNext()
require.NoError(t, err)
require.IsType(t, &settingsFrame{}, f)
settingsFrame := f.(*settingsFrame)
// Extended CONNECT is always supported
require.True(t, settingsFrame.ExtendedConnect)
require.Equal(t, settingsFrame.Datagram, enableDatagrams)
require.Equal(t, settingsFrame.Other, other)
}
func TestServerRequestHandling(t *testing.T) {
t.Run("200 with an empty handler", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {},
httptest.NewRequest(http.MethodGet, "https://www.example.com", nil),
)
require.Equal(t, hfs[":status"], []string{"200"})
require.Empty(t, body)
})
t.Run("content-length", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
w.Write([]byte("foobar"))
},
httptest.NewRequest(http.MethodGet, "https://www.example.com", nil),
)
require.Equal(t, hfs[":status"], []string{"418"})
require.Equal(t, hfs["content-length"], []string{"6"})
require.Equal(t, body, []byte("foobar"))
})
t.Run("no content-length when flushed", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foo"))
w.(http.Flusher).Flush()
w.Write([]byte("bar"))
},
httptest.NewRequest(http.MethodGet, "https://www.example.com", nil),
)
require.Equal(t, hfs[":status"], []string{"200"})
require.NotContains(t, hfs, "content-length")
require.Equal(t, body, []byte("foobar"))
})
t.Run("HEAD request", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
},
httptest.NewRequest(http.MethodHead, "https://www.example.com", nil),
)
require.Equal(t, hfs[":status"], []string{"200"})
require.Empty(t, body)
})
t.Run("POST request", func(t *testing.T) {
hfs, body := testServerRequestHandling(t,
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
data, _ := io.ReadAll(r.Body)
w.Write(data)
},
httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))),
)
require.Equal(t, hfs[":status"], []string{"418"})
require.Equal(t, []byte("foobar"), body)
})
}
func testServerRequestHandling(t *testing.T,
handler http.HandlerFunc,
req *http.Request,
) (responseHeaders map[string][]string, body []byte) {
clientConn, serverConn := newConnPair(t)
str, err := clientConn.OpenStream()
require.NoError(t, err)
_, err = str.Write(encodeRequest(t, req))
require.NoError(t, err)
require.NoError(t, str.Close())
s := &Server{Handler: handler}
go s.ServeQUICConn(serverConn)
hfs := decodeHeader(t, str)
fp := frameParser{r: str}
var content []byte
for {
frame, err := fp.ParseNext()
if err == io.EOF {
break
}
require.NoError(t, err)
require.IsType(t, &dataFrame{}, frame)
b := make([]byte, frame.(*dataFrame).Length)
_, err = io.ReadFull(str, b)
require.NoError(t, err)
content = append(content, b...)
}
return hfs, content
}
func TestServerFirstFrameNotHeaders(t *testing.T) {
clientConn, serverConn := newConnPair(t)
str, err := clientConn.OpenStream()
require.NoError(t, err)
var buf bytes.Buffer
buf.Write((&dataFrame{Length: 6}).Append(nil))
buf.Write([]byte("foobar"))
_, err = str.Write(buf.Bytes())
require.NoError(t, err)
require.NoError(t, str.Close())
s := &Server{}
go s.ServeQUICConn(serverConn)
select {
case <-clientConn.Context().Done():
err := context.Cause(clientConn.Context())
var appErr *quic.ApplicationError
require.ErrorAs(t, err, &appErr)
require.Equal(t, quic.ApplicationErrorCode(ErrCodeFrameUnexpected), appErr.ErrorCode)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerHandlerBodyNotRead(t *testing.T) {
t.Run("GET request with a body", func(t *testing.T) {
testServerHandlerBodyNotRead(t,
httptest.NewRequest(http.MethodGet, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))),
func(w http.ResponseWriter, r *http.Request) {},
)
})
t.Run("POST body not read", func(t *testing.T) {
testServerHandlerBodyNotRead(t,
httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))),
func(w http.ResponseWriter, r *http.Request) {},
)
})
t.Run("POST request, with a replaced body", func(t *testing.T) {
testServerHandlerBodyNotRead(t,
httptest.NewRequest(http.MethodPost, "https://www.example.com", bytes.NewBuffer([]byte("foobar"))),
func(w http.ResponseWriter, r *http.Request) {
r.Body = struct {
io.Reader
io.Closer
}{}
},
)
})
}
func testServerHandlerBodyNotRead(t *testing.T, req *http.Request, handler http.HandlerFunc) {
clientConn, serverConn := newConnPair(t)
str, err := clientConn.OpenStream()
require.NoError(t, err)
_, err = str.Write(encodeRequest(t, req))
require.NoError(t, err)
// require.NoError(t, str.Close())
done := make(chan struct{})
s := &Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(done)
handler(w, r)
}),
}
go s.ServeQUICConn(serverConn)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerStreamResetByClient(t *testing.T) {
clientConn, serverConn := newConnPair(t)
str, err := clientConn.OpenStream()
require.NoError(t, err)
str.CancelWrite(1337)
var called bool
s := &Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}),
}
go s.ServeQUICConn(serverConn)
expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
require.False(t, called)
}
func TestServerPanickingHandler(t *testing.T) {
t.Run("panicking handler", func(t *testing.T) {
logOutput := testServerPanickingHandler(t, func(w http.ResponseWriter, r *http.Request) {
panic("foobar")
})
require.Contains(t, logOutput, "http3: panic serving")
require.Contains(t, logOutput, "foobar")
})
t.Run("http.ErrAbortHandler", func(t *testing.T) {
logOutput := testServerPanickingHandler(t, func(w http.ResponseWriter, r *http.Request) {
panic(http.ErrAbortHandler)
})
require.NotContains(t, logOutput, "http3: panic serving")
require.NotContains(t, logOutput, "http.ErrAbortHandler")
})
}
func testServerPanickingHandler(t *testing.T, handler http.HandlerFunc) (logOutput string) {
clientConn, serverConn := newConnPair(t)
str, err := clientConn.OpenStream()
require.NoError(t, err)
_, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
require.NoError(t, err)
require.NoError(t, str.Close())
var logBuf bytes.Buffer
s := &Server{
Handler: handler,
Logger: slog.New(slog.NewTextHandler(&logBuf, nil)),
}
go s.ServeQUICConn(serverConn)
expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeInternalError))
s.Close()
return logBuf.String()
}
func TestServerRequestHeaderTooLarge(t *testing.T) {
t.Run("default value", func(t *testing.T) {
// use 2*DefaultMaxHeaderBytes here. qpack will compress the request,
// but the request will still end up larger than DefaultMaxHeaderBytes.
url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2)
testServerRequestHeaderTooLarge(t,
httptest.NewRequest(http.MethodGet, "https://"+string(url), nil),
0,
)
})
t.Run("custom value", func(t *testing.T) {
testServerRequestHeaderTooLarge(t,
httptest.NewRequest(http.MethodGet, "https://www.example.com", nil),
20,
)
})
}
func testServerRequestHeaderTooLarge(t *testing.T, req *http.Request, maxHeaderBytes int) {
var called bool
s := &Server{
MaxHeaderBytes: maxHeaderBytes,
Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }),
}
s.init()
clientConn, serverConn := newConnPair(t)
str, err := clientConn.OpenStream()
require.NoError(t, err)
_, err = str.Write(encodeRequest(t, req))
require.NoError(t, err)
require.NoError(t, str.Close())
go s.ServeQUICConn(serverConn)
expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeFrameError))
expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeFrameError))
require.False(t, called)
}
func TestServerRequestContext(t *testing.T) {
clientConn, serverConn := newConnPair(t)
str, err := clientConn.OpenStream()
require.NoError(t, err)
_, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
require.NoError(t, err)
ctxChan := make(chan context.Context, 1)
block := make(chan struct{})
s := &Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxChan <- r.Context()
<-block
}),
}
go s.ServeQUICConn(serverConn)
var requestContext context.Context
select {
case requestContext = <-ctxChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Equal(t, s, requestContext.Value(ServerContextKey))
require.Equal(t, serverConn.LocalAddr(), requestContext.Value(http.LocalAddrContextKey))
require.Equal(t, serverConn.RemoteAddr(), requestContext.Value(RemoteAddrContextKey))
select {
case <-requestContext.Done():
t.Fatal("request context was canceled")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
str.CancelRead(1337)
select {
case <-requestContext.Done():
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Equal(t, context.Canceled, requestContext.Err())
close(block)
}
func TestServerHTTPStreamHijacking(t *testing.T) {
clientConn, serverConn := newConnPair(t)
str, err := clientConn.OpenStream()
require.NoError(t, err)
_, err = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodHead, "https://www.example.com", nil)))
require.NoError(t, err)
require.NoError(t, str.Close())
s := &Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
str := w.(HTTPStreamer).HTTPStream()
str.Write([]byte("foobar"))
str.Close()
}),
}
go s.ServeQUICConn(serverConn)
str.SetReadDeadline(time.Now().Add(time.Second))
rsp, err := io.ReadAll(str)
require.NoError(t, err)
r := bytes.NewReader(rsp)
hfs := decodeHeader(t, r)
require.Equal(t, hfs[":status"], []string{"200"})
fp := frameParser{r: r}
frame, err := fp.ParseNext()
require.NoError(t, err)
require.IsType(t, &dataFrame{}, frame)
dataFrame := frame.(*dataFrame)
require.Equal(t, uint64(6), dataFrame.Length)
data, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, []byte("foobar"), data)
}
func TestServerStreamHijacking(t *testing.T) {
for _, bidirectional := range []bool{true, false} {
name := "bidirectional"
if !bidirectional {
name = "unidirectional"
}
t.Run(name, func(t *testing.T) {
t.Run("hijack", func(t *testing.T) {
testServerHijackBidirectionalStream(t, bidirectional, true, nil)
})
t.Run("don't hijack", func(t *testing.T) {
testServerHijackBidirectionalStream(t, bidirectional, false, nil)
})
t.Run("hijacker error", func(t *testing.T) {
testServerHijackBidirectionalStream(t, bidirectional, false, assert.AnError)
})
})
}
}
func testServerHijackBidirectionalStream(t *testing.T, bidirectional bool, doHijack bool, hijackErr error) {
type hijackCall struct {
ft FrameType // for bidirectional streams
st StreamType // for unidirectional streams
connTracingID quic.ConnectionTracingID
e error
}
hijackChan := make(chan hijackCall, 1)
testDone := make(chan struct{})
s := &Server{
StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
defer close(testDone)
hijackChan <- hijackCall{ft: ft, connTracingID: connTracingID, e: e}
return doHijack, hijackErr
},
UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
defer close(testDone)
hijackChan <- hijackCall{st: st, connTracingID: connTracingID, e: err}
return doHijack
},
}
clientConn, serverConn := newConnPair(t)
go s.ServeQUICConn(serverConn)
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
if bidirectional {
str, err := clientConn.OpenStream()
require.NoError(t, err)
_, err = str.Write(buf.Bytes())
require.NoError(t, err)
if hijackErr != nil {
expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeRequestIncomplete))
}
// if the stream is not hijacked, the frame parser will skip the frame
} else {
str, err := clientConn.OpenUniStream()
require.NoError(t, err)
_, err = str.Write(buf.Bytes())
require.NoError(t, err)
if !doHijack || hijackErr != nil {
expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeStreamCreationError))
}
}
select {
case hijackCall := <-hijackChan:
if bidirectional {
assert.Zero(t, hijackCall.st)
assert.Equal(t, hijackCall.ft, FrameType(0x41))
} else {
assert.Equal(t, hijackCall.st, StreamType(0x41))
assert.Zero(t, hijackCall.ft)
}
assert.Equal(t, serverConn.Context().Value(quic.ConnectionTracingKey), hijackCall.connTracingID)
assert.NoError(t, hijackCall.e)
case <-time.After(time.Second):
t.Fatal("hijack call not received")
}
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
}
func getAltSvc(s *Server) (string, bool) {
hdr := http.Header{}
s.SetQUICHeaders(hdr)
if altSvc, ok := hdr["Alt-Svc"]; ok {
return altSvc[0], true
}
return "", false
}
func TestServerAltSvcFromListenersAndConns(t *testing.T) {
t.Run("default", func(t *testing.T) {
testServerAltSvcFromListenersAndConns(t, []quic.Version{})
})
t.Run("v1", func(t *testing.T) {
testServerAltSvcFromListenersAndConns(t, []quic.Version{quic.Version1})
})
t.Run("v1 and v2", func(t *testing.T) {
testServerAltSvcFromListenersAndConns(t, []quic.Version{quic.Version1, quic.Version2})
})
}
func testServerAltSvcFromListenersAndConns(t *testing.T, versions []quic.Version) {
ln1, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), nil)
require.NoError(t, err)
port1 := ln1.Addr().(*net.UDPAddr).Port
s := &Server{
Addr: ":1337", // will be ignored since we're using listeners
TLSConfig: getTLSConfig(),
QUICConfig: &quic.Config{Versions: versions},
}
done1 := make(chan struct{})
go func() {
defer close(done1)
s.ServeListener(ln1)
}()
time.Sleep(scaleDuration(10 * time.Millisecond))
altSvc, ok := getAltSvc(s)
require.True(t, ok)
require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000`, port1), altSvc)
udpConn := newUDPConnLocalhost(t)
port2 := udpConn.LocalAddr().(*net.UDPAddr).Port
done2 := make(chan struct{})
go func() {
defer close(done2)
s.Serve(udpConn)
}()
time.Sleep(scaleDuration(10 * time.Millisecond))
altSvc, ok = getAltSvc(s)
require.True(t, ok)
require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000,h3=":%d"; ma=2592000`, port1, port2), altSvc)
// Close the first listener.
// This should remove the associated Alt-Svc entry.
require.NoError(t, ln1.Close())
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("timeout")
}
altSvc, ok = getAltSvc(s)
require.True(t, ok)
require.Equal(t, fmt.Sprintf(`h3=":%d"; ma=2592000`, port2), altSvc)
// Close the second listener.
// This should remove the Alt-Svc entry altogether.
require.NoError(t, udpConn.Close())
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("timeout")
}
_, ok = getAltSvc(s)
require.False(t, ok)
}
func TestServerAltSvcFromPort(t *testing.T) {
s := &Server{Port: 1337}
_, ok := getAltSvc(s)
require.False(t, ok)
ln, err := quic.ListenEarly(newUDPConnLocalhost(t), getTLSConfig(), nil)
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
s.ServeListener(ln)
}()
time.Sleep(scaleDuration(10 * time.Millisecond))
altSvc, ok := getAltSvc(s)
require.True(t, ok)
require.Equal(t, `h3=":1337"; ma=2592000`, altSvc)
require.NoError(t, ln.Close())
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
_, ok = getAltSvc(s)
require.False(t, ok)
}
type unixSocketListener struct {
*quic.EarlyListener
}
func (l *unixSocketListener) Addr() net.Addr {
return &net.UnixAddr{Net: "unix", Name: "/tmp/quic.sock"}
}
func TestServerAltSvcFromUnixSocket(t *testing.T) {
t.Run("with Server.Addr not set", func(t *testing.T) {
_, ok := testServerAltSvcFromUnixSocket(t, "")
require.False(t, ok)
})
t.Run("with Server.Addr set", func(t *testing.T) {
altSvc, ok := testServerAltSvcFromUnixSocket(t, ":1337")
require.True(t, ok)
require.Equal(t, `h3=":1337"; ma=2592000`, altSvc)
})
}
func testServerAltSvcFromUnixSocket(t *testing.T, addr string) (altSvc string, ok bool) {
ln, err := quic.ListenEarly(newUDPConnLocalhost(t), testdata.GetTLSConfig(), nil)
require.NoError(t, err)
var logBuf bytes.Buffer
s := &Server{
Addr: addr,
Logger: slog.New(slog.NewTextHandler(&logBuf, nil)),
}
done := make(chan struct{})
go func() {
defer close(done)
s.ServeListener(&unixSocketListener{EarlyListener: ln})
}()
time.Sleep(scaleDuration(10 * time.Millisecond))
altSvc, ok = getAltSvc(s)
require.NoError(t, ln.Close())
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Contains(t, logBuf.String(), "Unable to extract port from listener, will not be announced using SetQUICHeaders")
return altSvc, ok
}
func TestServerListenAndServeErrors(t *testing.T) {
require.EqualError(t, (&Server{}).ListenAndServe(), "use of http3.Server without TLSConfig")
s := &Server{
Addr: ":123456",
TLSConfig: testdata.GetTLSConfig(),
}
require.ErrorContains(t, s.ListenAndServe(), "invalid port")
}
func TestServerClosing(t *testing.T) {
s := &Server{TLSConfig: getTLSConfig()}
require.NoError(t, s.Close())
require.NoError(t, s.Close()) // duplicate calls are ok
require.ErrorIs(t, s.ListenAndServe(), http.ErrServerClosed)
require.ErrorIs(t, s.ListenAndServeTLS(testdata.GetCertificatePaths()), http.ErrServerClosed)
require.ErrorIs(t, s.Serve(nil), http.ErrServerClosed)
require.ErrorIs(t, s.ServeListener(nil), http.ErrServerClosed)
require.ErrorIs(t, s.ServeQUICConn(nil), http.ErrServerClosed)
}
func TestServerConcurrentServeAndClose(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
require.NoError(t, err)
c, err := net.ListenUDP("udp", addr)
require.NoError(t, err)
done := make(chan struct{})
s := &Server{TLSConfig: testdata.GetTLSConfig()}
go func() {
defer close(done)
s.Serve(c)
}()
runtime.Gosched()
s.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerImmediateGracefulShutdown(t *testing.T) {
s := &Server{TLSConfig: testdata.GetTLSConfig()}
errChan := make(chan error, 1)
go func() { errChan <- s.Shutdown(context.Background()) }()
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestServerGracefulShutdown(t *testing.T) {
requestChan := make(chan struct{}, 1)
s := &Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestChan <- struct{}{}
})}
clientConn, serverConn := newConnPair(t)
go s.ServeQUICConn(serverConn)
firstStream, err := clientConn.OpenStream()
require.NoError(t, err)
_, err = firstStream.Write(encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)))
require.NoError(t, err)
select {
case <-requestChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
controlStr, err := clientConn.AcceptUniStream(ctx)
require.NoError(t, err)
typ, err := quicvarint.Read(quicvarint.NewReader(controlStr))
require.NoError(t, err)
require.EqualValues(t, streamTypeControlStream, typ)
fp := &frameParser{r: controlStr}
f, err := fp.ParseNext()
require.NoError(t, err)
require.IsType(t, &settingsFrame{}, f)
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
errChan := make(chan error)
go func() {
errChan <- s.Shutdown(shutdownCtx)
}()
f, err = fp.ParseNext()
require.NoError(t, err)
require.Equal(t, &goAwayFrame{StreamID: 4}, f)
select {
case <-errChan:
t.Fatal("didn't expect Shutdown to return")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
// all further streams are getting rejected
for range 3 {
str, err := clientConn.OpenStream()
require.NoError(t, err)
_, _ = str.Write(encodeRequest(t, httptest.NewRequest(http.MethodGet, "https://www.example.com", nil)))
expectStreamReadReset(t, str, quic.StreamErrorCode(ErrCodeRequestRejected))
expectStreamWriteReset(t, str, quic.StreamErrorCode(ErrCodeRequestRejected))
}
// cancel the context passed to Shutdown
shutdownCancel()
select {
case err := <-errChan:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}