diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 35a82ecd..b42f3dce 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -4,15 +4,15 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "testing" "time" - "github.com/golang/mock/gomock" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -525,3 +525,60 @@ var _ = Describe("Updatable AEAD", func() { }) } }) + +func getClientAndServer() (client, server *updatableAEAD) { + trafficSecret1 := make([]byte, 16) + trafficSecret2 := make([]byte, 16) + rand.Read(trafficSecret1) + rand.Read(trafficSecret2) + + cs := cipherSuites[0] + rttStats := utils.NewRTTStats() + client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) + server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) + client.SetReadKey(cs, trafficSecret2) + client.SetWriteKey(cs, trafficSecret1) + server.SetReadKey(cs, trafficSecret1) + server.SetWriteKey(cs, trafficSecret2) + return +} + +func BenchmarkPacketEncryption(b *testing.B) { + client, _ := getClientAndServer() + const l = 1200 + src := make([]byte, l) + rand.Read(src) + ad := make([]byte, 32) + rand.Read(ad) + + for i := 0; i < b.N; i++ { + src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad) + } +} + +func BenchmarkPacketDecryption(b *testing.B) { + client, server := getClientAndServer() + const l = 1200 + src := make([]byte, l) + dst := make([]byte, l) + rand.Read(src) + ad := make([]byte, 32) + rand.Read(ad) + src = client.Seal(src[:0], src[:l], 1337, ad) + + for i := 0; i < b.N; i++ { + if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil { + b.Fatalf("opening failed: %v", err) + } + } +} + +func BenchmarkRollKeys(b *testing.B) { + client, _ := getClientAndServer() + for i := 0; i < b.N; i++ { + client.rollKeys() + } + if int(client.keyPhase) != b.N { + b.Fatal("didn't roll keys often enough") + } +}