diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index 54c2b9bd4..8d4d0935f 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -19,7 +19,7 @@ import ( type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) // KeyExchangeFunction is used to make a new KEX -type KeyExchangeFunction func() crypto.KeyExchange +type KeyExchangeFunction func() (crypto.KeyExchange, error) // The CryptoSetupServer handles all things crypto for the Session type cryptoSetupServer struct { @@ -405,7 +405,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T var fsNonce bytes.Buffer fsNonce.Write(clientNonce) fsNonce.Write(serverNonce) - ephermalKex := h.keyExchange() + ephermalKex, err := h.keyExchange() + if err != nil { + return nil, err + } ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS]) if err != nil { return nil, err diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index dc653e19c..86f108847 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -179,7 +179,7 @@ var _ = Describe("Server Crypto Setup", func() { sourceAddrValid = true cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid } cs.keyDerivation = mockQuicCryptoKeyDerivation - cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } + cs.keyExchange = func() (crypto.KeyExchange, error) { return &mockKEX{ephermal: true}, nil } cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) cs.cryptoStream = stream }) diff --git a/internal/handshake/ephermal_cache.go b/internal/handshake/ephermal_cache.go index 3bccbef06..728715835 100644 --- a/internal/handshake/ephermal_cache.go +++ b/internal/handshake/ephermal_cache.go @@ -6,7 +6,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" ) var ( @@ -24,13 +23,13 @@ var ( // used for all connections for 60 seconds is negligible. Thus we can amortise // the Diffie-Hellman key generation at the server over all the connections in a // small time span. -func getEphermalKEX() (res crypto.KeyExchange) { +func getEphermalKEX() (crypto.KeyExchange, error) { kexMutex.RLock() - res = kexCurrent + res := kexCurrent t := kexCurrentTime kexMutex.RUnlock() if res != nil && time.Since(t) < kexLifetime { - return res + return res, nil } kexMutex.Lock() @@ -39,12 +38,11 @@ func getEphermalKEX() (res crypto.KeyExchange) { if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime { kex, err := crypto.NewCurve25519KEX() if err != nil { - utils.Errorf("could not set KEX: %s", err.Error()) - return kexCurrent + return nil, err } kexCurrent = kex kexCurrentTime = time.Now() - return kexCurrent + return kexCurrent, nil } - return kexCurrent + return kexCurrent, nil } diff --git a/internal/handshake/ephermal_cache_test.go b/internal/handshake/ephermal_cache_test.go index 88943b840..c0c0e5eff 100644 --- a/internal/handshake/ephermal_cache_test.go +++ b/internal/handshake/ephermal_cache_test.go @@ -3,7 +3,6 @@ package handshake import ( "time" - "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -11,9 +10,11 @@ import ( var _ = Describe("Ephermal KEX", func() { It("has a consistent KEX", func() { - kex1 := getEphermalKEX() + kex1, err := getEphermalKEX() + Expect(err).ToNot(HaveOccurred()) Expect(kex1).ToNot(BeNil()) - kex2 := getEphermalKEX() + kex2, err := getEphermalKEX() + Expect(err).ToNot(HaveOccurred()) Expect(kex2).ToNot(BeNil()) Expect(kex1).To(Equal(kex2)) }) @@ -23,8 +24,12 @@ var _ = Describe("Ephermal KEX", func() { defer func() { kexLifetime = protocol.EphermalKeyLifetime }() - kex := getEphermalKEX() + kex, err := getEphermalKEX() + Expect(err).ToNot(HaveOccurred()) Expect(kex).ToNot(BeNil()) - Eventually(func() crypto.KeyExchange { return getEphermalKEX() }).ShouldNot(Equal(kex)) + time.Sleep(kexLifetime) + kex2, err := getEphermalKEX() + Expect(err).ToNot(HaveOccurred()) + Expect(kex2).ToNot(Equal(kex)) }) })