forked from quic-go/quic-go
@@ -2,7 +2,6 @@ package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -12,12 +11,7 @@ import (
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type streamCreator interface {
|
||||
NewStream(protocol.StreamID) (utils.Stream, error)
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
session streamCreator
|
||||
dataStreamID protocol.StreamID
|
||||
headerStream utils.Stream
|
||||
dataStream utils.Stream
|
||||
@@ -26,12 +20,12 @@ type responseWriter struct {
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func newResponseWriter(headerStream utils.Stream, dataStreamID protocol.StreamID, session streamCreator) *responseWriter {
|
||||
func newResponseWriter(headerStream, dataStream utils.Stream, dataStreamID protocol.StreamID) *responseWriter {
|
||||
return &responseWriter{
|
||||
header: http.Header{},
|
||||
headerStream: headerStream,
|
||||
dataStream: dataStream,
|
||||
dataStreamID: dataStreamID,
|
||||
session: session,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,16 +60,5 @@ func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
|
||||
if len(p) != 0 {
|
||||
if w.dataStream == nil {
|
||||
var err error
|
||||
w.dataStream, err = w.session.NewStream(w.dataStreamID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error creating data stream: %s", err.Error())
|
||||
}
|
||||
}
|
||||
return w.dataStream.Write(p)
|
||||
}
|
||||
return 0, nil
|
||||
return w.dataStream.Write(p)
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -16,28 +14,17 @@ type mockStream struct {
|
||||
|
||||
func (mockStream) Close() error { return nil }
|
||||
|
||||
type mockSession struct {
|
||||
stream *mockStream
|
||||
}
|
||||
|
||||
func (s *mockSession) NewStream(id protocol.StreamID) (utils.Stream, error) {
|
||||
Expect(id).To(Equal(protocol.StreamID(5)))
|
||||
return s.stream, nil
|
||||
}
|
||||
|
||||
var _ = Describe("Response Writer", func() {
|
||||
var (
|
||||
w *responseWriter
|
||||
headerStream *mockStream
|
||||
dataStream *mockStream
|
||||
s *mockSession
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
headerStream = &mockStream{}
|
||||
dataStream = &mockStream{}
|
||||
s = &mockSession{stream: dataStream}
|
||||
w = newResponseWriter(headerStream, 5, s)
|
||||
w = newResponseWriter(headerStream, dataStream, 5)
|
||||
})
|
||||
|
||||
It("writes status", func() {
|
||||
|
||||
@@ -3,6 +3,7 @@ package h2quic
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
@@ -12,6 +13,10 @@ import (
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type streamCreator interface {
|
||||
GetOrCreateStream(protocol.StreamID) (utils.Stream, error)
|
||||
}
|
||||
|
||||
// Server is a HTTP2 server listening for QUIC connections
|
||||
type Server struct {
|
||||
server *quic.Server
|
||||
@@ -23,7 +28,7 @@ func NewServer(tlsConfig *tls.Config) (*Server, error) {
|
||||
s := &Server{}
|
||||
|
||||
var err error
|
||||
s.server, err = quic.NewServer(tlsConfig, s.handleStream)
|
||||
s.server, err = quic.NewServer(tlsConfig, s.handleStreamCb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -41,7 +46,11 @@ func (s *Server) ListenAndServe(addr string, handler http.Handler) error {
|
||||
return s.server.ListenAndServe(addr)
|
||||
}
|
||||
|
||||
func (s *Server) handleStream(session *quic.Session, headerStream utils.Stream) {
|
||||
func (s *Server) handleStreamCb(session *quic.Session, headerStream utils.Stream) {
|
||||
s.handleStream(session, headerStream)
|
||||
}
|
||||
|
||||
func (s *Server) handleStream(session streamCreator, headerStream utils.Stream) {
|
||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
||||
h2framer := http2.NewFramer(nil, headerStream)
|
||||
|
||||
@@ -55,7 +64,7 @@ func (s *Server) handleStream(session *quic.Session, headerStream utils.Stream)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(session *quic.Session, headerStream utils.Stream, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
||||
func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
||||
h2frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -76,13 +85,16 @@ func (s *Server) handleRequest(session *quic.Session, headerStream utils.Stream,
|
||||
}
|
||||
utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
||||
|
||||
responseWriter := &responseWriter{
|
||||
header: http.Header{},
|
||||
headerStream: headerStream,
|
||||
dataStreamID: protocol.StreamID(h2headersFrame.StreamID),
|
||||
session: session,
|
||||
dataStream, err := session.GetOrCreateStream(protocol.StreamID(h2headersFrame.StreamID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// stream's Close() closes the write side, not the read side
|
||||
req.Body = ioutil.NopCloser(dataStream)
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||
|
||||
go func() {
|
||||
s.handler.ServeHTTP(responseWriter, req)
|
||||
if responseWriter.dataStream != nil {
|
||||
|
||||
@@ -6,15 +6,24 @@ import (
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/testdata"
|
||||
"github.com/lucas-clemente/quic-go/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockSession struct{}
|
||||
|
||||
func (s *mockSession) GetOrCreateStream(id protocol.StreamID) (utils.Stream, error) {
|
||||
return &mockStream{}, nil
|
||||
}
|
||||
|
||||
var _ = Describe("Response Writer", func() {
|
||||
var (
|
||||
s *Server
|
||||
s *Server
|
||||
session *mockSession
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
@@ -22,6 +31,7 @@ var _ = Describe("Response Writer", func() {
|
||||
s, err = NewServer(testdata.GetTLSConfig())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(s).NotTo(BeNil())
|
||||
session = &mockSession{}
|
||||
})
|
||||
|
||||
It("uses default handler", func() {
|
||||
@@ -63,11 +73,10 @@ var _ = Describe("Response Writer", func() {
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
err := s.handleRequest(nil, headerStream, hpackDecoder, h2framer)
|
||||
err := s.handleRequest(session, headerStream, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
It("handles the header stream", func() {
|
||||
@@ -82,7 +91,7 @@ var _ = Describe("Response Writer", func() {
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
s.handleStream(nil, headerStream)
|
||||
s.handleStream(session, headerStream)
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user