correctly handle errors when creating a new gQUIC key exchange

This commit is contained in:
Marten Seemann
2018-03-28 05:33:26 +07:00
parent 1f9ab3b65f
commit 48731221c0
4 changed files with 22 additions and 16 deletions

View File

@@ -19,7 +19,7 @@ import (
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
// KeyExchangeFunction is used to make a new KEX // KeyExchangeFunction is used to make a new KEX
type KeyExchangeFunction func() crypto.KeyExchange type KeyExchangeFunction func() (crypto.KeyExchange, error)
// The CryptoSetupServer handles all things crypto for the Session // The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct { type cryptoSetupServer struct {
@@ -405,7 +405,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
var fsNonce bytes.Buffer var fsNonce bytes.Buffer
fsNonce.Write(clientNonce) fsNonce.Write(clientNonce)
fsNonce.Write(serverNonce) fsNonce.Write(serverNonce)
ephermalKex := h.keyExchange() ephermalKex, err := h.keyExchange()
if err != nil {
return nil, err
}
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS]) ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -179,7 +179,7 @@ var _ = Describe("Server Crypto Setup", func() {
sourceAddrValid = true sourceAddrValid = true
cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid } cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid }
cs.keyDerivation = mockQuicCryptoKeyDerivation cs.keyDerivation = mockQuicCryptoKeyDerivation
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } cs.keyExchange = func() (crypto.KeyExchange, error) { return &mockKEX{ephermal: true}, nil }
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
cs.cryptoStream = stream cs.cryptoStream = stream
}) })

View File

@@ -6,7 +6,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
var ( var (
@@ -24,13 +23,13 @@ var (
// used for all connections for 60 seconds is negligible. Thus we can amortise // used for all connections for 60 seconds is negligible. Thus we can amortise
// the Diffie-Hellman key generation at the server over all the connections in a // the Diffie-Hellman key generation at the server over all the connections in a
// small time span. // small time span.
func getEphermalKEX() (res crypto.KeyExchange) { func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.RLock() kexMutex.RLock()
res = kexCurrent res := kexCurrent
t := kexCurrentTime t := kexCurrentTime
kexMutex.RUnlock() kexMutex.RUnlock()
if res != nil && time.Since(t) < kexLifetime { if res != nil && time.Since(t) < kexLifetime {
return res return res, nil
} }
kexMutex.Lock() kexMutex.Lock()
@@ -39,12 +38,11 @@ func getEphermalKEX() (res crypto.KeyExchange) {
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime { if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
if err != nil { if err != nil {
utils.Errorf("could not set KEX: %s", err.Error()) return nil, err
return kexCurrent
} }
kexCurrent = kex kexCurrent = kex
kexCurrentTime = time.Now() kexCurrentTime = time.Now()
return kexCurrent return kexCurrent, nil
} }
return kexCurrent return kexCurrent, nil
} }

View File

@@ -3,7 +3,6 @@ package handshake
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@@ -11,9 +10,11 @@ import (
var _ = Describe("Ephermal KEX", func() { var _ = Describe("Ephermal KEX", func() {
It("has a consistent KEX", func() { It("has a consistent KEX", func() {
kex1 := getEphermalKEX() kex1, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex1).ToNot(BeNil()) Expect(kex1).ToNot(BeNil())
kex2 := getEphermalKEX() kex2, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex2).ToNot(BeNil()) Expect(kex2).ToNot(BeNil())
Expect(kex1).To(Equal(kex2)) Expect(kex1).To(Equal(kex2))
}) })
@@ -23,8 +24,12 @@ var _ = Describe("Ephermal KEX", func() {
defer func() { defer func() {
kexLifetime = protocol.EphermalKeyLifetime kexLifetime = protocol.EphermalKeyLifetime
}() }()
kex := getEphermalKEX() kex, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex).ToNot(BeNil()) Expect(kex).ToNot(BeNil())
Eventually(func() crypto.KeyExchange { return getEphermalKEX() }).ShouldNot(Equal(kex)) time.Sleep(kexLifetime)
kex2, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex2).ToNot(Equal(kex))
}) })
}) })