implement the new Retry mechanism

This commit is contained in:
Marten Seemann
2018-08-12 10:41:56 +07:00
parent d8aa49d0f9
commit f88b7bb175
23 changed files with 401 additions and 1082 deletions

View File

@@ -1,13 +1,13 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
@@ -26,28 +26,30 @@ var _ = Describe("TLS Crypto Setup", func() {
BeforeEach(func() {
handshakeEvent = make(chan struct{}, 2)
cs = NewCryptoSetupTLSServer(
nil,
NewCryptoStreamConn(nil),
nil, // AEAD
css, err := NewCryptoSetupTLSServer(
newCryptoStreamConn(bytes.NewBuffer([]byte{})),
protocol.ConnectionID{},
&mint.Config{},
handshakeEvent,
protocol.VersionTLS,
).(*cryptoSetupTLS)
)
Expect(err).ToNot(HaveOccurred())
cs = css.(*cryptoSetupTLS)
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
})
It("errors when the handshake fails", func() {
alert := mint.AlertBadRecordMAC
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(alert)
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(alert)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
})
It("derives keys", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
@@ -56,10 +58,10 @@ var _ = Describe("TLS Crypto Setup", func() {
})
It("handshakes until it is connected", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerNegotiated).Times(9)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerNegotiated}).Times(9)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
@@ -68,18 +70,17 @@ var _ = Describe("TLS Crypto Setup", func() {
Context("reporting the handshake state", func() {
It("reports before the handshake compeletes", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.PeerCertificates).To(BeNil())
})
It("reports after the handshake completes", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{})
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected}).Times(2)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
@@ -91,9 +92,9 @@ var _ = Describe("TLS Crypto Setup", func() {
Context("escalating crypto", func() {
doHandshake := func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.tls = NewMockMintTLS(mockCtrl)
cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected})
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
@@ -189,33 +190,3 @@ var _ = Describe("TLS Crypto Setup", func() {
})
})
})
var _ = Describe("TLS Crypto Setup, for the client", func() {
var (
cs *cryptoSetupTLS
handshakeEvent chan struct{}
)
BeforeEach(func() {
handshakeEvent = make(chan struct{})
csInt, err := NewCryptoSetupTLSClient(
nil,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
"quic.clemente.io",
handshakeEvent,
nil, // mintTLS
protocol.VersionTLS,
)
Expect(err).ToNot(HaveOccurred())
cs = csInt.(*cryptoSetupTLS)
})
It("returns when a retry is performed", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateClientStart)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(ErrCloseSessionForRetry))
})
})