forked from quic-go/quic-go
simplify the interaction with mint
This commit is contained in:
@@ -10,15 +10,14 @@ const (
|
||||
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret"
|
||||
)
|
||||
|
||||
// MintController is an interface that bundles all methods needed to interact with mint
|
||||
type MintController interface {
|
||||
Handshake() mint.Alert
|
||||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
||||
type TLSExporter interface {
|
||||
GetCipherSuite() mint.CipherSuiteParams
|
||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
}
|
||||
|
||||
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveAESKeys(mc MintController, pers protocol.Perspective) (AEAD, error) {
|
||||
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
||||
var myLabel, otherLabel string
|
||||
if pers == protocol.PerspectiveClient {
|
||||
myLabel = clientExporterLabel
|
||||
@@ -27,20 +26,20 @@ func DeriveAESKeys(mc MintController, pers protocol.Perspective) (AEAD, error) {
|
||||
myLabel = serverExporterLabel
|
||||
otherLabel = clientExporterLabel
|
||||
}
|
||||
myKey, myIV, err := computeKeyAndIV(mc, myLabel)
|
||||
myKey, myIV, err := computeKeyAndIV(tls, myLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otherKey, otherIV, err := computeKeyAndIV(mc, otherLabel)
|
||||
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
func computeKeyAndIV(mc MintController, label string) (key, iv []byte, err error) {
|
||||
cs := mc.GetCipherSuite()
|
||||
secret, err := mc.ComputeExporter(label, nil, cs.Hash.Size())
|
||||
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
|
||||
cs := tls.GetCipherSuite()
|
||||
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -10,16 +10,16 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockMintController struct {
|
||||
type mockTLSExporter struct {
|
||||
hash crypto.Hash
|
||||
computerError error
|
||||
}
|
||||
|
||||
var _ MintController = &mockMintController{}
|
||||
var _ TLSExporter = &mockTLSExporter{}
|
||||
|
||||
func (c *mockMintController) Handshake() mint.Alert { panic("not implemented") }
|
||||
func (c *mockTLSExporter) Handshake() mint.Alert { panic("not implemented") }
|
||||
|
||||
func (c *mockMintController) GetCipherSuite() mint.CipherSuiteParams {
|
||||
func (c *mockTLSExporter) GetCipherSuite() mint.CipherSuiteParams {
|
||||
return mint.CipherSuiteParams{
|
||||
Hash: c.hash,
|
||||
KeyLen: 32,
|
||||
@@ -27,7 +27,7 @@ func (c *mockMintController) GetCipherSuite() mint.CipherSuiteParams {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockMintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
func (c *mockTLSExporter) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
if c.computerError != nil {
|
||||
return nil, c.computerError
|
||||
}
|
||||
@@ -36,9 +36,9 @@ func (c *mockMintController) ComputeExporter(label string, context []byte, keyLe
|
||||
|
||||
var _ = Describe("Key Derivation", func() {
|
||||
It("derives keys", func() {
|
||||
clientAEAD, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA256}, protocol.PerspectiveClient)
|
||||
clientAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAEAD, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA256}, protocol.PerspectiveServer)
|
||||
serverAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ciphertext := clientAEAD.Seal(nil, []byte("foobar"), 0, []byte("aad"))
|
||||
data, err := serverAEAD.Open(nil, ciphertext, 0, []byte("aad"))
|
||||
@@ -47,9 +47,9 @@ var _ = Describe("Key Derivation", func() {
|
||||
})
|
||||
|
||||
It("fails when different hash functions are used", func() {
|
||||
clientAEAD, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA256}, protocol.PerspectiveClient)
|
||||
clientAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAEAD, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA512}, protocol.PerspectiveServer)
|
||||
serverAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA512}, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ciphertext := clientAEAD.Seal(nil, []byte("foobar"), 0, []byte("aad"))
|
||||
_, err = serverAEAD.Open(nil, ciphertext, 0, []byte("aad"))
|
||||
@@ -58,7 +58,7 @@ var _ = Describe("Key Derivation", func() {
|
||||
|
||||
It("fails when computing the exporter fails", func() {
|
||||
testErr := errors.New("test error")
|
||||
_, err := DeriveAESKeys(&mockMintController{hash: crypto.SHA256, computerError: testErr}, protocol.PerspectiveClient)
|
||||
_, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256, computerError: testErr}, protocol.PerspectiveClient)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
// KeyDerivationFunction is used for key derivation
|
||||
type KeyDerivationFunction func(crypto.MintController, protocol.Perspective) (crypto.AEAD, error)
|
||||
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
type cryptoSetupTLS struct {
|
||||
mutex sync.RWMutex
|
||||
@@ -21,8 +21,7 @@ type cryptoSetupTLS struct {
|
||||
|
||||
keyDerivation KeyDerivationFunction
|
||||
|
||||
conn *mint.Conn
|
||||
extensionHandler mint.AppExtensionHandler
|
||||
tls mintTLS
|
||||
|
||||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
@@ -30,10 +29,6 @@ type cryptoSetupTLS struct {
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
}
|
||||
|
||||
var newMintController = func(conn *mint.Conn) crypto.MintController {
|
||||
return &mintController{conn}
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
func NewCryptoSetupTLSServer(
|
||||
cryptoStream io.ReadWriter,
|
||||
@@ -48,14 +43,18 @@ func NewCryptoSetupTLSServer(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mintConn := mint.Server(&fakeConn{cryptoStream}, mintConf)
|
||||
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveServer,
|
||||
conn: mint.Server(&fakeConn{cryptoStream}, mintConf),
|
||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
extensionHandler: newExtensionHandlerServer(params, paramsChan, supportedVersions, version),
|
||||
perspective: protocol.PerspectiveServer,
|
||||
tls: &mintController{mintConn},
|
||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -76,28 +75,28 @@ func NewCryptoSetupTLSClient(
|
||||
return nil, err
|
||||
}
|
||||
mintConf.ServerName = hostname
|
||||
mintConn := mint.Client(&fakeConn{cryptoStream}, mintConf)
|
||||
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveClient,
|
||||
conn: mint.Client(&fakeConn{cryptoStream}, mintConf),
|
||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
extensionHandler: newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version),
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: &mintController{mintConn},
|
||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||
if err := h.conn.SetExtensionHandler(h.extensionHandler); err != nil {
|
||||
return err
|
||||
}
|
||||
mc := newMintController(h.conn)
|
||||
|
||||
if alert := mc.Handshake(); alert != mint.AlertNoAlert {
|
||||
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
}
|
||||
|
||||
aead, err := h.keyDerivation(mc, h.perspective)
|
||||
aead, err := h.keyDerivation(h.tls, h.perspective)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -12,21 +12,24 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type fakeMintController struct {
|
||||
type fakeMintTLS struct {
|
||||
result mint.Alert
|
||||
}
|
||||
|
||||
var _ crypto.MintController = &fakeMintController{}
|
||||
var _ mintTLS = &fakeMintTLS{}
|
||||
|
||||
func (h *fakeMintController) Handshake() mint.Alert {
|
||||
func (h *fakeMintTLS) Handshake() mint.Alert {
|
||||
return h.result
|
||||
}
|
||||
func (h *fakeMintController) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") }
|
||||
func (h *fakeMintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
func (h *fakeMintTLS) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") }
|
||||
func (h *fakeMintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (h *fakeMintTLS) SetExtensionHandler(mint.AppExtensionHandler) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func mockKeyDerivation(crypto.MintController, protocol.Perspective) (crypto.AEAD, error) {
|
||||
func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) {
|
||||
return &mockAEAD{encLevel: protocol.EncryptionForwardSecure}, nil
|
||||
}
|
||||
|
||||
@@ -35,8 +38,6 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||
cs *cryptoSetupTLS
|
||||
paramsChan chan TransportParameters
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
|
||||
mintControllerConstructor = newMintController
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
@@ -55,23 +56,15 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||
cs = csInt.(*cryptoSetupTLS)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
newMintController = mintControllerConstructor
|
||||
})
|
||||
|
||||
It("errors when the handshake fails", func() {
|
||||
alert := mint.AlertBadRecordMAC
|
||||
newMintController = func(*mint.Conn) crypto.MintController {
|
||||
return &fakeMintController{result: alert}
|
||||
}
|
||||
cs.tls = &fakeMintTLS{result: alert}
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
|
||||
})
|
||||
|
||||
It("derives keys", func() {
|
||||
newMintController = func(*mint.Conn) crypto.MintController {
|
||||
return &fakeMintController{result: mint.AlertNoAlert}
|
||||
}
|
||||
cs.tls = &fakeMintTLS{result: mint.AlertNoAlert}
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -83,9 +76,7 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||
var foobarFNVSigned []byte // a "foobar", FNV signed
|
||||
|
||||
doHandshake := func() {
|
||||
newMintController = func(*mint.Conn) crypto.MintController {
|
||||
return &fakeMintController{result: mint.AlertNoAlert}
|
||||
}
|
||||
cs.tls = &fakeMintTLS{result: mint.AlertNoAlert}
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -43,15 +43,16 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
|
||||
return mconf, nil
|
||||
}
|
||||
|
||||
type mintTLS interface {
|
||||
crypto.TLSExporter
|
||||
Handshake() mint.Alert
|
||||
}
|
||||
|
||||
type mintController struct {
|
||||
conn *mint.Conn
|
||||
}
|
||||
|
||||
var _ crypto.MintController = &mintController{}
|
||||
|
||||
func (mc *mintController) Handshake() mint.Alert {
|
||||
return mc.conn.Handshake()
|
||||
}
|
||||
var _ mintTLS = &mintController{}
|
||||
|
||||
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
|
||||
return mc.conn.State().CipherSuite
|
||||
@@ -61,6 +62,10 @@ func (mc *mintController) ComputeExporter(label string, context []byte, keyLengt
|
||||
return mc.conn.ComputeExporter(label, context, keyLength)
|
||||
}
|
||||
|
||||
func (mc *mintController) Handshake() mint.Alert {
|
||||
return mc.conn.Handshake()
|
||||
}
|
||||
|
||||
// mint expects a net.Conn, but we're doing the handshake on a stream
|
||||
// so we wrap a stream such that implements a net.Conn
|
||||
type fakeConn struct {
|
||||
|
||||
Reference in New Issue
Block a user