diff --git a/session.go b/session.go index ab814554b..ab7f9930c 100644 --- a/session.go +++ b/session.go @@ -48,9 +48,10 @@ type Session struct { unpacker *packetUnpacker packer *packetPacker - receivedPackets chan receivedPacket - closeChan chan struct{} - closed bool + receivedPackets chan receivedPacket + sendingScheduled chan struct{} + closeChan chan struct{} + closed bool connectionParametersManager *handshake.ConnectionParametersManager @@ -74,6 +75,7 @@ func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol stopWaitingManager: stopWaitingManager, receivedPackets: make(chan receivedPacket, 1000), // TODO: What if server receives many packets and connection is already closed?! closeChan: make(chan struct{}, 1), + sendingScheduled: make(chan struct{}, 1), rttStats: congestion.RTTStats{}, connectionParametersManager: handshake.NewConnectionParamatersManager(), } @@ -95,15 +97,22 @@ func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol // Run the session main loop func (s *Session) Run() { - sendTimeout := 1 * time.Millisecond for { + // Close immediately if requested + select { + case <-s.closeChan: + return + default: + } + var err error select { case <-s.closeChan: return case p := <-s.receivedPackets: err = s.handlePacket(p.remoteAddr, p.publicHeader, p.r) - case <-time.After(sendTimeout): + s.scheduleSending() + case <-s.sendingScheduled: err = s.sendPacket() } @@ -360,12 +369,18 @@ func (s *Session) sendPacket() error { if err != nil { return err } + + if !s.packer.Empty() { + s.scheduleSending() + } + return nil } // QueueStreamFrame queues a frame for sending to the client func (s *Session) QueueStreamFrame(frame *frames.StreamFrame) error { s.packer.AddStreamFrame(*frame) + s.scheduleSending() return nil } @@ -409,3 +424,11 @@ func (s *Session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) er packet.Write(&b) return s.conn.write(b.Bytes()) } + +// scheduleSending signals that we have data for sending +func (s *Session) scheduleSending() { + select { + case s.sendingScheduled <- struct{}{}: + default: + } +} diff --git a/session_test.go b/session_test.go index a71c07b54..de6fe3744 100644 --- a/session_test.go +++ b/session_test.go @@ -29,6 +29,7 @@ func (m *mockConnection) write(p []byte) error { func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {} +// TODO: Reorganize var _ = Describe("Session", func() { var ( session *Session @@ -271,6 +272,38 @@ var _ = Describe("Session", func() { }) }) + Context("scheduling sending", func() { + BeforeEach(func() { + signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) + Expect(err).ToNot(HaveOccurred()) + scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) + session = NewSession(conn, 0, 0, scfg, nil).(*Session) + }) + + It("sends after queuing a stream frame", func() { + Expect(session.sendingScheduled).NotTo(Receive()) + err := session.QueueStreamFrame(&frames.StreamFrame{StreamID: 1}) + Expect(err).ToNot(HaveOccurred()) + // Try again, so that we detect blocking scheduleSending + err = session.QueueStreamFrame(&frames.StreamFrame{StreamID: 1}) + Expect(err).ToNot(HaveOccurred()) + Expect(session.sendingScheduled).To(Receive()) + }) + + It("sends after receiving a packet", func() { + Expect(session.sendingScheduled).NotTo(Receive()) + session.receivedPackets <- receivedPacket{ + publicHeader: &PublicHeader{}, + r: bytes.NewReader([]byte{ + // FNV hash + "foobar" + 0x18, 0x6f, 0x44, 0xba, 0x97, 0x35, 0xd, 0x6f, 0xbf, 0x64, 0x3c, 0x79, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72, + }), + } + session.Run() + Expect(session.sendingScheduled).To(Receive()) + }) + }) + It("closes when crypto stream errors", func() { signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) Expect(err).ToNot(HaveOccurred()) @@ -289,6 +322,7 @@ var _ = Describe("Session", func() { Expect(err).To(MatchError("CryptoSetup: expected CHLO")) }) + // See https://github.com/lucas-clemente/quic-go/issues/38 PIt("sends public reset when receiving invalid message", func() { signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) Expect(err).ToNot(HaveOccurred())