From 008615284ec4b66c0089dba1aef677b630a9ede5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 29 Nov 2018 15:52:41 +0700 Subject: [PATCH] error when Listen is called without a tls.Config or certificates --- server.go | 4 ++++ server_test.go | 35 ++++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/server.go b/server.go index 33a3123c9..6abde79be 100644 --- a/server.go +++ b/server.go @@ -128,6 +128,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, } func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) { + // TODO(#1655): only require that tls.Config.Certificates or tls.Config.GetCertificate is set + if tlsConf == nil || len(tlsConf.Certificates) == 0 { + return nil, errors.New("quic: Certificates not set in tls.Config") + } config = populateServerConfig(config) for _, v := range config.Versions { if !protocol.IsValidVersion(v) { diff --git a/server_test.go b/server_test.go index e9336fb9f..479423ab2 100644 --- a/server_test.go +++ b/server_test.go @@ -10,6 +10,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" @@ -18,21 +19,37 @@ import ( ) var _ = Describe("Server", func() { - var conn *mockPacketConn + var ( + conn *mockPacketConn + tlsConf *tls.Config + ) BeforeEach(func() { conn = newMockPacketConn() conn.addr = &net.UDPAddr{} + tlsConf = testdata.GetTLSConfig() + }) + + It("errors when no tls.Config is given", func() { + _, err := ListenAddr("localhost:0", nil, nil) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("quic: Certificates not set in tls.Config")) + }) + + It("errors when no certificates are set in the tls.Config is given", func() { + _, err := ListenAddr("localhost:0", &tls.Config{}, nil) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("quic: Certificates not set in tls.Config")) }) It("errors when the Config contains an invalid version", func() { version := protocol.VersionNumber(0x1234) - _, err := Listen(nil, &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) + _, err := Listen(nil, tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) }) It("fills in default values if options are not set in the Config", func() { - ln, err := Listen(conn, &tls.Config{}, &Config{}) + ln, err := Listen(conn, tlsConf, &Config{}) Expect(err).ToNot(HaveOccurred()) server := ln.(*server) Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) @@ -54,7 +71,7 @@ var _ = Describe("Server", func() { IdleTimeout: 42 * time.Minute, KeepAlive: true, } - ln, err := Listen(conn, &tls.Config{}, &config) + ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) server := ln.(*server) Expect(server.sessionHandler).ToNot(BeNil()) @@ -69,7 +86,7 @@ var _ = Describe("Server", func() { It("listens on a given address", func() { addr := "127.0.0.1:13579" - ln, err := ListenAddr(addr, nil, &Config{}) + ln, err := ListenAddr(addr, tlsConf, &Config{}) Expect(err).ToNot(HaveOccurred()) serv := ln.(*server) Expect(serv.Addr().String()).To(Equal(addr)) @@ -79,13 +96,13 @@ var _ = Describe("Server", func() { It("errors if given an invalid address", func() { addr := "127.0.0.1" - _, err := ListenAddr(addr, nil, &Config{}) + _, err := ListenAddr(addr, tlsConf, &Config{}) Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) }) It("errors if given an invalid address", func() { addr := "1.1.1.1:1111" - _, err := ListenAddr(addr, nil, &Config{}) + _, err := ListenAddr(addr, tlsConf, &Config{}) Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) }) @@ -93,7 +110,7 @@ var _ = Describe("Server", func() { var serv *server BeforeEach(func() { - ln, err := Listen(conn, nil, nil) + ln, err := Listen(conn, tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.(*server) }) @@ -306,7 +323,7 @@ var _ = Describe("Server", func() { var serv *server BeforeEach(func() { - ln, err := Listen(conn, nil, nil) + ln, err := Listen(conn, tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.(*server) })