pass remote address to cryptoSetupServer

This commit is contained in:
Marten Seemann
2017-03-30 17:46:09 +07:00
parent 8489c94f4d
commit e68e2d287a
4 changed files with 59 additions and 47 deletions

View File

@@ -6,6 +6,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"net"
"sync" "sync"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
@@ -58,7 +59,7 @@ var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL exper
// NewCryptoSetup creates a new CryptoSetup instance for a server // NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup( func NewCryptoSetup(
connID protocol.ConnectionID, connID protocol.ConnectionID,
sourceAddr []byte, remoteAddr net.Addr,
version protocol.VersionNumber, version protocol.VersionNumber,
scfg *ServerConfig, scfg *ServerConfig,
cryptoStream io.ReadWriter, cryptoStream io.ReadWriter,
@@ -66,6 +67,13 @@ func NewCryptoSetup(
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
aeadChanged chan<- protocol.EncryptionLevel, aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
var sourceAddr []byte
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
sourceAddr = udpAddr.IP
} else {
sourceAddr = []byte(remoteAddr.String())
}
return &cryptoSetupServer{ return &cryptoSetupServer{
connID: connID, connID: connID,
sourceAddr: sourceAddr, sourceAddr: sourceAddr,
@@ -263,7 +271,8 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
if crypto.HashCert(cert) != xlct { if crypto.HashCert(cert) != xlct {
return true return true
} }
if err := h.scfg.stkSource.VerifyToken(h.sourceAddr, cryptoData[TagSTK]); err != nil { stk := cryptoData[TagSTK]
if err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk); err != nil {
utils.Debugf("STK invalid: %s", err.Error()) utils.Debugf("STK invalid: %s", err.Error())
return true return true
} }

View File

@@ -161,7 +161,6 @@ var _ = Describe("Server Crypto Setup", func() {
aeadChanged chan protocol.EncryptionLevel aeadChanged chan protocol.EncryptionLevel
nonce32 []byte nonce32 []byte
versionTag []byte versionTag []byte
sourceAddr []byte
validSTK []byte validSTK []byte
aead []byte aead []byte
kexs []byte kexs []byte
@@ -171,8 +170,8 @@ var _ = Describe("Server Crypto Setup", func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
sourceAddr = net.ParseIP("1.2.3.4") remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
validSTK, err = mockStkSource{}.NewToken(sourceAddr) validSTK, err = mockStkSource{}.NewToken(remoteAddr.IP)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
expectedInitialNonceLen = 32 expectedInitialNonceLen = 32
expectedFSNonceLen = 64 expectedFSNonceLen = 64
@@ -192,7 +191,16 @@ var _ = Describe("Server Crypto Setup", func() {
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1] version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
supportedVersions = []protocol.VersionNumber{version, 98, 99} supportedVersions = []protocol.VersionNumber{version, 98, 99}
cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever) cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever)
csInt, err := NewCryptoSetup(protocol.ConnectionID(42), sourceAddr, version, scfg, stream, cpm, supportedVersions, aeadChanged) csInt, err := NewCryptoSetup(
protocol.ConnectionID(42),
remoteAddr,
version,
scfg,
stream,
cpm,
supportedVersions,
aeadChanged,
)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer) cs = csInt.(*cryptoSetupServer)
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
@@ -219,6 +227,40 @@ var _ = Describe("Server Crypto Setup", func() {
}) })
}) })
Context("source address token", func() {
It("uses the IP address when the remote address is a UDP address", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 3, 7), Port: 1337}
cs, err := NewCryptoSetup(
protocol.ConnectionID(42),
remoteAddr,
protocol.VersionWhatever,
scfg,
stream,
cpm,
supportedVersions,
aeadChanged,
)
Expect(err).ToNot(HaveOccurred())
Expect(cs.(*cryptoSetupServer).sourceAddr).To(BeEquivalentTo(remoteAddr.IP))
})
It("works with remote address that are not UDP", func() {
remoteAddr := &net.TCPAddr{IP: net.IPv4(1, 3, 3, 7), Port: 1337}
cs, err := NewCryptoSetup(
protocol.ConnectionID(42),
remoteAddr,
protocol.VersionWhatever,
scfg,
stream,
cpm,
supportedVersions,
aeadChanged,
)
Expect(err).ToNot(HaveOccurred())
Expect(cs.(*cryptoSetupServer).sourceAddr).To(BeEquivalentTo("1.3.3.7:1337"))
})
})
Context("when responding to client messages", func() { Context("when responding to client messages", func() {
var cert []byte var cert []byte
var xlct []byte var xlct []byte

View File

@@ -137,12 +137,6 @@ func newSession(
s.setup() s.setup()
cryptoStream, _ := s.GetOrOpenStream(1) cryptoStream, _ := s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream _, _ = s.AcceptStream() // don't expose the crypto stream
var sourceAddr []byte
if udpAddr, ok := conn.RemoteAddr().(*net.UDPAddr); ok {
sourceAddr = udpAddr.IP
} else {
sourceAddr = []byte(conn.RemoteAddr().String())
}
aeadChanged := make(chan protocol.EncryptionLevel, 2) aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3) handshakeChan := make(chan handshakeEvent, 3)
@@ -150,7 +144,7 @@ func newSession(
var err error var err error
s.cryptoSetup, err = newCryptoSetup( s.cryptoSetup, err = newCryptoSetup(
connectionID, connectionID,
sourceAddr, conn.RemoteAddr(),
v, v,
sCfg, sCfg,
cryptoStream, cryptoStream,

View File

@@ -131,8 +131,6 @@ var _ = Describe("Session", func() {
cryptoSetup *mockCryptoSetup cryptoSetup *mockCryptoSetup
handshakeChan <-chan handshakeEvent handshakeChan <-chan handshakeEvent
aeadChanged chan<- protocol.EncryptionLevel aeadChanged chan<- protocol.EncryptionLevel
cryptoSetupSourceAddr []byte
) )
BeforeEach(func() { BeforeEach(func() {
@@ -141,7 +139,7 @@ var _ = Describe("Session", func() {
cryptoSetup = &mockCryptoSetup{} cryptoSetup = &mockCryptoSetup{}
newCryptoSetup = func( newCryptoSetup = func(
_ protocol.ConnectionID, _ protocol.ConnectionID,
sourceAddr []byte, _ net.Addr,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ *handshake.ServerConfig, _ *handshake.ServerConfig,
_ io.ReadWriter, _ io.ReadWriter,
@@ -149,7 +147,6 @@ var _ = Describe("Session", func() {
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
aeadChangedP chan<- protocol.EncryptionLevel, aeadChangedP chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, error) {
cryptoSetupSourceAddr = sourceAddr
aeadChanged = aeadChangedP aeadChanged = aeadChangedP
return cryptoSetup, nil return cryptoSetup, nil
} }
@@ -183,36 +180,6 @@ var _ = Describe("Session", func() {
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
}) })
Context("source address", func() {
It("uses the IP address if given an UDP connection", func() {
conn := &conn{currentAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}}
_, _, err := newSession(
conn,
protocol.VersionWhatever,
0,
scfg,
populateServerConfig(&Config{}),
)
Expect(err).ToNot(HaveOccurred())
Expect(cryptoSetupSourceAddr).To(Equal([]byte{192, 168, 100, 200}))
})
It("uses the string representation of the remote addresses if not given a UDP connection", func() {
conn := &conn{
currentAddr: &net.TCPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337},
}
_, _, err := newSession(
conn,
protocol.VersionWhatever,
0,
scfg,
populateServerConfig(&Config{}),
)
Expect(err).ToNot(HaveOccurred())
Expect(cryptoSetupSourceAddr).To(Equal([]byte("192.168.100.200:1337")))
})
})
Context("when handling stream frames", func() { Context("when handling stream frames", func() {
It("makes new streams", func() { It("makes new streams", func() {
sess.handleStreamFrame(&frames.StreamFrame{ sess.handleStreamFrame(&frames.StreamFrame{