From 9562df583883c02de5cfda7200e3fe7f717885a8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 13 May 2017 11:47:36 +0800 Subject: [PATCH] move the STK generation from the ServerConfig to a separate struct --- crypto/source_address_token.go | 6 ++++- crypto/source_address_token_test.go | 4 +-- handshake/crypto_setup_server.go | 11 +++++++-- handshake/crypto_setup_server_test.go | 6 ++--- handshake/server_config.go | 12 --------- handshake/stk_generator.go | 33 +++++++++++++++++++++++++ handshake/stk_generator_test.go | 35 +++++++++++++++++++++++++++ 7 files changed, 86 insertions(+), 21 deletions(-) create mode 100644 handshake/stk_generator.go create mode 100644 handshake/stk_generator_test.go diff --git a/crypto/source_address_token.go b/crypto/source_address_token.go index cad04531..e3e9ec87 100644 --- a/crypto/source_address_token.go +++ b/crypto/source_address_token.go @@ -58,7 +58,11 @@ const stkKeySize = 16 const stkNonceSize = 16 // NewStkSource creates a source for source address tokens -func NewStkSource(secret []byte) (StkSource, error) { +func NewStkSource() (StkSource, error) { + secret := make([]byte, 32) + if _, err := rand.Read(secret); err != nil { + return nil, err + } key, err := deriveKey(secret) if err != nil { return nil, err diff --git a/crypto/source_address_token_test.go b/crypto/source_address_token_test.go index c246bfca..748f421d 100644 --- a/crypto/source_address_token_test.go +++ b/crypto/source_address_token_test.go @@ -43,7 +43,6 @@ var _ = Describe("Source Address Tokens", func() { Context("source", func() { var ( source *stkSource - secret []byte ip4 net.IP ip6 net.IP ) @@ -56,8 +55,7 @@ var _ = Describe("Source Address Tokens", func() { ip6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329") Expect(ip6).NotTo(BeEmpty()) - secret = []byte("TESTING") - sourceI, err := NewStkSource(secret) + sourceI, err := NewStkSource() source = sourceI.(*stkSource) Expect(err).NotTo(HaveOccurred()) }) diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index 7558cfe7..a5f85766 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -27,6 +27,7 @@ type cryptoSetupServer struct { connID protocol.ConnectionID sourceAddr []byte scfg *ServerConfig + stkGenerator *STKGenerator diversificationNonce []byte version protocol.VersionNumber @@ -68,6 +69,11 @@ func NewCryptoSetup( supportedVersions []protocol.VersionNumber, aeadChanged chan<- protocol.EncryptionLevel, ) (CryptoSetup, error) { + stkGenerator, err := NewSTKGenerator() + if err != nil { + return nil, err + } + var sourceAddr []byte if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { sourceAddr = udpAddr.IP @@ -81,6 +87,7 @@ func NewCryptoSetup( version: version, supportedVersions: supportedVersions, scfg: scfg, + stkGenerator: stkGenerator, keyDerivation: crypto.DeriveKeysAESGCM, keyExchange: getEphermalKEX, nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), @@ -276,7 +283,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt } func (h *cryptoSetupServer) verifySTK(stk []byte) bool { - stkTime, err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk) + stkTime, err := h.stkGenerator.VerifyToken(h.sourceAddr, stk) if err != nil { utils.Debugf("STK invalid: %s", err.Error()) return false @@ -292,7 +299,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small") } - token, err := h.scfg.stkSource.NewToken(h.sourceAddr) + token, err := h.stkGenerator.NewToken(h.sourceAddr) if err != nil { return nil, err } diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index 60bdce0b..c371fa78 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -194,7 +194,6 @@ var _ = Describe("Server Crypto Setup", func() { versionTag = make([]byte, 4) binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(protocol.VersionWhatever)) Expect(err).NotTo(HaveOccurred()) - scfg.stkSource = &mockStkSource{} version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1] supportedVersions = []protocol.VersionNumber{version, 98, 99} cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever) @@ -210,6 +209,7 @@ var _ = Describe("Server Crypto Setup", func() { ) Expect(err).NotTo(HaveOccurred()) cs = csInt.(*cryptoSetupServer) + cs.stkGenerator.stkSource = &mockStkSource{} cs.keyDerivation = mockKeyDerivation cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } }) @@ -436,12 +436,12 @@ var _ = Describe("Server Crypto Setup", func() { It("recognizes inchoate CHLOs with an invalid STK", func() { testErr := errors.New("STK invalid") - scfg.stkSource.(*mockStkSource).verifyErr = testErr + cs.stkGenerator.stkSource.(*mockStkSource).verifyErr = testErr Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue()) }) It("REJ messages that have an expired STK", func() { - cs.scfg.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second) + cs.stkGenerator.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second) Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue()) }) diff --git a/handshake/server_config.go b/handshake/server_config.go index 24195608..fce66efd 100644 --- a/handshake/server_config.go +++ b/handshake/server_config.go @@ -13,7 +13,6 @@ type ServerConfig struct { certChain crypto.CertChain ID []byte obit []byte - stkSource crypto.StkSource } // NewServerConfig creates a new server config @@ -24,27 +23,16 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve return nil, err } - stkSecret := make([]byte, 32) - if _, err = rand.Read(stkSecret); err != nil { - return nil, err - } - obit := make([]byte, 8) if _, err = rand.Read(obit); err != nil { return nil, err } - stkSource, err := crypto.NewStkSource(stkSecret) - if err != nil { - return nil, err - } - return &ServerConfig{ kex: kex, certChain: certChain, ID: id, obit: obit, - stkSource: stkSource, }, nil } diff --git a/handshake/stk_generator.go b/handshake/stk_generator.go new file mode 100644 index 00000000..eca677b6 --- /dev/null +++ b/handshake/stk_generator.go @@ -0,0 +1,33 @@ +package handshake + +import ( + "time" + + "github.com/lucas-clemente/quic-go/crypto" +) + +// An STKGenerator generates STKs +type STKGenerator struct { + stkSource crypto.StkSource +} + +// NewSTKGenerator initializes a new STKGenerator +func NewSTKGenerator() (*STKGenerator, error) { + stkSource, err := crypto.NewStkSource() + if err != nil { + return nil, err + } + return &STKGenerator{ + stkSource: stkSource, + }, nil +} + +// NewToken generates a new STK token for a given source address +func (g *STKGenerator) NewToken(sourceAddr []byte) ([]byte, error) { + return g.stkSource.NewToken(sourceAddr) +} + +// VerifyToken verifies an STK token +func (g *STKGenerator) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) { + return g.stkSource.VerifyToken(sourceAddr, token) +} diff --git a/handshake/stk_generator_test.go b/handshake/stk_generator_test.go new file mode 100644 index 00000000..cad4b243 --- /dev/null +++ b/handshake/stk_generator_test.go @@ -0,0 +1,35 @@ +package handshake + +import ( + "net" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("STK Generator", func() { + var stkGen *STKGenerator + + BeforeEach(func() { + var err error + stkGen, err = NewSTKGenerator() + Expect(err).ToNot(HaveOccurred()) + }) + + It("generates an STK", func() { + ip := net.IPv4(127, 0, 0, 1) + token, err := stkGen.NewToken(ip) + Expect(err).ToNot(HaveOccurred()) + Expect(token).ToNot(BeEmpty()) + }) + + It("verifies an STK", func() { + ip := net.IPv4(192, 168, 0, 1) + token, err := stkGen.NewToken(ip) + Expect(err).ToNot(HaveOccurred()) + t, err := stkGen.VerifyToken(ip, token) + Expect(err).ToNot(HaveOccurred()) + Expect(t).To(BeTemporally("~", time.Now(), time.Second)) + }) +})