forked from quic-go/quic-go
Merge pull request #698 from twdkeule/implement-closeNotifier
implement the context for the server http.Request
This commit is contained in:
@@ -3,6 +3,7 @@ package h2quic
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
@@ -36,6 +37,7 @@ var _ = Describe("Client", func() {
|
|||||||
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
|
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
|
||||||
Expect(client.hostname).To(Equal(hostname))
|
Expect(client.hostname).To(Equal(hostname))
|
||||||
session = &mockSession{}
|
session = &mockSession{}
|
||||||
|
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
||||||
client.session = session
|
client.session = session
|
||||||
|
|
||||||
headerStream = newMockStream(3)
|
headerStream = newMockStream(3)
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) {
|
|||||||
|
|
||||||
func (w *responseWriter) Flush() {}
|
func (w *responseWriter) Flush() {}
|
||||||
|
|
||||||
// TODO: Implement a functional CloseNotify method.
|
// This is a NOP. Use http.Request.Context
|
||||||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
||||||
|
|
||||||
// test that we implement http.Flusher
|
// test that we implement http.Flusher
|
||||||
|
|||||||
@@ -25,20 +25,24 @@ type mockStream struct {
|
|||||||
remoteClosed bool
|
remoteClosed bool
|
||||||
|
|
||||||
unblockRead chan struct{}
|
unblockRead chan struct{}
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockStream(id protocol.StreamID) *mockStream {
|
func newMockStream(id protocol.StreamID) *mockStream {
|
||||||
return &mockStream{
|
s := &mockStream{
|
||||||
id: id,
|
id: id,
|
||||||
unblockRead: make(chan struct{}),
|
unblockRead: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockStream) Close() error { s.closed = true; return nil }
|
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
|
||||||
func (s *mockStream) Reset(error) { s.reset = true }
|
func (s *mockStream) Reset(error) { s.reset = true }
|
||||||
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true }
|
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
|
||||||
func (s mockStream) StreamID() protocol.StreamID { return s.id }
|
func (s mockStream) StreamID() protocol.StreamID { return s.id }
|
||||||
func (s *mockStream) Context() context.Context { panic("not implemented") }
|
func (s *mockStream) Context() context.Context { return s.ctx }
|
||||||
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
|
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
|
||||||
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
|
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
|
||||||
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }
|
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req = req.WithContext(dataStream.Context())
|
||||||
reqBody := newRequestBody(dataStream)
|
reqBody := newRequestBody(dataStream)
|
||||||
req.Body = reqBody
|
req.Body = reqBody
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ type mockSession struct {
|
|||||||
streamsToOpen []quic.Stream
|
streamsToOpen []quic.Stream
|
||||||
blockOpenStreamSync bool
|
blockOpenStreamSync bool
|
||||||
streamOpenErr error
|
streamOpenErr error
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
|
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
|
||||||
@@ -57,6 +59,7 @@ func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
|
|||||||
func (s *mockSession) Close(e error) error {
|
func (s *mockSession) Close(e error) error {
|
||||||
s.closed = true
|
s.closed = true
|
||||||
s.closedWithError = e
|
s.closedWithError = e
|
||||||
|
s.ctxCancel()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (s *mockSession) LocalAddr() net.Addr {
|
func (s *mockSession) LocalAddr() net.Addr {
|
||||||
@@ -65,7 +68,9 @@ func (s *mockSession) LocalAddr() net.Addr {
|
|||||||
func (s *mockSession) RemoteAddr() net.Addr {
|
func (s *mockSession) RemoteAddr() net.Addr {
|
||||||
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
|
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
|
||||||
}
|
}
|
||||||
func (s *mockSession) Context() context.Context { panic("not implemented") }
|
func (s *mockSession) Context() context.Context {
|
||||||
|
return s.ctx
|
||||||
|
}
|
||||||
|
|
||||||
var _ = Describe("H2 server", func() {
|
var _ = Describe("H2 server", func() {
|
||||||
var (
|
var (
|
||||||
@@ -84,6 +89,7 @@ var _ = Describe("H2 server", func() {
|
|||||||
dataStream = newMockStream(0)
|
dataStream = newMockStream(0)
|
||||||
close(dataStream.unblockRead)
|
close(dataStream.unblockRead)
|
||||||
session = &mockSession{dataStream: dataStream}
|
session = &mockSession{dataStream: dataStream}
|
||||||
|
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
||||||
origQuicListenAddr = quicListenAddr
|
origQuicListenAddr = quicListenAddr
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -247,6 +253,29 @@ var _ = Describe("H2 server", func() {
|
|||||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||||
Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame"))
|
Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("Cancels the request context when the datstream is closed", func() {
|
||||||
|
var handlerCalled bool
|
||||||
|
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
err := r.Context().Err()
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(Equal("context canceled"))
|
||||||
|
handlerCalled = true
|
||||||
|
})
|
||||||
|
headerStream.dataToRead.Write([]byte{
|
||||||
|
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
||||||
|
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||||
|
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||||
|
})
|
||||||
|
dataStream.Close()
|
||||||
|
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||||
|
Expect(dataStream.remoteClosed).To(BeTrue())
|
||||||
|
Expect(dataStream.reset).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("handles the header stream", func() {
|
It("handles the header stream", func() {
|
||||||
|
|||||||
Reference in New Issue
Block a user