diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index fc0a3771..6143d993 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -1,6 +1,7 @@ package self_test import ( + "bytes" "context" "fmt" "io" @@ -8,7 +9,10 @@ import ( mrand "math/rand" "net" "os" + "runtime/pprof" "strconv" + "strings" + "sync/atomic" "time" quic "github.com/lucas-clemente/quic-go" @@ -20,23 +24,34 @@ import ( type faultyConn struct { net.PacketConn - Timeout time.Time + + MaxPackets int32 + counter int32 } func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) { - if time.Now().Before(c.Timeout) { - return c.PacketConn.ReadFrom(p) + n, addr, err := c.PacketConn.ReadFrom(p) + counter := atomic.AddInt32(&c.counter, 1) + if counter <= c.MaxPackets { + return n, addr, err } return 0, nil, io.ErrClosedPipe } func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) { - if time.Now().Before(c.Timeout) { + counter := atomic.AddInt32(&c.counter, 1) + if counter <= c.MaxPackets { return c.PacketConn.WriteTo(p, addr) } return 0, io.ErrClosedPipe } +func areHandshakesRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "RunHandshake") +} + var _ = Describe("Timeout tests", func() { checkTimeoutError := func(err error) { ExpectWithOffset(1, err).To(HaveOccurred()) @@ -333,8 +348,18 @@ var _ = Describe("Timeout tests", func() { }) Context("faulty packet conns", func() { - runServer := func(ctx context.Context, ln quic.Listener) error { - sess, err := ln.Accept(ctx) + const handshakeTimeout = time.Second / 2 + + BeforeEach(func() { + Expect(areHandshakesRunning()).To(BeFalse()) + }) + + AfterEach(func() { + Expect(areHandshakesRunning()).To(BeFalse()) + }) + + runServer := func(ln quic.Listener) error { + sess, err := ln.Accept(context.Background()) if err != nil { return err } @@ -365,37 +390,30 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) conn, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) - timeout := time.Duration(mrand.Intn(150)) * time.Millisecond - fmt.Fprintf(GinkgoWriter, "Timeout: %s\n", timeout) + maxPackets := mrand.Int31n(25) + fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets) ln, err := quic.Listen( - &faultyConn{PacketConn: conn, Timeout: time.Now().Add(timeout)}, + &faultyConn{PacketConn: conn, MaxPackets: maxPackets}, getTLSConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return 10 * time.Millisecond }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - serverErrChan := make(chan error, 1) go func() { defer GinkgoRecover() - serverErrChan <- runServer(context.Background(), ln) + serverErrChan <- runServer(ln) }() clientErrChan := make(chan error, 1) go func() { defer GinkgoRecover() sess, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{ - HandshakeTimeout: time.Second, - MaxIdleTimeout: time.Second, + HandshakeTimeout: handshakeTimeout, + MaxIdleTimeout: handshakeTimeout, }), ) if err != nil { @@ -405,16 +423,20 @@ var _ = Describe("Timeout tests", func() { clientErrChan <- runClient(sess) }() - var serverErr, clientErr error - Eventually(serverErrChan, 5*time.Second).Should(Receive(&serverErr)) - if serverErr != nil { + var clientErr error + Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr)) + Expect(clientErr).To(HaveOccurred()) + nErr, ok := clientErr.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nErr.Timeout()).To(BeTrue()) + + select { + case serverErr := <-serverErrChan: + Expect(serverErr).To(HaveOccurred()) Expect(serverErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error())) - } - Eventually(clientErrChan, 5*time.Second).Should(Receive(&clientErr)) - if clientErr != nil { - nErr, ok := clientErr.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nErr.Timeout()).To(BeTrue()) + default: + Expect(ln.Close()).To(Succeed()) + Eventually(serverErrChan).Should(Receive()) } }) @@ -423,44 +445,31 @@ var _ = Describe("Timeout tests", func() { "localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{ - HandshakeTimeout: time.Second, - MaxIdleTimeout: time.Second, + HandshakeTimeout: handshakeTimeout, + MaxIdleTimeout: handshakeTimeout, KeepAlive: true, }), ) Expect(err).ToNot(HaveOccurred()) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return 10 * time.Millisecond }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - // If the connection errors before the handshake completes, the handshake will fail with a - // handshake error on the server side. This means that the session will never be returned - // on ln.Accept(). - // By using this context for ln.Accept(), we make sure that the runServer() still returns. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() serverErrChan := make(chan error, 1) go func() { defer GinkgoRecover() - serverErrChan <- runServer(ctx, ln) + serverErrChan <- runServer(ln) }() addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) conn, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) - timeout := time.Duration(mrand.Intn(150)) * time.Millisecond - fmt.Fprintf(GinkgoWriter, "Timeout: %s\n", timeout) + maxPackets := mrand.Int31n(25) + fmt.Fprintf(GinkgoWriter, "blocking connection after %d packets\n", maxPackets) clientErrChan := make(chan error, 1) go func() { defer GinkgoRecover() sess, err := quic.Dial( - &faultyConn{PacketConn: conn, Timeout: time.Now().Add(timeout)}, - proxy.LocalAddr(), + &faultyConn{PacketConn: conn, MaxPackets: maxPackets}, + ln.Addr(), "localhost", getTLSClientConfig(), getQuicConfig(nil), @@ -472,19 +481,21 @@ var _ = Describe("Timeout tests", func() { clientErrChan <- runClient(sess) }() - var serverErr, clientErr error - Eventually(clientErrChan, 5*time.Second).Should(Receive(&clientErr)) - if clientErr != nil { - Expect(clientErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error())) - cancel() - } - Eventually(serverErrChan, 5*time.Second).Should(Receive(&serverErr)) - if serverErr != nil && serverErr != context.Canceled { + var clientErr error + Eventually(clientErrChan, 5*handshakeTimeout).Should(Receive(&clientErr)) + Expect(clientErr).To(HaveOccurred()) + Expect(clientErr.Error()).To(ContainSubstring(io.ErrClosedPipe.Error())) + Eventually(areHandshakesRunning, 5*handshakeTimeout).Should(BeFalse()) + select { + case serverErr := <-serverErrChan: // The handshake completed on the server side. + Expect(serverErr).To(HaveOccurred()) nErr, ok := serverErr.(net.Error) Expect(ok).To(BeTrue()) Expect(nErr.Timeout()).To(BeTrue()) + default: // The handshake didn't complete + Expect(ln.Close()).To(Succeed()) + Eventually(serverErrChan).Should(Receive()) } - }) }) })