Merge pull request #1895 from lucas-clemente/fix-server-close-deadlock

fix race condition when closing the server after a Read failed
This commit is contained in:
Marten Seemann
2019-05-12 22:19:03 +08:00
committed by GitHub
4 changed files with 21 additions and 26 deletions

View File

@@ -33,20 +33,6 @@ func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorde
return m.recorder return m.recorder
} }
// closeWithError mocks base method
func (m *MockUnknownPacketHandler) closeWithError(arg0 error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "closeWithError", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// closeWithError indicates an expected call of closeWithError
func (mr *MockUnknownPacketHandlerMockRecorder) closeWithError(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).closeWithError), arg0)
}
// handlePacket mocks base method // handlePacket mocks base method
func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) { func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -58,3 +44,15 @@ func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0)
} }
// setCloseError mocks base method
func (m *MockUnknownPacketHandler) setCloseError(arg0 error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "setCloseError", arg0)
}
// setCloseError indicates an expected call of setCloseError
func (mr *MockUnknownPacketHandlerMockRecorder) setCloseError(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setCloseError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).setCloseError), arg0)
}

View File

@@ -141,7 +141,6 @@ func (h *packetHandlerMap) close(e error) error {
h.mutex.Unlock() h.mutex.Unlock()
return nil return nil
} }
h.closed = true
var wg sync.WaitGroup var wg sync.WaitGroup
for _, handler := range h.handlers { for _, handler := range h.handlers {
@@ -153,8 +152,9 @@ func (h *packetHandlerMap) close(e error) error {
} }
if h.server != nil { if h.server != nil {
h.server.closeWithError(e) h.server.setCloseError(e)
} }
h.closed = true
h.mutex.Unlock() h.mutex.Unlock()
wg.Wait() wg.Wait()
return getMultiplexer().RemoveConn(h.conn) return getMultiplexer().RemoveConn(h.conn)

View File

@@ -28,7 +28,7 @@ type packetHandler interface {
type unknownPacketHandler interface { type unknownPacketHandler interface {
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
closeWithError(error) error setCloseError(error)
} }
type packetHandlerManager interface { type packetHandlerManager interface {
@@ -293,10 +293,6 @@ func (s *server) Close() error {
if s.closed { if s.closed {
return nil return nil
} }
return s.closeWithMutex()
}
func (s *server) closeWithMutex() error {
s.sessionHandler.CloseServer() s.sessionHandler.CloseServer()
if s.serverError == nil { if s.serverError == nil {
s.serverError = errors.New("server closed") s.serverError = errors.New("server closed")
@@ -312,14 +308,15 @@ func (s *server) closeWithMutex() error {
return err return err
} }
func (s *server) closeWithError(e error) error { func (s *server) setCloseError(e error) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.closed { if s.closed {
return nil return
} }
s.closed = true
s.serverError = e s.serverError = e
return s.closeWithMutex() close(s.errorChan)
} }
// Addr returns the server's network address // Addr returns the server's network address

View File

@@ -451,13 +451,13 @@ var _ = Describe("Server", func() {
close(done) close(done)
}() }()
Expect(serv.closeWithError(testErr)).To(Succeed()) serv.setCloseError(testErr)
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("returns immediately, if an error occurred before", func() { It("returns immediately, if an error occurred before", func() {
testErr := errors.New("test err") testErr := errors.New("test err")
Expect(serv.closeWithError(testErr)).To(Succeed()) serv.setCloseError(testErr)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, err := serv.Accept() _, err := serv.Accept()
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))