Add test for request context

This commit is contained in:
Thomas De Keulenaer
2017-08-27 20:04:58 +02:00
parent 51be240488
commit 4f6eef8b08
3 changed files with 38 additions and 5 deletions

View File

@@ -3,6 +3,7 @@ package h2quic
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
@@ -36,6 +37,7 @@ var _ = Describe("Client", func() {
client = newClient(hostname, nil, &roundTripperOpts{}, nil)
Expect(client.hostname).To(Equal(hostname))
session = &mockSession{}
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
client.session = session
headerStream = newMockStream(3)

View File

@@ -25,20 +25,24 @@ type mockStream struct {
remoteClosed bool
unblockRead chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
}
func newMockStream(id protocol.StreamID) *mockStream {
return &mockStream{
s := &mockStream{
id: id,
unblockRead: make(chan struct{}),
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
}
func (s *mockStream) Close() error { s.closed = true; return nil }
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
func (s *mockStream) Reset(error) { s.reset = true }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
func (s mockStream) StreamID() protocol.StreamID { return s.id }
func (s *mockStream) Context() context.Context { return context.Background() }
func (s *mockStream) Context() context.Context { return s.ctx }
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }

View File

@@ -34,6 +34,8 @@ type mockSession struct {
streamsToOpen []quic.Stream
blockOpenStreamSync bool
streamOpenErr error
ctx context.Context
ctxCancel context.CancelFunc
}
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
@@ -57,6 +59,7 @@ func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
func (s *mockSession) Close(e error) error {
s.closed = true
s.closedWithError = e
s.ctxCancel()
return nil
}
func (s *mockSession) LocalAddr() net.Addr {
@@ -66,7 +69,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
}
func (s *mockSession) Context() context.Context {
return context.Background()
return s.ctx
}
var _ = Describe("H2 server", func() {
@@ -86,6 +89,7 @@ var _ = Describe("H2 server", func() {
dataStream = newMockStream(0)
close(dataStream.unblockRead)
session = &mockSession{dataStream: dataStream}
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
origQuicListenAddr = quicListenAddr
})
@@ -249,6 +253,29 @@ var _ = Describe("H2 server", func() {
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame"))
})
It("Cancels the request context when the datstream is closed", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := r.Context().Err()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("context canceled"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
dataStream.Close()
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeTrue())
Expect(dataStream.reset).To(BeFalse())
})
})
It("handles the header stream", func() {