forked from quic-go/quic-go
29
h2quic/request_body.go
Normal file
29
h2quic/request_body.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package h2quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
)
|
||||
|
||||
type requestBody struct {
|
||||
requestRead bool
|
||||
dataStream utils.Stream
|
||||
}
|
||||
|
||||
// make sure the requestBody can be used as a http.Request.Body
|
||||
var _ io.ReadCloser = &requestBody{}
|
||||
|
||||
func newRequestBody(stream utils.Stream) *requestBody {
|
||||
return &requestBody{dataStream: stream}
|
||||
}
|
||||
|
||||
func (b *requestBody) Read(p []byte) (int, error) {
|
||||
b.requestRead = true
|
||||
return b.dataStream.Read(p)
|
||||
}
|
||||
|
||||
func (b *requestBody) Close() error {
|
||||
// stream's Close() closes the write side, not the read side
|
||||
return nil
|
||||
}
|
||||
39
h2quic/request_body_test.go
Normal file
39
h2quic/request_body_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package h2quic
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Request body", func() {
|
||||
var (
|
||||
stream *mockStream
|
||||
rb *requestBody
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stream = &mockStream{}
|
||||
stream.Write([]byte("foobar")) // provides data to be read
|
||||
rb = newRequestBody(stream)
|
||||
})
|
||||
|
||||
It("reads from the stream", func() {
|
||||
b := make([]byte, 10)
|
||||
n, _ := stream.Read(b)
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b[0:6]).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("saves if the stream was read from", func() {
|
||||
Expect(rb.requestRead).To(BeFalse())
|
||||
rb.Read(make([]byte, 1))
|
||||
Expect(rb.requestRead).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't close the stream when closing the request body", func() {
|
||||
Expect(stream.closed).To(BeFalse())
|
||||
err := rb.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.closed).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -14,10 +14,12 @@ type mockStream struct {
|
||||
id protocol.StreamID
|
||||
bytes.Buffer
|
||||
remoteClosed bool
|
||||
reset bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (mockStream) Close() error { return nil }
|
||||
func (mockStream) Reset(error) { panic("not implemented") }
|
||||
func (s *mockStream) Close() error { s.closed = true; return nil }
|
||||
func (s *mockStream) Reset(error) { s.reset = true }
|
||||
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true }
|
||||
func (s mockStream) StreamID() protocol.StreamID { return s.id }
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
@@ -155,13 +154,14 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
|
||||
return err
|
||||
}
|
||||
|
||||
var streamEnded bool
|
||||
if h2headersFrame.StreamEnded() {
|
||||
dataStream.CloseRemote(0)
|
||||
streamEnded = true
|
||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||
}
|
||||
|
||||
// stream's Close() closes the write side, not the read side
|
||||
req.Body = ioutil.NopCloser(dataStream)
|
||||
req.Body = newRequestBody(dataStream)
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||
|
||||
@@ -190,6 +190,9 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
|
||||
responseWriter.WriteHeader(200)
|
||||
}
|
||||
if responseWriter.dataStream != nil {
|
||||
if !streamEnded && !req.Body.(*requestBody).requestRead {
|
||||
responseWriter.dataStream.Reset(nil)
|
||||
}
|
||||
responseWriter.dataStream.Close()
|
||||
}
|
||||
if s.CloseAfterFirstRequest {
|
||||
|
||||
@@ -79,6 +79,7 @@ var _ = Describe("H2 server", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
Expect(dataStream.remoteClosed).To(BeTrue())
|
||||
Expect(dataStream.reset).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns 200 with an empty handler", func() {
|
||||
@@ -111,7 +112,7 @@ var _ = Describe("H2 server", func() {
|
||||
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500
|
||||
})
|
||||
|
||||
It("does not close the dataStream when end of stream is not set", func() {
|
||||
It("resets the dataStream when client sends a body in GET request", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
@@ -126,6 +127,41 @@ var _ = Describe("H2 server", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
Expect(dataStream.remoteClosed).To(BeFalse())
|
||||
Expect(dataStream.reset).To(BeTrue())
|
||||
})
|
||||
|
||||
It("resets the dataStream when the body of POST request is not read", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
Expect(r.Method).To(Equal("POST"))
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
Expect(dataStream.remoteClosed).To(BeFalse())
|
||||
Expect(dataStream.reset).To(BeTrue())
|
||||
})
|
||||
|
||||
It("closes the dataStream if the body of POST request was read", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
Expect(r.Method).To(Equal("POST"))
|
||||
handlerCalled = true
|
||||
// read the request body
|
||||
b := make([]byte, 1000)
|
||||
n, _ := r.Body.Read(b)
|
||||
Expect(n).ToNot(BeZero())
|
||||
})
|
||||
headerStream.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
||||
dataStream.Write([]byte("foo=bar"))
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
Expect(dataStream.reset).To(BeFalse())
|
||||
})
|
||||
|
||||
It("errors when non-header frames are received", func() {
|
||||
|
||||
Reference in New Issue
Block a user