From edb34b17658186e3b3be1453691283d1673e1ff0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 19 Dec 2016 14:41:44 +0700 Subject: [PATCH] close the quic client on protocol errors in h2quic client --- h2quic/client.go | 9 +++++++++ h2quic/client_test.go | 29 ++++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index 585347113..65bca65b5 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -162,12 +162,15 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { c.responses[dataStreamID] = hdrChan c.mutex.Unlock() + // TODO: think about what to do with a TooManyOpenStreams error. Wait and retry? dataStream, err := c.client.OpenStream(dataStreamID) if err != nil { + c.Close(err) return nil, err } err = c.requestWriter.WriteRequest(req, dataStreamID) if err != nil { + c.Close(err) return nil, err } @@ -181,6 +184,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { // if an error occured on the header stream if res == nil { + c.Close(c.headerErr) return nil, c.headerErr } @@ -203,6 +207,11 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { return res, nil } +// Close closes the client +func (c *Client) Close(e error) { + _ = c.client.Close(e) +} + // copied from net/transport.go // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 2e8aeec52..59ab85ea4 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -14,11 +14,12 @@ import ( ) type mockQuicClient struct { - streams map[protocol.StreamID]*mockStream + streams map[protocol.StreamID]*mockStream + closeErr error } -func (m *mockQuicClient) Close(error) error { panic("not implemented") } -func (m *mockQuicClient) Listen() error { panic("not implemented") } +func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil } +func (m *mockQuicClient) Listen() error { panic("not implemented") } func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) { _, ok := m.streams[id] if ok { @@ -121,6 +122,28 @@ var _ = Describe("Client", func() { close(done) }) + It("closes the quic client when encountering an error on the header stream", func() { + req, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) + + headerStream.dataToRead.Write([]byte("invalid response")) + go client.handleHeaderStream() + + var doRsp *http.Response + var doErr error + var doReturned bool + go func() { + doRsp, doErr = client.Do(req) + doReturned = true + }() + + Eventually(func() bool { return doReturned }).Should(BeTrue()) + Expect(client.headerErr).To(HaveOccurred()) + Expect(doErr).To(MatchError(client.headerErr)) + Expect(doRsp).To(BeNil()) + Expect(client.client.(*mockQuicClient).closeErr).To(MatchError(client.headerErr)) + }) + Context("validating the address", func() { It("refuses to do requests for the wrong host", func() { req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)