diff --git a/client.go b/client.go index fc6ed290..5dac6a3a 100644 --- a/client.go +++ b/client.go @@ -292,7 +292,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error { return ctx.Err() case err := <-errorChan: return err - case <-c.handshakeChan: + case <-c.session.HandshakeComplete().Done(): // handshake successfully completed return nil } @@ -371,13 +371,9 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error { c.mutex.Lock() defer c.mutex.Unlock() - runner := &runner{ - packetHandlerManager: c.packetHandlers, - onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) }, - } sess, err := newClientSession( c.conn, - runner, + c.packetHandlers, c.destConnID, c.srcConnID, c.config, diff --git a/client_test.go b/client_test.go index 0bc1e2ef..fb326045 100644 --- a/client_test.go +++ b/client_test.go @@ -149,6 +149,7 @@ var _ = Describe("Client", func() { remoteAddrChan <- conn.RemoteAddr().String() sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run() + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeTimeout: time.Millisecond}) @@ -179,6 +180,7 @@ var _ = Describe("Client", func() { hostnameChan <- tlsConf.ServerName sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run() + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } tlsConf.ServerName = "foobar" @@ -208,6 +210,7 @@ var _ = Describe("Client", func() { ) (quicSession, error) { hostnameChan <- tlsConf.ServerName sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().HandshakeComplete().Return(context.Background()) sess.EXPECT().run() return sess, nil } @@ -243,7 +246,9 @@ var _ = Describe("Client", func() { ) (quicSession, error) { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Do(func() { close(run) }) - runner.OnHandshakeComplete(sess) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + sess.EXPECT().HandshakeComplete().Return(ctx) return sess, nil } s, err := Dial( @@ -258,7 +263,7 @@ var _ = Describe("Client", func() { Eventually(run).Should(BeClosed()) }) - It("returns an error that occurs while waiting for the connection to become secure", func() { + It("returns an error that occurs while waiting for the handshake to complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) @@ -279,6 +284,7 @@ var _ = Describe("Client", func() { ) (quicSession, error) { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Return(testErr) + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) @@ -303,6 +309,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run().Do(func() { <-sessionRunning }) + sess.EXPECT().HandshakeComplete().Return(context.Background()) newClientSession = func( _ connection, _ sessionRunner, @@ -366,6 +373,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run().Do(func() { runner.Retire(connID) }) + sess.EXPECT().HandshakeComplete().Return(context.Background()) _, err := DialContext( context.Background(), @@ -411,6 +419,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run().Do(func() { <-run }) + sess.EXPECT().HandshakeComplete().Return(context.Background()) done := make(chan struct{}) go func() { @@ -497,7 +506,7 @@ var _ = Describe("Client", func() { }) }) - It("creates new TLS sessions with the right parameters", func() { + It("creates new sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) @@ -527,6 +536,7 @@ var _ = Describe("Client", func() { // TODO: check connection IDs? sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run() + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) @@ -571,6 +581,7 @@ var _ = Describe("Client", func() { Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed()) sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Return(testErr) + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } _, err := Dial( diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 30bcd90b..99ba98ac 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -137,6 +137,20 @@ func (mr *MockQuicSessionMockRecorder) GetVersion() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockQuicSession)(nil).GetVersion)) } +// HandshakeComplete mocks base method +func (m *MockQuicSession) HandshakeComplete() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandshakeComplete") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// HandshakeComplete indicates an expected call of HandshakeComplete +func (mr *MockQuicSessionMockRecorder) HandshakeComplete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockQuicSession)(nil).HandshakeComplete)) +} + // LocalAddr mocks base method func (m *MockQuicSession) LocalAddr() net.Addr { m.ctrl.T.Helper() diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index fad0e813..9bb54a05 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -46,18 +46,6 @@ func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1) } -// OnHandshakeComplete mocks base method -func (m *MockSessionRunner) OnHandshakeComplete(arg0 Session) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnHandshakeComplete", arg0) -} - -// OnHandshakeComplete indicates an expected call of OnHandshakeComplete -func (mr *MockSessionRunnerMockRecorder) OnHandshakeComplete(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).OnHandshakeComplete), arg0) -} - // Remove mocks base method func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/server.go b/server.go index 66eab853..c0269ae4 100644 --- a/server.go +++ b/server.go @@ -46,6 +46,7 @@ type packetHandlerManager interface { type quicSession interface { Session + HandshakeComplete() context.Context handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber getPerspective() protocol.Perspective @@ -56,23 +57,12 @@ type quicSession interface { } type sessionRunner interface { - OnHandshakeComplete(Session) Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) AddResetToken([16]byte, packetHandler) RemoveResetToken([16]byte) } -type runner struct { - packetHandlerManager - - onHandshakeCompleteImpl func(Session) -} - -func (r *runner) OnHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) } - -var _ sessionRunner = &runner{} - // A Listener of QUIC type server struct { mutex sync.Mutex @@ -99,8 +89,6 @@ type server struct { sessionQueue chan Session sessionQueueLen int32 // to be used as an atomic - sessionRunner sessionRunner - logger utils.Logger } @@ -169,21 +157,6 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, newSession: newSession, logger: utils.DefaultLogger.WithPrefix("server"), } - s.sessionRunner = &runner{ - packetHandlerManager: s.sessionHandler, - onHandshakeCompleteImpl: func(sess Session) { - go func() { - atomic.AddInt32(&s.sessionQueueLen, 1) - select { - case s.sessionQueue <- sess: - // blocks until the session is accepted - case <-sess.Context().Done(): - atomic.AddInt32(&s.sessionQueueLen, -1) - // don't pass sessions that were already closed to Accept() - } - }() - }, - } sessionHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil @@ -454,7 +427,7 @@ func (s *server) createNewSession( } sess, err := s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, - s.sessionRunner, + s.sessionHandler, clientDestConnID, destConnID, srcConnID, @@ -469,9 +442,28 @@ func (s *server) createNewSession( return nil, err } go sess.run() + go s.waitUntilHandshakeComplete(sess) return sess, nil } +func (s *server) waitUntilHandshakeComplete(sess quicSession) { + sessCtx := sess.Context() + select { + case <-sess.HandshakeComplete().Done(): + case <-sessCtx.Done(): + return + } + + atomic.AddInt32(&s.sessionQueueLen, 1) + select { + case s.sessionQueue <- sess: + // blocks until the session is accepted + case <-sessCtx.Done(): + atomic.AddInt32(&s.sessionQueueLen, -1) + // don't pass sessions that were already closed to Accept() + } +} + func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID) if err != nil { diff --git a/server_test.go b/server_test.go index caa5935b..ca25b7fb 100644 --- a/server_test.go +++ b/server_test.go @@ -303,6 +303,8 @@ var _ = Describe("Server", func() { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(p) sess.EXPECT().run().Do(func() { close(run) }) + sess.EXPECT().Context().Return(context.Background()) + sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess, nil } @@ -349,7 +351,9 @@ var _ = Describe("Server", func() { sess.EXPECT().handlePacket(p) sess.EXPECT().run() sess.EXPECT().Context().Return(context.Background()) - runner.OnHandshakeComplete(sess) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + sess.EXPECT().HandshakeComplete().Return(ctx) return sess, nil } @@ -407,7 +411,9 @@ var _ = Describe("Server", func() { sess.EXPECT().handlePacket(p) sess.EXPECT().run() sess.EXPECT().Context().Return(ctx) - runner.OnHandshakeComplete(sess) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + sess.EXPECT().HandshakeComplete().Return(ctx) close(sessionCreated) return sess, nil } @@ -493,7 +499,7 @@ var _ = Describe("Server", func() { close(done) }() - completeHandshake := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) // handshake context serv.newSession = func( _ connection, runner sessionRunner, @@ -507,10 +513,7 @@ var _ = Describe("Server", func() { _ utils.Logger, _ protocol.VersionNumber, ) (quicSession, error) { - go func() { - <-completeHandshake - runner.OnHandshakeComplete(sess) - }() + sess.EXPECT().HandshakeComplete().Return(ctx) sess.EXPECT().run().Do(func() {}) sess.EXPECT().Context().Return(context.Background()) return sess, nil @@ -518,44 +521,9 @@ var _ = Describe("Server", func() { _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Consistently(done).ShouldNot(BeClosed()) - close(completeHandshake) + cancel() // complete the handshake Eventually(done).Should(BeClosed()) }) - - It("never blocks when calling the onHandshakeComplete callback", func() { - const num = 50 - - runs := make(chan struct{}, num) - contexts := make(chan struct{}, num) - serv.newSession = func( - _ connection, - runner sessionRunner, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ *Config, - _ *tls.Config, - _ *handshake.TransportParameters, - _ *handshake.TokenGenerator, - _ utils.Logger, - _ protocol.VersionNumber, - ) (quicSession, error) { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run().Do(func() { runs <- struct{}{} }) - sess.EXPECT().Context().Do(func() { contexts <- struct{}{} }).Return(context.Background()) - runner.OnHandshakeComplete(sess) - return sess, nil - } - - go func() { - for i := 0; i < num; i++ { - _, err := serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) - Expect(err).ToNot(HaveOccurred()) - } - }() - Eventually(runs).Should(HaveLen(num)) - Eventually(contexts).Should(HaveLen(num)) - }) }) }) diff --git a/session.go b/session.go index 845e98e2..793b9d8f 100644 --- a/session.go +++ b/session.go @@ -513,7 +513,6 @@ func (s *session) handleHandshakeComplete() { s.handshakeComplete = true s.handshakeCompleteChan = nil // prevent this case from ever being selected again s.handshakeCtxCancel() - s.sessionRunner.OnHandshakeComplete(s) // The client completes the handshake first (after sending the CFIN). // We need to make sure it learns about the server completing the handshake, diff --git a/session_test.go b/session_test.go index f46b7581..1a92f429 100644 --- a/session_test.go +++ b/session_test.go @@ -1154,7 +1154,6 @@ var _ = Describe("Session", func() { }) It("cancels the HandshakeComplete context when the handshake completes", func() { - sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()) packer.EXPECT().PackPacket().AnyTimes() finishHandshake := make(chan struct{}) go func() { @@ -1197,17 +1196,14 @@ var _ = Describe("Session", func() { It("sends a 1-RTT packet when the handshake completes", func() { done := make(chan struct{}) - gomock.InOrder( - sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()), - packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { - defer close(done) - return &packedPacket{ - header: &wire.ExtendedHeader{}, - buffer: getPacketBuffer(), - }, nil - }), - packer.EXPECT().PackPacket().AnyTimes(), - ) + packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { + defer close(done) + return &packedPacket{ + header: &wire.ExtendedHeader{}, + buffer: getPacketBuffer(), + }, nil + }) + packer.EXPECT().PackPacket().AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -1451,7 +1447,6 @@ var _ = Describe("Session", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sessionRunner.EXPECT().OnHandshakeComplete(sess) cryptoSetup.EXPECT().RunHandshake() close(sess.handshakeCompleteChan) err := sess.run()