diff --git a/http3/server.go b/http3/server.go index 09ab44961..856ec9d4e 100644 --- a/http3/server.go +++ b/http3/server.go @@ -181,7 +181,7 @@ func (s *Server) handleConn(sess quic.Session) { } go func() { defer ginkgo.GinkgoRecover() - rerr := s.handleRequest(str, decoder, func() { + rerr := s.handleRequest(sess, str, decoder, func() { sess.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "") }) if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { @@ -210,7 +210,7 @@ func (s *Server) maxHeaderBytes() uint64 { return uint64(s.Server.MaxHeaderBytes) } -func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { +func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { frame, err := parseNextFrame(str) if err != nil { return newStreamError(errorRequestIncomplete, err) @@ -236,6 +236,8 @@ func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder, onFrameE // TODO: use the right error code return newStreamError(errorGeneralProtocolError, err) } + + req.RemoteAddr = sess.RemoteAddr().String() req.Body = newRequestBody(str, onFrameError) if s.logger.Debug() { diff --git a/http3/server_test.go b/http3/server_test.go index 8ede3e1d6..418f9fe9b 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -46,6 +46,7 @@ var _ = Describe("Server", func() { var ( qpackDecoder *qpack.Decoder str *mockquic.MockStream + sess *mockquic.MockSession exampleGetRequest *http.Request examplePostRequest *http.Request ) @@ -103,6 +104,10 @@ var _ = Describe("Server", func() { qpackDecoder = qpack.NewDecoder(nil) str = mockquic.NewMockStream(mockCtrl) + + sess = mockquic.NewMockSession(mockCtrl) + addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + sess.EXPECT().RemoteAddr().Return(addr).AnyTimes() }) It("calls the HTTP handler function", func() { @@ -117,10 +122,11 @@ var _ = Describe("Server", func() { return len(p), nil }).AnyTimes() - Expect(s.handleRequest(str, qpackDecoder, nil)).To(Equal(requestError{})) + Expect(s.handleRequest(sess, str, qpackDecoder, nil)).To(Equal(requestError{})) var req *http.Request Eventually(requestChan).Should(Receive(&req)) Expect(req.Host).To(Equal("www.example.com")) + Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) }) It("returns 200 with an empty handler", func() { @@ -133,7 +139,7 @@ var _ = Describe("Server", func() { return responseBuf.Write(p) }).AnyTimes() - serr := s.handleRequest(str, qpackDecoder, nil) + serr := s.handleRequest(sess, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) @@ -152,7 +158,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) - serr := s.handleRequest(str, qpackDecoder, nil) + serr := s.handleRequest(sess, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) @@ -162,12 +168,14 @@ var _ = Describe("Server", func() { var sess *mockquic.MockSession BeforeEach(func() { + addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} sess = mockquic.NewMockSession(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()) sess.EXPECT().OpenUniStream().Return(controlStr, nil) sess.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + sess.EXPECT().RemoteAddr().Return(addr).AnyTimes() }) It("cancels reading when client sends a body in GET request", func() { @@ -294,7 +302,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse)) - serr := s.handleRequest(str, qpackDecoder, nil) + serr := s.handleRequest(sess, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) Eventually(handlerCalled).Should(BeClosed()) }) @@ -317,7 +325,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse)) - serr := s.handleRequest(str, qpackDecoder, nil) + serr := s.handleRequest(sess, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) Eventually(handlerCalled).Should(BeClosed()) })