forked from quic-go/quic-go
use the non-blocking of mint to cycle through the handshake
This commit is contained in:
@@ -21,7 +21,8 @@ type cryptoSetupTLS struct {
|
||||
|
||||
keyDerivation KeyDerivationFunction
|
||||
|
||||
tls mintTLS
|
||||
tls mintTLS
|
||||
conn *fakeConn
|
||||
|
||||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
@@ -44,7 +45,8 @@ func NewCryptoSetupTLSServer(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mintConn := mint.Server(&fakeConn{cryptoStream}, mintConf)
|
||||
conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveServer}
|
||||
mintConn := mint.Server(conn, mintConf)
|
||||
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
return nil, err
|
||||
@@ -58,6 +60,7 @@ func NewCryptoSetupTLSServer(
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveServer,
|
||||
tls: &mintController{mintConn},
|
||||
conn: conn,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
@@ -82,7 +85,8 @@ func NewCryptoSetupTLSClient(
|
||||
return nil, err
|
||||
}
|
||||
mintConf.ServerName = hostname
|
||||
mintConn := mint.Client(&fakeConn{cryptoStream}, mintConf)
|
||||
conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveClient}
|
||||
mintConn := mint.Client(conn, mintConf)
|
||||
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
return nil, err
|
||||
@@ -94,6 +98,7 @@ func NewCryptoSetupTLSClient(
|
||||
}
|
||||
|
||||
return &cryptoSetupTLS{
|
||||
conn: conn,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: &mintController{mintConn},
|
||||
nullAEAD: nullAEAD,
|
||||
@@ -103,9 +108,16 @@ func NewCryptoSetupTLSClient(
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||
|
||||
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
handshakeLoop:
|
||||
for {
|
||||
switch alert := h.tls.Handshake(); alert {
|
||||
case mint.AlertNoAlert: // handshake complete
|
||||
break handshakeLoop
|
||||
case mint.AlertWouldBlock:
|
||||
h.conn.UnblockRead()
|
||||
default:
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
}
|
||||
}
|
||||
|
||||
aead, err := h.keyDerivation(h.tls, h.perspective)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
@@ -14,23 +15,6 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type fakeMintTLS struct {
|
||||
result mint.Alert
|
||||
}
|
||||
|
||||
var _ mintTLS = &fakeMintTLS{}
|
||||
|
||||
func (h *fakeMintTLS) Handshake() mint.Alert {
|
||||
return h.result
|
||||
}
|
||||
func (h *fakeMintTLS) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") }
|
||||
func (h *fakeMintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (h *fakeMintTLS) SetExtensionHandler(mint.AppExtensionHandler) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) {
|
||||
return mockcrypto.NewMockAEAD(mockCtrl), nil
|
||||
}
|
||||
@@ -62,13 +46,24 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||
|
||||
It("errors when the handshake fails", func() {
|
||||
alert := mint.AlertBadRecordMAC
|
||||
cs.tls = &fakeMintTLS{result: alert}
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(alert)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
|
||||
})
|
||||
|
||||
It("continues shaking hands when mint says that it would block", func() {
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("derives keys", func() {
|
||||
cs.tls = &fakeMintTLS{result: mint.AlertNoAlert}
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -78,7 +73,8 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||
|
||||
Context("escalating crypto", func() {
|
||||
doHandshake := func() {
|
||||
cs.tls = &fakeMintTLS{result: mint.AlertNoAlert}
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -44,10 +44,16 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
|
||||
}
|
||||
|
||||
type mintTLS interface {
|
||||
crypto.TLSExporter
|
||||
// These two methods are the same as the crypto.TLSExporter interface.
|
||||
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
|
||||
GetCipherSuite() mint.CipherSuiteParams
|
||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
// additional methods
|
||||
Handshake() mint.Alert
|
||||
}
|
||||
|
||||
var _ crypto.TLSExporter = (mintTLS)(nil)
|
||||
|
||||
type mintController struct {
|
||||
conn *mint.Conn
|
||||
}
|
||||
@@ -69,11 +75,30 @@ func (mc *mintController) Handshake() mint.Alert {
|
||||
// mint expects a net.Conn, but we're doing the handshake on a stream
|
||||
// so we wrap a stream such that implements a net.Conn
|
||||
type fakeConn struct {
|
||||
io.ReadWriter
|
||||
stream io.ReadWriter
|
||||
pers protocol.Perspective
|
||||
|
||||
blockRead bool
|
||||
}
|
||||
|
||||
var _ net.Conn = &fakeConn{}
|
||||
|
||||
func (c *fakeConn) Read(b []byte) (int, error) {
|
||||
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
|
||||
return 0, nil
|
||||
}
|
||||
c.blockRead = true // block the next Read call
|
||||
return c.stream.Read(b)
|
||||
}
|
||||
|
||||
func (c *fakeConn) Write(p []byte) (int, error) {
|
||||
return c.stream.Write(p)
|
||||
}
|
||||
|
||||
func (c *fakeConn) UnblockRead() {
|
||||
c.blockRead = false
|
||||
}
|
||||
|
||||
func (c *fakeConn) Close() error { return nil }
|
||||
func (c *fakeConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) RemoteAddr() net.Addr { return nil }
|
||||
|
||||
44
internal/handshake/mint_utils_test.go
Normal file
44
internal/handshake/mint_utils_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Fake Conn", func() {
|
||||
var (
|
||||
c *fakeConn
|
||||
stream *bytes.Buffer
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stream = &bytes.Buffer{}
|
||||
c = &fakeConn{stream: stream}
|
||||
})
|
||||
|
||||
Context("Reading", func() {
|
||||
It("doesn't return any new data after one Read call", func() {
|
||||
stream.Write([]byte("foobar"))
|
||||
b := make([]byte, 3)
|
||||
_, err := c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal([]byte("foo")))
|
||||
n, err := c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
|
||||
It("allows more Read calls after unblocking", func() {
|
||||
stream.Write([]byte("foobar"))
|
||||
b := make([]byte, 3)
|
||||
_, err := c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c.UnblockRead()
|
||||
_, err = c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal([]byte("bar")))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,5 +1,6 @@
|
||||
package mocks
|
||||
|
||||
//go:generate sh -c "mockgen -source=../handshake/mint_utils.go -package mockhandshake -destination handshake/mint_tls.go"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
|
||||
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"
|
||||
|
||||
71
internal/mocks/handshake/mint_tls.go
Normal file
71
internal/mocks/handshake/mint_tls.go
Normal file
@@ -0,0 +1,71 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ../handshake/mint_utils.go
|
||||
|
||||
package mockhandshake
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
mint "github.com/bifurcation/mint"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockmintTLS is a mock of mintTLS interface
|
||||
type MockmintTLS struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockmintTLSMockRecorder
|
||||
}
|
||||
|
||||
// MockmintTLSMockRecorder is the mock recorder for MockmintTLS
|
||||
type MockmintTLSMockRecorder struct {
|
||||
mock *MockmintTLS
|
||||
}
|
||||
|
||||
// NewMockmintTLS creates a new mock instance
|
||||
func NewMockmintTLS(ctrl *gomock.Controller) *MockmintTLS {
|
||||
mock := &MockmintTLS{ctrl: ctrl}
|
||||
mock.recorder = &MockmintTLSMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (_m *MockmintTLS) EXPECT() *MockmintTLSMockRecorder {
|
||||
return _m.recorder
|
||||
}
|
||||
|
||||
// GetCipherSuite mocks base method
|
||||
func (_m *MockmintTLS) GetCipherSuite() mint.CipherSuiteParams {
|
||||
ret := _m.ctrl.Call(_m, "GetCipherSuite")
|
||||
ret0, _ := ret[0].(mint.CipherSuiteParams)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetCipherSuite indicates an expected call of GetCipherSuite
|
||||
func (_mr *MockmintTLSMockRecorder) GetCipherSuite() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockmintTLS)(nil).GetCipherSuite))
|
||||
}
|
||||
|
||||
// ComputeExporter mocks base method
|
||||
func (_m *MockmintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
ret := _m.ctrl.Call(_m, "ComputeExporter", label, context, keyLength)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ComputeExporter indicates an expected call of ComputeExporter
|
||||
func (_mr *MockmintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockmintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Handshake mocks base method
|
||||
func (_m *MockmintTLS) Handshake() mint.Alert {
|
||||
ret := _m.ctrl.Call(_m, "Handshake")
|
||||
ret0, _ := ret[0].(mint.Alert)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Handshake indicates an expected call of Handshake
|
||||
func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake))
|
||||
}
|
||||
Reference in New Issue
Block a user