forked from quic-go/quic-go
don't use ASN.1 for marshalling data for the session ticket
This commit is contained in:
@@ -3,7 +3,7 @@ package handshake
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
@@ -13,13 +13,7 @@ import (
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const clientSessionStateRevision = 0
|
||||
|
||||
type nonceField struct {
|
||||
Nonce []byte
|
||||
AppData []byte
|
||||
RTT int64 // in ns
|
||||
}
|
||||
const clientSessionStateRevision = 1
|
||||
|
||||
type clientSessionCache struct {
|
||||
tls.ClientSessionCache
|
||||
@@ -66,13 +60,29 @@ func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, b
|
||||
if rev != clientSessionStateRevision {
|
||||
return nil, false
|
||||
}
|
||||
var nf nonceField
|
||||
if rest, err := asn1.Unmarshal(session.nonce[len(session.nonce)-r.Len():], &nf); err != nil || len(rest) != 0 {
|
||||
rtt, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
c.setAppData(nf.AppData)
|
||||
session.nonce = nf.Nonce
|
||||
c.rttStats.SetInitialRTT(time.Duration(nf.RTT) * time.Nanosecond)
|
||||
appDataLen, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
appData := make([]byte, appDataLen)
|
||||
if _, err := io.ReadFull(r, appData); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
nonceLen, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
nonce := make([]byte, nonceLen)
|
||||
if _, err := io.ReadFull(r, nonce); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
c.setAppData(appData)
|
||||
session.nonce = nonce
|
||||
c.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
|
||||
var qtlsSession qtls.ClientSessionState
|
||||
qtlsSessBytes := (*[unsafe.Sizeof(qtlsSession)]byte)(unsafe.Pointer(&qtlsSession))[:]
|
||||
copy(qtlsSessBytes, sessBytes)
|
||||
@@ -92,18 +102,15 @@ func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState)
|
||||
var session clientSessionState
|
||||
sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
|
||||
copy(sessBytes, qtlsSessBytes)
|
||||
data, err := asn1.Marshal(nonceField{
|
||||
Nonce: session.nonce,
|
||||
AppData: c.getAppData(),
|
||||
RTT: c.rttStats.SmoothedRTT().Nanoseconds(),
|
||||
})
|
||||
b := bytes.NewBuffer(make([]byte, 0, int(utils.VarIntLen(clientSessionStateRevision))+len(data)))
|
||||
utils.WriteVarInt(b, clientSessionStateRevision)
|
||||
b.Write(data)
|
||||
if err != nil { // marshaling
|
||||
panic(err)
|
||||
}
|
||||
session.nonce = b.Bytes()
|
||||
appData := c.getAppData()
|
||||
buf := &bytes.Buffer{}
|
||||
utils.WriteVarInt(buf, clientSessionStateRevision)
|
||||
utils.WriteVarInt(buf, uint64(c.rttStats.SmoothedRTT().Microseconds()))
|
||||
utils.WriteVarInt(buf, uint64(len(appData)))
|
||||
buf.Write(appData)
|
||||
utils.WriteVarInt(buf, uint64(len(session.nonce)))
|
||||
buf.Write(session.nonce)
|
||||
session.nonce = buf.Bytes()
|
||||
var tlsSession tls.ClientSessionState
|
||||
tlsSessBytes := (*[unsafe.Sizeof(tlsSession)]byte)(unsafe.Pointer(&tlsSession))[:]
|
||||
copy(tlsSessBytes, sessBytes)
|
||||
|
||||
@@ -104,19 +104,33 @@ var _ = Describe("ClientSessionCache", func() {
|
||||
})
|
||||
|
||||
It("refuses a session state when unmarshalling fails", func() {
|
||||
rttStats := congestion.NewRTTStats()
|
||||
rttStats.SetInitialRTT(10 * time.Second)
|
||||
cache := tls.NewLRUClientSessionCache(1)
|
||||
b := &bytes.Buffer{}
|
||||
utils.WriteVarInt(b, clientSessionStateRevision)
|
||||
b.Write([]byte("foobar"))
|
||||
cache.Put("localhost", encodeIntoSessionTicket(b.Bytes()))
|
||||
|
||||
csc := newClientSessionCache(
|
||||
cache,
|
||||
congestion.NewRTTStats(),
|
||||
func() []byte { return nil },
|
||||
func([]byte) {},
|
||||
rttStats,
|
||||
func() []byte { return []byte("foobar") },
|
||||
func(b []byte) {},
|
||||
)
|
||||
_, ok := csc.Get("localhost")
|
||||
Expect(ok).To(BeFalse())
|
||||
csc.Put("localhost", &qtls.ClientSessionState{})
|
||||
state, ok := cache.Get("localhost")
|
||||
Expect(ok).To(BeTrue())
|
||||
tlsSessBytes := (*[unsafe.Sizeof(*state)]byte)(unsafe.Pointer(state))[:]
|
||||
var session clientSessionState
|
||||
sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
|
||||
copy(sessBytes, tlsSessBytes)
|
||||
Expect(session.nonce).ToNot(BeEmpty())
|
||||
|
||||
_, ok = csc.Get("localhost")
|
||||
Expect(ok).To(BeTrue())
|
||||
nonce := session.nonce
|
||||
|
||||
for i := 0; i < len(nonce); i++ {
|
||||
session.nonce = session.nonce[:i]
|
||||
copy(tlsSessBytes, sessBytes)
|
||||
_, ok = csc.Get("localhost")
|
||||
Expect(ok).To(BeFalse())
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user