diff --git a/server.go b/server.go index 3e2f30a0..990075e1 100644 --- a/server.go +++ b/server.go @@ -70,6 +70,8 @@ type baseServer struct { sessionHandler packetHandlerManager + receivedPackets chan *receivedPacket + // set as a member, so they can be set in the tests newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, bool /* enable 0-RTT */, utils.Logger, protocol.VersionNumber) quicSession @@ -174,15 +176,35 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl sessionHandler: sessionHandler, sessionQueue: make(chan quicSession), errorChan: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, 1000), newSession: newSession, logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlySessions: acceptEarly, } + go s.run() sessionHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } +func (s *baseServer) run() { + for { + select { + case <-s.errorChan: + return + default: + } + select { + case <-s.errorChan: + return + case p := <-s.receivedPackets: + if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer { + p.buffer.Release() + } + } + } +} + var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { if token == nil { return false @@ -327,11 +349,7 @@ func (s *baseServer) Addr() net.Addr { } func (s *baseServer) handlePacket(p *receivedPacket) { - go func() { - if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer { - p.buffer.Release() - } - }() + s.receivedPackets <- p } func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ { diff --git a/server_test.go b/server_test.go index 58846b49..f04f52ae 100644 --- a/server_test.go +++ b/server_test.go @@ -8,6 +8,8 @@ import ( "errors" "net" "reflect" + "runtime/pprof" + "strings" "sync" "time" @@ -23,6 +25,12 @@ import ( . "github.com/onsi/gomega" ) +func areServersRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "quic-go.(*baseServer).run") +} + var _ = Describe("Server", func() { var ( conn *mockPacketConn @@ -77,6 +85,10 @@ var _ = Describe("Server", func() { tlsConf.NextProtos = []string{"proto1"} }) + AfterEach(func() { + Eventually(areServersRunning).Should(BeFalse()) + }) + It("errors when no tls.Config is given", func() { _, err := ListenAddr("localhost:0", nil, nil) Expect(err).To(HaveOccurred()) @@ -165,6 +177,11 @@ var _ = Describe("Server", func() { serv.sessionHandler = phm }) + AfterEach(func() { + phm.EXPECT().CloseServer().MaxTimes(1) + serv.Close() + }) + Context("handling packets", func() { It("drops Initial packets with a too short connection ID", func() { serv.handlePacket(getPacket(&wire.Header{ @@ -611,6 +628,11 @@ var _ = Describe("Server", func() { serv.sessionHandler = phm }) + AfterEach(func() { + phm.EXPECT().CloseServer().MaxTimes(1) + serv.Close() + }) + It("accepts new sessions when they become ready", func() { sess := NewMockQuicSession(mockCtrl)