diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index a41167fb..e0261fac 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -339,7 +339,7 @@ func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool { h.logger.Debugf("Sending HelloRetryRequest") return false case data := <-h.paramsChan: - h.runner.OnReceivedParams(data) + h.handleTransportParameters(data) case <-h.handshakeDone: return false } @@ -404,7 +404,7 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool { case typeEncryptedExtensions: select { case data := <-h.paramsChan: - h.runner.OnReceivedParams(data) + h.handleTransportParameters(data) case <-h.handshakeDone: return false } @@ -431,6 +431,14 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool { } } +func (h *cryptoSetup) handleTransportParameters(data []byte) { + var tp TransportParameters + if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { + h.runner.OnError(qerr.Error(qerr.TransportParameterError, err.Error())) + } + h.runner.OnReceivedParams(&tp) +} + // only valid for the server func (h *cryptoSetup) maybeSendSessionTicket() { ticket, err := h.conn.GetSessionTicket(h.ourParams.MarshalForSessionTicket()) diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index fe4afbe6..d58f61c3 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -430,11 +430,11 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("receives transport parameters", func() { - var cTransportParametersRcvd, sTransportParametersRcvd []byte + var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second} cRunner := NewMockHandshakeRunner(mockCtrl) - cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { sTransportParametersRcvd = b }) + cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp }) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, @@ -453,7 +453,7 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() var token [16]byte sRunner := NewMockHandshakeRunner(mockCtrl) - sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(b []byte) { cTransportParametersRcvd = b }) + sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp }) sRunner.EXPECT().OnHandshakeComplete() sTransportParameters := &TransportParameters{ MaxIdleTimeout: 0x1337 * time.Second, @@ -480,14 +480,9 @@ var _ = Describe("Crypto Setup TLS", func() { close(done) }() Eventually(done).Should(BeClosed()) - Expect(cTransportParametersRcvd).ToNot(BeNil()) - clTP := &TransportParameters{} - Expect(clTP.Unmarshal(cTransportParametersRcvd, protocol.PerspectiveClient)).To(Succeed()) - Expect(clTP.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout)) + Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout)) Expect(sTransportParametersRcvd).ToNot(BeNil()) - srvTP := &TransportParameters{} - Expect(srvTP.Unmarshal(sTransportParametersRcvd, protocol.PerspectiveServer)).To(Succeed()) - Expect(srvTP.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout)) + Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout)) }) Context("with session tickets", func() { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 1915a71b..90aa4197 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -59,7 +59,7 @@ type tlsExtensionHandler interface { } type handshakeRunner interface { - OnReceivedParams([]byte) + OnReceivedParams(*TransportParameters) OnHandshakeComplete() OnError(error) DropKeys(protocol.EncryptionLevel) diff --git a/internal/handshake/mock_handshake_runner_test.go b/internal/handshake/mock_handshake_runner_test.go index 7911e69d..d1d82c65 100644 --- a/internal/handshake/mock_handshake_runner_test.go +++ b/internal/handshake/mock_handshake_runner_test.go @@ -71,7 +71,7 @@ func (mr *MockHandshakeRunnerMockRecorder) OnHandshakeComplete() *gomock.Call { } // OnReceivedParams mocks base method -func (m *MockHandshakeRunner) OnReceivedParams(arg0 []byte) { +func (m *MockHandshakeRunner) OnReceivedParams(arg0 *TransportParameters) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnReceivedParams", arg0) } diff --git a/session.go b/session.go index d65bf9e0..8f58432f 100644 --- a/session.go +++ b/session.go @@ -84,16 +84,16 @@ type sessionRunner interface { } type handshakeRunner struct { - onReceivedParams func([]byte) + onReceivedParams func(*handshake.TransportParameters) onError func(error) dropKeys func(protocol.EncryptionLevel) onHandshakeComplete func() } -func (r *handshakeRunner) OnReceivedParams(b []byte) { r.onReceivedParams(b) } -func (r *handshakeRunner) OnError(e error) { r.onError(e) } -func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } -func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() } +func (r *handshakeRunner) OnReceivedParams(tp *handshake.TransportParameters) { r.onReceivedParams(tp) } +func (r *handshakeRunner) OnError(e error) { r.onError(e) } +func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } +func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() } type closeError struct { err error @@ -1092,19 +1092,13 @@ func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { s.receivedPacketHandler.DropPackets(encLevel) } -func (s *session) processTransportParameters(data []byte) { - var params *handshake.TransportParameters - var err error - switch s.perspective { - case protocol.PerspectiveClient: - params, err = s.processTransportParametersForClient(data) - case protocol.PerspectiveServer: - params, err = s.processTransportParametersForServer(data) - } - if err != nil { - s.closeLocal(err) +func (s *session) processTransportParameters(params *handshake.TransportParameters) { + // check the Retry token + if s.perspective == protocol.PerspectiveClient && !params.OriginalConnectionID.Equal(s.origDestConnID) { + s.closeLocal(qerr.Error(qerr.TransportParameterError, fmt.Sprintf("expected original_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalConnectionID))) return } + s.logger.Debugf("Received Transport Parameters: %s", params) s.peerParams = params // Our local idle timeout will always be > 0. @@ -1122,36 +1116,15 @@ func (s *session) processTransportParameters(data []byte) { if params.StatelessResetToken != nil { s.connIDManager.SetStatelessResetToken(*params.StatelessResetToken) } - // On the server side, the early session is ready as soon as we processed - // the client's transport parameters. - close(s.earlySessionReadyChan) -} - -func (s *session) processTransportParametersForClient(data []byte) (*handshake.TransportParameters, error) { - params := &handshake.TransportParameters{} - if err := params.Unmarshal(data, s.perspective.Opposite()); err != nil { - return nil, err - } - - // check the Retry token - if !params.OriginalConnectionID.Equal(s.origDestConnID) { - return nil, qerr.Error(qerr.TransportParameterError, fmt.Sprintf("expected original_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalConnectionID)) - } // We don't support connection migration yet, so we don't have any use for the preferred_address. if params.PreferredAddress != nil { s.logger.Debugf("Server sent preferred_address. Retiring the preferred_address connection ID.") // Retire the connection ID. s.framer.QueueControlFrame(&wire.RetireConnectionIDFrame{SequenceNumber: 1}) } - return params, nil -} - -func (s *session) processTransportParametersForServer(data []byte) (*handshake.TransportParameters, error) { - params := &handshake.TransportParameters{} - if err := params.Unmarshal(data, s.perspective.Opposite()); err != nil { - return nil, err - } - return params, nil + // On the server side, the early session is ready as soon as we processed + // the client's transport parameters. + close(s.earlySessionReadyChan) } func (s *session) sendPackets() error { diff --git a/session_test.go b/session_test.go index aa1a6a99..cdc5ef24 100644 --- a/session_test.go +++ b/session_test.go @@ -1313,23 +1313,7 @@ var _ = Describe("Session", func() { }) Context("transport parameters", func() { - It("errors if it can't unmarshal the TransportParameters", func() { - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := sess.run() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("transport parameter")) - }() - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - cryptoSetup.EXPECT().Close() - sess.processTransportParameters([]byte("invalid")) - Eventually(sess.Context().Done()).Should(BeClosed()) - }) - - It("processes transport parameters received from the client", func() { + It("process transport parameters received from the client", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) @@ -1348,7 +1332,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackPacket().MaxTimes(3) Expect(sess.earlySessionReady()).ToNot(BeClosed()) sessionRunner.EXPECT().Add(gomock.Any(), sess).Times(2) - sess.processTransportParameters(params.Marshal()) + sess.processTransportParameters(params) Expect(sess.earlySessionReady()).To(BeClosed()) // make the go routine return @@ -1367,10 +1351,9 @@ var _ = Describe("Session", func() { Context("keep-alives", func() { setRemoteIdleTimeout := func(t time.Duration) { - tp := &handshake.TransportParameters{MaxIdleTimeout: t} streamManager.EXPECT().UpdateLimits(gomock.Any()) packer.EXPECT().HandleTransportParameters(gomock.Any()) - sess.processTransportParameters(tp.Marshal()) + sess.processTransportParameters(&handshake.TransportParameters{MaxIdleTimeout: t}) } runSession := func() { @@ -1814,27 +1797,40 @@ var _ = Describe("Client Session", func() { }) Context("transport parameters", func() { - It("errors if it can't unmarshal the TransportParameters", func() { + var ( + closed bool + errChan chan error + ) + + JustBeforeEach(func() { + errChan = make(chan error, 1) + closed = false go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := sess.run() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("transport parameter")) + errChan <- sess.run() }() - expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - cryptoSetup.EXPECT().Close() - sess.processTransportParameters([]byte("invalid")) + }) + + expectClose := func() { + if !closed { + sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) + Expect(s.Close()).To(Succeed()) + }) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil).MaxTimes(1) + cryptoSetup.EXPECT().Close() + } + closed = true + } + + AfterEach(func() { + expectClose() + sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) It("immediately retires the preferred_address connection ID", func() { - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() - }() params := &handshake.TransportParameters{ PreferredAddress: &handshake.PreferredAddress{ IPv4: net.IPv4(127, 0, 0, 1), @@ -1844,20 +1840,10 @@ var _ = Describe("Client Session", func() { } packer.EXPECT().HandleTransportParameters(gomock.Any()) packer.EXPECT().PackPacket().MaxTimes(1) - sess.processTransportParameters(params.Marshal()) + sess.processTransportParameters(params) cf, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(cf).To(HaveLen(1)) Expect(cf[0].Frame).To(Equal(&wire.RetireConnectionIDFrame{SequenceNumber: 1})) - - // make the go routine return - sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) - Expect(s.Close()).To(Succeed()) - }) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - cryptoSetup.EXPECT().Close() - sess.Close() - Eventually(sess.Context().Done()).Should(BeClosed()) }) It("uses the minimum of the peers' idle timeouts", func() { @@ -1866,27 +1852,27 @@ var _ = Describe("Client Session", func() { MaxIdleTimeout: 18 * time.Second, } packer.EXPECT().HandleTransportParameters(gomock.Any()) - sess.processTransportParameters(params.Marshal()) + sess.processTransportParameters(params) Expect(sess.idleTimeout).To(Equal(18 * time.Second)) }) It("errors if the TransportParameters contain an original_connection_id, although no Retry was performed", func() { - params := &handshake.TransportParameters{ + expectClose() + sess.processTransportParameters(&handshake.TransportParameters{ OriginalConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - } - _, err := sess.processTransportParametersForClient(params.Marshal()) - Expect(err).To(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_connection_id to equal (empty), is 0xdecafbad")) + }) + Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_connection_id to equal (empty), is 0xdecafbad"))) }) It("errors if the TransportParameters contain a wrong original_connection_id", func() { sess.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - params := &handshake.TransportParameters{ + expectClose() + sess.processTransportParameters(&handshake.TransportParameters{ OriginalConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - } - _, err := sess.processTransportParametersForClient(params.Marshal()) - Expect(err).To(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_connection_id to equal 0xdeadbeef, is 0xdecafbad")) + }) + Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_connection_id to equal 0xdeadbeef, is 0xdecafbad"))) }) })