forked from quic-go/quic-go
http3: improve the client API (#4693)
* http3: rename RoundTripper to Transport * http3: rename SingleDestinationRoundTripper to ClientConn * http3: construct the ClientConn via Transport.NewClientConn
This commit is contained in:
@@ -40,7 +40,7 @@ func main() {
|
||||
}
|
||||
testdata.AddRootCA(pool)
|
||||
|
||||
roundTripper := &http3.RoundTripper{
|
||||
roundTripper := &http3.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: pool,
|
||||
InsecureSkipVerify: *insecure,
|
||||
|
||||
153
http3/client.go
153
http3/client.go
@@ -9,7 +9,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
@@ -38,102 +37,117 @@ var defaultQuicConfig = &quic.Config{
|
||||
KeepAlivePeriod: 10 * time.Second,
|
||||
}
|
||||
|
||||
// SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server.
|
||||
type SingleDestinationRoundTripper struct {
|
||||
Connection quic.Connection
|
||||
// ClientConn is an HTTP/3 client doing requests to a single remote server.
|
||||
type ClientConn struct {
|
||||
connection
|
||||
|
||||
// Enable support for HTTP/3 datagrams (RFC 9297).
|
||||
// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
|
||||
EnableDatagrams bool
|
||||
// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting enableDatagrams.
|
||||
enableDatagrams bool
|
||||
|
||||
// Additional HTTP/3 settings.
|
||||
// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
|
||||
AdditionalSettings map[uint64]uint64
|
||||
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
|
||||
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
|
||||
additionalSettings map[uint64]uint64
|
||||
|
||||
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||
// maxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||
// allowed in the server's response header.
|
||||
// Zero means to use a default limit.
|
||||
MaxResponseHeaderBytes int64
|
||||
maxResponseHeaderBytes uint64
|
||||
|
||||
// DisableCompression, if true, prevents the Transport from requesting compression with an
|
||||
// disableCompression, if true, prevents the Transport from requesting compression with an
|
||||
// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
|
||||
// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
|
||||
// decoded in the Response.Body.
|
||||
// However, if the user explicitly requested gzip it is not automatically uncompressed.
|
||||
DisableCompression bool
|
||||
disableCompression bool
|
||||
|
||||
Logger *slog.Logger
|
||||
logger *slog.Logger
|
||||
|
||||
initOnce sync.Once
|
||||
hconn *connection
|
||||
requestWriter *requestWriter
|
||||
decoder *qpack.Decoder
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = &SingleDestinationRoundTripper{}
|
||||
var _ http.RoundTripper = &ClientConn{}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) Start() Connection {
|
||||
c.initOnce.Do(func() { c.init() })
|
||||
return c.hconn
|
||||
}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) init() {
|
||||
func newClientConn(
|
||||
conn quic.Connection,
|
||||
enableDatagrams bool,
|
||||
additionalSettings map[uint64]uint64,
|
||||
streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error),
|
||||
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool),
|
||||
maxResponseHeaderBytes int64,
|
||||
disableCompression bool,
|
||||
logger *slog.Logger,
|
||||
) *ClientConn {
|
||||
c := &ClientConn{
|
||||
enableDatagrams: enableDatagrams,
|
||||
additionalSettings: additionalSettings,
|
||||
disableCompression: disableCompression,
|
||||
logger: logger,
|
||||
}
|
||||
if maxResponseHeaderBytes <= 0 {
|
||||
c.maxResponseHeaderBytes = defaultMaxResponseHeaderBytes
|
||||
} else {
|
||||
c.maxResponseHeaderBytes = uint64(maxResponseHeaderBytes)
|
||||
}
|
||||
c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
|
||||
c.requestWriter = newRequestWriter()
|
||||
c.hconn = newConnection(
|
||||
c.Connection.Context(),
|
||||
c.Connection,
|
||||
c.EnableDatagrams,
|
||||
c.connection = *newConnection(
|
||||
conn.Context(),
|
||||
conn,
|
||||
c.enableDatagrams,
|
||||
protocol.PerspectiveClient,
|
||||
c.Logger,
|
||||
c.logger,
|
||||
0,
|
||||
)
|
||||
// send the SETTINGs frame, using 0-RTT data, if possible
|
||||
go func() {
|
||||
if err := c.setupConn(c.hconn); err != nil {
|
||||
if c.Logger != nil {
|
||||
c.Logger.Debug("Setting up connection failed", "error", err)
|
||||
if err := c.setupConn(); err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("Setting up connection failed", "error", err)
|
||||
}
|
||||
c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
|
||||
c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
|
||||
}
|
||||
}()
|
||||
if c.StreamHijacker != nil {
|
||||
go c.handleBidirectionalStreams()
|
||||
if streamHijacker != nil {
|
||||
go c.handleBidirectionalStreams(streamHijacker)
|
||||
}
|
||||
go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker)
|
||||
go c.connection.HandleUnidirectionalStreams(uniStreamHijacker)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error {
|
||||
func (c *ClientConn) OpenRequestStream(ctx context.Context) (RequestStream, error) {
|
||||
return c.connection.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes)
|
||||
}
|
||||
|
||||
func (c *ClientConn) setupConn() error {
|
||||
// open the control stream
|
||||
str, err := conn.OpenUniStream()
|
||||
str, err := c.connection.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b := make([]byte, 0, 64)
|
||||
b = quicvarint.Append(b, streamTypeControlStream)
|
||||
// send the SETTINGS frame
|
||||
b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b)
|
||||
b = (&settingsFrame{Datagram: c.enableDatagrams, Other: c.additionalSettings}).Append(b)
|
||||
_, err = str.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
|
||||
func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)) {
|
||||
for {
|
||||
str, err := c.hconn.AcceptStream(context.Background())
|
||||
str, err := c.connection.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
if c.Logger != nil {
|
||||
c.Logger.Debug("accepting bidirectional stream failed", "error", err)
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("accepting bidirectional stream failed", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
fp := &frameParser{
|
||||
r: str,
|
||||
conn: c.hconn,
|
||||
conn: &c.connection,
|
||||
unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) {
|
||||
id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
|
||||
return c.StreamHijacker(ft, id, str, e)
|
||||
id := c.connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
|
||||
return streamHijacker(ft, id, str, e)
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
@@ -141,26 +155,17 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
if c.Logger != nil {
|
||||
c.Logger.Debug("error handling stream", "error", err)
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("error handling stream", "error", err)
|
||||
}
|
||||
}
|
||||
c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
||||
c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 {
|
||||
if c.MaxResponseHeaderBytes <= 0 {
|
||||
return defaultMaxResponseHeaderBytes
|
||||
}
|
||||
return uint64(c.MaxResponseHeaderBytes)
|
||||
}
|
||||
|
||||
// RoundTrip executes a request and returns a response
|
||||
func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
c.initOnce.Do(func() { c.init() })
|
||||
|
||||
func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rsp, err := c.roundTrip(req)
|
||||
if err != nil && req.Context().Err() != nil {
|
||||
// if the context was canceled, return the context cancellation error
|
||||
@@ -169,7 +174,7 @@ func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Resp
|
||||
return rsp, err
|
||||
}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Immediately send out this request, if this is a 0-RTT request.
|
||||
switch req.Method {
|
||||
case MethodGet0RTT:
|
||||
@@ -200,17 +205,23 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp
|
||||
connCtx := c.Connection.Context()
|
||||
// wait for the server's SETTINGS frame to arrive
|
||||
select {
|
||||
case <-c.hconn.ReceivedSettings():
|
||||
case <-c.connection.ReceivedSettings():
|
||||
case <-connCtx.Done():
|
||||
return nil, context.Cause(connCtx)
|
||||
}
|
||||
if !c.hconn.Settings().EnableExtendedConnect {
|
||||
if !c.connection.Settings().EnableExtendedConnect {
|
||||
return nil, errors.New("http3: server didn't enable Extended CONNECT")
|
||||
}
|
||||
}
|
||||
|
||||
reqDone := make(chan struct{})
|
||||
str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes())
|
||||
str, err := c.connection.openRequestStream(
|
||||
req.Context(),
|
||||
c.requestWriter,
|
||||
reqDone,
|
||||
c.disableCompression,
|
||||
c.maxResponseHeaderBytes,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -238,12 +249,6 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp
|
||||
return rsp, maybeReplaceError(err)
|
||||
}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) {
|
||||
c.initOnce.Do(func() { c.init() })
|
||||
|
||||
return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes())
|
||||
}
|
||||
|
||||
// cancelingReader reads from the io.Reader.
|
||||
// It cancels writing on the stream if any error other than io.EOF occurs.
|
||||
type cancelingReader struct {
|
||||
@@ -259,7 +264,7 @@ func (r *cancelingReader) Read(b []byte) (int, error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
|
||||
func (c *ClientConn) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
|
||||
defer body.Close()
|
||||
buf := make([]byte, bodyCopyBufferSize)
|
||||
sr := &cancelingReader{str: str, r: body}
|
||||
@@ -283,7 +288,7 @@ func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.Read
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
|
||||
func (c *ClientConn) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
|
||||
if err := str.SendRequestHeader(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -299,8 +304,8 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques
|
||||
contentLength = req.ContentLength
|
||||
}
|
||||
if err := c.sendRequestBody(str, req.Body, contentLength); err != nil {
|
||||
if c.Logger != nil {
|
||||
c.Logger.Debug("error writing request", "error", err)
|
||||
if c.logger != nil {
|
||||
c.logger.Debug("error writing request", "error", err)
|
||||
}
|
||||
}
|
||||
str.Close()
|
||||
@@ -337,7 +342,7 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques
|
||||
}
|
||||
break
|
||||
}
|
||||
connState := c.hconn.ConnectionState().TLS
|
||||
connState := c.connection.ConnectionState().TLS
|
||||
res.TLS = &connState
|
||||
res.Request = req
|
||||
return res, nil
|
||||
|
||||
@@ -81,8 +81,7 @@ var _ = Describe("Client", func() {
|
||||
It("hijacks a bidirectional stream of unknown frame type", func() {
|
||||
id := quic.ConnectionTracingID(1234)
|
||||
frameTypeChan := make(chan FrameType, 1)
|
||||
rt := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
tr := &Transport{
|
||||
StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
|
||||
Expect(e).ToNot(HaveOccurred())
|
||||
Expect(connTracingID).To(Equal(id))
|
||||
@@ -101,7 +100,8 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
_, err := rt.RoundTrip(request)
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(request)
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
@@ -109,8 +109,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
|
||||
frameTypeChan := make(chan FrameType, 1)
|
||||
rt := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
tr := &Transport{
|
||||
StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
|
||||
Expect(e).ToNot(HaveOccurred())
|
||||
frameTypeChan <- ft
|
||||
@@ -129,15 +128,15 @@ var _ = Describe("Client", func() {
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
||||
_, err := rt.RoundTrip(request)
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(request)
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
})
|
||||
|
||||
It("closes the connection when hijacker returned error", func() {
|
||||
frameTypeChan := make(chan FrameType, 1)
|
||||
rt := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
tr := &Transport{
|
||||
StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
|
||||
Expect(e).ToNot(HaveOccurred())
|
||||
frameTypeChan <- ft
|
||||
@@ -156,7 +155,8 @@ var _ = Describe("Client", func() {
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
||||
_, err := rt.RoundTrip(request)
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(request)
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
})
|
||||
@@ -165,8 +165,7 @@ var _ = Describe("Client", func() {
|
||||
testErr := errors.New("test error")
|
||||
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||
done := make(chan struct{})
|
||||
rt := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
tr := &Transport{
|
||||
StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, e error) (hijacked bool, err error) {
|
||||
defer close(done)
|
||||
Expect(e).To(MatchError(testErr))
|
||||
@@ -185,7 +184,8 @@ var _ = Describe("Client", func() {
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
||||
_, err := rt.RoundTrip(request)
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(request)
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
@@ -226,8 +226,7 @@ var _ = Describe("Client", func() {
|
||||
It("hijacks an unidirectional stream of unknown stream type", func() {
|
||||
id := quic.ConnectionTracingID(100)
|
||||
streamTypeChan := make(chan StreamType, 1)
|
||||
rt := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
tr := &Transport{
|
||||
UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
|
||||
Expect(connTracingID).To(Equal(id))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -235,7 +234,6 @@ var _ = Describe("Client", func() {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
|
||||
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
@@ -248,7 +246,8 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
_, err := rt.RoundTrip(req)
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
@@ -258,8 +257,7 @@ var _ = Describe("Client", func() {
|
||||
testErr := errors.New("test error")
|
||||
done := make(chan struct{})
|
||||
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||
rt := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
tr := &Transport{
|
||||
UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool {
|
||||
defer close(done)
|
||||
Expect(st).To(BeZero())
|
||||
@@ -268,7 +266,6 @@ var _ = Describe("Client", func() {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
|
||||
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
|
||||
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||
@@ -277,7 +274,8 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
_, err := rt.RoundTrip(req)
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
@@ -285,15 +283,13 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
||||
streamTypeChan := make(chan StreamType, 1)
|
||||
rt := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
tr := &Transport{
|
||||
UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
streamTypeChan <- st
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
|
||||
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
@@ -307,7 +303,8 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
|
||||
conn.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
_, err := rt.RoundTrip(req)
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
@@ -337,13 +334,13 @@ var _ = Describe("Client", func() {
|
||||
return nil, errors.New("test done")
|
||||
}).AnyTimes()
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
rt := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
tr := &Transport{
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
cc := tr.NewClientConn(conn)
|
||||
req, err := http.NewRequest(http.MethodGet, "https://quic-go.net", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req)
|
||||
_, err = cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError("test done"))
|
||||
t, err := quicvarint.Read(&buf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -373,10 +370,10 @@ var _ = Describe("Client", func() {
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
|
||||
rt := &SingleDestinationRoundTripper{Connection: conn}
|
||||
hconn := rt.Start()
|
||||
Eventually(hconn.ReceivedSettings()).Should(BeClosed())
|
||||
settings := hconn.Settings()
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
Eventually(cc.ReceivedSettings()).Should(BeClosed())
|
||||
settings := cc.Settings()
|
||||
Expect(settings.EnableExtendedConnect).To(BeTrue())
|
||||
// test shutdown
|
||||
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
@@ -410,8 +407,9 @@ var _ = Describe("Client", func() {
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("test error"))
|
||||
|
||||
rt := &SingleDestinationRoundTripper{Connection: conn}
|
||||
_, err := rt.RoundTrip(&http.Request{
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(&http.Request{
|
||||
Method: http.MethodConnect,
|
||||
Proto: "connect",
|
||||
Host: "localhost",
|
||||
@@ -450,8 +448,9 @@ var _ = Describe("Client", func() {
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
conn.EXPECT().Context().Return(context.Background())
|
||||
|
||||
rt := &SingleDestinationRoundTripper{Connection: conn}
|
||||
_, err := rt.RoundTrip(&http.Request{
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(&http.Request{
|
||||
Method: http.MethodConnect,
|
||||
Proto: "connect",
|
||||
Host: "localhost",
|
||||
@@ -470,7 +469,6 @@ var _ = Describe("Client", func() {
|
||||
req *http.Request
|
||||
str *mockquic.MockStream
|
||||
conn *mockquic.MockEarlyConnection
|
||||
cl *SingleDestinationRoundTripper
|
||||
settingsFrameWritten chan struct{}
|
||||
)
|
||||
testDone := make(chan struct{})
|
||||
@@ -517,7 +515,6 @@ var _ = Describe("Client", func() {
|
||||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
cl = &SingleDestinationRoundTripper{Connection: conn}
|
||||
var err error
|
||||
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -533,7 +530,9 @@ var _ = Describe("Client", func() {
|
||||
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
||||
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
@@ -552,7 +551,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
||||
return 0, testErr
|
||||
})
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", serialized))
|
||||
// make sure the request wasn't modified
|
||||
@@ -572,7 +573,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
|
||||
Expect(rsp.ProtoMajor).To(Equal(3))
|
||||
@@ -599,7 +602,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = io.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -625,7 +630,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = io.ReadAll(rsp.Body)
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -662,7 +669,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = io.ReadAll(rsp.Body)
|
||||
Expect(err).To(MatchError(errors.New("additional HEADERS frame received after trailers")))
|
||||
@@ -702,7 +711,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = io.ReadAll(rsp.Body)
|
||||
Expect(err).To(MatchError(errors.New("DATA frame received after trailers")))
|
||||
@@ -744,7 +755,9 @@ var _ = Describe("Client", func() {
|
||||
<-done
|
||||
return 0, testErr
|
||||
})
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
hfs := decodeHeader(strBuf)
|
||||
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
|
||||
@@ -770,7 +783,10 @@ var _ = Describe("Client", func() {
|
||||
<-done
|
||||
return 0, errors.New("done")
|
||||
})
|
||||
cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(strBuf.String()).To(ContainSubstring("request"))
|
||||
Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
|
||||
})
|
||||
@@ -794,7 +810,9 @@ var _ = Describe("Client", func() {
|
||||
})
|
||||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() error { close(closed); return nil })
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
@@ -806,7 +824,9 @@ var _ = Describe("Client", func() {
|
||||
r := bytes.NewReader(b)
|
||||
str.EXPECT().Close().Do(func() error { close(closed); return nil })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: expected first frame to be a HEADERS frame"))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
@@ -825,13 +845,16 @@ var _ = Describe("Client", func() {
|
||||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() error { close(closed); return nil })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("cancels the stream when the HEADERS frame is too large", func() {
|
||||
cl.MaxResponseHeaderBytes = 1337
|
||||
tr := &Transport{MaxResponseHeaderBytes: 1337}
|
||||
cc := tr.NewClientConn(conn)
|
||||
b := (&headersFrame{Length: 1338}).Append(nil)
|
||||
r := bytes.NewReader(b)
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
@@ -839,14 +862,16 @@ var _ = Describe("Client", func() {
|
||||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() error { close(closed); return nil })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
||||
_, err := cl.RoundTrip(req)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: HEADERS frame too large: 1338 bytes (max: 1337)"))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("opens a request stream", func() {
|
||||
cl.Connection.(quic.EarlyConnection).HandshakeComplete()
|
||||
str, err := cl.OpenRequestStream(context.Background())
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
conn.HandshakeComplete()
|
||||
str, err := cc.OpenRequestStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.SendRequestHeader(req)).To(Succeed())
|
||||
str.Write([]byte("foobar"))
|
||||
@@ -863,9 +888,11 @@ var _ = Describe("Client", func() {
|
||||
req := req.WithContext(ctx)
|
||||
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
||||
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
_, err := cl.RoundTrip(req)
|
||||
_, err := cc.RoundTrip(req)
|
||||
errChan <- err
|
||||
}()
|
||||
Consistently(errChan).ShouldNot(Receive())
|
||||
@@ -895,7 +922,9 @@ var _ = Describe("Client", func() {
|
||||
<-canceled
|
||||
return 0, errors.New("test done")
|
||||
})
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
@@ -916,7 +945,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
@@ -940,17 +971,17 @@ var _ = Describe("Client", func() {
|
||||
)
|
||||
testErr := errors.New("test done")
|
||||
str.EXPECT().Read(gomock.Any()).Return(0, testErr)
|
||||
_, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
_, err := cc.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
hfs := decodeHeader(buf)
|
||||
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
})
|
||||
|
||||
It("doesn't add gzip if the header disable it", func() {
|
||||
client := &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
DisableCompression: true,
|
||||
}
|
||||
tr := &Transport{DisableCompression: true}
|
||||
client := tr.NewClientConn(conn)
|
||||
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
|
||||
@@ -985,7 +1016,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str.EXPECT().Close()
|
||||
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -1009,7 +1042,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str.EXPECT().Close()
|
||||
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -1045,7 +1080,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
|
||||
Expect(rsp.ProtoMajor).To(Equal(3))
|
||||
@@ -1077,7 +1114,9 @@ var _ = Describe("Client", func() {
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := cl.RoundTrip(req)
|
||||
tr := &Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
rsp, err := cc.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
|
||||
Expect(rsp.ProtoMajor).To(Equal(3))
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -30,7 +31,7 @@ type Settings struct {
|
||||
|
||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
type RoundTripOpt struct {
|
||||
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
|
||||
// OnlyCachedConn controls whether the Transport may create a new QUIC connection.
|
||||
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
|
||||
OnlyCachedConn bool
|
||||
}
|
||||
@@ -59,8 +60,8 @@ func (r *roundTripperWithCount) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTripper implements the http.RoundTripper interface
|
||||
type RoundTripper struct {
|
||||
// Transport implements the http.RoundTripper interface
|
||||
type Transport struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
@@ -97,6 +98,11 @@ type RoundTripper struct {
|
||||
// However, if the user explicitly requested gzip it is not automatically uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
|
||||
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
|
||||
|
||||
Logger *slog.Logger
|
||||
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
|
||||
@@ -107,18 +113,53 @@ type RoundTripper struct {
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.RoundTripper = &RoundTripper{}
|
||||
_ io.Closer = &RoundTripper{}
|
||||
_ http.RoundTripper = &Transport{}
|
||||
_ io.Closer = &Transport{}
|
||||
)
|
||||
|
||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
||||
// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set
|
||||
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||
|
||||
func (t *Transport) init() error {
|
||||
if t.newClient == nil {
|
||||
t.newClient = func(conn quic.EarlyConnection) singleRoundTripper {
|
||||
return newClientConn(
|
||||
conn,
|
||||
t.EnableDatagrams,
|
||||
t.AdditionalSettings,
|
||||
t.StreamHijacker,
|
||||
t.UniStreamHijacker,
|
||||
t.MaxResponseHeaderBytes,
|
||||
t.DisableCompression,
|
||||
t.Logger,
|
||||
)
|
||||
}
|
||||
}
|
||||
if t.QUICConfig == nil {
|
||||
t.QUICConfig = defaultQuicConfig.Clone()
|
||||
t.QUICConfig.EnableDatagrams = t.EnableDatagrams
|
||||
}
|
||||
if t.EnableDatagrams && !t.QUICConfig.EnableDatagrams {
|
||||
return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
|
||||
}
|
||||
if len(t.QUICConfig.Versions) == 0 {
|
||||
t.QUICConfig = t.QUICConfig.Clone()
|
||||
t.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
|
||||
}
|
||||
if len(t.QUICConfig.Versions) != 1 {
|
||||
return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
|
||||
}
|
||||
if t.QUICConfig.MaxIncomingStreams == 0 {
|
||||
t.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTripOpt is like RoundTrip, but takes options.
|
||||
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
r.initOnce.Do(func() { r.initErr = r.init() })
|
||||
if r.initErr != nil {
|
||||
return nil, r.initErr
|
||||
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
t.initOnce.Do(func() { t.initErr = t.init() })
|
||||
if t.initErr != nil {
|
||||
return nil, t.initErr
|
||||
}
|
||||
|
||||
if req.URL == nil {
|
||||
@@ -154,7 +195,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
||||
}
|
||||
|
||||
hostname := authorityAddr(hostnameFromURL(req.URL))
|
||||
cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn)
|
||||
cl, isReused, err := t.getClient(req.Context(), hostname, opt.OnlyCachedConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -166,7 +207,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
||||
}
|
||||
|
||||
if cl.dialErr != nil {
|
||||
r.removeClient(hostname)
|
||||
t.removeClient(hostname)
|
||||
return nil, cl.dialErr
|
||||
}
|
||||
defer cl.useCount.Add(-1)
|
||||
@@ -176,12 +217,12 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
||||
// so we remove the client from the cache so that subsequent trips reconnect
|
||||
// context cancelation is excluded as is does not signify a connection error
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
r.removeClient(hostname)
|
||||
t.removeClient(hostname)
|
||||
}
|
||||
|
||||
if isReused {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
return r.RoundTripOpt(req, opt)
|
||||
return t.RoundTripOpt(req, opt)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -189,51 +230,19 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
||||
}
|
||||
|
||||
// RoundTrip does a round trip.
|
||||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (r *RoundTripper) init() error {
|
||||
if r.newClient == nil {
|
||||
r.newClient = func(conn quic.EarlyConnection) singleRoundTripper {
|
||||
return &SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
EnableDatagrams: r.EnableDatagrams,
|
||||
DisableCompression: r.DisableCompression,
|
||||
AdditionalSettings: r.AdditionalSettings,
|
||||
MaxResponseHeaderBytes: r.MaxResponseHeaderBytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
if r.QUICConfig == nil {
|
||||
r.QUICConfig = defaultQuicConfig.Clone()
|
||||
r.QUICConfig.EnableDatagrams = r.EnableDatagrams
|
||||
}
|
||||
if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams {
|
||||
return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
|
||||
}
|
||||
if len(r.QUICConfig.Versions) == 0 {
|
||||
r.QUICConfig = r.QUICConfig.Clone()
|
||||
r.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
|
||||
}
|
||||
if len(r.QUICConfig.Versions) != 1 {
|
||||
return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
|
||||
}
|
||||
if r.QUICConfig.MaxIncomingStreams == 0 {
|
||||
r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if r.clients == nil {
|
||||
r.clients = make(map[string]*roundTripperWithCount)
|
||||
if t.clients == nil {
|
||||
t.clients = make(map[string]*roundTripperWithCount)
|
||||
}
|
||||
|
||||
cl, ok := r.clients[hostname]
|
||||
cl, ok := t.clients[hostname]
|
||||
if !ok {
|
||||
if onlyCached {
|
||||
return nil, false, ErrNoCachedConn
|
||||
@@ -246,7 +255,7 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
|
||||
go func() {
|
||||
defer close(cl.dialing)
|
||||
defer cancel()
|
||||
conn, rt, err := r.dial(ctx, hostname)
|
||||
conn, rt, err := t.dial(ctx, hostname)
|
||||
if err != nil {
|
||||
cl.dialErr = err
|
||||
return
|
||||
@@ -254,12 +263,12 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
|
||||
cl.conn = conn
|
||||
cl.rt = rt
|
||||
}()
|
||||
r.clients[hostname] = cl
|
||||
t.clients[hostname] = cl
|
||||
}
|
||||
select {
|
||||
case <-cl.dialing:
|
||||
if cl.dialErr != nil {
|
||||
delete(r.clients, hostname)
|
||||
delete(t.clients, hostname)
|
||||
return nil, false, cl.dialErr
|
||||
}
|
||||
select {
|
||||
@@ -273,12 +282,12 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
|
||||
return cl, isReused, nil
|
||||
}
|
||||
|
||||
func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) {
|
||||
func (t *Transport) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) {
|
||||
var tlsConf *tls.Config
|
||||
if r.TLSClientConfig == nil {
|
||||
if t.TLSClientConfig == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
tlsConf = r.TLSClientConfig.Clone()
|
||||
tlsConf = t.TLSClientConfig.Clone()
|
||||
}
|
||||
if tlsConf.ServerName == "" {
|
||||
sni, _, err := net.SplitHostPort(hostname)
|
||||
@@ -289,61 +298,74 @@ func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyCon
|
||||
tlsConf.ServerName = sni
|
||||
}
|
||||
// Replace existing ALPNs by H3
|
||||
tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])}
|
||||
tlsConf.NextProtos = []string{versionToALPN(t.QUICConfig.Versions[0])}
|
||||
|
||||
dial := r.Dial
|
||||
dial := t.Dial
|
||||
if dial == nil {
|
||||
if r.transport == nil {
|
||||
if t.transport == nil {
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
r.transport = &quic.Transport{Conn: udpConn}
|
||||
t.transport = &quic.Transport{Conn: udpConn}
|
||||
}
|
||||
dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
|
||||
return t.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig)
|
||||
conn, err := dial(ctx, hostname, tlsConf, t.QUICConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, r.newClient(conn), nil
|
||||
return conn, t.newClient(conn), nil
|
||||
}
|
||||
|
||||
func (r *RoundTripper) removeClient(hostname string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
if r.clients == nil {
|
||||
func (t *Transport) removeClient(hostname string) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
if t.clients == nil {
|
||||
return
|
||||
}
|
||||
delete(r.clients, hostname)
|
||||
delete(t.clients, hostname)
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this RoundTripper has used.
|
||||
func (t *Transport) NewClientConn(conn quic.Connection) *ClientConn {
|
||||
return newClientConn(
|
||||
conn,
|
||||
t.EnableDatagrams,
|
||||
t.AdditionalSettings,
|
||||
t.StreamHijacker,
|
||||
t.UniStreamHijacker,
|
||||
t.MaxResponseHeaderBytes,
|
||||
t.DisableCompression,
|
||||
t.Logger,
|
||||
)
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this Transport has used.
|
||||
// It also closes the underlying UDPConn if it is not nil.
|
||||
func (r *RoundTripper) Close() error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for _, cl := range r.clients {
|
||||
func (t *Transport) Close() error {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
for _, cl := range t.clients {
|
||||
if err := cl.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.clients = nil
|
||||
if r.transport != nil {
|
||||
if err := r.transport.Close(); err != nil {
|
||||
t.clients = nil
|
||||
if t.transport != nil {
|
||||
if err := t.transport.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.transport.Conn.Close(); err != nil {
|
||||
if err := t.transport.Conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
r.transport = nil
|
||||
t.transport = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -376,13 +398,13 @@ func isNotToken(r rune) bool {
|
||||
return !httpguts.IsTokenRune(r)
|
||||
}
|
||||
|
||||
func (r *RoundTripper) CloseIdleConnections() {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for hostname, cl := range r.clients {
|
||||
func (t *Transport) CloseIdleConnections() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
for hostname, cl := range t.clients {
|
||||
if cl.useCount.Load() == 0 {
|
||||
cl.Close()
|
||||
delete(r.clients, hostname)
|
||||
delete(t.clients, hostname)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -45,7 +45,7 @@ func (m *mockBody) Close() error {
|
||||
return m.closeErr
|
||||
}
|
||||
|
||||
var _ = Describe("RoundTripper", func() {
|
||||
var _ = Describe("Transport", func() {
|
||||
var req *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
@@ -58,14 +58,14 @@ var _ = Describe("RoundTripper", func() {
|
||||
qconf := &quic.Config{
|
||||
Versions: []quic.Version{protocol.Version2, protocol.Version1},
|
||||
}
|
||||
rt := &RoundTripper{QUICConfig: qconf}
|
||||
_, err := rt.RoundTrip(req)
|
||||
tr := &Transport{QUICConfig: qconf}
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
|
||||
})
|
||||
|
||||
It("uses the default QUIC and TLS config if none is give", func() {
|
||||
var dialAddrCalled bool
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
|
||||
@@ -75,14 +75,14 @@ var _ = Describe("RoundTripper", func() {
|
||||
return nil, errors.New("test done")
|
||||
},
|
||||
}
|
||||
_, err := rt.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
Expect(dialAddrCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
var dialAddrCalled bool
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
||||
@@ -92,7 +92,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err = tr.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
Expect(dialAddrCalled).To(BeTrue())
|
||||
})
|
||||
@@ -100,7 +100,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
It("sets the ServerName in the tls.Config, if not set", func() {
|
||||
const host = "foo.bar"
|
||||
var dialCalled bool
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(tlsCfg.ServerName).To(Equal(host))
|
||||
@@ -110,7 +110,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://foo.bar", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err = tr.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
Expect(dialCalled).To(BeTrue())
|
||||
})
|
||||
@@ -122,7 +122,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
}
|
||||
quicConf := &quic.Config{MaxIdleTimeout: 3 * time.Nanosecond}
|
||||
var dialAddrCalled bool
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(host).To(Equal("www.example.org:443"))
|
||||
@@ -135,7 +135,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
QUICConfig: quicConf,
|
||||
TLSClientConfig: tlsConf,
|
||||
}
|
||||
_, err := rt.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
Expect(dialAddrCalled).To(BeTrue())
|
||||
// make sure the original tls.Config was not modified
|
||||
@@ -149,7 +149,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
// nolint:staticcheck // This is a test.
|
||||
ctx := context.WithValue(context.Background(), "foo", "bar")
|
||||
var dialerCalled bool
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(ctx.Value("foo").(string)).To(Equal("bar"))
|
||||
@@ -162,14 +162,14 @@ var _ = Describe("RoundTripper", func() {
|
||||
TLSClientConfig: tlsConf,
|
||||
QUICConfig: quicConf,
|
||||
}
|
||||
_, err := rt.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
|
||||
_, err := tr.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(dialerCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
It("enables HTTP/3 Datagrams", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
EnableDatagrams: true,
|
||||
Dial: func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
@@ -177,19 +177,19 @@ var _ = Describe("RoundTripper", func() {
|
||||
return nil, testErr
|
||||
},
|
||||
}
|
||||
_, err := rt.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := tr.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("requires quic.Config.EnableDatagrams if HTTP/3 datagrams are enabled", func() {
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
QUICConfig: &quic.Config{EnableDatagrams: false},
|
||||
EnableDatagrams: true,
|
||||
Dial: func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) {
|
||||
return nil, errors.New("handshake error")
|
||||
},
|
||||
}
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("HTTP Datagrams enabled, but QUIC Datagrams disabled"))
|
||||
})
|
||||
|
||||
@@ -200,29 +200,29 @@ var _ = Describe("RoundTripper", func() {
|
||||
req2, err := http.NewRequest("GET", "https://example.com/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var hostsDialed []string
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(_ context.Context, host string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
|
||||
hostsDialed = append(hostsDialed, host)
|
||||
return nil, testErr
|
||||
},
|
||||
}
|
||||
_, err = rt.RoundTrip(req1)
|
||||
_, err = tr.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
_, err = rt.RoundTrip(req2)
|
||||
_, err = tr.RoundTrip(req2)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(hostsDialed).To(Equal([]string{"quic-go.net:443", "example.com:443"}))
|
||||
})
|
||||
|
||||
Context("reusing clients", func() {
|
||||
var (
|
||||
rt *RoundTripper
|
||||
tr *Transport
|
||||
req1, req2 *http.Request
|
||||
clientChan chan *MockSingleRoundTripper
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
clientChan = make(chan *MockSingleRoundTripper, 16)
|
||||
rt = &RoundTripper{
|
||||
tr = &Transport{
|
||||
newClient: func(quic.EarlyConnection) singleRoundTripper {
|
||||
select {
|
||||
case c := <-clientChan:
|
||||
@@ -252,14 +252,14 @@ var _ = Describe("RoundTripper", func() {
|
||||
cl.EXPECT().RoundTrip(req1).Return(&http.Response{Request: req1}, nil)
|
||||
cl.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
|
||||
var count int
|
||||
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
count++
|
||||
return conn, nil
|
||||
}
|
||||
rsp, err := rt.RoundTrip(req1)
|
||||
rsp, err := tr.RoundTrip(req1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Request).To(Equal(req1))
|
||||
rsp, err = rt.RoundTrip(req2)
|
||||
rsp, err = tr.RoundTrip(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Request).To(Equal(req2))
|
||||
Expect(count).To(Equal(1))
|
||||
@@ -277,7 +277,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
var count int
|
||||
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
count++
|
||||
if count == 1 {
|
||||
return nil, testErr
|
||||
@@ -288,9 +288,9 @@ var _ = Describe("RoundTripper", func() {
|
||||
close(handshakeChan)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
|
||||
cl1.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
|
||||
_, err = rt.RoundTrip(req1)
|
||||
_, err = tr.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
rsp, err := rt.RoundTrip(req2)
|
||||
rsp, err := tr.RoundTrip(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Request).To(Equal(req2))
|
||||
Expect(count).To(Equal(2))
|
||||
@@ -309,7 +309,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
|
||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
var count int
|
||||
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
count++
|
||||
return conn, nil
|
||||
}
|
||||
@@ -319,9 +319,9 @@ var _ = Describe("RoundTripper", func() {
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
|
||||
cl1.EXPECT().RoundTrip(req1).Return(nil, testErr)
|
||||
cl2.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
|
||||
_, err = rt.RoundTrip(req1)
|
||||
_, err = tr.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
rsp, err := rt.RoundTrip(req2)
|
||||
rsp, err := tr.RoundTrip(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Request).To(Equal(req2))
|
||||
Expect(count).To(Equal(2))
|
||||
@@ -340,7 +340,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
|
||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
var count int
|
||||
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
count++
|
||||
return conn, nil
|
||||
}
|
||||
@@ -350,9 +350,9 @@ var _ = Describe("RoundTripper", func() {
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
|
||||
cl1.EXPECT().RoundTrip(req1).Return(nil, testErr)
|
||||
cl1.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
|
||||
_, err = rt.RoundTrip(req1)
|
||||
_, err = tr.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
rsp, err := rt.RoundTrip(req2)
|
||||
rsp, err := tr.RoundTrip(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Request).To(Equal(req2))
|
||||
Expect(count).To(Equal(1))
|
||||
@@ -383,30 +383,30 @@ var _ = Describe("RoundTripper", func() {
|
||||
close(handshakeChan)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
|
||||
var count int
|
||||
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
count++
|
||||
return conn, nil
|
||||
}
|
||||
rsp1, err := rt.RoundTrip(req1)
|
||||
rsp1, err := tr.RoundTrip(req1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
|
||||
rsp2, err := rt.RoundTrip(req2)
|
||||
rsp2, err := tr.RoundTrip(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
|
||||
})
|
||||
|
||||
It("only issues a request once, even if a timeout error occurs", func() {
|
||||
var count int
|
||||
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
count++
|
||||
return mockquic.NewMockEarlyConnection(mockCtrl), nil
|
||||
}
|
||||
rt.newClient = func(quic.EarlyConnection) singleRoundTripper {
|
||||
tr.newClient = func(quic.EarlyConnection) singleRoundTripper {
|
||||
cl := NewMockSingleRoundTripper(mockCtrl)
|
||||
cl.EXPECT().RoundTrip(gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
|
||||
return cl
|
||||
}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
_, err := tr.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
||||
Expect(count).To(Equal(1))
|
||||
})
|
||||
@@ -426,7 +426,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
conn.EXPECT().HandshakeComplete().Return(wait).AnyTimes()
|
||||
var count int
|
||||
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
tr.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
count++
|
||||
return conn, nil
|
||||
}
|
||||
@@ -435,7 +435,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer func() { done <- struct{}{} }()
|
||||
_, err := rt.RoundTrip(req1)
|
||||
_, err := tr.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
||||
}()
|
||||
// wait for the first requests to be issued
|
||||
@@ -443,7 +443,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer func() { done <- struct{}{} }()
|
||||
_, err := rt.RoundTrip(req2)
|
||||
_, err := tr.RoundTrip(req2)
|
||||
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
||||
}()
|
||||
Eventually(reqs).Should(Receive())
|
||||
@@ -456,19 +456,19 @@ var _ = Describe("RoundTripper", func() {
|
||||
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
|
||||
_, err = tr.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
|
||||
Expect(err).To(MatchError(ErrNoCachedConn))
|
||||
})
|
||||
})
|
||||
|
||||
Context("validating request", func() {
|
||||
var rt RoundTripper
|
||||
var tr Transport
|
||||
|
||||
It("rejects plain HTTP requests", func() {
|
||||
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
||||
req.Body = &mockBody{}
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req)
|
||||
_, err = tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: unsupported protocol scheme: http"))
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
@@ -476,7 +476,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
It("rejects requests without a URL", func() {
|
||||
req.URL = nil
|
||||
req.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: nil Request.URL"))
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
@@ -484,7 +484,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
It("rejects request without a URL Host", func() {
|
||||
req.URL.Host = ""
|
||||
req.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: no Host in request URL"))
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
@@ -492,34 +492,34 @@ var _ = Describe("RoundTripper", func() {
|
||||
It("doesn't try to close the body if the request doesn't have one", func() {
|
||||
req.URL = nil
|
||||
Expect(req.Body).To(BeNil())
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: nil Request.URL"))
|
||||
})
|
||||
|
||||
It("rejects requests without a header", func() {
|
||||
req.Header = nil
|
||||
req.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: nil Request.Header"))
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects requests with invalid header name fields", func() {
|
||||
req.Header.Add("foobär", "value")
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
|
||||
})
|
||||
|
||||
It("rejects requests with invalid header name values", func() {
|
||||
req.Header.Add("foo", string([]byte{0x7}))
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
|
||||
})
|
||||
|
||||
It("rejects requests with an invalid request method", func() {
|
||||
req.Method = "foobär"
|
||||
req.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: invalid method \"foobär\""))
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
@@ -528,7 +528,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
Context("closing", func() {
|
||||
It("closes", func() {
|
||||
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
return conn, nil
|
||||
},
|
||||
@@ -540,14 +540,14 @@ var _ = Describe("RoundTripper", func() {
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://quic-go.net/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req)
|
||||
_, err = tr.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(0), "")
|
||||
Expect(rt.Close()).To(Succeed())
|
||||
Expect(tr.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("closes while dialing", func() {
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(ctx context.Context, _ string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
|
||||
defer GinkgoRecover()
|
||||
Eventually(ctx.Done()).Should(BeClosed())
|
||||
@@ -560,12 +560,12 @@ var _ = Describe("RoundTripper", func() {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := rt.RoundTrip(req)
|
||||
_, err := tr.RoundTrip(req)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
Consistently(errChan, scaleDuration(30*time.Millisecond)).ShouldNot(Receive())
|
||||
Expect(rt.Close()).To(Succeed())
|
||||
Expect(tr.Close()).To(Succeed())
|
||||
var rtErr error
|
||||
Eventually(errChan).Should(Receive(&rtErr))
|
||||
Expect(rtErr).To(MatchError("cancelled"))
|
||||
@@ -574,7 +574,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
It("closes idle connections", func() {
|
||||
conn1 := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
conn2 := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
rt := &RoundTripper{
|
||||
tr := &Transport{
|
||||
Dial: func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
|
||||
switch hostname {
|
||||
case "site1.com:443":
|
||||
@@ -598,7 +598,7 @@ var _ = Describe("RoundTripper", func() {
|
||||
req2 = req2.WithContext(ctx2)
|
||||
roundTripCalled := make(chan struct{})
|
||||
reqFinished := make(chan struct{})
|
||||
rt.newClient = func(quic.EarlyConnection) singleRoundTripper {
|
||||
tr.newClient = func(quic.EarlyConnection) singleRoundTripper {
|
||||
cl := NewMockSingleRoundTripper(mockCtrl)
|
||||
cl.EXPECT().RoundTrip(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Response, error) {
|
||||
roundTripCalled <- struct{}{}
|
||||
@@ -608,11 +608,11 @@ var _ = Describe("RoundTripper", func() {
|
||||
return cl
|
||||
}
|
||||
go func() {
|
||||
rt.RoundTrip(req1)
|
||||
tr.RoundTrip(req1)
|
||||
reqFinished <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
rt.RoundTrip(req2)
|
||||
tr.RoundTrip(req2)
|
||||
reqFinished <- struct{}{}
|
||||
}()
|
||||
<-roundTripCalled
|
||||
@@ -622,12 +622,12 @@ var _ = Describe("RoundTripper", func() {
|
||||
<-reqFinished
|
||||
// req1 is finished
|
||||
conn1.EXPECT().CloseWithError(gomock.Any(), gomock.Any())
|
||||
rt.CloseIdleConnections()
|
||||
tr.CloseIdleConnections()
|
||||
cancel2()
|
||||
<-reqFinished
|
||||
// all requests are finished
|
||||
conn2.EXPECT().CloseWithError(gomock.Any(), gomock.Any())
|
||||
rt.CloseIdleConnections()
|
||||
tr.CloseIdleConnections()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -64,7 +64,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
|
||||
mux1 *http.ServeMux
|
||||
mux2 *http.ServeMux
|
||||
client *http.Client
|
||||
rt *http3.RoundTripper
|
||||
rt *http3.Transport
|
||||
server1 *http3.Server
|
||||
server2 *http3.Server
|
||||
ln *listenerWrapper
|
||||
@@ -106,7 +106,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
rt = &http3.RoundTripper{
|
||||
rt = &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfig(),
|
||||
DisableCompression: true,
|
||||
QUICConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
|
||||
@@ -151,7 +151,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() {
|
||||
Expect(fake1.closed.Load()).To(BeTrue())
|
||||
Expect(fake2.closed.Load()).To(BeFalse())
|
||||
Expect(ln.listenerClosed).ToNot(BeTrue())
|
||||
Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred())
|
||||
Expect(client.Transport.(*http3.Transport).Close()).NotTo(HaveOccurred())
|
||||
|
||||
// verify that new connections are being initiated from the second server now
|
||||
resp, err = client.Get("https://localhost:" + port + "/hello2")
|
||||
|
||||
@@ -46,7 +46,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
var (
|
||||
mux *http.ServeMux
|
||||
client *http.Client
|
||||
rt *http3.RoundTripper
|
||||
tr *http3.Transport
|
||||
server *http3.Server
|
||||
stoppedServing chan struct{}
|
||||
port int
|
||||
@@ -112,20 +112,20 @@ var _ = Describe("HTTP tests", func() {
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(rt.Close()).NotTo(HaveOccurred())
|
||||
Expect(tr.Close()).NotTo(HaveOccurred())
|
||||
Expect(server.Close()).NotTo(HaveOccurred())
|
||||
Eventually(stoppedServing).Should(BeClosed())
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
rt = &http3.RoundTripper{
|
||||
tr = &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfigWithoutServerName(),
|
||||
QUICConfig: getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: 10 * time.Second,
|
||||
}),
|
||||
DisableCompression: true,
|
||||
}
|
||||
client = &http.Client{Transport: rt}
|
||||
client = &http.Client{Transport: tr}
|
||||
})
|
||||
|
||||
It("closes the connection after idle timeout", func() {
|
||||
@@ -166,7 +166,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
var dialCounter int
|
||||
testErr := errors.New("test error")
|
||||
cl := http.Client{
|
||||
Transport: &http3.RoundTripper{
|
||||
Transport: &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfig(),
|
||||
Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
|
||||
dialCounter++
|
||||
@@ -355,7 +355,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
gw.Write([]byte("Hello, World!\n"))
|
||||
})
|
||||
|
||||
client.Transport.(*http3.RoundTripper).DisableCompression = false
|
||||
client.Transport.(*http3.Transport).DisableCompression = false
|
||||
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/gzipped/hello", port))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
@@ -484,8 +484,9 @@ var _ = Describe("HTTP tests", func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
rt := http3.SingleDestinationRoundTripper{Connection: conn}
|
||||
str, err := rt.OpenRequestStream(context.Background())
|
||||
tr := http3.Transport{}
|
||||
cc := tr.NewClientConn(conn)
|
||||
str, err := cc.OpenRequestStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.SendRequestHeader(req)).To(Succeed())
|
||||
// make sure the request is received (and not stuck in some buffer, for example)
|
||||
@@ -677,10 +678,10 @@ var _ = Describe("HTTP tests", func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.CloseWithError(0, "")
|
||||
rt := http3.SingleDestinationRoundTripper{Connection: conn}
|
||||
hconn := rt.Start()
|
||||
Eventually(hconn.ReceivedSettings(), 5*time.Second, 10*time.Millisecond).Should(BeClosed())
|
||||
settings := hconn.Settings()
|
||||
var tr http3.Transport
|
||||
cc := tr.NewClientConn(conn)
|
||||
Eventually(cc.ReceivedSettings(), 5*time.Second, 10*time.Millisecond).Should(BeClosed())
|
||||
settings := cc.Settings()
|
||||
Expect(settings.EnableExtendedConnect).To(BeTrue())
|
||||
Expect(settings.EnableDatagrams).To(BeFalse())
|
||||
Expect(settings.Other).To(BeEmpty())
|
||||
@@ -696,7 +697,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rt = &http3.RoundTripper{
|
||||
tr = &http3.Transport{
|
||||
TLSClientConfig: getTLSClientConfigWithoutServerName(),
|
||||
QUICConfig: getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: 10 * time.Second,
|
||||
@@ -708,7 +709,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/settings", port), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = rt.RoundTrip(req)
|
||||
_, err = tr.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var settings *http3.Settings
|
||||
Expect(settingsChan).To(Receive(&settings))
|
||||
@@ -803,11 +804,9 @@ var _ = Describe("HTTP tests", func() {
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
rt := &http3.SingleDestinationRoundTripper{
|
||||
Connection: conn,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
str, err := rt.OpenRequestStream(context.Background())
|
||||
tr := http3.Transport{EnableDatagrams: true}
|
||||
cc := tr.NewClientConn(conn)
|
||||
str, err := cc.OpenRequestStream(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
u, err := url.Parse(h)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -981,21 +980,21 @@ var _ = Describe("HTTP tests", func() {
|
||||
tlsConf := getTLSClientConfigWithoutServerName()
|
||||
puts := make(chan string, 10)
|
||||
tlsConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(10), nil, puts)
|
||||
rt := &http3.RoundTripper{
|
||||
tr := &http3.Transport{
|
||||
TLSClientConfig: tlsConf,
|
||||
QUICConfig: getQuicConfig(&quic.Config{
|
||||
MaxIdleTimeout: 10 * time.Second,
|
||||
}),
|
||||
DisableCompression: true,
|
||||
}
|
||||
defer rt.Close()
|
||||
defer tr.Close()
|
||||
|
||||
mux.HandleFunc("/0rtt", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(strconv.FormatBool(!r.TLS.HandshakeComplete)))
|
||||
})
|
||||
req, err := http.NewRequest(http3.MethodGet0RTT, fmt.Sprintf("https://localhost:%d/0rtt", proxy.LocalPort()), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rsp, err := rt.RoundTrip(req)
|
||||
rsp, err := tr.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(BeEquivalentTo(200))
|
||||
data, err := io.ReadAll(rsp.Body)
|
||||
@@ -1004,13 +1003,13 @@ var _ = Describe("HTTP tests", func() {
|
||||
Expect(num0RTTPackets.Load()).To(BeZero())
|
||||
Eventually(puts).Should(Receive())
|
||||
|
||||
rt2 := &http3.RoundTripper{
|
||||
TLSClientConfig: rt.TLSClientConfig,
|
||||
QUICConfig: rt.QUICConfig,
|
||||
tr2 := &http3.Transport{
|
||||
TLSClientConfig: tr.TLSClientConfig,
|
||||
QUICConfig: tr.QUICConfig,
|
||||
DisableCompression: true,
|
||||
}
|
||||
defer rt2.Close()
|
||||
rsp, err = rt2.RoundTrip(req)
|
||||
defer tr2.Close()
|
||||
rsp, err = tr2.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(BeEquivalentTo(200))
|
||||
data, err = io.ReadAll(rsp.Body)
|
||||
@@ -1111,7 +1110,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal([]byte("shutdown")))
|
||||
// manually close the client, since we don't support
|
||||
client.Transport.(*http3.RoundTripper).Close()
|
||||
client.Transport.(*http3.Transport).Close()
|
||||
|
||||
// make sure that CloseGracefully returned
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
@@ -67,7 +67,7 @@ func runTestcase(testcase string) error {
|
||||
quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer}
|
||||
|
||||
if testcase == "http3" {
|
||||
r := &http3.RoundTripper{
|
||||
r := &http3.Transport{
|
||||
TLSClientConfig: tlsConf,
|
||||
QUICConfig: quicConf,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user