forked from quic-go/quic-go
reject a tls.Config without NextProtos for dialing
This commit is contained in:
@@ -3,6 +3,7 @@ package quic
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -119,6 +120,9 @@ func dialContext(
|
|||||||
config *Config,
|
config *Config,
|
||||||
createdPacketConn bool,
|
createdPacketConn bool,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
|
if tlsConf == nil || len(tlsConf.NextProtos) == 0 {
|
||||||
|
return nil, errors.New("quic: NextProtos not set in tls.Config")
|
||||||
|
}
|
||||||
config = populateClientConfig(config, createdPacketConn)
|
config = populateClientConfig(config, createdPacketConn)
|
||||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey)
|
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ var _ = Describe("Client", func() {
|
|||||||
connID protocol.ConnectionID
|
connID protocol.ConnectionID
|
||||||
mockMultiplexer *MockMultiplexer
|
mockMultiplexer *MockMultiplexer
|
||||||
origMultiplexer multiplexer
|
origMultiplexer multiplexer
|
||||||
|
tlsConf *tls.Config
|
||||||
|
|
||||||
originalClientSessConstructor func(
|
originalClientSessConstructor func(
|
||||||
conn connection,
|
conn connection,
|
||||||
@@ -65,6 +66,7 @@ var _ = Describe("Client", func() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
||||||
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
||||||
originalClientSessConstructor = newClientSession
|
originalClientSessConstructor = newClientSession
|
||||||
Eventually(areSessionsRunning).Should(BeFalse())
|
Eventually(areSessionsRunning).Should(BeFalse())
|
||||||
@@ -148,7 +150,7 @@ var _ = Describe("Client", func() {
|
|||||||
sess.EXPECT().run()
|
sess.EXPECT().run()
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
_, err := DialAddr("localhost:17890", nil, &Config{HandshakeTimeout: time.Millisecond})
|
_, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
|
Eventually(remoteAddrChan).Should(Receive(Equal("127.0.0.1:17890")))
|
||||||
})
|
})
|
||||||
@@ -178,7 +180,8 @@ var _ = Describe("Client", func() {
|
|||||||
sess.EXPECT().run()
|
sess.EXPECT().run()
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
_, err := DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil)
|
tlsConf.ServerName = "foobar"
|
||||||
|
_, err := DialAddr("localhost:17890", tlsConf, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Eventually(hostnameChan).Should(Receive(Equal("foobar")))
|
Eventually(hostnameChan).Should(Receive(Equal("foobar")))
|
||||||
})
|
})
|
||||||
@@ -211,7 +214,7 @@ var _ = Describe("Client", func() {
|
|||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
nil,
|
tlsConf,
|
||||||
&Config{},
|
&Config{},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -247,7 +250,7 @@ var _ = Describe("Client", func() {
|
|||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
nil,
|
tlsConf,
|
||||||
&Config{},
|
&Config{},
|
||||||
)
|
)
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
@@ -288,7 +291,7 @@ var _ = Describe("Client", func() {
|
|||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
nil,
|
tlsConf,
|
||||||
&Config{},
|
&Config{},
|
||||||
)
|
)
|
||||||
Expect(err).To(MatchError(context.Canceled))
|
Expect(err).To(MatchError(context.Canceled))
|
||||||
@@ -333,7 +336,7 @@ var _ = Describe("Client", func() {
|
|||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
nil,
|
tlsConf,
|
||||||
&Config{},
|
&Config{},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -376,7 +379,7 @@ var _ = Describe("Client", func() {
|
|||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
_, err := DialAddr("localhost:1337", nil, nil)
|
_, err := DialAddr("localhost:1337", tlsConf, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
@@ -417,10 +420,15 @@ var _ = Describe("Client", func() {
|
|||||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
|
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||||
|
|
||||||
version := protocol.VersionNumber(0x1234)
|
version := protocol.VersionNumber(0x1234)
|
||||||
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
|
_, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}})
|
||||||
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
|
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("erros when the tls.Config doesn't contain NextProtos", func() {
|
||||||
|
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, nil)
|
||||||
|
Expect(err).To(MatchError("quic: NextProtos not set in tls.Config"))
|
||||||
|
})
|
||||||
|
|
||||||
It("disables bidirectional streams", func() {
|
It("disables bidirectional streams", func() {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
MaxIncomingStreams: -1,
|
MaxIncomingStreams: -1,
|
||||||
@@ -487,7 +495,7 @@ var _ = Describe("Client", func() {
|
|||||||
sess.EXPECT().run()
|
sess.EXPECT().run()
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
_, err := Dial(packetConn, addr, "localhost:1337", nil, config)
|
_, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Eventually(c).Should(BeClosed())
|
Eventually(c).Should(BeClosed())
|
||||||
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||||
@@ -535,7 +543,7 @@ var _ = Describe("Client", func() {
|
|||||||
packetConn,
|
packetConn,
|
||||||
addr,
|
addr,
|
||||||
"localhost:1337",
|
"localhost:1337",
|
||||||
nil,
|
tlsConf,
|
||||||
&Config{},
|
&Config{},
|
||||||
)
|
)
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
|
|||||||
Reference in New Issue
Block a user