diff --git a/session.go b/session.go index e319cfa0..265a785a 100644 --- a/session.go +++ b/session.go @@ -113,7 +113,6 @@ type session struct { receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this receivedFirstForwardSecurePacket bool - lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire // representation, and sent back in public reset packets largestRcvdPacketNumber protocol.PacketNumber @@ -530,7 +529,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } } - s.lastRcvdPacketNumber = hdr.PacketNumber // Only do this after decrypting, so we are sure the packet is not attacker-controlled s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) @@ -543,10 +541,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } } - return s.handleFrames(packet.frames, packet.encryptionLevel) + return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel) } -func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error { +func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error { for _, ff := range fs { var err error wire.LogFrame(s.logger, ff, false) @@ -556,7 +554,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve case *wire.StreamFrame: err = s.handleStreamFrame(frame, encLevel) case *wire.AckFrame: - err = s.handleAckFrame(frame, encLevel) + err = s.handleAckFrame(frame, pn, encLevel) case *wire.ConnectionCloseFrame: s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) case *wire.ResetStreamFrame: @@ -702,8 +700,8 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) } -func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { - if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil { +func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error { + if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil { return err } s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) diff --git a/session_test.go b/session_test.go index b7ebac77..854fec94 100644 --- a/session_test.go +++ b/session_test.go @@ -156,8 +156,7 @@ var _ = Describe("Session", func() { sph.EXPECT().ReceivedAck(f, protocol.PacketNumber(42), protocol.EncryptionHandshake, gomock.Any()) sph.EXPECT().GetLowestPacketNotConfirmedAcked() sess.sentPacketHandler = sph - sess.lastRcvdPacketNumber = 42 - err := sess.handleAckFrame(f, protocol.EncryptionHandshake) + err := sess.handleAckFrame(f, 42, protocol.EncryptionHandshake) Expect(err).ToNot(HaveOccurred()) }) @@ -170,7 +169,7 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().IgnoreBelow(protocol.PacketNumber(0x42)) sess.receivedPacketHandler = rph - Expect(sess.handleAckFrame(ack, protocol.EncryptionInitial)).To(Succeed()) + Expect(sess.handleAckFrame(ack, 0, protocol.EncryptionInitial)).To(Succeed()) }) }) @@ -203,11 +202,10 @@ var _ = Describe("Session", func() { It("ignores RESET_STREAM frames for closed streams", func() { streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(nil, nil) - err := sess.handleFrames([]wire.Frame{&wire.ResetStreamFrame{ + Expect(sess.handleFrames([]wire.Frame{&wire.ResetStreamFrame{ StreamID: 3, ErrorCode: 42, - }}, protocol.EncryptionUnspecified) - Expect(err).NotTo(HaveOccurred()) + }}, 0, protocol.EncryptionUnspecified)).To(Succeed()) }) }) @@ -242,7 +240,7 @@ var _ = Describe("Session", func() { err := sess.handleFrames([]wire.Frame{&wire.MaxStreamDataFrame{ StreamID: 10, ByteOffset: 1337, - }}, protocol.EncryptionUnspecified) + }}, 0, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) }) }) @@ -282,44 +280,43 @@ var _ = Describe("Session", func() { It("ignores STOP_SENDING frames for a closed stream", func() { streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(3)).Return(nil, nil) - err := sess.handleFrames([]wire.Frame{&wire.StopSendingFrame{ + Expect(sess.handleFrames([]wire.Frame{&wire.StopSendingFrame{ StreamID: 3, ErrorCode: 1337, - }}, protocol.EncryptionUnspecified) - Expect(err).NotTo(HaveOccurred()) + }}, 0, protocol.EncryptionUnspecified)).To(Succeed()) }) }) It("handles PING frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.PingFrame{}}, protocol.EncryptionUnspecified) + err := sess.handleFrames([]wire.Frame{&wire.PingFrame{}}, 0, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) }) It("rejects PATH_RESPONSE frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}}, protocol.EncryptionUnspecified) + err := sess.handleFrames([]wire.Frame{&wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}}, 0, protocol.EncryptionUnspecified) Expect(err).To(MatchError("unexpected PATH_RESPONSE frame")) }) It("handles PATH_CHALLENGE frames", func() { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - err := sess.handleFrames([]wire.Frame{&wire.PathChallengeFrame{Data: data}}, protocol.EncryptionUnspecified) + err := sess.handleFrames([]wire.Frame{&wire.PathChallengeFrame{Data: data}}, 0, protocol.EncryptionUnspecified) Expect(err).ToNot(HaveOccurred()) frames, _ := sess.framer.AppendControlFrames(nil, 1000) Expect(frames).To(Equal([]wire.Frame{&wire.PathResponseFrame{Data: data}})) }) It("handles BLOCKED frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.DataBlockedFrame{}}, protocol.EncryptionUnspecified) + err := sess.handleFrames([]wire.Frame{&wire.DataBlockedFrame{}}, 0, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) }) It("handles STREAM_BLOCKED frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.StreamDataBlockedFrame{}}, protocol.EncryptionUnspecified) + err := sess.handleFrames([]wire.Frame{&wire.StreamDataBlockedFrame{}}, 0, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) }) It("handles STREAM_ID_BLOCKED frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.StreamsBlockedFrame{}}, protocol.EncryptionUnspecified) + err := sess.handleFrames([]wire.Frame{&wire.StreamsBlockedFrame{}}, 0, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) }) @@ -335,7 +332,7 @@ var _ = Describe("Session", func() { err := sess.run() Expect(err).To(MatchError(testErr)) }() - err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, protocol.EncryptionUnspecified) + err := sess.handleFrames([]wire.Frame{&wire.ConnectionCloseFrame{ErrorCode: qerr.ProofInvalid, ReasonPhrase: "foobar"}}, 0, protocol.EncryptionUnspecified) Expect(err).NotTo(HaveOccurred()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -456,13 +453,12 @@ var _ = Describe("Session", func() { hdr = &wire.Header{PacketNumberLen: protocol.PacketNumberLen4} }) - It("sets the {last,largest}RcvdPacketNumber", func() { + It("sets the largestRcvdPacketNumber", func() { hdr.PacketNumber = 5 hdr.Raw = []byte("raw header") unpacker.EXPECT().Unpack([]byte("raw header"), hdr, []byte("foobar")).Return(&unpackedPacket{}, nil) err := sess.handlePacketImpl(&receivedPacket{header: hdr, data: []byte("foobar")}) Expect(err).ToNot(HaveOccurred()) - Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) }) @@ -509,27 +505,11 @@ var _ = Describe("Session", func() { Eventually(done).Should(BeClosed()) }) - It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil).Times(2) - hdr.PacketNumber = 5 - err := sess.handlePacketImpl(&receivedPacket{header: hdr}) - Expect(err).ToNot(HaveOccurred()) - Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) - Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) - hdr.PacketNumber = 3 - err = sess.handlePacketImpl(&receivedPacket{header: hdr}) - Expect(err).ToNot(HaveOccurred()) - Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(3))) - Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) - }) - It("handles duplicate packets", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil).Times(2) hdr.PacketNumber = 5 - err := sess.handlePacketImpl(&receivedPacket{header: hdr}) - Expect(err).ToNot(HaveOccurred()) - err = sess.handlePacketImpl(&receivedPacket{header: hdr}) - Expect(err).ToNot(HaveOccurred()) + Expect(sess.handlePacketImpl(&receivedPacket{header: hdr})).To(Succeed()) + Expect(sess.handlePacketImpl(&receivedPacket{header: hdr})).To(Succeed()) }) It("ignores packets with a different source connection ID", func() {