Files
quic-go/http3/client_test.go
Lorenzo Saino 8db2288382 Make http3.client.Close() succeed if session was not started
Invoking http3.client.Close() before client.dial() is invoked
causes a segmentation fault. That occurs because, in this
circumstance, invoking client.Close() results in invoking
client.session.CloseWithError(...) while client.session is nil.

This commit changes the behavior of
http3.client.Close() to return nil if client.session
is nil and adds an associated test case.
2020-02-23 00:21:19 +00:00

496 lines
18 KiB
Go

package http3
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
"io/ioutil"
"net/http"
"time"
"github.com/golang/mock/gomock"
quic "github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client", func() {
var (
client *client
req *http.Request
origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context
)
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
Expect(client.hostname).To(Equal(hostname))
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("uses the default QUIC and TLS config if none is give", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
Expect(quicConf).To(Equal(defaultQuicConfig))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
Expect(hostname).To(Equal("quic.clemente.io:443"))
dialAddrCalled = true
return nil, errors.New("test done")
}
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
Expect(err).ToNot(HaveOccurred())
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
})
It("uses the TLS config and QUIC config", func() {
tlsConf := &tls.Config{
ServerName: "foo.bar",
NextProtos: []string{"proto foo", "proto bar"},
}
quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
var dialAddrCalled bool
dialAddr = func(
hostname string,
tlsConfP *tls.Config,
quicConfP *quic.Config,
) (quic.EarlySession, error) {
Expect(hostname).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
// make sure the original tls.Config was not modified
Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
})
It("uses the custom dialer, if provided", func() {
testErr := errors.New("test done")
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
var dialerCalled bool
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
Expect(network).To(Equal("udp"))
Expect(address).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
dialerCalled = true
return nil, testErr
}
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
Expect(dialerCalled).To(BeTrue())
})
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("errors if it can't open a stream", func() {
testErr := errors.New("stream open error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session := mockquic.NewMockEarlySession(mockCtrl)
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return session, nil
}
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("closes correctly if session was not created", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
err := client.Close()
Expect(err).ToNot(HaveOccurred())
})
Context("Doing requests", func() {
var (
request *http.Request
str *mockquic.MockStream
sess *mockquic.MockEarlySession
)
decodeHeader := func(str io.Reader) map[string]string {
fields := make(map[string]string)
decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
hfs, err := decoder.DecodeFull(data)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
for _, p := range hfs {
fields[p.Name] = p.Value
}
return fields
}
BeforeEach(func() {
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write([]byte{0x0}).Return(1, nil).MaxTimes(1)
controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame
str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockEarlySession(mockCtrl)
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return sess, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
It("performs a 0-RTT request", func() {
testErr := errors.New("stream open error")
request.Method = MethodGet0RTT
// don't EXPECT any calls to HandshakeComplete()
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
}).AnyTimes()
str.EXPECT().Close()
str.EXPECT().CancelWrite(gomock.Any())
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
return 0, testErr
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
})
It("returns a response", func() {
rspBuf := &bytes.Buffer{}
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return rspBuf.Read(p)
}).AnyTimes()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3"))
Expect(rsp.ProtoMajor).To(Equal(3))
Expect(rsp.StatusCode).To(Equal(418))
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("refuses to do plain HTTP requests", func() {
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("http3: unsupported scheme"))
})
})
Context("requests containing a Body", func() {
var strBuf *bytes.Buffer
BeforeEach(func() {
strBuf = &bytes.Buffer{}
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
body := &mockBody{}
body.SetData([]byte("request body"))
var err error
request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
Expect(err).ToNot(HaveOccurred())
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return strBuf.Write(p)
}).AnyTimes()
})
It("sends a request", func() {
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().Close().Do(func() { close(done) }),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
)
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
})
It("returns the error that occurred when reading the body", func() {
request.Body.(*mockBody).readErr = errors.New("testErr")
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) {
close(done)
}),
str.EXPECT().CancelWrite(gomock.Any()),
)
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
})
It("closes the connection when the first frame is not a HEADERS frame", func() {
buf := &bytes.Buffer{}
(&dataFrame{Length: 0x42}).Write(buf)
sess.EXPECT().CloseWithError(quic.ErrorCode(errorFrameUnexpected), gomock.Any())
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Read(b)
}).AnyTimes()
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
Eventually(closed).Should(BeClosed())
})
It("cancels the stream when the HEADERS frame is too large", func() {
buf := &bytes.Buffer{}
(&headersFrame{Length: 1338}).Write(buf)
str.EXPECT().CancelWrite(quic.ErrorCode(errorFrameError))
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Read(b)
}).AnyTimes()
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
Eventually(closed).Should(BeClosed())
})
})
Context("request cancellations", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(context.Background())
errChan := make(chan error)
go func() {
_, err := client.RoundTrip(req)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
done := make(chan struct{})
canceled := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) }),
)
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel()
<-canceled
return 0, errors.New("test done")
})
_, err := client.RoundTrip(req)
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
It("cancels a request after the response arrived", func() {
rspBuf := &bytes.Buffer{}
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return rspBuf.Read(b)
}).AnyTimes()
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled))
str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) })
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(done).Should(BeClosed())
})
})
Context("gzip compression", func() {
BeforeEach(func() {
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
})
It("adds the gzip header to requests", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
gomock.InOrder(
str.EXPECT().Close(),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("doesn't add gzip if the header disable it", func() {
client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
gomock.InOrder(
str.EXPECT().Close(),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).ToNot(HaveKey("accept-encoding"))
})
It("decompresses the response", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
gz.Close()
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(string(data)).To(Equal("gzipped response"))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(rsp.Uncompressed).To(BeTrue())
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(string(data)).To(Equal("not gzipped"))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
})
})
})
})