Files
quic-go/internal/handshake/client_session_cache_test.go
2020-02-26 16:13:57 +07:00

128 lines
3.3 KiB
Go

package handshake
import (
"bytes"
"crypto/tls"
"time"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/congestion"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("ClientSessionCache", func() {
encodeIntoSessionTicket := func(data []byte) *tls.ClientSessionState {
session := &clientSessionState{nonce: data}
return (*tls.ClientSessionState)(unsafe.Pointer(session))
}
It("puts and gets", func() {
get := make(chan []byte, 100)
set := make(chan []byte, 100)
csc := newClientSessionCache(
tls.NewLRUClientSessionCache(100),
congestion.NewRTTStats(),
func() []byte { return <-get },
func(b []byte) { set <- b },
)
get <- []byte("foobar")
csc.Put("localhost", &qtls.ClientSessionState{})
Expect(set).To(BeEmpty())
state, ok := csc.Get("localhost")
Expect(ok).To(BeTrue())
Expect(state).ToNot(BeNil())
Expect(set).To(Receive(Equal([]byte("foobar"))))
})
It("saves the RTT", func() {
rttStatsOrig := congestion.NewRTTStats()
rttStatsOrig.UpdateRTT(10*time.Second, 0, time.Now())
Expect(rttStatsOrig.SmoothedRTT()).To(Equal(10 * time.Second))
cache := tls.NewLRUClientSessionCache(100)
csc1 := newClientSessionCache(
cache,
rttStatsOrig,
func() []byte { return nil },
func([]byte) {},
)
csc1.Put("localhost", &qtls.ClientSessionState{})
rttStats := congestion.NewRTTStats()
csc2 := newClientSessionCache(
cache,
rttStats,
func() []byte { return nil },
func([]byte) {},
)
Expect(rttStats.SmoothedRTT()).ToNot(Equal(10 * time.Second))
_, ok := csc2.Get("localhost")
Expect(ok).To(BeTrue())
Expect(rttStats.SmoothedRTT()).To(Equal(10 * time.Second))
})
It("refuses a session state that is too short for the revision", func() {
cache := tls.NewLRUClientSessionCache(1)
cache.Put("localhost", encodeIntoSessionTicket([]byte{}))
csc := newClientSessionCache(
cache,
congestion.NewRTTStats(),
func() []byte { return nil },
func([]byte) {},
)
_, ok := csc.Get("localhost")
Expect(ok).To(BeFalse())
})
It("refuses a session state with the wrong revision", func() {
cache := tls.NewLRUClientSessionCache(1)
b := &bytes.Buffer{}
utils.WriteVarInt(b, clientSessionStateRevision+1)
cache.Put("localhost", encodeIntoSessionTicket(b.Bytes()))
csc := newClientSessionCache(
cache,
congestion.NewRTTStats(),
func() []byte { return nil },
func([]byte) {},
)
_, ok := csc.Get("localhost")
Expect(ok).To(BeFalse())
})
It("refuses a session state when unmarshalling fails", func() {
rttStats := congestion.NewRTTStats()
rttStats.SetInitialRTT(10 * time.Second)
cache := tls.NewLRUClientSessionCache(1)
csc := newClientSessionCache(
cache,
rttStats,
func() []byte { return []byte("foobar") },
func(b []byte) {},
)
csc.Put("localhost", &qtls.ClientSessionState{})
state, ok := cache.Get("localhost")
Expect(ok).To(BeTrue())
session := (*clientSessionState)(unsafe.Pointer(state))
Expect(session.nonce).ToNot(BeEmpty())
_, ok = csc.Get("localhost")
Expect(ok).To(BeTrue())
nonce := session.nonce
for i := 0; i < len(nonce); i++ {
session.nonce = session.nonce[:i]
_, ok = csc.Get("localhost")
Expect(ok).To(BeFalse())
}
})
})