diff --git a/h2quic/client.go b/h2quic/client.go index 58a0e21fd..c1ff76c5f 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -24,7 +24,8 @@ import ( type Client struct { mutex sync.RWMutex - config *quic.Config + dialAddr func(hostname string, config *quic.Config) (quic.Session, error) + config *quic.Config t *QuicRoundTripper @@ -46,6 +47,7 @@ var _ h2quicClient = &Client{} func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client { return &Client{ t: t, + dialAddr: quic.DialAddr, hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), encryptionLevel: protocol.EncryptionUnencrypted, @@ -60,7 +62,7 @@ func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Cli // Dial dials the connection func (c *Client) Dial() error { var err error - c.session, err = quic.DialAddr(c.hostname, c.config) + c.session, err = c.dialAddr(c.hostname, c.config) if err != nil { return err } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index f73ce686c..07664957b 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -5,7 +5,6 @@ import ( "compress/gzip" "crypto/tls" "errors" - "net" "net/http" "golang.org/x/net/http2" @@ -52,16 +51,45 @@ var _ = Describe("Client", func() { }) It("dials", func() { - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + client = NewClient(quicTransport, nil, "localhost") + session.streamToOpen = &mockStream{id: 3} + client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + return session, nil + } + err := client.Dial() Expect(err).ToNot(HaveOccurred()) - client = NewClient(quicTransport, nil, udpConn.LocalAddr().String()) - go client.Dial() - data := make([]byte, 100) - _, err = udpConn.Read(data) - hdr, err := quic.ParsePublicHeader(bytes.NewReader(data), protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.VersionFlag).To(BeTrue()) - Expect(hdr.ConnectionID).ToNot(BeNil()) + Expect(client.session).To(Equal(session)) + }) + + It("errors when dialing fails", func() { + testErr := errors.New("handshake error") + client = NewClient(quicTransport, nil, "localhost") + client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + return nil, testErr + } + err := client.Dial() + Expect(err).To(MatchError(testErr)) + }) + + It("errors if the header stream has the wrong stream ID", func() { + client = NewClient(quicTransport, nil, "localhost") + session.streamToOpen = &mockStream{id: 2} + client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + return session, nil + } + err := client.Dial() + Expect(err).To(MatchError("h2quic Client BUG: StreamID of Header Stream is not 3")) + }) + + It("errors if it can't open a stream", func() { + testErr := errors.New("you shall not pass") + client = NewClient(quicTransport, nil, "localhost") + session.streamOpenErr = testErr + client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + return session, nil + } + err := client.Dial() + Expect(err).To(MatchError(testErr)) }) Context("Doing requests", func() {