From 182a25d2b93fde206e6173f45764cd91f3fbb2cb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 8 Feb 2020 21:41:58 +0100 Subject: [PATCH] don't use ASN.1 for marshalling data for the session ticket --- internal/handshake/client_session_cache.go | 57 +++++++++++-------- .../handshake/client_session_cache_test.go | 34 +++++++---- 2 files changed, 56 insertions(+), 35 deletions(-) diff --git a/internal/handshake/client_session_cache.go b/internal/handshake/client_session_cache.go index 8158b3970..f28bb5b30 100644 --- a/internal/handshake/client_session_cache.go +++ b/internal/handshake/client_session_cache.go @@ -3,7 +3,7 @@ package handshake import ( "bytes" "crypto/tls" - "encoding/asn1" + "io" "time" "unsafe" @@ -13,13 +13,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -const clientSessionStateRevision = 0 - -type nonceField struct { - Nonce []byte - AppData []byte - RTT int64 // in ns -} +const clientSessionStateRevision = 1 type clientSessionCache struct { tls.ClientSessionCache @@ -66,13 +60,29 @@ func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, b if rev != clientSessionStateRevision { return nil, false } - var nf nonceField - if rest, err := asn1.Unmarshal(session.nonce[len(session.nonce)-r.Len():], &nf); err != nil || len(rest) != 0 { + rtt, err := utils.ReadVarInt(r) + if err != nil { return nil, false } - c.setAppData(nf.AppData) - session.nonce = nf.Nonce - c.rttStats.SetInitialRTT(time.Duration(nf.RTT) * time.Nanosecond) + appDataLen, err := utils.ReadVarInt(r) + if err != nil { + return nil, false + } + appData := make([]byte, appDataLen) + if _, err := io.ReadFull(r, appData); err != nil { + return nil, false + } + nonceLen, err := utils.ReadVarInt(r) + if err != nil { + return nil, false + } + nonce := make([]byte, nonceLen) + if _, err := io.ReadFull(r, nonce); err != nil { + return nil, false + } + c.setAppData(appData) + session.nonce = nonce + c.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) var qtlsSession qtls.ClientSessionState qtlsSessBytes := (*[unsafe.Sizeof(qtlsSession)]byte)(unsafe.Pointer(&qtlsSession))[:] copy(qtlsSessBytes, sessBytes) @@ -92,18 +102,15 @@ func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) var session clientSessionState sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] copy(sessBytes, qtlsSessBytes) - data, err := asn1.Marshal(nonceField{ - Nonce: session.nonce, - AppData: c.getAppData(), - RTT: c.rttStats.SmoothedRTT().Nanoseconds(), - }) - b := bytes.NewBuffer(make([]byte, 0, int(utils.VarIntLen(clientSessionStateRevision))+len(data))) - utils.WriteVarInt(b, clientSessionStateRevision) - b.Write(data) - if err != nil { // marshaling - panic(err) - } - session.nonce = b.Bytes() + appData := c.getAppData() + buf := &bytes.Buffer{} + utils.WriteVarInt(buf, clientSessionStateRevision) + utils.WriteVarInt(buf, uint64(c.rttStats.SmoothedRTT().Microseconds())) + utils.WriteVarInt(buf, uint64(len(appData))) + buf.Write(appData) + utils.WriteVarInt(buf, uint64(len(session.nonce))) + buf.Write(session.nonce) + session.nonce = buf.Bytes() var tlsSession tls.ClientSessionState tlsSessBytes := (*[unsafe.Sizeof(tlsSession)]byte)(unsafe.Pointer(&tlsSession))[:] copy(tlsSessBytes, sessBytes) diff --git a/internal/handshake/client_session_cache_test.go b/internal/handshake/client_session_cache_test.go index c5b3a372e..85294438b 100644 --- a/internal/handshake/client_session_cache_test.go +++ b/internal/handshake/client_session_cache_test.go @@ -104,19 +104,33 @@ var _ = Describe("ClientSessionCache", func() { }) It("refuses a session state when unmarshalling fails", func() { + rttStats := congestion.NewRTTStats() + rttStats.SetInitialRTT(10 * time.Second) cache := tls.NewLRUClientSessionCache(1) - b := &bytes.Buffer{} - utils.WriteVarInt(b, clientSessionStateRevision) - b.Write([]byte("foobar")) - cache.Put("localhost", encodeIntoSessionTicket(b.Bytes())) - csc := newClientSessionCache( cache, - congestion.NewRTTStats(), - func() []byte { return nil }, - func([]byte) {}, + rttStats, + func() []byte { return []byte("foobar") }, + func(b []byte) {}, ) - _, ok := csc.Get("localhost") - Expect(ok).To(BeFalse()) + csc.Put("localhost", &qtls.ClientSessionState{}) + state, ok := cache.Get("localhost") + Expect(ok).To(BeTrue()) + tlsSessBytes := (*[unsafe.Sizeof(*state)]byte)(unsafe.Pointer(state))[:] + var session clientSessionState + sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] + copy(sessBytes, tlsSessBytes) + Expect(session.nonce).ToNot(BeEmpty()) + + _, ok = csc.Get("localhost") + Expect(ok).To(BeTrue()) + nonce := session.nonce + + for i := 0; i < len(nonce); i++ { + session.nonce = session.nonce[:i] + copy(tlsSessBytes, sessBytes) + _, ok = csc.Get("localhost") + Expect(ok).To(BeFalse()) + } }) })