diff --git a/client.go b/client.go index 7818c2af1..6aa1a2ecb 100644 --- a/client.go +++ b/client.go @@ -214,7 +214,7 @@ func (c *client) establishSecureConnection() error { if ev.err != nil { return ev.err } - if c.version != protocol.VersionTLS && ev.encLevel != protocol.EncryptionSecure { + if c.version.UsesTLS() && ev.encLevel != protocol.EncryptionSecure { return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel) } return nil diff --git a/internal/crypto/null_aead.go b/internal/crypto/null_aead.go index fe1943438..e1cc61355 100644 --- a/internal/crypto/null_aead.go +++ b/internal/crypto/null_aead.go @@ -4,7 +4,7 @@ import "github.com/lucas-clemente/quic-go/internal/protocol" // NewNullAEAD creates a NullAEAD func NewNullAEAD(p protocol.Perspective, v protocol.VersionNumber) AEAD { - if v == protocol.VersionTLS { + if v.UsesTLS() { return &nullAEADFNV64a{} } return &nullAEADFNV128a{ diff --git a/internal/protocol/version.go b/internal/protocol/version.go index f4e4af9fc..aa1dbdfbd 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -26,6 +26,11 @@ var SupportedVersions = []VersionNumber{ Version35, } +// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake +func (vn VersionNumber) UsesTLS() bool { + return vn == VersionTLS +} + // VersionNumberToTag maps version numbers ('32') to tags ('Q032') func VersionNumberToTag(vn VersionNumber) uint32 { v := uint32(vn) diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index 6b97d463f..ece2b2690 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -6,6 +6,15 @@ import ( ) var _ = Describe("Version", func() { + It("says if a version supports TLS", func() { + Expect(Version35.UsesTLS()).To(BeFalse()) + Expect(Version36.UsesTLS()).To(BeFalse()) + Expect(Version37.UsesTLS()).To(BeFalse()) + Expect(Version38.UsesTLS()).To(BeFalse()) + Expect(Version39.UsesTLS()).To(BeFalse()) + Expect(VersionTLS.UsesTLS()).To(BeTrue()) + }) + It("converts tags to numbers", func() { Expect(VersionTagToNumber('Q' + '1'<<8 + '2'<<16 + '3'<<24)).To(Equal(VersionNumber(123))) }) diff --git a/session.go b/session.go index 5dece8772..d34036f7e 100644 --- a/session.go +++ b/session.go @@ -200,7 +200,7 @@ func (s *session) setup( verifySourceAddr := func(clientAddr net.Addr, stk *STK) bool { return s.config.AcceptSTK(clientAddr, stk) } - if s.version == protocol.VersionTLS { + if s.version.UsesTLS() { s.cryptoSetup, err = handshake.NewCryptoSetupTLS( "", s.perspective, @@ -224,7 +224,7 @@ func (s *session) setup( } } else { cryptoStream, _ := s.OpenStream() - if s.version == protocol.VersionTLS { + if s.version.UsesTLS() { s.cryptoSetup, err = handshake.NewCryptoSetupTLS( hostname, s.perspective,