From 25db2166dd825ce1ad1d951ba9fed5a254182ae1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 6 Jul 2020 16:27:29 +0700 Subject: [PATCH] enforce that a conn ID is not retired in a packet that uses that conn ID --- conn_id_generator.go | 5 ++++- conn_id_generator_test.go | 20 +++++++++++++++----- session.go | 12 ++++++------ session_test.go | 24 ++++++++++++------------ 4 files changed, 37 insertions(+), 24 deletions(-) diff --git a/conn_id_generator.go b/conn_id_generator.go index 24571650..7bcc3581 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -67,7 +67,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { return nil } -func (m *connIDGenerator) Retire(seq uint64) error { +func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { if seq > m.highestSeq { return qerr.NewError(qerr.ProtocolViolation, fmt.Sprintf("tried to retire connection ID %d. Highest issued: %d", seq, m.highestSeq)) } @@ -76,6 +76,9 @@ func (m *connIDGenerator) Retire(seq uint64) error { if !ok { return nil } + if connID.Equal(sentWithDestConnID) { + return qerr.NewError(qerr.FrameEncodingError, fmt.Sprintf("tried to retire connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID)) + } m.retireConnectionID(connID) delete(m.activeSrcConnIDs, seq) // Don't issue a replacement for the initial connection ID. diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 1eeda3f6..7b51b480 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -1,6 +1,8 @@ package quic import ( + "fmt" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -68,14 +70,22 @@ var _ = Describe("Connection ID Generator", func() { }) It("errors if the peers tries to retire a connection ID that wasn't yet issued", func() { - Expect(g.Retire(1)).To(MatchError("PROTOCOL_VIOLATION: tried to retire connection ID 1. Highest issued: 0")) + Expect(g.Retire(1, protocol.ConnectionID{})).To(MatchError("PROTOCOL_VIOLATION: tried to retire connection ID 1. Highest issued: 0")) + }) + + It("errors if the peers tries to retire a connection ID in a packet with that connection ID", func() { + Expect(g.SetMaxActiveConnIDs(4)).To(Succeed()) + Expect(queuedFrames).ToNot(BeEmpty()) + Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) + f := queuedFrames[0].(*wire.NewConnectionIDFrame) + Expect(g.Retire(f.SequenceNumber, f.ConnectionID)).To(MatchError(fmt.Sprintf("FRAME_ENCODING_ERROR: tried to retire connection ID %d (%s), which was used as the Destination Connection ID on this packet", f.SequenceNumber, f.ConnectionID))) }) It("issues new connection IDs, when old ones are retired", func() { Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) queuedFrames = nil Expect(retiredConnIDs).To(BeEmpty()) - Expect(g.Retire(3)).To(Succeed()) + Expect(g.Retire(3, protocol.ConnectionID{})).To(Succeed()) Expect(queuedFrames).To(HaveLen(1)) Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) nf := queuedFrames[0].(*wire.NewConnectionIDFrame) @@ -84,7 +94,7 @@ var _ = Describe("Connection ID Generator", func() { }) It("retires the initial connection ID", func() { - Expect(g.Retire(0)).To(Succeed()) + Expect(g.Retire(0, protocol.ConnectionID{})).To(Succeed()) Expect(removedConnIDs).To(BeEmpty()) Expect(retiredConnIDs).To(HaveLen(1)) Expect(retiredConnIDs[0]).To(Equal(initialConnID)) @@ -95,10 +105,10 @@ var _ = Describe("Connection ID Generator", func() { Expect(g.SetMaxActiveConnIDs(11)).To(Succeed()) queuedFrames = nil Expect(retiredConnIDs).To(BeEmpty()) - Expect(g.Retire(5)).To(Succeed()) + Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed()) Expect(retiredConnIDs).To(HaveLen(1)) Expect(queuedFrames).To(HaveLen(1)) - Expect(g.Retire(5)).To(Succeed()) + Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed()) Expect(retiredConnIDs).To(HaveLen(1)) Expect(queuedFrames).To(HaveLen(1)) }) diff --git a/session.go b/session.go index 83c55b9f..e35b49cd 100644 --- a/session.go +++ b/session.go @@ -1024,7 +1024,7 @@ func (s *session) handleUnpackedPacket( // Only process frames now if we're not logging. // If we're logging, we need to make sure that the packet_received event is logged first. if s.tracer == nil { - if err := s.handleFrame(frame, packet.encryptionLevel); err != nil { + if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { return err } } @@ -1049,7 +1049,7 @@ func (s *session) handleUnpackedPacket( } s.tracer.ReceivedPacket(packet.hdr, packetSize, fs) for _, frame := range frames { - if err := s.handleFrame(frame, packet.encryptionLevel); err != nil { + if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { return err } } @@ -1058,7 +1058,7 @@ func (s *session) handleUnpackedPacket( return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, packet.encryptionLevel, rcvTime, isAckEliciting) } -func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel) error { +func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { var err error wire.LogFrame(s.logger, f, false) switch frame := f.(type) { @@ -1094,7 +1094,7 @@ func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel) e case *wire.NewConnectionIDFrame: err = s.handleNewConnectionIDFrame(frame) case *wire.RetireConnectionIDFrame: - err = s.handleRetireConnectionIDFrame(frame) + err = s.handleRetireConnectionIDFrame(frame, destConnID) case *wire.HandshakeDoneFrame: err = s.handleHandshakeDoneFrame() default: @@ -1211,8 +1211,8 @@ func (s *session) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error return s.connIDManager.Add(f) } -func (s *session) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame) error { - return s.connIDGenerator.Retire(f.SequenceNumber) +func (s *session) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { + return s.connIDGenerator.Retire(f.SequenceNumber, destConnID) } func (s *session) handleHandshakeDoneFrame() error { diff --git a/session_test.go b/session_test.go index 329ca1ad..89612fc4 100644 --- a/session_test.go +++ b/session_test.go @@ -198,7 +198,7 @@ var _ = Describe("Session", func() { Expect(sess.handleFrame(&wire.ResetStreamFrame{ StreamID: 3, ErrorCode: 42, - }, protocol.EncryptionUnspecified)).To(Succeed()) + }, protocol.EncryptionUnspecified, protocol.ConnectionID{})).To(Succeed()) }) }) @@ -233,7 +233,7 @@ var _ = Describe("Session", func() { Expect(sess.handleFrame(&wire.MaxStreamDataFrame{ StreamID: 10, MaximumStreamData: 1337, - }, protocol.EncryptionUnspecified)).To(Succeed()) + }, protocol.EncryptionUnspecified, protocol.ConnectionID{})).To(Succeed()) }) }) @@ -275,7 +275,7 @@ var _ = Describe("Session", func() { Expect(sess.handleFrame(&wire.StopSendingFrame{ StreamID: 3, ErrorCode: 1337, - }, protocol.EncryptionUnspecified)).To(Succeed()) + }, protocol.EncryptionUnspecified, protocol.ConnectionID{})).To(Succeed()) }) }) @@ -283,23 +283,23 @@ var _ = Describe("Session", func() { Expect(sess.handleFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 10, ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - }, protocol.Encryption1RTT)).To(Succeed()) + }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) Expect(sess.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) }) It("handles PING frames", func() { - err := sess.handleFrame(&wire.PingFrame{}, protocol.EncryptionUnspecified) + err := sess.handleFrame(&wire.PingFrame{}, protocol.EncryptionUnspecified, protocol.ConnectionID{}) Expect(err).NotTo(HaveOccurred()) }) It("rejects PATH_RESPONSE frames", func() { - err := sess.handleFrame(&wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, protocol.EncryptionUnspecified) + err := sess.handleFrame(&wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, protocol.EncryptionUnspecified, protocol.ConnectionID{}) Expect(err).To(MatchError("unexpected PATH_RESPONSE frame")) }) It("handles PATH_CHALLENGE frames", func() { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - err := sess.handleFrame(&wire.PathChallengeFrame{Data: data}, protocol.EncryptionUnspecified) + err := sess.handleFrame(&wire.PathChallengeFrame{Data: data}, protocol.EncryptionUnspecified, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) frames, _ := sess.framer.AppendControlFrames(nil, 1000) Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PathResponseFrame{Data: data}}})) @@ -313,17 +313,17 @@ var _ = Describe("Session", func() { }) It("handles BLOCKED frames", func() { - err := sess.handleFrame(&wire.DataBlockedFrame{}, protocol.EncryptionUnspecified) + err := sess.handleFrame(&wire.DataBlockedFrame{}, protocol.EncryptionUnspecified, protocol.ConnectionID{}) Expect(err).NotTo(HaveOccurred()) }) It("handles STREAM_BLOCKED frames", func() { - err := sess.handleFrame(&wire.StreamDataBlockedFrame{}, protocol.EncryptionUnspecified) + err := sess.handleFrame(&wire.StreamDataBlockedFrame{}, protocol.EncryptionUnspecified, protocol.ConnectionID{}) Expect(err).NotTo(HaveOccurred()) }) It("handles STREAM_ID_BLOCKED frames", func() { - err := sess.handleFrame(&wire.StreamsBlockedFrame{}, protocol.EncryptionUnspecified) + err := sess.handleFrame(&wire.StreamsBlockedFrame{}, protocol.EncryptionUnspecified, protocol.ConnectionID{}) Expect(err).NotTo(HaveOccurred()) }) @@ -347,7 +347,7 @@ var _ = Describe("Session", func() { Expect(sess.handleFrame(&wire.ConnectionCloseFrame{ ErrorCode: qerr.StreamLimitError, ReasonPhrase: "foobar", - }, protocol.EncryptionUnspecified)).To(Succeed()) + }, protocol.EncryptionUnspecified, protocol.ConnectionID{})).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -373,7 +373,7 @@ var _ = Describe("Session", func() { ReasonPhrase: "foobar", IsApplicationError: true, } - Expect(sess.handleFrame(ccf, protocol.EncryptionUnspecified)).To(Succeed()) + Expect(sess.handleFrame(ccf, protocol.EncryptionUnspecified, protocol.ConnectionID{})).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) })