forked from quic-go/quic-go
move the STK expiration check to the cryptoSetup
This commit is contained in:
@@ -10,10 +10,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
@@ -21,8 +20,8 @@ import (
|
||||
type StkSource interface {
|
||||
// NewToken creates a new token for a given IP address
|
||||
NewToken(sourceAddress []byte) ([]byte, error)
|
||||
// VerifyToken verifies if a token matches a given IP address and is not outdated
|
||||
VerifyToken(sourceAddress []byte, data []byte) error
|
||||
// VerifyToken verifies if a token matches a given IP address
|
||||
VerifyToken(sourceAddress []byte, data []byte) (time.Time, error)
|
||||
}
|
||||
|
||||
type sourceAddressToken struct {
|
||||
@@ -82,31 +81,31 @@ func (s *stkSource) NewToken(sourceAddr []byte) ([]byte, error) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *stkSource) VerifyToken(sourceAddr []byte, data []byte) error {
|
||||
func (s *stkSource) VerifyToken(sourceAddr []byte, data []byte) (time.Time, error) {
|
||||
if len(data) < stkNonceSize {
|
||||
return errors.New("STK too short")
|
||||
return time.Time{}, errors.New("STK too short")
|
||||
}
|
||||
nonce := data[:stkNonceSize]
|
||||
|
||||
res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
token, err := parseToken(res)
|
||||
if err != nil {
|
||||
return err
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(token.sourceAddr, sourceAddr) != 1 {
|
||||
return errors.New("invalid source address in STK")
|
||||
return time.Time{}, errors.New("invalid source address in STK")
|
||||
}
|
||||
|
||||
if time.Now().Unix() > int64(token.timestamp)+protocol.STKExpiryTimeSec {
|
||||
return errors.New("STK expired")
|
||||
if token.timestamp > math.MaxInt64 {
|
||||
return time.Time{}, errors.New("invalid timestamp")
|
||||
}
|
||||
|
||||
return nil
|
||||
return time.Unix(int64(token.timestamp), 0), nil
|
||||
}
|
||||
|
||||
func deriveKey(secret []byte) ([]byte, error) {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -62,57 +62,69 @@ var _ = Describe("Source Address Tokens", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should generate new tokens", func() {
|
||||
It("generates new tokens", func() {
|
||||
token, err := source.NewToken(ip4)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(token).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("should generate and verify ipv4 tokens", func() {
|
||||
It("generates and verifies ipv4 tokens", func() {
|
||||
stk, err := source.NewToken(ip4)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stk).ToNot(BeEmpty())
|
||||
err = source.VerifyToken(ip4, stk)
|
||||
_, err = source.VerifyToken(ip4, stk)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should generate and verify ipv6 tokens", func() {
|
||||
It("generates and verify ipv6 tokens", func() {
|
||||
stk, err := source.NewToken(ip6)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stk).ToNot(BeEmpty())
|
||||
err = source.VerifyToken(ip6, stk)
|
||||
_, err = source.VerifyToken(ip6, stk)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should reject empty tokens", func() {
|
||||
err := source.VerifyToken(ip4, nil)
|
||||
It("rejects empty tokens", func() {
|
||||
_, err := source.VerifyToken(ip4, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should reject invalid tokens", func() {
|
||||
err := source.VerifyToken(ip4, []byte("foobar"))
|
||||
It("rejects invalid tokens", func() {
|
||||
_, err := source.VerifyToken(ip4, []byte("foobar"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should reject outdated tokens", func() {
|
||||
stk, err := encryptToken(source.aead, &sourceAddressToken{
|
||||
sourceAddr: ip4,
|
||||
timestamp: uint64(time.Now().Unix() - protocol.STKExpiryTimeSec - 1),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = source.VerifyToken(ip4, stk)
|
||||
Expect(err).To(MatchError("STK expired"))
|
||||
})
|
||||
|
||||
It("should reject tokens with wrong IP addresses", func() {
|
||||
It("rejects tokens with wrong IP addresses", func() {
|
||||
otherIP := net.ParseIP("4.3.2.1")
|
||||
stk, err := encryptToken(source.aead, &sourceAddressToken{
|
||||
sourceAddr: otherIP,
|
||||
timestamp: uint64(time.Now().Unix()),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = source.VerifyToken(ip4, stk)
|
||||
_, err = source.VerifyToken(ip4, stk)
|
||||
Expect(err).To(MatchError("invalid source address in STK"))
|
||||
})
|
||||
|
||||
It("rejects overflowing timestamps", func() {
|
||||
stk, err := encryptToken(source.aead, &sourceAddressToken{
|
||||
sourceAddr: ip4,
|
||||
timestamp: math.MaxUint64,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = source.VerifyToken(ip4, stk)
|
||||
Expect(err).To(MatchError("invalid timestamp"))
|
||||
})
|
||||
|
||||
It("returns the timestamp encoded in the token", func() {
|
||||
timestamp := time.Now().Add(-time.Hour)
|
||||
stk, err := encryptToken(source.aead, &sourceAddressToken{
|
||||
sourceAddr: ip4,
|
||||
timestamp: uint64(timestamp.Unix()),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
t, err := source.VerifyToken(ip4, stk)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t).To(BeTemporally("~", timestamp, time.Second))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
@@ -271,12 +272,19 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
|
||||
if crypto.HashCert(cert) != xlct {
|
||||
return true
|
||||
}
|
||||
stk := cryptoData[TagSTK]
|
||||
if err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk); err != nil {
|
||||
return !h.verifySTK(cryptoData[TagSTK])
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) verifySTK(stk []byte) bool {
|
||||
stkTime, err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk)
|
||||
if err != nil {
|
||||
utils.Debugf("STK invalid: %s", err.Error())
|
||||
return true
|
||||
return false
|
||||
}
|
||||
return false
|
||||
if time.Now().After(stkTime.Add(protocol.STKExpiryTimeSec * time.Second)) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||
@@ -295,7 +303,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
|
||||
TagSVID: []byte("quic-go"),
|
||||
}
|
||||
|
||||
if h.scfg.stkSource.VerifyToken(h.sourceAddr, cryptoData[TagSTK]) == nil {
|
||||
if h.verifySTK(cryptoData[TagSTK]) {
|
||||
proof, err := h.scfg.Sign(sni, chlo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
@@ -127,27 +128,33 @@ func (s mockStream) StreamID() protocol.StreamID { panic("not implemente
|
||||
|
||||
type mockStkSource struct {
|
||||
verifyErr error
|
||||
stkTime time.Time
|
||||
}
|
||||
|
||||
var _ crypto.StkSource = &mockStkSource{}
|
||||
|
||||
func (mockStkSource) NewToken(sourceAddr []byte) ([]byte, error) {
|
||||
return append([]byte("token "), sourceAddr...), nil
|
||||
}
|
||||
|
||||
func (s mockStkSource) VerifyToken(sourceAddr []byte, token []byte) error {
|
||||
func (s mockStkSource) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) {
|
||||
if s.verifyErr != nil {
|
||||
return s.verifyErr
|
||||
return time.Time{}, s.verifyErr
|
||||
}
|
||||
split := bytes.Split(token, []byte(" "))
|
||||
if len(split) != 2 {
|
||||
return errors.New("stk required")
|
||||
return time.Time{}, errors.New("stk required")
|
||||
}
|
||||
if !bytes.Equal(split[0], []byte("token")) {
|
||||
return errors.New("no prefix match")
|
||||
return time.Time{}, errors.New("no prefix match")
|
||||
}
|
||||
if !bytes.Equal(split[1], sourceAddr) {
|
||||
return errors.New("ip wrong")
|
||||
return time.Time{}, errors.New("ip wrong")
|
||||
}
|
||||
return nil
|
||||
if !s.stkTime.IsZero() {
|
||||
return s.stkTime, nil
|
||||
}
|
||||
return time.Now(), nil
|
||||
}
|
||||
|
||||
var _ = Describe("Server Crypto Setup", func() {
|
||||
@@ -433,6 +440,11 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("REJ messages that have an expired STK", func() {
|
||||
cs.scfg.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTimeSec * time.Second).Add(-time.Second)
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("recognizes proper CHLOs", func() {
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user