From f6628224861b20fd1b23b58effffd52313515185 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 1 Nov 2017 13:51:47 +0700 Subject: [PATCH] use stream 0 for the crypto stream when using TLS --- h2quic/client.go | 3 -- h2quic/client_test.go | 10 ------- h2quic/server.go | 4 --- h2quic/server_test.go | 13 --------- internal/protocol/version.go | 16 +++++++++-- internal/protocol/version_test.go | 6 ++++ packet_packer_test.go | 12 ++++++-- packet_unpacker.go | 2 +- packet_unpacker_test.go | 8 +++--- session.go | 10 +++---- stream.go | 5 +++- stream_framer.go | 2 +- stream_framer_test.go | 2 +- stream_test.go | 2 +- streams_map.go | 10 +++++-- streams_map_test.go | 47 ++++++++++++++++++++++++++----- 16 files changed, 93 insertions(+), 59 deletions(-) diff --git a/h2quic/client.go b/h2quic/client.go index 0297486c..cd7f65c7 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -90,9 +90,6 @@ func (c *client) dial() error { if err != nil { return err } - if c.headerStream.StreamID() != 3 { - return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3") - } c.requestWriter = newRequestWriter(c.headerStream) go c.handleHeaderStream() return nil diff --git a/h2quic/client_test.go b/h2quic/client_test.go index ad8df477..db988301 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -97,16 +97,6 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(testErr)) }) - It("errors if the header stream has the wrong stream ID", func() { - client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) - session.streamsToOpen = []quic.Stream{&mockStream{id: 2}} - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { - return session, nil - } - _, err := client.RoundTrip(req) - Expect(err).To(MatchError("h2quic Client BUG: StreamID of Header Stream is not 3")) - }) - It("errors if it can't open a stream", func() { testErr := errors.New("you shall not pass") client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) diff --git a/h2quic/server.go b/h2quic/server.go index 586c781d..4f2039d0 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -122,10 +122,6 @@ func (s *Server) handleHeaderStream(session streamCreator) { session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) return } - if stream.StreamID() != 3 { - session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3")) - return - } hpackDecoder := hpack.NewDecoder(4096, nil) h2framer := http2.NewFramer(nil, stream) diff --git a/h2quic/server_test.go b/h2quic/server_test.go index 94e2c7bf..242652cf 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -308,19 +308,6 @@ var _ = Describe("H2 server", func() { Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame"))) }) - It("errors if the accepted header stream has the wrong stream ID", func() { - headerStream := &mockStream{id: 1} - headerStream.dataToRead.Write([]byte{ - 0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, - // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding - 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, - }) - session.streamToAccept = headerStream - go s.handleHeaderStream(session) - Eventually(func() bool { return session.closed }).Should(BeTrue()) - Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))) - }) - It("supports closing after first request", func() { s.CloseAfterFirstRequest = true s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) diff --git a/internal/protocol/version.go b/internal/protocol/version.go index a25fbbbc..f9bfef36 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -45,7 +45,7 @@ func (vn VersionNumber) String() string { case VersionTLS: return "TLS dev version (WIP)" default: - if vn > gquicVersion0 && vn <= maxGquicVersion { + if vn.isGQUIC() { return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) } return fmt.Sprintf("%d", vn) @@ -54,12 +54,24 @@ func (vn VersionNumber) String() string { // ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters func (vn VersionNumber) ToAltSvc() string { - if vn > gquicVersion0 && vn <= maxGquicVersion { + if vn.isGQUIC() { return fmt.Sprintf("%d", vn.toGQUICVersion()) } return fmt.Sprintf("%d", vn) } +// CryptoStreamID gets the Stream ID of the crypto stream +func (vn VersionNumber) CryptoStreamID() StreamID { + if vn.isGQUIC() { + return 1 + } + return 0 +} + +func (vn VersionNumber) isGQUIC() bool { + return vn > gquicVersion0 && vn <= maxGquicVersion +} + func (vn VersionNumber) toGQUICVersion() int { return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) } diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index 4fa8d4c4..5a23aea4 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -43,7 +43,13 @@ var _ = Describe("Version", func() { Expect(VersionNumber(0x51303133).ToAltSvc()).To(Equal("13")) Expect(VersionNumber(0x51303235).ToAltSvc()).To(Equal("25")) Expect(VersionNumber(0x51303438).ToAltSvc()).To(Equal("48")) + }) + It("tells the Stream ID of the crypto stream", func() { + Expect(Version37.CryptoStreamID()).To(Equal(StreamID(1))) + Expect(Version38.CryptoStreamID()).To(Equal(StreamID(1))) + Expect(Version39.CryptoStreamID()).To(Equal(StreamID(1))) + Expect(VersionTLS.CryptoStreamID()).To(Equal(StreamID(0))) }) It("recognizes supported versions", func() { diff --git a/packet_packer_test.go b/packet_packer_test.go index 960f92c6..d4a62378 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -63,7 +63,7 @@ var _ = Describe("Packet packer", func() { BeforeEach(func() { cryptoStream = &stream{flowController: flowcontrol.NewStreamFlowController(1, false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} - streamsMap := newStreamsMap(nil, protocol.PerspectiveServer) + streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, protocol.VersionWhatever) streamFramer = newStreamFramer(cryptoStream, streamsMap, nil) packer = &packetPacker{ @@ -574,7 +574,10 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&wire.StreamFrame{StreamID: 1, Data: []byte("foobar")})) + Expect(p.frames[0]).To(Equal(&wire.StreamFrame{ + StreamID: packer.version.CryptoStreamID(), + Data: []byte("foobar"), + })) }) It("sends encrypted stream data on the crypto stream", func() { @@ -584,7 +587,10 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&wire.StreamFrame{StreamID: 1, Data: []byte("foobar")})) + Expect(p.frames[0]).To(Equal(&wire.StreamFrame{ + StreamID: packer.version.CryptoStreamID(), + Data: []byte("foobar"), + })) }) It("does not pack stream frames if not allowed", func() { diff --git a/packet_unpacker.go b/packet_unpacker.go index bf1e0cfe..02610111 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -55,7 +55,7 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []by err = qerr.Error(qerr.InvalidStreamData, err.Error()) } else { streamID := frame.(*wire.StreamFrame).StreamID - if streamID != 1 && encryptionLevel <= protocol.EncryptionUnencrypted { + if streamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted { err = qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", streamID)) } } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 5187ada9..3d4c2219 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -224,10 +224,10 @@ var _ = Describe("Packet unpacker", func() { }) Context("unpacking STREAM frames", func() { - It("unpacks unencrypted STREAM frames on stream 1", func() { + It("unpacks unencrypted STREAM frames on the crypto stream", func() { unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionUnencrypted f := &wire.StreamFrame{ - StreamID: 1, + StreamID: unpacker.version.CryptoStreamID(), Data: []byte("foobar"), } err := f.Write(buf, 0) @@ -238,10 +238,10 @@ var _ = Describe("Packet unpacker", func() { Expect(packet.frames).To(Equal([]wire.Frame{f})) }) - It("unpacks encrypted STREAM frames on stream 1", func() { + It("unpacks encrypted STREAM frames on the crypto stream", func() { unpacker.aead.(*mockAEAD).encLevelOpen = protocol.EncryptionSecure f := &wire.StreamFrame{ - StreamID: 1, + StreamID: unpacker.version.CryptoStreamID(), Data: []byte("foobar"), } err := f.Write(buf, 0) diff --git a/session.go b/session.go index d6dc5c4b..b9211012 100644 --- a/session.go +++ b/session.go @@ -200,8 +200,8 @@ func (s *session) setup( protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), s.rttStats, ) - s.streamsMap = newStreamsMap(s.newStream, s.perspective) - s.cryptoStream = s.newStream(1) + s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) + s.cryptoStream = s.newStream(s.version.CryptoStreamID()) s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController) var err error @@ -527,7 +527,7 @@ func (s *session) handlePacket(p *receivedPacket) { } func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { - if frame.StreamID == 1 { + if frame.StreamID == s.version.CryptoStreamID() { return s.cryptoStream.AddStreamFrame(frame) } str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) @@ -820,7 +820,7 @@ func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.By func (s *session) newStream(id protocol.StreamID) streamI { // TODO: find a better solution for determining which streams contribute to connection level flow control var contributesToConnection bool - if id != 1 && id != 3 { + if id != 0 && id != 1 && id != 3 { contributesToConnection = true } var initialSendWindow protocol.ByteCount @@ -836,7 +836,7 @@ func (s *session) newStream(id protocol.StreamID) streamI { initialSendWindow, s.rttStats, ) - return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController) + return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version) } func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { diff --git a/stream.go b/stream.go index 6af8526f..806e7fc9 100644 --- a/stream.go +++ b/stream.go @@ -75,6 +75,7 @@ type stream struct { writeDeadline time.Time flowController flowcontrol.StreamFlowController + version protocol.VersionNumber } var _ Stream = &stream{} @@ -93,6 +94,7 @@ func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, ) *stream { s := &stream{ onData: onData, @@ -102,6 +104,7 @@ func newStream(StreamID protocol.StreamID, frameQueue: newStreamFrameSorter(), readChan: make(chan struct{}, 1), writeChan: make(chan struct{}, 1), + version: version, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) return s @@ -274,7 +277,7 @@ func (s *stream) GetDataForWriting(maxBytes protocol.ByteCount) []byte { } // TODO(#657): Flow control for the crypto stream - if s.streamID != 1 { + if s.streamID != s.version.CryptoStreamID() { maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) } if maxBytes == 0 { diff --git a/stream_framer.go b/stream_framer.go index ae1b556c..00755b58 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -60,7 +60,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str return nil } frame := &wire.StreamFrame{ - StreamID: 1, + StreamID: f.cryptoStream.StreamID(), Offset: f.cryptoStream.GetWriteOffset(), } frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error diff --git a/stream_framer_test.go b/stream_framer_test.go index 670a4d98..054147cc 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -41,7 +41,7 @@ var _ = Describe("Stream Framer", func() { stream2 = mocks.NewMockStreamI(mockCtrl) stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() - streamsMap = newStreamsMap(nil, protocol.PerspectiveServer) + streamsMap = newStreamsMap(nil, protocol.PerspectiveServer, protocol.VersionWhatever) streamsMap.putStream(stream1) streamsMap.putStream(stream2) diff --git a/stream_test.go b/stream_test.go index b8187f50..691f910a 100644 --- a/stream_test.go +++ b/stream_test.go @@ -59,7 +59,7 @@ var _ = Describe("Stream", func() { onDataCalled = false resetCalled = false mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newStream(streamID, onData, onReset, mockFC) + str = newStream(streamID, onData, onReset, mockFC, protocol.VersionWhatever) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = struct { diff --git a/streams_map.go b/streams_map.go index a342023c..df5b4c99 100644 --- a/streams_map.go +++ b/streams_map.go @@ -41,7 +41,7 @@ type newStreamLambda func(protocol.StreamID) streamI var errMapAccess = errors.New("streamsMap: Error accessing the streams map") -func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *streamsMap { +func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective, ver protocol.VersionNumber) *streamsMap { // add some tolerance to the maximum incoming streams value maxStreams := uint32(protocol.MaxIncomingStreams) maxIncomingStreams := utils.MaxUint32( @@ -58,12 +58,16 @@ func newStreamsMap(newStream newStreamLambda, pers protocol.Perspective) *stream sm.nextStreamOrErrCond.L = &sm.mutex sm.openStreamOrErrCond.L = &sm.mutex + nextOddStream := protocol.StreamID(1) + if ver.CryptoStreamID() == protocol.StreamID(1) { + nextOddStream = 3 + } if pers == protocol.PerspectiveClient { - sm.nextStream = 3 + sm.nextStream = nextOddStream sm.nextStreamToAccept = 2 } else { sm.nextStream = 2 - sm.nextStreamToAccept = 3 + sm.nextStreamToAccept = nextOddStream } return &sm diff --git a/streams_map_test.go b/streams_map_test.go index 9e102e99..b4c83a39 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -14,6 +14,11 @@ import ( ) var _ = Describe("Streams Map", func() { + const ( + versionCryptoStream1 = protocol.Version39 + versionCryptoStream0 = protocol.VersionTLS + ) + var ( m *streamsMap finishedStreams map[protocol.StreamID]*gomock.Call @@ -27,11 +32,13 @@ var _ = Describe("Streams Map", func() { return str } - setNewStreamsMap := func(p protocol.Perspective) { - m = newStreamsMap(newStream, p) + setNewStreamsMap := func(p protocol.Perspective, v protocol.VersionNumber) { + m = newStreamsMap(newStream, p, v) } BeforeEach(func() { + Expect(versionCryptoStream0.CryptoStreamID()).To(Equal(protocol.StreamID(0))) + Expect(versionCryptoStream1.CryptoStreamID()).To(Equal(protocol.StreamID(1))) finishedStreams = make(map[protocol.StreamID]*gomock.Call) }) @@ -50,7 +57,7 @@ var _ = Describe("Streams Map", func() { Context("getting and creating streams", func() { Context("as a server", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveServer) + setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream1) }) Context("client-side streams", func() { @@ -280,7 +287,22 @@ var _ = Describe("Streams Map", func() { Consistently(func() bool { return accepted }).Should(BeFalse()) }) - It("start with stream 3", func() { + It("starts with stream 1, if the crypto stream is stream 0", func() { + setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream0) + var str streamI + go func() { + defer GinkgoRecover() + var err error + str, err = m.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + }() + _, err := m.GetOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Eventually(func() Stream { return str }).ShouldNot(BeNil()) + Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) + }) + + It("starts with stream 3, if the crypto stream is stream 1", func() { var str streamI go func() { defer GinkgoRecover() @@ -399,7 +421,7 @@ var _ = Describe("Streams Map", func() { Context("as a client", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveClient) + setNewStreamsMap(protocol.PerspectiveClient, versionCryptoStream1) m.UpdateMaxStreamLimit(100) }) @@ -445,7 +467,18 @@ var _ = Describe("Streams Map", func() { }) Context("server-side streams", func() { - It("starts with stream 3", func() { + It("starts with stream 1, if the crypto stream is stream 0", func() { + setNewStreamsMap(protocol.PerspectiveClient, versionCryptoStream0) + m.UpdateMaxStreamLimit(100) + s, err := m.OpenStream() + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(BeEquivalentTo(1)) + Expect(m.numOutgoingStreams).To(BeZero()) + Expect(m.numIncomingStreams).To(BeEquivalentTo(1)) + }) + + It("starts with stream 3, if the crypto stream is stream 1", func() { s, err := m.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) @@ -493,7 +526,7 @@ var _ = Describe("Streams Map", func() { Context("DoS mitigation, iterating and deleting", func() { BeforeEach(func() { - setNewStreamsMap(protocol.PerspectiveServer) + setNewStreamsMap(protocol.PerspectiveServer, versionCryptoStream1) }) closeStream := func(id protocol.StreamID) {