forked from quic-go/quic-go
move the STK generation from the ServerConfig to a separate struct
This commit is contained in:
@@ -58,7 +58,11 @@ const stkKeySize = 16
|
||||
const stkNonceSize = 16
|
||||
|
||||
// NewStkSource creates a source for source address tokens
|
||||
func NewStkSource(secret []byte) (StkSource, error) {
|
||||
func NewStkSource() (StkSource, error) {
|
||||
secret := make([]byte, 32)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := deriveKey(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -43,7 +43,6 @@ var _ = Describe("Source Address Tokens", func() {
|
||||
Context("source", func() {
|
||||
var (
|
||||
source *stkSource
|
||||
secret []byte
|
||||
ip4 net.IP
|
||||
ip6 net.IP
|
||||
)
|
||||
@@ -56,8 +55,7 @@ var _ = Describe("Source Address Tokens", func() {
|
||||
ip6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")
|
||||
Expect(ip6).NotTo(BeEmpty())
|
||||
|
||||
secret = []byte("TESTING")
|
||||
sourceI, err := NewStkSource(secret)
|
||||
sourceI, err := NewStkSource()
|
||||
source = sourceI.(*stkSource)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -27,6 +27,7 @@ type cryptoSetupServer struct {
|
||||
connID protocol.ConnectionID
|
||||
sourceAddr []byte
|
||||
scfg *ServerConfig
|
||||
stkGenerator *STKGenerator
|
||||
diversificationNonce []byte
|
||||
|
||||
version protocol.VersionNumber
|
||||
@@ -68,6 +69,11 @@ func NewCryptoSetup(
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
) (CryptoSetup, error) {
|
||||
stkGenerator, err := NewSTKGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sourceAddr []byte
|
||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||
sourceAddr = udpAddr.IP
|
||||
@@ -81,6 +87,7 @@ func NewCryptoSetup(
|
||||
version: version,
|
||||
supportedVersions: supportedVersions,
|
||||
scfg: scfg,
|
||||
stkGenerator: stkGenerator,
|
||||
keyDerivation: crypto.DeriveKeysAESGCM,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||
@@ -276,7 +283,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) verifySTK(stk []byte) bool {
|
||||
stkTime, err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk)
|
||||
stkTime, err := h.stkGenerator.VerifyToken(h.sourceAddr, stk)
|
||||
if err != nil {
|
||||
utils.Debugf("STK invalid: %s", err.Error())
|
||||
return false
|
||||
@@ -292,7 +299,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
|
||||
return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small")
|
||||
}
|
||||
|
||||
token, err := h.scfg.stkSource.NewToken(h.sourceAddr)
|
||||
token, err := h.stkGenerator.NewToken(h.sourceAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -194,7 +194,6 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||
versionTag = make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(protocol.VersionWhatever))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
scfg.stkSource = &mockStkSource{}
|
||||
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
||||
supportedVersions = []protocol.VersionNumber{version, 98, 99}
|
||||
cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
@@ -210,6 +209,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cs = csInt.(*cryptoSetupServer)
|
||||
cs.stkGenerator.stkSource = &mockStkSource{}
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
|
||||
})
|
||||
@@ -436,12 +436,12 @@ var _ = Describe("Server Crypto Setup", func() {
|
||||
|
||||
It("recognizes inchoate CHLOs with an invalid STK", func() {
|
||||
testErr := errors.New("STK invalid")
|
||||
scfg.stkSource.(*mockStkSource).verifyErr = testErr
|
||||
cs.stkGenerator.stkSource.(*mockStkSource).verifyErr = testErr
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("REJ messages that have an expired STK", func() {
|
||||
cs.scfg.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
|
||||
cs.stkGenerator.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ type ServerConfig struct {
|
||||
certChain crypto.CertChain
|
||||
ID []byte
|
||||
obit []byte
|
||||
stkSource crypto.StkSource
|
||||
}
|
||||
|
||||
// NewServerConfig creates a new server config
|
||||
@@ -24,27 +23,16 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stkSecret := make([]byte, 32)
|
||||
if _, err = rand.Read(stkSecret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
obit := make([]byte, 8)
|
||||
if _, err = rand.Read(obit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stkSource, err := crypto.NewStkSource(stkSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ServerConfig{
|
||||
kex: kex,
|
||||
certChain: certChain,
|
||||
ID: id,
|
||||
obit: obit,
|
||||
stkSource: stkSource,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
33
handshake/stk_generator.go
Normal file
33
handshake/stk_generator.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
)
|
||||
|
||||
// An STKGenerator generates STKs
|
||||
type STKGenerator struct {
|
||||
stkSource crypto.StkSource
|
||||
}
|
||||
|
||||
// NewSTKGenerator initializes a new STKGenerator
|
||||
func NewSTKGenerator() (*STKGenerator, error) {
|
||||
stkSource, err := crypto.NewStkSource()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &STKGenerator{
|
||||
stkSource: stkSource,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewToken generates a new STK token for a given source address
|
||||
func (g *STKGenerator) NewToken(sourceAddr []byte) ([]byte, error) {
|
||||
return g.stkSource.NewToken(sourceAddr)
|
||||
}
|
||||
|
||||
// VerifyToken verifies an STK token
|
||||
func (g *STKGenerator) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) {
|
||||
return g.stkSource.VerifyToken(sourceAddr, token)
|
||||
}
|
||||
35
handshake/stk_generator_test.go
Normal file
35
handshake/stk_generator_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("STK Generator", func() {
|
||||
var stkGen *STKGenerator
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
stkGen, err = NewSTKGenerator()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("generates an STK", func() {
|
||||
ip := net.IPv4(127, 0, 0, 1)
|
||||
token, err := stkGen.NewToken(ip)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(token).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("verifies an STK", func() {
|
||||
ip := net.IPv4(192, 168, 0, 1)
|
||||
token, err := stkGen.NewToken(ip)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
t, err := stkGen.VerifyToken(ip, token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user