forked from quic-go/quic-go
reject a tls.Config without NextProtos for dialing
This commit is contained in:
@@ -3,6 +3,7 @@ package quic
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -119,6 +120,9 @@ func dialContext(
|
||||
config *Config,
|
||||
createdPacketConn bool,
|
||||
) (Session, error) {
|
||||
if tlsConf == nil || len(tlsConf.NextProtos) == 0 {
|
||||
return nil, errors.New("quic: NextProtos not set in tls.Config")
|
||||
}
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,6 +27,7 @@ var _ = Describe("Client", func() {
|
||||
connID protocol.ConnectionID
|
||||
mockMultiplexer *MockMultiplexer
|
||||
origMultiplexer multiplexer
|
||||
tlsConf *tls.Config
|
||||
|
||||
originalClientSessConstructor func(
|
||||
conn connection,
|
||||
@@ -65,6 +66,7 @@ var _ = Describe("Client", func() {
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
||||
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
||||
originalClientSessConstructor = newClientSession
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
@@ -148,7 +150,7 @@ var _ = Describe("Client", func() {
|
||||
sess.EXPECT().run()
|
||||
return sess, nil
|
||||
}
|
||||
_, err := DialAddr("localhost:17890", nil, &Config{HandshakeTimeout: time.Millisecond})
|
||||
_, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
|
||||
})
|
||||
@@ -178,7 +180,8 @@ var _ = Describe("Client", func() {
|
||||
sess.EXPECT().run()
|
||||
return sess, nil
|
||||
}
|
||||
_, err := DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil)
|
||||
tlsConf.ServerName = "foobar"
|
||||
_, err := DialAddr("localhost:17890", tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(hostnameChan).Should(Receive(Equal("foobar")))
|
||||
})
|
||||
@@ -211,7 +214,7 @@ var _ = Describe("Client", func() {
|
||||
packetConn,
|
||||
addr,
|
||||
"localhost:1337",
|
||||
nil,
|
||||
tlsConf,
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -247,7 +250,7 @@ var _ = Describe("Client", func() {
|
||||
packetConn,
|
||||
addr,
|
||||
"localhost:1337",
|
||||
nil,
|
||||
tlsConf,
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
@@ -288,7 +291,7 @@ var _ = Describe("Client", func() {
|
||||
packetConn,
|
||||
addr,
|
||||
"localhost:1337",
|
||||
nil,
|
||||
tlsConf,
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
@@ -333,7 +336,7 @@ var _ = Describe("Client", func() {
|
||||
packetConn,
|
||||
addr,
|
||||
"localhost:1337",
|
||||
nil,
|
||||
tlsConf,
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -376,7 +379,7 @@ var _ = Describe("Client", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := DialAddr("localhost:1337", nil, nil)
|
||||
_, err := DialAddr("localhost:1337", tlsConf, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
@@ -417,10 +420,15 @@ var _ = Describe("Client", func() {
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
version := protocol.VersionNumber(0x1234)
|
||||
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
|
||||
_, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
|
||||
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
|
||||
})
|
||||
|
||||
It("erros when the tls.Config doesn't contain NextProtos", func() {
|
||||
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, nil)
|
||||
Expect(err).To(MatchError("quic: NextProtos not set in tls.Config"))
|
||||
})
|
||||
|
||||
It("disables bidirectional streams", func() {
|
||||
config := &Config{
|
||||
MaxIncomingStreams: -1,
|
||||
@@ -487,7 +495,7 @@ var _ = Describe("Client", func() {
|
||||
sess.EXPECT().run()
|
||||
return sess, nil
|
||||
}
|
||||
_, err := Dial(packetConn, addr, "localhost:1337", nil, config)
|
||||
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||
@@ -535,7 +543,7 @@ var _ = Describe("Client", func() {
|
||||
packetConn,
|
||||
addr,
|
||||
"localhost:1337",
|
||||
nil,
|
||||
tlsConf,
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
|
||||
Reference in New Issue
Block a user