forked from quic-go/quic-go
Merge pull request #2298 from MarinX/http-remoteaddr-fix
set remote address on http request
This commit is contained in:
@@ -181,7 +181,7 @@ func (s *Server) handleConn(sess quic.Session) {
|
|||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
defer ginkgo.GinkgoRecover()
|
defer ginkgo.GinkgoRecover()
|
||||||
rerr := s.handleRequest(str, decoder, func() {
|
rerr := s.handleRequest(sess, str, decoder, func() {
|
||||||
sess.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
|
sess.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
|
||||||
})
|
})
|
||||||
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
|
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
|
||||||
@@ -210,7 +210,7 @@ func (s *Server) maxHeaderBytes() uint64 {
|
|||||||
return uint64(s.Server.MaxHeaderBytes)
|
return uint64(s.Server.MaxHeaderBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
|
func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
|
||||||
frame, err := parseNextFrame(str)
|
frame, err := parseNextFrame(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newStreamError(errorRequestIncomplete, err)
|
return newStreamError(errorRequestIncomplete, err)
|
||||||
@@ -236,6 +236,8 @@ func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder, onFrameE
|
|||||||
// TODO: use the right error code
|
// TODO: use the right error code
|
||||||
return newStreamError(errorGeneralProtocolError, err)
|
return newStreamError(errorGeneralProtocolError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req.RemoteAddr = sess.RemoteAddr().String()
|
||||||
req.Body = newRequestBody(str, onFrameError)
|
req.Body = newRequestBody(str, onFrameError)
|
||||||
|
|
||||||
if s.logger.Debug() {
|
if s.logger.Debug() {
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ var _ = Describe("Server", func() {
|
|||||||
var (
|
var (
|
||||||
qpackDecoder *qpack.Decoder
|
qpackDecoder *qpack.Decoder
|
||||||
str *mockquic.MockStream
|
str *mockquic.MockStream
|
||||||
|
sess *mockquic.MockSession
|
||||||
exampleGetRequest *http.Request
|
exampleGetRequest *http.Request
|
||||||
examplePostRequest *http.Request
|
examplePostRequest *http.Request
|
||||||
)
|
)
|
||||||
@@ -103,6 +104,10 @@ var _ = Describe("Server", func() {
|
|||||||
|
|
||||||
qpackDecoder = qpack.NewDecoder(nil)
|
qpackDecoder = qpack.NewDecoder(nil)
|
||||||
str = mockquic.NewMockStream(mockCtrl)
|
str = mockquic.NewMockStream(mockCtrl)
|
||||||
|
|
||||||
|
sess = mockquic.NewMockSession(mockCtrl)
|
||||||
|
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
sess.EXPECT().RemoteAddr().Return(addr).AnyTimes()
|
||||||
})
|
})
|
||||||
|
|
||||||
It("calls the HTTP handler function", func() {
|
It("calls the HTTP handler function", func() {
|
||||||
@@ -117,10 +122,11 @@ var _ = Describe("Server", func() {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
Expect(s.handleRequest(str, qpackDecoder, nil)).To(Equal(requestError{}))
|
Expect(s.handleRequest(sess, str, qpackDecoder, nil)).To(Equal(requestError{}))
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
Eventually(requestChan).Should(Receive(&req))
|
Eventually(requestChan).Should(Receive(&req))
|
||||||
Expect(req.Host).To(Equal("www.example.com"))
|
Expect(req.Host).To(Equal("www.example.com"))
|
||||||
|
Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns 200 with an empty handler", func() {
|
It("returns 200 with an empty handler", func() {
|
||||||
@@ -133,7 +139,7 @@ var _ = Describe("Server", func() {
|
|||||||
return responseBuf.Write(p)
|
return responseBuf.Write(p)
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
serr := s.handleRequest(str, qpackDecoder, nil)
|
serr := s.handleRequest(sess, str, qpackDecoder, nil)
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
hfs := decodeHeader(responseBuf)
|
hfs := decodeHeader(responseBuf)
|
||||||
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
|
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||||
@@ -152,7 +158,7 @@ var _ = Describe("Server", func() {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
str.EXPECT().CancelRead(gomock.Any())
|
str.EXPECT().CancelRead(gomock.Any())
|
||||||
|
|
||||||
serr := s.handleRequest(str, qpackDecoder, nil)
|
serr := s.handleRequest(sess, str, qpackDecoder, nil)
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
hfs := decodeHeader(responseBuf)
|
hfs := decodeHeader(responseBuf)
|
||||||
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
|
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
|
||||||
@@ -162,12 +168,14 @@ var _ = Describe("Server", func() {
|
|||||||
var sess *mockquic.MockSession
|
var sess *mockquic.MockSession
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
sess = mockquic.NewMockSession(mockCtrl)
|
sess = mockquic.NewMockSession(mockCtrl)
|
||||||
controlStr := mockquic.NewMockStream(mockCtrl)
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
||||||
controlStr.EXPECT().Write(gomock.Any())
|
controlStr.EXPECT().Write(gomock.Any())
|
||||||
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
|
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
|
||||||
sess.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
|
sess.EXPECT().AcceptStream(gomock.Any()).Return(str, nil)
|
||||||
sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
|
sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
|
||||||
|
sess.EXPECT().RemoteAddr().Return(addr).AnyTimes()
|
||||||
})
|
})
|
||||||
|
|
||||||
It("cancels reading when client sends a body in GET request", func() {
|
It("cancels reading when client sends a body in GET request", func() {
|
||||||
@@ -294,7 +302,7 @@ var _ = Describe("Server", func() {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
|
||||||
serr := s.handleRequest(str, qpackDecoder, nil)
|
serr := s.handleRequest(sess, str, qpackDecoder, nil)
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
Eventually(handlerCalled).Should(BeClosed())
|
Eventually(handlerCalled).Should(BeClosed())
|
||||||
})
|
})
|
||||||
@@ -317,7 +325,7 @@ var _ = Describe("Server", func() {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
|
||||||
serr := s.handleRequest(str, qpackDecoder, nil)
|
serr := s.handleRequest(sess, str, qpackDecoder, nil)
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
Eventually(handlerCalled).Should(BeClosed())
|
Eventually(handlerCalled).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user