diff --git a/internal/handshake/cookie_handler.go b/internal/handshake/cookie_handler.go new file mode 100644 index 00000000..317f6e50 --- /dev/null +++ b/internal/handshake/cookie_handler.go @@ -0,0 +1,43 @@ +package handshake + +import ( + "net" + + "github.com/bifurcation/mint" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type cookieHandler struct { + callback func(net.Addr, *Cookie) bool + + cookieGenerator *CookieGenerator +} + +var _ mint.CookieHandler = &cookieHandler{} + +func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) { + cookieGenerator, err := NewCookieGenerator() + if err != nil { + return nil, err + } + return &cookieHandler{ + callback: callback, + cookieGenerator: cookieGenerator, + }, nil +} + +func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) { + if h.callback(conn.RemoteAddr(), nil) { + return nil, nil + } + return h.cookieGenerator.NewToken(conn.RemoteAddr()) +} + +func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool { + data, err := h.cookieGenerator.DecodeToken(token) + if err != nil { + utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) + return false + } + return h.callback(conn.RemoteAddr(), data) +} diff --git a/internal/handshake/cookie_handler_test.go b/internal/handshake/cookie_handler_test.go new file mode 100644 index 00000000..9c95ea61 --- /dev/null +++ b/internal/handshake/cookie_handler_test.go @@ -0,0 +1,49 @@ +package handshake + +import ( + "net" + + "github.com/bifurcation/mint" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var callbackReturn bool +var mockCallback = func(net.Addr, *Cookie) bool { + return callbackReturn +} + +var _ = Describe("Cookie Handler", func() { + var ch *cookieHandler + var conn *mint.Conn + + BeforeEach(func() { + callbackReturn = false + var err error + ch, err = newCookieHandler(mockCallback) + Expect(err).ToNot(HaveOccurred()) + addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46} + conn = mint.NewConn(&fakeConn{remoteAddr: addr}, &mint.Config{}, false) + }) + + It("generates and validates a token", func() { + cookie, err := ch.Generate(conn) + Expect(err).ToNot(HaveOccurred()) + Expect(ch.Validate(conn, cookie)).To(BeFalse()) + callbackReturn = true + Expect(ch.Validate(conn, cookie)).To(BeTrue()) + }) + + It("doesn't generate a token if the callback says so", func() { + callbackReturn = true + cookie, err := ch.Generate(conn) + Expect(err).ToNot(HaveOccurred()) + Expect(cookie).To(BeNil()) + }) + + It("correctly handles a token that it can't decode", func() { + cookie := []byte("unparseable cookie") + Expect(ch.Validate(conn, cookie)).To(BeFalse()) + }) +}) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 294f466d..d8e4245c 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "fmt" "io" + "net" "sync" "github.com/bifurcation/mint" @@ -36,9 +37,11 @@ func NewCryptoSetupTLSServer( cryptoStream io.ReadWriter, connID protocol.ConnectionID, tlsConfig *tls.Config, + remoteAddr net.Addr, params *TransportParameters, paramsChan chan<- TransportParameters, aeadChanged chan<- protocol.EncryptionLevel, + checkCookie func(net.Addr, *Cookie) bool, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, ) (CryptoSetup, error) { @@ -46,7 +49,16 @@ func NewCryptoSetupTLSServer( if err != nil { return nil, err } - conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveServer} + mintConf.RequireCookie = true + mintConf.CookieHandler, err = newCookieHandler(checkCookie) + if err != nil { + return nil, err + } + conn := &fakeConn{ + stream: cryptoStream, + pers: protocol.PerspectiveServer, + remoteAddr: remoteAddr, + } mintConn := mint.Server(conn, mintConf) eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version) if err := mintConn.SetExtensionHandler(eh); err != nil { @@ -86,7 +98,10 @@ func NewCryptoSetupTLSClient( return nil, err } mintConf.ServerName = hostname - conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveClient} + conn := &fakeConn{ + stream: cryptoStream, + pers: protocol.PerspectiveClient, + } mintConn := mint.Client(conn, mintConf) eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version) if err := mintConn.SetExtensionHandler(eh); err != nil { diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 42933a0b..84b77d96 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -34,10 +34,12 @@ var _ = Describe("TLS Crypto Setup", func() { nil, 1, testdata.GetTLSConfig(), + nil, &TransportParameters{}, paramsChan, aeadChanged, nil, + nil, protocol.VersionTLS, ) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/handshake/mint_utils.go b/internal/handshake/mint_utils.go index a8bd2953..8c3a83bd 100644 --- a/internal/handshake/mint_utils.go +++ b/internal/handshake/mint_utils.go @@ -81,8 +81,9 @@ func (mc *mintController) State() mint.ConnectionState { // mint expects a net.Conn, but we're doing the handshake on a stream // so we wrap a stream such that implements a net.Conn type fakeConn struct { - stream io.ReadWriter - pers protocol.Perspective + stream io.ReadWriter + pers protocol.Perspective + remoteAddr net.Addr blockRead bool writeBuffer bytes.Buffer @@ -120,7 +121,7 @@ func (c *fakeConn) Continue() error { func (c *fakeConn) Close() error { return nil } func (c *fakeConn) LocalAddr() net.Addr { return nil } -func (c *fakeConn) RemoteAddr() net.Addr { return nil } +func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr } func (c *fakeConn) SetReadDeadline(time.Time) error { return nil } func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil } func (c *fakeConn) SetDeadline(time.Time) error { return nil } diff --git a/internal/handshake/mint_utils_test.go b/internal/handshake/mint_utils_test.go index d1ba7a95..1bb1858a 100644 --- a/internal/handshake/mint_utils_test.go +++ b/internal/handshake/mint_utils_test.go @@ -2,6 +2,7 @@ package handshake import ( "bytes" + "net" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -62,4 +63,10 @@ var _ = Describe("Fake Conn", func() { Expect(stream.Bytes()).To(Equal([]byte("foobar"))) }) }) + + It("returns its remote address", func() { + addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + c.remoteAddr = addr + Expect(c.RemoteAddr()).To(Equal(addr)) + }) }) diff --git a/session.go b/session.go index a1691362..a40b845f 100644 --- a/session.go +++ b/session.go @@ -214,9 +214,11 @@ func (s *session) setup( s.cryptoStream, s.connectionID, tlsConf, + s.conn.RemoteAddr(), transportParams, paramsChan, aeadChanged, + verifySourceAddr, s.config.Versions, s.version, )