move the STK generation from the ServerConfig to a separate struct

This commit is contained in:
Marten Seemann
2017-05-13 11:47:36 +08:00
parent 6cc6d49a10
commit 9562df5838
7 changed files with 86 additions and 21 deletions

View File

@@ -58,7 +58,11 @@ const stkKeySize = 16
const stkNonceSize = 16
// NewStkSource creates a source for source address tokens
func NewStkSource(secret []byte) (StkSource, error) {
func NewStkSource() (StkSource, error) {
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
key, err := deriveKey(secret)
if err != nil {
return nil, err

View File

@@ -43,7 +43,6 @@ var _ = Describe("Source Address Tokens", func() {
Context("source", func() {
var (
source *stkSource
secret []byte
ip4 net.IP
ip6 net.IP
)
@@ -56,8 +55,7 @@ var _ = Describe("Source Address Tokens", func() {
ip6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")
Expect(ip6).NotTo(BeEmpty())
secret = []byte("TESTING")
sourceI, err := NewStkSource(secret)
sourceI, err := NewStkSource()
source = sourceI.(*stkSource)
Expect(err).NotTo(HaveOccurred())
})

View File

@@ -27,6 +27,7 @@ type cryptoSetupServer struct {
connID protocol.ConnectionID
sourceAddr []byte
scfg *ServerConfig
stkGenerator *STKGenerator
diversificationNonce []byte
version protocol.VersionNumber
@@ -68,6 +69,11 @@ func NewCryptoSetup(
supportedVersions []protocol.VersionNumber,
aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) {
stkGenerator, err := NewSTKGenerator()
if err != nil {
return nil, err
}
var sourceAddr []byte
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
sourceAddr = udpAddr.IP
@@ -81,6 +87,7 @@ func NewCryptoSetup(
version: version,
supportedVersions: supportedVersions,
scfg: scfg,
stkGenerator: stkGenerator,
keyDerivation: crypto.DeriveKeysAESGCM,
keyExchange: getEphermalKEX,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
@@ -276,7 +283,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
}
func (h *cryptoSetupServer) verifySTK(stk []byte) bool {
stkTime, err := h.scfg.stkSource.VerifyToken(h.sourceAddr, stk)
stkTime, err := h.stkGenerator.VerifyToken(h.sourceAddr, stk)
if err != nil {
utils.Debugf("STK invalid: %s", err.Error())
return false
@@ -292,7 +299,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
return nil, qerr.Error(qerr.CryptoInvalidValueLength, "CHLO too small")
}
token, err := h.scfg.stkSource.NewToken(h.sourceAddr)
token, err := h.stkGenerator.NewToken(h.sourceAddr)
if err != nil {
return nil, err
}

View File

@@ -194,7 +194,6 @@ var _ = Describe("Server Crypto Setup", func() {
versionTag = make([]byte, 4)
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(protocol.VersionWhatever))
Expect(err).NotTo(HaveOccurred())
scfg.stkSource = &mockStkSource{}
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
supportedVersions = []protocol.VersionNumber{version, 98, 99}
cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever)
@@ -210,6 +209,7 @@ var _ = Describe("Server Crypto Setup", func() {
)
Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer)
cs.stkGenerator.stkSource = &mockStkSource{}
cs.keyDerivation = mockKeyDerivation
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
})
@@ -436,12 +436,12 @@ var _ = Describe("Server Crypto Setup", func() {
It("recognizes inchoate CHLOs with an invalid STK", func() {
testErr := errors.New("STK invalid")
scfg.stkSource.(*mockStkSource).verifyErr = testErr
cs.stkGenerator.stkSource.(*mockStkSource).verifyErr = testErr
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("REJ messages that have an expired STK", func() {
cs.scfg.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
cs.stkGenerator.stkSource.(*mockStkSource).stkTime = time.Now().Add(-protocol.STKExpiryTime).Add(-time.Second)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})

View File

@@ -13,7 +13,6 @@ type ServerConfig struct {
certChain crypto.CertChain
ID []byte
obit []byte
stkSource crypto.StkSource
}
// NewServerConfig creates a new server config
@@ -24,27 +23,16 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve
return nil, err
}
stkSecret := make([]byte, 32)
if _, err = rand.Read(stkSecret); err != nil {
return nil, err
}
obit := make([]byte, 8)
if _, err = rand.Read(obit); err != nil {
return nil, err
}
stkSource, err := crypto.NewStkSource(stkSecret)
if err != nil {
return nil, err
}
return &ServerConfig{
kex: kex,
certChain: certChain,
ID: id,
obit: obit,
stkSource: stkSource,
}, nil
}

View File

@@ -0,0 +1,33 @@
package handshake
import (
"time"
"github.com/lucas-clemente/quic-go/crypto"
)
// An STKGenerator generates STKs
type STKGenerator struct {
stkSource crypto.StkSource
}
// NewSTKGenerator initializes a new STKGenerator
func NewSTKGenerator() (*STKGenerator, error) {
stkSource, err := crypto.NewStkSource()
if err != nil {
return nil, err
}
return &STKGenerator{
stkSource: stkSource,
}, nil
}
// NewToken generates a new STK token for a given source address
func (g *STKGenerator) NewToken(sourceAddr []byte) ([]byte, error) {
return g.stkSource.NewToken(sourceAddr)
}
// VerifyToken verifies an STK token
func (g *STKGenerator) VerifyToken(sourceAddr []byte, token []byte) (time.Time, error) {
return g.stkSource.VerifyToken(sourceAddr, token)
}

View File

@@ -0,0 +1,35 @@
package handshake
import (
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STK Generator", func() {
var stkGen *STKGenerator
BeforeEach(func() {
var err error
stkGen, err = NewSTKGenerator()
Expect(err).ToNot(HaveOccurred())
})
It("generates an STK", func() {
ip := net.IPv4(127, 0, 0, 1)
token, err := stkGen.NewToken(ip)
Expect(err).ToNot(HaveOccurred())
Expect(token).ToNot(BeEmpty())
})
It("verifies an STK", func() {
ip := net.IPv4(192, 168, 0, 1)
token, err := stkGen.NewToken(ip)
Expect(err).ToNot(HaveOccurred())
t, err := stkGen.VerifyToken(ip, token)
Expect(err).ToNot(HaveOccurred())
Expect(t).To(BeTemporally("~", time.Now(), time.Second))
})
})