From 755a46d6e261df6b62b805c550e05358af4bdc8f Mon Sep 17 00:00:00 2001 From: spacewander Date: Sat, 6 Jul 2019 15:46:26 +0800 Subject: [PATCH] allow host without port passed as 'host' argument in Dial function. Previously, if the given host doesn't contain port, dial with it will result in error "missing port in address". --- client.go | 14 ++++++++++---- client_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 63d52393c..6234a64c5 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "strings" "sync" "github.com/lucas-clemente/quic-go/internal/handshake" @@ -151,11 +152,16 @@ func newClient( tlsConf = &tls.Config{} } if tlsConf.ServerName == "" { - var err error - tlsConf.ServerName, _, err = net.SplitHostPort(host) - if err != nil { - return nil, err + sni := host + if strings.IndexByte(sni, ':') != -1 { + var err error + sni, _, err = net.SplitHostPort(sni) + if err != nil { + return nil, err + } } + + tlsConf.ServerName = sni } // check that all versions are actually supported diff --git a/client_test.go b/client_test.go index d17f50661..60965e69c 100644 --- a/client_test.go +++ b/client_test.go @@ -187,6 +187,41 @@ var _ = Describe("Client", func() { Eventually(hostnameChan).Should(Receive(Equal("foobar"))) }) + It("allows passing host without port as server name", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + + hostnameChan := make(chan string, 1) + newClientSession = func( + _ connection, + _ sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + tlsConf *tls.Config, + _ protocol.PacketNumber, + _ *handshake.TransportParameters, + _ protocol.VersionNumber, + _ utils.Logger, + _ protocol.VersionNumber, + ) (quicSession, error) { + hostnameChan <- tlsConf.ServerName + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().run() + return sess, nil + } + _, err := Dial( + packetConn, + addr, + "test.com", + tlsConf, + &Config{}, + ) + Expect(err).ToNot(HaveOccurred()) + Eventually(hostnameChan).Should(Receive(Equal("test.com"))) + }) + It("returns after the handshake is complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any())