diff --git a/internal/handshake/client_session_cache.go b/internal/handshake/client_session_cache.go index f28bb5b30..43b816683 100644 --- a/internal/handshake/client_session_cache.go +++ b/internal/handshake/client_session_cache.go @@ -48,10 +48,7 @@ func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, b // In order to allow users of quic-go to use a tls.Config, // we need this workaround to use the ClientSessionCache. // In unsafe.go we check that the two structs are actually identical. - tlsSessBytes := (*[unsafe.Sizeof(*sess)]byte)(unsafe.Pointer(sess))[:] - var session clientSessionState - sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] - copy(sessBytes, tlsSessBytes) + session := (*clientSessionState)(unsafe.Pointer(sess)) r := bytes.NewReader(session.nonce) rev, err := utils.ReadVarInt(r) if err != nil { @@ -83,10 +80,7 @@ func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, b c.setAppData(appData) session.nonce = nonce c.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) - var qtlsSession qtls.ClientSessionState - qtlsSessBytes := (*[unsafe.Sizeof(qtlsSession)]byte)(unsafe.Pointer(&qtlsSession))[:] - copy(qtlsSessBytes, sessBytes) - return &qtlsSession, ok + return (*qtls.ClientSessionState)(unsafe.Pointer(session)), ok } func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) { @@ -98,10 +92,7 @@ func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) // In order to allow users of quic-go to use a tls.Config, // we need this workaround to use the ClientSessionCache. // In unsafe.go we check that the two structs are actually identical. - qtlsSessBytes := (*[unsafe.Sizeof(*cs)]byte)(unsafe.Pointer(cs))[:] - var session clientSessionState - sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] - copy(sessBytes, qtlsSessBytes) + session := (*clientSessionState)(unsafe.Pointer(cs)) appData := c.getAppData() buf := &bytes.Buffer{} utils.WriteVarInt(buf, clientSessionStateRevision) @@ -111,8 +102,5 @@ func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) utils.WriteVarInt(buf, uint64(len(session.nonce))) buf.Write(session.nonce) session.nonce = buf.Bytes() - var tlsSession tls.ClientSessionState - tlsSessBytes := (*[unsafe.Sizeof(tlsSession)]byte)(unsafe.Pointer(&tlsSession))[:] - copy(tlsSessBytes, sessBytes) - c.ClientSessionCache.Put(sessionKey, &tlsSession) + c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(session))) } diff --git a/internal/handshake/client_session_cache_test.go b/internal/handshake/client_session_cache_test.go index 85294438b..330d8e4d4 100644 --- a/internal/handshake/client_session_cache_test.go +++ b/internal/handshake/client_session_cache_test.go @@ -18,13 +18,8 @@ import ( var _ = Describe("ClientSessionCache", func() { encodeIntoSessionTicket := func(data []byte) *tls.ClientSessionState { - var session clientSessionState - sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] - session.nonce = data - var tlsSession tls.ClientSessionState - tlsSessBytes := (*[unsafe.Sizeof(tlsSession)]byte)(unsafe.Pointer(&tlsSession))[:] - copy(tlsSessBytes, sessBytes) - return &tlsSession + session := &clientSessionState{nonce: data} + return (*tls.ClientSessionState)(unsafe.Pointer(session)) } It("puts and gets", func() { @@ -116,10 +111,7 @@ var _ = Describe("ClientSessionCache", func() { csc.Put("localhost", &qtls.ClientSessionState{}) state, ok := cache.Get("localhost") Expect(ok).To(BeTrue()) - tlsSessBytes := (*[unsafe.Sizeof(*state)]byte)(unsafe.Pointer(state))[:] - var session clientSessionState - sessBytes := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] - copy(sessBytes, tlsSessBytes) + session := (*clientSessionState)(unsafe.Pointer(state)) Expect(session.nonce).ToNot(BeEmpty()) _, ok = csc.Get("localhost") @@ -128,7 +120,6 @@ var _ = Describe("ClientSessionCache", func() { for i := 0; i < len(nonce); i++ { session.nonce = session.nonce[:i] - copy(tlsSessBytes, sessBytes) _, ok = csc.Get("localhost") Expect(ok).To(BeFalse()) }