From dc751238363e6d5244e322bf56cfc32ce83f7694 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 10 May 2019 09:47:52 +0900 Subject: [PATCH] fix race condition when closing the server after a Read failed --- mock_unknown_packet_handler_test.go | 26 ++++++++++++-------------- packet_handler_map.go | 4 ++-- server.go | 13 +++++-------- server_test.go | 4 ++-- 4 files changed, 21 insertions(+), 26 deletions(-) diff --git a/mock_unknown_packet_handler_test.go b/mock_unknown_packet_handler_test.go index ca4771d6b..a0adf5564 100644 --- a/mock_unknown_packet_handler_test.go +++ b/mock_unknown_packet_handler_test.go @@ -33,20 +33,6 @@ func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorde return m.recorder } -// closeWithError mocks base method -func (m *MockUnknownPacketHandler) closeWithError(arg0 error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "closeWithError", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// closeWithError indicates an expected call of closeWithError -func (mr *MockUnknownPacketHandlerMockRecorder) closeWithError(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).closeWithError), arg0) -} - // handlePacket mocks base method func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) { m.ctrl.T.Helper() @@ -58,3 +44,15 @@ func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) * mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0) } + +// setCloseError mocks base method +func (m *MockUnknownPacketHandler) setCloseError(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "setCloseError", arg0) +} + +// setCloseError indicates an expected call of setCloseError +func (mr *MockUnknownPacketHandlerMockRecorder) setCloseError(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setCloseError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).setCloseError), arg0) +} diff --git a/packet_handler_map.go b/packet_handler_map.go index 0af6a4074..35265b59b 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -141,7 +141,6 @@ func (h *packetHandlerMap) close(e error) error { h.mutex.Unlock() return nil } - h.closed = true var wg sync.WaitGroup for _, handler := range h.handlers { @@ -153,8 +152,9 @@ func (h *packetHandlerMap) close(e error) error { } if h.server != nil { - h.server.closeWithError(e) + h.server.setCloseError(e) } + h.closed = true h.mutex.Unlock() wg.Wait() return getMultiplexer().RemoveConn(h.conn) diff --git a/server.go b/server.go index fe2f9388a..9c959ec6a 100644 --- a/server.go +++ b/server.go @@ -28,7 +28,7 @@ type packetHandler interface { type unknownPacketHandler interface { handlePacket(*receivedPacket) - closeWithError(error) error + setCloseError(error) } type packetHandlerManager interface { @@ -293,10 +293,6 @@ func (s *server) Close() error { if s.closed { return nil } - return s.closeWithMutex() -} - -func (s *server) closeWithMutex() error { s.sessionHandler.CloseServer() if s.serverError == nil { s.serverError = errors.New("server closed") @@ -312,14 +308,15 @@ func (s *server) closeWithMutex() error { return err } -func (s *server) closeWithError(e error) error { +func (s *server) setCloseError(e error) { s.mutex.Lock() defer s.mutex.Unlock() if s.closed { - return nil + return } + s.closed = true s.serverError = e - return s.closeWithMutex() + close(s.errorChan) } // Addr returns the server's network address diff --git a/server_test.go b/server_test.go index fb8067949..dd6787faf 100644 --- a/server_test.go +++ b/server_test.go @@ -451,13 +451,13 @@ var _ = Describe("Server", func() { close(done) }() - Expect(serv.closeWithError(testErr)).To(Succeed()) + serv.setCloseError(testErr) Eventually(done).Should(BeClosed()) }) It("returns immediately, if an error occurred before", func() { testErr := errors.New("test err") - Expect(serv.closeWithError(testErr)).To(Succeed()) + serv.setCloseError(testErr) for i := 0; i < 3; i++ { _, err := serv.Accept() Expect(err).To(MatchError(testErr))