implement source address token generation and validation

ref #121
This commit is contained in:
Lucas Clemente
2016-05-23 18:13:39 +02:00
parent c3f8837dfe
commit 9539169fa4
3 changed files with 234 additions and 0 deletions

View File

@@ -0,0 +1,119 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/protocol"
"golang.org/x/crypto/hkdf"
)
type sourceAddressToken struct {
ip net.IP
// unix timestamp in seconds
timestamp uint64
}
func (t *sourceAddressToken) serialize() []byte {
res := make([]byte, 8+len(t.ip))
binary.LittleEndian.PutUint64(res, t.timestamp)
copy(res[8:], t.ip)
return res
}
func parseToken(data []byte) (*sourceAddressToken, error) {
if len(data) != 8+4 && len(data) != 8+16 {
return nil, fmt.Errorf("invalid STK length %d", len(data))
}
return &sourceAddressToken{
ip: data[8:],
timestamp: binary.LittleEndian.Uint64(data),
}, nil
}
type stkSource struct {
aead cipher.AEAD
}
const stkKeySize = 16
// Chrome currently sets this to 12, but discusses changing it to 16. We start
// at 16 :)
const stkNonceSize = 16
func newStkSource(secret []byte) (*stkSource, error) {
key, err := deriveKey(secret)
if err != nil {
return nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
if err != nil {
return nil, err
}
return &stkSource{aead: aead}, nil
}
func (s *stkSource) NewToken(ip net.IP) ([]byte, error) {
return encryptToken(s.aead, &sourceAddressToken{
ip: ip,
timestamp: uint64(time.Now().Unix()),
})
}
func (s *stkSource) VerifyToken(ip net.IP, data []byte) error {
if len(data) < stkNonceSize {
return errors.New("STK too short")
}
nonce := data[:stkNonceSize]
res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil)
if err != nil {
return err
}
token, err := parseToken(res)
if err != nil {
return err
}
if subtle.ConstantTimeCompare(token.ip, ip) != 1 {
return errors.New("invalid ip in STK")
}
if time.Now().Unix() > int64(token.timestamp)+protocol.STKExpiryTimeSec {
return errors.New("STK expired")
}
return nil
}
func deriveKey(secret []byte) ([]byte, error) {
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
key := make([]byte, stkKeySize)
if _, err := io.ReadFull(r, key); err != nil {
return nil, err
}
return key, nil
}
func encryptToken(aead cipher.AEAD, token *sourceAddressToken) ([]byte, error) {
nonce := make([]byte, stkNonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return aead.Seal(nonce, nonce, token.serialize(), nil), nil
}

View File

@@ -0,0 +1,112 @@
package crypto
import (
"net"
"time"
"github.com/lucas-clemente/quic-go/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Source Address Tokens", func() {
It("should generate the encryption key", func() {
Expect(deriveKey([]byte("TESTING"))).To(Equal([]byte{0xee, 0x71, 0x18, 0x9, 0xfd, 0xb8, 0x9a, 0x79, 0x19, 0xfc, 0x5e, 0x1a, 0x97, 0x20, 0xb2, 0x6}))
})
Context("tokens", func() {
It("serializes", func() {
ip := []byte{127, 0, 0, 1}
token := &sourceAddressToken{ip: ip, timestamp: 0xdeadbeef}
Expect(token.serialize()).To(Equal([]byte{
0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00,
127, 0, 0, 1,
}))
})
It("reads", func() {
token, err := parseToken([]byte{
0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00,
127, 0, 0, 1,
})
Expect(err).NotTo(HaveOccurred())
Expect(token.ip).To(Equal(net.IP{127, 0, 0, 1}))
Expect(token.timestamp).To(Equal(uint64(0xdeadbeef)))
})
})
Context("source", func() {
var (
source *stkSource
secret []byte
ip4 net.IP
ip6 net.IP
)
BeforeEach(func() {
var err error
ip4 = net.ParseIP("1.2.3.4")
Expect(ip4).NotTo(BeEmpty())
ip6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")
Expect(ip6).NotTo(BeEmpty())
secret = []byte("TESTING")
source, err = newStkSource(secret)
Expect(err).NotTo(HaveOccurred())
})
It("should generate new tokens", func() {
token, err := source.NewToken(ip4)
Expect(err).NotTo(HaveOccurred())
Expect(token).ToNot(BeEmpty())
})
It("should generate and verify ipv4 tokens", func() {
stk, err := source.NewToken(ip4)
Expect(err).NotTo(HaveOccurred())
Expect(stk).ToNot(BeEmpty())
err = source.VerifyToken(ip4, stk)
Expect(err).NotTo(HaveOccurred())
})
It("should generate and verify ipv6 tokens", func() {
stk, err := source.NewToken(ip6)
Expect(err).NotTo(HaveOccurred())
Expect(stk).ToNot(BeEmpty())
err = source.VerifyToken(ip6, stk)
Expect(err).NotTo(HaveOccurred())
})
It("should reject empty tokens", func() {
err := source.VerifyToken(ip4, nil)
Expect(err).To(HaveOccurred())
})
It("should reject invalid tokens", func() {
err := source.VerifyToken(ip4, []byte("foobar"))
Expect(err).To(HaveOccurred())
})
It("should reject outdated tokens", func() {
stk, err := encryptToken(source.aead, &sourceAddressToken{
ip: ip4,
timestamp: uint64(time.Now().Unix() - protocol.STKExpiryTimeSec - 1),
})
Expect(err).NotTo(HaveOccurred())
err = source.VerifyToken(ip4, stk)
Expect(err).To(MatchError("STK expired"))
})
It("should reject tokens with wrong IP addresses", func() {
otherIP := net.ParseIP("4.3.2.1")
stk, err := encryptToken(source.aead, &sourceAddressToken{
ip: otherIP,
timestamp: uint64(time.Now().Unix()),
})
Expect(err).NotTo(HaveOccurred())
err = source.VerifyToken(ip4, stk)
Expect(err).To(MatchError("invalid ip in STK"))
})
})
})

View File

@@ -48,3 +48,6 @@ const MaxSessionUnprocessedPackets = 128
// RetransmissionThreshold + 1 is the number of times a packet has to be NACKed so that it gets retransmitted
const RetransmissionThreshold uint8 = 3
// STKExpiryTimeSec is the valid time of a source address token in seconds
const STKExpiryTimeSec = 24 * 60 * 60