diff --git a/client.go b/client.go index 6234a64c5..fc6ed290e 100644 --- a/client.go +++ b/client.go @@ -89,6 +89,7 @@ func DialAddrContext( // The same PacketConn can be used for multiple calls to Dial and Listen, // QUIC connection IDs are used for demultiplexing the different connections. // The host parameter is used for SNI. +// The tls.Config must define an application protocol (using NextProtos). func Dial( pconn net.PacketConn, remoteAddr net.Addr, @@ -121,8 +122,8 @@ func dialContext( config *Config, createdPacketConn bool, ) (Session, error) { - if tlsConf == nil || len(tlsConf.NextProtos) == 0 { - return nil, errors.New("quic: NextProtos not set in tls.Config") + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") } config = populateClientConfig(config, createdPacketConn) packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey) diff --git a/client_test.go b/client_test.go index 60965e69c..8c4cfcbde 100644 --- a/client_test.go +++ b/client_test.go @@ -463,11 +463,6 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) }) - It("erros when the tls.Config doesn't contain NextProtos", func() { - _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, nil) - Expect(err).To(MatchError("quic: NextProtos not set in tls.Config")) - }) - It("disables bidirectional streams", func() { config := &Config{ MaxIncomingStreams: -1, diff --git a/server.go b/server.go index 2c7c451f0..5a267b36c 100644 --- a/server.go +++ b/server.go @@ -128,10 +128,11 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err } // Listen listens for QUIC connections on a given net.PacketConn. -// A single PacketConn only be used for a single call to Listen. +// A single net.PacketConn only be used for a single call to Listen. // The PacketConn can be used for simultaneous calls to Dial. // QUIC connection IDs are used for demultiplexing the different connections. // The tls.Config must not be nil and must contain a certificate configuration. +// Furthermore, it must define an application control (using NextProtos). // The quic.Config may be nil, in that case the default values will be used. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { return listen(conn, tlsConf, config) @@ -139,11 +140,8 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) { // TODO(#1655): only require that tls.Config.Certificates or tls.Config.GetCertificate is set - if tlsConf == nil || len(tlsConf.Certificates) == 0 { - return nil, errors.New("quic: Certificates not set in tls.Config") - } - if len(tlsConf.NextProtos) == 0 { - return nil, errors.New("quic: NextProtos not set in tls.Config") + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") } config = populateServerConfig(config) for _, v := range config.Versions { diff --git a/server_test.go b/server_test.go index 0a966688d..caa5935b9 100644 --- a/server_test.go +++ b/server_test.go @@ -49,20 +49,7 @@ var _ = Describe("Server", func() { It("errors when no tls.Config is given", func() { _, err := ListenAddr("localhost:0", nil, nil) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("quic: Certificates not set in tls.Config")) - }) - - It("errors when no certificates are set in the tls.Config is given", func() { - _, err := ListenAddr("localhost:0", &tls.Config{}, nil) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("quic: Certificates not set in tls.Config")) - }) - - It("errors when NextProtos is not set in the tls.Config", func() { - tlsConf.NextProtos = nil - _, err := ListenAddr("localhost:0", tlsConf, nil) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("quic: NextProtos not set in tls.Config")) + Expect(err.Error()).To(ContainSubstring("quic: tls.Config not set")) }) It("errors when the Config contains an invalid version", func() {