From f92b0ec74af16be1493cf0228469dd9c77391730 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 21 Dec 2020 16:43:49 +0700 Subject: [PATCH] make the HTTP/3 client request tests more strict --- http3/client_test.go | 82 +++++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 39 deletions(-) diff --git a/http3/client_test.go b/http3/client_test.go index 29b6909e1..4af17fb4e 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -147,34 +147,34 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(testErr)) }) - It("errors if it can't open a stream", func() { - testErr := errors.New("stream open error") - client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - session := mockquic.NewMockEarlySession(mockCtrl) - session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1) - session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1) - session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1) - session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { - return session, nil - } - defer GinkgoRecover() - _, err = client.RoundTrip(req) - Expect(err).To(MatchError(testErr)) - }) - It("closes correctly if session was not created", func() { client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(client.Close()).To(Succeed()) }) + Context("validating the address", func() { + It("refuses to do requests for the wrong host", func() { + req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = client.RoundTrip(req) + Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) + }) + + It("refuses to do plain HTTP requests", func() { + req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = client.RoundTrip(req) + Expect(err).To(MatchError("http3: unsupported scheme")) + }) + }) + Context("Doing requests", func() { var ( - request *http.Request - str *mockquic.MockStream - sess *mockquic.MockEarlySession + request *http.Request + str *mockquic.MockStream + sess *mockquic.MockEarlySession + settingsFrameWritten chan struct{} ) decodeHeader := func(str io.Reader) map[string]string { @@ -197,12 +197,19 @@ var _ = Describe("Client", func() { } BeforeEach(func() { + settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write([]byte{0x0}).Return(1, nil).MaxTimes(1) - controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + defer GinkgoRecover() + r := bytes.NewReader(b) + streamType, err := utils.ReadVarInt(r) + Expect(err).ToNot(HaveOccurred()) + Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) + close(settingsFrameWritten) + }) // SETTINGS frame str = mockquic.NewMockStream(mockCtrl) sess = mockquic.NewMockEarlySession(mockCtrl) - sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1) + sess.EXPECT().OpenUniStream().Return(controlStr, nil) dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil } @@ -211,6 +218,19 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) }) + AfterEach(func() { + Eventually(settingsFrameWritten).Should(BeClosed()) + }) + + It("errors if it can't open a stream", func() { + testErr := errors.New("stream open error") + sess.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) + sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) + sess.EXPECT().HandshakeComplete().Return(handshakeCtx) + _, err := client.RoundTrip(request) + Expect(err).To(MatchError(testErr)) + }) + It("performs a 0-RTT request", func() { testErr := errors.New("stream open error") request.Method = MethodGet0RTT @@ -253,22 +273,6 @@ var _ = Describe("Client", func() { Expect(rsp.StatusCode).To(Equal(418)) }) - Context("validating the address", func() { - It("refuses to do requests for the wrong host", func() { - req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTrip(req) - Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) - }) - - It("refuses to do plain HTTP requests", func() { - req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTrip(req) - Expect(err).To(MatchError("http3: unsupported scheme")) - }) - }) - Context("requests containing a Body", func() { var strBuf *bytes.Buffer