diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go index 8cd2fe89..a5769649 100644 --- a/internal/handshake/qtls.go +++ b/internal/handshake/qtls.go @@ -2,10 +2,45 @@ package handshake import ( "crypto/tls" + "unsafe" "github.com/marten-seemann/qtls" ) +type clientSessionCache struct { + tls.ClientSessionCache +} + +var _ qtls.ClientSessionCache = &clientSessionCache{} + +func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) { + sess, ok := c.ClientSessionCache.Get(sessionKey) + if sess == nil { + return nil, ok + } + // qtls.ClientSessionState is identical to the tls.ClientSessionState. + // In order to allow users of quic-go to use a tls.Config, + // we need this workaround to use the ClientSessionCache. + // In unsafe.go we check that the two structs are actually identical. + usess := (*[unsafe.Sizeof(*sess)]byte)(unsafe.Pointer(sess))[:] + var session qtls.ClientSessionState + usession := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] + copy(usession, usess) + return &session, ok +} + +func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) { + // qtls.ClientSessionState is identical to the tls.ClientSessionState. + // In order to allow users of quic-go to use a tls.Config, + // we need this workaround to use the ClientSessionCache. + // In unsafe.go we check that the two structs are actually identical. + usess := (*[unsafe.Sizeof(*cs)]byte)(unsafe.Pointer(cs))[:] + var session tls.ClientSessionState + usession := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:] + copy(usession, usess) + c.ClientSessionCache.Put(sessionKey, &session) +} + func tlsConfigToQtlsConfig( c *tls.Config, recordLayer qtls.RecordLayer, @@ -36,6 +71,10 @@ func tlsConfigToQtlsConfig( return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler), nil } } + var csc qtls.ClientSessionCache + if c.ClientSessionCache != nil { + csc = &clientSessionCache{c.ClientSessionCache} + } return &qtls.Config{ Rand: c.Rand, Time: c.Time, @@ -55,6 +94,7 @@ func tlsConfigToQtlsConfig( PreferServerCipherSuites: c.PreferServerCipherSuites, SessionTicketsDisabled: c.SessionTicketsDisabled, SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: csc, MinVersion: minVersion, MaxVersion: maxVersion, CurvePreferences: c.CurvePreferences, diff --git a/internal/handshake/qtls_test.go b/internal/handshake/qtls_test.go index 4c036812..3e95a7fe 100644 --- a/internal/handshake/qtls_test.go +++ b/internal/handshake/qtls_test.go @@ -24,6 +24,21 @@ func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Ext } func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") } +type mockClientSessionCache struct { + get, put string + session *tls.ClientSessionState +} + +var _ tls.ClientSessionCache = &mockClientSessionCache{} + +func (c *mockClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) { + c.get = sessionKey + return c.session, false +} +func (c *mockClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) { + c.put = sessionKey +} + var _ = Describe("qtls.Config generation", func() { It("sets MinVersion and MaxVersion", func() { tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12} @@ -96,4 +111,39 @@ var _ = Describe("qtls.Config generation", func() { Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil()) }) }) + + Context("ClientSessionCache", func() { + It("doesn't set if absent", func() { + qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}) + Expect(qtlsConf.ClientSessionCache).To(BeNil()) + }) + + It("sets it, and puts and gets session states", func() { + state := &qtls.ClientSessionState{} + csc := &mockClientSessionCache{session: &tls.ClientSessionState{}} + tlsConf := &tls.Config{ClientSessionCache: csc} + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}) + Expect(qtlsConf.ClientSessionCache).ToNot(BeNil()) + Expect(csc.put).To(BeEmpty()) + qtlsConf.ClientSessionCache.Put("foobar", state) + Expect(csc.put).To(Equal("foobar")) + Expect(csc.get).To(BeEmpty()) + sess, _ := qtlsConf.ClientSessionCache.Get("raboof") + Expect(csc.get).To(Equal("raboof")) + Expect(sess).To(Equal(state)) + }) + + It("sets it, and gets nil session states for unknown keys", func() { + csc := &mockClientSessionCache{} + tlsConf := &tls.Config{ + ClientSessionCache: csc, + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}) + Expect(qtlsConf.ClientSessionCache).ToNot(BeNil()) + Expect(csc.get).To(BeEmpty()) + sess, _ := qtlsConf.ClientSessionCache.Get("raboof") + Expect(csc.get).To(Equal("raboof")) + Expect(sess).To(BeNil()) + }) + }) }) diff --git a/internal/handshake/unsafe.go b/internal/handshake/unsafe.go index df9c0db8..fb051aeb 100644 --- a/internal/handshake/unsafe.go +++ b/internal/handshake/unsafe.go @@ -1,6 +1,8 @@ package handshake -// This package uses unsafe to convert between qtls.ConnectionState and tls.ConnectionState. +// This package uses unsafe to convert between: +// * qtls.ConnectionState and tls.ConnectionState +// * qtls.ClientSessionState and tls.ClientSessionState // We check in init() that this conversion actually is safe. import ( @@ -14,6 +16,9 @@ func init() { if !structsEqual(&tls.ConnectionState{}, &qtls.ConnectionState{}) { panic("qtls.ConnectionState not compatible with tls.ConnectionState") } + if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) { + panic("qtls.ClientSessionState not compatible with tls.ClientSessionState") + } } func structsEqual(a, b interface{}) bool {