diff --git a/client.go b/client.go index 1b6756407..788aa123a 100644 --- a/client.go +++ b/client.go @@ -74,7 +74,7 @@ func NewClient(host string, cryptoChangeCallback CryptoChangeCallback, versionNe utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version) - err = client.createNewSession() + err = client.createNewSession(nil) if err != nil { return nil, err } @@ -170,7 +170,7 @@ func (c *Client) handlePacket(packet []byte) error { c.versionNegotiated = true c.session.Close(errCloseSessionForNewVersion) - err = c.createNewSession() + err = c.createNewSession(hdr.SupportedVersions) if err != nil { return err } @@ -191,9 +191,9 @@ func (c *Client) handlePacket(packet []byte) error { return nil } -func (c *Client) createNewSession() error { +func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { var err error - c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback, c.cryptoChangeCallback) + c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions) if err != nil { return err } diff --git a/client_test.go b/client_test.go index 4941c30ed..9e790431d 100644 --- a/client_test.go +++ b/client_test.go @@ -5,8 +5,10 @@ import ( "encoding/binary" "errors" "net" + "reflect" "runtime" "time" + "unsafe" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -44,11 +46,12 @@ var _ = Describe("Client", func() { Expect(err).NotTo(HaveOccurred()) } - It("sets the correct hostname for new clients", func() { + It("creates a new client", func() { var err error client, err = NewClient("quic.clemente.io:1337", nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(client.hostname).To(Equal("quic.clemente.io")) + Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*Session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) }) It("errors on invalid public header", func() { @@ -95,7 +98,7 @@ var _ = Describe("Client", func() { startUDPConn() client.session = nil client.hostname = "hostname" - err := client.createNewSession() + err := client.createNewSession(nil) Expect(err).ToNot(HaveOccurred()) Expect(client.session).ToNot(BeNil()) Expect(client.session.(*Session).connectionID).To(Equal(client.connectionID)) @@ -217,6 +220,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) // it didn't pass the version negoation packet to the session (since it has no payload) Expect(session.packetCount).To(BeZero()) + Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*Session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35})) err = client.Close(nil) Expect(err).ToNot(HaveOccurred()) diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index 4e550eeda..b4cd19992 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -16,9 +16,10 @@ import ( ) type cryptoSetupClient struct { - hostname string - connID protocol.ConnectionID - version protocol.VersionNumber + hostname string + connID protocol.ConnectionID + version protocol.VersionNumber + negotiatedVersions []protocol.VersionNumber cryptoStream utils.Stream @@ -62,6 +63,7 @@ func NewCryptoSetupClient( cryptoStream utils.Stream, connectionParameters ConnectionParametersManager, aeadChanged chan struct{}, + negotiatedVersions []protocol.VersionNumber, ) (CryptoSetup, error) { return &cryptoSetupClient{ hostname: hostname, @@ -72,6 +74,7 @@ func NewCryptoSetupClient( connectionParameters: connectionParameters, keyDerivation: crypto.DeriveKeysAESGCM, aeadChanged: aeadChanged, + negotiatedVersions: negotiatedVersions, }, nil } @@ -196,11 +199,13 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS") } - _, ok = cryptoData[TagVER] + verTag, ok := cryptoData[TagVER] if !ok { return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list") } - // TODO: verify versions + if !h.validateVersionList(verTag) { + return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected") + } nonce := append(h.nonc, h.sno...) @@ -236,6 +241,31 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error { return nil } +func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool { + if len(h.negotiatedVersions) == 0 { + return true + } + if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) { + return false + } + + b := bytes.NewReader(verTags) + for _, negotiatedVersion := range h.negotiatedVersions { + verTag, err := utils.ReadUint32(b) + if err != nil { // should never occur, since the length was already checked + return false + } + ver := protocol.VersionTagToNumber(verTag) + if !protocol.IsSupportedVersion(ver) { + ver = protocol.VersionUnsupported + } + if ver != negotiatedVersion { + return false + } + } + return true +} + func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { if h.forwardSecureAEAD != nil { data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index 2a1bcacec..168b85f82 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -15,6 +15,7 @@ import ( "github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" + "github.com/lucas-clemente/quic-go/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -121,7 +122,7 @@ var _ = Describe("Crypto setup", func() { stream = &mockStream{} certManager = &mockCertManager{} version := protocol.Version36 - csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan struct{}, 1)) + csInt, err := NewCryptoSetupClient("hostname", 0, version, stream, NewConnectionParamatersManager(protocol.PerspectiveClient, version), make(chan struct{}, 1), nil) Expect(err).ToNot(HaveOccurred()) cs = csInt.(*cryptoSetupClient) cs.certManager = certManager @@ -189,6 +190,48 @@ var _ = Describe("Crypto setup", func() { Expect(cs.sno).To(Equal(nonc)) }) + Context("validating the Version list", func() { + It("doesn't care about the version list if there was no version negotiation", func() { + Expect(cs.validateVersionList([]byte{0})).To(BeTrue()) + }) + + It("detects a downgrade attack if the number of versions is unequal", func() { + cs.negotiatedVersions = []protocol.VersionNumber{protocol.VersionWhatever} + Expect(cs.validateVersionList(bytes.Repeat([]byte{'f'}, 8))).To(BeFalse()) + }) + + It("detects a downgrade attack", func() { + cs.negotiatedVersions = []protocol.VersionNumber{protocol.Version36} + b := &bytes.Buffer{} + utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.Version35)) + Expect(cs.validateVersionList(b.Bytes())).To(BeFalse()) + }) + + It("errors if the version tags are invalid", func() { + cs.negotiatedVersions = []protocol.VersionNumber{protocol.VersionWhatever} + Expect(cs.validateVersionList([]byte{0, 1, 2})).To(BeFalse()) + }) + + It("doesn't care about unsupported versions", func() { + cs.negotiatedVersions = []protocol.VersionNumber{protocol.VersionUnsupported, protocol.Version36, protocol.VersionUnsupported} + b := &bytes.Buffer{} + b.Write([]byte{0, 0, 0, 0}) + utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.Version36)) + b.Write([]byte{0x13, 0x37, 0x13, 0x37}) + Expect(cs.validateVersionList(b.Bytes())).To(BeTrue()) + }) + + It("returns the right error when detecting a downgrade attack", func() { + cs.negotiatedVersions = []protocol.VersionNumber{protocol.VersionWhatever} + cs.receivedSecurePacket = true + err := cs.handleSHLOMessage(map[Tag][]byte{ + TagPUBS: []byte{0}, + TagVER: []byte{0, 1}, + }) + Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected"))) + }) + }) + Context("Certificates", func() { BeforeEach(func() { cs.serverConfig = &serverConfigClient{} @@ -370,6 +413,16 @@ var _ = Describe("Crypto setup", func() { Expect(cs.HandshakeComplete()).To(BeFalse()) }) + It("accepts a SHLO after a version negotiation", func() { + cs.negotiatedVersions = []protocol.VersionNumber{protocol.Version36} + cs.receivedSecurePacket = true + b := &bytes.Buffer{} + utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.Version36)) + tagMap[TagVER] = b.Bytes() + err := cs.handleSHLOMessage(tagMap) + Expect(err).ToNot(HaveOccurred()) + }) + It("reads the server nonce, if set", func() { tagMap[TagSNO] = []byte("server nonce") err := cs.handleSHLOMessage(tagMap) diff --git a/session.go b/session.go index 33bfa5b43..346b2d3d0 100644 --- a/session.go +++ b/session.go @@ -127,7 +127,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return session, err } -func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback) (*Session, error) { +func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) { session := &Session{ conn: &udpConn{conn: conn, currentAddr: addr}, connectionID: connectionID, @@ -145,7 +145,7 @@ func newClientSession(conn *net.UDPConn, addr *net.UDPAddr, hostname string, v p cryptoStream, _ := session.OpenStream(1) var err error - session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged) + session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, session.connectionParameters, session.aeadChanged, negotiatedVersions) if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index 5c0bf3945..2abe94adf 100644 --- a/session_test.go +++ b/session_test.go @@ -158,6 +158,7 @@ var _ = Describe("Session", func() { func(*Session, utils.Stream) { streamCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true }, func(isForwardSecure bool) {}, + nil, ) Expect(err).ToNot(HaveOccurred()) Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream