diff --git a/session.go b/session.go index 76161d10..eae77601 100644 --- a/session.go +++ b/session.go @@ -44,6 +44,7 @@ type Session struct { packer *packetPacker receivedPackets chan receivedPacket + closeChan chan struct{} // Used to calculate the next packet number from the truncated wire representation lastRcvdPacketNumber protocol.PacketNumber @@ -58,6 +59,7 @@ func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol sentPacketHandler: ackhandler.NewSentPacketHandler(), receivedPacketHandler: ackhandler.NewReceivedPacketHandler(), receivedPackets: make(chan receivedPacket), + closeChan: make(chan struct{}), } cryptoStream, _ := session.NewStream(1) @@ -76,6 +78,8 @@ func (s *Session) Run() { for { var err error select { + case <-s.closeChan: + return case p := <-s.receivedPackets: err = s.handlePacket(p.remoteAddr, p.publicHeader, p.r) case <-time.After(sendTimeout): @@ -190,6 +194,7 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { // Close the connection by sending a ConnectionClose frame func (s *Session) Close(e error) error { + s.closeChan <- struct{}{} if e == nil { e = protocol.NewQuicError(errorcodes.QUIC_PEER_GOING_AWAY, "peer going away") } diff --git a/session_test.go b/session_test.go index abf68f3b..227be291 100644 --- a/session_test.go +++ b/session_test.go @@ -172,17 +172,18 @@ var _ = Describe("Session", func() { ) BeforeEach(func() { + time.Sleep(1 * time.Millisecond) // Wait for old goroutines to finish nGoRoutinesBefore = runtime.NumGoroutine() path := os.Getenv("GOPATH") + "/src/github.com/lucas-clemente/quic-go/example/" signer, err := crypto.NewRSASigner(path+"cert.der", path+"key.der") Expect(err).ToNot(HaveOccurred()) scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) session = NewSession(&mockConnection{}, 0, 0, scfg, nil).(*Session) + go session.Run() + Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore + 2)) }) It("shuts down without error", func() { - // crypto stream is running in separate go routine - Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore + 1)) session.Close(nil) time.Sleep(1 * time.Millisecond) Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore))