reject a tls.Config without NextProtos for dialing

This commit is contained in:
Marten Seemann
2019-06-02 01:08:11 +08:00
parent 8eeddeb9c0
commit 09111b45f9
2 changed files with 22 additions and 10 deletions

View File

@@ -3,6 +3,7 @@ package quic
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@@ -119,6 +120,9 @@ func dialContext(
config *Config, config *Config,
createdPacketConn bool, createdPacketConn bool,
) (Session, error) { ) (Session, error) {
if tlsConf == nil || len(tlsConf.NextProtos) == 0 {
return nil, errors.New("quic: NextProtos not set in tls.Config")
}
config = populateClientConfig(config, createdPacketConn) config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey) packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey)
if err != nil { if err != nil {

View File

@@ -27,6 +27,7 @@ var _ = Describe("Client", func() {
connID protocol.ConnectionID connID protocol.ConnectionID
mockMultiplexer *MockMultiplexer mockMultiplexer *MockMultiplexer
origMultiplexer multiplexer origMultiplexer multiplexer
tlsConf *tls.Config
originalClientSessConstructor func( originalClientSessConstructor func(
conn connection, conn connection,
@@ -65,6 +66,7 @@ var _ = Describe("Client", func() {
} }
BeforeEach(func() { BeforeEach(func() {
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
originalClientSessConstructor = newClientSession originalClientSessConstructor = newClientSession
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
@@ -148,7 +150,7 @@ var _ = Describe("Client", func() {
sess.EXPECT().run() sess.EXPECT().run()
return sess, nil return sess, nil
} }
_, err := DialAddr("localhost:17890", nil, &Config{HandshakeTimeout: time.Millisecond}) _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890"))) Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
}) })
@@ -178,7 +180,8 @@ var _ = Describe("Client", func() {
sess.EXPECT().run() sess.EXPECT().run()
return sess, nil return sess, nil
} }
_, err := DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil) tlsConf.ServerName = "foobar"
_, err := DialAddr("localhost:17890", tlsConf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(hostnameChan).Should(Receive(Equal("foobar"))) Eventually(hostnameChan).Should(Receive(Equal("foobar")))
}) })
@@ -211,7 +214,7 @@ var _ = Describe("Client", func() {
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
nil, tlsConf,
&Config{}, &Config{},
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@@ -247,7 +250,7 @@ var _ = Describe("Client", func() {
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
nil, tlsConf,
&Config{}, &Config{},
) )
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
@@ -288,7 +291,7 @@ var _ = Describe("Client", func() {
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
nil, tlsConf,
&Config{}, &Config{},
) )
Expect(err).To(MatchError(context.Canceled)) Expect(err).To(MatchError(context.Canceled))
@@ -333,7 +336,7 @@ var _ = Describe("Client", func() {
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
nil, tlsConf,
&Config{}, &Config{},
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@@ -376,7 +379,7 @@ var _ = Describe("Client", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
_, err := DialAddr("localhost:1337", nil, nil) _, err := DialAddr("localhost:1337", tlsConf, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
close(done) close(done)
}() }()
@@ -417,10 +420,15 @@ var _ = Describe("Client", func() {
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
version := protocol.VersionNumber(0x1234) version := protocol.VersionNumber(0x1234)
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) _, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
}) })
It("erros when the tls.Config doesn't contain NextProtos", func() {
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, nil)
Expect(err).To(MatchError("quic: NextProtos not set in tls.Config"))
})
It("disables bidirectional streams", func() { It("disables bidirectional streams", func() {
config := &Config{ config := &Config{
MaxIncomingStreams: -1, MaxIncomingStreams: -1,
@@ -487,7 +495,7 @@ var _ = Describe("Client", func() {
sess.EXPECT().run() sess.EXPECT().run()
return sess, nil return sess, nil
} }
_, err := Dial(packetConn, addr, "localhost:1337", nil, config) _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed()) Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn)) Expect(cconn.(*conn).pconn).To(Equal(packetConn))
@@ -535,7 +543,7 @@ var _ = Describe("Client", func() {
packetConn, packetConn,
addr, addr,
"localhost:1337", "localhost:1337",
nil, tlsConf,
&Config{}, &Config{},
) )
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))