From 65cca7b7262ecbcaa3c7cccbf39cd23cef3f7235 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 15 Nov 2018 19:14:34 +0700 Subject: [PATCH] immediately remove sessions that were closed remotely We only need to keep sessions that were closed locally, in order to retransmit the CONNECTION_CLOSE packet. For sessions that were closed by the peer, we don't need to wait for any more packets, and there's also no CONNECTION_CLOSE to retransmit. The same applies for sessions that were destroyed when receiving a Version Negotiation or a Retry packet. --- client.go | 1 + mock_packet_handler_manager_test.go | 10 ++++++++++ mock_session_runner_test.go | 10 ++++++++++ packet_handler_map.go | 6 ++++++ packet_handler_map_test.go | 10 +++++++++- server.go | 5 +++++ session.go | 4 +++- session_test.go | 4 ++-- 8 files changed, 46 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index f3bad6a00..c5d667322 100644 --- a/client.go +++ b/client.go @@ -407,6 +407,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error { runner := &runner{ onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) }, retireConnectionIDImpl: c.packetHandlers.Retire, + removeConnectionIDImpl: c.packetHandlers.Remove, } sess, err := newClientSession( c.conn, diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 8d4fd1c95..b1906b35a 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -54,6 +54,16 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) } +// Remove mocks base method +func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { + m.ctrl.Call(m, "Remove", arg0) +} + +// Remove indicates an expected call of Remove +func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) +} + // Retire mocks base method func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { m.ctrl.Call(m, "Retire", arg0) diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index 8bfccf67c..aeba646e2 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -44,6 +44,16 @@ func (mr *MockSessionRunnerMockRecorder) onHandshakeComplete(arg0 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).onHandshakeComplete), arg0) } +// removeConnectionID mocks base method +func (m *MockSessionRunner) removeConnectionID(arg0 protocol.ConnectionID) { + m.ctrl.Call(m, "removeConnectionID", arg0) +} + +// removeConnectionID indicates an expected call of removeConnectionID +func (mr *MockSessionRunnerMockRecorder) removeConnectionID(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removeConnectionID", reflect.TypeOf((*MockSessionRunner)(nil).removeConnectionID), arg0) +} + // retireConnectionID mocks base method func (m *MockSessionRunner) retireConnectionID(arg0 protocol.ConnectionID) { m.ctrl.Call(m, "retireConnectionID", arg0) diff --git a/packet_handler_map.go b/packet_handler_map.go index 05ad3e0fc..7d6150769 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -51,6 +51,12 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) h.mutex.Unlock() } +func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.mutex.Lock() + delete(h.handlers, string(id)) + h.mutex.Unlock() +} + func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { h.retireByConnectionIDAsString(string(id)) } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 13d5b2d7f..e3d17eb4d 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -88,7 +88,15 @@ var _ = Describe("Packet Handler Map", func() { Expect(err.Error()).To(ContainSubstring("error parsing invariant header:")) }) - It("deletes closed session entries after a wait time", func() { + It("deletes removed session immediately", func() { + handler.deleteRetiredSessionsAfter = time.Hour + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + handler.Add(connID, NewMockPacketHandler(mockCtrl)) + handler.Remove(connID) + Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + }) + + It("deletes retired session entries after a wait time", func() { handler.deleteRetiredSessionsAfter = 10 * time.Millisecond connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) diff --git a/server.go b/server.go index 20a58a415..3a5c83709 100644 --- a/server.go +++ b/server.go @@ -33,6 +33,7 @@ type unknownPacketHandler interface { type packetHandlerManager interface { Add(protocol.ConnectionID, packetHandler) Retire(protocol.ConnectionID) + Remove(protocol.ConnectionID) SetServer(unknownPacketHandler) CloseServer() } @@ -49,15 +50,18 @@ type quicSession interface { type sessionRunner interface { onHandshakeComplete(Session) retireConnectionID(protocol.ConnectionID) + removeConnectionID(protocol.ConnectionID) } type runner struct { onHandshakeCompleteImpl func(Session) retireConnectionIDImpl func(protocol.ConnectionID) + removeConnectionIDImpl func(protocol.ConnectionID) } func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) } func (r *runner) retireConnectionID(c protocol.ConnectionID) { r.retireConnectionIDImpl(c) } +func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) } var _ sessionRunner = &runner{} @@ -153,6 +157,7 @@ func (s *server) setup() error { s.sessionRunner = &runner{ onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess }, retireConnectionIDImpl: s.sessionHandler.Retire, + removeConnectionIDImpl: s.sessionHandler.Remove, } cookieGenerator, err := handshake.NewCookieGenerator() if err != nil { diff --git a/session.go b/session.go index 1de985b66..c6d782907 100644 --- a/session.go +++ b/session.go @@ -424,7 +424,6 @@ runLoop: } s.closed.Set(true) s.logger.Infof("Connection %s closed.", s.srcConnID) - s.sessionRunner.retireConnectionID(s.srcConnID) s.cryptoStreamHandler.Close() return closeErr.err } @@ -718,6 +717,7 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt // closeLocal closes the session and send a CONNECTION_CLOSE containing the error func (s *session) closeLocal(e error) { s.closeOnce.Do(func() { + s.sessionRunner.retireConnectionID(s.srcConnID) s.closeChan <- closeError{err: e, sendClose: true, remote: false} }) } @@ -725,12 +725,14 @@ func (s *session) closeLocal(e error) { // destroy closes the session without sending the error on the wire func (s *session) destroy(e error) { s.closeOnce.Do(func() { + s.sessionRunner.removeConnectionID(s.srcConnID) s.closeChan <- closeError{err: e, sendClose: false, remote: false} }) } func (s *session) closeRemote(e error) { s.closeOnce.Do(func() { + s.sessionRunner.removeConnectionID(s.srcConnID) s.closeChan <- closeError{err: e, remote: true} }) } diff --git a/session_test.go b/session_test.go index d63a447a4..675e2eb9d 100644 --- a/session_test.go +++ b/session_test.go @@ -326,7 +326,7 @@ var _ = Describe("Session", func() { It("handles CONNECTION_CLOSE frames", func() { testErr := qerr.Error(qerr.ProofInvalid, "foobar") streamManager.EXPECT().CloseWithError(testErr) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) cryptoSetup.EXPECT().Close() go func() { @@ -402,7 +402,7 @@ var _ = Describe("Session", func() { It("closes the session in order to replace it with another QUIC version", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) cryptoSetup.EXPECT().Close() sess.destroy(errCloseSessionForNewVersion) Eventually(areSessionsRunning).Should(BeFalse())