remove stream ID from OpenStream() method

This commit is contained in:
Marten Seemann
2017-02-09 17:05:58 +07:00
parent 8cd1e4484c
commit f47142eaac
9 changed files with 424 additions and 423 deletions

View File

@@ -21,7 +21,7 @@ import (
)
type quicClient interface {
OpenStream(protocol.StreamID) (utils.Stream, error)
OpenStream() (utils.Stream, error)
Close(error) error
Listen() error
}
@@ -36,11 +36,10 @@ type Client struct {
hostname string
encryptionLevel protocol.EncryptionLevel
client quicClient
headerStream utils.Stream
headerErr *qerr.QuicError
highestOpenedStream protocol.StreamID
requestWriter *requestWriter
client quicClient
headerStream utils.Stream
headerErr *qerr.QuicError
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
}
@@ -50,10 +49,9 @@ var _ h2quicClient = &Client{}
// NewClient creates a new client
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
c := &Client{
t: t,
hostname: authorityAddr("https", hostname),
highestOpenedStream: 3,
responses: make(map[protocol.StreamID]chan *http.Response),
t: t,
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
}
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
@@ -88,10 +86,13 @@ func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
func (c *Client) versionNegotiateCallback() error {
var err error
// once the version has been negotiated, open the header stream
c.headerStream, err = c.client.OpenStream(3)
c.headerStream, err = c.client.OpenStream()
if err != nil {
return err
}
if c.headerStream.StreamID() != 3 {
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
}
c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream()
return nil
@@ -160,21 +161,18 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
hasBody := (req.Body != nil)
c.mutex.Lock()
c.highestOpenedStream += 2
dataStreamID := c.highestOpenedStream
for c.encryptionLevel != protocol.EncryptionForwardSecure {
c.cryptoChangedCond.Wait()
}
hdrChan := make(chan *http.Response)
c.responses[dataStreamID] = hdrChan
c.mutex.Unlock()
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
dataStream, err := c.client.OpenStream(dataStreamID)
dataStream, err := c.client.OpenStream()
if err != nil {
c.Close(err)
return nil, err
}
c.responses[dataStream.StreamID()] = hdrChan
c.mutex.Unlock()
var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
@@ -182,7 +180,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
}
// TODO: add support for trailers
endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip)
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil {
c.Close(err)
return nil, err
@@ -209,7 +207,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
case res = <-hdrChan:
receivedResponse = true
c.mutex.Lock()
delete(c.responses, dataStreamID)
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
if res == nil { // an error occured on the header stream
c.Close(c.headerErr)

View File

@@ -18,25 +18,25 @@ import (
)
type mockQuicClient struct {
streams map[protocol.StreamID]*mockStream
closeErr error
nextStream protocol.StreamID
streams map[protocol.StreamID]*mockStream
closeErr error
}
func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil }
func (m *mockQuicClient) Listen() error { panic("not implemented") }
func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) {
_, ok := m.streams[id]
if ok {
panic("Stream already exists")
}
func (m *mockQuicClient) OpenStream() (utils.Stream, error) {
id := m.nextStream
ms := &mockStream{id: id}
m.streams[id] = ms
m.nextStream += 2
return ms, nil
}
func newMockQuicClient() *mockQuicClient {
return &mockQuicClient{
streams: make(map[protocol.StreamID]*mockStream),
streams: make(map[protocol.StreamID]*mockStream),
nextStream: 5,
}
}
@@ -77,6 +77,7 @@ var _ = Describe("Client", func() {
// delete the headerStream openend in the BeforeEach
client.headerStream = nil
delete(qClient.streams, 3)
qClient.nextStream = 3
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
// now start the actual test
err := client.versionNegotiateCallback()
@@ -133,7 +134,6 @@ var _ = Describe("Client", func() {
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
Expect(client.highestOpenedStream).To(Equal(protocol.StreamID(5)))
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
Expect(client.responses).To(HaveKey(protocol.StreamID(5)))
rsp := &http.Response{