From 85abf5fc536e0f5ac5bfa5a325acc05a1f890e2e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 2 Feb 2020 14:15:03 +0700 Subject: [PATCH] save the RTT in the session ticket and use it when accepting 0-RTT --- internal/handshake/crypto_setup.go | 11 +++++++++-- internal/handshake/session_ticket.go | 10 +++++++++- internal/handshake/session_ticket_test.go | 19 +++++++++++++++---- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index eb235729e..1b76d7c43 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -87,6 +87,8 @@ type cryptoSetup struct { // for clients: to see if a ServerHello is a HelloRetryRequest writeRecord chan struct{} + rttStats *congestion.RTTStats + logger utils.Logger perspective protocol.Perspective @@ -198,6 +200,7 @@ func newCryptoSetup( runner: runner, ourParams: tp, paramsChan: extHandler.TransportParameters(), + rttStats: rttStats, logger: logger, perspective: perspective, handshakeDone: make(chan struct{}), @@ -464,7 +467,10 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { var appData []byte // Save transport parameters to the session ticket if we're allowing 0-RTT. if h.tlsConf.MaxEarlyData > 0 { - appData = (&sessionTicket{Parameters: h.ourParams}).Marshal() + appData = (&sessionTicket{ + Parameters: h.ourParams, + RTT: h.rttStats.SmoothedRTT(), + }).Marshal() } return h.conn.GetSessionTicket(appData) } @@ -479,7 +485,8 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { } valid := h.ourParams.ValidFor0RTT(t.Parameters) if valid { - h.logger.Debugf("Accepting 0-RTT.") + h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) + h.rttStats.SetInitialRTT(t.RTT) } else { h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") } diff --git a/internal/handshake/session_ticket.go b/internal/handshake/session_ticket.go index 41153af96..0ca726803 100644 --- a/internal/handshake/session_ticket.go +++ b/internal/handshake/session_ticket.go @@ -4,19 +4,22 @@ import ( "bytes" "errors" "fmt" + "time" "github.com/lucas-clemente/quic-go/internal/utils" ) -const sessionTicketRevision = 0 +const sessionTicketRevision = 1 type sessionTicket struct { Parameters *TransportParameters + RTT time.Duration // to be encoded in mus } func (t *sessionTicket) Marshal() []byte { b := &bytes.Buffer{} utils.WriteVarInt(b, sessionTicketRevision) + utils.WriteVarInt(b, uint64(t.RTT.Microseconds())) t.Parameters.MarshalForSessionTicket(b) return b.Bytes() } @@ -30,10 +33,15 @@ func (t *sessionTicket) Unmarshal(b []byte) error { if rev != sessionTicketRevision { return fmt.Errorf("unknown session ticket revision: %d", rev) } + rtt, err := utils.ReadVarInt(r) + if err != nil { + return errors.New("failed to read RTT") + } var tp TransportParameters if err := tp.UnmarshalFromSessionTicket(b[len(b)-r.Len():]); err != nil { return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) } t.Parameters = &tp + t.RTT = time.Duration(rtt) * time.Microsecond return nil } diff --git a/internal/handshake/session_ticket_test.go b/internal/handshake/session_ticket_test.go index 682af0d4f..4633f00ef 100644 --- a/internal/handshake/session_ticket_test.go +++ b/internal/handshake/session_ticket_test.go @@ -2,6 +2,7 @@ package handshake import ( "bytes" + "time" "github.com/lucas-clemente/quic-go/internal/utils" @@ -11,14 +12,18 @@ import ( var _ = Describe("Session Ticket", func() { It("marshals and unmarshals a session ticket", func() { - params := &TransportParameters{ - InitialMaxStreamDataBidiLocal: 1, - InitialMaxStreamDataBidiRemote: 2, + ticket := &sessionTicket{ + Parameters: &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + }, + RTT: 1337 * time.Microsecond, } var t sessionTicket - Expect(t.Unmarshal((&sessionTicket{Parameters: params}).Marshal())).To(Succeed()) + Expect(t.Unmarshal(ticket.Marshal())).To(Succeed()) Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) + Expect(t.RTT).To(Equal(1337 * time.Microsecond)) }) It("refuses to unmarshal if the ticket is too short for the revision", func() { @@ -31,6 +36,12 @@ var _ = Describe("Session Ticket", func() { Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("unknown session ticket revision: 1337")) }) + It("refuses to unmarshal if the RTT cannot be read", func() { + b := &bytes.Buffer{} + utils.WriteVarInt(b, sessionTicketRevision) + Expect((&sessionTicket{}).Unmarshal(b.Bytes())).To(MatchError("failed to read RTT")) + }) + It("refuses to unmarshal if unmarshaling the transport parameters fails", func() { b := &bytes.Buffer{} utils.WriteVarInt(b, sessionTicketRevision)