simplify casting of the (q)tls.ClientSessionCache

This commit is contained in:
Marten Seemann
2020-02-26 16:13:57 +07:00
parent d31dcdaa7b
commit 70263249ee
2 changed files with 7 additions and 28 deletions

View File

@@ -48,10 +48,7 @@ func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, b
// In order to allow users of quic-go to use a tls.Config, // In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache. // we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical. // In unsafe.go we check that the two structs are actually identical.
tlsSessBytes := (*[unsafe.Sizeof(*sess)]byte)(unsafe.Pointer(sess))[:] session := (*clientSessionState)(unsafe.Pointer(sess))
var session clientSessionState
sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
copy(sessBytes, tlsSessBytes)
r := bytes.NewReader(session.nonce) r := bytes.NewReader(session.nonce)
rev, err := utils.ReadVarInt(r) rev, err := utils.ReadVarInt(r)
if err != nil { if err != nil {
@@ -83,10 +80,7 @@ func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, b
c.setAppData(appData) c.setAppData(appData)
session.nonce = nonce session.nonce = nonce
c.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) c.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
var qtlsSession qtls.ClientSessionState return (*qtls.ClientSessionState)(unsafe.Pointer(session)), ok
qtlsSessBytes := (*[unsafe.Sizeof(qtlsSession)]byte)(unsafe.Pointer(&qtlsSession))[:]
copy(qtlsSessBytes, sessBytes)
return &qtlsSession, ok
} }
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) { func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
@@ -98,10 +92,7 @@ func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState)
// In order to allow users of quic-go to use a tls.Config, // In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache. // we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical. // In unsafe.go we check that the two structs are actually identical.
qtlsSessBytes := (*[unsafe.Sizeof(*cs)]byte)(unsafe.Pointer(cs))[:] session := (*clientSessionState)(unsafe.Pointer(cs))
var session clientSessionState
sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
copy(sessBytes, qtlsSessBytes)
appData := c.getAppData() appData := c.getAppData()
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
utils.WriteVarInt(buf, clientSessionStateRevision) utils.WriteVarInt(buf, clientSessionStateRevision)
@@ -111,8 +102,5 @@ func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState)
utils.WriteVarInt(buf, uint64(len(session.nonce))) utils.WriteVarInt(buf, uint64(len(session.nonce)))
buf.Write(session.nonce) buf.Write(session.nonce)
session.nonce = buf.Bytes() session.nonce = buf.Bytes()
var tlsSession tls.ClientSessionState c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(session)))
tlsSessBytes := (*[unsafe.Sizeof(tlsSession)]byte)(unsafe.Pointer(&tlsSession))[:]
copy(tlsSessBytes, sessBytes)
c.ClientSessionCache.Put(sessionKey, &tlsSession)
} }

View File

@@ -18,13 +18,8 @@ import (
var _ = Describe("ClientSessionCache", func() { var _ = Describe("ClientSessionCache", func() {
encodeIntoSessionTicket := func(data []byte) *tls.ClientSessionState { encodeIntoSessionTicket := func(data []byte) *tls.ClientSessionState {
var session clientSessionState session := &clientSessionState{nonce: data}
sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] return (*tls.ClientSessionState)(unsafe.Pointer(session))
session.nonce = data
var tlsSession tls.ClientSessionState
tlsSessBytes := (*[unsafe.Sizeof(tlsSession)]byte)(unsafe.Pointer(&tlsSession))[:]
copy(tlsSessBytes, sessBytes)
return &tlsSession
} }
It("puts and gets", func() { It("puts and gets", func() {
@@ -116,10 +111,7 @@ var _ = Describe("ClientSessionCache", func() {
csc.Put("localhost", &qtls.ClientSessionState{}) csc.Put("localhost", &qtls.ClientSessionState{})
state, ok := cache.Get("localhost") state, ok := cache.Get("localhost")
Expect(ok).To(BeTrue()) Expect(ok).To(BeTrue())
tlsSessBytes := (*[unsafe.Sizeof(*state)]byte)(unsafe.Pointer(state))[:] session := (*clientSessionState)(unsafe.Pointer(state))
var session clientSessionState
sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
copy(sessBytes, tlsSessBytes)
Expect(session.nonce).ToNot(BeEmpty()) Expect(session.nonce).ToNot(BeEmpty())
_, ok = csc.Get("localhost") _, ok = csc.Get("localhost")
@@ -128,7 +120,6 @@ var _ = Describe("ClientSessionCache", func() {
for i := 0; i < len(nonce); i++ { for i := 0; i < len(nonce); i++ {
session.nonce = session.nonce[:i] session.nonce = session.nonce[:i]
copy(tlsSessBytes, sessBytes)
_, ok = csc.Get("localhost") _, ok = csc.Get("localhost")
Expect(ok).To(BeFalse()) Expect(ok).To(BeFalse())
} }