forked from quic-go/quic-go
make the ClientSessionCache work with qtls
This commit is contained in:
@@ -2,10 +2,45 @@ package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"unsafe"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type clientSessionCache struct {
|
||||
tls.ClientSessionCache
|
||||
}
|
||||
|
||||
var _ qtls.ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
|
||||
sess, ok := c.ClientSessionCache.Get(sessionKey)
|
||||
if sess == nil {
|
||||
return nil, ok
|
||||
}
|
||||
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
|
||||
// In order to allow users of quic-go to use a tls.Config,
|
||||
// we need this workaround to use the ClientSessionCache.
|
||||
// In unsafe.go we check that the two structs are actually identical.
|
||||
usess := (*[unsafe.Sizeof(*sess)]byte)(unsafe.Pointer(sess))[:]
|
||||
var session qtls.ClientSessionState
|
||||
usession := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
|
||||
copy(usession, usess)
|
||||
return &session, ok
|
||||
}
|
||||
|
||||
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
|
||||
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
|
||||
// In order to allow users of quic-go to use a tls.Config,
|
||||
// we need this workaround to use the ClientSessionCache.
|
||||
// In unsafe.go we check that the two structs are actually identical.
|
||||
usess := (*[unsafe.Sizeof(*cs)]byte)(unsafe.Pointer(cs))[:]
|
||||
var session tls.ClientSessionState
|
||||
usession := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
|
||||
copy(usession, usess)
|
||||
c.ClientSessionCache.Put(sessionKey, &session)
|
||||
}
|
||||
|
||||
func tlsConfigToQtlsConfig(
|
||||
c *tls.Config,
|
||||
recordLayer qtls.RecordLayer,
|
||||
@@ -36,6 +71,10 @@ func tlsConfigToQtlsConfig(
|
||||
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler), nil
|
||||
}
|
||||
}
|
||||
var csc qtls.ClientSessionCache
|
||||
if c.ClientSessionCache != nil {
|
||||
csc = &clientSessionCache{c.ClientSessionCache}
|
||||
}
|
||||
return &qtls.Config{
|
||||
Rand: c.Rand,
|
||||
Time: c.Time,
|
||||
@@ -55,6 +94,7 @@ func tlsConfigToQtlsConfig(
|
||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||
SessionTicketKey: c.SessionTicketKey,
|
||||
ClientSessionCache: csc,
|
||||
MinVersion: minVersion,
|
||||
MaxVersion: maxVersion,
|
||||
CurvePreferences: c.CurvePreferences,
|
||||
|
||||
@@ -24,6 +24,21 @@ func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Ext
|
||||
}
|
||||
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
|
||||
|
||||
type mockClientSessionCache struct {
|
||||
get, put string
|
||||
session *tls.ClientSessionState
|
||||
}
|
||||
|
||||
var _ tls.ClientSessionCache = &mockClientSessionCache{}
|
||||
|
||||
func (c *mockClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
|
||||
c.get = sessionKey
|
||||
return c.session, false
|
||||
}
|
||||
func (c *mockClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
|
||||
c.put = sessionKey
|
||||
}
|
||||
|
||||
var _ = Describe("qtls.Config generation", func() {
|
||||
It("sets MinVersion and MaxVersion", func() {
|
||||
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
|
||||
@@ -96,4 +111,39 @@ var _ = Describe("qtls.Config generation", func() {
|
||||
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ClientSessionCache", func() {
|
||||
It("doesn't set if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{})
|
||||
Expect(qtlsConf.ClientSessionCache).To(BeNil())
|
||||
})
|
||||
|
||||
It("sets it, and puts and gets session states", func() {
|
||||
state := &qtls.ClientSessionState{}
|
||||
csc := &mockClientSessionCache{session: &tls.ClientSessionState{}}
|
||||
tlsConf := &tls.Config{ClientSessionCache: csc}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
||||
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
|
||||
Expect(csc.put).To(BeEmpty())
|
||||
qtlsConf.ClientSessionCache.Put("foobar", state)
|
||||
Expect(csc.put).To(Equal("foobar"))
|
||||
Expect(csc.get).To(BeEmpty())
|
||||
sess, _ := qtlsConf.ClientSessionCache.Get("raboof")
|
||||
Expect(csc.get).To(Equal("raboof"))
|
||||
Expect(sess).To(Equal(state))
|
||||
})
|
||||
|
||||
It("sets it, and gets nil session states for unknown keys", func() {
|
||||
csc := &mockClientSessionCache{}
|
||||
tlsConf := &tls.Config{
|
||||
ClientSessionCache: csc,
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
||||
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
|
||||
Expect(csc.get).To(BeEmpty())
|
||||
sess, _ := qtlsConf.ClientSessionCache.Get("raboof")
|
||||
Expect(csc.get).To(Equal("raboof"))
|
||||
Expect(sess).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package handshake
|
||||
|
||||
// This package uses unsafe to convert between qtls.ConnectionState and tls.ConnectionState.
|
||||
// This package uses unsafe to convert between:
|
||||
// * qtls.ConnectionState and tls.ConnectionState
|
||||
// * qtls.ClientSessionState and tls.ClientSessionState
|
||||
// We check in init() that this conversion actually is safe.
|
||||
|
||||
import (
|
||||
@@ -14,6 +16,9 @@ func init() {
|
||||
if !structsEqual(&tls.ConnectionState{}, &qtls.ConnectionState{}) {
|
||||
panic("qtls.ConnectionState not compatible with tls.ConnectionState")
|
||||
}
|
||||
if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) {
|
||||
panic("qtls.ClientSessionState not compatible with tls.ClientSessionState")
|
||||
}
|
||||
}
|
||||
|
||||
func structsEqual(a, b interface{}) bool {
|
||||
|
||||
Reference in New Issue
Block a user