http3: improve the client API (#4693)

* http3: rename RoundTripper to Transport

* http3: rename SingleDestinationRoundTripper to ClientConn

* http3: construct the ClientConn via Transport.NewClientConn
This commit is contained in:
Marten Seemann
2024-10-14 00:17:50 -05:00
committed by GitHub
parent eaa879f32f
commit 1db805ce4f
8 changed files with 389 additions and 324 deletions

View File

@@ -40,7 +40,7 @@ func main() {
}
testdata.AddRootCA(pool)
roundTripper := &http3.RoundTripper{
roundTripper := &http3.Transport{
TLSClientConfig: &tls.Config{
RootCAs: pool,
InsecureSkipVerify: *insecure,

View File

@@ -9,7 +9,6 @@ import (
"net/http"
"net/http/httptrace"
"net/textproto"
"sync"
"time"
"github.com/quic-go/quic-go"
@@ -38,102 +37,117 @@ var defaultQuicConfig = &quic.Config{
KeepAlivePeriod: 10 * time.Second,
}
// SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server.
type SingleDestinationRoundTripper struct {
Connection quic.Connection
// ClientConn is an HTTP/3 client doing requests to a single remote server.
type ClientConn struct {
connection
// Enable support for HTTP/3 datagrams (RFC 9297).
// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
EnableDatagrams bool
// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting enableDatagrams.
enableDatagrams bool
// Additional HTTP/3 settings.
// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
additionalSettings map[uint64]uint64
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
// maxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header.
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
maxResponseHeaderBytes uint64
// DisableCompression, if true, prevents the Transport from requesting compression with an
// disableCompression, if true, prevents the Transport from requesting compression with an
// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
// decoded in the Response.Body.
// However, if the user explicitly requested gzip it is not automatically uncompressed.
DisableCompression bool
disableCompression bool
Logger *slog.Logger
logger *slog.Logger
initOnce sync.Once
hconn *connection
requestWriter *requestWriter
decoder *qpack.Decoder
}
var _ http.RoundTripper = &SingleDestinationRoundTripper{}
var _ http.RoundTripper = &ClientConn{}
func (c *SingleDestinationRoundTripper) Start() Connection {
c.initOnce.Do(func() { c.init() })
return c.hconn
}
func (c *SingleDestinationRoundTripper) init() {
func newClientConn(
conn quic.Connection,
enableDatagrams bool,
additionalSettings map[uint64]uint64,
streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error),
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool),
maxResponseHeaderBytes int64,
disableCompression bool,
logger *slog.Logger,
) *ClientConn {
c := &ClientConn{
enableDatagrams: enableDatagrams,
additionalSettings: additionalSettings,
disableCompression: disableCompression,
logger: logger,
}
if maxResponseHeaderBytes <= 0 {
c.maxResponseHeaderBytes = defaultMaxResponseHeaderBytes
} else {
c.maxResponseHeaderBytes = uint64(maxResponseHeaderBytes)
}
c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
c.requestWriter = newRequestWriter()
c.hconn = newConnection(
c.Connection.Context(),
c.Connection,
c.EnableDatagrams,
c.connection = *newConnection(
conn.Context(),
conn,
c.enableDatagrams,
protocol.PerspectiveClient,
c.Logger,
c.logger,
0,
)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(c.hconn); err != nil {
if c.Logger != nil {
c.Logger.Debug("Setting up connection failed", "error", err)
if err := c.setupConn(); err != nil {
if c.logger != nil {
c.logger.Debug("Setting up connection failed", "error", err)
}
c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
}
}()
if c.StreamHijacker != nil {
go c.handleBidirectionalStreams()
if streamHijacker != nil {
go c.handleBidirectionalStreams(streamHijacker)
}
go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker)
go c.connection.HandleUnidirectionalStreams(uniStreamHijacker)
return c
}
func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error {
func (c *ClientConn) OpenRequestStream(ctx context.Context) (RequestStream, error) {
return c.connection.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes)
}
func (c *ClientConn) setupConn() error {
// open the control stream
str, err := conn.OpenUniStream()
str, err := c.connection.OpenUniStream()
if err != nil {
return err
}
b := make([]byte, 0, 64)
b = quicvarint.Append(b, streamTypeControlStream)
// send the SETTINGS frame
b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b)
b = (&settingsFrame{Datagram: c.enableDatagrams, Other: c.additionalSettings}).Append(b)
_, err = str.Write(b)
return err
}
func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)) {
for {
str, err := c.hconn.AcceptStream(context.Background())
str, err := c.connection.AcceptStream(context.Background())
if err != nil {
if c.Logger != nil {
c.Logger.Debug("accepting bidirectional stream failed", "error", err)
if c.logger != nil {
c.logger.Debug("accepting bidirectional stream failed", "error", err)
}
return
}
fp := &frameParser{
r: str,
conn: c.hconn,
conn: &c.connection,
unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) {
id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return c.StreamHijacker(ft, id, str, e)
id := c.connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return streamHijacker(ft, id, str, e)
},
}
go func() {
@@ -141,26 +155,17 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
return
}
if err != nil {
if c.Logger != nil {
c.Logger.Debug("error handling stream", "error", err)
if c.logger != nil {
c.logger.Debug("error handling stream", "error", err)
}
}
c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}()
}
}
func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 {
if c.MaxResponseHeaderBytes <= 0 {
return defaultMaxResponseHeaderBytes
}
return uint64(c.MaxResponseHeaderBytes)
}
// RoundTrip executes a request and returns a response
func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
c.initOnce.Do(func() { c.init() })
func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
rsp, err := c.roundTrip(req)
if err != nil && req.Context().Err() != nil {
// if the context was canceled, return the context cancellation error
@@ -169,7 +174,7 @@ func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Resp
return rsp, err
}
func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) {
func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) {
// Immediately send out this request, if this is a 0-RTT request.
switch req.Method {
case MethodGet0RTT:
@@ -200,17 +205,23 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp
connCtx := c.Connection.Context()
// wait for the server's SETTINGS frame to arrive
select {
case <-c.hconn.ReceivedSettings():
case <-c.connection.ReceivedSettings():
case <-connCtx.Done():
return nil, context.Cause(connCtx)
}
if !c.hconn.Settings().EnableExtendedConnect {
if !c.connection.Settings().EnableExtendedConnect {
return nil, errors.New("http3: server didn't enable Extended CONNECT")
}
}
reqDone := make(chan struct{})
str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes())
str, err := c.connection.openRequestStream(
req.Context(),
c.requestWriter,
reqDone,
c.disableCompression,
c.maxResponseHeaderBytes,
)
if err != nil {
return nil, err
}
@@ -238,12 +249,6 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp
return rsp, maybeReplaceError(err)
}
func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) {
c.initOnce.Do(func() { c.init() })
return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes())
}
// cancelingReader reads from the io.Reader.
// It cancels writing on the stream if any error other than io.EOF occurs.
type cancelingReader struct {
@@ -259,7 +264,7 @@ func (r *cancelingReader) Read(b []byte) (int, error) {
return n, err
}
func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
func (c *ClientConn) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
defer body.Close()
buf := make([]byte, bodyCopyBufferSize)
sr := &cancelingReader{str: str, r: body}
@@ -283,7 +288,7 @@ func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.Read
return err
}
func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
func (c *ClientConn) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
if err := str.SendRequestHeader(req); err != nil {
return nil, err
}
@@ -299,8 +304,8 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques
contentLength = req.ContentLength
}
if err := c.sendRequestBody(str, req.Body, contentLength); err != nil {
if c.Logger != nil {
c.Logger.Debug("error writing request", "error", err)
if c.logger != nil {
c.logger.Debug("error writing request", "error", err)
}
}
str.Close()
@@ -337,7 +342,7 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques
}
break
}
connState := c.hconn.ConnectionState().TLS
connState := c.connection.ConnectionState().TLS
res.TLS = &connState
res.Request = req
return res, nil

View File

@@ -81,8 +81,7 @@ var _ = Describe("Client", func() {
It("hijacks a bidirectional stream of unknown frame type", func() {
id := quic.ConnectionTracingID(1234)
frameTypeChan := make(chan FrameType, 1)
rt := &SingleDestinationRoundTripper{
Connection: conn,
tr := &Transport{
StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
Expect(connTracingID).To(Equal(id))
@@ -101,7 +100,8 @@ var _ = Describe("Client", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
_, err := rt.RoundTrip(request)
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@@ -109,8 +109,7 @@ var _ = Describe("Client", func() {
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
rt := &SingleDestinationRoundTripper{
Connection: conn,
tr := &Transport{
StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
@@ -129,15 +128,15 @@ var _ = Describe("Client", func() {
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := rt.RoundTrip(request)
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
It("closes the connection when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
rt := &SingleDestinationRoundTripper{
Connection: conn,
tr := &Transport{
StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
@@ -156,7 +155,8 @@ var _ = Describe("Client", func() {
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := rt.RoundTrip(request)
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
@@ -165,8 +165,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error")
unknownStr := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{})
rt := &SingleDestinationRoundTripper{
Connection: conn,
tr := &Transport{
StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, e error) (hijacked bool, err error) {
defer close(done)
Expect(e).To(MatchError(testErr))
@@ -185,7 +184,8 @@ var _ = Describe("Client", func() {
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := rt.RoundTrip(request)
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@@ -226,8 +226,7 @@ var _ = Describe("Client", func() {
It("hijacks an unidirectional stream of unknown stream type", func() {
id := quic.ConnectionTracingID(100)
streamTypeChan := make(chan StreamType, 1)
rt := &SingleDestinationRoundTripper{
Connection: conn,
tr := &Transport{
UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
Expect(connTracingID).To(Equal(id))
Expect(err).ToNot(HaveOccurred())
@@ -235,7 +234,6 @@ var _ = Describe("Client", func() {
return true
},
}
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
@@ -248,7 +246,8 @@ var _ = Describe("Client", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
_, err := rt.RoundTrip(req)
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@@ -258,8 +257,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
rt := &SingleDestinationRoundTripper{
Connection: conn,
tr := &Transport{
UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
@@ -268,7 +266,6 @@ var _ = Describe("Client", func() {
return true
},
}
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
@@ -277,7 +274,8 @@ var _ = Describe("Client", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
_, err := rt.RoundTrip(req)
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@@ -285,15 +283,13 @@ var _ = Describe("Client", func() {
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
rt := &SingleDestinationRoundTripper{
Connection: conn,
tr := &Transport{
UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
},
}
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
@@ -307,7 +303,8 @@ var _ = Describe("Client", func() {
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
_, err := rt.RoundTrip(req)
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@@ -337,13 +334,13 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done")
}).AnyTimes()
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
rt := &SingleDestinationRoundTripper{
Connection: conn,
tr := &Transport{
EnableDatagrams: true,
}
cc := tr.NewClientConn(conn)
req, err := http.NewRequest(http.MethodGet, "https://quic-go.net", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
_, err = cc.RoundTrip(req)
Expect(err).To(MatchError("test done"))
t, err := quicvarint.Read(&buf)
Expect(err).ToNot(HaveOccurred())
@@ -373,10 +370,10 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done")
})
rt := &SingleDestinationRoundTripper{Connection: conn}
hconn := rt.Start()
Eventually(hconn.ReceivedSettings()).Should(BeClosed())
settings := hconn.Settings()
tr := &Transport{}
cc := tr.NewClientConn(conn)
Eventually(cc.ReceivedSettings()).Should(BeClosed())
settings := cc.Settings()
Expect(settings.EnableExtendedConnect).To(BeTrue())
// test shutdown
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
@@ -410,8 +407,9 @@ var _ = Describe("Client", func() {
conn.EXPECT().Context().Return(context.Background())
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("test error"))
rt := &SingleDestinationRoundTripper{Connection: conn}
_, err := rt.RoundTrip(&http.Request{
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(&http.Request{
Method: http.MethodConnect,
Proto: "connect",
Host: "localhost",
@@ -450,8 +448,9 @@ var _ = Describe("Client", func() {
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().Context().Return(context.Background())
rt := &SingleDestinationRoundTripper{Connection: conn}
_, err := rt.RoundTrip(&http.Request{
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(&http.Request{
Method: http.MethodConnect,
Proto: "connect",
Host: "localhost",
@@ -470,7 +469,6 @@ var _ = Describe("Client", func() {
req *http.Request
str *mockquic.MockStream
conn *mockquic.MockEarlyConnection
cl *SingleDestinationRoundTripper
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})
@@ -517,7 +515,6 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
cl = &SingleDestinationRoundTripper{Connection: conn}
var err error
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
@@ -533,7 +530,9 @@ var _ = Describe("Client", func() {
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
@@ -552,7 +551,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
return 0, testErr
})
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError(testErr))
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", serialized))
// make sure the request wasn't modified
@@ -572,7 +573,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
@@ -599,7 +602,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
_, err = io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
@@ -625,7 +630,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
_, err = io.ReadAll(rsp.Body)
Expect(err).To(HaveOccurred())
@@ -662,7 +669,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
_, err = io.ReadAll(rsp.Body)
Expect(err).To(MatchError(errors.New("additional HEADERS frame received after trailers")))
@@ -702,7 +711,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
_, err = io.ReadAll(rsp.Body)
Expect(err).To(MatchError(errors.New("DATA frame received after trailers")))
@@ -744,7 +755,9 @@ var _ = Describe("Client", func() {
<-done
return 0, testErr
})
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError(testErr))
hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
@@ -770,7 +783,10 @@ var _ = Describe("Client", func() {
<-done
return 0, errors.New("done")
})
cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(HaveOccurred())
Expect(strBuf.String()).To(ContainSubstring("request"))
Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
})
@@ -794,7 +810,9 @@ var _ = Describe("Client", func() {
})
closed := make(chan struct{})
str.EXPECT().Close().Do(func() error { close(closed); return nil })
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError(testErr))
Eventually(closed).Should(BeClosed())
})
@@ -806,7 +824,9 @@ var _ = Describe("Client", func() {
r := bytes.NewReader(b)
str.EXPECT().Close().Do(func() error { close(closed); return nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError("http3: expected first frame to be a HEADERS frame"))
Eventually(closed).Should(BeClosed())
})
@@ -825,13 +845,16 @@ var _ = Describe("Client", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() error { close(closed); return nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(HaveOccurred())
Eventually(closed).Should(BeClosed())
})
It("cancels the stream when the HEADERS frame is too large", func() {
cl.MaxResponseHeaderBytes = 1337
tr := &Transport{MaxResponseHeaderBytes: 1337}
cc := tr.NewClientConn(conn)
b := (&headersFrame{Length: 1338}).Append(nil)
r := bytes.NewReader(b)
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
@@ -839,14 +862,16 @@ var _ = Describe("Client", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() error { close(closed); return nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := cl.RoundTrip(req)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError("http3: HEADERS frame too large: 1338 bytes (max: 1337)"))
Eventually(closed).Should(BeClosed())
})
It("opens a request stream", func() {
cl.Connection.(quic.EarlyConnection).HandshakeComplete()
str, err := cl.OpenRequestStream(context.Background())
tr := &Transport{}
cc := tr.NewClientConn(conn)
conn.HandshakeComplete()
str, err := cc.OpenRequestStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.SendRequestHeader(req)).To(Succeed())
str.Write([]byte("foobar"))
@@ -863,9 +888,11 @@ var _ = Describe("Client", func() {
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
tr := &Transport{}
cc := tr.NewClientConn(conn)
errChan := make(chan error)
go func() {
_, err := cl.RoundTrip(req)
_, err := cc.RoundTrip(req)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
@@ -895,7 +922,9 @@ var _ = Describe("Client", func() {
<-canceled
return 0, errors.New("test done")
})
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError(context.Canceled))
Eventually(done).Should(BeClosed())
})
@@ -916,7 +945,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(done).Should(BeClosed())
@@ -940,17 +971,17 @@ var _ = Describe("Client", func() {
)
testErr := errors.New("test done")
str.EXPECT().Read(gomock.Any()).Return(0, testErr)
_, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
_, err := cc.RoundTrip(req)
Expect(err).To(MatchError(testErr))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("doesn't add gzip if the header disable it", func() {
client := &SingleDestinationRoundTripper{
Connection: conn,
DisableCompression: true,
}
tr := &Transport{DisableCompression: true}
client := tr.NewClientConn(conn)
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
@@ -985,7 +1016,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
@@ -1009,7 +1042,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
@@ -1045,7 +1080,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
@@ -1077,7 +1114,9 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTrip(req)
tr := &Transport{}
cc := tr.NewClientConn(conn)
rsp, err := cc.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"strings"
@@ -30,7 +31,7 @@ type Settings struct {
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
// OnlyCachedConn controls whether the Transport may create a new QUIC connection.
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
OnlyCachedConn bool
}
@@ -59,8 +60,8 @@ func (r *roundTripperWithCount) Close() error {
return nil
}
// RoundTripper implements the http.RoundTripper interface
type RoundTripper struct {
// Transport implements the http.RoundTripper interface
type Transport struct {
mutex sync.Mutex
// TLSClientConfig specifies the TLS configuration to use with
@@ -97,6 +98,11 @@ type RoundTripper struct {
// However, if the user explicitly requested gzip it is not automatically uncompressed.
DisableCompression bool
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
Logger *slog.Logger
initOnce sync.Once
initErr error
@@ -107,18 +113,53 @@ type RoundTripper struct {
}
var (
_ http.RoundTripper = &RoundTripper{}
_ io.Closer = &RoundTripper{}
_ http.RoundTripper = &Transport{}
_ io.Closer = &Transport{}
)
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
func (t *Transport) init() error {
if t.newClient == nil {
t.newClient = func(conn quic.EarlyConnection) singleRoundTripper {
return newClientConn(
conn,
t.EnableDatagrams,
t.AdditionalSettings,
t.StreamHijacker,
t.UniStreamHijacker,
t.MaxResponseHeaderBytes,
t.DisableCompression,
t.Logger,
)
}
}
if t.QUICConfig == nil {
t.QUICConfig = defaultQuicConfig.Clone()
t.QUICConfig.EnableDatagrams = t.EnableDatagrams
}
if t.EnableDatagrams && !t.QUICConfig.EnableDatagrams {
return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
}
if len(t.QUICConfig.Versions) == 0 {
t.QUICConfig = t.QUICConfig.Clone()
t.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
}
if len(t.QUICConfig.Versions) != 1 {
return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
}
if t.QUICConfig.MaxIncomingStreams == 0 {
t.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
}
return nil
}
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
r.initOnce.Do(func() { r.initErr = r.init() })
if r.initErr != nil {
return nil, r.initErr
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
t.initOnce.Do(func() { t.initErr = t.init() })
if t.initErr != nil {
return nil, t.initErr
}
if req.URL == nil {
@@ -154,7 +195,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
}
hostname := authorityAddr(hostnameFromURL(req.URL))
cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn)
cl, isReused, err := t.getClient(req.Context(), hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
@@ -166,7 +207,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
}
if cl.dialErr != nil {
r.removeClient(hostname)
t.removeClient(hostname)
return nil, cl.dialErr
}
defer cl.useCount.Add(-1)
@@ -176,12 +217,12 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
// so we remove the client from the cache so that subsequent trips reconnect
// context cancelation is excluded as is does not signify a connection error
if !errors.Is(err, context.Canceled) {
r.removeClient(hostname)
t.removeClient(hostname)
}
if isReused {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return r.RoundTripOpt(req, opt)
return t.RoundTripOpt(req, opt)
}
}
}
@@ -189,51 +230,19 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
}
// RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) init() error {
if r.newClient == nil {
r.newClient = func(conn quic.EarlyConnection) singleRoundTripper {
return &SingleDestinationRoundTripper{
Connection: conn,
EnableDatagrams: r.EnableDatagrams,
DisableCompression: r.DisableCompression,
AdditionalSettings: r.AdditionalSettings,
MaxResponseHeaderBytes: r.MaxResponseHeaderBytes,
}
}
}
if r.QUICConfig == nil {
r.QUICConfig = defaultQuicConfig.Clone()
r.QUICConfig.EnableDatagrams = r.EnableDatagrams
}
if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams {
return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
}
if len(r.QUICConfig.Versions) == 0 {
r.QUICConfig = r.QUICConfig.Clone()
r.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
}
if len(r.QUICConfig.Versions) != 1 {
return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
}
if r.QUICConfig.MaxIncomingStreams == 0 {
r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
}
return nil
}
func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
t.mutex.Lock()
defer t.mutex.Unlock()
func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]*roundTripperWithCount)
if t.clients == nil {
t.clients = make(map[string]*roundTripperWithCount)
}
cl, ok := r.clients[hostname]
cl, ok := t.clients[hostname]
if !ok {
if onlyCached {
return nil, false, ErrNoCachedConn
@@ -246,7 +255,7 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
go func() {
defer close(cl.dialing)
defer cancel()
conn, rt, err := r.dial(ctx, hostname)
conn, rt, err := t.dial(ctx, hostname)
if err != nil {
cl.dialErr = err
return
@@ -254,12 +263,12 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
cl.conn = conn
cl.rt = rt
}()
r.clients[hostname] = cl
t.clients[hostname] = cl
}
select {
case <-cl.dialing:
if cl.dialErr != nil {
delete(r.clients, hostname)
delete(t.clients, hostname)
return nil, false, cl.dialErr
}
select {
@@ -273,12 +282,12 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
return cl, isReused, nil
}
func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) {
func (t *Transport) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) {
var tlsConf *tls.Config
if r.TLSClientConfig == nil {
if t.TLSClientConfig == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = r.TLSClientConfig.Clone()
tlsConf = t.TLSClientConfig.Clone()
}
if tlsConf.ServerName == "" {
sni, _, err := net.SplitHostPort(hostname)
@@ -289,61 +298,74 @@ func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyCon
tlsConf.ServerName = sni
}
// Replace existing ALPNs by H3
tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])}
tlsConf.NextProtos = []string{versionToALPN(t.QUICConfig.Versions[0])}
dial := r.Dial
dial := t.Dial
if dial == nil {
if r.transport == nil {
if t.transport == nil {
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, nil, err
}
r.transport = &quic.Transport{Conn: udpConn}
t.transport = &quic.Transport{Conn: udpConn}
}
dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
return t.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
}
}
conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig)
conn, err := dial(ctx, hostname, tlsConf, t.QUICConfig)
if err != nil {
return nil, nil, err
}
return conn, r.newClient(conn), nil
return conn, t.newClient(conn), nil
}
func (r *RoundTripper) removeClient(hostname string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
func (t *Transport) removeClient(hostname string) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.clients == nil {
return
}
delete(r.clients, hostname)
delete(t.clients, hostname)
}
// Close closes the QUIC connections that this RoundTripper has used.
func (t *Transport) NewClientConn(conn quic.Connection) *ClientConn {
return newClientConn(
conn,
t.EnableDatagrams,
t.AdditionalSettings,
t.StreamHijacker,
t.UniStreamHijacker,
t.MaxResponseHeaderBytes,
t.DisableCompression,
t.Logger,
)
}
// Close closes the QUIC connections that this Transport has used.
// It also closes the underlying UDPConn if it is not nil.
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, cl := range r.clients {
func (t *Transport) Close() error {
t.mutex.Lock()
defer t.mutex.Unlock()
for _, cl := range t.clients {
if err := cl.Close(); err != nil {
return err
}
}
r.clients = nil
if r.transport != nil {
if err := r.transport.Close(); err != nil {
t.clients = nil
if t.transport != nil {
if err := t.transport.Close(); err != nil {
return err
}
if err := r.transport.Conn.Close(); err != nil {
if err := t.transport.Conn.Close(); err != nil {
return err
}
r.transport = nil
t.transport = nil
}
return nil
}
@@ -376,13 +398,13 @@ func isNotToken(r rune) bool {
return !httpguts.IsTokenRune(r)
}
func (r *RoundTripper) CloseIdleConnections() {
r.mutex.Lock()
defer r.mutex.Unlock()
for hostname, cl := range r.clients {
func (t *Transport) CloseIdleConnections() {
t.mutex.Lock()
defer t.mutex.Unlock()
for hostname, cl := range t.clients {
if cl.useCount.Load() == 0 {
cl.Close()
delete(r.clients, hostname)
delete(t.clients, hostname)
}
}
}

View File

@@ -45,7 +45,7 @@ func (m *mockBody) Close() error {
return m.closeErr
}
var _ = Describe("RoundTripper", func() {
var _ = Describe("Transport", func() {
var req *http.Request
BeforeEach(func() {
@@ -58,14 +58,14 @@ var _ = Describe("RoundTripper", func() {
qconf := &quic.Config{
Versions: []quic.Version{protocol.Version2, protocol.Version1},
}
rt := &RoundTripper{QUICConfig: qconf}
_, err := rt.RoundTrip(req)
tr := &Transport{QUICConfig: qconf}
_, err := tr.RoundTrip(req)
Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
})
It("uses the default QUIC and TLS config if none is give", func() {
var dialAddrCalled bool
rt := &RoundTripper{
tr := &Transport{
Dial: func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
@@ -75,14 +75,14 @@ var _ = Describe("RoundTripper", func() {
return nil, errors.New("test done")
},
}
_, err := rt.RoundTripOpt(req, RoundTripOpt{})
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Expect(dialAddrCalled).To(BeTrue())
})
It("adds the port to the hostname, if none is given", func() {
var dialAddrCalled bool
rt := &RoundTripper{
tr := &Transport{
Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
Expect(hostname).To(Equal("quic.clemente.io:443"))
@@ -92,7 +92,7 @@ var _ = Describe("RoundTripper", func() {
}
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTripOpt(req, RoundTripOpt{})
_, err = tr.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Expect(dialAddrCalled).To(BeTrue())
})
@@ -100,7 +100,7 @@ var _ = Describe("RoundTripper", func() {
It("sets the ServerName in the tls.Config, if not set", func() {
const host = "foo.bar"
var dialCalled bool
rt := &RoundTripper{
tr := &Transport{
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
Expect(tlsCfg.ServerName).To(Equal(host))
@@ -110,7 +110,7 @@ var _ = Describe("RoundTripper", func() {
}
req, err := http.NewRequest("GET", "https://foo.bar", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTripOpt(req, RoundTripOpt{})
_, err = tr.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Expect(dialCalled).To(BeTrue())
})
@@ -122,7 +122,7 @@ var _ = Describe("RoundTripper", func() {
}
quicConf := &quic.Config{MaxIdleTimeout: 3 * time.Nanosecond}
var dialAddrCalled bool
rt := &RoundTripper{
tr := &Transport{
Dial: func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
Expect(host).To(Equal("www.example.org:443"))
@@ -135,7 +135,7 @@ var _ = Describe("RoundTripper", func() {
QUICConfig: quicConf,
TLSClientConfig: tlsConf,
}
_, err := rt.RoundTripOpt(req, RoundTripOpt{})
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Expect(dialAddrCalled).To(BeTrue())
// make sure the original tls.Config was not modified
@@ -149,7 +149,7 @@ var _ = Describe("RoundTripper", func() {
// nolint:staticcheck // This is a test.
ctx := context.WithValue(context.Background(), "foo", "bar")
var dialerCalled bool
rt := &RoundTripper{
tr := &Transport{
Dial: func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
Expect(ctx.Value("foo").(string)).To(Equal("bar"))
@@ -162,14 +162,14 @@ var _ = Describe("RoundTripper", func() {
TLSClientConfig: tlsConf,
QUICConfig: quicConf,
}
_, err := rt.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
_, err := tr.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
Expect(err).To(MatchError(testErr))
Expect(dialerCalled).To(BeTrue())
})
It("enables HTTP/3 Datagrams", func() {
testErr := errors.New("handshake error")
rt := &RoundTripper{
tr := &Transport{
EnableDatagrams: true,
Dial: func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
@@ -177,19 +177,19 @@ var _ = Describe("RoundTripper", func() {
return nil, testErr
},
}
_, err := rt.RoundTripOpt(req, RoundTripOpt{})
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
It("requires quic.Config.EnableDatagrams if HTTP/3 datagrams are enabled", func() {
rt := &RoundTripper{
tr := &Transport{
QUICConfig: &quic.Config{EnableDatagrams: false},
EnableDatagrams: true,
Dial: func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
return nil, errors.New("handshake error")
},
}
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
Expect(err).To(MatchError("HTTP Datagrams enabled, but QUIC Datagrams disabled"))
})
@@ -200,29 +200,29 @@ var _ = Describe("RoundTripper", func() {
req2, err := http.NewRequest("GET", "https://example.com/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
var hostsDialed []string
rt := &RoundTripper{
tr := &Transport{
Dial: func(_ context.Context, host string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
hostsDialed = append(hostsDialed, host)
return nil, testErr
},
}
_, err = rt.RoundTrip(req1)
_, err = tr.RoundTrip(req1)
Expect(err).To(MatchError(testErr))
_, err = rt.RoundTrip(req2)
_, err = tr.RoundTrip(req2)
Expect(err).To(MatchError(testErr))
Expect(hostsDialed).To(Equal([]string{"quic-go.net:443", "example.com:443"}))
})
Context("reusing clients", func() {
var (
rt *RoundTripper
tr *Transport
req1, req2 *http.Request
clientChan chan *MockSingleRoundTripper
)
BeforeEach(func() {
clientChan = make(chan *MockSingleRoundTripper, 16)
rt = &RoundTripper{
tr = &Transport{
newClient: func(quic.EarlyConnection) singleRoundTripper {
select {
case c := <-clientChan:
@@ -252,14 +252,14 @@ var _ = Describe("RoundTripper", func() {
cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil)
cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
var count int
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
count++
return conn, nil
}
rsp, err := rt.RoundTrip(req1)
rsp, err := tr.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Request).To(Equal(req1))
rsp, err = rt.RoundTrip(req2)
rsp, err = tr.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Request).To(Equal(req2))
Expect(count).To(Equal(1))
@@ -277,7 +277,7 @@ var _ = Describe("RoundTripper", func() {
testErr := errors.New("handshake error")
conn := mockquic.NewMockEarlyConnection(mockCtrl)
var count int
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
count++
if count == 1 {
return nil, testErr
@@ -288,9 +288,9 @@ var _ = Describe("RoundTripper", func() {
close(handshakeChan)
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
cl1.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
_, err = rt.RoundTrip(req1)
_, err = tr.RoundTrip(req1)
Expect(err).To(MatchError(testErr))
rsp, err := rt.RoundTrip(req2)
rsp, err := tr.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Request).To(Equal(req2))
Expect(count).To(Equal(2))
@@ -309,7 +309,7 @@ var _ = Describe("RoundTripper", func() {
conn := mockquic.NewMockEarlyConnection(mockCtrl)
var count int
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
count++
return conn, nil
}
@@ -319,9 +319,9 @@ var _ = Describe("RoundTripper", func() {
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
cl1.EXPECT().RoundTrip(req1).Return(nil, testErr)
cl2.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
_, err = rt.RoundTrip(req1)
_, err = tr.RoundTrip(req1)
Expect(err).To(MatchError(testErr))
rsp, err := rt.RoundTrip(req2)
rsp, err := tr.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Request).To(Equal(req2))
Expect(count).To(Equal(2))
@@ -340,7 +340,7 @@ var _ = Describe("RoundTripper", func() {
conn := mockquic.NewMockEarlyConnection(mockCtrl)
var count int
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
count++
return conn, nil
}
@@ -350,9 +350,9 @@ var _ = Describe("RoundTripper", func() {
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
cl1.EXPECT().RoundTrip(req1).Return(nil, testErr)
cl1.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
_, err = rt.RoundTrip(req1)
_, err = tr.RoundTrip(req1)
Expect(err).To(MatchError(testErr))
rsp, err := rt.RoundTrip(req2)
rsp, err := tr.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Request).To(Equal(req2))
Expect(count).To(Equal(1))
@@ -383,30 +383,30 @@ var _ = Describe("RoundTripper", func() {
close(handshakeChan)
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
var count int
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
count++
return conn, nil
}
rsp1, err := rt.RoundTrip(req1)
rsp1, err := tr.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
rsp2, err := rt.RoundTrip(req2)
rsp2, err := tr.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
})
It("only issues a request once, even if a timeout error occurs", func() {
var count int
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
count++
return mockquic.NewMockEarlyConnection(mockCtrl), nil
}
rt.newClient = func(quic.EarlyConnection) singleRoundTripper {
tr.newClient = func(quic.EarlyConnection) singleRoundTripper {
cl := NewMockSingleRoundTripper(mockCtrl)
cl.EXPECT().RoundTrip(gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
return cl
}
_, err := rt.RoundTrip(req1)
_, err := tr.RoundTrip(req1)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
Expect(count).To(Equal(1))
})
@@ -426,7 +426,7 @@ var _ = Describe("RoundTripper", func() {
conn := mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().HandshakeComplete().Return(wait).AnyTimes()
var count int
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
count++
return conn, nil
}
@@ -435,7 +435,7 @@ var _ = Describe("RoundTripper", func() {
go func() {
defer GinkgoRecover()
defer func() { done <- struct{}{} }()
_, err := rt.RoundTrip(req1)
_, err := tr.RoundTrip(req1)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
}()
// wait for the first requests to be issued
@@ -443,7 +443,7 @@ var _ = Describe("RoundTripper", func() {
go func() {
defer GinkgoRecover()
defer func() { done <- struct{}{} }()
_, err := rt.RoundTrip(req2)
_, err := tr.RoundTrip(req2)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
}()
Eventually(reqs).Should(Receive())
@@ -456,19 +456,19 @@ var _ = Describe("RoundTripper", func() {
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
_, err = tr.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
Expect(err).To(MatchError(ErrNoCachedConn))
})
})
Context("validating request", func() {
var rt RoundTripper
var tr Transport
It("rejects plain HTTP requests", func() {
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
req.Body = &mockBody{}
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
_, err = tr.RoundTrip(req)
Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
@@ -476,7 +476,7 @@ var _ = Describe("RoundTripper", func() {
It("rejects requests without a URL", func() {
req.URL = nil
req.Body = &mockBody{}
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.URL"))
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
@@ -484,7 +484,7 @@ var _ = Describe("RoundTripper", func() {
It("rejects request without a URL Host", func() {
req.URL.Host = ""
req.Body = &mockBody{}
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
Expect(err).To(MatchError("http3: no Host in request URL"))
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
@@ -492,34 +492,34 @@ var _ = Describe("RoundTripper", func() {
It("doesn't try to close the body if the request doesn't have one", func() {
req.URL = nil
Expect(req.Body).To(BeNil())
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.URL"))
})
It("rejects requests without a header", func() {
req.Header = nil
req.Body = &mockBody{}
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.Header"))
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects requests with invalid header name fields", func() {
req.Header.Add("foobär", "value")
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
})
It("rejects requests with invalid header name values", func() {
req.Header.Add("foo", string([]byte{0x7}))
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
})
It("rejects requests with an invalid request method", func() {
req.Method = "foobär"
req.Body = &mockBody{}
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
Expect(err).To(MatchError("http3: invalid method \"foobär\""))
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
@@ -528,7 +528,7 @@ var _ = Describe("RoundTripper", func() {
Context("closing", func() {
It("closes", func() {
conn := mockquic.NewMockEarlyConnection(mockCtrl)
rt := &RoundTripper{
tr := &Transport{
Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
},
@@ -540,14 +540,14 @@ var _ = Describe("RoundTripper", func() {
}
req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
_, err = tr.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(0), "")
Expect(rt.Close()).To(Succeed())
Expect(tr.Close()).To(Succeed())
})
It("closes while dialing", func() {
rt := &RoundTripper{
tr := &Transport{
Dial: func(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
defer GinkgoRecover()
Eventually(ctx.Done()).Should(BeClosed())
@@ -560,12 +560,12 @@ var _ = Describe("RoundTripper", func() {
errChan := make(chan error, 1)
go func() {
defer GinkgoRecover()
_, err := rt.RoundTrip(req)
_, err := tr.RoundTrip(req)
errChan <- err
}()
Consistently(errChan, scaleDuration(30*time.Millisecond)).ShouldNot(Receive())
Expect(rt.Close()).To(Succeed())
Expect(tr.Close()).To(Succeed())
var rtErr error
Eventually(errChan).Should(Receive(&rtErr))
Expect(rtErr).To(MatchError("cancelled"))
@@ -574,7 +574,7 @@ var _ = Describe("RoundTripper", func() {
It("closes idle connections", func() {
conn1 := mockquic.NewMockEarlyConnection(mockCtrl)
conn2 := mockquic.NewMockEarlyConnection(mockCtrl)
rt := &RoundTripper{
tr := &Transport{
Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
switch hostname {
case "site1.com:443":
@@ -598,7 +598,7 @@ var _ = Describe("RoundTripper", func() {
req2 = req2.WithContext(ctx2)
roundTripCalled := make(chan struct{})
reqFinished := make(chan struct{})
rt.newClient = func(quic.EarlyConnection) singleRoundTripper {
tr.newClient = func(quic.EarlyConnection) singleRoundTripper {
cl := NewMockSingleRoundTripper(mockCtrl)
cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) {
roundTripCalled <- struct{}{}
@@ -608,11 +608,11 @@ var _ = Describe("RoundTripper", func() {
return cl
}
go func() {
rt.RoundTrip(req1)
tr.RoundTrip(req1)
reqFinished <- struct{}{}
}()
go func() {
rt.RoundTrip(req2)
tr.RoundTrip(req2)
reqFinished <- struct{}{}
}()
<-roundTripCalled
@@ -622,12 +622,12 @@ var _ = Describe("RoundTripper", func() {
<-reqFinished
// req1 is finished
conn1.EXPECT().CloseWithError(gomock.Any(), gomock.Any())
rt.CloseIdleConnections()
tr.CloseIdleConnections()
cancel2()
<-reqFinished
// all requests are finished
conn2.EXPECT().CloseWithError(gomock.Any(), gomock.Any())
rt.CloseIdleConnections()
tr.CloseIdleConnections()
})
})
})

View File

@@ -64,7 +64,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
mux1 *http.ServeMux
mux2 *http.ServeMux
client *http.Client
rt *http3.RoundTripper
rt *http3.Transport
server1 *http3.Server
server2 *http3.Server
ln *listenerWrapper
@@ -106,7 +106,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
})
BeforeEach(func() {
rt = &http3.RoundTripper{
rt = &http3.Transport{
TLSClientConfig: getTLSClientConfig(),
DisableCompression: true,
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
@@ -151,7 +151,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
Expect(fake1.closed.Load()).To(BeTrue())
Expect(fake2.closed.Load()).To(BeFalse())
Expect(ln.listenerClosed).ToNot(BeTrue())
Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred())
Expect(client.Transport.(*http3.Transport).Close()).NotTo(HaveOccurred())
// verify that new connections are being initiated from the second server now
resp, err = client.Get("https://localhost:" + port + "/hello2")

View File

@@ -46,7 +46,7 @@ var _ = Describe("HTTP tests", func() {
var (
mux *http.ServeMux
client *http.Client
rt *http3.RoundTripper
tr *http3.Transport
server *http3.Server
stoppedServing chan struct{}
port int
@@ -112,20 +112,20 @@ var _ = Describe("HTTP tests", func() {
})
AfterEach(func() {
Expect(rt.Close()).NotTo(HaveOccurred())
Expect(tr.Close()).NotTo(HaveOccurred())
Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed())
})
BeforeEach(func() {
rt = &http3.RoundTripper{
tr = &http3.Transport{
TLSClientConfig: getTLSClientConfigWithoutServerName(),
QUICConfig: getQuicConfig(&quic.Config{
MaxIdleTimeout: 10 * time.Second,
}),
DisableCompression: true,
}
client = &http.Client{Transport: rt}
client = &http.Client{Transport: tr}
})
It("closes the connection after idle timeout", func() {
@@ -166,7 +166,7 @@ var _ = Describe("HTTP tests", func() {
var dialCounter int
testErr := errors.New("test error")
cl := http.Client{
Transport: &http3.RoundTripper{
Transport: &http3.Transport{
TLSClientConfig: getTLSClientConfig(),
Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
dialCounter++
@@ -355,7 +355,7 @@ var _ = Describe("HTTP tests", func() {
gw.Write([]byte("Hello, World!\n"))
})
client.Transport.(*http3.RoundTripper).DisableCompression = false
client.Transport.(*http3.Transport).DisableCompression = false
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/gzipped/hello", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
@@ -484,8 +484,9 @@ var _ = Describe("HTTP tests", func() {
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
rt := http3.SingleDestinationRoundTripper{Connection: conn}
str, err := rt.OpenRequestStream(context.Background())
tr := http3.Transport{}
cc := tr.NewClientConn(conn)
str, err := cc.OpenRequestStream(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(str.SendRequestHeader(req)).To(Succeed())
// make sure the request is received (and not stuck in some buffer, for example)
@@ -677,10 +678,10 @@ var _ = Describe("HTTP tests", func() {
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
rt := http3.SingleDestinationRoundTripper{Connection: conn}
hconn := rt.Start()
Eventually(hconn.ReceivedSettings(), 5*time.Second, 10*time.Millisecond).Should(BeClosed())
settings := hconn.Settings()
var tr http3.Transport
cc := tr.NewClientConn(conn)
Eventually(cc.ReceivedSettings(), 5*time.Second, 10*time.Millisecond).Should(BeClosed())
settings := cc.Settings()
Expect(settings.EnableExtendedConnect).To(BeTrue())
Expect(settings.EnableDatagrams).To(BeFalse())
Expect(settings.Other).To(BeEmpty())
@@ -696,7 +697,7 @@ var _ = Describe("HTTP tests", func() {
w.WriteHeader(http.StatusOK)
})
rt = &http3.RoundTripper{
tr = &http3.Transport{
TLSClientConfig: getTLSClientConfigWithoutServerName(),
QUICConfig: getQuicConfig(&quic.Config{
MaxIdleTimeout: 10 * time.Second,
@@ -708,7 +709,7 @@ var _ = Describe("HTTP tests", func() {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/settings", port), nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
_, err = tr.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
var settings *http3.Settings
Expect(settingsChan).To(Receive(&settings))
@@ -803,11 +804,9 @@ var _ = Describe("HTTP tests", func() {
)
Expect(err).ToNot(HaveOccurred())
rt := &http3.SingleDestinationRoundTripper{
Connection: conn,
EnableDatagrams: true,
}
str, err := rt.OpenRequestStream(context.Background())
tr := http3.Transport{EnableDatagrams: true}
cc := tr.NewClientConn(conn)
str, err := cc.OpenRequestStream(context.Background())
Expect(err).ToNot(HaveOccurred())
u, err := url.Parse(h)
Expect(err).ToNot(HaveOccurred())
@@ -981,21 +980,21 @@ var _ = Describe("HTTP tests", func() {
tlsConf := getTLSClientConfigWithoutServerName()
puts := make(chan string, 10)
tlsConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(10), nil, puts)
rt := &http3.RoundTripper{
tr := &http3.Transport{
TLSClientConfig: tlsConf,
QUICConfig: getQuicConfig(&quic.Config{
MaxIdleTimeout: 10 * time.Second,
}),
DisableCompression: true,
}
defer rt.Close()
defer tr.Close()
mux.HandleFunc("/0rtt", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(strconv.FormatBool(!r.TLS.HandshakeComplete)))
})
req, err := http.NewRequest(http3.MethodGet0RTT, fmt.Sprintf("https://localhost:%d/0rtt", proxy.LocalPort()), nil)
Expect(err).ToNot(HaveOccurred())
rsp, err := rt.RoundTrip(req)
rsp, err := tr.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(BeEquivalentTo(200))
data, err := io.ReadAll(rsp.Body)
@@ -1004,13 +1003,13 @@ var _ = Describe("HTTP tests", func() {
Expect(num0RTTPackets.Load()).To(BeZero())
Eventually(puts).Should(Receive())
rt2 := &http3.RoundTripper{
TLSClientConfig: rt.TLSClientConfig,
QUICConfig: rt.QUICConfig,
tr2 := &http3.Transport{
TLSClientConfig: tr.TLSClientConfig,
QUICConfig: tr.QUICConfig,
DisableCompression: true,
}
defer rt2.Close()
rsp, err = rt2.RoundTrip(req)
defer tr2.Close()
rsp, err = tr2.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(BeEquivalentTo(200))
data, err = io.ReadAll(rsp.Body)
@@ -1111,7 +1110,7 @@ var _ = Describe("HTTP tests", func() {
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal([]byte("shutdown")))
// manually close the client, since we don't support
client.Transport.(*http3.RoundTripper).Close()
client.Transport.(*http3.Transport).Close()
// make sure that CloseGracefully returned
Eventually(done).Should(BeClosed())

View File

@@ -67,7 +67,7 @@ func runTestcase(testcase string) error {
quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer}
if testcase == "http3" {
r := &http3.RoundTripper{
r := &http3.Transport{
TLSClientConfig: tlsConf,
QUICConfig: quicConf,
}