diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go index aea54271..af5400d1 100644 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -641,6 +641,49 @@ var _ = Describe("0-RTT", func() { Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) + It("doesn't use 0-RTT, if the server didn't enable it", func() { + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer server.Close() + + gets := make(chan string, 100) + puts := make(chan string, 100) + cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) + tlsConf := getTLSClientConfig() + tlsConf.ClientSessionCache = cache + conn1, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn1.CloseWithError(0, "") + var sessionKey string + Eventually(puts).Should(Receive(&sessionKey)) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) + + serverConn, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) + + conn2, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(gets).To(Receive(Equal(sessionKey))) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue()) + + serverConn, err = server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse()) + conn2.CloseWithError(0, "") + }) + DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { counter, tracer := newPacketTracer() diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index c0fc5325..7bfa66ce 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -692,6 +692,49 @@ var _ = Describe("0-RTT", func() { Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) + It("doesn't use 0-RTT, if the server didn't enable it", func() { + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer server.Close() + + gets := make(chan string, 100) + puts := make(chan string, 100) + cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) + tlsConf := getTLSClientConfig() + tlsConf.ClientSessionCache = cache + conn1, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn1.CloseWithError(0, "") + var sessionKey string + Eventually(puts).Should(Receive(&sessionKey)) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) + + serverConn, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) + + conn2, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(gets).To(Receive(Equal(sessionKey))) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue()) + + serverConn, err = server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse()) + conn2.CloseWithError(0, "") + }) + DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { counter, tracer := newPacketTracer() diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index c5787e86..34f29887 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -25,7 +25,7 @@ type quicVersionContextKey struct{} var QUICVersionContextKey = &quicVersionContextKey{} -const clientSessionStateRevision = 3 +const clientSessionStateRevision = 4 type cryptoSetup struct { tlsConf *tls.Config @@ -313,19 +313,24 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) error { } // must be called after receiving the transport parameters -func (h *cryptoSetup) marshalDataForSessionState() []byte { +func (h *cryptoSetup) marshalDataForSessionState(earlyData bool) []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, clientSessionStateRevision) b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds())) - return h.peerParams.MarshalForSessionTicket(b) + if earlyData { + // only save the transport parameters for 0-RTT enabled session tickets + return h.peerParams.MarshalForSessionTicket(b) + } + return b } -func (h *cryptoSetup) handleDataFromSessionState(data []byte) (allowEarlyData bool) { - tp, err := h.handleDataFromSessionStateImpl(data) +func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) { + rtt, tp, err := decodeDataFromSessionState(data, earlyData) if err != nil { h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) return } + h.rttStats.SetInitialRTT(rtt) // The session ticket might have been saved from a connection that allowed 0-RTT, // and therefore contain transport parameters. // Only use them if 0-RTT is actually used on the new connection. @@ -336,25 +341,28 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte) (allowEarlyData bo return false } -func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { +func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) { r := bytes.NewReader(data) ver, err := quicvarint.Read(r) if err != nil { - return nil, err + return 0, nil, err } if ver != clientSessionStateRevision { - return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) + return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) } - rtt, err := quicvarint.Read(r) + rttEncoded, err := quicvarint.Read(r) if err != nil { - return nil, err + return 0, nil, err + } + rtt := time.Duration(rttEncoded) * time.Microsecond + if !earlyData { + return rtt, nil, nil } - h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) var tp wire.TransportParameters if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return nil, err + return 0, nil, err } - return &tp, nil + return rtt, &tp, nil } func (h *cryptoSetup) getDataForSessionTicket() []byte { diff --git a/internal/qtls/client_session_cache.go b/internal/qtls/client_session_cache.go index 336d6035..d81eb8c3 100644 --- a/internal/qtls/client_session_cache.go +++ b/internal/qtls/client_session_cache.go @@ -7,8 +7,8 @@ import ( ) type clientSessionCache struct { - getData func() []byte - setData func([]byte) (allowEarlyData bool) + getData func(earlyData bool) []byte + setData func(data []byte, earlyData bool) (allowEarlyData bool) wrapped tls.ClientSessionCache } @@ -24,7 +24,7 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { c.wrapped.Put(key, cs) return } - state.Extra = append(state.Extra, addExtraPrefix(c.getData())) + state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData))) newCS, err := tls.NewResumptionState(ticket, state) if err != nil { // It's not clear why this would error. Just save the original state. @@ -46,12 +46,13 @@ func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { c.wrapped.Put(key, nil) return nil, false } - var earlyData bool // restore QUIC transport parameters and RTT stored in state.Extra if extra := findExtraData(state.Extra); extra != nil { - earlyData = c.setData(extra) + earlyData := c.setData(extra, state.EarlyData) + if state.EarlyData { + state.EarlyData = earlyData + } } - state.EarlyData = earlyData session, err := tls.NewResumptionState(ticket, state) if err != nil { // It's not clear why this would error. diff --git a/internal/qtls/client_session_cache_test.go b/internal/qtls/client_session_cache_test.go index 6af19293..fdb0aa06 100644 --- a/internal/qtls/client_session_cache_test.go +++ b/internal/qtls/client_session_cache_test.go @@ -40,8 +40,9 @@ var _ = Describe("Client Session Cache", func() { RootCAs: testdata.GetRootCA(), ClientSessionCache: &clientSessionCache{ wrapped: tls.NewLRUClientSessionCache(10), - getData: func() []byte { return []byte("session") }, - setData: func(data []byte) bool { + getData: func(bool) []byte { return []byte("session") }, + setData: func(data []byte, earlyData bool) bool { + Expect(earlyData).To(BeFalse()) // running on top of TCP, we can only test non-0-RTT here restored <- data return true }, diff --git a/internal/qtls/go120.go b/internal/qtls/go120.go index 7e7eee1e..554aeaf4 100644 --- a/internal/qtls/go120.go +++ b/internal/qtls/go120.go @@ -52,10 +52,20 @@ func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTi } } -func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte) bool) { +func SetupConfigForClient( + conf *QUICConfig, + getDataForSessionState func(earlyData bool) []byte, + setDataFromSessionState func(data []byte, earlyData bool) (allowEarlyData bool), +) { conf.ExtraConfig = &qtls.ExtraConfig{ - GetAppDataForSessionState: getDataForSessionState, - SetAppDataFromSessionState: setDataFromSessionState, + GetAppDataForSessionState: func() []byte { + // qtls only calls the GetAppDataForSessionState when doing 0-RTT + return getDataForSessionState(true) + }, + SetAppDataFromSessionState: func(data []byte) (allowEarlyData bool) { + // qtls only calls the SetAppDataFromSessionState for 0-RTT enabled tickets + return setDataFromSessionState(data, true) + }, } } diff --git a/internal/qtls/go121.go b/internal/qtls/go121.go index 35a52ce0..66e289b5 100644 --- a/internal/qtls/go121.go +++ b/internal/qtls/go121.go @@ -93,7 +93,11 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, hand } } -func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte) bool) { +func SetupConfigForClient( + qconf *QUICConfig, + getData func(earlyData bool) []byte, + setData func(data []byte, earlyData bool) (allowEarlyData bool), +) { conf := qconf.TLSConfig if conf.ClientSessionCache != nil { origCache := conf.ClientSessionCache