diff --git a/crypto/source_address_token.go b/crypto/source_address_token.go index df220221..3dcb26a7 100644 --- a/crypto/source_address_token.go +++ b/crypto/source_address_token.go @@ -5,45 +5,18 @@ import ( "crypto/cipher" "crypto/rand" "crypto/sha256" - "encoding/binary" - "errors" "fmt" "io" - "math" - "time" "golang.org/x/crypto/hkdf" ) // StkSource is used to create and verify source address tokens type StkSource interface { - // NewToken creates a new token for a given IP address - NewToken(sourceAddress []byte) ([]byte, error) - // DecodeToken decodes a token and returns the source address and the timestamp - DecodeToken(data []byte) ([]byte, time.Time, error) -} - -type sourceAddressToken struct { - data []byte - // unix timestamp in seconds - timestamp uint64 -} - -func (t *sourceAddressToken) serialize() []byte { - res := make([]byte, 8+len(t.data)) - binary.LittleEndian.PutUint64(res, t.timestamp) - copy(res[8:], t.data) - return res -} - -func parseToken(data []byte) (*sourceAddressToken, error) { - if len(data) < 8 { - return nil, fmt.Errorf("STK too short: %d", len(data)) - } - return &sourceAddressToken{ - data: data[8:], - timestamp: binary.LittleEndian.Uint64(data), - }, nil + // NewToken creates a new token + NewToken([]byte) ([]byte, error) + // DecodeToken decodes a token + DecodeToken([]byte) ([]byte, error) } type stkSource struct { @@ -78,30 +51,19 @@ func NewStkSource() (StkSource, error) { } func (s *stkSource) NewToken(data []byte) ([]byte, error) { - return encryptToken(s.aead, &sourceAddressToken{ - data: data, - timestamp: uint64(time.Now().Unix()), - }) + nonce := make([]byte, stkNonceSize) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + return s.aead.Seal(nonce, nonce, data, nil), nil } -func (s *stkSource) DecodeToken(data []byte) ([]byte, time.Time, error) { - if len(data) < stkNonceSize { - return nil, time.Time{}, errors.New("STK too short") +func (s *stkSource) DecodeToken(p []byte) ([]byte, error) { + if len(p) < stkNonceSize { + return nil, fmt.Errorf("STK too short: %d", len(p)) } - nonce := data[:stkNonceSize] - res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil) - if err != nil { - return nil, time.Time{}, err - } - - token, err := parseToken(res) - if err != nil { - return nil, time.Time{}, err - } - if token.timestamp > math.MaxInt64 { - return nil, time.Time{}, errors.New("invalid timestamp") - } - return token.data, time.Unix(int64(token.timestamp), 0), nil + nonce := p[:stkNonceSize] + return s.aead.Open(nil, nonce, p[stkNonceSize:], nil) } func deriveKey(secret []byte) ([]byte, error) { @@ -112,11 +74,3 @@ func deriveKey(secret []byte) ([]byte, error) { } return key, nil } - -func encryptToken(aead cipher.AEAD, token *sourceAddressToken) ([]byte, error) { - nonce := make([]byte, stkNonceSize) - if _, err := rand.Read(nonce); err != nil { - return nil, err - } - return aead.Seal(nonce, nonce, token.serialize(), nil), nil -} diff --git a/crypto/source_address_token_test.go b/crypto/source_address_token_test.go index e0101925..d25a2ba9 100644 --- a/crypto/source_address_token_test.go +++ b/crypto/source_address_token_test.go @@ -1,10 +1,6 @@ package crypto import ( - "math" - "net" - "time" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -15,105 +11,31 @@ var _ = Describe("Source Address Tokens", func() { }) Context("tokens", func() { - It("serializes", func() { - ip := []byte{127, 0, 0, 1} - token := &sourceAddressToken{data: ip, timestamp: 0xdeadbeef} - Expect(token.serialize()).To(Equal([]byte{ - 0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00, - 127, 0, 0, 1, - })) - }) - - It("reads", func() { - token, err := parseToken([]byte{ - 0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00, - 127, 0, 0, 1, - }) - Expect(err).NotTo(HaveOccurred()) - Expect(token.data).To(Equal([]byte{127, 0, 0, 1})) - Expect(token.timestamp).To(Equal(uint64(0xdeadbeef))) - }) - - It("rejects tokens of wrong size", func() { - _, err := parseToken(nil) - Expect(err).To(MatchError("STK too short: 0")) - }) - }) - - Context("source", func() { - var ( - source *stkSource - ip4 net.IP - ip6 net.IP - ) + var source *stkSource BeforeEach(func() { - var err error - - ip4 = net.ParseIP("1.2.3.4") - Expect(ip4).NotTo(BeEmpty()) - ip6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329") - Expect(ip6).NotTo(BeEmpty()) - sourceI, err := NewStkSource() source = sourceI.(*stkSource) Expect(err).NotTo(HaveOccurred()) }) - It("generates new tokens", func() { - token, err := source.NewToken(ip4) - Expect(err).NotTo(HaveOccurred()) - Expect(token).ToNot(BeEmpty()) - }) - - It("generates and verifies ipv4 tokens", func() { - stk, err := source.NewToken(ip4) - Expect(err).NotTo(HaveOccurred()) - Expect(stk).ToNot(BeEmpty()) - decodedIP, _, err := source.DecodeToken(stk) - Expect(err).NotTo(HaveOccurred()) - Expect(decodedIP).To(BeEquivalentTo(ip4)) - }) - - It("generates and verify ipv6 tokens", func() { - stk, err := source.NewToken(ip6) - Expect(err).NotTo(HaveOccurred()) - Expect(stk).ToNot(BeEmpty()) - decodedIP, _, err := source.DecodeToken(stk) - Expect(err).NotTo(HaveOccurred()) - Expect(decodedIP).To(BeEquivalentTo(ip6)) - }) - - It("rejects empty tokens", func() { - _, _, err := source.DecodeToken([]byte{}) - Expect(err).To(MatchError("STK too short")) + It("serializes", func() { + token, err := source.NewToken([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + data, err := source.DecodeToken(token) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) }) It("rejects invalid tokens", func() { - _, _, err := source.DecodeToken([]byte("foobar")) + _, err := source.DecodeToken([]byte("invalid source address token")) Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("message authentication failed")) }) - It("rejects overflowing timestamps", func() { - stk, err := encryptToken(source.aead, &sourceAddressToken{ - data: ip4, - timestamp: math.MaxUint64, - }) - Expect(err).NotTo(HaveOccurred()) - _, _, err = source.DecodeToken(stk) - Expect(err).To(MatchError("invalid timestamp")) - }) - - It("returns the timestamp encoded in the token", func() { - timestamp := time.Now().Add(-time.Hour) - stk, err := encryptToken(source.aead, &sourceAddressToken{ - data: ip4, - timestamp: uint64(timestamp.Unix()), - }) - Expect(err).NotTo(HaveOccurred()) - _, t, err := source.DecodeToken(stk) - Expect(err).ToNot(HaveOccurred()) - Expect(t).To(BeTemporally("~", timestamp, time.Second)) + It("rejects tokens of wrong size", func() { + _, err := source.DecodeToken(nil) + Expect(err).To(MatchError("STK too short: 0")) }) }) }) diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index 764f262b..ea5b04dd 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "errors" "net" - "time" "github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/protocol" @@ -127,8 +126,8 @@ func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemente func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") } type mockStkSource struct { + data []byte decodeErr error - stkTime time.Time } var _ crypto.StkSource = &mockStkSource{} @@ -137,18 +136,14 @@ func (mockStkSource) NewToken(sourceAddr []byte) ([]byte, error) { return append([]byte("token "), sourceAddr...), nil } -func (s mockStkSource) DecodeToken(data []byte) ([]byte, time.Time, error) { +func (s mockStkSource) DecodeToken(data []byte) ([]byte, error) { if s.decodeErr != nil { - return nil, time.Time{}, s.decodeErr + return nil, s.decodeErr } if len(data) < 6 { - return nil, time.Time{}, errors.New("token too short") + return nil, errors.New("token too short") } - t := s.stkTime - if t.IsZero() { - t = time.Now() - } - return data[6:], t, nil + return data[6:], nil } var _ = Describe("Server Crypto Setup", func() { diff --git a/handshake/stk_generator.go b/handshake/stk_generator.go index cf803082..497ef213 100644 --- a/handshake/stk_generator.go +++ b/handshake/stk_generator.go @@ -1,6 +1,8 @@ package handshake import ( + "encoding/asn1" + "fmt" "net" "time" @@ -18,6 +20,12 @@ type STK struct { SentTime time.Time } +// token is the struct that is used for ASN1 serialization and deserialization +type token struct { + Data []byte + Timestamp int64 +} + // An STKGenerator generates STKs type STKGenerator struct { stkSource crypto.StkSource @@ -36,22 +44,38 @@ func NewSTKGenerator() (*STKGenerator, error) { // NewToken generates a new STK token for a given source address func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) { - return g.stkSource.NewToken(encodeRemoteAddr(raddr)) -} - -// DecodeToken decodes an STK token -func (g *STKGenerator) DecodeToken(data []byte) (*STK, error) { - // if the client didn't send any STK, DecodeToken will be called with a nil-slice - if len(data) == 0 { - return nil, nil - } - remote, timestamp, err := g.stkSource.DecodeToken(data) + data, err := asn1.Marshal(token{ + Data: encodeRemoteAddr(raddr), + Timestamp: time.Now().Unix(), + }) if err != nil { return nil, err } + return g.stkSource.NewToken(data) +} + +// DecodeToken decodes an STK token +func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) { + // if the client didn't send any STK, DecodeToken will be called with a nil-slice + if len(encrypted) == 0 { + return nil, nil + } + + data, err := g.stkSource.DecodeToken(encrypted) + if err != nil { + return nil, err + } + t := &token{} + rest, err := asn1.Unmarshal(data, t) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) + } return &STK{ - RemoteAddr: decodeRemoteAddr(remote), - SentTime: timestamp, + RemoteAddr: decodeRemoteAddr(t.Data), + SentTime: time.Unix(t.Timestamp, 0), }, nil } diff --git a/handshake/stk_generator_test.go b/handshake/stk_generator_test.go index 99c26502..c3dd3aa7 100644 --- a/handshake/stk_generator_test.go +++ b/handshake/stk_generator_test.go @@ -1,6 +1,7 @@ package handshake import ( + "encoding/asn1" "net" "time" @@ -40,6 +41,28 @@ var _ = Describe("STK Generator", func() { Expect(stk.SentTime).To(BeTemporally("~", time.Now(), time.Second)) }) + It("rejects invalid tokens", func() { + _, err := stkGen.DecodeToken([]byte("invalid token")) + Expect(err).To(HaveOccurred()) + }) + + It("rejects tokens that cannot be decoded", func() { + token, err := stkGen.stkSource.NewToken([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + _, err = stkGen.DecodeToken(token) + Expect(err).To(HaveOccurred()) + }) + + It("rejects tokens that can be decoded, but have additional payload", func() { + t, err := asn1.Marshal(token{Data: []byte("foobar")}) + Expect(err).ToNot(HaveOccurred()) + t = append(t, []byte("rest")...) + enc, err := stkGen.stkSource.NewToken(t) + Expect(err).ToNot(HaveOccurred()) + _, err = stkGen.DecodeToken(enc) + Expect(err).To(MatchError("rest when unpacking token: 4")) + }) + It("works with an IPv6 addresses ", func() { addresses := []string{ "2001:db8::68",