diff --git a/session.go b/session.go index 7146aa2a6..cc9b10f0e 100644 --- a/session.go +++ b/session.go @@ -313,7 +313,7 @@ var newSession = func( conn.RemoteAddr(), params, &handshakeRunner{ - onReceivedParams: s.processTransportParameters, + onReceivedParams: s.handleTransportParameters, onError: s.closeLocal, dropKeys: s.dropEncryptionLevel, onHandshakeComplete: func() { @@ -434,7 +434,7 @@ var newClientSession = func( conn.RemoteAddr(), params, &handshakeRunner{ - onReceivedParams: s.processTransportParameters, + onReceivedParams: s.handleTransportParameters, onError: s.closeLocal, dropKeys: s.dropEncryptionLevel, onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, @@ -759,7 +759,7 @@ func (s *session) idleTimeoutStartTime() time.Time { func (s *session) handleHandshakeComplete() { s.handshakeComplete = true s.handshakeCompleteChan = nil // prevent this case from ever being selected again - s.handshakeCtxCancel() + defer s.handshakeCtxCancel() // Once the handshake completes, we have derived 1-RTT keys. // There's no point in queueing undecryptable packets for later decryption any more. s.undecryptablePackets = nil @@ -784,27 +784,30 @@ func (s *session) handleHandshakeComplete() { ) } - if s.perspective == protocol.PerspectiveServer { - s.handshakeConfirmed = true - s.sentPacketHandler.SetHandshakeConfirmed() - ticket, err := s.cryptoStreamHandler.GetSessionTicket() - if err != nil { - s.closeLocal(err) - } - if ticket != nil { - s.oneRTTStream.Write(ticket) - for s.oneRTTStream.HasData() { - s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) - } - } - token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) - if err != nil { - s.closeLocal(err) - } - s.queueControlFrame(&wire.NewTokenFrame{Token: token}) - s.cryptoStreamHandler.SetHandshakeConfirmed() - s.queueControlFrame(&wire.HandshakeDoneFrame{}) + if s.perspective == protocol.PerspectiveClient { + s.applyTransportParameters() + return } + + s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() + ticket, err := s.cryptoStreamHandler.GetSessionTicket() + if err != nil { + s.closeLocal(err) + } + if ticket != nil { + s.oneRTTStream.Write(ticket) + for s.oneRTTStream.HasData() { + s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) + } + } + token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) + if err != nil { + s.closeLocal(err) + } + s.queueControlFrame(&wire.NewTokenFrame{Token: token}) + s.cryptoStreamHandler.SetHandshakeConfirmed() + s.queueControlFrame(&wire.HandshakeDoneFrame{}) } func (s *session) handlePacketImpl(rp *receivedPacket) bool { @@ -1470,13 +1473,22 @@ func (s *session) restoreTransportParameters(params *wire.TransportParameters) { s.streamsMap.UpdateLimits(params) } -func (s *session) processTransportParameters(params *wire.TransportParameters) { - if err := s.processTransportParametersImpl(params); err != nil { +func (s *session) handleTransportParameters(params *wire.TransportParameters) { + if err := s.checkTransportParameters(params); err != nil { s.closeLocal(err) } + s.peerParams = params + // On the client side we have to wait for handshake completion. + // During a 0-RTT connection, we are only allowed to use the new transport parameters for 1-RTT packets. + if s.perspective == protocol.PerspectiveServer { + s.applyTransportParameters() + // On the server side, the early session is ready as soon as we processed + // the client's transport parameters. + close(s.earlySessionReadyChan) + } } -func (s *session) processTransportParametersImpl(params *wire.TransportParameters) error { +func (s *session) checkTransportParameters(params *wire.TransportParameters) error { if s.logger.Debug() { s.logger.Debugf("Processed Transport Parameters: %s", params) } @@ -1489,24 +1501,28 @@ func (s *session) processTransportParametersImpl(params *wire.TransportParameter return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID)) } - if s.perspective == protocol.PerspectiveClient { - // check the original_destination_connection_id - if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { - return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID)) - } - if s.retrySrcConnID != nil { // a Retry was performed - if params.RetrySourceConnectionID == nil { - return qerr.NewError(qerr.TransportParameterError, "missing retry_source_connection_id") - } - if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { - return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID)) - } - } else if params.RetrySourceConnectionID != nil { - return qerr.NewError(qerr.TransportParameterError, "received retry_source_connection_id, although no Retry was performed") - } + if s.perspective == protocol.PerspectiveServer { + return nil } + // check the original_destination_connection_id + if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { + return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID)) + } + if s.retrySrcConnID != nil { // a Retry was performed + if params.RetrySourceConnectionID == nil { + return qerr.NewError(qerr.TransportParameterError, "missing retry_source_connection_id") + } + if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { + return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID)) + } + } else if params.RetrySourceConnectionID != nil { + return qerr.NewError(qerr.TransportParameterError, "received retry_source_connection_id, although no Retry was performed") + } + return nil +} - s.peerParams = params +func (s *session) applyTransportParameters() { + params := s.peerParams // Our local idle timeout will always be > 0. s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) s.keepAliveInterval = utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval) @@ -1524,12 +1540,6 @@ func (s *session) processTransportParametersImpl(params *wire.TransportParameter // Retire the connection ID. s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken) } - // On the server side, the early session is ready as soon as we processed - // the client's transport parameters. - if s.perspective == protocol.PerspectiveServer { - close(s.earlySessionReadyChan) - } - return nil } func (s *session) sendPackets() error { diff --git a/session_test.go b/session_test.go index 60a4aa23a..d28908d33 100644 --- a/session_test.go +++ b/session_test.go @@ -2081,7 +2081,7 @@ var _ = Describe("Session", func() { sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) sessionRunner.EXPECT().Add(gomock.Any(), sess).Times(2) tracer.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) + sess.handleTransportParameters(params) Expect(sess.earlySessionReady()).To(BeClosed()) }) }) @@ -2091,7 +2091,7 @@ var _ = Describe("Session", func() { streamManager.EXPECT().UpdateLimits(gomock.Any()) packer.EXPECT().HandleTransportParameters(gomock.Any()) tracer.EXPECT().ReceivedTransportParameters(gomock.Any()) - sess.processTransportParameters(&wire.TransportParameters{ + sess.handleTransportParameters(&wire.TransportParameters{ MaxIdleTimeout: t, InitialSourceConnectionID: destConnID, }) @@ -2457,10 +2457,9 @@ var _ = Describe("Client Session", func() { Eventually(areSessionsRunning).Should(BeFalse()) mconn = NewMockSendConn(mockCtrl) - mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2) - mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{}) + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() + mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() if tlsConf == nil { - mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) tlsConf = &tls.Config{} } sessionRunner = NewMockSessionRunner(mockCtrl) @@ -2774,8 +2773,8 @@ var _ = Describe("Client Session", func() { packer.EXPECT().HandleTransportParameters(gomock.Any()) packer.EXPECT().PackCoalescedPacket().MaxTimes(1) tracer.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) - sess.connIDManager.SetHandshakeComplete() + sess.handleTransportParameters(params) + sess.handleHandshakeComplete() // make sure the connection ID is not retired cf, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(cf).To(BeEmpty()) @@ -2795,7 +2794,8 @@ var _ = Describe("Client Session", func() { } packer.EXPECT().HandleTransportParameters(gomock.Any()) tracer.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) + sess.handleTransportParameters(params) + sess.handleHandshakeComplete() Expect(sess.idleTimeout).To(Equal(18 * time.Second)) }) @@ -2808,7 +2808,7 @@ var _ = Describe("Client Session", func() { } expectClose() tracer.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) + sess.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected initial_source_connection_id to equal deadbeef, is decafbad"))) }) @@ -2821,7 +2821,7 @@ var _ = Describe("Client Session", func() { } expectClose() tracer.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) + sess.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: missing retry_source_connection_id"))) }) @@ -2835,7 +2835,7 @@ var _ = Describe("Client Session", func() { } expectClose() tracer.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) + sess.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected retry_source_connection_id to equal deadbeef, is deadc0de"))) }) @@ -2848,7 +2848,7 @@ var _ = Describe("Client Session", func() { } expectClose() tracer.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) + sess.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: received retry_source_connection_id, although no Retry was performed"))) }) @@ -2861,7 +2861,7 @@ var _ = Describe("Client Session", func() { } expectClose() tracer.EXPECT().ReceivedTransportParameters(params) - sess.processTransportParameters(params) + sess.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_destination_connection_id to equal deadbeef, is decafbad"))) }) })