From bb185a3ad2f98f7526fec7ecb60d18205b4c37cd Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 24 Jan 2019 18:09:46 +0700 Subject: [PATCH] add a function to close the packet handler map Close will close the underlying connection and wait until listen has returned. While not strictly necessary in production use, this will fix a few race conditions in our tests. --- client.go | 2 +- client_test.go | 7 +++---- mock_packet_handler_manager_test.go | 12 ++++++++++++ packet_handler_map.go | 15 ++++++++++++++- server.go | 3 ++- 5 files changed, 32 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 787fa89c1..e4e819ed2 100644 --- a/client.go +++ b/client.go @@ -271,7 +271,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error { go func() { err := c.session.run() // returns as soon as the session is closed if err != errCloseForRecreating && c.createdPacketConn { - c.conn.Close() + c.packetHandlers.Close() } errorChan <- err }() diff --git a/client_test.go b/client_test.go index 875abcd6e..a29b2cbdd 100644 --- a/client_test.go +++ b/client_test.go @@ -131,6 +131,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) + manager.EXPECT().Close() mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) remoteAddrChan := make(chan string, 1) @@ -162,6 +163,7 @@ var _ = Describe("Client", func() { It("uses the tls.Config.ServerName as the hostname, if present", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) + manager.EXPECT().Close() mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) hostnameChan := make(chan string, 1) @@ -403,12 +405,9 @@ var _ = Describe("Client", func() { // check that the connection is not closed Expect(conn.Write([]byte("foobar"))).To(Succeed()) + manager.EXPECT().Close() close(run) time.Sleep(50 * time.Millisecond) - // check that the connection is closed - err := conn.Write([]byte("foobar")) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("use of closed network connection")) Eventually(done).Should(BeClosed()) }) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index b1906b35a..056fa28e8 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -44,6 +44,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) } +// Close mocks base method +func (m *MockPacketHandlerManager) Close() error { + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close)) +} + // CloseServer mocks base method func (m *MockPacketHandlerManager) CloseServer() { m.ctrl.Call(m, "CloseServer") diff --git a/packet_handler_map.go b/packet_handler_map.go index 6ce4e7fd3..b526b6d95 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -31,7 +31,9 @@ type packetHandlerMap struct { handlers map[string] /* string(ConnectionID)*/ packetHandlerEntry resetTokens map[[16]byte] /* stateless reset token */ packetHandler server unknownPacketHandler - closed bool + + listening chan struct{} // is closed when listen returns + closed bool deleteRetiredSessionsAfter time.Duration @@ -44,6 +46,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger m := &packetHandlerMap{ conn: conn, connIDLen: connIDLen, + listening: make(chan struct{}), handlers: make(map[string]packetHandlerEntry), resetTokens: make(map[[16]byte]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, @@ -117,6 +120,15 @@ func (h *packetHandlerMap) CloseServer() { wg.Wait() } +// Close the underlying connection and wait until listen() has returned. +func (h *packetHandlerMap) Close() error { + if err := h.conn.Close(); err != nil { + return err + } + <-h.listening // wait until listening returns + return nil +} + func (h *packetHandlerMap) close(e error) error { h.mutex.Lock() if h.closed { @@ -143,6 +155,7 @@ func (h *packetHandlerMap) close(e error) error { } func (h *packetHandlerMap) listen() { + defer close(h.listening) for { buffer := getPacketBuffer() data := buffer.Slice diff --git a/server.go b/server.go index f1277838f..9b106d2e8 100644 --- a/server.go +++ b/server.go @@ -32,6 +32,7 @@ type unknownPacketHandler interface { } type packetHandlerManager interface { + io.Closer Add(protocol.ConnectionID, packetHandler) Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) @@ -300,7 +301,7 @@ func (s *server) closeWithMutex() error { // If the server was started with ListenAddr, we created the packet conn. // We need to close it in order to make the go routine reading from that conn return. if s.createdPacketConn { - err = s.conn.Close() + err = s.sessionHandler.Close() } s.closed = true close(s.errorChan)