parse hostname from address and pass it to the client cryptoSetup

This commit is contained in:
Marten Seemann
2016-11-18 19:38:10 +07:00
parent 4b8508c017
commit 014315d3c7
8 changed files with 114 additions and 20 deletions

View File

@@ -2,8 +2,10 @@ package quic
import (
"bytes"
"errors"
"math/rand"
"net"
"net/url"
"time"
"github.com/lucas-clemente/quic-go/protocol"
@@ -22,8 +24,26 @@ type Client struct {
session *Session
}
var errHostname = errors.New("Invalid hostname")
// NewClient makes a new client
func NewClient(addr *net.UDPAddr) (*Client, error) {
func NewClient(addr string) (*Client, error) {
hostname, err := utils.HostnameFromAddr(addr)
if err != nil || len(hostname) == 0 {
return nil, errHostname
}
p, err := url.Parse(addr)
if err != nil {
return nil, err
}
host := p.Host
udpAddr, err := net.ResolveUDPAddr("udp", host)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
@@ -33,10 +53,10 @@ func NewClient(addr *net.UDPAddr) (*Client, error) {
rand.Seed(time.Now().UTC().UnixNano())
connectionID := protocol.ConnectionID(rand.Int63())
utils.Infof("Starting new connection to %s, connectionID %x", addr.String(), connectionID)
utils.Infof("Starting new connection to %s (%s), connectionID %x", host, udpAddr.String(), connectionID)
client := &Client{
addr: addr,
addr: udpAddr,
conn: conn,
version: protocol.Version36,
connectionID: connectionID,
@@ -44,7 +64,7 @@ func NewClient(addr *net.UDPAddr) (*Client, error) {
streamCallback := func(session *Session, stream utils.Stream) {}
client.session, err = newClientSession(conn, addr, client.version, client.connectionID, streamCallback, client.closeCallback)
client.session, err = newClientSession(conn, udpAddr, hostname, client.version, client.connectionID, streamCallback, client.closeCallback)
if err != nil {
return nil, err
}

View File

@@ -1,23 +1,16 @@
package main
import (
"net"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/utils"
)
func main() {
addr := "quic.clemente.io:6121"
addr := "https://quic.clemente.io:6121"
utils.SetLogLevel(utils.LogLevelDebug)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
panic(err)
}
client, err := quic.NewClient(udpAddr)
client, err := quic.NewClient(addr)
if err != nil {
panic(err)
}

View File

@@ -15,6 +15,7 @@ import (
)
type cryptoSetupClient struct {
hostname string
connID protocol.ConnectionID
version protocol.VersionNumber
@@ -48,11 +49,13 @@ var (
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
func NewCryptoSetupClient(
hostname string,
connID protocol.ConnectionID,
version protocol.VersionNumber,
cryptoStream utils.Stream,
) (CryptoSetup, error) {
return &cryptoSetupClient{
hostname: hostname,
connID: connID,
version: version,
cryptoStream: cryptoStream,
@@ -282,7 +285,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
func (h *cryptoSetupClient) getTags() map[Tag][]byte {
tags := make(map[Tag][]byte)
tags[TagSNI] = []byte("quic.clemente.io") // TODO: use real SNI here
tags[TagSNI] = []byte(h.hostname)
tags[TagPDMD] = []byte("X509")
versionTag := make([]byte, 4, 4)

View File

@@ -75,7 +75,7 @@ var _ = Describe("Crypto setup", func() {
BeforeEach(func() {
stream = &mockStream{}
certManager = &mockCertManager{}
csInt, err := NewCryptoSetupClient(0, protocol.Version36, stream)
csInt, err := NewCryptoSetupClient("hostname", 0, protocol.Version36, stream)
Expect(err).ToNot(HaveOccurred())
cs = csInt.(*cryptoSetupClient)
cs.certManager = certManager
@@ -302,8 +302,9 @@ var _ = Describe("Crypto setup", func() {
})
It("has the right values for an inchoate CHLO", func() {
cs.hostname = "sni-hostname"
tags := cs.getTags()
Expect(tags).To(HaveKey(TagSNI))
Expect(string(tags[TagSNI])).To(Equal(cs.hostname))
Expect(tags[TagPDMD]).To(Equal([]byte("X509")))
Expect(tags[TagVER]).To(Equal([]byte("Q036")))
})

View File

@@ -120,7 +120,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return session, err
}
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback) (*Session, error) {
func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback) (*Session, error) {
session := &Session{
conn: &udpConn{conn: conn, currentAddr: addr},
connectionID: connectionID,
@@ -136,7 +136,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, v protocol.VersionNu
cryptoStream, _ := session.GetOrOpenStream(1)
var err error
session.cryptoSetup, err = handshake.NewCryptoSetupClient(connectionID, v, cryptoStream)
session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream)
if err != nil {
return nil, err
}

View File

@@ -152,6 +152,7 @@ var _ = Describe("Session", func() {
clientSession, err = newClientSession(
&net.UDPConn{},
&net.UDPAddr{},
"hostname",
protocol.Version35,
0,
func(*Session, utils.Stream) { streamCallbackCalled = true },

27
utils/host.go Normal file
View File

@@ -0,0 +1,27 @@
package utils
import (
"net/url"
"strings"
)
// HostnameFromAddr determines the hostname in an address string
func HostnameFromAddr(addr string) (string, error) {
p, err := url.Parse(addr)
if err != nil {
return "", err
}
h := p.Host
// copied from https://golang.org/src/net/http/transport.go
if hasPort(h) {
h = h[:strings.LastIndex(h, ":")]
}
return h, nil
}
// copied from https://golang.org/src/net/http/http.go
func hasPort(s string) bool {
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
}

49
utils/host_test.go Normal file
View File

@@ -0,0 +1,49 @@
package utils
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Hostname", func() {
It("gets the hostname from an URL", func() {
h, err := HostnameFromAddr("https://quic.clemente.io/file.dat?param=true&param2=false")
Expect(err).ToNot(HaveOccurred())
Expect(h).To(Equal("quic.clemente.io"))
})
It("gets the hostname from an URL with a port number", func() {
h, err := HostnameFromAddr("https://quic.clemente.io:6121/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(h).To(Equal("quic.clemente.io"))
})
It("gets the hostname from an URL containing username and password", func() {
h, err := HostnameFromAddr("https://user:password@quic.clemente.io:6121/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(h).To(Equal("quic.clemente.io"))
})
It("gets local hostnames", func() {
h, err := HostnameFromAddr("https://localhost/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(h).To(Equal("localhost"))
})
It("gets the hostname for other protocols", func() {
h, err := HostnameFromAddr("ftp://quic.clemente.io:6121/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(h).To(Equal("quic.clemente.io"))
})
It("gets an IP", func() {
h, err := HostnameFromAddr("https://1.3.3.7:6121/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(h).To(Equal("1.3.3.7"))
})
It("errors on malformed URLs", func() {
_, err := HostnameFromAddr("://quic.clemente.io:6121/file.dat")
Expect(err).To(HaveOccurred())
})
})