diff --git a/client.go b/client.go index e1f03c87d..3d0bca3df 100644 --- a/client.go +++ b/client.go @@ -146,11 +146,7 @@ func dial( c.tracingID = nextConnTracingID() if c.config.Tracer != nil { - c.tracer = c.config.Tracer.TracerForConnection( - context.WithValue(ctx, ConnectionTracingKey, c.tracingID), - protocol.PerspectiveClient, - c.destConnID, - ) + c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) } if c.tracer != nil { c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) diff --git a/client_test.go b/client_test.go index a3c31e1db..23847329c 100644 --- a/client_test.go +++ b/client_test.go @@ -56,10 +56,12 @@ var _ = Describe("Client", func() { connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) originalClientConnConstructor = newClientConnection tracer = mocklogging.NewMockConnectionTracer(mockCtrl) - tr := mocklogging.NewMockTracer(mockCtrl) - tr.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) - config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.Version1}} + config = &Config{ + Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer { + return tracer + }, + Versions: []protocol.VersionNumber{protocol.Version1}, + } Eventually(areConnsRunning).Should(BeFalse()) packetConn = NewMockSendConn(mockCtrl) packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() diff --git a/config_test.go b/config_test.go index d796b3643..53482ca48 100644 --- a/config_test.go +++ b/config_test.go @@ -1,14 +1,15 @@ package quic import ( + "context" "errors" "fmt" "net" "reflect" "time" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -46,7 +47,7 @@ var _ = Describe("Config", func() { } switch fn := typ.Field(i).Name; fn { - case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease": + case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) @@ -88,8 +89,6 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(true)) case "Allow0RTT": f.Set(reflect.ValueOf(true)) - case "Tracer": - f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl))) default: Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) } @@ -109,11 +108,15 @@ var _ = Describe("Config", func() { Context("cloning", func() { It("clones function fields", func() { - var calledAddrValidation, calledAllowConnectionWindowIncrease bool + var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool c1 := &Config{ GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, + Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer { + calledTracer = true + return nil + }, } c2 := c1.Clone() c2.RequireAddressValidation(&net.UDPAddr{}) @@ -122,6 +125,8 @@ var _ = Describe("Config", func() { Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) _, err := c2.GetConfigForClient(&ClientHelloInfo{}) Expect(err).To(MatchError("nope")) + c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{}) + Expect(calledTracer).To(BeTrue()) }) It("clones non-function fields", func() { diff --git a/example/client/main.go b/example/client/main.go index 19da87262..83f810fd1 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -3,6 +3,7 @@ package main import ( "bufio" "bytes" + "context" "crypto/tls" "crypto/x509" "flag" @@ -57,15 +58,15 @@ func main() { var qconf quic.Config if *enableQlog { - qconf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser { + qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { filename := fmt.Sprintf("client_%x.qlog", connID) f, err := os.Create(filename) if err != nil { log.Fatal(err) } log.Printf("Creating qlog file %s.\n", filename) - return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f) - }) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) + } } roundTripper := &http3.RoundTripper{ TLSClientConfig: &tls.Config{ diff --git a/example/main.go b/example/main.go index ea2a0babe..058144050 100644 --- a/example/main.go +++ b/example/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "crypto/md5" "errors" "flag" @@ -162,15 +163,15 @@ func main() { handler := setupHandler(*www) quicConf := &quic.Config{} if *enableQlog { - quicConf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser { + quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { filename := fmt.Sprintf("server_%x.qlog", connID) f, err := os.Create(filename) if err != nil { log.Fatal(err) } log.Printf("Creating qlog file %s.\n", filename) - return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f) - }) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) + } } var wg sync.WaitGroup diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index 51ae6be01..c24bef27d 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -75,7 +75,9 @@ var _ = Describe("Key Update tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return &keyUpdateConnTracer{} })}), + getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return &keyUpdateConnTracer{} + }}), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptUniStream(context.Background()) diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 86062bd56..740956c54 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -27,7 +27,7 @@ var _ = Describe("Packetization", func() { getTLSConfig(), getQuicConfig(&quic.Config{ DisablePathMTUDiscovery: true, - Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }), + Tracer: newTracer(serverTracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -50,7 +50,7 @@ var _ = Describe("Packetization", func() { getTLSClientConfig(), getQuicConfig(&quic.Config{ DisablePathMTUDiscovery: true, - Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }), + Tracer: newTracer(clientTracer), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 966a57158..10adc856b 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -87,7 +87,7 @@ var ( logBuf *syncedBuffer versionParam string - qlogTracer logging.Tracer + qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer enableQlog bool version quic.VersionNumber @@ -175,7 +175,13 @@ func getQuicConfig(conf *quic.Config) *quic.Config { if conf.Tracer == nil { conf.Tracer = qlogTracer } else if qlogTracer != nil { - conf.Tracer = logging.NewMultiplexedTracer(qlogTracer, conf.Tracer) + origTracer := conf.Tracer + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + qlogTracer(ctx, p, connID), + origTracer(ctx, p, connID), + ) + } } } return conf @@ -232,19 +238,8 @@ func scaleDuration(d time.Duration) time.Duration { return time.Duration(scaleFactor) * d } -type tracer struct { - logging.NullTracer - createNewConnTracer func() logging.ConnectionTracer -} - -var _ logging.Tracer = &tracer{} - -func newTracer(c func() logging.ConnectionTracer) logging.Tracer { - return &tracer{createNewConnTracer: c} -} - -func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { - return t.createNewConnTracer() +func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer } } type packet struct { diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index abc05dd79..5a01270d3 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -201,7 +201,7 @@ var _ = Describe("Timeout tests", func() { getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: idleTimeout, - Tracer: newTracer(func() logging.ConnectionTracer { return tr }), + Tracer: newTracer(tr), DisablePathMTUDiscovery: true, }), ) diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index 377fbe1ed..3bfae3c6b 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -19,14 +19,6 @@ import ( . "github.com/onsi/gomega" ) -type customTracer struct{ logging.NullTracer } - -func (t *customTracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { - return &customConnTracer{} -} - -type customConnTracer struct{ logging.NullConnectionTracer } - var _ = Describe("Handshake tests", func() { addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config { enableQlog := mrand.Int()%3 != 0 @@ -34,22 +26,32 @@ var _ = Describe("Handshake tests", func() { fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer) - var tracers []logging.Tracer + var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer if enableQlog { - tracers = append(tracers, qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { + tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections - fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connectionID) + fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID) return nil } - fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connectionID) - return utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)) - })) + fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connID) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID) + }) } if enableCustomTracer { - tracers = append(tracers, &customTracer{}) + tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return logging.NullConnectionTracer{} + }) } c := conf.Clone() - c.Tracer = logging.NewMultiplexedTracer(tracers...) + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors)) + for _, c := range tracerConstructors { + if tr := c(ctx, p, connID); tr != nil { + tracers = append(tracers, tr) + } + } + return logging.NewMultiplexedConnectionTracer(tracers...) + } return c } diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 684cfe6ae..001ea5da2 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -223,7 +223,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -277,7 +277,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -359,7 +359,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -435,7 +435,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ RequireAddressValidation: func(net.Addr) bool { return true }, Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -496,7 +496,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -541,7 +541,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ MaxIncomingStreams: maxStreams - 1, Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -569,7 +569,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -596,7 +596,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: false, // application rejects 0-RTT - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -622,7 +622,7 @@ var _ = Describe("0-RTT", func() { secondConf := getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }) addFlowControlLimit(secondConf, 100) ln, err := quic.ListenAddrEarly( @@ -699,7 +699,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: 1, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -775,7 +775,7 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Allow0RTT: true, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/tools/qlog.go b/integrationtests/tools/qlog.go index a0854260c..352e0a613 100644 --- a/integrationtests/tools/qlog.go +++ b/integrationtests/tools/qlog.go @@ -2,23 +2,25 @@ package tools import ( "bufio" + "context" "fmt" "io" "log" "os" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/qlog" ) -func NewQlogger(logger io.Writer) logging.Tracer { - return qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { +func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { role := "server" if p == logging.PerspectiveClient { role = "client" } - filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role) + filename := fmt.Sprintf("log_%x_%s.qlog", connID.Bytes(), role) fmt.Fprintf(logger, "Creating %s.\n", filename) f, err := os.Create(filename) if err != nil { @@ -26,6 +28,6 @@ func NewQlogger(logger io.Writer) logging.Tracer { return nil } bw := bufio.NewWriter(f) - return utils.NewBufferedWriteCloser(bw, f) - }) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bw, f), p, connID) + } } diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index ad3b7f7a8..965700c15 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -85,7 +85,9 @@ var _ = Describe("Handshake tests", func() { serverConfig := &quic.Config{} serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return serverTracer + } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() clientTracer := &versionNegotiationTracer{} @@ -93,7 +95,9 @@ var _ = Describe("Handshake tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - maybeAddQlogTracer(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}), + maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer { + return clientTracer + }}), ) Expect(err).ToNot(HaveOccurred()) Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) @@ -111,10 +115,12 @@ var _ = Describe("Handshake tests", func() { expectedVersion := protocol.SupportedVersions[0] // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak + serverTracer := &versionNegotiationTracer{} serverConfig := &quic.Config{} serverConfig.Versions = supportedVersions - serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = newTracer(func() logging.ConnectionTracer { return serverTracer }) + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return serverTracer + } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} @@ -123,9 +129,11 @@ var _ = Describe("Handshake tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - maybeAddQlogTracer(&quic.Config{ + maybeAddQLOGTracer(&quic.Config{ Versions: clientVersions, - Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer }), + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + return clientTracer + }, }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/versionnegotiation/rtt_test.go b/integrationtests/versionnegotiation/rtt_test.go index 2c8868dbd..ce9cba1be 100644 --- a/integrationtests/versionnegotiation/rtt_test.go +++ b/integrationtests/versionnegotiation/rtt_test.go @@ -47,7 +47,7 @@ var _ = Describe("Handshake RTT tests", func() { context.Background(), proxy.LocalAddr().String(), getTLSClientConfig(), - maybeAddQlogTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), + maybeAddQLOGTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), ) Expect(err).To(HaveOccurred()) expectDurationInRTTs(startTime, 1) diff --git a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go index 96a98fa8e..a01ac1f8a 100644 --- a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go +++ b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go @@ -58,7 +58,7 @@ func TestQuicVersionNegotiation(t *testing.T) { RunSpecs(t, "Version Negotiation Suite") } -func maybeAddQlogTracer(c *quic.Config) *quic.Config { +func maybeAddQLOGTracer(c *quic.Config) *quic.Config { if c == nil { c = &quic.Config{} } @@ -69,22 +69,13 @@ func maybeAddQlogTracer(c *quic.Config) *quic.Config { if c.Tracer == nil { c.Tracer = qlogger } else if qlogger != nil { - c.Tracer = logging.NewMultiplexedTracer(qlogger, c.Tracer) + origTracer := c.Tracer + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + qlogger(ctx, p, connID), + origTracer(ctx, p, connID), + ) + } } return c } - -type tracer struct { - logging.NullTracer - createNewConnTracer func() logging.ConnectionTracer -} - -var _ logging.Tracer = &tracer{} - -func newTracer(c func() logging.ConnectionTracer) logging.Tracer { - return &tracer{createNewConnTracer: c} -} - -func (t *tracer) TracerForConnection(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { - return t.createNewConnTracer() -} diff --git a/interface.go b/interface.go index 4d2a91976..7022cac4e 100644 --- a/interface.go +++ b/interface.go @@ -322,7 +322,7 @@ type Config struct { Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool - Tracer logging.Tracer + Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer } type ClientHelloInfo struct { diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index f269e7d65..56741ef29 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -5,7 +5,6 @@ package mocklogging import ( - context "context" net "net" reflect "reflect" @@ -73,17 +72,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) } - -// TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) - ret0, _ := ret[0].(logging.ConnectionTracer) - return ret0 -} - -// TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) -} diff --git a/interop/client/main.go b/interop/client/main.go index eaec9e67a..747619a09 100644 --- a/interop/client/main.go +++ b/interop/client/main.go @@ -21,7 +21,6 @@ import ( "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/interop/http09" "github.com/quic-go/quic-go/interop/utils" - "github.com/quic-go/quic-go/qlog" ) var errUnsupported = errors.New("unsupported test case") @@ -65,11 +64,7 @@ func runTestcase(testcase string) error { flag.Parse() urls := flag.Args() - getLogWriter, err := utils.GetQLOGWriter() - if err != nil { - return err - } - quicConf := &quic.Config{Tracer: qlog.NewTracer(getLogWriter)} + quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer} if testcase == "http3" { r := &http3.RoundTripper{ diff --git a/interop/server/main.go b/interop/server/main.go index 6c607c5c0..df7044624 100644 --- a/interop/server/main.go +++ b/interop/server/main.go @@ -13,7 +13,6 @@ import ( "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/interop/http09" "github.com/quic-go/quic-go/interop/utils" - "github.com/quic-go/quic-go/qlog" ) var tlsConf *tls.Config @@ -38,16 +37,10 @@ func main() { testcase := os.Getenv("TESTCASE") - getLogWriter, err := utils.GetQLOGWriter() - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - // a quic.Config that doesn't do a Retry quicConf := &quic.Config{ RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" }, Allow0RTT: testcase == "zerortt", - Tracer: qlog.NewTracer(getLogWriter), + Tracer: utils.NewQLOGConnectionTracer, } cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key") if err != nil { diff --git a/interop/utils/logging.go b/interop/utils/logging.go index 3a6940ad9..30e3f663f 100644 --- a/interop/utils/logging.go +++ b/interop/utils/logging.go @@ -2,14 +2,17 @@ package utils import ( "bufio" + "context" "fmt" "io" "log" "os" "strings" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/logging" + "github.com/quic-go/quic-go/qlog" ) // GetSSLKeyLog creates a file for the TLS key log @@ -25,25 +28,23 @@ func GetSSLKeyLog() (io.WriteCloser, error) { return f, nil } -// GetQLOGWriter creates the QLOGDIR and returns the GetLogWriter callback -func GetQLOGWriter() (func(perspective logging.Perspective, connID []byte) io.WriteCloser, error) { +// NewQLOGConnectionTracer create a qlog file in QLOGDIR +func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { qlogDir := os.Getenv("QLOGDIR") if len(qlogDir) == 0 { - return nil, nil + return nil } if _, err := os.Stat(qlogDir); os.IsNotExist(err) { if err := os.MkdirAll(qlogDir, 0o666); err != nil { - return nil, fmt.Errorf("failed to create qlog dir %s: %s", qlogDir, err.Error()) + log.Fatalf("failed to create qlog dir %s: %v", qlogDir, err) } } - return func(_ logging.Perspective, connID []byte) io.WriteCloser { - path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID) - f, err := os.Create(path) - if err != nil { - log.Printf("Failed to create qlog file %s: %s", path, err.Error()) - return nil - } - log.Printf("Created qlog file: %s\n", path) - return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f) - }, nil + path := fmt.Sprintf("%s/%x.qlog", strings.TrimRight(qlogDir, "/"), connID) + f, err := os.Create(path) + if err != nil { + log.Printf("Failed to create qlog file %s: %s", path, err.Error()) + return nil + } + log.Printf("Created qlog file: %s\n", path) + return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) } diff --git a/logging/interface.go b/logging/interface.go index efcef151e..2ce8582ec 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -3,7 +3,6 @@ package logging import ( - "context" "net" "time" @@ -101,12 +100,6 @@ type ShortHeader struct { // A Tracer traces events. type Tracer interface { - // TracerForConnection requests a new tracer for a connection. - // The ODCID is the original destination connection ID: - // The destination connection ID that the client used on the first Initial packet it sent on this connection. - // If nil is returned, tracing will be disabled for this connection. - TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer - SentPacket(net.Addr, *Header, ByteCount, []Frame) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) diff --git a/logging/mock_tracer_test.go b/logging/mock_tracer_test.go index 935d899bd..8526cd3a3 100644 --- a/logging/mock_tracer_test.go +++ b/logging/mock_tracer_test.go @@ -5,7 +5,6 @@ package logging import ( - context "context" net "net" reflect "reflect" @@ -72,17 +71,3 @@ func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) } - -// TracerForConnection mocks base method. -func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TracerForConnection", arg0, arg1, arg2) - ret0, _ := ret[0].(ConnectionTracer) - return ret0 -} - -// TracerForConnection indicates an expected call of TracerForConnection. -func (mr *MockTracerMockRecorder) TracerForConnection(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForConnection", reflect.TypeOf((*MockTracer)(nil).TracerForConnection), arg0, arg1, arg2) -} diff --git a/logging/multiplex.go b/logging/multiplex.go index 8e85db494..672a5cdbd 100644 --- a/logging/multiplex.go +++ b/logging/multiplex.go @@ -1,7 +1,6 @@ package logging import ( - "context" "net" "time" ) @@ -23,16 +22,6 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer { return &tracerMultiplexer{tracers} } -func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { - var connTracers []ConnectionTracer - for _, t := range m.tracers { - if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { - connTracers = append(connTracers, ct) - } - } - return NewMultiplexedConnectionTracer(connTracers...) -} - func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { for _, t := range m.tracers { t.SentPacket(remote, hdr, size, frames) diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 9b8515509..d22204999 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -1,7 +1,6 @@ package logging import ( - "context" "errors" "net" "time" @@ -37,46 +36,6 @@ var _ = Describe("Tracing", func() { tracer = NewMultiplexedTracer(tr1, tr2) }) - It("multiplexes the TracerForConnection call", func() { - ctx := context.Background() - connID := protocol.ParseConnectionID([]byte{1, 2, 3}) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) - tracer.TracerForConnection(ctx, PerspectiveClient, connID) - }) - - It("uses multiple connection tracers", func() { - ctx := context.Background() - ctr1 := NewMockConnectionTracer(mockCtrl) - ctr2 := NewMockConnectionTracer(mockCtrl) - connID := protocol.ParseConnectionID([]byte{1, 2, 3}) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr2) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID) - ctr1.EXPECT().LossTimerCanceled() - ctr2.EXPECT().LossTimerCanceled() - tr.LossTimerCanceled() - }) - - It("handles tracers that return a nil ConnectionTracer", func() { - ctx := context.Background() - ctr1 := NewMockConnectionTracer(mockCtrl) - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID) - ctr1.EXPECT().LossTimerCanceled() - tr.LossTimerCanceled() - }) - - It("returns nil when all tracers return a nil ConnectionTracer", func() { - ctx := context.Background() - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) - Expect(tracer.TracerForConnection(ctx, PerspectiveClient, connID)).To(BeNil()) - }) - It("traces the PacketSent event", func() { remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} diff --git a/logging/null_tracer.go b/logging/null_tracer.go index 38052ae3b..de9703857 100644 --- a/logging/null_tracer.go +++ b/logging/null_tracer.go @@ -1,7 +1,6 @@ package logging import ( - "context" "net" "time" ) @@ -12,9 +11,6 @@ type NullTracer struct{} var _ Tracer = &NullTracer{} -func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer { - return NullConnectionTracer{} -} func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { } diff --git a/qlog/qlog.go b/qlog/qlog.go index bc2bb233d..4c480e260 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -2,7 +2,6 @@ package qlog import ( "bytes" - "context" "fmt" "io" "log" @@ -49,26 +48,6 @@ func init() { const eventChanSize = 50 -type tracer struct { - logging.NullTracer - - getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser -} - -var _ logging.Tracer = &tracer{} - -// NewTracer creates a new qlog tracer. -func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser) logging.Tracer { - return &tracer{getLogWriter: getLogWriter} -} - -func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { - if w := t.getLogWriter(p, odcid.Bytes()); w != nil { - return NewConnectionTracer(w, p, odcid) - } - return nil -} - type connectionTracer struct { mutex sync.Mutex diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index b4105c051..dc0d2dc18 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -2,7 +2,6 @@ package qlog import ( "bytes" - "context" "encoding/json" "errors" "io" @@ -51,17 +50,6 @@ type entry struct { } var _ = Describe("Tracing", func() { - Context("tracer", func() { - It("returns nil when there's no io.WriteCloser", func() { - t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil }) - Expect(t.TracerForConnection( - context.Background(), - logging.PerspectiveClient, - protocol.ParseConnectionID([]byte{1, 2, 3, 4}), - )).To(BeNil()) - }) - }) - It("stops writing when encountering an error", func() { buf := &bytes.Buffer{} t := NewConnectionTracer( @@ -88,9 +76,8 @@ var _ = Describe("Tracing", func() { BeforeEach(func() { buf = &bytes.Buffer{} - t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) }) - tracer = t.TracerForConnection( - context.Background(), + tracer = NewConnectionTracer( + nopWriteCloser(buf), logging.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), ) diff --git a/server.go b/server.go index a38a57848..f5b04549e 100644 --- a/server.go +++ b/server.go @@ -106,6 +106,8 @@ type baseServer struct { connQueue chan quicConn connQueueLen int32 // to be used as an atomic + tracer logging.Tracer + logger utils.Logger } @@ -212,7 +214,16 @@ func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Ear return tr.ListenEarly(tlsConf, config) } -func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, onClose func(), acceptEarly bool) (*baseServer, error) { +func newServer( + conn rawConn, + connHandler packetHandlerManager, + connIDGenerator ConnectionIDGenerator, + tlsConf *tls.Config, + config *Config, + tracer logging.Tracer, + onClose func(), + acceptEarly bool, +) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err @@ -229,6 +240,7 @@ func newServer(conn rawConn, connHandler packetHandlerManager, connIDGenerator C running: make(chan struct{}), receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), newConn: newConnection, + tracer: tracer, logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlyConns: acceptEarly, onClose: onClose, @@ -318,8 +330,8 @@ func (s *baseServer) handlePacket(p *receivedPacket) { case s.receivedPackets <- p: default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } } @@ -331,8 +343,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -345,16 +357,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) { if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize { s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false } _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) if err != nil { // should never happen s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -366,8 +378,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s if wire.Is0RTTPacket(p.data) { if !s.acceptEarlyConns { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -378,16 +390,16 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) return false } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -397,8 +409,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s // There's little point in sending a Stateless Reset, since the client // might not have received the token yet. s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -416,8 +428,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { connID, err := wire.ParseConnectionID(p.data, 0) if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) } return false } @@ -430,8 +442,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { if q, ok := s.zeroRTTQueues[connID]; ok { if len(q.packets) >= protocol.Max0RTTQueueLen { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false } @@ -440,8 +452,8 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { } if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false } @@ -468,8 +480,8 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { continue } for _, p := range q.packets { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } p.buffer.Release() } @@ -504,8 +516,8 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { p.buffer.Release() - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return errors.New("too short connection ID") } @@ -585,19 +597,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro var conn quicConn tracingID := nextConnTracingID() if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { - var tracer logging.ConnectionTracer - if s.config.Tracer != nil { - // Use the same connection ID that is passed to the client's GetLogWriter callback. - connID := hdr.DestConnectionID - if origDestConnID.Len() > 0 { - connID = origDestConnID - } - tracer = s.config.Tracer.TracerForConnection( - context.WithValue(context.Background(), ConnectionTracingKey, tracingID), - protocol.PerspectiveServer, - connID, - ) - } config := s.config if s.config.GetConfigForClient != nil { conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) @@ -607,6 +606,15 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } config = populateConfig(conf) } + var tracer logging.ConnectionTracer + if config.Tracer != nil { + // Use the same connection ID that is passed to the client's GetLogWriter callback. + connID := hdr.DestConnectionID + if origDestConnID.Len() > 0 { + connID = origDestConnID + } + tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) + } conn = s.newConn( newSendConn(s.conn, p.remoteAddr, p.info), s.connHandler, @@ -715,8 +723,8 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack // append the Retry integrity tag tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) buf.Data = append(buf.Data, tag[:]...) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) + if s.tracer != nil { + s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) return err @@ -729,8 +737,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) data := p.data[:hdr.ParsedLen()+hdr.Length] extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) } // don't return the error here. Just drop the packet. return nil @@ -738,8 +746,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { // don't return the error here. Just drop the packet. - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) } return nil } @@ -792,8 +800,8 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) + if s.tracer != nil { + s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) return err @@ -803,8 +811,8 @@ func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest pro s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) - if s.config.Tracer != nil { - s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) + if s.tracer != nil { + s.tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) } if _, err := s.conn.WritePacket(data, remote, oob); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) diff --git a/server_test.go b/server_test.go index e78356c81..bf0adfbb3 100644 --- a/server_test.go +++ b/server_test.go @@ -170,7 +170,7 @@ var _ = Describe("Server", func() { Context("server accepting connections that completed the handshake", func() { var ( - ln *Listener + tr *Transport serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer @@ -178,8 +178,8 @@ var _ = Describe("Server", func() { BeforeEach(func() { tracer = mocklogging.NewMockTracer(mockCtrl) - var err error - ln, err = Listen(conn, tlsConf, &Config{Tracer: tracer}) + tr = &Transport{Conn: conn, Tracer: tracer} + ln, err := tr.Listen(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.baseServer phm = NewMockPacketHandlerManager(mockCtrl) @@ -187,7 +187,7 @@ var _ = Describe("Server", func() { }) AfterEach(func() { - ln.Close() + tr.Close() }) Context("handling packets", func() { @@ -276,7 +276,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})) conn := NewMockQUICConn(mockCtrl) serv.newConn = func( _ sendConn, @@ -478,7 +477,6 @@ var _ = Describe("Server", func() { return ok }), ) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) conn := NewMockQUICConn(mockCtrl) serv.newConn = func( @@ -537,7 +535,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }).AnyTimes() - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() acceptConn := make(chan struct{}) var counter uint32 // to be used as an atomic, so we query it in Eventually @@ -665,7 +662,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }).Times(protocol.MaxAcceptQueueSize) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize) var wg sync.WaitGroup wg.Add(protocol.MaxAcceptQueueSize) @@ -737,7 +733,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handlePacket(p) // make sure there are no Write calls on the packet conn @@ -1062,6 +1057,7 @@ var _ = Describe("Server", func() { return ok }) done := make(chan struct{}) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) rejectHdr := parseHeader(b) @@ -1119,7 +1115,6 @@ var _ = Describe("Server", func() { _, ok := fn() return ok }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, @@ -1326,6 +1321,7 @@ var _ = Describe("Server", func() { Context("0-RTT", func() { var ( + tr *Transport serv *baseServer phm *MockPacketHandlerManager tracer *mocklogging.MockTracer @@ -1333,7 +1329,8 @@ var _ = Describe("Server", func() { BeforeEach(func() { tracer = mocklogging.NewMockTracer(mockCtrl) - ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer}) + tr = &Transport{Conn: conn, Tracer: tracer} + ln, err := tr.ListenEarly(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) phm = NewMockPacketHandlerManager(mockCtrl) serv = ln.baseServer @@ -1342,7 +1339,7 @@ var _ = Describe("Server", func() { AfterEach(func() { phm.EXPECT().CloseServer().MaxTimes(1) - serv.Close() + tr.Close() }) It("passes packets to existing connections", func() { @@ -1425,7 +1422,6 @@ var _ = Describe("Server", func() { return conn } - tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any()) phm.EXPECT().Get(connID) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) diff --git a/transport.go b/transport.go index baeb592b7..153675da6 100644 --- a/transport.go +++ b/transport.go @@ -95,10 +95,10 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(conf, true); err != nil { + if err := t.init(true); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, false) + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) if err != nil { return nil, err } @@ -124,10 +124,10 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(conf, true); err != nil { + if err := t.init(true); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.closeServer, true) + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) if err != nil { return nil, err } @@ -141,7 +141,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config return nil, err } conf = populateConfig(conf) - if err := t.init(conf, false); err != nil { + if err := t.init(false); err != nil { return nil, err } var onClose func() @@ -157,7 +157,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C return nil, err } conf = populateConfig(conf) - if err := t.init(conf, false); err != nil { + if err := t.init(false); err != nil { return nil, err } var onClose func() @@ -200,7 +200,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { // only print warnings about the UDP receive buffer size once var receiveBufferWarningOnce sync.Once -func (t *Transport) init(conf *Config, isServer bool) error { +func (t *Transport) init(isServer bool) error { t.initOnce.Do(func() { getMultiplexer().AddConn(t.Conn) @@ -210,7 +210,6 @@ func (t *Transport) init(conf *Config, isServer bool) error { return } - t.Tracer = conf.Tracer t.logger = utils.DefaultLogger // TODO: make this configurable t.conn = conn t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) diff --git a/transport_test.go b/transport_test.go index 19ef54f20..9441e4dec 100644 --- a/transport_test.go +++ b/transport_test.go @@ -61,7 +61,7 @@ var _ = Describe("Transport", func() { It("handles packets for different packet handlers on the same packet conn", func() { packetChan := make(chan packetToRead) tr := &Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{}, true) + tr.init(true) phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) @@ -128,8 +128,9 @@ var _ = Describe("Transport", func() { tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, + Tracer: tracer, } - tr.init(&Config{Tracer: tracer}, true) + tr.init(true) dropped := make(chan struct{}) tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) packetChan <- packetToRead{ @@ -148,7 +149,7 @@ var _ = Describe("Transport", func() { tr := Transport{Conn: newMockPacketConn(packetChan)} defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}, true) + tr.init(true) tr.handlerMap = phm done := make(chan struct{}) @@ -166,7 +167,7 @@ var _ = Describe("Transport", func() { tr := Transport{Conn: newMockPacketConn(packetChan)} defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) - tr.init(&Config{}, true) + tr.init(true) tr.handlerMap = phm tempErr := deadlineError{} @@ -188,7 +189,7 @@ var _ = Describe("Transport", func() { Conn: newMockPacketConn(packetChan), ConnectionIDLength: connID.Len(), } - tr.init(&Config{}, true) + tr.init(true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm @@ -221,7 +222,7 @@ var _ = Describe("Transport", func() { connID := protocol.ParseConnectionID([]byte{2, 3, 4, 5}) packetChan := make(chan packetToRead) tr := Transport{Conn: newMockPacketConn(packetChan)} - tr.init(&Config{}, true) + tr.init(true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm @@ -257,7 +258,7 @@ var _ = Describe("Transport", func() { StatelessResetKey: &StatelessResetKey{1, 2, 3, 4}, ConnectionIDLength: connID.Len(), } - tr.init(&Config{}, true) + tr.init(true) defer tr.Close() phm := NewMockPacketHandlerManager(mockCtrl) tr.handlerMap = phm