use the non-blocking of mint to cycle through the handshake

This commit is contained in:
Marten Seemann
2017-10-26 18:32:44 +07:00
parent fcc380187a
commit 3e39991e1e
6 changed files with 177 additions and 28 deletions

View File

@@ -21,7 +21,8 @@ type cryptoSetupTLS struct {
keyDerivation KeyDerivationFunction
tls mintTLS
tls mintTLS
conn *fakeConn
nullAEAD crypto.AEAD
aead crypto.AEAD
@@ -44,7 +45,8 @@ func NewCryptoSetupTLSServer(
if err != nil {
return nil, err
}
mintConn := mint.Server(&fakeConn{cryptoStream}, mintConf)
conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveServer}
mintConn := mint.Server(conn, mintConf)
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
@@ -58,6 +60,7 @@ func NewCryptoSetupTLSServer(
return &cryptoSetupTLS{
perspective: protocol.PerspectiveServer,
tls: &mintController{mintConn},
conn: conn,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
@@ -82,7 +85,8 @@ func NewCryptoSetupTLSClient(
return nil, err
}
mintConf.ServerName = hostname
mintConn := mint.Client(&fakeConn{cryptoStream}, mintConf)
conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveClient}
mintConn := mint.Client(conn, mintConf)
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
@@ -94,6 +98,7 @@ func NewCryptoSetupTLSClient(
}
return &cryptoSetupTLS{
conn: conn,
perspective: protocol.PerspectiveClient,
tls: &mintController{mintConn},
nullAEAD: nullAEAD,
@@ -103,9 +108,16 @@ func NewCryptoSetupTLSClient(
}
func (h *cryptoSetupTLS) HandleCryptoStream() error {
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
handshakeLoop:
for {
switch alert := h.tls.Handshake(); alert {
case mint.AlertNoAlert: // handshake complete
break handshakeLoop
case mint.AlertWouldBlock:
h.conn.UnblockRead()
default:
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
}
}
aead, err := h.keyDerivation(h.tls, h.perspective)

View File

@@ -7,6 +7,7 @@ import (
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
@@ -14,23 +15,6 @@ import (
. "github.com/onsi/gomega"
)
type fakeMintTLS struct {
result mint.Alert
}
var _ mintTLS = &fakeMintTLS{}
func (h *fakeMintTLS) Handshake() mint.Alert {
return h.result
}
func (h *fakeMintTLS) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") }
func (h *fakeMintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
panic("not implemented")
}
func (h *fakeMintTLS) SetExtensionHandler(mint.AppExtensionHandler) error {
panic("not implemented")
}
func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) {
return mockcrypto.NewMockAEAD(mockCtrl), nil
}
@@ -62,13 +46,24 @@ var _ = Describe("TLS Crypto Setup", func() {
It("errors when the handshake fails", func() {
alert := mint.AlertBadRecordMAC
cs.tls = &fakeMintTLS{result: alert}
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(alert)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
})
It("continues shaking hands when mint says that it would block", func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
})
It("derives keys", func() {
cs.tls = &fakeMintTLS{result: mint.AlertNoAlert}
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
@@ -78,7 +73,8 @@ var _ = Describe("TLS Crypto Setup", func() {
Context("escalating crypto", func() {
doHandshake := func() {
cs.tls = &fakeMintTLS{result: mint.AlertNoAlert}
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())

View File

@@ -44,10 +44,16 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
}
type mintTLS interface {
crypto.TLSExporter
// These two methods are the same as the crypto.TLSExporter interface.
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
GetCipherSuite() mint.CipherSuiteParams
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
// additional methods
Handshake() mint.Alert
}
var _ crypto.TLSExporter = (mintTLS)(nil)
type mintController struct {
conn *mint.Conn
}
@@ -69,11 +75,30 @@ func (mc *mintController) Handshake() mint.Alert {
// mint expects a net.Conn, but we're doing the handshake on a stream
// so we wrap a stream such that implements a net.Conn
type fakeConn struct {
io.ReadWriter
stream io.ReadWriter
pers protocol.Perspective
blockRead bool
}
var _ net.Conn = &fakeConn{}
func (c *fakeConn) Read(b []byte) (int, error) {
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
return 0, nil
}
c.blockRead = true // block the next Read call
return c.stream.Read(b)
}
func (c *fakeConn) Write(p []byte) (int, error) {
return c.stream.Write(p)
}
func (c *fakeConn) UnblockRead() {
c.blockRead = false
}
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) LocalAddr() net.Addr { return nil }
func (c *fakeConn) RemoteAddr() net.Addr { return nil }

View File

@@ -0,0 +1,44 @@
package handshake
import (
"bytes"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Fake Conn", func() {
var (
c *fakeConn
stream *bytes.Buffer
)
BeforeEach(func() {
stream = &bytes.Buffer{}
c = &fakeConn{stream: stream}
})
Context("Reading", func() {
It("doesn't return any new data after one Read call", func() {
stream.Write([]byte("foobar"))
b := make([]byte, 3)
_, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("foo")))
n, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(BeZero())
})
It("allows more Read calls after unblocking", func() {
stream.Write([]byte("foobar"))
b := make([]byte, 3)
_, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
c.UnblockRead()
_, err = c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("bar")))
})
})
})

View File

@@ -1,5 +1,6 @@
package mocks
//go:generate sh -c "mockgen -source=../handshake/mint_utils.go -package mockhandshake -destination handshake/mint_tls.go"
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"

View File

@@ -0,0 +1,71 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ../handshake/mint_utils.go
package mockhandshake
import (
reflect "reflect"
mint "github.com/bifurcation/mint"
gomock "github.com/golang/mock/gomock"
)
// MockmintTLS is a mock of mintTLS interface
type MockmintTLS struct {
ctrl *gomock.Controller
recorder *MockmintTLSMockRecorder
}
// MockmintTLSMockRecorder is the mock recorder for MockmintTLS
type MockmintTLSMockRecorder struct {
mock *MockmintTLS
}
// NewMockmintTLS creates a new mock instance
func NewMockmintTLS(ctrl *gomock.Controller) *MockmintTLS {
mock := &MockmintTLS{ctrl: ctrl}
mock.recorder = &MockmintTLSMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (_m *MockmintTLS) EXPECT() *MockmintTLSMockRecorder {
return _m.recorder
}
// GetCipherSuite mocks base method
func (_m *MockmintTLS) GetCipherSuite() mint.CipherSuiteParams {
ret := _m.ctrl.Call(_m, "GetCipherSuite")
ret0, _ := ret[0].(mint.CipherSuiteParams)
return ret0
}
// GetCipherSuite indicates an expected call of GetCipherSuite
func (_mr *MockmintTLSMockRecorder) GetCipherSuite() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockmintTLS)(nil).GetCipherSuite))
}
// ComputeExporter mocks base method
func (_m *MockmintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
ret := _m.ctrl.Call(_m, "ComputeExporter", label, context, keyLength)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ComputeExporter indicates an expected call of ComputeExporter
func (_mr *MockmintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockmintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
}
// Handshake mocks base method
func (_m *MockmintTLS) Handshake() mint.Alert {
ret := _m.ctrl.Call(_m, "Handshake")
ret0, _ := ret[0].(mint.Alert)
return ret0
}
// Handshake indicates an expected call of Handshake
func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake))
}