forked from quic-go/quic-go
implement the TLS Cookie extension
This commit is contained in:
43
internal/handshake/cookie_handler.go
Normal file
43
internal/handshake/cookie_handler.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type cookieHandler struct {
|
||||
callback func(net.Addr, *Cookie) bool
|
||||
|
||||
cookieGenerator *CookieGenerator
|
||||
}
|
||||
|
||||
var _ mint.CookieHandler = &cookieHandler{}
|
||||
|
||||
func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) {
|
||||
cookieGenerator, err := NewCookieGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cookieHandler{
|
||||
callback: callback,
|
||||
cookieGenerator: cookieGenerator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||
if h.callback(conn.RemoteAddr(), nil) {
|
||||
return nil, nil
|
||||
}
|
||||
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||
data, err := h.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||
return false
|
||||
}
|
||||
return h.callback(conn.RemoteAddr(), data)
|
||||
}
|
||||
49
internal/handshake/cookie_handler_test.go
Normal file
49
internal/handshake/cookie_handler_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var callbackReturn bool
|
||||
var mockCallback = func(net.Addr, *Cookie) bool {
|
||||
return callbackReturn
|
||||
}
|
||||
|
||||
var _ = Describe("Cookie Handler", func() {
|
||||
var ch *cookieHandler
|
||||
var conn *mint.Conn
|
||||
|
||||
BeforeEach(func() {
|
||||
callbackReturn = false
|
||||
var err error
|
||||
ch, err = newCookieHandler(mockCallback)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46}
|
||||
conn = mint.NewConn(&fakeConn{remoteAddr: addr}, &mint.Config{}, false)
|
||||
})
|
||||
|
||||
It("generates and validates a token", func() {
|
||||
cookie, err := ch.Generate(conn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ch.Validate(conn, cookie)).To(BeFalse())
|
||||
callbackReturn = true
|
||||
Expect(ch.Validate(conn, cookie)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't generate a token if the callback says so", func() {
|
||||
callbackReturn = true
|
||||
cookie, err := ch.Generate(conn)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cookie).To(BeNil())
|
||||
})
|
||||
|
||||
It("correctly handles a token that it can't decode", func() {
|
||||
cookie := []byte("unparseable cookie")
|
||||
Expect(ch.Validate(conn, cookie)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
@@ -36,9 +37,11 @@ func NewCryptoSetupTLSServer(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
tlsConfig *tls.Config,
|
||||
remoteAddr net.Addr,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
checkCookie func(net.Addr, *Cookie) bool,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
@@ -46,7 +49,16 @@ func NewCryptoSetupTLSServer(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveServer}
|
||||
mintConf.RequireCookie = true
|
||||
mintConf.CookieHandler, err = newCookieHandler(checkCookie)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := &fakeConn{
|
||||
stream: cryptoStream,
|
||||
pers: protocol.PerspectiveServer,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
mintConn := mint.Server(conn, mintConf)
|
||||
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
@@ -86,7 +98,10 @@ func NewCryptoSetupTLSClient(
|
||||
return nil, err
|
||||
}
|
||||
mintConf.ServerName = hostname
|
||||
conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveClient}
|
||||
conn := &fakeConn{
|
||||
stream: cryptoStream,
|
||||
pers: protocol.PerspectiveClient,
|
||||
}
|
||||
mintConn := mint.Client(conn, mintConf)
|
||||
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
|
||||
@@ -34,10 +34,12 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||
nil,
|
||||
1,
|
||||
testdata.GetTLSConfig(),
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
nil,
|
||||
nil,
|
||||
protocol.VersionTLS,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -81,8 +81,9 @@ func (mc *mintController) State() mint.ConnectionState {
|
||||
// mint expects a net.Conn, but we're doing the handshake on a stream
|
||||
// so we wrap a stream such that implements a net.Conn
|
||||
type fakeConn struct {
|
||||
stream io.ReadWriter
|
||||
pers protocol.Perspective
|
||||
stream io.ReadWriter
|
||||
pers protocol.Perspective
|
||||
remoteAddr net.Addr
|
||||
|
||||
blockRead bool
|
||||
writeBuffer bytes.Buffer
|
||||
@@ -120,7 +121,7 @@ func (c *fakeConn) Continue() error {
|
||||
|
||||
func (c *fakeConn) Close() error { return nil }
|
||||
func (c *fakeConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *fakeConn) SetDeadline(time.Time) error { return nil }
|
||||
|
||||
@@ -2,6 +2,7 @@ package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
@@ -62,4 +63,10 @@ var _ = Describe("Fake Conn", func() {
|
||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
})
|
||||
|
||||
It("returns its remote address", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
c.remoteAddr = addr
|
||||
Expect(c.RemoteAddr()).To(Equal(addr))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -214,9 +214,11 @@ func (s *session) setup(
|
||||
s.cryptoStream,
|
||||
s.connectionID,
|
||||
tlsConf,
|
||||
s.conn.RemoteAddr(),
|
||||
transportParams,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
verifySourceAddr,
|
||||
s.config.Versions,
|
||||
s.version,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user