diff --git a/benchmark_test.go b/benchmark_test.go index 9342bc98..a62b2dab 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -45,7 +45,7 @@ func (c *linkedConnection) write(p []byte) error { } func (*linkedConnection) setCurrentRemoteAddr(addr interface{}) {} -func (*linkedConnection) IP() net.IP { return nil } +func (*linkedConnection) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{} } func setAEAD(cs *handshake.CryptoSetup, aead crypto.AEAD) { *(*bool)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true diff --git a/h2quic/server.go b/h2quic/server.go index 254786ee..a7c6ae0b 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -22,6 +22,7 @@ import ( type streamCreator interface { GetOrOpenStream(protocol.StreamID) (utils.Stream, error) Close(error) error + RemoteAddr() *net.UDPAddr } // Server is a HTTP2 server listening for QUIC connections. @@ -137,6 +138,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, return err } + req.RemoteAddr = session.RemoteAddr().String() + if utils.Debug() { utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) } else { diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 9a81ce19..4d32c3a9 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -28,8 +28,10 @@ type mockSession struct { func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { return s.dataStream, nil } - func (s *mockSession) Close(error) error { s.closed = true; return nil } +func (s *mockSession) RemoteAddr() *net.UDPAddr { + return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42} +} var _ = Describe("H2 server", func() { certPath := os.Getenv("GOPATH") @@ -67,7 +69,9 @@ var _ = Describe("H2 server", func() { It("handles a sample GET request", func() { var handlerCalled bool s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() Expect(r.Host).To(Equal("www.example.com")) + Expect(r.RemoteAddr).To(Equal("127.0.0.1:42")) handlerCalled = true }) headerStream.Write([]byte{ diff --git a/session.go b/session.go index 318c4d29..1c0c213d 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package quic import ( "errors" "fmt" + "net" "sync/atomic" "time" @@ -128,7 +129,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol cryptoStream, _ := session.GetOrOpenStream(1) var err error - session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) + session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.RemoteAddr().IP, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) if err != nil { return nil, err } @@ -652,3 +653,8 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { return res, nil } + +// RemoteAddr returns the net.UDPAddr of the client +func (s *Session) RemoteAddr() *net.UDPAddr { + return s.conn.RemoteAddr() +} diff --git a/session_test.go b/session_test.go index 01dcc94b..433b850c 100644 --- a/session_test.go +++ b/session_test.go @@ -35,7 +35,7 @@ func (m *mockConnection) write(p []byte) error { } func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {} -func (*mockConnection) IP() net.IP { return nil } +func (*mockConnection) RemoteAddr() *net.UDPAddr { return &net.UDPAddr{} } type mockUnpacker struct{} diff --git a/udp_conn.go b/udp_conn.go index ea1cc7ed..2c1bafe4 100644 --- a/udp_conn.go +++ b/udp_conn.go @@ -5,7 +5,7 @@ import "net" type connection interface { write([]byte) error setCurrentRemoteAddr(interface{}) - IP() net.IP + RemoteAddr() *net.UDPAddr } type udpConn struct { @@ -24,6 +24,6 @@ func (c *udpConn) setCurrentRemoteAddr(addr interface{}) { c.currentAddr = addr.(*net.UDPAddr) } -func (c *udpConn) IP() net.IP { - return c.currentAddr.IP +func (c *udpConn) RemoteAddr() *net.UDPAddr { + return c.currentAddr }