diff --git a/benchmark_test.go b/benchmark_test.go index 9319dee1..d80727dc 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -63,6 +63,7 @@ func (c *linkedConnection) Write(p []byte) error { func (c *linkedConnection) Read(p []byte) (int, net.Addr, error) { panic("not implemented") } func (*linkedConnection) SetCurrentRemoteAddr(addr net.Addr) {} +func (*linkedConnection) LocalAddr() net.Addr { panic("not implemented") } func (*linkedConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} } func (c *linkedConnection) Close() error { return nil } diff --git a/conn.go b/conn.go index 0fede0ba..700c1471 100644 --- a/conn.go +++ b/conn.go @@ -9,6 +9,7 @@ type connection interface { Write([]byte) error Read([]byte) (int, net.Addr, error) Close() error + LocalAddr() net.Addr RemoteAddr() net.Addr SetCurrentRemoteAddr(net.Addr) } @@ -37,6 +38,10 @@ func (c *conn) SetCurrentRemoteAddr(addr net.Addr) { c.mutex.Unlock() } +func (c *conn) LocalAddr() net.Addr { + return c.pconn.LocalAddr() +} + func (c *conn) RemoteAddr() net.Addr { c.mutex.RLock() addr := c.currentAddr diff --git a/conn_test.go b/conn_test.go index 0417d536..764ca5fd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -82,6 +82,15 @@ var _ = Describe("Connection", func() { Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337")) }) + It("gets the local address", func() { + addr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 1), + Port: 1234, + } + packetConn.addr = addr + Expect(c.LocalAddr()).To(Equal(addr)) + }) + It("changes the remote address", func() { addr := &net.UDPAddr{ IP: net.IPv4(127, 0, 0, 1), diff --git a/h2quic/server_test.go b/h2quic/server_test.go index b1e30c01..e2c99a8f 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -56,6 +56,9 @@ func (s *mockSession) Close(e error) error { s.closedWithError = e return nil } +func (s *mockSession) LocalAddr() net.Addr { + panic("not implemented") +} func (s *mockSession) RemoteAddr() net.Addr { return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42} } diff --git a/interface.go b/interface.go index 0567b89d..e32ef7d7 100644 --- a/interface.go +++ b/interface.go @@ -30,6 +30,8 @@ type Session interface { // OpenStreamSync opens a new QUIC stream, blocking until the peer's concurrent stream limit allows a new stream to be opened. // It always picks the smallest possible stream ID. OpenStreamSync() (Stream, error) + // LocalAddr returns the local address. + LocalAddr() net.Addr // RemoteAddr returns the address of the peer. RemoteAddr() net.Addr // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. diff --git a/server_test.go b/server_test.go index 63001269..6cc68f1e 100644 --- a/server_test.go +++ b/server_test.go @@ -42,6 +42,9 @@ func (s *mockSession) OpenStream() (Stream, error) { func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } +func (s *mockSession) LocalAddr() net.Addr { + panic("not implemented") +} func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } diff --git a/session.go b/session.go index 7454ff7b..0e69155c 100644 --- a/session.go +++ b/session.go @@ -814,6 +814,10 @@ func (s *session) ackAlarmChanged(t time.Time) { s.maybeResetTimer() } +func (s *session) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + // RemoteAddr returns the net.Addr of the client func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() diff --git a/session_test.go b/session_test.go index e4c7aa7c..93d6a7b7 100644 --- a/session_test.go +++ b/session_test.go @@ -25,6 +25,7 @@ import ( type mockConnection struct { remoteAddr net.Addr + localAddr net.Addr written [][]byte } @@ -39,8 +40,9 @@ func (m *mockConnection) Read([]byte) (int, net.Addr, error) { panic("not implem func (m *mockConnection) SetCurrentRemoteAddr(addr net.Addr) { m.remoteAddr = addr } -func (*mockConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} } -func (*mockConnection) Close() error { panic("not implemented") } +func (m *mockConnection) LocalAddr() net.Addr { return m.localAddr } +func (m *mockConnection) RemoteAddr() net.Addr { return m.remoteAddr } +func (*mockConnection) Close() error { panic("not implemented") } type mockUnpacker struct { unpackErr error @@ -124,7 +126,9 @@ var _ = Describe("Session", func() { ) BeforeEach(func() { - mconn = &mockConnection{} + mconn = &mockConnection{ + remoteAddr: &net.UDPAddr{}, + } closeCallbackCalled = false certChain := crypto.NewCertChain(testdata.GetTLSConfig()) @@ -1426,4 +1430,16 @@ var _ = Describe("Session", func() { Expect(frames[0].ByteOffset).To(Equal(protocol.ReceiveConnectionFlowControlWindow * 2)) }) }) + + It("returns the local address", func() { + addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + mconn.localAddr = addr + Expect(sess.LocalAddr()).To(Equal(addr)) + }) + + It("returns the remote address", func() { + addr := &net.UDPAddr{IP: net.IPv4(1, 2, 7, 1), Port: 7331} + mconn.remoteAddr = addr + Expect(sess.RemoteAddr()).To(Equal(addr)) + }) })