diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index f698db66..9c28683e 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -221,13 +221,6 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) { h.aead.SetLargestAcked(pn) - // drop handshake keys - if h.handshakeOpener != nil { - h.handshakeOpener = nil - h.handshakeSealer = nil - h.logger.Debugf("Dropping Handshake keys.") - h.runner.DropKeys(protocol.EncryptionHandshake) - } } func (h *cryptoSetup) RunHandshake() { @@ -563,6 +556,21 @@ func (h *cryptoSetup) dropInitialKeys() { h.logger.Debugf("Dropping Initial keys.") } +func (h *cryptoSetup) DropHandshakeKeys() { + var dropped bool + h.mutex.Lock() + if h.handshakeOpener != nil { + h.handshakeOpener = nil + h.handshakeSealer = nil + dropped = true + } + h.mutex.Unlock() + if dropped { + h.runner.DropKeys(protocol.EncryptionHandshake) + h.logger.Debugf("Dropping Handshake keys.") + } +} + func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { h.mutex.Lock() defer h.mutex.Unlock() diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 1baee25a..1159266a 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -73,6 +73,7 @@ type CryptoSetup interface { HandleMessage([]byte, protocol.EncryptionLevel) bool SetLargest1RTTAcked(protocol.PacketNumber) + DropHandshakeKeys() ConnectionState() tls.ConnectionState GetInitialOpener() (LongHeaderOpener, error) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index af24ed46..411bbd04 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -76,6 +76,18 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) } +// DropHandshakeKeys mocks base method +func (m *MockCryptoSetup) DropHandshakeKeys() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DropHandshakeKeys") +} + +// DropHandshakeKeys indicates an expected call of DropHandshakeKeys +func (mr *MockCryptoSetupMockRecorder) DropHandshakeKeys() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropHandshakeKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DropHandshakeKeys)) +} + // Get1RTTOpener mocks base method func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { m.ctrl.T.Helper() diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index fbd66d44..4488ecc2 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -85,6 +85,8 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protoc frame, err = parsePathResponseFrame(r, p.version) case 0x1c, 0x1d: frame, err = parseConnectionCloseFrame(r, p.version) + case 0x1e: + frame, err = parseHandshakeDoneFrame(r, p.version) default: err = errors.New("unknown frame type") } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 1d204bdf..f3990b72 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -271,6 +271,15 @@ var _ = Describe("Frame parsing", func() { Expect(frame).To(Equal(f)) }) + It("unpacks HANDSHAKE_DONE frames", func() { + f := &HandshakeDoneFrame{} + buf := &bytes.Buffer{} + Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + It("errors on invalid type", func() { _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x42): unknown frame type")) @@ -308,6 +317,7 @@ var _ = Describe("Frame parsing", func() { &PathChallengeFrame{}, &PathResponseFrame{}, &ConnectionCloseFrame{}, + &HandshakeDoneFrame{}, } var framesSerialized [][]byte diff --git a/internal/wire/handshake_done_frame.go b/internal/wire/handshake_done_frame.go new file mode 100644 index 00000000..158d659f --- /dev/null +++ b/internal/wire/handshake_done_frame.go @@ -0,0 +1,28 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A HandshakeDoneFrame is a HANDSHAKE_DONE frame +type HandshakeDoneFrame struct{} + +// ParseHandshakeDoneFrame parses a HandshakeDone frame +func parseHandshakeDoneFrame(r *bytes.Reader, _ protocol.VersionNumber) (*HandshakeDoneFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &HandshakeDoneFrame{}, nil +} + +func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x1e) + return nil +} + +// Length of a written frame +func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/session.go b/session.go index aaf770c2..6756eb73 100644 --- a/session.go +++ b/session.go @@ -51,6 +51,7 @@ type cryptoStreamHandler interface { RunHandshake() ChangeConnectionID(protocol.ConnectionID) SetLargest1RTTAcked(protocol.PacketNumber) + DropHandshakeKeys() io.Closer ConnectionState() tls.ConnectionState } @@ -605,16 +606,15 @@ func (s *session) handleHandshakeComplete() { s.connIDGenerator.SetHandshakeComplete() s.sentPacketHandler.SetHandshakeComplete() - // The client completes the handshake first (after sending the CFIN). - // We need to make sure it learns about the server completing the handshake, - // in order to stop retransmitting handshake packets. - // They will stop retransmitting handshake packets when receiving the first 1-RTT packet. + if s.perspective == protocol.PerspectiveServer { token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) if err != nil { s.closeLocal(err) } s.queueControlFrame(&wire.NewTokenFrame{Token: token}) + s.cryptoStreamHandler.DropHandshakeKeys() + s.queueControlFrame(&wire.HandshakeDoneFrame{}) } } @@ -857,6 +857,8 @@ func (s *session) handleFrame(f wire.Frame, pn protocol.PacketNumber, encLevel p err = s.handleNewConnectionIDFrame(frame) case *wire.RetireConnectionIDFrame: err = s.handleRetireConnectionIDFrame(frame) + case *wire.HandshakeDoneFrame: + err = s.handleHandshakeDoneFrame() default: err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) } @@ -975,6 +977,14 @@ func (s *session) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame) return s.connIDGenerator.Retire(f.SequenceNumber) } +func (s *session) handleHandshakeDoneFrame() error { + if s.perspective == protocol.PerspectiveServer { + return qerr.Error(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame") + } + s.cryptoStreamHandler.DropHandshakeKeys() + return nil +} + func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error { if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastPacketReceivedTime); err != nil { return err diff --git a/session_test.go b/session_test.go index ff65253a..5e98b051 100644 --- a/session_test.go +++ b/session_test.go @@ -406,6 +406,10 @@ var _ = Describe("Session", func() { Expect(sess.handleFrame(ccf, 0, protocol.EncryptionUnspecified)).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) + + It("errors on HANDSHAKE_DONE frames", func() { + Expect(sess.handleHandshakeDoneFrame()).To(MatchError("PROTOCOL_VIOLATION: received a HANDSHAKE_DONE frame")) + }) }) It("tells its versions", func() { @@ -1207,6 +1211,7 @@ var _ = Describe("Session", func() { defer GinkgoRecover() <-finishHandshake cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().DropHandshakeKeys() close(sess.handshakeCompleteChan) sess.run() }() @@ -1242,10 +1247,13 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("sends a 1-RTT packet when the handshake completes", func() { + It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { done := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { + frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) + Expect(frames).ToNot(BeEmpty()) + Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) defer close(done) return &packedPacket{ header: &wire.ExtendedHeader{}, @@ -1256,6 +1264,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().DropHandshakeKeys() close(sess.handshakeCompleteChan) sess.run() }() @@ -1508,6 +1517,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1) close(sess.handshakeCompleteChan) err := sess.run() nerr, ok := err.(net.Error) @@ -1732,6 +1742,11 @@ var _ = Describe("Client Session", func() { Expect(sess.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) }) + It("handles HANDSHAKE_DONE frames", func() { + cryptoSetup.EXPECT().DropHandshakeKeys() + Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) + }) + Context("handling tokens", func() { var mockTokenStore *MockTokenStore