From 09111b45f93ddce315aa7326027b1c55207b9b58 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 2 Jun 2019 01:08:11 +0800 Subject: [PATCH] reject a tls.Config without NextProtos for dialing --- client.go | 4 ++++ client_test.go | 28 ++++++++++++++++++---------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 78d0d70b..a9d8133b 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,7 @@ package quic import ( "context" "crypto/tls" + "errors" "fmt" "net" "sync" @@ -119,6 +120,9 @@ func dialContext( config *Config, createdPacketConn bool, ) (Session, error) { + if tlsConf == nil || len(tlsConf.NextProtos) == 0 { + return nil, errors.New("quic: NextProtos not set in tls.Config") + } config = populateClientConfig(config, createdPacketConn) packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey) if err != nil { diff --git a/client_test.go b/client_test.go index 87bd1a9e..6c1ce156 100644 --- a/client_test.go +++ b/client_test.go @@ -27,6 +27,7 @@ var _ = Describe("Client", func() { connID protocol.ConnectionID mockMultiplexer *MockMultiplexer origMultiplexer multiplexer + tlsConf *tls.Config originalClientSessConstructor func( conn connection, @@ -65,6 +66,7 @@ var _ = Describe("Client", func() { } BeforeEach(func() { + tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) @@ -148,7 +150,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run() return sess, nil } - _, err := DialAddr("localhost:17890", nil, &Config{HandshakeTimeout: time.Millisecond}) + _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond}) Expect(err).ToNot(HaveOccurred()) Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890"))) }) @@ -178,7 +180,8 @@ var _ = Describe("Client", func() { sess.EXPECT().run() return sess, nil } - _, err := DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil) + tlsConf.ServerName = "foobar" + _, err := DialAddr("localhost:17890", tlsConf, nil) Expect(err).ToNot(HaveOccurred()) Eventually(hostnameChan).Should(Receive(Equal("foobar"))) }) @@ -211,7 +214,7 @@ var _ = Describe("Client", func() { packetConn, addr, "localhost:1337", - nil, + tlsConf, &Config{}, ) Expect(err).ToNot(HaveOccurred()) @@ -247,7 +250,7 @@ var _ = Describe("Client", func() { packetConn, addr, "localhost:1337", - nil, + tlsConf, &Config{}, ) Expect(err).To(MatchError(testErr)) @@ -288,7 +291,7 @@ var _ = Describe("Client", func() { packetConn, addr, "localhost:1337", - nil, + tlsConf, &Config{}, ) Expect(err).To(MatchError(context.Canceled)) @@ -333,7 +336,7 @@ var _ = Describe("Client", func() { packetConn, addr, "localhost:1337", - nil, + tlsConf, &Config{}, ) Expect(err).ToNot(HaveOccurred()) @@ -376,7 +379,7 @@ var _ = Describe("Client", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := DialAddr("localhost:1337", nil, nil) + _, err := DialAddr("localhost:1337", tlsConf, nil) Expect(err).ToNot(HaveOccurred()) close(done) }() @@ -417,10 +420,15 @@ var _ = Describe("Client", func() { mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) version := protocol.VersionNumber(0x1234) - _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) + _, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) }) + It("erros when the tls.Config doesn't contain NextProtos", func() { + _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, nil) + Expect(err).To(MatchError("quic: NextProtos not set in tls.Config")) + }) + It("disables bidirectional streams", func() { config := &Config{ MaxIncomingStreams: -1, @@ -487,7 +495,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run() return sess, nil } - _, err := Dial(packetConn, addr, "localhost:1337", nil, config) + _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) Expect(err).ToNot(HaveOccurred()) Eventually(c).Should(BeClosed()) Expect(cconn.(*conn).pconn).To(Equal(packetConn)) @@ -535,7 +543,7 @@ var _ = Describe("Client", func() { packetConn, addr, "localhost:1337", - nil, + tlsConf, &Config{}, ) Expect(err).To(MatchError(testErr))