forked from quic-go/quic-go
remove stream ID from OpenStream() method
This commit is contained in:
@@ -21,7 +21,7 @@ import (
|
||||
)
|
||||
|
||||
type quicClient interface {
|
||||
OpenStream(protocol.StreamID) (utils.Stream, error)
|
||||
OpenStream() (utils.Stream, error)
|
||||
Close(error) error
|
||||
Listen() error
|
||||
}
|
||||
@@ -36,11 +36,10 @@ type Client struct {
|
||||
hostname string
|
||||
encryptionLevel protocol.EncryptionLevel
|
||||
|
||||
client quicClient
|
||||
headerStream utils.Stream
|
||||
headerErr *qerr.QuicError
|
||||
highestOpenedStream protocol.StreamID
|
||||
requestWriter *requestWriter
|
||||
client quicClient
|
||||
headerStream utils.Stream
|
||||
headerErr *qerr.QuicError
|
||||
requestWriter *requestWriter
|
||||
|
||||
responses map[protocol.StreamID]chan *http.Response
|
||||
}
|
||||
@@ -50,10 +49,9 @@ var _ h2quicClient = &Client{}
|
||||
// NewClient creates a new client
|
||||
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
|
||||
c := &Client{
|
||||
t: t,
|
||||
hostname: authorityAddr("https", hostname),
|
||||
highestOpenedStream: 3,
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
t: t,
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
}
|
||||
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
|
||||
|
||||
@@ -88,10 +86,13 @@ func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
|
||||
func (c *Client) versionNegotiateCallback() error {
|
||||
var err error
|
||||
// once the version has been negotiated, open the header stream
|
||||
c.headerStream, err = c.client.OpenStream(3)
|
||||
c.headerStream, err = c.client.OpenStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.headerStream.StreamID() != 3 {
|
||||
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream)
|
||||
go c.handleHeaderStream()
|
||||
return nil
|
||||
@@ -160,21 +161,18 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
hasBody := (req.Body != nil)
|
||||
|
||||
c.mutex.Lock()
|
||||
c.highestOpenedStream += 2
|
||||
dataStreamID := c.highestOpenedStream
|
||||
for c.encryptionLevel != protocol.EncryptionForwardSecure {
|
||||
c.cryptoChangedCond.Wait()
|
||||
}
|
||||
hdrChan := make(chan *http.Response)
|
||||
c.responses[dataStreamID] = hdrChan
|
||||
c.mutex.Unlock()
|
||||
|
||||
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
|
||||
dataStream, err := c.client.OpenStream(dataStreamID)
|
||||
dataStream, err := c.client.OpenStream()
|
||||
if err != nil {
|
||||
c.Close(err)
|
||||
return nil, err
|
||||
}
|
||||
c.responses[dataStream.StreamID()] = hdrChan
|
||||
c.mutex.Unlock()
|
||||
|
||||
var requestedGzip bool
|
||||
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
||||
@@ -182,7 +180,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
// TODO: add support for trailers
|
||||
endStream := !hasBody
|
||||
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip)
|
||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
||||
if err != nil {
|
||||
c.Close(err)
|
||||
return nil, err
|
||||
@@ -209,7 +207,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
case res = <-hdrChan:
|
||||
receivedResponse = true
|
||||
c.mutex.Lock()
|
||||
delete(c.responses, dataStreamID)
|
||||
delete(c.responses, dataStream.StreamID())
|
||||
c.mutex.Unlock()
|
||||
if res == nil { // an error occured on the header stream
|
||||
c.Close(c.headerErr)
|
||||
|
||||
@@ -18,25 +18,25 @@ import (
|
||||
)
|
||||
|
||||
type mockQuicClient struct {
|
||||
streams map[protocol.StreamID]*mockStream
|
||||
closeErr error
|
||||
nextStream protocol.StreamID
|
||||
streams map[protocol.StreamID]*mockStream
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil }
|
||||
func (m *mockQuicClient) Listen() error { panic("not implemented") }
|
||||
func (m *mockQuicClient) OpenStream(id protocol.StreamID) (utils.Stream, error) {
|
||||
_, ok := m.streams[id]
|
||||
if ok {
|
||||
panic("Stream already exists")
|
||||
}
|
||||
func (m *mockQuicClient) OpenStream() (utils.Stream, error) {
|
||||
id := m.nextStream
|
||||
ms := &mockStream{id: id}
|
||||
m.streams[id] = ms
|
||||
m.nextStream += 2
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
func newMockQuicClient() *mockQuicClient {
|
||||
return &mockQuicClient{
|
||||
streams: make(map[protocol.StreamID]*mockStream),
|
||||
streams: make(map[protocol.StreamID]*mockStream),
|
||||
nextStream: 5,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ var _ = Describe("Client", func() {
|
||||
// delete the headerStream openend in the BeforeEach
|
||||
client.headerStream = nil
|
||||
delete(qClient.streams, 3)
|
||||
qClient.nextStream = 3
|
||||
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
|
||||
// now start the actual test
|
||||
err := client.versionNegotiateCallback()
|
||||
@@ -133,7 +134,6 @@ var _ = Describe("Client", func() {
|
||||
}()
|
||||
|
||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
||||
Expect(client.highestOpenedStream).To(Equal(protocol.StreamID(5)))
|
||||
Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5)))
|
||||
Expect(client.responses).To(HaveKey(protocol.StreamID(5)))
|
||||
rsp := &http.Response{
|
||||
|
||||
Reference in New Issue
Block a user