From f034e8ba192acb228f9feca8d8e49317faf92feb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 16 Feb 2020 14:10:30 +0700 Subject: [PATCH] set the LocalAddr that is used in the tls.ClientHelloInfo.Conn --- internal/handshake/crypto_setup.go | 6 ++++-- internal/handshake/crypto_setup_test.go | 16 ++++++++++++++++ internal/handshake/qtls.go | 11 +++++++---- session.go | 2 ++ session_test.go | 2 ++ 5 files changed, 31 insertions(+), 6 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 1b76d7c43..d2226c707 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -122,6 +122,7 @@ func NewCryptoSetupClient( initialStream io.Writer, handshakeStream io.Writer, connID protocol.ConnectionID, + localAddr net.Addr, remoteAddr net.Addr, tp *TransportParameters, runner handshakeRunner, @@ -142,7 +143,7 @@ func NewCryptoSetupClient( logger, protocol.PerspectiveClient, ) - cs.conn = qtls.Client(newConn(remoteAddr), cs.tlsConf) + cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf) return cs, clientHelloWritten } @@ -151,6 +152,7 @@ func NewCryptoSetupServer( initialStream io.Writer, handshakeStream io.Writer, connID protocol.ConnectionID, + localAddr net.Addr, remoteAddr net.Addr, tp *TransportParameters, runner handshakeRunner, @@ -171,7 +173,7 @@ func NewCryptoSetupServer( logger, protocol.PerspectiveServer, ) - cs.conn = qtls.Server(newConn(remoteAddr), cs.tlsConf) + cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf) return cs } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index ee585adde..56a6723a5 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -89,6 +89,7 @@ var _ = Describe("Crypto Setup TLS", func() { &bytes.Buffer{}, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, NewMockHandshakeRunner(mockCtrl), tlsConf, @@ -120,6 +121,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, runner, testdata.GetTLSConfig(), @@ -157,6 +159,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, runner, testdata.GetTLSConfig(), @@ -197,6 +200,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, runner, serverConf, @@ -230,6 +234,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, NewMockHandshakeRunner(mockCtrl), serverConf, @@ -323,6 +328,7 @@ var _ = Describe("Crypto Setup TLS", func() { cHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, cRunner, clientConf, @@ -344,6 +350,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{StatelessResetToken: &token}, sRunner, serverConf, @@ -396,6 +403,7 @@ var _ = Describe("Crypto Setup TLS", func() { cHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, runner, &tls.Config{InsecureSkipVerify: true}, @@ -436,6 +444,7 @@ var _ = Describe("Crypto Setup TLS", func() { cHandshakeStream, protocol.ConnectionID{}, nil, + nil, cTransportParameters, cRunner, clientConf, @@ -458,6 +467,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, sTransportParameters, sRunner, serverConf, @@ -489,6 +499,7 @@ var _ = Describe("Crypto Setup TLS", func() { cHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, cRunner, clientConf, @@ -506,6 +517,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, sRunner, serverConf, @@ -544,6 +556,7 @@ var _ = Describe("Crypto Setup TLS", func() { cHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, cRunner, clientConf, @@ -561,6 +574,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, sRunner, serverConf, @@ -671,6 +685,7 @@ var _ = Describe("Crypto Setup TLS", func() { cHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, cRunner, clientConf, @@ -688,6 +703,7 @@ var _ = Describe("Crypto Setup TLS", func() { sHandshakeStream, protocol.ConnectionID{}, nil, + nil, &TransportParameters{}, sRunner, serverConf, diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index 9f25e9869..0798a1d08 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -11,11 +11,14 @@ import ( ) type conn struct { - remoteAddr net.Addr + localAddr, remoteAddr net.Addr } -func newConn(remote net.Addr) net.Conn { - return &conn{remoteAddr: remote} +func newConn(local, remote net.Addr) net.Conn { + return &conn{ + localAddr: local, + remoteAddr: remote, + } } var _ net.Conn = &conn{} @@ -24,7 +27,7 @@ func (c *conn) Read([]byte) (int, error) { return 0, nil } func (c *conn) Write([]byte) (int, error) { return 0, nil } func (c *conn) Close() error { return nil } func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *conn) LocalAddr() net.Addr { return nil } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } func (c *conn) SetReadDeadline(time.Time) error { return nil } func (c *conn) SetWriteDeadline(time.Time) error { return nil } func (c *conn) SetDeadline(time.Time) error { return nil } diff --git a/session.go b/session.go index d3dfb0c98..1b065c305 100644 --- a/session.go +++ b/session.go @@ -270,6 +270,7 @@ var newSession = func( initialStream, handshakeStream, clientDestConnID, + conn.LocalAddr(), conn.RemoteAddr(), params, &handshakeRunner{ @@ -372,6 +373,7 @@ var newClientSession = func( initialStream, handshakeStream, destConnID, + conn.LocalAddr(), conn.RemoteAddr(), params, &handshakeRunner{ diff --git a/session_test.go b/session_test.go index dd759e5a1..6e744ded6 100644 --- a/session_test.go +++ b/session_test.go @@ -78,6 +78,7 @@ var _ = Describe("Session", func() { sessionRunner = NewMockSessionRunner(mockCtrl) mconn = NewMockConnection(mockCtrl) mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2) + mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{}) tokenGenerator, err := handshake.NewTokenGenerator() Expect(err).ToNot(HaveOccurred()) sess = newSession( @@ -1684,6 +1685,7 @@ var _ = Describe("Client Session", func() { mconn = NewMockConnection(mockCtrl) mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2) + mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{}) if tlsConf == nil { mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) tlsConf = &tls.Config{}