diff --git a/session.go b/session.go index 7c68f65c9..9f3f5f991 100644 --- a/session.go +++ b/session.go @@ -751,10 +751,10 @@ func (s *session) maybeResetTimer() { } else { deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) } - if !s.config.DisablePathMTUDiscovery { - if probeTime := s.mtuDiscoverer.NextProbeTime(); !probeTime.IsZero() { - deadline = utils.MinTime(deadline, probeTime) - } + } + if s.handshakeConfirmed && !s.config.DisablePathMTUDiscovery { + if probeTime := s.mtuDiscoverer.NextProbeTime(); !probeTime.IsZero() { + deadline = utils.MinTime(deadline, probeTime) } } @@ -786,6 +786,36 @@ func (s *session) handleHandshakeComplete() { s.connIDManager.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete() + if s.perspective == protocol.PerspectiveClient { + s.applyTransportParameters() + return + } + + s.handleHandshakeConfirmed() + + ticket, err := s.cryptoStreamHandler.GetSessionTicket() + if err != nil { + s.closeLocal(err) + } + if ticket != nil { + s.oneRTTStream.Write(ticket) + for s.oneRTTStream.HasData() { + s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) + } + } + token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) + if err != nil { + s.closeLocal(err) + } + s.queueControlFrame(&wire.NewTokenFrame{Token: token}) + s.queueControlFrame(&wire.HandshakeDoneFrame{}) +} + +func (s *session) handleHandshakeConfirmed() { + s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() + s.cryptoStreamHandler.SetHandshakeConfirmed() + if !s.config.DisablePathMTUDiscovery { maxPacketSize := s.peerParams.MaxUDPPayloadSize if maxPacketSize == 0 { @@ -802,31 +832,6 @@ func (s *session) handleHandshakeComplete() { }, ) } - - if s.perspective == protocol.PerspectiveClient { - s.applyTransportParameters() - return - } - - s.handshakeConfirmed = true - s.sentPacketHandler.SetHandshakeConfirmed() - ticket, err := s.cryptoStreamHandler.GetSessionTicket() - if err != nil { - s.closeLocal(err) - } - if ticket != nil { - s.oneRTTStream.Write(ticket) - for s.oneRTTStream.HasData() { - s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) - } - } - token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) - if err != nil { - s.closeLocal(err) - } - s.queueControlFrame(&wire.NewTokenFrame{Token: token}) - s.cryptoStreamHandler.SetHandshakeConfirmed() - s.queueControlFrame(&wire.HandshakeDoneFrame{}) } func (s *session) handlePacketImpl(rp *receivedPacket) bool { @@ -1353,9 +1358,9 @@ func (s *session) handleHandshakeDoneFrame() error { if s.perspective == protocol.PerspectiveServer { return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame") } - s.handshakeConfirmed = true - s.sentPacketHandler.SetHandshakeConfirmed() - s.cryptoStreamHandler.SetHandshakeConfirmed() + if !s.handshakeConfirmed { + s.handleHandshakeConfirmed() + } return nil } @@ -1718,7 +1723,7 @@ func (s *session) sendPacket() (bool, error) { s.sendQueue.Send(packet.buffer) return true, nil } - if !s.config.DisablePathMTUDiscovery && s.handshakeComplete && s.mtuDiscoverer.ShouldSendProbe(now) { + if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing()) if err != nil { return false, err diff --git a/session_test.go b/session_test.go index 31681ca7d..4cba0ace8 100644 --- a/session_test.go +++ b/session_test.go @@ -2555,6 +2555,7 @@ var _ = Describe("Client Session", func() { }) It("handles HANDSHAKE_DONE frames", func() { + sess.peerParams = &wire.TransportParameters{} sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph sph.EXPECT().SetHandshakeConfirmed()