From 8cde4ab638644164dcd0abde06cecfa6715f9068 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 2 Feb 2020 13:45:52 +0700 Subject: [PATCH] refactor how session tickets are sent Previously, RunHandshake() would send the session tickets. Now, the session initiates the sending of the session ticket. This simplifies the setup a bit, and it will make it possible to include the RTT estimate in the session ticket without accessing the RTTStats concurrently. --- crypto_stream.go | 29 ------------- crypto_stream_test.go | 43 ------------------ internal/handshake/crypto_setup.go | 30 ++----------- internal/handshake/crypto_setup_test.go | 58 ++++++++++--------------- internal/handshake/interface.go | 1 + internal/mocks/crypto_setup.go | 15 +++++++ internal/protocol/params.go | 2 +- session.go | 21 ++++++--- session_test.go | 44 +++++++++++++++++++ 9 files changed, 103 insertions(+), 140 deletions(-) diff --git a/crypto_stream.go b/crypto_stream.go index 557c1d3d6..683ab0b1b 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -22,35 +22,6 @@ type cryptoStream interface { PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame } -type postHandshakeCryptoStream struct { - cryptoStream - - framer framer -} - -func newPostHandshakeCryptoStream(framer framer) cryptoStream { - return &postHandshakeCryptoStream{ - cryptoStream: newCryptoStream(), - framer: framer, - } -} - -// Write writes post-handshake messages. -// For simplicity, post-handshake crypto messages are treated as control frames. -// The framer functions as a stack (LIFO), so if there are multiple writes, -// they will be returned in the opposite order. -// This is acceptable, since post-handshake crypto messages are very rare. -func (s *postHandshakeCryptoStream) Write(p []byte) (int, error) { - n, err := s.cryptoStream.Write(p) - if err != nil { - return n, err - } - for s.cryptoStream.HasData() { - s.framer.QueueControlFrame(s.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) - } - return n, nil -} - type cryptoStreamImpl struct { queue *frameSorter msgBuf []byte diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 5e1b7e8b6..c962587b6 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -180,46 +180,3 @@ var _ = Describe("Crypto Stream", func() { }) }) }) - -var _ = Describe("Post Handshake Crypto Stream", func() { - var ( - cs cryptoStream - framer framer - ) - - BeforeEach(func() { - framer = newFramer(NewMockStreamGetter(mockCtrl), protocol.VersionTLS) - cs = newPostHandshakeCryptoStream(framer) - }) - - It("queues CRYPTO frames when writing data", func() { - n, err := cs.Write([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - n, err = cs.Write([]byte("bar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(3)) - frames, _ := framer.AppendControlFrames(nil, 1000) - Expect(frames).To(HaveLen(2)) - fs := []wire.Frame{frames[0].Frame, frames[1].Frame} - Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("foo")})) - Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})) - }) - - It("splits large writes into multiple frames", func() { - size := 10 * protocol.MaxPostHandshakeCryptoFrameSize - n, err := cs.Write(make([]byte, size)) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(BeEquivalentTo(size)) - frames, _ := framer.AppendControlFrames(nil, protocol.MaxByteCount) - Expect(frames).To(HaveLen(11)) // one more for framing overhead - var dataLen int - for _, f := range frames { - Expect(f.Frame.Length(protocol.VersionTLS)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize)) - Expect(f.Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) - dataLen += len(f.Frame.(*wire.CryptoFrame).Data) - } - Expect(dataLen).To(BeEquivalentTo(size)) - }) - -}) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 80bfb6dd5..eb235729e 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -16,12 +16,8 @@ import ( "github.com/marten-seemann/qtls" ) -const ( - // TLS unexpected_message alert - alertUnexpectedMessage uint8 = 10 - // TLS internal error - alertInternalError uint8 = 80 -) +// TLS unexpected_message alert +const alertUnexpectedMessage uint8 = 10 type messageType uint8 @@ -111,7 +107,6 @@ type cryptoSetup struct { handshakeOpener LongHeaderOpener handshakeSealer LongHeaderSealer - oneRTTStream io.Writer aead *updatableAEAD has1RTTSealer bool has1RTTOpener bool @@ -124,7 +119,6 @@ var _ CryptoSetup = &cryptoSetup{} func NewCryptoSetupClient( initialStream io.Writer, handshakeStream io.Writer, - oneRTTStream io.Writer, connID protocol.ConnectionID, remoteAddr net.Addr, tp *TransportParameters, @@ -137,7 +131,6 @@ func NewCryptoSetupClient( cs, clientHelloWritten := newCryptoSetup( initialStream, handshakeStream, - oneRTTStream, connID, tp, runner, @@ -155,7 +148,6 @@ func NewCryptoSetupClient( func NewCryptoSetupServer( initialStream io.Writer, handshakeStream io.Writer, - oneRTTStream io.Writer, connID protocol.ConnectionID, remoteAddr net.Addr, tp *TransportParameters, @@ -168,7 +160,6 @@ func NewCryptoSetupServer( cs, _ := newCryptoSetup( initialStream, handshakeStream, - oneRTTStream, connID, tp, runner, @@ -185,7 +176,6 @@ func NewCryptoSetupServer( func newCryptoSetup( initialStream io.Writer, handshakeStream io.Writer, - oneRTTStream io.Writer, connID protocol.ConnectionID, tp *TransportParameters, runner handshakeRunner, @@ -202,7 +192,6 @@ func newCryptoSetup( initialSealer: initialSealer, initialOpener: initialOpener, handshakeStream: handshakeStream, - oneRTTStream: oneRTTStream, aead: newUpdatableAEAD(rttStats, logger), readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, @@ -251,10 +240,6 @@ func (h *cryptoSetup) RunHandshake() { select { case <-handshakeComplete: // return when the handshake is done h.runner.OnHandshakeComplete() - // send a session ticket - if h.perspective == protocol.PerspectiveServer { - h.maybeSendSessionTicket() - } case <-h.closeChan: close(h.messageChan) // wait until the Handshake() go routine has returned @@ -475,20 +460,13 @@ func (h *cryptoSetup) handlePeerParamsFromSessionStateImpl(data []byte) (*Transp } // only valid for the server -func (h *cryptoSetup) maybeSendSessionTicket() { +func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { var appData []byte // Save transport parameters to the session ticket if we're allowing 0-RTT. if h.tlsConf.MaxEarlyData > 0 { appData = (&sessionTicket{Parameters: h.ourParams}).Marshal() } - ticket, err := h.conn.GetSessionTicket(appData) - if err != nil { - h.onError(alertInternalError, err.Error()) - return - } - if ticket != nil { - h.oneRTTStream.Write(ticket) - } + return h.conn.GetSessionTicket(appData) } // accept0RTT is called for the server when receiving the client's session ticket. diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 954b9fc59..ee585adde 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "errors" - "io/ioutil" "math/big" "time" @@ -55,12 +54,11 @@ func (s *stream) Write(b []byte) (int, error) { var _ = Describe("Crypto Setup TLS", func() { var clientConf, serverConf *tls.Config - initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */, *stream /* 1-RTT */) { + initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { chunkChan := make(chan chunk, 100) initialStream := newStream(chunkChan, protocol.EncryptionInitial) handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) - oneRTTStream := newStream(chunkChan, protocol.Encryption1RTT) - return chunkChan, initialStream, handshakeStream, oneRTTStream + return chunkChan, initialStream, handshakeStream } BeforeEach(func() { @@ -89,7 +87,6 @@ var _ = Describe("Crypto Setup TLS", func() { server := NewCryptoSetupServer( &bytes.Buffer{}, &bytes.Buffer{}, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -117,11 +114,10 @@ var _ = Describe("Crypto Setup TLS", func() { sErrChan := make(chan error, 1) runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - _, sInitialStream, sHandshakeStream, _ := initStreams() + _, sInitialStream, sHandshakeStream := initStreams() server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -153,13 +149,12 @@ var _ = Describe("Crypto Setup TLS", func() { It("errors when a message is received at the wrong encryption level", func() { sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream, _ := initStreams() + _, sInitialStream, sHandshakeStream := initStreams() runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -194,13 +189,12 @@ var _ = Describe("Crypto Setup TLS", func() { It("returns Handshake() when handling a message fails", func() { sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream, _ := initStreams() + _, sInitialStream, sHandshakeStream := initStreams() runner := NewMockHandshakeRunner(mockCtrl) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -230,11 +224,10 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("returns Handshake() when it is closed", func() { - _, sInitialStream, sHandshakeStream, _ := initStreams() + _, sInitialStream, sHandshakeStream := initStreams() server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -305,6 +298,11 @@ var _ = Describe("Crypto Setup TLS", func() { go func() { defer GinkgoRecover() server.RunHandshake() + ticket, err := server.GetSessionTicket() + Expect(err).ToNot(HaveOccurred()) + if ticket != nil { + client.HandleMessage(ticket, protocol.Encryption1RTT) + } close(done) }() @@ -314,7 +312,7 @@ var _ = Describe("Crypto Setup TLS", func() { handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { var cHandshakeComplete bool - cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams() + cChunkChan, cInitialStream, cHandshakeStream := initStreams() cErrChan := make(chan error, 1) cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) @@ -323,7 +321,6 @@ var _ = Describe("Crypto Setup TLS", func() { client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - cOneRTTStream, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -335,7 +332,7 @@ var _ = Describe("Crypto Setup TLS", func() { ) var sHandshakeComplete bool - sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams() + sChunkChan, sInitialStream, sHandshakeStream := initStreams() sErrChan := make(chan error, 1) sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) @@ -345,7 +342,6 @@ var _ = Describe("Crypto Setup TLS", func() { server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - sOneRTTStream, protocol.ConnectionID{}, nil, &TransportParameters{StatelessResetToken: &token}, @@ -394,11 +390,10 @@ var _ = Describe("Crypto Setup TLS", func() { It("signals when it has written the ClientHello", func() { runner := NewMockHandshakeRunner(mockCtrl) - cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() + cChunkChan, cInitialStream, cHandshakeStream := initStreams() client, chChan := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -431,7 +426,7 @@ var _ = Describe("Crypto Setup TLS", func() { It("receives transport parameters", func() { var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters - cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() + cChunkChan, cInitialStream, cHandshakeStream := initStreams() cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second} cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp }) @@ -439,7 +434,6 @@ var _ = Describe("Crypto Setup TLS", func() { client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, cTransportParameters, @@ -450,7 +444,7 @@ var _ = Describe("Crypto Setup TLS", func() { utils.DefaultLogger.WithPrefix("client"), ) - sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() + sChunkChan, sInitialStream, sHandshakeStream := initStreams() var token [16]byte sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp }) @@ -462,7 +456,6 @@ var _ = Describe("Crypto Setup TLS", func() { server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, sTransportParameters, @@ -487,14 +480,13 @@ var _ = Describe("Crypto Setup TLS", func() { Context("with session tickets", func() { It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { - cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() + cChunkChan, cInitialStream, cHandshakeStream := initStreams() cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -505,14 +497,13 @@ var _ = Describe("Crypto Setup TLS", func() { utils.DefaultLogger.WithPrefix("client"), ) - sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() + sChunkChan, sInitialStream, sHandshakeStream := initStreams() sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnHandshakeComplete() server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -544,14 +535,13 @@ var _ = Describe("Crypto Setup TLS", func() { }) It("errors when handling the NewSessionTicket fails", func() { - cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() + cChunkChan, cInitialStream, cHandshakeStream := initStreams() cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -562,14 +552,13 @@ var _ = Describe("Crypto Setup TLS", func() { utils.DefaultLogger.WithPrefix("client"), ) - sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() + sChunkChan, sInitialStream, sHandshakeStream := initStreams() sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnHandshakeComplete() server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -670,16 +659,16 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(client.ConnectionState().DidResume).To(BeFalse()) csc.EXPECT().Get(gomock.Any()).Return(state, true) + csc.EXPECT().Put(gomock.Any(), nil) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) - cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() + cChunkChan, cInitialStream, cHandshakeStream := initStreams() cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnHandshakeComplete() client, clientHelloChan := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, @@ -690,14 +679,13 @@ var _ = Describe("Crypto Setup TLS", func() { utils.DefaultLogger.WithPrefix("client"), ) - sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() + sChunkChan, sInitialStream, sHandshakeStream := initStreams() sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnHandshakeComplete() server = NewCryptoSetupServer( sInitialStream, sHandshakeStream, - ioutil.Discard, protocol.ConnectionID{}, nil, &TransportParameters{}, diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 7183a11ac..07fb64c7c 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -72,6 +72,7 @@ type CryptoSetup interface { RunHandshake() io.Closer ChangeConnectionID(protocol.ConnectionID) + GetSessionTicket() ([]byte, error) HandleMessage([]byte, protocol.EncryptionLevel) bool SetLargest1RTTAcked(protocol.PacketNumber) diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 382c1dc0a..ed03c433a 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -208,6 +208,21 @@ func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) } +// GetSessionTicket mocks base method +func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSessionTicket") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSessionTicket indicates an expected call of GetSessionTicket +func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket)) +} + // HandleMessage mocks base method func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { m.ctrl.T.Helper() diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 000da23ca..fe88b495b 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -110,7 +110,7 @@ const MinStreamFrameSize ByteCount = 128 // MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames // we send after the handshake completes. -const MaxPostHandshakeCryptoFrameSize ByteCount = 1000 +const MaxPostHandshakeCryptoFrameSize = 1000 // MaxAckFrameSize is the maximum size for an ACK frame that we write // Due to the varint encoding, ACK frames can grow (almost) indefinitely large. diff --git a/session.go b/session.go index ebd14ab1f..4d8ec6f63 100644 --- a/session.go +++ b/session.go @@ -52,6 +52,7 @@ type cryptoStreamHandler interface { ChangeConnectionID(protocol.ConnectionID) SetLargest1RTTAcked(protocol.PacketNumber) DropHandshakeKeys() + GetSessionTicket() ([]byte, error) io.Closer ConnectionState() handshake.ConnectionState } @@ -141,6 +142,7 @@ type session struct { frameParser wire.FrameParser packer packer + oneRTTStream cryptoStream // only set for the server cryptoStreamHandler cryptoStreamHandler receivedPackets chan *receivedPacket @@ -213,6 +215,7 @@ var newSession = func( handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, + oneRTTStream: newCryptoStream(), perspective: protocol.PerspectiveServer, handshakeCompleteChan: make(chan struct{}), logger: logger, @@ -244,7 +247,6 @@ var newSession = func( s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger) initialStream := newCryptoStream() handshakeStream := newCryptoStream() - oneRTTStream := newPostHandshakeCryptoStream(s.framer) params := &handshake.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, @@ -263,7 +265,6 @@ var newSession = func( cs := handshake.NewCryptoSetupServer( initialStream, handshakeStream, - oneRTTStream, clientDestConnID, conn.RemoteAddr(), params, @@ -297,7 +298,7 @@ var newSession = func( s.version, ) s.unpacker = newPacketUnpacker(cs, s.version) - s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) return s } @@ -348,7 +349,6 @@ var newClientSession = func( s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger) initialStream := newCryptoStream() handshakeStream := newCryptoStream() - oneRTTStream := newPostHandshakeCryptoStream(s.framer) params := &handshake.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, @@ -365,7 +365,6 @@ var newClientSession = func( cs, clientHelloWritten := handshake.NewCryptoSetupClient( initialStream, handshakeStream, - oneRTTStream, destConnID, conn.RemoteAddr(), params, @@ -382,7 +381,7 @@ var newClientSession = func( ) s.clientHelloWritten = clientHelloWritten s.cryptoStreamHandler = cs - s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) s.unpacker = newPacketUnpacker(cs, s.version) s.packer = newPacketPacker( srcConnID, @@ -633,6 +632,16 @@ func (s *session) handleHandshakeComplete() { s.sentPacketHandler.SetHandshakeComplete() if s.perspective == protocol.PerspectiveServer { + ticket, err := s.cryptoStreamHandler.GetSessionTicket() + if err != nil { + s.closeLocal(err) + } + if ticket != nil { + s.oneRTTStream.Write(ticket) + for s.oneRTTStream.HasData() { + s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) + } + } token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) if err != nil { s.closeLocal(err) diff --git a/session_test.go b/session_test.go index 5d5a21931..6163648ff 100644 --- a/session_test.go +++ b/session_test.go @@ -1191,6 +1191,7 @@ var _ = Describe("Session", func() { <-finishHandshake cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().DropHandshakeKeys() + cryptoSetup.EXPECT().GetSessionTicket() close(sess.handshakeCompleteChan) sess.run() }() @@ -1210,6 +1211,47 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) + It("sends a session ticket when the handshake completes", func() { + const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 + packer.EXPECT().PackPacket().AnyTimes() + finishHandshake := make(chan struct{}) + sessionRunner.EXPECT().Retire(clientDestConnID) + go func() { + defer GinkgoRecover() + <-finishHandshake + cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().DropHandshakeKeys() + cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) + close(sess.handshakeCompleteChan) + sess.run() + }() + + handshakeCtx := sess.HandshakeComplete() + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token + close(finishHandshake) + Eventually(handshakeCtx.Done()).Should(BeClosed()) + frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) + var count int + var s int + for _, f := range frames { + if cf, ok := f.Frame.(*wire.CryptoFrame); ok { + count++ + s += len(cf.Data) + Expect(f.Length(sess.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize)) + } + } + Expect(size).To(BeEquivalentTo(s)) + // make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + expectReplaceWithClosed() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + cryptoSetup.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + sess.shutdown() + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { packer.EXPECT().PackPacket().AnyTimes() streamManager.EXPECT().CloseWithError(gomock.Any()) @@ -1247,6 +1289,7 @@ var _ = Describe("Session", func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().DropHandshakeKeys() + cryptoSetup.EXPECT().GetSessionTicket() mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token mconn.EXPECT().Write(gomock.Any()) close(sess.handshakeCompleteChan) @@ -1492,6 +1535,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1) mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) close(sess.handshakeCompleteChan)