From 14fae7b6d3b5d7c82308ba4ef243f9c3f07860c5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 11 Sep 2017 18:10:58 +0200 Subject: [PATCH] rename the STKGenerator to CookieGenerator --- interface.go | 2 +- internal/handshake/cookie_generator.go | 101 +++++++++++++++++ ...rator_test.go => cookie_generator_test.go} | 66 +++++------ internal/handshake/crypto_setup_server.go | 8 +- .../handshake/crypto_setup_server_test.go | 14 +-- internal/handshake/stk_generator.go | 106 ------------------ session_test.go | 8 +- 7 files changed, 150 insertions(+), 155 deletions(-) create mode 100644 internal/handshake/cookie_generator.go rename internal/handshake/{stk_generator_test.go => cookie_generator_test.go} (51%) delete mode 100644 internal/handshake/stk_generator.go diff --git a/interface.go b/interface.go index 2903bc73..34226422 100644 --- a/interface.go +++ b/interface.go @@ -17,7 +17,7 @@ type StreamID = protocol.StreamID type VersionNumber = protocol.VersionNumber // An STK can be used to verify the ownership of the client address. -type STK = handshake.STK +type STK = handshake.Cookie // Stream is the interface implemented by QUIC streams type Stream interface { diff --git a/internal/handshake/cookie_generator.go b/internal/handshake/cookie_generator.go new file mode 100644 index 00000000..10281fa6 --- /dev/null +++ b/internal/handshake/cookie_generator.go @@ -0,0 +1,101 @@ +package handshake + +import ( + "encoding/asn1" + "fmt" + "net" + "time" + + "github.com/lucas-clemente/quic-go/internal/crypto" +) + +const ( + cookiePrefixIP byte = iota + cookiePrefixString +) + +// A Cookie is derived from the client address and can be used to verify the ownership of this address. +type Cookie struct { + RemoteAddr string + // The time that the STK was issued (resolution 1 second) + SentTime time.Time +} + +// token is the struct that is used for ASN1 serialization and deserialization +type token struct { + Data []byte + Timestamp int64 +} + +// A CookieGenerator generates Cookies +type CookieGenerator struct { + cookieSource crypto.StkSource +} + +// NewCookieGenerator initializes a new CookieGenerator +func NewCookieGenerator() (*CookieGenerator, error) { + stkSource, err := crypto.NewStkSource() + if err != nil { + return nil, err + } + return &CookieGenerator{ + cookieSource: stkSource, + }, nil +} + +// NewToken generates a new Cookie for a given source address +func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) { + data, err := asn1.Marshal(token{ + Data: encodeRemoteAddr(raddr), + Timestamp: time.Now().Unix(), + }) + if err != nil { + return nil, err + } + return g.cookieSource.NewToken(data) +} + +// DecodeToken decodes a Cookie +func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) { + // if the client didn't send any Cookie, DecodeToken will be called with a nil-slice + if len(encrypted) == 0 { + return nil, nil + } + + data, err := g.cookieSource.DecodeToken(encrypted) + if err != nil { + return nil, err + } + t := &token{} + rest, err := asn1.Unmarshal(data, t) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) + } + return &Cookie{ + RemoteAddr: decodeRemoteAddr(t.Data), + SentTime: time.Unix(t.Timestamp, 0), + }, nil +} + +// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie +func encodeRemoteAddr(remoteAddr net.Addr) []byte { + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + return append([]byte{cookiePrefixIP}, udpAddr.IP...) + } + return append([]byte{cookiePrefixString}, []byte(remoteAddr.String())...) +} + +// decodeRemoteAddr decodes the remote address saved in the Cookie +func decodeRemoteAddr(data []byte) string { + // data will never be empty for a Cookie that we generated. Check it to be on the safe side + if len(data) == 0 { + return "" + } + if data[0] == cookiePrefixIP { + return net.IP(data[1:]).String() + } + return string(data[1:]) +} diff --git a/internal/handshake/stk_generator_test.go b/internal/handshake/cookie_generator_test.go similarity index 51% rename from internal/handshake/stk_generator_test.go rename to internal/handshake/cookie_generator_test.go index e224d312..7f908e6b 100644 --- a/internal/handshake/stk_generator_test.go +++ b/internal/handshake/cookie_generator_test.go @@ -9,49 +9,49 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("STK Generator", func() { - var stkGen *STKGenerator +var _ = Describe("Cookie Generator", func() { + var cookieGen *CookieGenerator BeforeEach(func() { var err error - stkGen, err = NewSTKGenerator() + cookieGen, err = NewCookieGenerator() Expect(err).ToNot(HaveOccurred()) }) - It("generates an STK", func() { + It("generates a Cookie", func() { ip := net.IPv4(127, 0, 0, 1) - token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) + token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) Expect(err).ToNot(HaveOccurred()) Expect(token).ToNot(BeEmpty()) }) It("works with nil tokens", func() { - stk, err := stkGen.DecodeToken(nil) + cookie, err := cookieGen.DecodeToken(nil) Expect(err).ToNot(HaveOccurred()) - Expect(stk).To(BeNil()) + Expect(cookie).To(BeNil()) }) - It("accepts a valid STK", func() { + It("accepts a valid cookie", func() { ip := net.IPv4(192, 168, 0, 1) - token, err := stkGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) + token, err := cookieGen.NewToken(&net.UDPAddr{IP: ip, Port: 1337}) Expect(err).ToNot(HaveOccurred()) - stk, err := stkGen.DecodeToken(token) + cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) - Expect(stk.RemoteAddr).To(Equal("192.168.0.1")) - // the time resolution of the STK is just 1 second - // if STK generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds - Expect(stk.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second)) + Expect(cookie.RemoteAddr).To(Equal("192.168.0.1")) + // the time resolution of the Cookie is just 1 second + // if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds + Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second)) }) It("rejects invalid tokens", func() { - _, err := stkGen.DecodeToken([]byte("invalid token")) + _, err := cookieGen.DecodeToken([]byte("invalid token")) Expect(err).To(HaveOccurred()) }) It("rejects tokens that cannot be decoded", func() { - token, err := stkGen.stkSource.NewToken([]byte("foobar")) + token, err := cookieGen.cookieSource.NewToken([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - _, err = stkGen.DecodeToken(token) + _, err = cookieGen.DecodeToken(token) Expect(err).To(HaveOccurred()) }) @@ -59,9 +59,9 @@ var _ = Describe("STK Generator", func() { t, err := asn1.Marshal(token{Data: []byte("foobar")}) Expect(err).ToNot(HaveOccurred()) t = append(t, []byte("rest")...) - enc, err := stkGen.stkSource.NewToken(t) + enc, err := cookieGen.cookieSource.NewToken(t) Expect(err).ToNot(HaveOccurred()) - _, err = stkGen.DecodeToken(enc) + _, err = cookieGen.DecodeToken(enc) Expect(err).To(MatchError("rest when unpacking token: 4")) }) @@ -69,9 +69,9 @@ var _ = Describe("STK Generator", func() { It("doesn't panic if a tokens has no data", func() { t, err := asn1.Marshal(token{Data: []byte("")}) Expect(err).ToNot(HaveOccurred()) - enc, err := stkGen.stkSource.NewToken(t) + enc, err := cookieGen.cookieSource.NewToken(t) Expect(err).ToNot(HaveOccurred()) - _, err = stkGen.DecodeToken(enc) + _, err = cookieGen.DecodeToken(enc) Expect(err).ToNot(HaveOccurred()) }) @@ -86,26 +86,26 @@ var _ = Describe("STK Generator", func() { ip := net.ParseIP(addr) Expect(ip).ToNot(BeNil()) raddr := &net.UDPAddr{IP: ip, Port: 1337} - token, err := stkGen.NewToken(raddr) + token, err := cookieGen.NewToken(raddr) Expect(err).ToNot(HaveOccurred()) - stk, err := stkGen.DecodeToken(token) + cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) - Expect(stk.RemoteAddr).To(Equal(ip.String())) - // the time resolution of the STK is just 1 second - // if STK generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds - Expect(stk.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second)) + Expect(cookie.RemoteAddr).To(Equal(ip.String())) + // the time resolution of the Cookie is just 1 second + // if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds + Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second)) } }) It("uses the string representation an address that is not a UDP address", func() { raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - token, err := stkGen.NewToken(raddr) + token, err := cookieGen.NewToken(raddr) Expect(err).ToNot(HaveOccurred()) - stk, err := stkGen.DecodeToken(token) + cookie, err := cookieGen.DecodeToken(token) Expect(err).ToNot(HaveOccurred()) - Expect(stk.RemoteAddr).To(Equal("192.168.13.37:1337")) - // the time resolution of the STK is just 1 second - // if STK generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds - Expect(stk.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second)) + Expect(cookie.RemoteAddr).To(Equal("192.168.13.37:1337")) + // the time resolution of the Cookie is just 1 second + // if Cookie generation and this check happen in "different seconds", the difference will be between 1 and 2 seconds + Expect(cookie.SentTime).To(BeTemporally("~", time.Now(), 2*time.Second)) }) }) diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index 8093bdd9..be55080c 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -26,13 +26,13 @@ type cryptoSetupServer struct { connID protocol.ConnectionID remoteAddr net.Addr scfg *ServerConfig - stkGenerator *STKGenerator + stkGenerator *CookieGenerator diversificationNonce []byte version protocol.VersionNumber supportedVersions []protocol.VersionNumber - acceptSTKCallback func(net.Addr, *STK) bool + acceptSTKCallback func(net.Addr, *Cookie) bool nullAEAD crypto.AEAD secureAEAD crypto.AEAD @@ -72,10 +72,10 @@ func NewCryptoSetup( cryptoStream io.ReadWriter, connectionParametersManager ConnectionParametersManager, supportedVersions []protocol.VersionNumber, - acceptSTK func(net.Addr, *STK) bool, + acceptSTK func(net.Addr, *Cookie) bool, aeadChanged chan<- protocol.EncryptionLevel, ) (CryptoSetup, error) { - stkGenerator, err := NewSTKGenerator() + stkGenerator, err := NewCookieGenerator() if err != nil { return nil, err } diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index 3c5b72aa..724caa10 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -131,18 +131,18 @@ func (s *mockStream) Reset(error) { panic("not implemente func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") } func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") } -type mockStkSource struct { +type mockCookieSource struct { data []byte decodeErr error } -var _ crypto.StkSource = &mockStkSource{} +var _ crypto.StkSource = &mockCookieSource{} -func (mockStkSource) NewToken(sourceAddr []byte) ([]byte, error) { +func (mockCookieSource) NewToken(sourceAddr []byte) ([]byte, error) { return append([]byte("token "), sourceAddr...), nil } -func (s mockStkSource) DecodeToken(data []byte) ([]byte, error) { +func (s mockCookieSource) DecodeToken(data []byte) ([]byte, error) { if s.decodeErr != nil { return nil, s.decodeErr } @@ -209,11 +209,11 @@ var _ = Describe("Server Crypto Setup", func() { ) Expect(err).NotTo(HaveOccurred()) cs = csInt.(*cryptoSetupServer) - cs.stkGenerator.stkSource = &mockStkSource{} + cs.stkGenerator.cookieSource = &mockCookieSource{} validSTK, err = cs.stkGenerator.NewToken(remoteAddr) Expect(err).NotTo(HaveOccurred()) sourceAddrValid = true - cs.acceptSTKCallback = func(_ net.Addr, _ *STK) bool { return sourceAddrValid } + cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid } cs.keyDerivation = mockQuicCryptoKeyDerivation cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } }) @@ -422,7 +422,7 @@ var _ = Describe("Server Crypto Setup", func() { It("recognizes inchoate CHLOs with an invalid STK", func() { testErr := errors.New("STK invalid") - cs.stkGenerator.stkSource.(*mockStkSource).decodeErr = testErr + cs.stkGenerator.cookieSource.(*mockCookieSource).decodeErr = testErr Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue()) }) diff --git a/internal/handshake/stk_generator.go b/internal/handshake/stk_generator.go deleted file mode 100644 index 2b84268c..00000000 --- a/internal/handshake/stk_generator.go +++ /dev/null @@ -1,106 +0,0 @@ -package handshake - -import ( - "encoding/asn1" - "fmt" - "net" - "time" - - "github.com/lucas-clemente/quic-go/internal/crypto" -) - -const ( - stkPrefixIP byte = iota - stkPrefixString -) - -// An STK is a Source Address token. -// It is issued by the server and sent to the client. For the client, it is an opaque blob. -// The client can send the STK in subsequent handshakes to prove ownership of its IP address. -type STK struct { - // The remote address this token was issued for. - // If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String()) - // Otherwise, this is the string representation of the net.Addr (net.Addr.String()) - RemoteAddr string - // The time that the STK was issued (resolution 1 second) - SentTime time.Time -} - -// token is the struct that is used for ASN1 serialization and deserialization -type token struct { - Data []byte - Timestamp int64 -} - -// An STKGenerator generates STKs -type STKGenerator struct { - stkSource crypto.StkSource -} - -// NewSTKGenerator initializes a new STKGenerator -func NewSTKGenerator() (*STKGenerator, error) { - stkSource, err := crypto.NewStkSource() - if err != nil { - return nil, err - } - return &STKGenerator{ - stkSource: stkSource, - }, nil -} - -// NewToken generates a new STK token for a given source address -func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) { - data, err := asn1.Marshal(token{ - Data: encodeRemoteAddr(raddr), - Timestamp: time.Now().Unix(), - }) - if err != nil { - return nil, err - } - return g.stkSource.NewToken(data) -} - -// DecodeToken decodes an STK token -func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) { - // if the client didn't send any STK, DecodeToken will be called with a nil-slice - if len(encrypted) == 0 { - return nil, nil - } - - data, err := g.stkSource.DecodeToken(encrypted) - if err != nil { - return nil, err - } - t := &token{} - rest, err := asn1.Unmarshal(data, t) - if err != nil { - return nil, err - } - if len(rest) != 0 { - return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) - } - return &STK{ - RemoteAddr: decodeRemoteAddr(t.Data), - SentTime: time.Unix(t.Timestamp, 0), - }, nil -} - -// encodeRemoteAddr encodes a remote address such that it can be saved in the STK -func encodeRemoteAddr(remoteAddr net.Addr) []byte { - if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { - return append([]byte{stkPrefixIP}, udpAddr.IP...) - } - return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...) -} - -// decodeRemoteAddr decodes the remote address saved in the STK -func decodeRemoteAddr(data []byte) string { - // data will never be empty for an STK that we generated. Check it to be on the safe side - if len(data) == 0 { - return "" - } - if data[0] == stkPrefixIP { - return net.IP(data[1:]).String() - } - return string(data[1:]) -} diff --git a/session_test.go b/session_test.go index 190eee26..1c7a9fc7 100644 --- a/session_test.go +++ b/session_test.go @@ -166,7 +166,7 @@ var _ = Describe("Session", func() { _ io.ReadWriter, _ handshake.ConnectionParametersManager, _ []protocol.VersionNumber, - _ func(net.Addr, *handshake.STK) bool, + _ func(net.Addr, *STK) bool, aeadChangedP chan<- protocol.EncryptionLevel, ) (handshake.CryptoSetup, error) { aeadChanged = aeadChangedP @@ -204,7 +204,7 @@ var _ = Describe("Session", func() { Context("source address validation", func() { var ( - stkVerify func(net.Addr, *handshake.STK) bool + stkVerify func(net.Addr, *STK) bool paramClientAddr net.Addr paramSTK *STK ) @@ -219,7 +219,7 @@ var _ = Describe("Session", func() { _ io.ReadWriter, _ handshake.ConnectionParametersManager, _ []protocol.VersionNumber, - stkFunc func(net.Addr, *handshake.STK) bool, + stkFunc func(net.Addr, *STK) bool, _ chan<- protocol.EncryptionLevel, ) (handshake.CryptoSetup, error) { stkVerify = stkFunc @@ -253,7 +253,7 @@ var _ = Describe("Session", func() { It("calls the callback with the STK when the client sent an STK", func() { stkAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} sentTime := time.Now().Add(-time.Hour) - stkVerify(remoteAddr, &handshake.STK{SentTime: sentTime, RemoteAddr: stkAddr.String()}) + stkVerify(remoteAddr, &STK{SentTime: sentTime, RemoteAddr: stkAddr.String()}) Expect(paramClientAddr).To(Equal(remoteAddr)) Expect(paramSTK).ToNot(BeNil()) Expect(paramSTK.RemoteAddr).To(Equal(stkAddr.String()))