diff --git a/crypto_stream.go b/crypto_stream.go new file mode 100644 index 000000000..bc75a18a7 --- /dev/null +++ b/crypto_stream.go @@ -0,0 +1,41 @@ +package quic + +import ( + "io" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type cryptoStreamI interface { + io.Reader + io.Writer + AddStreamFrame(*wire.StreamFrame) error + HasDataForWriting() bool + GetDataForWriting(maxBytes protocol.ByteCount) (data []byte, shouldSendFin bool) + GetWriteOffset() protocol.ByteCount + Cancel(error) + SetReadOffset(protocol.ByteCount) + // methods needed for flow control + GetWindowUpdate() protocol.ByteCount + UpdateSendWindow(protocol.ByteCount) + IsFlowControlBlocked() bool +} + +type cryptoStream struct { + *stream +} + +func newCryptoStream(onData func(), flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI { + str := newStream(version.CryptoStreamID(), onData, nil, flowController, version) + return &cryptoStream{str} +} + +// SetReadOffset sets the read offset. +// It is only needed for the crypto stream. +// It must not be called concurrently with any other stream methods, especially Read and Write. +func (s *cryptoStream) SetReadOffset(offset protocol.ByteCount) { + s.readOffset = offset + s.frameQueue.readPosition = offset +} diff --git a/crypto_stream_test.go b/crypto_stream_test.go new file mode 100644 index 000000000..010db2038 --- /dev/null +++ b/crypto_stream_test.go @@ -0,0 +1,20 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stream", func() { + var str *cryptoStream + + str = newCryptoStream(nil, nil, protocol.VersionWhatever).(*cryptoStream) + + It("sets the read offset", func() { + str.SetReadOffset(0x42) + Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42))) + Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42))) + }) +}) diff --git a/packet_packer_test.go b/packet_packer_test.go index 0771f038b..97057703f 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -56,12 +56,12 @@ var _ = Describe("Packet packer", func() { publicHeaderLen protocol.ByteCount maxFrameSize protocol.ByteCount streamFramer *streamFramer - cryptoStream *stream + cryptoStream cryptoStreamI ) BeforeEach(func() { version := versionGQUICFrames - cryptoStream = &stream{streamID: version.CryptoStreamID(), flowController: flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil)} + cryptoStream = newCryptoStream(func() {}, flowcontrol.NewStreamFlowController(version.CryptoStreamID(), false, flowcontrol.NewConnectionFlowController(1000, 1000, nil), 1000, 1000, 1000, nil), version) streamsMap := newStreamsMap(nil, protocol.PerspectiveServer, versionGQUICFrames) streamFramer = newStreamFramer(cryptoStream, streamsMap, nil, versionGQUICFrames) @@ -585,29 +585,55 @@ var _ = Describe("Packet packer", func() { }) It("sends unencrypted stream data on the crypto stream", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := cryptoStream.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted - cryptoStream.dataForWriting = []byte("foobar") - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) + var p *packedPacket + Eventually(func() *packedPacket { + defer GinkgoRecover() + var err error + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + return p + }).ShouldNot(BeNil()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0]).To(Equal(&wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), })) + Eventually(done).Should(BeClosed()) }) It("sends encrypted stream data on the crypto stream", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := cryptoStream.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure - cryptoStream.dataForWriting = []byte("foobar") - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) + var p *packedPacket + Eventually(func() *packedPacket { + defer GinkgoRecover() + var err error + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + return p + }).ShouldNot(BeNil()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.frames).To(HaveLen(1)) Expect(p.frames[0]).To(Equal(&wire.StreamFrame{ StreamID: packer.version.CryptoStreamID(), Data: []byte("foobar"), })) + Eventually(done).Should(BeClosed()) }) It("does not pack stream frames if not allowed", func() { @@ -766,14 +792,27 @@ var _ = Describe("Packet packer", func() { packer.hasSentPacket = false packer.perspective = protocol.PerspectiveClient packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted - cryptoStream.dataForWriting = []byte("foobar") - packet, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := cryptoStream.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + var packet *packedPacket + Eventually(func() *packedPacket { + defer GinkgoRecover() + var err error + packet, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + return packet + }).ShouldNot(BeNil()) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) Expect(packet.frames).To(HaveLen(1)) sf := packet.frames[0].(*wire.StreamFrame) Expect(sf.Data).To(Equal([]byte("foobar"))) Expect(sf.DataLenPresent).To(BeTrue()) + Eventually(done).Should(BeClosed()) }) It("refuses to retransmit packets that were sent with forward-secure encryption", func() { diff --git a/server.go b/server.go index 4ca25a269..03766f3e3 100644 --- a/server.go +++ b/server.go @@ -19,7 +19,7 @@ import ( // packetHandler handles packets type packetHandler interface { Session - getCryptoStream() cryptoStream + getCryptoStream() cryptoStreamI handshakeStatus() <-chan handshakeEvent handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber diff --git a/server_test.go b/server_test.go index 7e95acc49..a84d10ac3 100644 --- a/server_test.go +++ b/server_test.go @@ -68,7 +68,7 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not imple func (*mockSession) Context() context.Context { panic("not implemented") } func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } -func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") } +func (*mockSession) getCryptoStream() cryptoStreamI { panic("not implemented") } var _ Session = &mockSession{} var _ NonFWSession = &mockSession{} diff --git a/session.go b/session.go index e282c69ad..91aa76056 100644 --- a/session.go +++ b/session.go @@ -56,7 +56,7 @@ type session struct { conn connection streamsMap *streamsMap - cryptoStream cryptoStream + cryptoStream cryptoStreamI rttStats *congestion.RTTStats @@ -294,7 +294,7 @@ func (s *session) preSetup() { protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), s.rttStats, ) - s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream) + s.cryptoStream = s.newCryptoStream() } func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error { @@ -880,6 +880,20 @@ func (s *session) newStream(id protocol.StreamID) streamI { return newStream(id, s.scheduleSending, s.queueResetStreamFrame, flowController, s.version) } +func (s *session) newCryptoStream() cryptoStreamI { + id := s.version.CryptoStreamID() + flowController := flowcontrol.NewStreamFlowController( + id, + s.version.StreamContributesToConnectionFlowControl(id), + s.connFlowController, + protocol.ReceiveStreamFlowControlWindow, + protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), + 0, + s.rttStats, + ) + return newCryptoStream(s.scheduleSending, flowController, s.version) +} + func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) return s.conn.Write(wire.WritePublicReset(s.connectionID, rejectedPacketNumber, 0)) @@ -949,7 +963,7 @@ func (s *session) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } -func (s *session) getCryptoStream() cryptoStream { +func (s *session) getCryptoStream() cryptoStreamI { return s.cryptoStream } diff --git a/session_test.go b/session_test.go index d2a6efe4f..82aad5263 100644 --- a/session_test.go +++ b/session_test.go @@ -414,7 +414,7 @@ var _ = Describe("Session", func() { fc := mocks.NewMockStreamFlowController(mockCtrl) offset := protocol.ByteCount(0x4321) fc.EXPECT().UpdateSendWindow(offset) - sess.cryptoStream.(*stream).flowController = fc + sess.cryptoStream.(*cryptoStream).flowController = fc err := sess.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ StreamID: sess.version.CryptoStreamID(), ByteOffset: offset, diff --git a/stream.go b/stream.go index 0e4f34e71..d1b95f032 100644 --- a/stream.go +++ b/stream.go @@ -30,11 +30,6 @@ type streamI interface { IsFlowControlBlocked() bool } -type cryptoStream interface { - streamI - SetReadOffset(protocol.ByteCount) -} - // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. @@ -481,11 +476,3 @@ func (s *stream) IsFlowControlBlocked() bool { func (s *stream) GetWindowUpdate() protocol.ByteCount { return s.flowController.GetWindowUpdate() } - -// SetReadOffset sets the read offset. -// It is only needed for the crypto stream. -// It must not be called concurrently with any other stream methods, especially Read and Write. -func (s *stream) SetReadOffset(offset protocol.ByteCount) { - s.readOffset = offset - s.frameQueue.readPosition = offset -} diff --git a/stream_framer.go b/stream_framer.go index 1c478aa4a..ded71ec61 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -8,7 +8,7 @@ import ( type streamFramer struct { streamsMap *streamsMap - cryptoStream streamI + cryptoStream cryptoStreamI version protocol.VersionNumber connFlowController flowcontrol.ConnectionFlowController @@ -18,7 +18,7 @@ type streamFramer struct { } func newStreamFramer( - cryptoStream streamI, + cryptoStream cryptoStreamI, streamsMap *streamsMap, cfc flowcontrol.ConnectionFlowController, v protocol.VersionNumber, @@ -63,7 +63,7 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str return nil } frame := &wire.StreamFrame{ - StreamID: f.cryptoStream.StreamID(), + StreamID: f.version.CryptoStreamID(), Offset: f.cryptoStream.GetWriteOffset(), } frame.Data, frame.FinBit = f.cryptoStream.GetDataForWriting(maxLen - frame.MinLength(f.version)) diff --git a/stream_test.go b/stream_test.go index ce2ab483f..72e990a33 100644 --- a/stream_test.go +++ b/stream_test.go @@ -267,12 +267,6 @@ var _ = Describe("Stream", func() { Expect(onDataCalled).To(BeTrue()) }) - It("sets the read offset", func() { - str.SetReadOffset(0x42) - Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42))) - Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42))) - }) - Context("deadlines", func() { It("the deadline error has the right net.Error properties", func() { Expect(errDeadline.Temporary()).To(BeTrue())