diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index 312e64e6..070250b2 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "io" + "net" "sync" "github.com/lucas-clemente/quic-go/crypto" @@ -58,7 +59,7 @@ var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL exper // NewCryptoSetup creates a new CryptoSetup instance for a server func NewCryptoSetup( connID protocol.ConnectionID, - sourceAddr []byte, + remoteAddr net.Addr, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream io.ReadWriter, @@ -66,6 +67,13 @@ func NewCryptoSetup( supportedVersions []protocol.VersionNumber, aeadChanged chan<- protocol.EncryptionLevel, ) (CryptoSetup, error) { + var sourceAddr []byte + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + sourceAddr = udpAddr.IP + } else { + sourceAddr = []byte(remoteAddr.String()) + } + return &cryptoSetupServer{ connID: connID, sourceAddr: sourceAddr, @@ -263,7 +271,8 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt if crypto.HashCert(cert) != xlct { return true } - if err := h.scfg.stkSource.VerifyToken(h.sourceAddr, cryptoData[TagSTK]); err != nil { + stk := cryptoData[TagSTK] + if err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk); err != nil { utils.Debugf("STK invalid: %s", err.Error()) return true } diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index b43b77a6..2cb22450 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -161,7 +161,6 @@ var _ = Describe("Server Crypto Setup", func() { aeadChanged chan protocol.EncryptionLevel nonce32 []byte versionTag []byte - sourceAddr []byte validSTK []byte aead []byte kexs []byte @@ -171,8 +170,8 @@ var _ = Describe("Server Crypto Setup", func() { BeforeEach(func() { var err error - sourceAddr = net.ParseIP("1.2.3.4") - validSTK, err = mockStkSource{}.NewToken(sourceAddr) + remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} + validSTK, err = mockStkSource{}.NewToken(remoteAddr.IP) Expect(err).NotTo(HaveOccurred()) expectedInitialNonceLen = 32 expectedFSNonceLen = 64 @@ -192,7 +191,16 @@ var _ = Describe("Server Crypto Setup", func() { version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1] supportedVersions = []protocol.VersionNumber{version, 98, 99} cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever) - csInt, err := NewCryptoSetup(protocol.ConnectionID(42), sourceAddr, version, scfg, stream, cpm, supportedVersions, aeadChanged) + csInt, err := NewCryptoSetup( + protocol.ConnectionID(42), + remoteAddr, + version, + scfg, + stream, + cpm, + supportedVersions, + aeadChanged, + ) Expect(err).NotTo(HaveOccurred()) cs = csInt.(*cryptoSetupServer) cs.keyDerivation = mockKeyDerivation @@ -219,6 +227,40 @@ var _ = Describe("Server Crypto Setup", func() { }) }) + Context("source address token", func() { + It("uses the IP address when the remote address is a UDP address", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 3, 3, 7), Port: 1337} + cs, err := NewCryptoSetup( + protocol.ConnectionID(42), + remoteAddr, + protocol.VersionWhatever, + scfg, + stream, + cpm, + supportedVersions, + aeadChanged, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(cs.(*cryptoSetupServer).sourceAddr).To(BeEquivalentTo(remoteAddr.IP)) + }) + + It("works with remote address that are not UDP", func() { + remoteAddr := &net.TCPAddr{IP: net.IPv4(1, 3, 3, 7), Port: 1337} + cs, err := NewCryptoSetup( + protocol.ConnectionID(42), + remoteAddr, + protocol.VersionWhatever, + scfg, + stream, + cpm, + supportedVersions, + aeadChanged, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(cs.(*cryptoSetupServer).sourceAddr).To(BeEquivalentTo("1.3.3.7:1337")) + }) + }) + Context("when responding to client messages", func() { var cert []byte var xlct []byte diff --git a/session.go b/session.go index 9eea8e26..2c225aed 100644 --- a/session.go +++ b/session.go @@ -137,12 +137,6 @@ func newSession( s.setup() cryptoStream, _ := s.GetOrOpenStream(1) _, _ = s.AcceptStream() // don't expose the crypto stream - var sourceAddr []byte - if udpAddr, ok := conn.RemoteAddr().(*net.UDPAddr); ok { - sourceAddr = udpAddr.IP - } else { - sourceAddr = []byte(conn.RemoteAddr().String()) - } aeadChanged := make(chan protocol.EncryptionLevel, 2) s.aeadChanged = aeadChanged handshakeChan := make(chan handshakeEvent, 3) @@ -150,7 +144,7 @@ func newSession( var err error s.cryptoSetup, err = newCryptoSetup( connectionID, - sourceAddr, + conn.RemoteAddr(), v, sCfg, cryptoStream, diff --git a/session_test.go b/session_test.go index ce43c1f4..b39ed38c 100644 --- a/session_test.go +++ b/session_test.go @@ -131,8 +131,6 @@ var _ = Describe("Session", func() { cryptoSetup *mockCryptoSetup handshakeChan <-chan handshakeEvent aeadChanged chan<- protocol.EncryptionLevel - - cryptoSetupSourceAddr []byte ) BeforeEach(func() { @@ -141,7 +139,7 @@ var _ = Describe("Session", func() { cryptoSetup = &mockCryptoSetup{} newCryptoSetup = func( _ protocol.ConnectionID, - sourceAddr []byte, + _ net.Addr, _ protocol.VersionNumber, _ *handshake.ServerConfig, _ io.ReadWriter, @@ -149,7 +147,6 @@ var _ = Describe("Session", func() { _ []protocol.VersionNumber, aeadChangedP chan<- protocol.EncryptionLevel, ) (handshake.CryptoSetup, error) { - cryptoSetupSourceAddr = sourceAddr aeadChanged = aeadChangedP return cryptoSetup, nil } @@ -183,36 +180,6 @@ var _ = Describe("Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) }) - Context("source address", func() { - It("uses the IP address if given an UDP connection", func() { - conn := &conn{currentAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}} - _, _, err := newSession( - conn, - protocol.VersionWhatever, - 0, - scfg, - populateServerConfig(&Config{}), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(cryptoSetupSourceAddr).To(Equal([]byte{192, 168, 100, 200})) - }) - - It("uses the string representation of the remote addresses if not given a UDP connection", func() { - conn := &conn{ - currentAddr: &net.TCPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}, - } - _, _, err := newSession( - conn, - protocol.VersionWhatever, - 0, - scfg, - populateServerConfig(&Config{}), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(cryptoSetupSourceAddr).To(Equal([]byte("192.168.100.200:1337"))) - }) - }) - Context("when handling stream frames", func() { It("makes new streams", func() { sess.handleStreamFrame(&frames.StreamFrame{