forked from quic-go/quic-go
make the ClientSessionCache work with qtls
This commit is contained in:
@@ -2,10 +2,45 @@ package handshake
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/marten-seemann/qtls"
|
"github.com/marten-seemann/qtls"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type clientSessionCache struct {
|
||||||
|
tls.ClientSessionCache
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ qtls.ClientSessionCache = &clientSessionCache{}
|
||||||
|
|
||||||
|
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
|
||||||
|
sess, ok := c.ClientSessionCache.Get(sessionKey)
|
||||||
|
if sess == nil {
|
||||||
|
return nil, ok
|
||||||
|
}
|
||||||
|
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
|
||||||
|
// In order to allow users of quic-go to use a tls.Config,
|
||||||
|
// we need this workaround to use the ClientSessionCache.
|
||||||
|
// In unsafe.go we check that the two structs are actually identical.
|
||||||
|
usess := (*[unsafe.Sizeof(*sess)]byte)(unsafe.Pointer(sess))[:]
|
||||||
|
var session qtls.ClientSessionState
|
||||||
|
usession := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
|
||||||
|
copy(usession, usess)
|
||||||
|
return &session, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
|
||||||
|
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
|
||||||
|
// In order to allow users of quic-go to use a tls.Config,
|
||||||
|
// we need this workaround to use the ClientSessionCache.
|
||||||
|
// In unsafe.go we check that the two structs are actually identical.
|
||||||
|
usess := (*[unsafe.Sizeof(*cs)]byte)(unsafe.Pointer(cs))[:]
|
||||||
|
var session tls.ClientSessionState
|
||||||
|
usession := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
|
||||||
|
copy(usession, usess)
|
||||||
|
c.ClientSessionCache.Put(sessionKey, &session)
|
||||||
|
}
|
||||||
|
|
||||||
func tlsConfigToQtlsConfig(
|
func tlsConfigToQtlsConfig(
|
||||||
c *tls.Config,
|
c *tls.Config,
|
||||||
recordLayer qtls.RecordLayer,
|
recordLayer qtls.RecordLayer,
|
||||||
@@ -36,6 +71,10 @@ func tlsConfigToQtlsConfig(
|
|||||||
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler), nil
|
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
var csc qtls.ClientSessionCache
|
||||||
|
if c.ClientSessionCache != nil {
|
||||||
|
csc = &clientSessionCache{c.ClientSessionCache}
|
||||||
|
}
|
||||||
return &qtls.Config{
|
return &qtls.Config{
|
||||||
Rand: c.Rand,
|
Rand: c.Rand,
|
||||||
Time: c.Time,
|
Time: c.Time,
|
||||||
@@ -55,6 +94,7 @@ func tlsConfigToQtlsConfig(
|
|||||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||||
SessionTicketKey: c.SessionTicketKey,
|
SessionTicketKey: c.SessionTicketKey,
|
||||||
|
ClientSessionCache: csc,
|
||||||
MinVersion: minVersion,
|
MinVersion: minVersion,
|
||||||
MaxVersion: maxVersion,
|
MaxVersion: maxVersion,
|
||||||
CurvePreferences: c.CurvePreferences,
|
CurvePreferences: c.CurvePreferences,
|
||||||
|
|||||||
@@ -24,6 +24,21 @@ func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Ext
|
|||||||
}
|
}
|
||||||
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
|
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
|
||||||
|
|
||||||
|
type mockClientSessionCache struct {
|
||||||
|
get, put string
|
||||||
|
session *tls.ClientSessionState
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ tls.ClientSessionCache = &mockClientSessionCache{}
|
||||||
|
|
||||||
|
func (c *mockClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
|
||||||
|
c.get = sessionKey
|
||||||
|
return c.session, false
|
||||||
|
}
|
||||||
|
func (c *mockClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
|
||||||
|
c.put = sessionKey
|
||||||
|
}
|
||||||
|
|
||||||
var _ = Describe("qtls.Config generation", func() {
|
var _ = Describe("qtls.Config generation", func() {
|
||||||
It("sets MinVersion and MaxVersion", func() {
|
It("sets MinVersion and MaxVersion", func() {
|
||||||
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
|
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
|
||||||
@@ -96,4 +111,39 @@ var _ = Describe("qtls.Config generation", func() {
|
|||||||
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
|
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("ClientSessionCache", func() {
|
||||||
|
It("doesn't set if absent", func() {
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{})
|
||||||
|
Expect(qtlsConf.ClientSessionCache).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets it, and puts and gets session states", func() {
|
||||||
|
state := &qtls.ClientSessionState{}
|
||||||
|
csc := &mockClientSessionCache{session: &tls.ClientSessionState{}}
|
||||||
|
tlsConf := &tls.Config{ClientSessionCache: csc}
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
||||||
|
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
|
||||||
|
Expect(csc.put).To(BeEmpty())
|
||||||
|
qtlsConf.ClientSessionCache.Put("foobar", state)
|
||||||
|
Expect(csc.put).To(Equal("foobar"))
|
||||||
|
Expect(csc.get).To(BeEmpty())
|
||||||
|
sess, _ := qtlsConf.ClientSessionCache.Get("raboof")
|
||||||
|
Expect(csc.get).To(Equal("raboof"))
|
||||||
|
Expect(sess).To(Equal(state))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets it, and gets nil session states for unknown keys", func() {
|
||||||
|
csc := &mockClientSessionCache{}
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
ClientSessionCache: csc,
|
||||||
|
}
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
||||||
|
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
|
||||||
|
Expect(csc.get).To(BeEmpty())
|
||||||
|
sess, _ := qtlsConf.ClientSessionCache.Get("raboof")
|
||||||
|
Expect(csc.get).To(Equal("raboof"))
|
||||||
|
Expect(sess).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
// This package uses unsafe to convert between qtls.ConnectionState and tls.ConnectionState.
|
// This package uses unsafe to convert between:
|
||||||
|
// * qtls.ConnectionState and tls.ConnectionState
|
||||||
|
// * qtls.ClientSessionState and tls.ClientSessionState
|
||||||
// We check in init() that this conversion actually is safe.
|
// We check in init() that this conversion actually is safe.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -14,6 +16,9 @@ func init() {
|
|||||||
if !structsEqual(&tls.ConnectionState{}, &qtls.ConnectionState{}) {
|
if !structsEqual(&tls.ConnectionState{}, &qtls.ConnectionState{}) {
|
||||||
panic("qtls.ConnectionState not compatible with tls.ConnectionState")
|
panic("qtls.ConnectionState not compatible with tls.ConnectionState")
|
||||||
}
|
}
|
||||||
|
if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) {
|
||||||
|
panic("qtls.ClientSessionState not compatible with tls.ClientSessionState")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func structsEqual(a, b interface{}) bool {
|
func structsEqual(a, b interface{}) bool {
|
||||||
|
|||||||
Reference in New Issue
Block a user