From 6eb72f712d35061e18418a11f4e9f7dffb6b2da6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 6 Jul 2019 11:05:00 +0700 Subject: [PATCH 1/6] implement a session method to tell if the handshake is already complete --- session.go | 12 ++++++++++-- session_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/session.go b/session.go index 509894f6..845e98e2 100644 --- a/session.go +++ b/session.go @@ -135,8 +135,10 @@ type session struct { connectionClosePacket *packedPacket packetsReceivedAfterClose int - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + handshakeCtx context.Context + handshakeCtxCancel context.CancelFunc undecryptablePackets []*receivedPacket @@ -356,6 +358,7 @@ func (s *session) postSetup() error { s.sendingScheduled = make(chan struct{}, 1) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) s.timer = utils.NewTimer() now := time.Now() @@ -465,6 +468,10 @@ runLoop: return closeErr.err } +func (s *session) HandshakeComplete() context.Context { + return s.handshakeCtx +} + func (s *session) Context() context.Context { return s.ctx } @@ -505,6 +512,7 @@ func (s *session) idleTimeoutStartTime() time.Time { func (s *session) handleHandshakeComplete() { s.handshakeComplete = true s.handshakeCompleteChan = nil // prevent this case from ever being selected again + s.handshakeCtxCancel() s.sessionRunner.OnHandshakeComplete(s) // The client completes the handshake first (after sending the CFIN). diff --git a/session_test.go b/session_test.go index 3a5a4d02..f46b7581 100644 --- a/session_test.go +++ b/session_test.go @@ -1153,6 +1153,48 @@ var _ = Describe("Session", func() { }) }) + It("cancels the HandshakeComplete context when the handshake completes", func() { + sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()) + packer.EXPECT().PackPacket().AnyTimes() + finishHandshake := make(chan struct{}) + go func() { + defer GinkgoRecover() + <-finishHandshake + cryptoSetup.EXPECT().RunHandshake() + close(sess.handshakeCompleteChan) + sess.run() + }() + handshakeCtx := sess.HandshakeComplete() + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + close(finishHandshake) + Eventually(handshakeCtx.Done()).Should(BeClosed()) + //make sure the go routine returns + streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + cryptoSetup.EXPECT().Close() + Expect(sess.Close()).To(Succeed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + + It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { + packer.EXPECT().PackPacket().AnyTimes() + streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + cryptoSetup.EXPECT().Close() + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake() + sess.run() + }() + handshakeCtx := sess.HandshakeComplete() + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + sess.closeLocal(errors.New("handshake error")) + Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + It("sends a 1-RTT packet when the handshake completes", func() { done := make(chan struct{}) gomock.InOrder( From 3724f91f339af2665da568e765d7a41fb7e7a8a3 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 7 Jul 2019 08:37:36 +0700 Subject: [PATCH 2/6] simplify setup of the server --- server.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/server.go b/server.go index 6f035c41..66eab853 100644 --- a/server.go +++ b/server.go @@ -154,25 +154,21 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, if err != nil { return nil, err } + tokenGenerator, err := handshake.NewTokenGenerator() + if err != nil { + return nil, err + } s := &server{ conn: conn, tlsConf: tlsConf, config: config, + tokenGenerator: tokenGenerator, sessionHandler: sessionHandler, sessionQueue: make(chan Session), errorChan: make(chan struct{}), newSession: newSession, logger: utils.DefaultLogger.WithPrefix("server"), } - if err := s.setup(); err != nil { - return nil, err - } - sessionHandler.SetServer(s) - s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - return s, nil -} - -func (s *server) setup() error { s.sessionRunner = &runner{ packetHandlerManager: s.sessionHandler, onHandshakeCompleteImpl: func(sess Session) { @@ -188,12 +184,9 @@ func (s *server) setup() error { }() }, } - tokenGenerator, err := handshake.NewTokenGenerator() - if err != nil { - return err - } - s.tokenGenerator = tokenGenerator - return nil + sessionHandler.SetServer(s) + s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) + return s, nil } var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { From b06d7b0cba9b889118a9b8b77c665d14bb292987 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 6 Jul 2019 17:37:15 +0700 Subject: [PATCH 3/6] use the HandshakeComplete() context for accepting sessions in the server --- client.go | 8 ++---- client_test.go | 17 +++++++++--- mock_quic_session_test.go | 14 ++++++++++ mock_session_runner_test.go | 12 --------- server.go | 50 +++++++++++++++------------------- server_test.go | 54 ++++++++----------------------------- session.go | 1 - session_test.go | 21 ++++++--------- 8 files changed, 70 insertions(+), 107 deletions(-) diff --git a/client.go b/client.go index fc6ed290..5dac6a3a 100644 --- a/client.go +++ b/client.go @@ -292,7 +292,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error { return ctx.Err() case err := <-errorChan: return err - case <-c.handshakeChan: + case <-c.session.HandshakeComplete().Done(): // handshake successfully completed return nil } @@ -371,13 +371,9 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error { c.mutex.Lock() defer c.mutex.Unlock() - runner := &runner{ - packetHandlerManager: c.packetHandlers, - onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) }, - } sess, err := newClientSession( c.conn, - runner, + c.packetHandlers, c.destConnID, c.srcConnID, c.config, diff --git a/client_test.go b/client_test.go index 0bc1e2ef..fb326045 100644 --- a/client_test.go +++ b/client_test.go @@ -149,6 +149,7 @@ var _ = Describe("Client", func() { remoteAddrChan <- conn.RemoteAddr().String() sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run() + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond}) @@ -179,6 +180,7 @@ var _ = Describe("Client", func() { hostnameChan <- tlsConf.ServerName sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run() + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } tlsConf.ServerName = "foobar" @@ -208,6 +210,7 @@ var _ = Describe("Client", func() { ) (quicSession, error) { hostnameChan <- tlsConf.ServerName sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().run() return sess, nil } @@ -243,7 +246,9 @@ var _ = Describe("Client", func() { ) (quicSession, error) { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Do(func() { close(run) }) - runner.OnHandshakeComplete(sess) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + sess.EXPECT().HandshakeComplete().Return(ctx) return sess, nil } s, err := Dial( @@ -258,7 +263,7 @@ var _ = Describe("Client", func() { Eventually(run).Should(BeClosed()) }) - It("returns an error that occurs while waiting for the connection to become secure", func() { + It("returns an error that occurs while waiting for the handshake to complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) @@ -279,6 +284,7 @@ var _ = Describe("Client", func() { ) (quicSession, error) { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Return(testErr) + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) @@ -303,6 +309,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run().Do(func() { <-sessionRunning }) + sess.EXPECT().HandshakeComplete().Return(context.Background()) newClientSession = func( _ connection, _ sessionRunner, @@ -366,6 +373,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run().Do(func() { runner.Retire(connID) }) + sess.EXPECT().HandshakeComplete().Return(context.Background()) _, err := DialContext( context.Background(), @@ -411,6 +419,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run().Do(func() { <-run }) + sess.EXPECT().HandshakeComplete().Return(context.Background()) done := make(chan struct{}) go func() { @@ -497,7 +506,7 @@ var _ = Describe("Client", func() { }) }) - It("creates new TLS sessions with the right parameters", func() { + It("creates new sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) @@ -527,6 +536,7 @@ var _ = Describe("Client", func() { // TODO: check connection IDs? sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run() + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) @@ -571,6 +581,7 @@ var _ = Describe("Client", func() { Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Return(testErr) + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } _, err := Dial( diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 30bcd90b..99ba98ac 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -137,6 +137,20 @@ func (mr *MockQuicSessionMockRecorder) GetVersion() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockQuicSession)(nil).GetVersion)) } +// HandshakeComplete mocks base method +func (m *MockQuicSession) HandshakeComplete() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandshakeComplete") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// HandshakeComplete indicates an expected call of HandshakeComplete +func (mr *MockQuicSessionMockRecorder) HandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockQuicSession)(nil).HandshakeComplete)) +} + // LocalAddr mocks base method func (m *MockQuicSession) LocalAddr() net.Addr { m.ctrl.T.Helper() diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index fad0e813..9bb54a05 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -46,18 +46,6 @@ func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1) } -// OnHandshakeComplete mocks base method -func (m *MockSessionRunner) OnHandshakeComplete(arg0 Session) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnHandshakeComplete", arg0) -} - -// OnHandshakeComplete indicates an expected call of OnHandshakeComplete -func (mr *MockSessionRunnerMockRecorder) OnHandshakeComplete(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).OnHandshakeComplete), arg0) -} - // Remove mocks base method func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/server.go b/server.go index 66eab853..c0269ae4 100644 --- a/server.go +++ b/server.go @@ -46,6 +46,7 @@ type packetHandlerManager interface { type quicSession interface { Session + HandshakeComplete() context.Context handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber getPerspective() protocol.Perspective @@ -56,23 +57,12 @@ type quicSession interface { } type sessionRunner interface { - OnHandshakeComplete(Session) Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) AddResetToken([16]byte, packetHandler) RemoveResetToken([16]byte) } -type runner struct { - packetHandlerManager - - onHandshakeCompleteImpl func(Session) -} - -func (r *runner) OnHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) } - -var _ sessionRunner = &runner{} - // A Listener of QUIC type server struct { mutex sync.Mutex @@ -99,8 +89,6 @@ type server struct { sessionQueue chan Session sessionQueueLen int32 // to be used as an atomic - sessionRunner sessionRunner - logger utils.Logger } @@ -169,21 +157,6 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, newSession: newSession, logger: utils.DefaultLogger.WithPrefix("server"), } - s.sessionRunner = &runner{ - packetHandlerManager: s.sessionHandler, - onHandshakeCompleteImpl: func(sess Session) { - go func() { - atomic.AddInt32(&s.sessionQueueLen, 1) - select { - case s.sessionQueue <- sess: - // blocks until the session is accepted - case <-sess.Context().Done(): - atomic.AddInt32(&s.sessionQueueLen, -1) - // don't pass sessions that were already closed to Accept() - } - }() - }, - } sessionHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil @@ -454,7 +427,7 @@ func (s *server) createNewSession( } sess, err := s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, - s.sessionRunner, + s.sessionHandler, clientDestConnID, destConnID, srcConnID, @@ -469,9 +442,28 @@ func (s *server) createNewSession( return nil, err } go sess.run() + go s.waitUntilHandshakeComplete(sess) return sess, nil } +func (s *server) waitUntilHandshakeComplete(sess quicSession) { + sessCtx := sess.Context() + select { + case <-sess.HandshakeComplete().Done(): + case <-sessCtx.Done(): + return + } + + atomic.AddInt32(&s.sessionQueueLen, 1) + select { + case s.sessionQueue <- sess: + // blocks until the session is accepted + case <-sessCtx.Done(): + atomic.AddInt32(&s.sessionQueueLen, -1) + // don't pass sessions that were already closed to Accept() + } +} + func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID) if err != nil { diff --git a/server_test.go b/server_test.go index caa5935b..ca25b7fb 100644 --- a/server_test.go +++ b/server_test.go @@ -303,6 +303,8 @@ var _ = Describe("Server", func() { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(p) sess.EXPECT().run().Do(func() { close(run) }) + sess.EXPECT().Context().Return(context.Background()) + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } @@ -349,7 +351,9 @@ var _ = Describe("Server", func() { sess.EXPECT().handlePacket(p) sess.EXPECT().run() sess.EXPECT().Context().Return(context.Background()) - runner.OnHandshakeComplete(sess) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + sess.EXPECT().HandshakeComplete().Return(ctx) return sess, nil } @@ -407,7 +411,9 @@ var _ = Describe("Server", func() { sess.EXPECT().handlePacket(p) sess.EXPECT().run() sess.EXPECT().Context().Return(ctx) - runner.OnHandshakeComplete(sess) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + sess.EXPECT().HandshakeComplete().Return(ctx) close(sessionCreated) return sess, nil } @@ -493,7 +499,7 @@ var _ = Describe("Server", func() { close(done) }() - completeHandshake := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) // handshake context serv.newSession = func( _ connection, runner sessionRunner, @@ -507,10 +513,7 @@ var _ = Describe("Server", func() { _ utils.Logger, _ protocol.VersionNumber, ) (quicSession, error) { - go func() { - <-completeHandshake - runner.OnHandshakeComplete(sess) - }() + sess.EXPECT().HandshakeComplete().Return(ctx) sess.EXPECT().run().Do(func() {}) sess.EXPECT().Context().Return(context.Background()) return sess, nil @@ -518,44 +521,9 @@ var _ = Describe("Server", func() { _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Consistently(done).ShouldNot(BeClosed()) - close(completeHandshake) + cancel() // complete the handshake Eventually(done).Should(BeClosed()) }) - - It("never blocks when calling the onHandshakeComplete callback", func() { - const num = 50 - - runs := make(chan struct{}, num) - contexts := make(chan struct{}, num) - serv.newSession = func( - _ connection, - runner sessionRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ *handshake.TransportParameters, - _ *handshake.TokenGenerator, - _ utils.Logger, - _ protocol.VersionNumber, - ) (quicSession, error) { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run().Do(func() { runs <- struct{}{} }) - sess.EXPECT().Context().Do(func() { contexts <- struct{}{} }).Return(context.Background()) - runner.OnHandshakeComplete(sess) - return sess, nil - } - - go func() { - for i := 0; i < num; i++ { - _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - } - }() - Eventually(runs).Should(HaveLen(num)) - Eventually(contexts).Should(HaveLen(num)) - }) }) }) diff --git a/session.go b/session.go index 845e98e2..793b9d8f 100644 --- a/session.go +++ b/session.go @@ -513,7 +513,6 @@ func (s *session) handleHandshakeComplete() { s.handshakeComplete = true s.handshakeCompleteChan = nil // prevent this case from ever being selected again s.handshakeCtxCancel() - s.sessionRunner.OnHandshakeComplete(s) // The client completes the handshake first (after sending the CFIN). // We need to make sure it learns about the server completing the handshake, diff --git a/session_test.go b/session_test.go index f46b7581..1a92f429 100644 --- a/session_test.go +++ b/session_test.go @@ -1154,7 +1154,6 @@ var _ = Describe("Session", func() { }) It("cancels the HandshakeComplete context when the handshake completes", func() { - sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()) packer.EXPECT().PackPacket().AnyTimes() finishHandshake := make(chan struct{}) go func() { @@ -1197,17 +1196,14 @@ var _ = Describe("Session", func() { It("sends a 1-RTT packet when the handshake completes", func() { done := make(chan struct{}) - gomock.InOrder( - sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()), - packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { - defer close(done) - return &packedPacket{ - header: &wire.ExtendedHeader{}, - buffer: getPacketBuffer(), - }, nil - }), - packer.EXPECT().PackPacket().AnyTimes(), - ) + packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { + defer close(done) + return &packedPacket{ + header: &wire.ExtendedHeader{}, + buffer: getPacketBuffer(), + }, nil + }) + packer.EXPECT().PackPacket().AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -1451,7 +1447,6 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sessionRunner.EXPECT().OnHandshakeComplete(sess) cryptoSetup.EXPECT().RunHandshake() close(sess.handshakeCompleteChan) err := sess.run() From cc76441539634601297311a62db3b52a6d6ce869 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 27 Jul 2019 01:06:43 -0400 Subject: [PATCH 4/6] implement a function that blocks until an early session can be used --- session.go | 10 ++++++++++ session_test.go | 3 +++ 2 files changed, 13 insertions(+) diff --git a/session.go b/session.go index 793b9d8f..6030c303 100644 --- a/session.go +++ b/session.go @@ -143,6 +143,7 @@ type session struct { undecryptablePackets []*receivedPacket clientHelloWritten <-chan struct{} + earlySessionReadyChan chan struct{} handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeComplete bool @@ -345,6 +346,7 @@ func (s *session) preSetup() { s.rttStats, s.logger, ) + s.earlySessionReadyChan = make(chan struct{}) if s.config.QuicTracer != nil { s.traceCallback = func(ev quictrace.Event) { s.config.QuicTracer.Trace(s.origDestConnID, ev) @@ -468,6 +470,11 @@ runLoop: return closeErr.err } +// blocks until the early session can be used +func (s *session) earlySessionReady() <-chan struct{} { + return s.earlySessionReadyChan +} + func (s *session) HandshakeComplete() context.Context { return s.handshakeCtx } @@ -1013,6 +1020,9 @@ func (s *session) processTransportParameters(data []byte) { if params.StatelessResetToken != nil { s.sessionRunner.AddResetToken(*params.StatelessResetToken, s) } + // On the server side, the early session is ready as soon as we processed + // the client's transport parameters. + close(s.earlySessionReadyChan) } func (s *session) processTransportParametersForClient(data []byte) (*handshake.TransportParameters, error) { diff --git a/session_test.go b/session_test.go index 1a92f429..82e606cc 100644 --- a/session_test.go +++ b/session_test.go @@ -1286,7 +1286,10 @@ var _ = Describe("Session", func() { } streamManager.EXPECT().UpdateLimits(params) packer.EXPECT().HandleTransportParameters(params) + Expect(sess.earlySessionReady()).ToNot(BeClosed()) sess.processTransportParameters(params.Marshal()) + Expect(sess.earlySessionReady()).To(BeClosed()) + // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().Retire(gomock.Any()) From 5cbb8d6597eb1122885d9a4192067fad7674625c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 27 Jul 2019 01:55:46 -0400 Subject: [PATCH 5/6] implement a listener that returns early sessions --- interface.go | 24 ++ mock_quic_session_test.go | 14 + server.go | 126 ++++--- server_test.go | 694 +++++++++++++++++++++++--------------- session.go | 1 + 5 files changed, 551 insertions(+), 308 deletions(-) diff --git a/interface.go b/interface.go index 8059e25b..f35b86dc 100644 --- a/interface.go +++ b/interface.go @@ -171,6 +171,19 @@ type Session interface { ConnectionState() tls.ConnectionState } +// An EarlySession is a session that is handshaking. +// Data sent during the handshake is encrypted using the forward secure keys. +// When using client certificates, the client's identity is only verified +// after completion of the handshake. +type EarlySession interface { + Session + + // Blocks until the handshake completes (or fails). + // Data sent before completion of the handshake is encrypted with 1-RTT keys. + // Note that the client's identity hasn't been verified yet. + HandshakeComplete() context.Context +} + // Config contains all configuration data needed for a QUIC server or client. type Config struct { // The QUIC versions that can be negotiated. @@ -234,3 +247,14 @@ type Listener interface { // Accept returns new sessions. It should be called in a loop. Accept(context.Context) (Session, error) } + +// An EarlyListener listens for incoming QUIC connections, +// and returns them before the handshake completes. +type EarlyListener interface { + // Close the server. All active sessions will be closed. + Close() error + // Addr returns the local network addr that the server is listening on. + Addr() net.Addr + // Accept returns new early sessions. It should be called in a loop. + Accept(context.Context) (EarlySession, error) +} diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 99ba98ac..f611a5b6 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -277,6 +277,20 @@ func (mr *MockQuicSessionMockRecorder) destroy(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicSession)(nil).destroy), arg0) } +// earlySessionReady mocks base method +func (m *MockQuicSession) earlySessionReady() <-chan struct{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "earlySessionReady") + ret0, _ := ret[0].(<-chan struct{}) + return ret0 +} + +// earlySessionReady indicates an expected call of earlySessionReady +func (mr *MockQuicSessionMockRecorder) earlySessionReady() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "earlySessionReady", reflect.TypeOf((*MockQuicSession)(nil).earlySessionReady)) +} + // getPerspective mocks base method func (m *MockQuicSession) getPerspective() protocol.Perspective { m.ctrl.T.Helper() diff --git a/server.go b/server.go index c0269ae4..b349478b 100644 --- a/server.go +++ b/server.go @@ -45,8 +45,8 @@ type packetHandlerManager interface { } type quicSession interface { - Session - HandshakeComplete() context.Context + EarlySession + earlySessionReady() <-chan struct{} handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber getPerspective() protocol.Perspective @@ -64,9 +64,11 @@ type sessionRunner interface { } // A Listener of QUIC -type server struct { +type baseServer struct { mutex sync.Mutex + acceptEarlySessions bool + tlsConf *tls.Config config *Config @@ -86,19 +88,40 @@ type server struct { errorChan chan struct{} closed bool - sessionQueue chan Session + sessionQueue chan quicSession sessionQueueLen int32 // to be used as an atomic logger utils.Logger } -var _ Listener = &server{} -var _ unknownPacketHandler = &server{} +var _ Listener = &baseServer{} +var _ unknownPacketHandler = &baseServer{} + +type earlyServer struct{ *baseServer } + +var _ EarlyListener = &earlyServer{} + +func (s *earlyServer) Accept(ctx context.Context) (EarlySession, error) { + return s.baseServer.accept(ctx) +} // ListenAddr creates a QUIC server listening on a given address. // The tls.Config must not be nil and must contain a certificate configuration. // The quic.Config may be nil, in that case the default values will be used. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { + return listenAddr(addr, tlsConf, config, false) +} + +// ListenAddrEarly works like ListenAddr, but it returns sessions before the handshake completes. +func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { + s, err := listenAddr(addr, tlsConf, config, true) + if err != nil { + return nil, err + } + return &earlyServer{s}, nil +} + +func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -107,7 +130,7 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err if err != nil { return nil, err } - serv, err := listen(conn, tlsConf, config) + serv, err := listen(conn, tlsConf, config, acceptEarly) if err != nil { return nil, err } @@ -123,10 +146,20 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err // Furthermore, it must define an application control (using NextProtos). // The quic.Config may be nil, in that case the default values will be used. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { - return listen(conn, tlsConf, config) + return listen(conn, tlsConf, config, false) } -func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) { +// ListenEarly works like Listen, but it returns sessions before the handshake completes. +func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { + s, err := listen(conn, tlsConf, config, true) + if err != nil { + return nil, err + } + s.acceptEarlySessions = true + return &earlyServer{s}, nil +} + +func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { // TODO(#1655): only require that tls.Config.Certificates or tls.Config.GetCertificate is set if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") @@ -146,16 +179,17 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, if err != nil { return nil, err } - s := &server{ - conn: conn, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - sessionHandler: sessionHandler, - sessionQueue: make(chan Session), - errorChan: make(chan struct{}), - newSession: newSession, - logger: utils.DefaultLogger.WithPrefix("server"), + s := &baseServer{ + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + sessionHandler: sessionHandler, + sessionQueue: make(chan quicSession), + errorChan: make(chan struct{}), + newSession: newSession, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlySessions: acceptEarly, } sessionHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) @@ -248,13 +282,17 @@ func populateServerConfig(config *Config) *Config { } } -// Accept returns newly openend sessions -func (s *server) Accept(ctx context.Context) (Session, error) { - var sess Session +// Accept returns sessions that already completed the handshake. +// It is only valid if acceptEarlySessions is false. +func (s *baseServer) Accept(ctx context.Context) (Session, error) { + return s.accept(ctx) +} + +func (s *baseServer) accept(ctx context.Context) (quicSession, error) { select { case <-ctx.Done(): return nil, ctx.Err() - case sess = <-s.sessionQueue: + case sess := <-s.sessionQueue: atomic.AddInt32(&s.sessionQueueLen, -1) return sess, nil case <-s.errorChan: @@ -263,7 +301,7 @@ func (s *server) Accept(ctx context.Context) (Session, error) { } // Close the server -func (s *server) Close() error { +func (s *baseServer) Close() error { s.mutex.Lock() defer s.mutex.Unlock() if s.closed { @@ -284,7 +322,7 @@ func (s *server) Close() error { return err } -func (s *server) setCloseError(e error) { +func (s *baseServer) setCloseError(e error) { s.mutex.Lock() defer s.mutex.Unlock() if s.closed { @@ -296,11 +334,11 @@ func (s *server) setCloseError(e error) { } // Addr returns the server's network address -func (s *server) Addr() net.Addr { +func (s *baseServer) Addr() net.Addr { return s.conn.LocalAddr() } -func (s *server) handlePacket(p *receivedPacket) { +func (s *baseServer) handlePacket(p *receivedPacket) { go func() { if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer { p.buffer.Release() @@ -308,7 +346,7 @@ func (s *server) handlePacket(p *receivedPacket) { }() } -func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ { +func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ { if len(p.data) < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", len(p.data)) return false @@ -352,7 +390,7 @@ func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet pass return true } -func (s *server) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) { +func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { return nil, nil, errors.New("too short connection ID") } @@ -402,7 +440,7 @@ func (s *server) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSes return sess, connID, nil } -func (s *server) createNewSession( +func (s *baseServer) createNewSession( remoteAddr net.Addr, origDestConnID protocol.ConnectionID, clientDestConnID protocol.ConnectionID, @@ -442,16 +480,26 @@ func (s *server) createNewSession( return nil, err } go sess.run() - go s.waitUntilHandshakeComplete(sess) + go s.handleNewSession(sess) return sess, nil } -func (s *server) waitUntilHandshakeComplete(sess quicSession) { +func (s *baseServer) handleNewSession(sess quicSession) { sessCtx := sess.Context() - select { - case <-sess.HandshakeComplete().Done(): - case <-sessCtx.Done(): - return + if s.acceptEarlySessions { + // wait until the early session is ready (or the handshake fails) + select { + case <-sess.earlySessionReady(): + case <-sessCtx.Done(): + return + } + } else { + // wait until the handshake is complete (or fails) + select { + case <-sess.HandshakeComplete().Done(): + case <-sessCtx.Done(): + return + } } atomic.AddInt32(&s.sessionQueueLen, 1) @@ -464,7 +512,7 @@ func (s *server) waitUntilHandshakeComplete(sess quicSession) { } } -func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { +func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID) if err != nil { return err @@ -494,7 +542,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { return nil } -func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { +func (s *baseServer) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { sealer, _, err := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer) if err != nil { return err @@ -541,7 +589,7 @@ func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { return nil } -func (s *server) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { +func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) if err != nil { diff --git a/server_test.go b/server_test.go index ca25b7fb..b73824bf 100644 --- a/server_test.go +++ b/server_test.go @@ -39,6 +39,12 @@ var _ = Describe("Server", func() { } } + parseHeader := func(data []byte) *wire.Header { + hdr, _, _, err := wire.ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + return hdr + } + BeforeEach(func() { conn = newMockPacketConn() conn.addr = &net.UDPAddr{} @@ -61,7 +67,7 @@ var _ = Describe("Server", func() { It("fills in default values if options are not set in the Config", func() { ln, err := Listen(conn, tlsConf, &Config{}) Expect(err).ToNot(HaveOccurred()) - server := ln.(*server) + server := ln.(*baseServer) Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) Expect(server.config.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout)) Expect(server.config.IdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) @@ -86,7 +92,7 @@ var _ = Describe("Server", func() { } ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) - server := ln.(*server) + server := ln.(*baseServer) Expect(server.sessionHandler).ToNot(BeNil()) Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.HandshakeTimeout).To(Equal(1337 * time.Hour)) @@ -103,8 +109,7 @@ var _ = Describe("Server", func() { addr := "127.0.0.1:13579" ln, err := ListenAddr(addr, tlsConf, &Config{}) Expect(err).ToNot(HaveOccurred()) - serv := ln.(*server) - Expect(serv.Addr().String()).To(Equal(addr)) + Expect(ln.Addr().String()).To(Equal(addr)) // stop the listener Expect(ln.Close()).To(Succeed()) }) @@ -121,173 +126,427 @@ var _ = Describe("Server", func() { Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) }) - Context("handling packets", func() { - var serv *server + Context("server accepting sessions that completed the handshake", func() { + var serv *baseServer BeforeEach(func() { ln, err := Listen(conn, tlsConf, nil) Expect(err).ToNot(HaveOccurred()) - serv = ln.(*server) + serv = ln.(*baseServer) }) - parseHeader := func(data []byte) *wire.Header { - hdr, _, _, err := wire.ParsePacket(data, 0) + Context("handling packets", func() { + It("drops Initial packets with a too short connection ID", func() { + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Version: serv.config.Versions[0], + }, nil)) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) + + It("drops too small Initial", func() { + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize-100), + )) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) + + It("drops packets with a too short connection ID", func() { + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize))) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) + + It("drops non-Initial packets", func() { + serv.handlePacket(getPacket( + &wire.Header{ + Type: protocol.PacketTypeHandshake, + Version: serv.config.Versions[0], + }, + []byte("invalid"), + )) + }) + + It("decodes the token from the Token field", func() { + raddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 13, 37), + Port: 1337, + } + done := make(chan struct{}) + serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { + Expect(addr).To(Equal(raddr)) + Expect(token).ToNot(BeNil()) + close(done) + return false + } + token, err := serv.tokenGenerator.NewRetryToken(raddr, nil) + Expect(err).ToNot(HaveOccurred()) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: token, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("passes an empty token to the callback, if decoding fails", func() { + raddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 13, 37), + Port: 1337, + } + done := make(chan struct{}) + serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { + Expect(addr).To(Equal(raddr)) + Expect(token).To(BeNil()) + close(done) + return false + } + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: []byte("foobar"), + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("sends a Version Negotiation Packet for unsupported versions", func() { + srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} + destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Version: 0x42, + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + serv.handlePacket(packet) + var write mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&write)) + Expect(write.to.String()).To(Equal("127.0.0.1:1337")) + Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue()) + hdr := parseHeader(write.data) + Expect(hdr.DestConnectionID).To(Equal(srcConnID)) + Expect(hdr.SrcConnectionID).To(Equal(destConnID)) + Expect(hdr.SupportedVersions).ToNot(ContainElement(protocol.VersionNumber(0x42))) + }) + + It("replies with a Retry packet, if a Token is required", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + serv.handlePacket(packet) + var write mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&write)) + Expect(write.to.String()).To(Equal("127.0.0.1:1337")) + replyHdr := parseHeader(write.data) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.Token).ToNot(BeEmpty()) + }) + + It("creates a session, if no Token is required", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + run := make(chan struct{}) + serv.newSession = func( + _ connection, + _ sessionRunner, + origConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ *handshake.TransportParameters, + _ *handshake.TokenGenerator, + _ utils.Logger, + _ protocol.VersionNumber, + ) (quicSession, error) { + Expect(origConnID).To(Equal(hdr.DestConnectionID)) + Expect(destConnID).To(Equal(hdr.SrcConnectionID)) + // make sure we're using a server-generated connection ID + Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) + Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().handlePacket(p) + sess.EXPECT().run().Do(func() { close(run) }) + sess.EXPECT().Context().Return(context.Background()) + sess.EXPECT().HandshakeComplete().Return(context.Background()) + return sess, nil + } + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + serv.handlePacket(p) + // the Handshake packet is written by the session + Consistently(conn.dataWritten).ShouldNot(Receive()) + close(done) + }() + // make sure we're using a server-generated connection ID + Eventually(run).Should(BeClosed()) + Eventually(done).Should(BeClosed()) + }) + + It("rejects new connection attempts if the accept queue is full", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} + + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.remoteAddr = senderAddr + serv.newSession = func( + _ connection, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ *handshake.TransportParameters, + _ *handshake.TokenGenerator, + _ utils.Logger, + _ protocol.VersionNumber, + ) (quicSession, error) { + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().handlePacket(p) + sess.EXPECT().run() + sess.EXPECT().Context().Return(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + sess.EXPECT().HandshakeComplete().Return(ctx) + return sess, nil + } + + var wg sync.WaitGroup + wg.Add(protocol.MaxAcceptQueueSize) + for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + go func() { + defer GinkgoRecover() + defer wg.Done() + serv.handlePacket(p) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }() + } + wg.Wait() + serv.handlePacket(p) + var reject mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&reject)) + Expect(reject.to).To(Equal(senderAddr)) + rejectHdr := parseHeader(reject.data) + Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(rejectHdr.Version).To(Equal(hdr.Version)) + Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + }) + + It("doesn't accept new sessions if they were closed in the mean time", func() { + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} + + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Version: protocol.VersionTLS, + } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.remoteAddr = senderAddr + ctx, cancel := context.WithCancel(context.Background()) + sessionCreated := make(chan struct{}) + sess := NewMockQuicSession(mockCtrl) + serv.newSession = func( + _ connection, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ *handshake.TransportParameters, + _ *handshake.TokenGenerator, + _ utils.Logger, + _ protocol.VersionNumber, + ) (quicSession, error) { + sess.EXPECT().handlePacket(p) + sess.EXPECT().run() + sess.EXPECT().Context().Return(ctx) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + sess.EXPECT().HandshakeComplete().Return(ctx) + close(sessionCreated) + return sess, nil + } + + serv.handlePacket(p) + Consistently(conn.dataWritten).ShouldNot(Receive()) + Eventually(sessionCreated).Should(BeClosed()) + cancel() + time.Sleep(scaleDuration(200 * time.Millisecond)) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + serv.Accept(context.Background()) + close(done) + }() + Consistently(done).ShouldNot(BeClosed()) + + // make the go routine return + sess.EXPECT().getPerspective() + Expect(serv.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + }) + + Context("accepting sessions", func() { + It("returns Accept when an error occurs", func() { + testErr := errors.New("test err") + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := serv.Accept(context.Background()) + Expect(err).To(MatchError(testErr)) + close(done) + }() + + serv.setCloseError(testErr) + Eventually(done).Should(BeClosed()) + }) + + It("returns immediately, if an error occurred before", func() { + testErr := errors.New("test err") + serv.setCloseError(testErr) + for i := 0; i < 3; i++ { + _, err := serv.Accept(context.Background()) + Expect(err).To(MatchError(testErr)) + } + }) + + It("returns when the context is canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := serv.Accept(ctx) + Expect(err).To(MatchError("context canceled")) + close(done) + }() + + Consistently(done).ShouldNot(BeClosed()) + cancel() + Eventually(done).Should(BeClosed()) + }) + + It("accepts new sessions when the handshake completes", func() { + sess := NewMockQuicSession(mockCtrl) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + s, err := serv.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(s).To(Equal(sess)) + close(done) + }() + + ctx, cancel := context.WithCancel(context.Background()) // handshake context + serv.newSession = func( + _ connection, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ *handshake.TransportParameters, + _ *handshake.TokenGenerator, + _ utils.Logger, + _ protocol.VersionNumber, + ) (quicSession, error) { + sess.EXPECT().HandshakeComplete().Return(ctx) + sess.EXPECT().run().Do(func() {}) + sess.EXPECT().Context().Return(context.Background()) + return sess, nil + } + _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Consistently(done).ShouldNot(BeClosed()) + cancel() // complete the handshake + Eventually(done).Should(BeClosed()) + }) + }) + }) + + Context("server accepting sessions that haven't completed the handshake", func() { + var serv *earlyServer + + BeforeEach(func() { + ln, err := ListenEarly(conn, tlsConf, nil) Expect(err).ToNot(HaveOccurred()) - return hdr - } - - It("drops Initial packets with a too short connection ID", func() { - serv.handlePacket(getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Version: serv.config.Versions[0], - }, nil)) - Consistently(conn.dataWritten).ShouldNot(Receive()) + serv = ln.(*earlyServer) }) - It("drops too small Initial", func() { - serv.handlePacket(getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize-100), - )) - Consistently(conn.dataWritten).ShouldNot(Receive()) - }) + It("accepts new sessions when they become ready", func() { + sess := NewMockQuicSession(mockCtrl) - It("drops packets with a too short connection ID", func() { - serv.handlePacket(getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize))) - Consistently(conn.dataWritten).ShouldNot(Receive()) - }) - - It("drops non-Initial packets", func() { - serv.handlePacket(getPacket( - &wire.Header{ - Type: protocol.PacketTypeHandshake, - Version: serv.config.Versions[0], - }, - []byte("invalid"), - )) - }) - - It("decodes the token from the Token field", func() { - raddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 13, 37), - Port: 1337, - } done := make(chan struct{}) - serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { - Expect(addr).To(Equal(raddr)) - Expect(token).ToNot(BeNil()) + go func() { + defer GinkgoRecover() + s, err := serv.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(s).To(Equal(sess)) close(done) - return false - } - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil) - Expect(err).ToNot(HaveOccurred()) - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Token: token, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) + }() - It("passes an empty token to the callback, if decoding fails", func() { - raddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 13, 37), - Port: 1337, - } - done := make(chan struct{}) - serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { - Expect(addr).To(Equal(raddr)) - Expect(token).To(BeNil()) - close(done) - return false - } - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Token: []byte("foobar"), - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("sends a Version Negotiation Packet for unsupported versions", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - serv.handlePacket(packet) - var write mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&write)) - Expect(write.to.String()).To(Equal("127.0.0.1:1337")) - Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue()) - hdr := parseHeader(write.data) - Expect(hdr.DestConnectionID).To(Equal(srcConnID)) - Expect(hdr.SrcConnectionID).To(Equal(destConnID)) - Expect(hdr.SupportedVersions).ToNot(ContainElement(protocol.VersionNumber(0x42))) - }) - - It("replies with a Retry packet, if a Token is required", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - serv.handlePacket(packet) - var write mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&write)) - Expect(write.to.String()).To(Equal("127.0.0.1:1337")) - replyHdr := parseHeader(write.data) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.Token).ToNot(BeEmpty()) - }) - - It("creates a session, if no Token is required", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Version: protocol.VersionTLS, - } - p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - run := make(chan struct{}) + ready := make(chan struct{}) serv.newSession = func( _ connection, - _ sessionRunner, - origConnID protocol.ConnectionID, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, _ *Config, _ *tls.Config, _ *handshake.TransportParameters, @@ -295,29 +554,15 @@ var _ = Describe("Server", func() { _ utils.Logger, _ protocol.VersionNumber, ) (quicSession, error) { - Expect(origConnID).To(Equal(hdr.DestConnectionID)) - Expect(destConnID).To(Equal(hdr.SrcConnectionID)) - // make sure we're using a server-generated connection ID - Expect(srcConnID).ToNot(Equal(hdr.DestConnectionID)) - Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(p) - sess.EXPECT().run().Do(func() { close(run) }) + sess.EXPECT().run().Do(func() {}) + sess.EXPECT().earlySessionReady().Return(ready) sess.EXPECT().Context().Return(context.Background()) - sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - serv.handlePacket(p) - // the Handshake packet is written by the session - Consistently(conn.dataWritten).ShouldNot(Receive()) - close(done) - }() - // make sure we're using a server-generated connection ID - Eventually(run).Should(BeClosed()) + _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Consistently(done).ShouldNot(BeClosed()) + close(ready) Eventually(done).Should(BeClosed()) }) @@ -347,13 +592,13 @@ var _ = Describe("Server", func() { _ utils.Logger, _ protocol.VersionNumber, ) (quicSession, error) { + ready := make(chan struct{}) + close(ready) sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(p) sess.EXPECT().run() + sess.EXPECT().earlySessionReady().Return(ready) sess.EXPECT().Context().Return(context.Background()) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - sess.EXPECT().HandshakeComplete().Return(ctx) return sess, nil } @@ -410,10 +655,8 @@ var _ = Describe("Server", func() { ) (quicSession, error) { sess.EXPECT().handlePacket(p) sess.EXPECT().run() + sess.EXPECT().earlySessionReady() sess.EXPECT().Context().Return(ctx) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - sess.EXPECT().HandshakeComplete().Return(ctx) close(sessionCreated) return sess, nil } @@ -438,93 +681,6 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) }) - - Context("accepting sessions", func() { - var serv *server - - BeforeEach(func() { - ln, err := Listen(conn, tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - serv = ln.(*server) - }) - - It("returns Accept when an error occurs", func() { - testErr := errors.New("test err") - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := serv.Accept(context.Background()) - Expect(err).To(MatchError(testErr)) - close(done) - }() - - serv.setCloseError(testErr) - Eventually(done).Should(BeClosed()) - }) - - It("returns immediately, if an error occurred before", func() { - testErr := errors.New("test err") - serv.setCloseError(testErr) - for i := 0; i < 3; i++ { - _, err := serv.Accept(context.Background()) - Expect(err).To(MatchError(testErr)) - } - }) - - It("returns when the context is canceled", func() { - ctx, cancel := context.WithCancel(context.Background()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := serv.Accept(ctx) - Expect(err).To(MatchError("context canceled")) - close(done) - }() - - Consistently(done).ShouldNot(BeClosed()) - cancel() - Eventually(done).Should(BeClosed()) - }) - - It("accepts new sessions when the handshake completes", func() { - sess := NewMockQuicSession(mockCtrl) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - s, err := serv.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(s).To(Equal(sess)) - close(done) - }() - - ctx, cancel := context.WithCancel(context.Background()) // handshake context - serv.newSession = func( - _ connection, - runner sessionRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ *handshake.TransportParameters, - _ *handshake.TokenGenerator, - _ utils.Logger, - _ protocol.VersionNumber, - ) (quicSession, error) { - sess.EXPECT().HandshakeComplete().Return(ctx) - sess.EXPECT().run().Do(func() {}) - sess.EXPECT().Context().Return(context.Background()) - return sess, nil - } - _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - Consistently(done).ShouldNot(BeClosed()) - cancel() // complete the handshake - Eventually(done).Should(BeClosed()) - }) - }) }) var _ = Describe("default source address verification", func() { diff --git a/session.go b/session.go index 6030c303..998e2dca 100644 --- a/session.go +++ b/session.go @@ -171,6 +171,7 @@ type session struct { } var _ Session = &session{} +var _ EarlySession = &session{} var _ streamSender = &session{} var newSession = func( From 8215e70de0358157ee73a33d9dc796292775bf47 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 27 Jul 2019 01:59:09 -0400 Subject: [PATCH 6/6] add an integration test for 0.5-RTT data --- integrationtests/self/early_data_test.go | 72 ++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 integrationtests/self/early_data_test.go diff --git a/integrationtests/self/early_data_test.go b/integrationtests/self/early_data_test.go new file mode 100644 index 00000000..6b2b057e --- /dev/null +++ b/integrationtests/self/early_data_test.go @@ -0,0 +1,72 @@ +package self_test + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "time" + + quic "github.com/lucas-clemente/quic-go" + quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" + "github.com/lucas-clemente/quic-go/internal/protocol" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("early data", func() { + const rtt = 80 * time.Millisecond + for _, v := range protocol.SupportedVersions { + version := v + + Context(fmt.Sprintf("with QUIC version %s", version), func() { + It("sends 0.5-RTT data", func() { + ln, err := quic.ListenAddrEarly( + "localhost:0", + getTLSConfig(), + &quic.Config{Versions: []protocol.VersionNumber{version}}, + ) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("early data")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + // make sure the Write finished before the handshake completed + Expect(sess.HandshakeComplete().Done()).ToNot(BeClosed()) + Eventually(sess.HandshakeComplete().Done()).Should(BeClosed()) + close(done) + }() + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + sess, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + &quic.Config{Versions: []protocol.VersionNumber{version}}, + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("early data"))) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + }) + } +})