diff --git a/Changelog.md b/Changelog.md index b06c9746b..91c3fd4ea 100644 --- a/Changelog.md +++ b/Changelog.md @@ -7,6 +7,7 @@ - Add a `quic.Config` option to request truncation of the connection ID from a server - Add a `quic.Config` option to configure the source address validation - Add a `quic.Config` option to configure the handshake timeout +- Add a `quic.Config` option to configure keep-alive - Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details. - Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details. - Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper` diff --git a/client.go b/client.go index a4f614124..1938635c6 100644 --- a/client.go +++ b/client.go @@ -169,6 +169,7 @@ func populateClientConfig(config *Config) *Config { RequestConnectionIDTruncation: config.RequestConnectionIDTruncation, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, + KeepAlive: config.KeepAlive, } } diff --git a/h2quic/client.go b/h2quic/client.go index 906253e08..813f466c5 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -59,6 +59,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) * tlsConf: tlsConfig, config: &quic.Config{ RequestConnectionIDTruncation: true, + KeepAlive: true, }, opts: opts, headerErrored: make(chan struct{}), diff --git a/interface.go b/interface.go index b812d37b4..4ba4cb42b 100644 --- a/interface.go +++ b/interface.go @@ -86,6 +86,8 @@ type Config struct { // MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data. // If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client. MaxReceiveConnectionFlowControlWindow protocol.ByteCount + // KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive. + KeepAlive bool } // A Listener for incoming QUIC connections diff --git a/session.go b/session.go index ae9cf2155..e7e671836 100644 --- a/session.go +++ b/session.go @@ -111,6 +111,9 @@ type session struct { lastNetworkActivityTime time.Time timer *utils.Timer + // keepAlivePingSent stores whether a Ping frame was sent to the peer or not + // it is reset as soon as we receive a packet from the peer + keepAlivePingSent bool } var _ Session = &session{} @@ -302,6 +305,12 @@ runLoop: s.sentPacketHandler.OnAlarm() } + if s.config.KeepAlive && time.Since(s.lastNetworkActivityTime) >= s.idleTimeout()/2 { + // send the PING frame since there is no activity in the session + s.packer.QueueControlFrame(&frames.PingFrame{}) + s.keepAlivePingSent = true + } + if err := s.sendPacket(); err != nil { s.closeLocal(err) } @@ -333,7 +342,12 @@ func (s *session) WaitUntilClosed() { } func (s *session) maybeResetTimer() { - deadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) + var deadline time.Time + if s.config.KeepAlive && !s.keepAlivePingSent { + deadline = s.lastNetworkActivityTime.Add(s.idleTimeout() / 2) + } else { + deadline = s.lastNetworkActivityTime.Add(s.idleTimeout()) + } if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { deadline = utils.MinTime(deadline, ackAlarm) @@ -373,6 +387,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } s.lastNetworkActivityTime = p.rcvTime + s.keepAlivePingSent = false hdr := p.publicHeader data := p.data diff --git a/session_test.go b/session_test.go index fd4952eae..e3b911d89 100644 --- a/session_test.go +++ b/session_test.go @@ -1437,6 +1437,29 @@ var _ = Describe("Session", func() { close(done) }) + Context("keep-alives", func() { + It("sends a ping packet", func() { + sess.config.KeepAlive = true + sess.lastNetworkActivityTime = time.Now().Add(-(sess.idleTimeout() / 2)) + go sess.run() + defer sess.Close(nil) + time.Sleep(60 * time.Millisecond) + Eventually(func() [][]byte { return mconn.written }).ShouldNot(BeEmpty()) + Eventually(func() byte { + // -12 because of the crypto tag. This should be 7 (the frame id for a ping frame). + s := mconn.written[0] + return s[len(s)-12-1] + }).Should(Equal(byte(0x07))) + }) + + It("doesn't send a ping packet if keep-alive is disabled", func() { + sess.lastNetworkActivityTime = time.Now().Add(-(sess.idleTimeout() / 2)) + go sess.run() + defer sess.Close(nil) + Consistently(func() [][]byte { return mconn.written }).Should(BeEmpty()) + }) + }) + Context("timeouts", func() { It("times out due to no network activity", func(done Done) { sess.lastNetworkActivityTime = time.Now().Add(-time.Hour)