forked from quic-go/quic-go
http3: allow re-dialing of connection after a dial error (#4573)
* http3: do not cache dial error * add an integration test * http3: add a unit test for dial failures --------- Co-authored-by: 世界 <i@sekai.icu>
This commit is contained in:
@@ -166,6 +166,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cl.dialErr != nil {
|
if cl.dialErr != nil {
|
||||||
|
r.removeClient(hostname)
|
||||||
return nil, cl.dialErr
|
return nil, cl.dialErr
|
||||||
}
|
}
|
||||||
defer cl.useCount.Add(-1)
|
defer cl.useCount.Add(-1)
|
||||||
@@ -258,6 +259,7 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
|
|||||||
select {
|
select {
|
||||||
case <-cl.dialing:
|
case <-cl.dialing:
|
||||||
if cl.dialErr != nil {
|
if cl.dialErr != nil {
|
||||||
|
delete(r.clients, hostname)
|
||||||
return nil, false, cl.dialErr
|
return nil, false, cl.dialErr
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -265,6 +265,37 @@ var _ = Describe("RoundTripper", func() {
|
|||||||
Expect(count).To(Equal(1))
|
Expect(count).To(Equal(1))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("redials a connection if dialing failed", func() {
|
||||||
|
cl1 := NewMockSingleRoundTripper(mockCtrl)
|
||||||
|
clientChan <- cl1
|
||||||
|
|
||||||
|
req1, err := http.NewRequest("GET", "https://quic-go.net/foo.html", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
req2, err := http.NewRequest("GET", "https://quic-go.net/bar.html", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
testErr := errors.New("handshake error")
|
||||||
|
conn := mockquic.NewMockEarlyConnection(mockCtrl)
|
||||||
|
var count int
|
||||||
|
rt.Dial = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||||
|
count++
|
||||||
|
if count == 1 {
|
||||||
|
return nil, testErr
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
handshakeChan := make(chan struct{})
|
||||||
|
close(handshakeChan)
|
||||||
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan).MaxTimes(2)
|
||||||
|
cl1.EXPECT().RoundTrip(req2).Return(&http.Response{Request: req2}, nil)
|
||||||
|
_, err = rt.RoundTrip(req1)
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
rsp, err := rt.RoundTrip(req2)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(rsp.Request).To(Equal(req2))
|
||||||
|
Expect(count).To(Equal(2))
|
||||||
|
})
|
||||||
|
|
||||||
It("immediately removes a clients when a request errored", func() {
|
It("immediately removes a clients when a request errored", func() {
|
||||||
cl1 := NewMockSingleRoundTripper(mockCtrl)
|
cl1 := NewMockSingleRoundTripper(mockCtrl)
|
||||||
clientChan <- cl1
|
clientChan <- cl1
|
||||||
|
|||||||
@@ -150,6 +150,29 @@ var _ = Describe("HTTP tests", func() {
|
|||||||
Expect(resp.Header.Get("Content-Length")).To(Equal("6"))
|
Expect(resp.Header.Get("Content-Length")).To(Equal("6"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("re-establishes a QUIC connection after a dial error", func() {
|
||||||
|
var dialCounter int
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
cl := http.Client{
|
||||||
|
Transport: &http3.RoundTripper{
|
||||||
|
TLSClientConfig: getTLSClientConfig(),
|
||||||
|
Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, conf *quic.Config) (quic.EarlyConnection, error) {
|
||||||
|
dialCounter++
|
||||||
|
if dialCounter == 1 { // make the first dial fail
|
||||||
|
return nil, testErr
|
||||||
|
}
|
||||||
|
return quic.DialAddrEarly(ctx, addr, tlsConf, conf)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer cl.Transport.(io.Closer).Close()
|
||||||
|
_, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port))
|
||||||
|
Expect(err).To(MatchError(testErr))
|
||||||
|
resp, err := cl.Get(fmt.Sprintf("https://localhost:%d/hello", port))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp.StatusCode).To(Equal(http.StatusOK))
|
||||||
|
})
|
||||||
|
|
||||||
It("detects stream errors when server panics when writing response", func() {
|
It("detects stream errors when server panics when writing response", func() {
|
||||||
respChan := make(chan struct{})
|
respChan := make(chan struct{})
|
||||||
mux.HandleFunc("/writing_and_panicking", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/writing_and_panicking", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
Reference in New Issue
Block a user