diff --git a/crypto/source_address_token.go b/crypto/source_address_token.go new file mode 100644 index 00000000..dea1d2b2 --- /dev/null +++ b/crypto/source_address_token.go @@ -0,0 +1,119 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/lucas-clemente/quic-go/protocol" + + "golang.org/x/crypto/hkdf" +) + +type sourceAddressToken struct { + ip net.IP + // unix timestamp in seconds + timestamp uint64 +} + +func (t *sourceAddressToken) serialize() []byte { + res := make([]byte, 8+len(t.ip)) + binary.LittleEndian.PutUint64(res, t.timestamp) + copy(res[8:], t.ip) + return res +} + +func parseToken(data []byte) (*sourceAddressToken, error) { + if len(data) != 8+4 && len(data) != 8+16 { + return nil, fmt.Errorf("invalid STK length %d", len(data)) + } + return &sourceAddressToken{ + ip: data[8:], + timestamp: binary.LittleEndian.Uint64(data), + }, nil +} + +type stkSource struct { + aead cipher.AEAD +} + +const stkKeySize = 16 + +// Chrome currently sets this to 12, but discusses changing it to 16. We start +// at 16 :) +const stkNonceSize = 16 + +func newStkSource(secret []byte) (*stkSource, error) { + key, err := deriveKey(secret) + if err != nil { + return nil, err + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize) + if err != nil { + return nil, err + } + return &stkSource{aead: aead}, nil +} + +func (s *stkSource) NewToken(ip net.IP) ([]byte, error) { + return encryptToken(s.aead, &sourceAddressToken{ + ip: ip, + timestamp: uint64(time.Now().Unix()), + }) +} + +func (s *stkSource) VerifyToken(ip net.IP, data []byte) error { + if len(data) < stkNonceSize { + return errors.New("STK too short") + } + nonce := data[:stkNonceSize] + + res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil) + if err != nil { + return err + } + + token, err := parseToken(res) + if err != nil { + return err + } + + if subtle.ConstantTimeCompare(token.ip, ip) != 1 { + return errors.New("invalid ip in STK") + } + + if time.Now().Unix() > int64(token.timestamp)+protocol.STKExpiryTimeSec { + return errors.New("STK expired") + } + + return nil +} + +func deriveKey(secret []byte) ([]byte, error) { + r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key")) + key := make([]byte, stkKeySize) + if _, err := io.ReadFull(r, key); err != nil { + return nil, err + } + return key, nil +} + +func encryptToken(aead cipher.AEAD, token *sourceAddressToken) ([]byte, error) { + nonce := make([]byte, stkNonceSize) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + return aead.Seal(nonce, nonce, token.serialize(), nil), nil +} diff --git a/crypto/source_address_token_test.go b/crypto/source_address_token_test.go new file mode 100644 index 00000000..496eb1d6 --- /dev/null +++ b/crypto/source_address_token_test.go @@ -0,0 +1,112 @@ +package crypto + +import ( + "net" + "time" + + "github.com/lucas-clemente/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Source Address Tokens", func() { + It("should generate the encryption key", func() { + Expect(deriveKey([]byte("TESTING"))).To(Equal([]byte{0xee, 0x71, 0x18, 0x9, 0xfd, 0xb8, 0x9a, 0x79, 0x19, 0xfc, 0x5e, 0x1a, 0x97, 0x20, 0xb2, 0x6})) + }) + + Context("tokens", func() { + It("serializes", func() { + ip := []byte{127, 0, 0, 1} + token := &sourceAddressToken{ip: ip, timestamp: 0xdeadbeef} + Expect(token.serialize()).To(Equal([]byte{ + 0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00, + 127, 0, 0, 1, + })) + }) + + It("reads", func() { + token, err := parseToken([]byte{ + 0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00, + 127, 0, 0, 1, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(token.ip).To(Equal(net.IP{127, 0, 0, 1})) + Expect(token.timestamp).To(Equal(uint64(0xdeadbeef))) + }) + }) + + Context("source", func() { + var ( + source *stkSource + secret []byte + ip4 net.IP + ip6 net.IP + ) + + BeforeEach(func() { + var err error + + ip4 = net.ParseIP("1.2.3.4") + Expect(ip4).NotTo(BeEmpty()) + ip6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329") + Expect(ip6).NotTo(BeEmpty()) + + secret = []byte("TESTING") + source, err = newStkSource(secret) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should generate new tokens", func() { + token, err := source.NewToken(ip4) + Expect(err).NotTo(HaveOccurred()) + Expect(token).ToNot(BeEmpty()) + }) + + It("should generate and verify ipv4 tokens", func() { + stk, err := source.NewToken(ip4) + Expect(err).NotTo(HaveOccurred()) + Expect(stk).ToNot(BeEmpty()) + err = source.VerifyToken(ip4, stk) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should generate and verify ipv6 tokens", func() { + stk, err := source.NewToken(ip6) + Expect(err).NotTo(HaveOccurred()) + Expect(stk).ToNot(BeEmpty()) + err = source.VerifyToken(ip6, stk) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should reject empty tokens", func() { + err := source.VerifyToken(ip4, nil) + Expect(err).To(HaveOccurred()) + }) + + It("should reject invalid tokens", func() { + err := source.VerifyToken(ip4, []byte("foobar")) + Expect(err).To(HaveOccurred()) + }) + + It("should reject outdated tokens", func() { + stk, err := encryptToken(source.aead, &sourceAddressToken{ + ip: ip4, + timestamp: uint64(time.Now().Unix() - protocol.STKExpiryTimeSec - 1), + }) + Expect(err).NotTo(HaveOccurred()) + err = source.VerifyToken(ip4, stk) + Expect(err).To(MatchError("STK expired")) + }) + + It("should reject tokens with wrong IP addresses", func() { + otherIP := net.ParseIP("4.3.2.1") + stk, err := encryptToken(source.aead, &sourceAddressToken{ + ip: otherIP, + timestamp: uint64(time.Now().Unix()), + }) + Expect(err).NotTo(HaveOccurred()) + err = source.VerifyToken(ip4, stk) + Expect(err).To(MatchError("invalid ip in STK")) + }) + }) +}) diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index 6dc5940c..703f0392 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -48,3 +48,6 @@ const MaxSessionUnprocessedPackets = 128 // RetransmissionThreshold + 1 is the number of times a packet has to be NACKed so that it gets retransmitted const RetransmissionThreshold uint8 = 3 + +// STKExpiryTimeSec is the valid time of a source address token in seconds +const STKExpiryTimeSec = 24 * 60 * 60