diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 52446af9..9c28683e 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -221,13 +221,6 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) { h.aead.SetLargestAcked(pn) - // drop handshake keys - if h.handshakeOpener != nil { - h.handshakeOpener = nil - h.handshakeSealer = nil - h.logger.Debugf("Dropping Handshake keys.") - h.runner.DropKeys(protocol.EncryptionHandshake) - } } func (h *cryptoSetup) RunHandshake() { @@ -564,12 +557,18 @@ func (h *cryptoSetup) dropInitialKeys() { } func (h *cryptoSetup) DropHandshakeKeys() { + var dropped bool h.mutex.Lock() - h.handshakeOpener = nil - h.handshakeSealer = nil + if h.handshakeOpener != nil { + h.handshakeOpener = nil + h.handshakeSealer = nil + dropped = true + } h.mutex.Unlock() - h.runner.DropKeys(protocol.EncryptionHandshake) - h.logger.Debugf("Dropping Handshake keys.") + if dropped { + h.runner.DropKeys(protocol.EncryptionHandshake) + h.logger.Debugf("Dropping Handshake keys.") + } } func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { diff --git a/session.go b/session.go index 4665b641..3b4d49c6 100644 --- a/session.go +++ b/session.go @@ -856,6 +856,7 @@ func (s *session) handleFrame(f wire.Frame, pn protocol.PacketNumber, encLevel p case *wire.RetireConnectionIDFrame: err = s.handleRetireConnectionIDFrame(frame) case *wire.HandshakeDoneFrame: + err = s.handleHandshakeDoneFrame() default: err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) } @@ -974,6 +975,14 @@ func (s *session) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame) return s.connIDGenerator.Retire(f.SequenceNumber) } +func (s *session) handleHandshakeDoneFrame() error { + if s.perspective == protocol.PerspectiveServer { + return qerr.Error(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame") + } + s.cryptoStreamHandler.DropHandshakeKeys() + return nil +} + func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error { if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastPacketReceivedTime); err != nil { return err diff --git a/session_test.go b/session_test.go index c48cd4a8..848227e0 100644 --- a/session_test.go +++ b/session_test.go @@ -404,6 +404,10 @@ var _ = Describe("Session", func() { Expect(sess.handleFrame(ccf, 0, protocol.EncryptionUnspecified)).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) + + It("errors on HANDSHAKE_DONE frames", func() { + Expect(sess.handleHandshakeDoneFrame()).To(MatchError("PROTOCOL_VIOLATION: received a HANDSHAKE_DONE frame")) + }) }) It("tells its versions", func() { @@ -1734,6 +1738,11 @@ var _ = Describe("Client Session", func() { Expect(sess.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) }) + It("handles HANDSHAKE_DONE frames", func() { + cryptoSetup.EXPECT().DropHandshakeKeys() + Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) + }) + Context("handling tokens", func() { var mockTokenStore *MockTokenStore