From 8e019214953cd0230d3ef913bd9c481323d4f821 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 13 May 2017 13:25:10 +0800 Subject: [PATCH] move comparison of the source address in the STK to the STKGenerator --- crypto/source_address_token.go | 40 ++++++--------- crypto/source_address_token_test.go | 39 ++++++--------- handshake/crypto_setup_server_test.go | 71 ++++++++++++++------------- handshake/stk_generator.go | 13 ++++- handshake/stk_generator_test.go | 11 ++++- 5 files changed, 89 insertions(+), 85 deletions(-) diff --git a/crypto/source_address_token.go b/crypto/source_address_token.go index e3e9ec87..5bc5c1e6 100644 --- a/crypto/source_address_token.go +++ b/crypto/source_address_token.go @@ -5,7 +5,6 @@ import ( "crypto/cipher" "crypto/rand" "crypto/sha256" - "crypto/subtle" "encoding/binary" "errors" "fmt" @@ -20,20 +19,20 @@ import ( type StkSource interface { // NewToken creates a new token for a given IP address NewToken(sourceAddress []byte) ([]byte, error) - // VerifyToken verifies if a token matches a given IP address - VerifyToken(sourceAddress []byte, data []byte) (time.Time, error) + // DecodeToken decodes a token and returns the source address and the timestamp + DecodeToken(data []byte) ([]byte, time.Time, error) } type sourceAddressToken struct { - sourceAddr []byte + data []byte // unix timestamp in seconds timestamp uint64 } func (t *sourceAddressToken) serialize() []byte { - res := make([]byte, 8+len(t.sourceAddr)) + res := make([]byte, 8+len(t.data)) binary.LittleEndian.PutUint64(res, t.timestamp) - copy(res[8:], t.sourceAddr) + copy(res[8:], t.data) return res } @@ -42,8 +41,8 @@ func parseToken(data []byte) (*sourceAddressToken, error) { return nil, fmt.Errorf("invalid STK length: %d", len(data)) } return &sourceAddressToken{ - sourceAddr: data[8:], - timestamp: binary.LittleEndian.Uint64(data), + data: data[8:], + timestamp: binary.LittleEndian.Uint64(data), }, nil } @@ -78,38 +77,31 @@ func NewStkSource() (StkSource, error) { return &stkSource{aead: aead}, nil } -func (s *stkSource) NewToken(sourceAddr []byte) ([]byte, error) { +func (s *stkSource) NewToken(data []byte) ([]byte, error) { return encryptToken(s.aead, &sourceAddressToken{ - sourceAddr: sourceAddr, - timestamp: uint64(time.Now().Unix()), + data: data, + timestamp: uint64(time.Now().Unix()), }) } -func (s *stkSource) VerifyToken(sourceAddr []byte, data []byte) (time.Time, error) { +func (s *stkSource) DecodeToken(data []byte) ([]byte, time.Time, error) { if len(data) < stkNonceSize { - return time.Time{}, errors.New("STK too short") + return nil, time.Time{}, errors.New("STK too short") } nonce := data[:stkNonceSize] - res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil) if err != nil { - return time.Time{}, err + return nil, time.Time{}, err } token, err := parseToken(res) if err != nil { - return time.Time{}, err + return nil, time.Time{}, err } - - if subtle.ConstantTimeCompare(token.sourceAddr, sourceAddr) != 1 { - return time.Time{}, errors.New("invalid source address in STK") - } - if token.timestamp > math.MaxInt64 { - return time.Time{}, errors.New("invalid timestamp") + return nil, time.Time{}, errors.New("invalid timestamp") } - - return time.Unix(int64(token.timestamp), 0), nil + return token.data, time.Unix(int64(token.timestamp), 0), nil } func deriveKey(secret []byte) ([]byte, error) { diff --git a/crypto/source_address_token_test.go b/crypto/source_address_token_test.go index 748f421d..843b6ed5 100644 --- a/crypto/source_address_token_test.go +++ b/crypto/source_address_token_test.go @@ -17,7 +17,7 @@ var _ = Describe("Source Address Tokens", func() { Context("tokens", func() { It("serializes", func() { ip := []byte{127, 0, 0, 1} - token := &sourceAddressToken{sourceAddr: ip, timestamp: 0xdeadbeef} + token := &sourceAddressToken{data: ip, timestamp: 0xdeadbeef} Expect(token.serialize()).To(Equal([]byte{ 0xef, 0xbe, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00, 127, 0, 0, 1, @@ -30,7 +30,7 @@ var _ = Describe("Source Address Tokens", func() { 127, 0, 0, 1, }) Expect(err).NotTo(HaveOccurred()) - Expect(token.sourceAddr).To(Equal([]byte{127, 0, 0, 1})) + Expect(token.data).To(Equal([]byte{127, 0, 0, 1})) Expect(token.timestamp).To(Equal(uint64(0xdeadbeef))) }) @@ -70,57 +70,48 @@ var _ = Describe("Source Address Tokens", func() { stk, err := source.NewToken(ip4) Expect(err).NotTo(HaveOccurred()) Expect(stk).ToNot(BeEmpty()) - _, err = source.VerifyToken(ip4, stk) + decodedIP, _, err := source.DecodeToken(stk) Expect(err).NotTo(HaveOccurred()) + Expect(decodedIP).To(BeEquivalentTo(ip4)) }) It("generates and verify ipv6 tokens", func() { stk, err := source.NewToken(ip6) Expect(err).NotTo(HaveOccurred()) Expect(stk).ToNot(BeEmpty()) - _, err = source.VerifyToken(ip6, stk) + decodedIP, _, err := source.DecodeToken(stk) Expect(err).NotTo(HaveOccurred()) + Expect(decodedIP).To(BeEquivalentTo(ip6)) }) It("rejects empty tokens", func() { - _, err := source.VerifyToken(ip4, nil) - Expect(err).To(HaveOccurred()) + _, _, err := source.DecodeToken([]byte{}) + Expect(err).To(MatchError("STK too short")) }) It("rejects invalid tokens", func() { - _, err := source.VerifyToken(ip4, []byte("foobar")) + _, _, err := source.DecodeToken([]byte("foobar")) Expect(err).To(HaveOccurred()) }) - It("rejects tokens with wrong IP addresses", func() { - otherIP := net.ParseIP("4.3.2.1") - stk, err := encryptToken(source.aead, &sourceAddressToken{ - sourceAddr: otherIP, - timestamp: uint64(time.Now().Unix()), - }) - Expect(err).NotTo(HaveOccurred()) - _, err = source.VerifyToken(ip4, stk) - Expect(err).To(MatchError("invalid source address in STK")) - }) - It("rejects overflowing timestamps", func() { stk, err := encryptToken(source.aead, &sourceAddressToken{ - sourceAddr: ip4, - timestamp: math.MaxUint64, + data: ip4, + timestamp: math.MaxUint64, }) Expect(err).NotTo(HaveOccurred()) - _, err = source.VerifyToken(ip4, stk) + _, _, err = source.DecodeToken(stk) Expect(err).To(MatchError("invalid timestamp")) }) It("returns the timestamp encoded in the token", func() { timestamp := time.Now().Add(-time.Hour) stk, err := encryptToken(source.aead, &sourceAddressToken{ - sourceAddr: ip4, - timestamp: uint64(timestamp.Unix()), + data: ip4, + timestamp: uint64(timestamp.Unix()), }) Expect(err).NotTo(HaveOccurred()) - t, err := source.VerifyToken(ip4, stk) + _, t, err := source.DecodeToken(stk) Expect(err).ToNot(HaveOccurred()) Expect(t).To(BeTemporally("~", timestamp, time.Second)) }) diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index c371fa78..df2ebd88 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -127,7 +127,7 @@ func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemente func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") } type mockStkSource struct { - verifyErr error + decodeErr error stkTime time.Time } @@ -137,24 +137,18 @@ func (mockStkSource) NewToken(sourceAddr []byte) ([]byte, error) { return append([]byte("token "), sourceAddr...), nil } -func (s mockStkSource) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) { - if s.verifyErr != nil { - return time.Time{}, s.verifyErr +func (s mockStkSource) DecodeToken(data []byte) ([]byte, time.Time, error) { + if s.decodeErr != nil { + return nil, time.Time{}, s.decodeErr } - split := bytes.Split(token, []byte(" ")) - if len(split) != 2 { - return time.Time{}, errors.New("stk required") + if len(data) < 6 { + return nil, time.Time{}, errors.New("token too short") } - if !bytes.Equal(split[0], []byte("token")) { - return time.Time{}, errors.New("no prefix match") + t := s.stkTime + if t.IsZero() { + t = time.Now() } - if !bytes.Equal(split[1], sourceAddr) { - return time.Time{}, errors.New("ip wrong") - } - if !s.stkTime.IsZero() { - return s.stkTime, nil - } - return time.Now(), nil + return data[6:], t, nil } var _ = Describe("Server Crypto Setup", func() { @@ -436,7 +430,7 @@ var _ = Describe("Server Crypto Setup", func() { It("recognizes inchoate CHLOs with an invalid STK", func() { testErr := errors.New("STK invalid") - cs.stkGenerator.stkSource.(*mockStkSource).verifyErr = testErr + cs.stkGenerator.stkSource.(*mockStkSource).decodeErr = testErr Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue()) }) @@ -730,33 +724,42 @@ var _ = Describe("Server Crypto Setup", func() { Context("STK verification and creation", func() { It("requires STK", func() { - done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{ - TagSNI: []byte("foo"), - TagVER: versionTag, - }) + done, err := cs.handleMessage( + bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), + map[Tag][]byte{ + TagSNI: []byte("foo"), + TagVER: versionTag, + }, + ) + Expect(err).ToNot(HaveOccurred()) Expect(done).To(BeFalse()) - Expect(err).To(BeNil()) Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK))) }) It("works with proper STK", func() { - done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{ - TagSTK: validSTK, - TagSNI: []byte("foo"), - TagVER: versionTag, - }) + done, err := cs.handleMessage( + bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), + map[Tag][]byte{ + TagSTK: validSTK, + TagSNI: []byte("foo"), + TagVER: versionTag, + }, + ) + Expect(err).ToNot(HaveOccurred()) Expect(done).To(BeFalse()) - Expect(err).To(BeNil()) }) It("errors if IP does not match", func() { - done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{ - TagSNI: []byte("foo"), - TagSTK: []byte("token \x04\x03\x03\x01"), - TagVER: versionTag, - }) + done, err := cs.handleMessage( + bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), + map[Tag][]byte{ + TagSNI: []byte("foo"), + TagSTK: []byte("token \x04\x03\x03\x01"), + TagVER: versionTag, + }, + ) + Expect(err).ToNot(HaveOccurred()) Expect(done).To(BeFalse()) - Expect(err).To(BeNil()) Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK))) }) }) diff --git a/handshake/stk_generator.go b/handshake/stk_generator.go index eca677b6..1b33496b 100644 --- a/handshake/stk_generator.go +++ b/handshake/stk_generator.go @@ -1,6 +1,8 @@ package handshake import ( + "crypto/subtle" + "errors" "time" "github.com/lucas-clemente/quic-go/crypto" @@ -28,6 +30,13 @@ func (g *STKGenerator) NewToken(sourceAddr []byte) ([]byte, error) { } // VerifyToken verifies an STK token -func (g *STKGenerator) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) { - return g.stkSource.VerifyToken(sourceAddr, token) +func (g *STKGenerator) VerifyToken(sourceAddr []byte, data []byte) (time.Time, error) { + tokenAddr, timestamp, err := g.stkSource.DecodeToken(data) + if err != nil { + return time.Time{}, err + } + if subtle.ConstantTimeCompare(sourceAddr, tokenAddr) != 1 { + return time.Time{}, errors.New("invalid source address in STK") + } + return timestamp, nil } diff --git a/handshake/stk_generator_test.go b/handshake/stk_generator_test.go index cad4b243..bcfe82dc 100644 --- a/handshake/stk_generator_test.go +++ b/handshake/stk_generator_test.go @@ -24,7 +24,7 @@ var _ = Describe("STK Generator", func() { Expect(token).ToNot(BeEmpty()) }) - It("verifies an STK", func() { + It("accepts a valid STK", func() { ip := net.IPv4(192, 168, 0, 1) token, err := stkGen.NewToken(ip) Expect(err).ToNot(HaveOccurred()) @@ -32,4 +32,13 @@ var _ = Describe("STK Generator", func() { Expect(err).ToNot(HaveOccurred()) Expect(t).To(BeTemporally("~", time.Now(), time.Second)) }) + + It("rejects an STK for the wrong address", func() { + ip := net.ParseIP("1.2.3.4") + otherIP := net.ParseIP("4.3.2.1") + token, err := stkGen.NewToken(ip) + Expect(err).NotTo(HaveOccurred()) + _, err = stkGen.VerifyToken(otherIP, token) + Expect(err).To(MatchError("invalid source address in STK")) + }) })