diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 398bebb44..d6b15adaa 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -21,7 +21,8 @@ type cryptoSetupTLS struct { keyDerivation KeyDerivationFunction - tls mintTLS + tls mintTLS + conn *fakeConn nullAEAD crypto.AEAD aead crypto.AEAD @@ -44,7 +45,8 @@ func NewCryptoSetupTLSServer( if err != nil { return nil, err } - mintConn := mint.Server(&fakeConn{cryptoStream}, mintConf) + conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveServer} + mintConn := mint.Server(conn, mintConf) eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version) if err := mintConn.SetExtensionHandler(eh); err != nil { return nil, err @@ -58,6 +60,7 @@ func NewCryptoSetupTLSServer( return &cryptoSetupTLS{ perspective: protocol.PerspectiveServer, tls: &mintController{mintConn}, + conn: conn, nullAEAD: nullAEAD, keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, @@ -82,7 +85,8 @@ func NewCryptoSetupTLSClient( return nil, err } mintConf.ServerName = hostname - mintConn := mint.Client(&fakeConn{cryptoStream}, mintConf) + conn := &fakeConn{stream: cryptoStream, pers: protocol.PerspectiveClient} + mintConn := mint.Client(conn, mintConf) eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version) if err := mintConn.SetExtensionHandler(eh); err != nil { return nil, err @@ -94,6 +98,7 @@ func NewCryptoSetupTLSClient( } return &cryptoSetupTLS{ + conn: conn, perspective: protocol.PerspectiveClient, tls: &mintController{mintConn}, nullAEAD: nullAEAD, @@ -103,9 +108,16 @@ func NewCryptoSetupTLSClient( } func (h *cryptoSetupTLS) HandleCryptoStream() error { - - if alert := h.tls.Handshake(); alert != mint.AlertNoAlert { - return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) +handshakeLoop: + for { + switch alert := h.tls.Handshake(); alert { + case mint.AlertNoAlert: // handshake complete + break handshakeLoop + case mint.AlertWouldBlock: + h.conn.UnblockRead() + default: + return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) + } } aead, err := h.keyDerivation(h.tls, h.perspective) diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index d950c11a6..92f4c8c02 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -7,6 +7,7 @@ import ( "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/mocks/crypto" + "github.com/lucas-clemente/quic-go/internal/mocks/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -14,23 +15,6 @@ import ( . "github.com/onsi/gomega" ) -type fakeMintTLS struct { - result mint.Alert -} - -var _ mintTLS = &fakeMintTLS{} - -func (h *fakeMintTLS) Handshake() mint.Alert { - return h.result -} -func (h *fakeMintTLS) GetCipherSuite() mint.CipherSuiteParams { panic("not implemented") } -func (h *fakeMintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - panic("not implemented") -} -func (h *fakeMintTLS) SetExtensionHandler(mint.AppExtensionHandler) error { - panic("not implemented") -} - func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) { return mockcrypto.NewMockAEAD(mockCtrl), nil } @@ -62,13 +46,24 @@ var _ = Describe("TLS Crypto Setup", func() { It("errors when the handshake fails", func() { alert := mint.AlertBadRecordMAC - cs.tls = &fakeMintTLS{result: alert} + cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(alert) err := cs.HandleCryptoStream() Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert))) }) + It("continues shaking hands when mint says that it would block", func() { + cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock) + cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) + cs.keyDerivation = mockKeyDerivation + err := cs.HandleCryptoStream() + Expect(err).ToNot(HaveOccurred()) + }) + It("derives keys", func() { - cs.tls = &fakeMintTLS{result: mint.AlertNoAlert} + cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) @@ -78,7 +73,8 @@ var _ = Describe("TLS Crypto Setup", func() { Context("escalating crypto", func() { doHandshake := func() { - cs.tls = &fakeMintTLS{result: mint.AlertNoAlert} + cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) + cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) diff --git a/internal/handshake/mint_utils.go b/internal/handshake/mint_utils.go index 63e8d8c4d..3792cfc44 100644 --- a/internal/handshake/mint_utils.go +++ b/internal/handshake/mint_utils.go @@ -44,10 +44,16 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf } type mintTLS interface { - crypto.TLSExporter + // These two methods are the same as the crypto.TLSExporter interface. + // Cannot use embedding here, because mockgen source mode refuses to generate mocks then. + GetCipherSuite() mint.CipherSuiteParams + ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) + // additional methods Handshake() mint.Alert } +var _ crypto.TLSExporter = (mintTLS)(nil) + type mintController struct { conn *mint.Conn } @@ -69,11 +75,30 @@ func (mc *mintController) Handshake() mint.Alert { // mint expects a net.Conn, but we're doing the handshake on a stream // so we wrap a stream such that implements a net.Conn type fakeConn struct { - io.ReadWriter + stream io.ReadWriter + pers protocol.Perspective + + blockRead bool } var _ net.Conn = &fakeConn{} +func (c *fakeConn) Read(b []byte) (int, error) { + if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock + return 0, nil + } + c.blockRead = true // block the next Read call + return c.stream.Read(b) +} + +func (c *fakeConn) Write(p []byte) (int, error) { + return c.stream.Write(p) +} + +func (c *fakeConn) UnblockRead() { + c.blockRead = false +} + func (c *fakeConn) Close() error { return nil } func (c *fakeConn) LocalAddr() net.Addr { return nil } func (c *fakeConn) RemoteAddr() net.Addr { return nil } diff --git a/internal/handshake/mint_utils_test.go b/internal/handshake/mint_utils_test.go new file mode 100644 index 000000000..92525b83b --- /dev/null +++ b/internal/handshake/mint_utils_test.go @@ -0,0 +1,44 @@ +package handshake + +import ( + "bytes" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Fake Conn", func() { + var ( + c *fakeConn + stream *bytes.Buffer + ) + + BeforeEach(func() { + stream = &bytes.Buffer{} + c = &fakeConn{stream: stream} + }) + + Context("Reading", func() { + It("doesn't return any new data after one Read call", func() { + stream.Write([]byte("foobar")) + b := make([]byte, 3) + _, err := c.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foo"))) + n, err := c.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(BeZero()) + }) + + It("allows more Read calls after unblocking", func() { + stream.Write([]byte("foobar")) + b := make([]byte, 3) + _, err := c.Read(b) + Expect(err).ToNot(HaveOccurred()) + c.UnblockRead() + _, err = c.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("bar"))) + }) + }) +}) diff --git a/internal/mocks/gen.go b/internal/mocks/gen.go index 098c1666e..e44c04d08 100644 --- a/internal/mocks/gen.go +++ b/internal/mocks/gen.go @@ -1,5 +1,6 @@ package mocks +//go:generate sh -c "mockgen -source=../handshake/mint_utils.go -package mockhandshake -destination handshake/mint_tls.go" //go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" //go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController" //go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD" diff --git a/internal/mocks/handshake/mint_tls.go b/internal/mocks/handshake/mint_tls.go new file mode 100644 index 000000000..086979262 --- /dev/null +++ b/internal/mocks/handshake/mint_tls.go @@ -0,0 +1,71 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../handshake/mint_utils.go + +package mockhandshake + +import ( + reflect "reflect" + + mint "github.com/bifurcation/mint" + gomock "github.com/golang/mock/gomock" +) + +// MockmintTLS is a mock of mintTLS interface +type MockmintTLS struct { + ctrl *gomock.Controller + recorder *MockmintTLSMockRecorder +} + +// MockmintTLSMockRecorder is the mock recorder for MockmintTLS +type MockmintTLSMockRecorder struct { + mock *MockmintTLS +} + +// NewMockmintTLS creates a new mock instance +func NewMockmintTLS(ctrl *gomock.Controller) *MockmintTLS { + mock := &MockmintTLS{ctrl: ctrl} + mock.recorder = &MockmintTLSMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (_m *MockmintTLS) EXPECT() *MockmintTLSMockRecorder { + return _m.recorder +} + +// GetCipherSuite mocks base method +func (_m *MockmintTLS) GetCipherSuite() mint.CipherSuiteParams { + ret := _m.ctrl.Call(_m, "GetCipherSuite") + ret0, _ := ret[0].(mint.CipherSuiteParams) + return ret0 +} + +// GetCipherSuite indicates an expected call of GetCipherSuite +func (_mr *MockmintTLSMockRecorder) GetCipherSuite() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockmintTLS)(nil).GetCipherSuite)) +} + +// ComputeExporter mocks base method +func (_m *MockmintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { + ret := _m.ctrl.Call(_m, "ComputeExporter", label, context, keyLength) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ComputeExporter indicates an expected call of ComputeExporter +func (_mr *MockmintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockmintTLS)(nil).ComputeExporter), arg0, arg1, arg2) +} + +// Handshake mocks base method +func (_m *MockmintTLS) Handshake() mint.Alert { + ret := _m.ctrl.Call(_m, "Handshake") + ret0, _ := ret[0].(mint.Alert) + return ret0 +} + +// Handshake indicates an expected call of Handshake +func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake)) +}