diff --git a/benchmark_test.go b/benchmark_test.go index d80727dcb..c4ce9485a 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -2,172 +2,76 @@ package quic import ( "bytes" - "crypto/rand" - "encoding/binary" + "crypto/tls" "fmt" "io" - mrand "math/rand" - "net" - "reflect" - "runtime" - "runtime/debug" + "math/rand" "time" - "unsafe" + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/testdata" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "github.com/lucas-clemente/quic-go/crypto" - "github.com/lucas-clemente/quic-go/handshake" - "github.com/lucas-clemente/quic-go/protocol" ) -type linkedConnection struct { - other *session - c chan []byte -} - -func newLinkedConnection(other *session) *linkedConnection { - c := make(chan []byte, 500) - conn := &linkedConnection{ - c: c, - other: other, - } - go func() { - for packet := range c { - if packet == nil { - return - } - r := bytes.NewReader(packet) - hdr, err := ParsePublicHeader(r, protocol.PerspectiveClient) - if err != nil { - Expect(err).NotTo(HaveOccurred()) - } - hdr.Raw = packet[:len(packet)-r.Len()] - conn.other.handlePacket(&receivedPacket{publicHeader: hdr, data: packet[len(packet)-r.Len():]}) - } - }() - return conn -} - -func (c *linkedConnection) Write(p []byte) error { - packet := getPacketBuffer() - packet = packet[:len(p)] - copy(packet, p) - select { - case c.c <- packet: - default: - } - return nil -} - -func (c *linkedConnection) Read(p []byte) (int, net.Addr, error) { panic("not implemented") } -func (*linkedConnection) SetCurrentRemoteAddr(addr net.Addr) {} -func (*linkedConnection) LocalAddr() net.Addr { panic("not implemented") } -func (*linkedConnection) RemoteAddr() net.Addr { return &net.UDPAddr{} } -func (c *linkedConnection) Close() error { return nil } - -func setAEAD(cs handshake.CryptoSetup, aead crypto.AEAD) { - *(*bool)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true - *(*bool)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("sentSHLO").UnsafeAddr())) = true - *(*crypto.AEAD)(unsafe.Pointer(reflect.ValueOf(cs).Elem().FieldByName("forwardSecureAEAD").UnsafeAddr())) = aead -} - -func setFlowControlParameters(mgr handshake.ConnectionParametersManager) { - sfcw := make([]byte, 4) - cfcw := make([]byte, 4) - binary.LittleEndian.PutUint32(sfcw, uint32(protocol.ReceiveStreamFlowControlWindow)) - binary.LittleEndian.PutUint32(cfcw, uint32(protocol.ReceiveConnectionFlowControlWindow)) - mgr.SetFromMap(map[handshake.Tag][]byte{ - handshake.TagSFCW: sfcw, - handshake.TagCFCW: cfcw, - }) -} - var _ = Describe("Benchmarks", func() { dataLen := 50 /* MB */ * (1 << 20) data := make([]byte, dataLen) + rand.Seed(time.Now().UnixNano()) + for i := range protocol.SupportedVersions { version := protocol.SupportedVersions[i] Context(fmt.Sprintf("with version %d", version), func() { + Measure("transferring a file", func(b Benchmarker) { + rand.Read(data) // no need to check for an error. math.Rand.Read never errors - Measure("two linked sessions", func(b Benchmarker) { - if runtime.GOOS == "windows" { - Skip("benchmark tests disabled on windows, see #325") + // start the server + sconf := &Config{ + TLSConfig: testdata.GetTLSConfig(), + ConnState: func(sess Session, cs ConnState) { + if cs != ConnStateForwardSecure { + return + } + + defer GinkgoRecover() + str, err := sess.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(data) + Expect(err).ToNot(HaveOccurred()) + err = str.Close() + Expect(err).ToNot(HaveOccurred()) + }, } + ln, err := ListenAddr("localhost:0", sconf) + Expect(err).ToNot(HaveOccurred()) + // Serve will error as soon as ln is closed. Ignore all errors here + go ln.Serve() - connID := protocol.ConnectionID(mrand.Uint32()) - - c1 := newLinkedConnection(nil) - session1I, err := newSession(c1, version, connID, nil, func(id protocol.ConnectionID) {}, func(Session, bool) {}) - if err != nil { - Expect(err).NotTo(HaveOccurred()) + // start the client + cconf := &Config{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, } - session1 := session1I.(*session) + sess, err := DialAddr(ln.Addr().String(), cconf) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.AcceptStream() + Expect(err).ToNot(HaveOccurred()) - c2 := newLinkedConnection(session1) - session2I, err := newSession(c2, version, connID, nil, func(id protocol.ConnectionID) {}, func(Session, bool) {}) - if err != nil { - Expect(err).NotTo(HaveOccurred()) - } - session2 := session2I.(*session) - c1.other = session2 - - key := make([]byte, 16) - iv := make([]byte, 4) - rand.Read(key) - rand.Read(iv) - aead, err := crypto.NewAEADAESGCM(key, key, iv, iv) - Expect(err).NotTo(HaveOccurred()) - setAEAD(session1.cryptoSetup, aead) - setAEAD(session2.cryptoSetup, aead) - - setFlowControlParameters(session1.connectionParameters) - setFlowControlParameters(session2.connectionParameters) - - go session1.run() - go session2.run() - - s1stream, err := session1.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - s2stream, err := session2.GetOrOpenStream(5) - Expect(err).NotTo(HaveOccurred()) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - buf := make([]byte, 1500) - dataRead := 0 - for dataRead < dataLen { - n, err := s2stream.Read(buf) - Expect(err).NotTo(HaveOccurred()) - dataRead += n - } - done <- struct{}{} - }() - - time.Sleep(time.Millisecond) + buf := &bytes.Buffer{} + // measure the time it takes to download the dataLen bytes + // note we're measuring the time for the transfer, i.e. excluding the handshake runtime := b.Time("transfer time", func() { - _, err := io.Copy(s1stream, bytes.NewReader(data)) + _, err := io.Copy(buf, str) Expect(err).NotTo(HaveOccurred()) - select { - case <-done: - case <-time.After(10 * time.Second): - debug.PrintStack() - Fail("timeout") - } }) - - session1.Close(nil) - session2.Close(nil) - - // Signal connections to close - c1.c <- nil - c2.c <- nil + // this is *a lot* faster than Expect(buf.Bytes()).To(Equal(data)) + Expect(bytes.Equal(buf.Bytes(), data)).To(BeTrue()) b.RecordValue("transfer rate [MB/s]", float64(dataLen)/1e6/runtime.Seconds()) + + ln.Close() + sess.Close(nil) }, 6) }) }