diff --git a/server.go b/server.go index dc1fb1fad..660d13188 100644 --- a/server.go +++ b/server.go @@ -86,6 +86,7 @@ func (s *server) Serve() error { // If it does, we only read a truncated packet, which will then end up undecryptable n, remoteAddr, err := s.conn.ReadFrom(data) if err != nil { + _ = s.Close() return err } data = data[:n] diff --git a/server_test.go b/server_test.go index 42de5ada5..b38f84ba3 100644 --- a/server_test.go +++ b/server_test.go @@ -243,6 +243,17 @@ var _ = Describe("Server", func() { Expect(err).To(MatchError(testErr)) }) + It("closes all sessions when encountering a connection error", func() { + err := serv.handlePacket(nil, nil, firstPacket) + Expect(err).ToNot(HaveOccurred()) + Expect(serv.sessions).To(HaveKey(connID)) + Expect(serv.sessions[connID].(*mockSession).closed).To(BeFalse()) + testErr := errors.New("connection error") + conn.readErr = testErr + _ = serv.Serve() + Expect(serv.sessions[connID].(*mockSession).closed).To(BeTrue()) + }) + It("ignores delayed packets with mismatching versions", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred())