move comparison of the source address in the STK to the STKGenerator

This commit is contained in:
Marten Seemann
2017-05-13 13:25:10 +08:00
parent 9562df5838
commit 8e01921495
5 changed files with 89 additions and 85 deletions

View File

@@ -5,7 +5,6 @@ import (
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt"
@@ -20,20 +19,20 @@ import (
type StkSource interface {
// NewToken creates a new token for a given IP address
NewToken(sourceAddress []byte) ([]byte, error)
// VerifyToken verifies if a token matches a given IP address
VerifyToken(sourceAddress []byte, data []byte) (time.Time, error)
// DecodeToken decodes a token and returns the source address and the timestamp
DecodeToken(data []byte) ([]byte, time.Time, error)
}
type sourceAddressToken struct {
sourceAddr []byte
data []byte
// unix timestamp in seconds
timestamp uint64
}
func (t *sourceAddressToken) serialize() []byte {
res := make([]byte, 8+len(t.sourceAddr))
res := make([]byte, 8+len(t.data))
binary.LittleEndian.PutUint64(res, t.timestamp)
copy(res[8:], t.sourceAddr)
copy(res[8:], t.data)
return res
}
@@ -42,8 +41,8 @@ func parseToken(data []byte) (*sourceAddressToken, error) {
return nil, fmt.Errorf("invalid STK length: %d", len(data))
}
return &sourceAddressToken{
sourceAddr: data[8:],
timestamp: binary.LittleEndian.Uint64(data),
data: data[8:],
timestamp: binary.LittleEndian.Uint64(data),
}, nil
}
@@ -78,38 +77,31 @@ func NewStkSource() (StkSource, error) {
return &stkSource{aead: aead}, nil
}
func (s *stkSource) NewToken(sourceAddr []byte) ([]byte, error) {
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
return encryptToken(s.aead, &sourceAddressToken{
sourceAddr: sourceAddr,
timestamp: uint64(time.Now().Unix()),
data: data,
timestamp: uint64(time.Now().Unix()),
})
}
func (s *stkSource) VerifyToken(sourceAddr []byte, data []byte) (time.Time, error) {
func (s *stkSource) DecodeToken(data []byte) ([]byte, time.Time, error) {
if len(data) < stkNonceSize {
return time.Time{}, errors.New("STK too short")
return nil, time.Time{}, errors.New("STK too short")
}
nonce := data[:stkNonceSize]
res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil)
if err != nil {
return time.Time{}, err
return nil, time.Time{}, err
}
token, err := parseToken(res)
if err != nil {
return time.Time{}, err
return nil, time.Time{}, err
}
if subtle.ConstantTimeCompare(token.sourceAddr, sourceAddr) != 1 {
return time.Time{}, errors.New("invalid source address in STK")
}
if token.timestamp > math.MaxInt64 {
return time.Time{}, errors.New("invalid timestamp")
return nil, time.Time{}, errors.New("invalid timestamp")
}
return time.Unix(int64(token.timestamp), 0), nil
return token.data, time.Unix(int64(token.timestamp), 0), nil
}
func deriveKey(secret []byte) ([]byte, error) {

View File

@@ -17,7 +17,7 @@ var _ = Describe("Source Address Tokens", func() {
Context("tokens", func() {
It("serializes", func() {
ip := []byte{127, 0, 0, 1}
token := &sourceAddressToken{sourceAddr: ip, timestamp: 0xdeadbeef}
token := &sourceAddressToken{data: ip, timestamp: 0xdeadbeef}
Expect(token.serialize()).To(Equal([]byte{
0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00,
127, 0, 0, 1,
@@ -30,7 +30,7 @@ var _ = Describe("Source Address Tokens", func() {
127, 0, 0, 1,
})
Expect(err).NotTo(HaveOccurred())
Expect(token.sourceAddr).To(Equal([]byte{127, 0, 0, 1}))
Expect(token.data).To(Equal([]byte{127, 0, 0, 1}))
Expect(token.timestamp).To(Equal(uint64(0xdeadbeef)))
})
@@ -70,57 +70,48 @@ var _ = Describe("Source Address Tokens", func() {
stk, err := source.NewToken(ip4)
Expect(err).NotTo(HaveOccurred())
Expect(stk).ToNot(BeEmpty())
_, err = source.VerifyToken(ip4, stk)
decodedIP, _, err := source.DecodeToken(stk)
Expect(err).NotTo(HaveOccurred())
Expect(decodedIP).To(BeEquivalentTo(ip4))
})
It("generates and verify ipv6 tokens", func() {
stk, err := source.NewToken(ip6)
Expect(err).NotTo(HaveOccurred())
Expect(stk).ToNot(BeEmpty())
_, err = source.VerifyToken(ip6, stk)
decodedIP, _, err := source.DecodeToken(stk)
Expect(err).NotTo(HaveOccurred())
Expect(decodedIP).To(BeEquivalentTo(ip6))
})
It("rejects empty tokens", func() {
_, err := source.VerifyToken(ip4, nil)
Expect(err).To(HaveOccurred())
_, _, err := source.DecodeToken([]byte{})
Expect(err).To(MatchError("STK too short"))
})
It("rejects invalid tokens", func() {
_, err := source.VerifyToken(ip4, []byte("foobar"))
_, _, err := source.DecodeToken([]byte("foobar"))
Expect(err).To(HaveOccurred())
})
It("rejects tokens with wrong IP addresses", func() {
otherIP := net.ParseIP("4.3.2.1")
stk, err := encryptToken(source.aead, &sourceAddressToken{
sourceAddr: otherIP,
timestamp: uint64(time.Now().Unix()),
})
Expect(err).NotTo(HaveOccurred())
_, err = source.VerifyToken(ip4, stk)
Expect(err).To(MatchError("invalid source address in STK"))
})
It("rejects overflowing timestamps", func() {
stk, err := encryptToken(source.aead, &sourceAddressToken{
sourceAddr: ip4,
timestamp: math.MaxUint64,
data: ip4,
timestamp: math.MaxUint64,
})
Expect(err).NotTo(HaveOccurred())
_, err = source.VerifyToken(ip4, stk)
_, _, err = source.DecodeToken(stk)
Expect(err).To(MatchError("invalid timestamp"))
})
It("returns the timestamp encoded in the token", func() {
timestamp := time.Now().Add(-time.Hour)
stk, err := encryptToken(source.aead, &sourceAddressToken{
sourceAddr: ip4,
timestamp: uint64(timestamp.Unix()),
data: ip4,
timestamp: uint64(timestamp.Unix()),
})
Expect(err).NotTo(HaveOccurred())
t, err := source.VerifyToken(ip4, stk)
_, t, err := source.DecodeToken(stk)
Expect(err).ToNot(HaveOccurred())
Expect(t).To(BeTemporally("~", timestamp, time.Second))
})

View File

@@ -127,7 +127,7 @@ func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemente
func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") }
type mockStkSource struct {
verifyErr error
decodeErr error
stkTime time.Time
}
@@ -137,24 +137,18 @@ func (mockStkSource) NewToken(sourceAddr []byte) ([]byte, error) {
return append([]byte("token "), sourceAddr...), nil
}
func (s mockStkSource) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) {
if s.verifyErr != nil {
return time.Time{}, s.verifyErr
func (s mockStkSource) DecodeToken(data []byte) ([]byte, time.Time, error) {
if s.decodeErr != nil {
return nil, time.Time{}, s.decodeErr
}
split := bytes.Split(token, []byte(" "))
if len(split) != 2 {
return time.Time{}, errors.New("stk required")
if len(data) < 6 {
return nil, time.Time{}, errors.New("token too short")
}
if !bytes.Equal(split[0], []byte("token")) {
return time.Time{}, errors.New("no prefix match")
t := s.stkTime
if t.IsZero() {
t = time.Now()
}
if !bytes.Equal(split[1], sourceAddr) {
return time.Time{}, errors.New("ip wrong")
}
if !s.stkTime.IsZero() {
return s.stkTime, nil
}
return time.Now(), nil
return data[6:], t, nil
}
var _ = Describe("Server Crypto Setup", func() {
@@ -436,7 +430,7 @@ var _ = Describe("Server Crypto Setup", func() {
It("recognizes inchoate CHLOs with an invalid STK", func() {
testErr := errors.New("STK invalid")
cs.stkGenerator.stkSource.(*mockStkSource).verifyErr = testErr
cs.stkGenerator.stkSource.(*mockStkSource).decodeErr = testErr
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
@@ -730,33 +724,42 @@ var _ = Describe("Server Crypto Setup", func() {
Context("STK verification and creation", func() {
It("requires STK", func() {
done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
TagSNI: []byte("foo"),
TagVER: versionTag,
})
done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
map[Tag][]byte{
TagSNI: []byte("foo"),
TagVER: versionTag,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
Expect(err).To(BeNil())
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
})
It("works with proper STK", func() {
done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
TagSTK: validSTK,
TagSNI: []byte("foo"),
TagVER: versionTag,
})
done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
map[Tag][]byte{
TagSTK: validSTK,
TagSNI: []byte("foo"),
TagVER: versionTag,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
Expect(err).To(BeNil())
})
It("errors if IP does not match", func() {
done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
TagSNI: []byte("foo"),
TagSTK: []byte("token \x04\x03\x03\x01"),
TagVER: versionTag,
})
done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
map[Tag][]byte{
TagSNI: []byte("foo"),
TagSTK: []byte("token \x04\x03\x03\x01"),
TagVER: versionTag,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
Expect(err).To(BeNil())
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
})
})

View File

@@ -1,6 +1,8 @@
package handshake
import (
"crypto/subtle"
"errors"
"time"
"github.com/lucas-clemente/quic-go/crypto"
@@ -28,6 +30,13 @@ func (g *STKGenerator) NewToken(sourceAddr []byte) ([]byte, error) {
}
// VerifyToken verifies an STK token
func (g *STKGenerator) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) {
return g.stkSource.VerifyToken(sourceAddr, token)
func (g *STKGenerator) VerifyToken(sourceAddr []byte, data []byte) (time.Time, error) {
tokenAddr, timestamp, err := g.stkSource.DecodeToken(data)
if err != nil {
return time.Time{}, err
}
if subtle.ConstantTimeCompare(sourceAddr, tokenAddr) != 1 {
return time.Time{}, errors.New("invalid source address in STK")
}
return timestamp, nil
}

View File

@@ -24,7 +24,7 @@ var _ = Describe("STK Generator", func() {
Expect(token).ToNot(BeEmpty())
})
It("verifies an STK", func() {
It("accepts a valid STK", func() {
ip := net.IPv4(192, 168, 0, 1)
token, err := stkGen.NewToken(ip)
Expect(err).ToNot(HaveOccurred())
@@ -32,4 +32,13 @@ var _ = Describe("STK Generator", func() {
Expect(err).ToNot(HaveOccurred())
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
})
It("rejects an STK for the wrong address", func() {
ip := net.ParseIP("1.2.3.4")
otherIP := net.ParseIP("4.3.2.1")
token, err := stkGen.NewToken(ip)
Expect(err).NotTo(HaveOccurred())
_, err = stkGen.VerifyToken(otherIP, token)
Expect(err).To(MatchError("invalid source address in STK"))
})
})