From 014315d3c7abe801e04ef36ae1f0cb56252c7074 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 18 Nov 2016 19:38:10 +0700 Subject: [PATCH] parse hostname from address and pass it to the client cryptoSetup --- client.go | 28 ++++++++++++--- example/client/client.go | 11 ++---- handshake/crypto_setup_client.go | 9 +++-- handshake/crypto_setup_client_test.go | 5 +-- session.go | 4 +-- session_test.go | 1 + utils/host.go | 27 +++++++++++++++ utils/host_test.go | 49 +++++++++++++++++++++++++++ 8 files changed, 114 insertions(+), 20 deletions(-) create mode 100644 utils/host.go create mode 100644 utils/host_test.go diff --git a/client.go b/client.go index 5fa62430c..16cdc897b 100644 --- a/client.go +++ b/client.go @@ -2,8 +2,10 @@ package quic import ( "bytes" + "errors" "math/rand" "net" + "net/url" "time" "github.com/lucas-clemente/quic-go/protocol" @@ -22,8 +24,26 @@ type Client struct { session *Session } +var errHostname = errors.New("Invalid hostname") + // NewClient makes a new client -func NewClient(addr *net.UDPAddr) (*Client, error) { +func NewClient(addr string) (*Client, error) { + hostname, err := utils.HostnameFromAddr(addr) + if err != nil || len(hostname) == 0 { + return nil, errHostname + } + + p, err := url.Parse(addr) + if err != nil { + return nil, err + } + host := p.Host + + udpAddr, err := net.ResolveUDPAddr("udp", host) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return nil, err @@ -33,10 +53,10 @@ func NewClient(addr *net.UDPAddr) (*Client, error) { rand.Seed(time.Now().UTC().UnixNano()) connectionID := protocol.ConnectionID(rand.Int63()) - utils.Infof("Starting new connection to %s, connectionID %x", addr.String(), connectionID) + utils.Infof("Starting new connection to %s (%s), connectionID %x", host, udpAddr.String(), connectionID) client := &Client{ - addr: addr, + addr: udpAddr, conn: conn, version: protocol.Version36, connectionID: connectionID, @@ -44,7 +64,7 @@ func NewClient(addr *net.UDPAddr) (*Client, error) { streamCallback := func(session *Session, stream utils.Stream) {} - client.session, err = newClientSession(conn, addr, client.version, client.connectionID, streamCallback, client.closeCallback) + client.session, err = newClientSession(conn, udpAddr, hostname, client.version, client.connectionID, streamCallback, client.closeCallback) if err != nil { return nil, err } diff --git a/example/client/client.go b/example/client/client.go index 4a466b832..0925033e3 100644 --- a/example/client/client.go +++ b/example/client/client.go @@ -1,23 +1,16 @@ package main import ( - "net" - quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/utils" ) func main() { - addr := "quic.clemente.io:6121" + addr := "https://quic.clemente.io:6121" utils.SetLogLevel(utils.LogLevelDebug) - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - panic(err) - } - - client, err := quic.NewClient(udpAddr) + client, err := quic.NewClient(addr) if err != nil { panic(err) } diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index 47cc835d4..9bd6a7889 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -15,8 +15,9 @@ import ( ) type cryptoSetupClient struct { - connID protocol.ConnectionID - version protocol.VersionNumber + hostname string + connID protocol.ConnectionID + version protocol.VersionNumber cryptoStream utils.Stream @@ -48,11 +49,13 @@ var ( // NewCryptoSetupClient creates a new CryptoSetup instance for a client func NewCryptoSetupClient( + hostname string, connID protocol.ConnectionID, version protocol.VersionNumber, cryptoStream utils.Stream, ) (CryptoSetup, error) { return &cryptoSetupClient{ + hostname: hostname, connID: connID, version: version, cryptoStream: cryptoStream, @@ -282,7 +285,7 @@ func (h *cryptoSetupClient) sendCHLO() error { func (h *cryptoSetupClient) getTags() map[Tag][]byte { tags := make(map[Tag][]byte) - tags[TagSNI] = []byte("quic.clemente.io") // TODO: use real SNI here + tags[TagSNI] = []byte(h.hostname) tags[TagPDMD] = []byte("X509") versionTag := make([]byte, 4, 4) diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index 1156c2964..d19bc4216 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -75,7 +75,7 @@ var _ = Describe("Crypto setup", func() { BeforeEach(func() { stream = &mockStream{} certManager = &mockCertManager{} - csInt, err := NewCryptoSetupClient(0, protocol.Version36, stream) + csInt, err := NewCryptoSetupClient("hostname", 0, protocol.Version36, stream) Expect(err).ToNot(HaveOccurred()) cs = csInt.(*cryptoSetupClient) cs.certManager = certManager @@ -302,8 +302,9 @@ var _ = Describe("Crypto setup", func() { }) It("has the right values for an inchoate CHLO", func() { + cs.hostname = "sni-hostname" tags := cs.getTags() - Expect(tags).To(HaveKey(TagSNI)) + Expect(string(tags[TagSNI])).To(Equal(cs.hostname)) Expect(tags[TagPDMD]).To(Equal([]byte("X509"))) Expect(tags[TagVER]).To(Equal([]byte("Q036"))) }) diff --git a/session.go b/session.go index 874b0faa4..8633eb524 100644 --- a/session.go +++ b/session.go @@ -120,7 +120,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return session, err } -func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback) (*Session, error) { +func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback) (*Session, error) { session := &Session{ conn: &udpConn{conn: conn, currentAddr: addr}, connectionID: connectionID, @@ -136,7 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, v protocol.VersionNu cryptoStream, _ := session.GetOrOpenStream(1) var err error - session.cryptoSetup, err = handshake.NewCryptoSetupClient(connectionID, v, cryptoStream) + session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream) if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index 9bf488e9b..0e9d2d34c 100644 --- a/session_test.go +++ b/session_test.go @@ -152,6 +152,7 @@ var _ = Describe("Session", func() { clientSession, err = newClientSession( &net.UDPConn{}, &net.UDPAddr{}, + "hostname", protocol.Version35, 0, func(*Session, utils.Stream) { streamCallbackCalled = true }, diff --git a/utils/host.go b/utils/host.go new file mode 100644 index 000000000..a1d6453b0 --- /dev/null +++ b/utils/host.go @@ -0,0 +1,27 @@ +package utils + +import ( + "net/url" + "strings" +) + +// HostnameFromAddr determines the hostname in an address string +func HostnameFromAddr(addr string) (string, error) { + p, err := url.Parse(addr) + if err != nil { + return "", err + } + h := p.Host + + // copied from https://golang.org/src/net/http/transport.go + if hasPort(h) { + h = h[:strings.LastIndex(h, ":")] + } + + return h, nil +} + +// copied from https://golang.org/src/net/http/http.go +func hasPort(s string) bool { + return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") +} diff --git a/utils/host_test.go b/utils/host_test.go new file mode 100644 index 000000000..d7667eb3c --- /dev/null +++ b/utils/host_test.go @@ -0,0 +1,49 @@ +package utils + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Hostname", func() { + It("gets the hostname from an URL", func() { + h, err := HostnameFromAddr("https://quic.clemente.io/file.dat?param=true¶m2=false") + Expect(err).ToNot(HaveOccurred()) + Expect(h).To(Equal("quic.clemente.io")) + }) + + It("gets the hostname from an URL with a port number", func() { + h, err := HostnameFromAddr("https://quic.clemente.io:6121/file.dat") + Expect(err).ToNot(HaveOccurred()) + Expect(h).To(Equal("quic.clemente.io")) + }) + + It("gets the hostname from an URL containing username and password", func() { + h, err := HostnameFromAddr("https://user:password@quic.clemente.io:6121/file.dat") + Expect(err).ToNot(HaveOccurred()) + Expect(h).To(Equal("quic.clemente.io")) + }) + + It("gets local hostnames", func() { + h, err := HostnameFromAddr("https://localhost/file.dat") + Expect(err).ToNot(HaveOccurred()) + Expect(h).To(Equal("localhost")) + }) + + It("gets the hostname for other protocols", func() { + h, err := HostnameFromAddr("ftp://quic.clemente.io:6121/file.dat") + Expect(err).ToNot(HaveOccurred()) + Expect(h).To(Equal("quic.clemente.io")) + }) + + It("gets an IP", func() { + h, err := HostnameFromAddr("https://1.3.3.7:6121/file.dat") + Expect(err).ToNot(HaveOccurred()) + Expect(h).To(Equal("1.3.3.7")) + }) + + It("errors on malformed URLs", func() { + _, err := HostnameFromAddr("://quic.clemente.io:6121/file.dat") + Expect(err).To(HaveOccurred()) + }) +})