forked from quic-go/quic-go
move the STK generation from the ServerConfig to a separate struct
This commit is contained in:
@@ -58,7 +58,11 @@ const stkKeySize = 16
|
|||||||
const stkNonceSize = 16
|
const stkNonceSize = 16
|
||||||
|
|
||||||
// NewStkSource creates a source for source address tokens
|
// NewStkSource creates a source for source address tokens
|
||||||
func NewStkSource(secret []byte) (StkSource, error) {
|
func NewStkSource() (StkSource, error) {
|
||||||
|
secret := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(secret); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
key, err := deriveKey(secret)
|
key, err := deriveKey(secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ var _ = Describe("Source Address Tokens", func() {
|
|||||||
Context("source", func() {
|
Context("source", func() {
|
||||||
var (
|
var (
|
||||||
source *stkSource
|
source *stkSource
|
||||||
secret []byte
|
|
||||||
ip4 net.IP
|
ip4 net.IP
|
||||||
ip6 net.IP
|
ip6 net.IP
|
||||||
)
|
)
|
||||||
@@ -56,8 +55,7 @@ var _ = Describe("Source Address Tokens", func() {
|
|||||||
ip6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")
|
ip6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")
|
||||||
Expect(ip6).NotTo(BeEmpty())
|
Expect(ip6).NotTo(BeEmpty())
|
||||||
|
|
||||||
secret = []byte("TESTING")
|
sourceI, err := NewStkSource()
|
||||||
sourceI, err := NewStkSource(secret)
|
|
||||||
source = sourceI.(*stkSource)
|
source = sourceI.(*stkSource)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ type cryptoSetupServer struct {
|
|||||||
connID protocol.ConnectionID
|
connID protocol.ConnectionID
|
||||||
sourceAddr []byte
|
sourceAddr []byte
|
||||||
scfg *ServerConfig
|
scfg *ServerConfig
|
||||||
|
stkGenerator *STKGenerator
|
||||||
diversificationNonce []byte
|
diversificationNonce []byte
|
||||||
|
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
@@ -68,6 +69,11 @@ func NewCryptoSetup(
|
|||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
aeadChanged chan<- protocol.EncryptionLevel,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetup, error) {
|
||||||
|
stkGenerator, err := NewSTKGenerator()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
var sourceAddr []byte
|
var sourceAddr []byte
|
||||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||||
sourceAddr = udpAddr.IP
|
sourceAddr = udpAddr.IP
|
||||||
@@ -81,6 +87,7 @@ func NewCryptoSetup(
|
|||||||
version: version,
|
version: version,
|
||||||
supportedVersions: supportedVersions,
|
supportedVersions: supportedVersions,
|
||||||
scfg: scfg,
|
scfg: scfg,
|
||||||
|
stkGenerator: stkGenerator,
|
||||||
keyDerivation: crypto.DeriveKeysAESGCM,
|
keyDerivation: crypto.DeriveKeysAESGCM,
|
||||||
keyExchange: getEphermalKEX,
|
keyExchange: getEphermalKEX,
|
||||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||||
@@ -276,7 +283,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupServer) verifySTK(stk []byte) bool {
|
func (h *cryptoSetupServer) verifySTK(stk []byte) bool {
|
||||||
stkTime, err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk)
|
stkTime, err := h.stkGenerator.VerifyToken(h.sourceAddr, stk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Debugf("STK invalid: %s", err.Error())
|
utils.Debugf("STK invalid: %s", err.Error())
|
||||||
return false
|
return false
|
||||||
@@ -292,7 +299,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
|
|||||||
return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small")
|
return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small")
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := h.scfg.stkSource.NewToken(h.sourceAddr)
|
token, err := h.stkGenerator.NewToken(h.sourceAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -194,7 +194,6 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||||||
versionTag = make([]byte, 4)
|
versionTag = make([]byte, 4)
|
||||||
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(protocol.VersionWhatever))
|
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(protocol.VersionWhatever))
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
scfg.stkSource = &mockStkSource{}
|
|
||||||
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
||||||
supportedVersions = []protocol.VersionNumber{version, 98, 99}
|
supportedVersions = []protocol.VersionNumber{version, 98, 99}
|
||||||
cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever)
|
cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||||
@@ -210,6 +209,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||||||
)
|
)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
cs = csInt.(*cryptoSetupServer)
|
cs = csInt.(*cryptoSetupServer)
|
||||||
|
cs.stkGenerator.stkSource = &mockStkSource{}
|
||||||
cs.keyDerivation = mockKeyDerivation
|
cs.keyDerivation = mockKeyDerivation
|
||||||
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
|
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
|
||||||
})
|
})
|
||||||
@@ -436,12 +436,12 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||||||
|
|
||||||
It("recognizes inchoate CHLOs with an invalid STK", func() {
|
It("recognizes inchoate CHLOs with an invalid STK", func() {
|
||||||
testErr := errors.New("STK invalid")
|
testErr := errors.New("STK invalid")
|
||||||
scfg.stkSource.(*mockStkSource).verifyErr = testErr
|
cs.stkGenerator.stkSource.(*mockStkSource).verifyErr = testErr
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("REJ messages that have an expired STK", func() {
|
It("REJ messages that have an expired STK", func() {
|
||||||
cs.scfg.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
|
cs.stkGenerator.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
|
||||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ type ServerConfig struct {
|
|||||||
certChain crypto.CertChain
|
certChain crypto.CertChain
|
||||||
ID []byte
|
ID []byte
|
||||||
obit []byte
|
obit []byte
|
||||||
stkSource crypto.StkSource
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerConfig creates a new server config
|
// NewServerConfig creates a new server config
|
||||||
@@ -24,27 +23,16 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
stkSecret := make([]byte, 32)
|
|
||||||
if _, err = rand.Read(stkSecret); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
obit := make([]byte, 8)
|
obit := make([]byte, 8)
|
||||||
if _, err = rand.Read(obit); err != nil {
|
if _, err = rand.Read(obit); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
stkSource, err := crypto.NewStkSource(stkSecret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ServerConfig{
|
return &ServerConfig{
|
||||||
kex: kex,
|
kex: kex,
|
||||||
certChain: certChain,
|
certChain: certChain,
|
||||||
ID: id,
|
ID: id,
|
||||||
obit: obit,
|
obit: obit,
|
||||||
stkSource: stkSource,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
33
handshake/stk_generator.go
Normal file
33
handshake/stk_generator.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An STKGenerator generates STKs
|
||||||
|
type STKGenerator struct {
|
||||||
|
stkSource crypto.StkSource
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSTKGenerator initializes a new STKGenerator
|
||||||
|
func NewSTKGenerator() (*STKGenerator, error) {
|
||||||
|
stkSource, err := crypto.NewStkSource()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &STKGenerator{
|
||||||
|
stkSource: stkSource,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToken generates a new STK token for a given source address
|
||||||
|
func (g *STKGenerator) NewToken(sourceAddr []byte) ([]byte, error) {
|
||||||
|
return g.stkSource.NewToken(sourceAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyToken verifies an STK token
|
||||||
|
func (g *STKGenerator) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) {
|
||||||
|
return g.stkSource.VerifyToken(sourceAddr, token)
|
||||||
|
}
|
||||||
35
handshake/stk_generator_test.go
Normal file
35
handshake/stk_generator_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("STK Generator", func() {
|
||||||
|
var stkGen *STKGenerator
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
stkGen, err = NewSTKGenerator()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("generates an STK", func() {
|
||||||
|
ip := net.IPv4(127, 0, 0, 1)
|
||||||
|
token, err := stkGen.NewToken(ip)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(token).ToNot(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("verifies an STK", func() {
|
||||||
|
ip := net.IPv4(192, 168, 0, 1)
|
||||||
|
token, err := stkGen.NewToken(ip)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
t, err := stkGen.VerifyToken(ip, token)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user