simplify the interaction with mint

This commit is contained in:
Marten Seemann
2017-10-21 19:28:51 +07:00
parent 282b423f7d
commit 9825ddb43a
5 changed files with 64 additions and 70 deletions

View File

@@ -12,7 +12,7 @@ import (
)
// KeyDerivationFunction is used for key derivation
type KeyDerivationFunction func(crypto.MintController, protocol.Perspective) (crypto.AEAD, error)
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
type cryptoSetupTLS struct {
mutex sync.RWMutex
@@ -21,8 +21,7 @@ type cryptoSetupTLS struct {
keyDerivation KeyDerivationFunction
conn *mint.Conn
extensionHandler mint.AppExtensionHandler
tls mintTLS
nullAEAD crypto.AEAD
aead crypto.AEAD
@@ -30,10 +29,6 @@ type cryptoSetupTLS struct {
aeadChanged chan<- protocol.EncryptionLevel
}
var newMintController = func(conn *mint.Conn) crypto.MintController {
return &mintController{conn}
}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer(
cryptoStream io.ReadWriter,
@@ -48,14 +43,18 @@ func NewCryptoSetupTLSServer(
if err != nil {
return nil, err
}
mintConn := mint.Server(&fakeConn{cryptoStream}, mintConf)
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
}
return &cryptoSetupTLS{
perspective: protocol.PerspectiveServer,
conn: mint.Server(&fakeConn{cryptoStream}, mintConf),
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
extensionHandler: newExtensionHandlerServer(params, paramsChan, supportedVersions, version),
perspective: protocol.PerspectiveServer,
tls: &mintController{mintConn},
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
}, nil
}
@@ -76,28 +75,28 @@ func NewCryptoSetupTLSClient(
return nil, err
}
mintConf.ServerName = hostname
mintConn := mint.Client(&fakeConn{cryptoStream}, mintConf)
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
}
return &cryptoSetupTLS{
perspective: protocol.PerspectiveClient,
conn: mint.Client(&fakeConn{cryptoStream}, mintConf),
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
extensionHandler: newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version),
perspective: protocol.PerspectiveClient,
tls: &mintController{mintConn},
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
}, nil
}
func (h *cryptoSetupTLS) HandleCryptoStream() error {
if err := h.conn.SetExtensionHandler(h.extensionHandler); err != nil {
return err
}
mc := newMintController(h.conn)
if alert := mc.Handshake(); alert != mint.AlertNoAlert {
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
}
aead, err := h.keyDerivation(mc, h.perspective)
aead, err := h.keyDerivation(h.tls, h.perspective)
if err != nil {
return err
}

View File

@@ -12,21 +12,24 @@ import (
. "github.com/onsi/gomega"
)
type fakeMintController struct {
type fakeMintTLS struct {
result mint.Alert
}
var _ crypto.MintController = &fakeMintController{}
var _ mintTLS = &fakeMintTLS{}
func (h *fakeMintController) Handshake() mint.Alert {
func (h *fakeMintTLS) Handshake() mint.Alert {
return h.result
}
func (h *fakeMintController) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") }
func (h *fakeMintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
func (h *fakeMintTLS) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") }
func (h *fakeMintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
panic("not implemented")
}
func (h *fakeMintTLS) SetExtensionHandler(mint.AppExtensionHandler) error {
panic("not implemented")
}
func mockKeyDerivation(crypto.MintController, protocol.Perspective) (crypto.AEAD, error) {
func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) {
return &mockAEAD{encLevel: protocol.EncryptionForwardSecure}, nil
}
@@ -35,8 +38,6 @@ var _ = Describe("TLS Crypto Setup", func() {
cs *cryptoSetupTLS
paramsChan chan TransportParameters
aeadChanged chan protocol.EncryptionLevel
mintControllerConstructor = newMintController
)
BeforeEach(func() {
@@ -55,23 +56,15 @@ var _ = Describe("TLS Crypto Setup", func() {
cs = csInt.(*cryptoSetupTLS)
})
AfterEach(func() {
newMintController = mintControllerConstructor
})
It("errors when the handshake fails", func() {
alert := mint.AlertBadRecordMAC
newMintController = func(*mint.Conn) crypto.MintController {
return &fakeMintController{result: alert}
}
cs.tls = &fakeMintTLS{result: alert}
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
})
It("derives keys", func() {
newMintController = func(*mint.Conn) crypto.MintController {
return &fakeMintController{result: mint.AlertNoAlert}
}
cs.tls = &fakeMintTLS{result: mint.AlertNoAlert}
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
@@ -83,9 +76,7 @@ var _ = Describe("TLS Crypto Setup", func() {
var foobarFNVSigned []byte // a "foobar", FNV signed
doHandshake := func() {
newMintController = func(*mint.Conn) crypto.MintController {
return &fakeMintController{result: mint.AlertNoAlert}
}
cs.tls = &fakeMintTLS{result: mint.AlertNoAlert}
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())

View File

@@ -43,15 +43,16 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
return mconf, nil
}
type mintTLS interface {
crypto.TLSExporter
Handshake() mint.Alert
}
type mintController struct {
conn *mint.Conn
}
var _ crypto.MintController = &mintController{}
func (mc *mintController) Handshake() mint.Alert {
return mc.conn.Handshake()
}
var _ mintTLS = &mintController{}
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
return mc.conn.State().CipherSuite
@@ -61,6 +62,10 @@ func (mc *mintController) ComputeExporter(label string, context []byte, keyLengt
return mc.conn.ComputeExporter(label, context, keyLength)
}
func (mc *mintController) Handshake() mint.Alert {
return mc.conn.Handshake()
}
// mint expects a net.Conn, but we're doing the handshake on a stream
// so we wrap a stream such that implements a net.Conn
type fakeConn struct {