From 12922bdec9e6683f535e81fd8c809e988251f113 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 22 Nov 2019 11:06:34 +0800 Subject: [PATCH] drop Handshake keys when receiving HANDSHAKE_DONE (as a client) --- internal/handshake/crypto_setup.go | 21 ++++++++++----------- session.go | 9 +++++++++ session_test.go | 9 +++++++++ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 52446af99..9c28683ec 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -221,13 +221,6 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) { h.aead.SetLargestAcked(pn) - // drop handshake keys - if h.handshakeOpener != nil { - h.handshakeOpener = nil - h.handshakeSealer = nil - h.logger.Debugf("Dropping Handshake keys.") - h.runner.DropKeys(protocol.EncryptionHandshake) - } } func (h *cryptoSetup) RunHandshake() { @@ -564,12 +557,18 @@ func (h *cryptoSetup) dropInitialKeys() { } func (h *cryptoSetup) DropHandshakeKeys() { + var dropped bool h.mutex.Lock() - h.handshakeOpener = nil - h.handshakeSealer = nil + if h.handshakeOpener != nil { + h.handshakeOpener = nil + h.handshakeSealer = nil + dropped = true + } h.mutex.Unlock() - h.runner.DropKeys(protocol.EncryptionHandshake) - h.logger.Debugf("Dropping Handshake keys.") + if dropped { + h.runner.DropKeys(protocol.EncryptionHandshake) + h.logger.Debugf("Dropping Handshake keys.") + } } func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { diff --git a/session.go b/session.go index 4665b641b..3b4d49c63 100644 --- a/session.go +++ b/session.go @@ -856,6 +856,7 @@ func (s *session) handleFrame(f wire.Frame, pn protocol.PacketNumber, encLevel p case *wire.RetireConnectionIDFrame: err = s.handleRetireConnectionIDFrame(frame) case *wire.HandshakeDoneFrame: + err = s.handleHandshakeDoneFrame() default: err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) } @@ -974,6 +975,14 @@ func (s *session) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame) return s.connIDGenerator.Retire(f.SequenceNumber) } +func (s *session) handleHandshakeDoneFrame() error { + if s.perspective == protocol.PerspectiveServer { + return qerr.Error(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame") + } + s.cryptoStreamHandler.DropHandshakeKeys() + return nil +} + func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error { if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastPacketReceivedTime); err != nil { return err diff --git a/session_test.go b/session_test.go index c48cd4a80..848227e09 100644 --- a/session_test.go +++ b/session_test.go @@ -404,6 +404,10 @@ var _ = Describe("Session", func() { Expect(sess.handleFrame(ccf, 0, protocol.EncryptionUnspecified)).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) + + It("errors on HANDSHAKE_DONE frames", func() { + Expect(sess.handleHandshakeDoneFrame()).To(MatchError("PROTOCOL_VIOLATION: received a HANDSHAKE_DONE frame")) + }) }) It("tells its versions", func() { @@ -1734,6 +1738,11 @@ var _ = Describe("Client Session", func() { Expect(sess.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) }) + It("handles HANDSHAKE_DONE frames", func() { + cryptoSetup.EXPECT().DropHandshakeKeys() + Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) + }) + Context("handling tokens", func() { var mockTokenStore *MockTokenStore