forked from quic-go/quic-go
128 lines
3.3 KiB
Go
128 lines
3.3 KiB
Go
package handshake
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
|
|
"github.com/marten-seemann/qtls"
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
|
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var _ = Describe("ClientSessionCache", func() {
|
|
encodeIntoSessionTicket := func(data []byte) *tls.ClientSessionState {
|
|
session := &clientSessionState{nonce: data}
|
|
return (*tls.ClientSessionState)(unsafe.Pointer(session))
|
|
}
|
|
|
|
It("puts and gets", func() {
|
|
get := make(chan []byte, 100)
|
|
set := make(chan []byte, 100)
|
|
|
|
csc := newClientSessionCache(
|
|
tls.NewLRUClientSessionCache(100),
|
|
congestion.NewRTTStats(),
|
|
func() []byte { return <-get },
|
|
func(b []byte) { set <- b },
|
|
)
|
|
|
|
get <- []byte("foobar")
|
|
csc.Put("localhost", &qtls.ClientSessionState{})
|
|
Expect(set).To(BeEmpty())
|
|
state, ok := csc.Get("localhost")
|
|
Expect(ok).To(BeTrue())
|
|
Expect(state).ToNot(BeNil())
|
|
Expect(set).To(Receive(Equal([]byte("foobar"))))
|
|
})
|
|
|
|
It("saves the RTT", func() {
|
|
rttStatsOrig := congestion.NewRTTStats()
|
|
rttStatsOrig.UpdateRTT(10*time.Second, 0, time.Now())
|
|
Expect(rttStatsOrig.SmoothedRTT()).To(Equal(10 * time.Second))
|
|
cache := tls.NewLRUClientSessionCache(100)
|
|
csc1 := newClientSessionCache(
|
|
cache,
|
|
rttStatsOrig,
|
|
func() []byte { return nil },
|
|
func([]byte) {},
|
|
)
|
|
csc1.Put("localhost", &qtls.ClientSessionState{})
|
|
|
|
rttStats := congestion.NewRTTStats()
|
|
csc2 := newClientSessionCache(
|
|
cache,
|
|
rttStats,
|
|
func() []byte { return nil },
|
|
func([]byte) {},
|
|
)
|
|
Expect(rttStats.SmoothedRTT()).ToNot(Equal(10 * time.Second))
|
|
_, ok := csc2.Get("localhost")
|
|
Expect(ok).To(BeTrue())
|
|
Expect(rttStats.SmoothedRTT()).To(Equal(10 * time.Second))
|
|
})
|
|
|
|
It("refuses a session state that is too short for the revision", func() {
|
|
cache := tls.NewLRUClientSessionCache(1)
|
|
cache.Put("localhost", encodeIntoSessionTicket([]byte{}))
|
|
|
|
csc := newClientSessionCache(
|
|
cache,
|
|
congestion.NewRTTStats(),
|
|
func() []byte { return nil },
|
|
func([]byte) {},
|
|
)
|
|
_, ok := csc.Get("localhost")
|
|
Expect(ok).To(BeFalse())
|
|
})
|
|
|
|
It("refuses a session state with the wrong revision", func() {
|
|
cache := tls.NewLRUClientSessionCache(1)
|
|
b := &bytes.Buffer{}
|
|
utils.WriteVarInt(b, clientSessionStateRevision+1)
|
|
cache.Put("localhost", encodeIntoSessionTicket(b.Bytes()))
|
|
|
|
csc := newClientSessionCache(
|
|
cache,
|
|
congestion.NewRTTStats(),
|
|
func() []byte { return nil },
|
|
func([]byte) {},
|
|
)
|
|
_, ok := csc.Get("localhost")
|
|
Expect(ok).To(BeFalse())
|
|
})
|
|
|
|
It("refuses a session state when unmarshalling fails", func() {
|
|
rttStats := congestion.NewRTTStats()
|
|
rttStats.SetInitialRTT(10 * time.Second)
|
|
cache := tls.NewLRUClientSessionCache(1)
|
|
csc := newClientSessionCache(
|
|
cache,
|
|
rttStats,
|
|
func() []byte { return []byte("foobar") },
|
|
func(b []byte) {},
|
|
)
|
|
csc.Put("localhost", &qtls.ClientSessionState{})
|
|
state, ok := cache.Get("localhost")
|
|
Expect(ok).To(BeTrue())
|
|
session := (*clientSessionState)(unsafe.Pointer(state))
|
|
Expect(session.nonce).ToNot(BeEmpty())
|
|
|
|
_, ok = csc.Get("localhost")
|
|
Expect(ok).To(BeTrue())
|
|
nonce := session.nonce
|
|
|
|
for i := 0; i < len(nonce); i++ {
|
|
session.nonce = session.nonce[:i]
|
|
_, ok = csc.Get("localhost")
|
|
Expect(ok).To(BeFalse())
|
|
}
|
|
})
|
|
})
|