forked from quic-go/quic-go
move comparison of the source address in the STK to the STKGenerator
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -20,20 +19,20 @@ import (
|
||||
type StkSource interface {
|
||||
// NewToken creates a new token for a given IP address
|
||||
NewToken(sourceAddress []byte) ([]byte, error)
|
||||
// VerifyToken verifies if a token matches a given IP address
|
||||
VerifyToken(sourceAddress []byte, data []byte) (time.Time, error)
|
||||
// DecodeToken decodes a token and returns the source address and the timestamp
|
||||
DecodeToken(data []byte) ([]byte, time.Time, error)
|
||||
}
|
||||
|
||||
type sourceAddressToken struct {
|
||||
sourceAddr []byte
|
||||
data []byte
|
||||
// unix timestamp in seconds
|
||||
timestamp uint64
|
||||
}
|
||||
|
||||
func (t *sourceAddressToken) serialize() []byte {
|
||||
res := make([]byte, 8+len(t.sourceAddr))
|
||||
res := make([]byte, 8+len(t.data))
|
||||
binary.LittleEndian.PutUint64(res, t.timestamp)
|
||||
copy(res[8:], t.sourceAddr)
|
||||
copy(res[8:], t.data)
|
||||
return res
|
||||
}
|
||||
|
||||
@@ -42,8 +41,8 @@ func parseToken(data []byte) (*sourceAddressToken, error) {
|
||||
return nil, fmt.Errorf("invalid STK length: %d", len(data))
|
||||
}
|
||||
return &sourceAddressToken{
|
||||
sourceAddr: data[8:],
|
||||
timestamp: binary.LittleEndian.Uint64(data),
|
||||
data: data[8:],
|
||||
timestamp: binary.LittleEndian.Uint64(data),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -78,38 +77,31 @@ func NewStkSource() (StkSource, error) {
|
||||
return &stkSource{aead: aead}, nil
|
||||
}
|
||||
|
||||
func (s *stkSource) NewToken(sourceAddr []byte) ([]byte, error) {
|
||||
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
|
||||
return encryptToken(s.aead, &sourceAddressToken{
|
||||
sourceAddr: sourceAddr,
|
||||
timestamp: uint64(time.Now().Unix()),
|
||||
data: data,
|
||||
timestamp: uint64(time.Now().Unix()),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *stkSource) VerifyToken(sourceAddr []byte, data []byte) (time.Time, error) {
|
||||
func (s *stkSource) DecodeToken(data []byte) ([]byte, time.Time, error) {
|
||||
if len(data) < stkNonceSize {
|
||||
return time.Time{}, errors.New("STK too short")
|
||||
return nil, time.Time{}, errors.New("STK too short")
|
||||
}
|
||||
nonce := data[:stkNonceSize]
|
||||
|
||||
res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
return nil, time.Time{}, err
|
||||
}
|
||||
|
||||
token, err := parseToken(res)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
return nil, time.Time{}, err
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(token.sourceAddr, sourceAddr) != 1 {
|
||||
return time.Time{}, errors.New("invalid source address in STK")
|
||||
}
|
||||
|
||||
if token.timestamp > math.MaxInt64 {
|
||||
return time.Time{}, errors.New("invalid timestamp")
|
||||
return nil, time.Time{}, errors.New("invalid timestamp")
|
||||
}
|
||||
|
||||
return time.Unix(int64(token.timestamp), 0), nil
|
||||
return token.data, time.Unix(int64(token.timestamp), 0), nil
|
||||
}
|
||||
|
||||
func deriveKey(secret []byte) ([]byte, error) {
|
||||
|
||||
@@ -17,7 +17,7 @@ var _ = Describe("Source Address Tokens", func() {
|
||||
Context("tokens", func() {
|
||||
It("serializes", func() {
|
||||
ip := []byte{127, 0, 0, 1}
|
||||
token := &sourceAddressToken{sourceAddr: ip, timestamp: 0xdeadbeef}
|
||||
token := &sourceAddressToken{data: ip, timestamp: 0xdeadbeef}
|
||||
Expect(token.serialize()).To(Equal([]byte{
|
||||
0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00,
|
||||
127, 0, 0, 1,
|
||||
@@ -30,7 +30,7 @@ var _ = Describe("Source Address Tokens", func() {
|
||||
127, 0, 0, 1,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(token.sourceAddr).To(Equal([]byte{127, 0, 0, 1}))
|
||||
Expect(token.data).To(Equal([]byte{127, 0, 0, 1}))
|
||||
Expect(token.timestamp).To(Equal(uint64(0xdeadbeef)))
|
||||
})
|
||||
|
||||
@@ -70,57 +70,48 @@ var _ = Describe("Source Address Tokens", func() {
|
||||
stk, err := source.NewToken(ip4)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stk).ToNot(BeEmpty())
|
||||
_, err = source.VerifyToken(ip4, stk)
|
||||
decodedIP, _, err := source.DecodeToken(stk)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(decodedIP).To(BeEquivalentTo(ip4))
|
||||
})
|
||||
|
||||
It("generates and verify ipv6 tokens", func() {
|
||||
stk, err := source.NewToken(ip6)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stk).ToNot(BeEmpty())
|
||||
_, err = source.VerifyToken(ip6, stk)
|
||||
decodedIP, _, err := source.DecodeToken(stk)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(decodedIP).To(BeEquivalentTo(ip6))
|
||||
})
|
||||
|
||||
It("rejects empty tokens", func() {
|
||||
_, err := source.VerifyToken(ip4, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
_, _, err := source.DecodeToken([]byte{})
|
||||
Expect(err).To(MatchError("STK too short"))
|
||||
})
|
||||
|
||||
It("rejects invalid tokens", func() {
|
||||
_, err := source.VerifyToken(ip4, []byte("foobar"))
|
||||
_, _, err := source.DecodeToken([]byte("foobar"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects tokens with wrong IP addresses", func() {
|
||||
otherIP := net.ParseIP("4.3.2.1")
|
||||
stk, err := encryptToken(source.aead, &sourceAddressToken{
|
||||
sourceAddr: otherIP,
|
||||
timestamp: uint64(time.Now().Unix()),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = source.VerifyToken(ip4, stk)
|
||||
Expect(err).To(MatchError("invalid source address in STK"))
|
||||
})
|
||||
|
||||
It("rejects overflowing timestamps", func() {
|
||||
stk, err := encryptToken(source.aead, &sourceAddressToken{
|
||||
sourceAddr: ip4,
|
||||
timestamp: math.MaxUint64,
|
||||
data: ip4,
|
||||
timestamp: math.MaxUint64,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = source.VerifyToken(ip4, stk)
|
||||
_, _, err = source.DecodeToken(stk)
|
||||
Expect(err).To(MatchError("invalid timestamp"))
|
||||
})
|
||||
|
||||
It("returns the timestamp encoded in the token", func() {
|
||||
timestamp := time.Now().Add(-time.Hour)
|
||||
stk, err := encryptToken(source.aead, &sourceAddressToken{
|
||||
sourceAddr: ip4,
|
||||
timestamp: uint64(timestamp.Unix()),
|
||||
data: ip4,
|
||||
timestamp: uint64(timestamp.Unix()),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
t, err := source.VerifyToken(ip4, stk)
|
||||
_, t, err := source.DecodeToken(stk)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t).To(BeTemporally("~", timestamp, time.Second))
|
||||
})
|
||||
|
||||
@@ -127,7 +127,7 @@ func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemente
|
||||
func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") }
|
||||
|
||||
type mockStkSource struct {
|
||||
verifyErr error
|
||||
decodeErr error
|
||||
stkTime time.Time
|
||||
}
|
||||
|
||||
@@ -137,24 +137,18 @@ func (mockStkSource) NewToken(sourceAddr []byte) ([]byte, error) {
|
||||
return append([]byte("token "), sourceAddr...), nil
|
||||
}
|
||||
|
||||
func (s mockStkSource) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) {
|
||||
if s.verifyErr != nil {
|
||||
return time.Time{}, s.verifyErr
|
||||
func (s mockStkSource) DecodeToken(data []byte) ([]byte, time.Time, error) {
|
||||
if s.decodeErr != nil {
|
||||
return nil, time.Time{}, s.decodeErr
|
||||
}
|
||||
split := bytes.Split(token, []byte(" "))
|
||||
if len(split) != 2 {
|
||||
return time.Time{}, errors.New("stk required")
|
||||
if len(data) < 6 {
|
||||
return nil, time.Time{}, errors.New("token too short")
|
||||
}
|
||||
if !bytes.Equal(split[0], []byte("token")) {
|
||||
return time.Time{}, errors.New("no prefix match")
|
||||
t := s.stkTime
|
||||
if t.IsZero() {
|
||||
t = time.Now()
|
||||
}
|
||||
if !bytes.Equal(split[1], sourceAddr) {
|
||||
return time.Time{}, errors.New("ip wrong")
|
||||
}
|
||||
if !s.stkTime.IsZero() {
|
||||
return s.stkTime, nil
|
||||
}
|
||||
return time.Now(), nil
|
||||
return data[6:], t, nil
|
||||
}
|
||||
|
||||
var _ = Describe("Server Crypto Setup", func() {
|
||||
@@ -436,7 +430,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||
|
||||
It("recognizes inchoate CHLOs with an invalid STK", func() {
|
||||
testErr := errors.New("STK invalid")
|
||||
cs.stkGenerator.stkSource.(*mockStkSource).verifyErr = testErr
|
||||
cs.stkGenerator.stkSource.(*mockStkSource).decodeErr = testErr
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
@@ -730,33 +724,42 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||
|
||||
Context("STK verification and creation", func() {
|
||||
It("requires STK", func() {
|
||||
done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
|
||||
TagSNI: []byte("foo"),
|
||||
TagVER: versionTag,
|
||||
})
|
||||
done, err := cs.handleMessage(
|
||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
||||
map[Tag][]byte{
|
||||
TagSNI: []byte("foo"),
|
||||
TagVER: versionTag,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(done).To(BeFalse())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
|
||||
})
|
||||
|
||||
It("works with proper STK", func() {
|
||||
done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
|
||||
TagSTK: validSTK,
|
||||
TagSNI: []byte("foo"),
|
||||
TagVER: versionTag,
|
||||
})
|
||||
done, err := cs.handleMessage(
|
||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
||||
map[Tag][]byte{
|
||||
TagSTK: validSTK,
|
||||
TagSNI: []byte("foo"),
|
||||
TagVER: versionTag,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(done).To(BeFalse())
|
||||
Expect(err).To(BeNil())
|
||||
})
|
||||
|
||||
It("errors if IP does not match", func() {
|
||||
done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{
|
||||
TagSNI: []byte("foo"),
|
||||
TagSTK: []byte("token \x04\x03\x03\x01"),
|
||||
TagVER: versionTag,
|
||||
})
|
||||
done, err := cs.handleMessage(
|
||||
bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize),
|
||||
map[Tag][]byte{
|
||||
TagSNI: []byte("foo"),
|
||||
TagSTK: []byte("token \x04\x03\x03\x01"),
|
||||
TagVER: versionTag,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(done).To(BeFalse())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
@@ -28,6 +30,13 @@ func (g *STKGenerator) NewToken(sourceAddr []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// VerifyToken verifies an STK token
|
||||
func (g *STKGenerator) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) {
|
||||
return g.stkSource.VerifyToken(sourceAddr, token)
|
||||
func (g *STKGenerator) VerifyToken(sourceAddr []byte, data []byte) (time.Time, error) {
|
||||
tokenAddr, timestamp, err := g.stkSource.DecodeToken(data)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if subtle.ConstantTimeCompare(sourceAddr, tokenAddr) != 1 {
|
||||
return time.Time{}, errors.New("invalid source address in STK")
|
||||
}
|
||||
return timestamp, nil
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ var _ = Describe("STK Generator", func() {
|
||||
Expect(token).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("verifies an STK", func() {
|
||||
It("accepts a valid STK", func() {
|
||||
ip := net.IPv4(192, 168, 0, 1)
|
||||
token, err := stkGen.NewToken(ip)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -32,4 +32,13 @@ var _ = Describe("STK Generator", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
||||
})
|
||||
|
||||
It("rejects an STK for the wrong address", func() {
|
||||
ip := net.ParseIP("1.2.3.4")
|
||||
otherIP := net.ParseIP("4.3.2.1")
|
||||
token, err := stkGen.NewToken(ip)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = stkGen.VerifyToken(otherIP, token)
|
||||
Expect(err).To(MatchError("invalid source address in STK"))
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user