From 81985f44bd6871ab9eeac141abfa4120fec48eda Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 1 Apr 2017 12:10:09 +0700 Subject: [PATCH] move the STK expiration check to the cryptoSetup --- crypto/source_address_token.go | 23 ++++++----- crypto/source_address_token_test.go | 56 ++++++++++++++++----------- handshake/crypto_setup_server.go | 18 ++++++--- handshake/crypto_setup_server_test.go | 24 +++++++++--- 4 files changed, 76 insertions(+), 45 deletions(-) diff --git a/crypto/source_address_token.go b/crypto/source_address_token.go index cde9cd1e..cad04531 100644 --- a/crypto/source_address_token.go +++ b/crypto/source_address_token.go @@ -10,10 +10,9 @@ import ( "errors" "fmt" "io" + "math" "time" - "github.com/lucas-clemente/quic-go/protocol" - "golang.org/x/crypto/hkdf" ) @@ -21,8 +20,8 @@ import ( type StkSource interface { // NewToken creates a new token for a given IP address NewToken(sourceAddress []byte) ([]byte, error) - // VerifyToken verifies if a token matches a given IP address and is not outdated - VerifyToken(sourceAddress []byte, data []byte) error + // VerifyToken verifies if a token matches a given IP address + VerifyToken(sourceAddress []byte, data []byte) (time.Time, error) } type sourceAddressToken struct { @@ -82,31 +81,31 @@ func (s *stkSource) NewToken(sourceAddr []byte) ([]byte, error) { }) } -func (s *stkSource) VerifyToken(sourceAddr []byte, data []byte) error { +func (s *stkSource) VerifyToken(sourceAddr []byte, data []byte) (time.Time, error) { if len(data) < stkNonceSize { - return errors.New("STK too short") + return time.Time{}, errors.New("STK too short") } nonce := data[:stkNonceSize] res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil) if err != nil { - return err + return time.Time{}, err } token, err := parseToken(res) if err != nil { - return err + return time.Time{}, err } if subtle.ConstantTimeCompare(token.sourceAddr, sourceAddr) != 1 { - return errors.New("invalid source address in STK") + return time.Time{}, errors.New("invalid source address in STK") } - if time.Now().Unix() > int64(token.timestamp)+protocol.STKExpiryTimeSec { - return errors.New("STK expired") + if token.timestamp > math.MaxInt64 { + return time.Time{}, errors.New("invalid timestamp") } - return nil + return time.Unix(int64(token.timestamp), 0), nil } func deriveKey(secret []byte) ([]byte, error) { diff --git a/crypto/source_address_token_test.go b/crypto/source_address_token_test.go index 126a1203..c246bfca 100644 --- a/crypto/source_address_token_test.go +++ b/crypto/source_address_token_test.go @@ -1,10 +1,10 @@ package crypto import ( + "math" "net" "time" - "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -62,57 +62,69 @@ var _ = Describe("Source Address Tokens", func() { Expect(err).NotTo(HaveOccurred()) }) - It("should generate new tokens", func() { + It("generates new tokens", func() { token, err := source.NewToken(ip4) Expect(err).NotTo(HaveOccurred()) Expect(token).ToNot(BeEmpty()) }) - It("should generate and verify ipv4 tokens", func() { + It("generates and verifies ipv4 tokens", func() { stk, err := source.NewToken(ip4) Expect(err).NotTo(HaveOccurred()) Expect(stk).ToNot(BeEmpty()) - err = source.VerifyToken(ip4, stk) + _, err = source.VerifyToken(ip4, stk) Expect(err).NotTo(HaveOccurred()) }) - It("should generate and verify ipv6 tokens", func() { + It("generates and verify ipv6 tokens", func() { stk, err := source.NewToken(ip6) Expect(err).NotTo(HaveOccurred()) Expect(stk).ToNot(BeEmpty()) - err = source.VerifyToken(ip6, stk) + _, err = source.VerifyToken(ip6, stk) Expect(err).NotTo(HaveOccurred()) }) - It("should reject empty tokens", func() { - err := source.VerifyToken(ip4, nil) + It("rejects empty tokens", func() { + _, err := source.VerifyToken(ip4, nil) Expect(err).To(HaveOccurred()) }) - It("should reject invalid tokens", func() { - err := source.VerifyToken(ip4, []byte("foobar")) + It("rejects invalid tokens", func() { + _, err := source.VerifyToken(ip4, []byte("foobar")) Expect(err).To(HaveOccurred()) }) - It("should reject outdated tokens", func() { - stk, err := encryptToken(source.aead, &sourceAddressToken{ - sourceAddr: ip4, - timestamp: uint64(time.Now().Unix() - protocol.STKExpiryTimeSec - 1), - }) - Expect(err).NotTo(HaveOccurred()) - err = source.VerifyToken(ip4, stk) - Expect(err).To(MatchError("STK expired")) - }) - - It("should reject tokens with wrong IP addresses", func() { + It("rejects tokens with wrong IP addresses", func() { otherIP := net.ParseIP("4.3.2.1") stk, err := encryptToken(source.aead, &sourceAddressToken{ sourceAddr: otherIP, timestamp: uint64(time.Now().Unix()), }) Expect(err).NotTo(HaveOccurred()) - err = source.VerifyToken(ip4, stk) + _, err = source.VerifyToken(ip4, stk) Expect(err).To(MatchError("invalid source address in STK")) }) + + It("rejects overflowing timestamps", func() { + stk, err := encryptToken(source.aead, &sourceAddressToken{ + sourceAddr: ip4, + timestamp: math.MaxUint64, + }) + Expect(err).NotTo(HaveOccurred()) + _, err = source.VerifyToken(ip4, stk) + Expect(err).To(MatchError("invalid timestamp")) + }) + + It("returns the timestamp encoded in the token", func() { + timestamp := time.Now().Add(-time.Hour) + stk, err := encryptToken(source.aead, &sourceAddressToken{ + sourceAddr: ip4, + timestamp: uint64(timestamp.Unix()), + }) + Expect(err).NotTo(HaveOccurred()) + t, err := source.VerifyToken(ip4, stk) + Expect(err).ToNot(HaveOccurred()) + Expect(t).To(BeTemporally("~", timestamp, time.Second)) + }) }) }) diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index 070250b2..b9e4304f 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -8,6 +8,7 @@ import ( "io" "net" "sync" + "time" "github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/protocol" @@ -271,12 +272,19 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt if crypto.HashCert(cert) != xlct { return true } - stk := cryptoData[TagSTK] - if err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk); err != nil { + return !h.verifySTK(cryptoData[TagSTK]) +} + +func (h *cryptoSetupServer) verifySTK(stk []byte) bool { + stkTime, err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk) + if err != nil { utils.Debugf("STK invalid: %s", err.Error()) - return true + return false } - return false + if time.Now().After(stkTime.Add(protocol.STKExpiryTimeSec * time.Second)) { + return false + } + return true } func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) { @@ -295,7 +303,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa TagSVID: []byte("quic-go"), } - if h.scfg.stkSource.VerifyToken(h.sourceAddr, cryptoData[TagSTK]) == nil { + if h.verifySTK(cryptoData[TagSTK]) { proof, err := h.scfg.Sign(sni, chlo) if err != nil { return nil, err diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index 2cb22450..1c31a319 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "net" + "time" "github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/protocol" @@ -127,27 +128,33 @@ func (s mockStream) StreamID() protocol.StreamID { panic("not implemente type mockStkSource struct { verifyErr error + stkTime time.Time } +var _ crypto.StkSource = &mockStkSource{} + func (mockStkSource) NewToken(sourceAddr []byte) ([]byte, error) { return append([]byte("token "), sourceAddr...), nil } -func (s mockStkSource) VerifyToken(sourceAddr []byte, token []byte) error { +func (s mockStkSource) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) { if s.verifyErr != nil { - return s.verifyErr + return time.Time{}, s.verifyErr } split := bytes.Split(token, []byte(" ")) if len(split) != 2 { - return errors.New("stk required") + return time.Time{}, errors.New("stk required") } if !bytes.Equal(split[0], []byte("token")) { - return errors.New("no prefix match") + return time.Time{}, errors.New("no prefix match") } if !bytes.Equal(split[1], sourceAddr) { - return errors.New("ip wrong") + return time.Time{}, errors.New("ip wrong") } - return nil + if !s.stkTime.IsZero() { + return s.stkTime, nil + } + return time.Now(), nil } var _ = Describe("Server Crypto Setup", func() { @@ -433,6 +440,11 @@ var _ = Describe("Server Crypto Setup", func() { Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue()) }) + It("REJ messages that have an expired STK", func() { + cs.scfg.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTimeSec * time.Second).Add(-time.Second) + Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue()) + }) + It("recognizes proper CHLOs", func() { Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse()) })