diff --git a/codecov.yml b/codecov.yml index 59e4b58f..77e47fbe 100644 --- a/codecov.yml +++ b/codecov.yml @@ -6,6 +6,8 @@ coverage: - internal/handshake/cipher_suite.go - internal/utils/linkedlist/linkedlist.go - internal/testdata + - logging/connection_tracer_multiplexer.go + - logging/tracer_multiplexer.go - testutils/ - fuzzing/ - metrics/ diff --git a/go.mod b/go.mod index 5e1ee107..416f1f7b 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( golang.org/x/sync v0.8.0 golang.org/x/sys v0.23.0 golang.org/x/time v0.5.0 + golang.org/x/tools v0.22.0 ) require ( @@ -30,9 +31,8 @@ require ( github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect - golang.org/x/mod v0.17.0 // indirect + golang.org/x/mod v0.18.0 // indirect golang.org/x/text v0.17.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 25fe9d11..022aa0f4 100644 --- a/go.sum +++ b/go.sum @@ -154,8 +154,8 @@ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -197,8 +197,8 @@ golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= diff --git a/integrationtests/gomodvendor/go.mod b/integrationtests/gomodvendor/go.mod index 3411b837..f3ddd4e2 100644 --- a/integrationtests/gomodvendor/go.mod +++ b/integrationtests/gomodvendor/go.mod @@ -15,11 +15,11 @@ require ( go.uber.org/mock v0.4.0 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect - golang.org/x/mod v0.17.0 // indirect + golang.org/x/mod v0.18.0 // indirect golang.org/x/net v0.28.0 // indirect golang.org/x/sys v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + golang.org/x/tools v0.22.0 // indirect ) replace github.com/quic-go/quic-go => ../../ diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 8d125919..bf1014b0 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -33,8 +33,8 @@ golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= @@ -46,8 +46,8 @@ golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/logging/connection_tracer.go b/logging/connection_tracer.go index 96bf4617..f218e046 100644 --- a/logging/connection_tracer.go +++ b/logging/connection_tracer.go @@ -5,34 +5,36 @@ import ( "time" ) +//go:generate go run generate_multiplexer.go ConnectionTracer connection_tracer.go multiplexer.tmpl connection_tracer_multiplexer.go + // A ConnectionTracer records events. type ConnectionTracer struct { StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID) NegotiatedVersion func(chosen Version, clientVersions, serverVersions []Version) - ClosedConnection func(error) - SentTransportParameters func(*TransportParameters) - ReceivedTransportParameters func(*TransportParameters) + ClosedConnection func(err error) + SentTransportParameters func(parameters *TransportParameters) + ReceivedTransportParameters func(parameters *TransportParameters) RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT - SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) - SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) - ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []Version) - ReceivedRetry func(*Header) - ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame) - ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame) - BufferedPacket func(PacketType, ByteCount) - DroppedPacket func(PacketType, PacketNumber, ByteCount, PacketDropReason) + SentLongHeaderPacket func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) + SentShortHeaderPacket func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) + ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, versions []Version) + ReceivedRetry func(hdr *Header) + ReceivedLongHeaderPacket func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) + ReceivedShortHeaderPacket func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) + BufferedPacket func(packetType PacketType, size ByteCount) + DroppedPacket func(packetType PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) - AcknowledgedPacket func(EncryptionLevel, PacketNumber) - LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason) + AcknowledgedPacket func(encLevel EncryptionLevel, pn PacketNumber) + LostPacket func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) UpdatedMTU func(mtu ByteCount, done bool) - UpdatedCongestionState func(CongestionState) + UpdatedCongestionState func(state CongestionState) UpdatedPTOCount func(value uint32) - UpdatedKeyFromTLS func(EncryptionLevel, Perspective) + UpdatedKeyFromTLS func(encLevel EncryptionLevel, p Perspective) UpdatedKey func(keyPhase KeyPhase, remote bool) - DroppedEncryptionLevel func(EncryptionLevel) + DroppedEncryptionLevel func(encLevel EncryptionLevel) DroppedKey func(keyPhase KeyPhase) - SetLossTimer func(TimerType, EncryptionLevel, time.Time) - LossTimerExpired func(TimerType, EncryptionLevel) + SetLossTimer func(timerType TimerType, encLevel EncryptionLevel, time time.Time) + LossTimerExpired func(timerType TimerType, encLevel EncryptionLevel) LossTimerCanceled func() ECNStateUpdated func(state ECNState, trigger ECNStateTrigger) ChoseALPN func(protocol string) @@ -40,232 +42,3 @@ type ConnectionTracer struct { Close func() Debug func(name, msg string) } - -// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. -func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &ConnectionTracer{ - StartedConnection: func(local, remote net.Addr, srcConnID, destConnID ConnectionID) { - for _, t := range tracers { - if t.StartedConnection != nil { - t.StartedConnection(local, remote, srcConnID, destConnID) - } - } - }, - NegotiatedVersion: func(chosen Version, clientVersions, serverVersions []Version) { - for _, t := range tracers { - if t.NegotiatedVersion != nil { - t.NegotiatedVersion(chosen, clientVersions, serverVersions) - } - } - }, - ClosedConnection: func(e error) { - for _, t := range tracers { - if t.ClosedConnection != nil { - t.ClosedConnection(e) - } - } - }, - SentTransportParameters: func(tp *TransportParameters) { - for _, t := range tracers { - if t.SentTransportParameters != nil { - t.SentTransportParameters(tp) - } - } - }, - ReceivedTransportParameters: func(tp *TransportParameters) { - for _, t := range tracers { - if t.ReceivedTransportParameters != nil { - t.ReceivedTransportParameters(tp) - } - } - }, - RestoredTransportParameters: func(tp *TransportParameters) { - for _, t := range tracers { - if t.RestoredTransportParameters != nil { - t.RestoredTransportParameters(tp) - } - } - }, - SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { - for _, t := range tracers { - if t.SentLongHeaderPacket != nil { - t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) - } - } - }, - SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { - for _, t := range tracers { - if t.SentShortHeaderPacket != nil { - t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) - } - } - }, - ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []Version) { - for _, t := range tracers { - if t.ReceivedVersionNegotiationPacket != nil { - t.ReceivedVersionNegotiationPacket(dest, src, versions) - } - } - }, - ReceivedRetry: func(hdr *Header) { - for _, t := range tracers { - if t.ReceivedRetry != nil { - t.ReceivedRetry(hdr) - } - } - }, - ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { - for _, t := range tracers { - if t.ReceivedLongHeaderPacket != nil { - t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) - } - } - }, - ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { - for _, t := range tracers { - if t.ReceivedShortHeaderPacket != nil { - t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) - } - } - }, - BufferedPacket: func(typ PacketType, size ByteCount) { - for _, t := range tracers { - if t.BufferedPacket != nil { - t.BufferedPacket(typ, size) - } - } - }, - DroppedPacket: func(typ PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) { - for _, t := range tracers { - if t.DroppedPacket != nil { - t.DroppedPacket(typ, pn, size, reason) - } - } - }, - UpdatedMetrics: func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { - for _, t := range tracers { - if t.UpdatedMetrics != nil { - t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) - } - } - }, - AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) { - for _, t := range tracers { - if t.AcknowledgedPacket != nil { - t.AcknowledgedPacket(encLevel, pn) - } - } - }, - LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { - for _, t := range tracers { - if t.LostPacket != nil { - t.LostPacket(encLevel, pn, reason) - } - } - }, - UpdatedMTU: func(mtu ByteCount, done bool) { - for _, t := range tracers { - if t.UpdatedMTU != nil { - t.UpdatedMTU(mtu, done) - } - } - }, - UpdatedCongestionState: func(state CongestionState) { - for _, t := range tracers { - if t.UpdatedCongestionState != nil { - t.UpdatedCongestionState(state) - } - } - }, - UpdatedPTOCount: func(value uint32) { - for _, t := range tracers { - if t.UpdatedPTOCount != nil { - t.UpdatedPTOCount(value) - } - } - }, - UpdatedKeyFromTLS: func(encLevel EncryptionLevel, perspective Perspective) { - for _, t := range tracers { - if t.UpdatedKeyFromTLS != nil { - t.UpdatedKeyFromTLS(encLevel, perspective) - } - } - }, - UpdatedKey: func(generation KeyPhase, remote bool) { - for _, t := range tracers { - if t.UpdatedKey != nil { - t.UpdatedKey(generation, remote) - } - } - }, - DroppedEncryptionLevel: func(encLevel EncryptionLevel) { - for _, t := range tracers { - if t.DroppedEncryptionLevel != nil { - t.DroppedEncryptionLevel(encLevel) - } - } - }, - DroppedKey: func(generation KeyPhase) { - for _, t := range tracers { - if t.DroppedKey != nil { - t.DroppedKey(generation) - } - } - }, - SetLossTimer: func(typ TimerType, encLevel EncryptionLevel, exp time.Time) { - for _, t := range tracers { - if t.SetLossTimer != nil { - t.SetLossTimer(typ, encLevel, exp) - } - } - }, - LossTimerExpired: func(typ TimerType, encLevel EncryptionLevel) { - for _, t := range tracers { - if t.LossTimerExpired != nil { - t.LossTimerExpired(typ, encLevel) - } - } - }, - LossTimerCanceled: func() { - for _, t := range tracers { - if t.LossTimerCanceled != nil { - t.LossTimerCanceled() - } - } - }, - ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) { - for _, t := range tracers { - if t.ECNStateUpdated != nil { - t.ECNStateUpdated(state, trigger) - } - } - }, - ChoseALPN: func(protocol string) { - for _, t := range tracers { - if t.ChoseALPN != nil { - t.ChoseALPN(protocol) - } - } - }, - Close: func() { - for _, t := range tracers { - if t.Close != nil { - t.Close() - } - } - }, - Debug: func(name, msg string) { - for _, t := range tracers { - if t.Debug != nil { - t.Debug(name, msg) - } - } - }, - } -} diff --git a/logging/connection_tracer_multiplexer.go b/logging/connection_tracer_multiplexer.go new file mode 100644 index 00000000..3a87058c --- /dev/null +++ b/logging/connection_tracer_multiplexer.go @@ -0,0 +1,236 @@ +// Code generated by generate_multiplexer.go; DO NOT EDIT. + +package logging + +import ( + "net" + "time" +) + +func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &ConnectionTracer{ + StartedConnection: func(local net.Addr, remote net.Addr, srcConnID ConnectionID, destConnID ConnectionID) { + for _, t := range tracers { + if t.StartedConnection != nil { + t.StartedConnection(local, remote, srcConnID, destConnID) + } + } + }, + NegotiatedVersion: func(chosen Version, clientVersions []Version, serverVersions []Version) { + for _, t := range tracers { + if t.NegotiatedVersion != nil { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } + } + }, + ClosedConnection: func(err error) { + for _, t := range tracers { + if t.ClosedConnection != nil { + t.ClosedConnection(err) + } + } + }, + SentTransportParameters: func(parameters *TransportParameters) { + for _, t := range tracers { + if t.SentTransportParameters != nil { + t.SentTransportParameters(parameters) + } + } + }, + ReceivedTransportParameters: func(parameters *TransportParameters) { + for _, t := range tracers { + if t.ReceivedTransportParameters != nil { + t.ReceivedTransportParameters(parameters) + } + } + }, + RestoredTransportParameters: func(parameters *TransportParameters) { + for _, t := range tracers { + if t.RestoredTransportParameters != nil { + t.RestoredTransportParameters(parameters) + } + } + }, + SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentLongHeaderPacket != nil { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentShortHeaderPacket != nil { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + ReceivedVersionNegotiationPacket: func(dest ArbitraryLenConnectionID, src ArbitraryLenConnectionID, versions []Version) { + for _, t := range tracers { + if t.ReceivedVersionNegotiationPacket != nil { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + } + } + }, + ReceivedRetry: func(hdr *Header) { + for _, t := range tracers { + if t.ReceivedRetry != nil { + t.ReceivedRetry(hdr) + } + } + }, + ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedLongHeaderPacket != nil { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + } + } + }, + ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedShortHeaderPacket != nil { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + } + } + }, + BufferedPacket: func(packetType PacketType, size ByteCount) { + for _, t := range tracers { + if t.BufferedPacket != nil { + t.BufferedPacket(packetType, size) + } + } + }, + DroppedPacket: func(packetType PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(packetType, pn, size, reason) + } + } + }, + UpdatedMetrics: func(rttStats *RTTStats, cwnd ByteCount, bytesInFlight ByteCount, packetsInFlight int) { + for _, t := range tracers { + if t.UpdatedMetrics != nil { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + } + } + }, + AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range tracers { + if t.AcknowledgedPacket != nil { + t.AcknowledgedPacket(encLevel, pn) + } + } + }, + LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range tracers { + if t.LostPacket != nil { + t.LostPacket(encLevel, pn, reason) + } + } + }, + UpdatedMTU: func(mtu ByteCount, done bool) { + for _, t := range tracers { + if t.UpdatedMTU != nil { + t.UpdatedMTU(mtu, done) + } + } + }, + UpdatedCongestionState: func(state CongestionState) { + for _, t := range tracers { + if t.UpdatedCongestionState != nil { + t.UpdatedCongestionState(state) + } + } + }, + UpdatedPTOCount: func(value uint32) { + for _, t := range tracers { + if t.UpdatedPTOCount != nil { + t.UpdatedPTOCount(value) + } + } + }, + UpdatedKeyFromTLS: func(encLevel EncryptionLevel, p Perspective) { + for _, t := range tracers { + if t.UpdatedKeyFromTLS != nil { + t.UpdatedKeyFromTLS(encLevel, p) + } + } + }, + UpdatedKey: func(keyPhase KeyPhase, remote bool) { + for _, t := range tracers { + if t.UpdatedKey != nil { + t.UpdatedKey(keyPhase, remote) + } + } + }, + DroppedEncryptionLevel: func(encLevel EncryptionLevel) { + for _, t := range tracers { + if t.DroppedEncryptionLevel != nil { + t.DroppedEncryptionLevel(encLevel) + } + } + }, + DroppedKey: func(keyPhase KeyPhase) { + for _, t := range tracers { + if t.DroppedKey != nil { + t.DroppedKey(keyPhase) + } + } + }, + SetLossTimer: func(timerType TimerType, encLevel EncryptionLevel, time time.Time) { + for _, t := range tracers { + if t.SetLossTimer != nil { + t.SetLossTimer(timerType, encLevel, time) + } + } + }, + LossTimerExpired: func(timerType TimerType, encLevel EncryptionLevel) { + for _, t := range tracers { + if t.LossTimerExpired != nil { + t.LossTimerExpired(timerType, encLevel) + } + } + }, + LossTimerCanceled: func() { + for _, t := range tracers { + if t.LossTimerCanceled != nil { + t.LossTimerCanceled() + } + } + }, + ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) { + for _, t := range tracers { + if t.ECNStateUpdated != nil { + t.ECNStateUpdated(state, trigger) + } + } + }, + ChoseALPN: func(protocol string) { + for _, t := range tracers { + if t.ChoseALPN != nil { + t.ChoseALPN(protocol) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + Debug: func(name string, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + } +} diff --git a/logging/connection_tracer_test.go b/logging/connection_tracer_test.go index fa79f6a5..95229869 100644 --- a/logging/connection_tracer_test.go +++ b/logging/connection_tracer_test.go @@ -2,349 +2,20 @@ package logging_test import ( "errors" - "net" "testing" - "time" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" - "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" - - "go.uber.org/mock/gomock" + "github.com/stretchr/testify/require" ) -func TestConnectionTracerStartedConnection(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - dest := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) - src := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - tr1.EXPECT().StartedConnection(local, remote, src, dest) - tr2.EXPECT().StartedConnection(local, remote, src, dest) - tracer.StartedConnection(local, remote, src, dest) -} - -func TestConnectionTracerNegotiatedVersion(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - chosen := protocol.Version2 - client := []protocol.Version{protocol.Version1} - server := []protocol.Version{13, 37} - tr1.EXPECT().NegotiatedVersion(chosen, client, server) - tr2.EXPECT().NegotiatedVersion(chosen, client, server) - tracer.NegotiatedVersion(chosen, client, server) -} - -func TestConnectionTracerClosedConnection(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) +func TestConnectionTracerMultiplexing(t *testing.T) { + var err1, err2 error + t1 := &logging.ConnectionTracer{ClosedConnection: func(e error) { err1 = e }} + t2 := &logging.ConnectionTracer{ClosedConnection: func(e error) { err2 = e }} tracer := logging.NewMultiplexedConnectionTracer(t1, t2) e := errors.New("test err") - tr1.EXPECT().ClosedConnection(e) - tr2.EXPECT().ClosedConnection(e) tracer.ClosedConnection(e) -} - -func TestConnectionTracerSentTransportParameters(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().SentTransportParameters(tp) - tr2.EXPECT().SentTransportParameters(tp) - tracer.SentTransportParameters(tp) -} - -func TestConnectionTracerReceivedTransportParameters(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().ReceivedTransportParameters(tp) - tr2.EXPECT().ReceivedTransportParameters(tp) - tracer.ReceivedTransportParameters(tp) -} - -func TestConnectionTracerRestoredTransportParameters(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tp := &wire.TransportParameters{InitialMaxData: 1337} - tr1.EXPECT().RestoredTransportParameters(tp) - tr2.EXPECT().RestoredTransportParameters(tp) - tracer.RestoredTransportParameters(tp) -} - -func TestConnectionTracerSentLongHeaderPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - hdr := &logging.ExtendedHeader{Header: logging.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} - ack := &logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}} - ping := &logging.PingFrame{} - tr1.EXPECT().SentLongHeaderPacket(hdr, logging.ByteCount(1337), logging.ECTNot, ack, []logging.Frame{ping}) - tr2.EXPECT().SentLongHeaderPacket(hdr, logging.ByteCount(1337), logging.ECTNot, ack, []logging.Frame{ping}) - tracer.SentLongHeaderPacket(hdr, 1337, logging.ECTNot, ack, []logging.Frame{ping}) -} - -func TestConnectionTracerSentShortHeaderPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - hdr := &logging.ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} - ack := &logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}} - ping := &logging.PingFrame{} - tr1.EXPECT().SentShortHeaderPacket(hdr, logging.ByteCount(1337), logging.ECNCE, ack, []logging.Frame{ping}) - tr2.EXPECT().SentShortHeaderPacket(hdr, logging.ByteCount(1337), logging.ECNCE, ack, []logging.Frame{ping}) - tracer.SentShortHeaderPacket(hdr, 1337, logging.ECNCE, ack, []logging.Frame{ping}) -} - -func TestConnectionTracerReceivedVersionNegotiationPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - src := logging.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13} - dest := logging.ArbitraryLenConnectionID{1, 2, 3, 4} - tr1.EXPECT().ReceivedVersionNegotiationPacket(dest, src, []logging.Version{1337}) - tr2.EXPECT().ReceivedVersionNegotiationPacket(dest, src, []logging.Version{1337}) - tracer.ReceivedVersionNegotiationPacket(dest, src, []logging.Version{1337}) -} - -func TestConnectionTracerReceivedRetry(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - hdr := &logging.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} - tr1.EXPECT().ReceivedRetry(hdr) - tr2.EXPECT().ReceivedRetry(hdr) - tracer.ReceivedRetry(hdr) -} - -func TestConnectionTracerReceivedLongHeaderPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - hdr := &logging.ExtendedHeader{Header: logging.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} - ping := &logging.PingFrame{} - tr1.EXPECT().ReceivedLongHeaderPacket(hdr, logging.ByteCount(1337), logging.ECT1, []logging.Frame{ping}) - tr2.EXPECT().ReceivedLongHeaderPacket(hdr, logging.ByteCount(1337), logging.ECT1, []logging.Frame{ping}) - tracer.ReceivedLongHeaderPacket(hdr, 1337, logging.ECT1, []logging.Frame{ping}) -} - -func TestConnectionTracerReceivedShortHeaderPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - hdr := &logging.ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} - ping := &logging.PingFrame{} - tr1.EXPECT().ReceivedShortHeaderPacket(hdr, logging.ByteCount(1337), logging.ECT0, []logging.Frame{ping}) - tr2.EXPECT().ReceivedShortHeaderPacket(hdr, logging.ByteCount(1337), logging.ECT0, []logging.Frame{ping}) - tracer.ReceivedShortHeaderPacket(hdr, 1337, logging.ECT0, []logging.Frame{ping}) -} - -func TestConnectionTracerBufferedPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().BufferedPacket(logging.PacketTypeHandshake, logging.ByteCount(1337)) - tr2.EXPECT().BufferedPacket(logging.PacketTypeHandshake, logging.ByteCount(1337)) - tracer.BufferedPacket(logging.PacketTypeHandshake, 1337) -} - -func TestConnectionTracerDroppedPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().DroppedPacket(logging.PacketTypeInitial, logging.PacketNumber(42), logging.ByteCount(1337), logging.PacketDropHeaderParseError) - tr2.EXPECT().DroppedPacket(logging.PacketTypeInitial, logging.PacketNumber(42), logging.ByteCount(1337), logging.PacketDropHeaderParseError) - tracer.DroppedPacket(logging.PacketTypeInitial, 42, 1337, logging.PacketDropHeaderParseError) -} - -func TestConnectionTracerUpdatedMTU(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().UpdatedMTU(logging.ByteCount(1337), true) - tr2.EXPECT().UpdatedMTU(logging.ByteCount(1337), true) - tracer.UpdatedMTU(1337, true) -} - -func TestConnectionTracerUpdatedCongestionState(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().UpdatedCongestionState(logging.CongestionStateRecovery) - tr2.EXPECT().UpdatedCongestionState(logging.CongestionStateRecovery) - tracer.UpdatedCongestionState(logging.CongestionStateRecovery) -} - -func TestConnectionTracerUpdatedMetrics(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - rttStats := &logging.RTTStats{} - rttStats.UpdateRTT(time.Second, 0, time.Now()) - tr1.EXPECT().UpdatedMetrics(rttStats, logging.ByteCount(1337), logging.ByteCount(42), 13) - tr2.EXPECT().UpdatedMetrics(rttStats, logging.ByteCount(1337), logging.ByteCount(42), 13) - tracer.UpdatedMetrics(rttStats, 1337, 42, 13) -} - -func TestConnectionTracerAcknowledgedPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().AcknowledgedPacket(logging.EncryptionHandshake, logging.PacketNumber(42)) - tr2.EXPECT().AcknowledgedPacket(logging.EncryptionHandshake, logging.PacketNumber(42)) - tracer.AcknowledgedPacket(logging.EncryptionHandshake, 42) -} - -func TestConnectionTracerLostPacket(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().LostPacket(logging.EncryptionHandshake, logging.PacketNumber(42), logging.PacketLossReorderingThreshold) - tr2.EXPECT().LostPacket(logging.EncryptionHandshake, logging.PacketNumber(42), logging.PacketLossReorderingThreshold) - tracer.LostPacket(logging.EncryptionHandshake, 42, logging.PacketLossReorderingThreshold) -} - -func TestConnectionTracerUpdatedPTOCount(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().UpdatedPTOCount(uint32(88)) - tr2.EXPECT().UpdatedPTOCount(uint32(88)) - tracer.UpdatedPTOCount(88) -} - -func TestConnectionTracerUpdatedKeyFromTLS(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().UpdatedKeyFromTLS(logging.EncryptionHandshake, logging.PerspectiveClient) - tr2.EXPECT().UpdatedKeyFromTLS(logging.EncryptionHandshake, logging.PerspectiveClient) - tracer.UpdatedKeyFromTLS(logging.EncryptionHandshake, logging.PerspectiveClient) -} - -func TestConnectionTracerUpdatedKey(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().UpdatedKey(logging.KeyPhase(42), true) - tr2.EXPECT().UpdatedKey(logging.KeyPhase(42), true) - tracer.UpdatedKey(logging.KeyPhase(42), true) -} - -func TestConnectionTracerDroppedEncryptionLevel(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().DroppedEncryptionLevel(logging.EncryptionHandshake) - tr2.EXPECT().DroppedEncryptionLevel(logging.EncryptionHandshake) - tracer.DroppedEncryptionLevel(logging.EncryptionHandshake) -} - -func TestConnectionTracerDroppedKey(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().DroppedKey(logging.KeyPhase(123)) - tr2.EXPECT().DroppedKey(logging.KeyPhase(123)) - tracer.DroppedKey(123) -} - -func TestConnectionTracerSetLossTimer(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - now := time.Now() - tr1.EXPECT().SetLossTimer(logging.TimerTypePTO, logging.EncryptionHandshake, now) - tr2.EXPECT().SetLossTimer(logging.TimerTypePTO, logging.EncryptionHandshake, now) - tracer.SetLossTimer(logging.TimerTypePTO, logging.EncryptionHandshake, now) -} - -func TestConnectionTracerLossTimerExpired(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().LossTimerExpired(logging.TimerTypePTO, logging.EncryptionHandshake) - tr2.EXPECT().LossTimerExpired(logging.TimerTypePTO, logging.EncryptionHandshake) - tracer.LossTimerExpired(logging.TimerTypePTO, logging.EncryptionHandshake) -} - -func TestConnectionTracerLossTimerCanceled(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().LossTimerCanceled() - tr2.EXPECT().LossTimerCanceled() - tracer.LossTimerCanceled() -} - -func TestConnectionTracerClose(t *testing.T) { - ctrl := gomock.NewController(t) - t1, tr1 := mocklogging.NewMockConnectionTracer(ctrl) - t2, tr2 := mocklogging.NewMockConnectionTracer(ctrl) - tracer := logging.NewMultiplexedConnectionTracer(t1, t2) - - tr1.EXPECT().Close() - tr2.EXPECT().Close() - tracer.Close() + require.Equal(t, e, err1) + require.Equal(t, e, err2) } diff --git a/logging/generate_multiplexer.go b/logging/generate_multiplexer.go new file mode 100644 index 00000000..c152b846 --- /dev/null +++ b/logging/generate_multiplexer.go @@ -0,0 +1,161 @@ +//go:build generate + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "log" + "os" + "strings" + "text/template" + + "golang.org/x/tools/imports" +) + +func main() { + if len(os.Args) != 5 { + log.Fatalf("Usage: %s ", os.Args[0]) + } + + structName := os.Args[1] + inputFile := os.Args[2] + templateFile := os.Args[3] + outputFile := os.Args[4] + + fset := token.NewFileSet() + + // Parse the input file containing the struct type + file, err := parser.ParseFile(fset, inputFile, nil, parser.AllErrors) + if err != nil { + log.Fatalf("Failed to parse file: %v", err) + } + + var fields []*ast.Field + + // Find the specified struct type in the AST + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok || typeSpec.Name.Name != structName { + continue + } + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + log.Fatalf("%s is not a struct", structName) + } + fields = structType.Fields.List + break + } + } + + if fields == nil { + log.Fatalf("Could not find %s type", structName) + } + + // Prepare data for the template + type FieldData struct { + Name string + Params string + Args string + HasParams bool + ReturnTypes string + HasReturn bool + } + + var fieldDataList []FieldData + + for _, field := range fields { + funcType, ok := field.Type.(*ast.FuncType) + if !ok { + continue + } + for _, name := range field.Names { + fieldData := FieldData{Name: name.Name} + + // extract parameters + var params []string + var args []string + if funcType.Params != nil { + for i, param := range funcType.Params.List { + // We intentionally reject unnamed (and, further down, "_") function parameters. + // We could auto-generate parameter names, + // but having meaningful variable names will be more helpful for the user. + if len(param.Names) == 0 { + log.Fatalf("encountered unnamed parameter at position %d in function %s", i, fieldData.Name) + } + var buf bytes.Buffer + printer.Fprint(&buf, fset, param.Type) + paramType := buf.String() + for _, paramName := range param.Names { + if paramName.Name == "_" { + log.Fatalf("encountered underscore parameter at position %d in function %s", i, fieldData.Name) + } + params = append(params, fmt.Sprintf("%s %s", paramName.Name, paramType)) + args = append(args, paramName.Name) + } + } + } + fieldData.Params = strings.Join(params, ", ") + fieldData.Args = strings.Join(args, ", ") + fieldData.HasParams = len(params) > 0 + + // extract return types + if funcType.Results != nil && len(funcType.Results.List) > 0 { + fieldData.HasReturn = true + var returns []string + for _, result := range funcType.Results.List { + var buf bytes.Buffer + printer.Fprint(&buf, fset, result.Type) + returns = append(returns, buf.String()) + } + if len(returns) == 1 { + fieldData.ReturnTypes = fmt.Sprintf(" %s", returns[0]) + } else { + fieldData.ReturnTypes = fmt.Sprintf(" (%s)", strings.Join(returns, ", ")) + } + } + + fieldDataList = append(fieldDataList, fieldData) + } + } + + // Read the template from file + templateContent, err := os.ReadFile(templateFile) + if err != nil { + log.Fatalf("Failed to read template file: %v", err) + } + + // Generate the code using the template + tmpl, err := template.New("multiplexer").Funcs(template.FuncMap{"join": strings.Join}).Parse(string(templateContent)) + if err != nil { + log.Fatalf("Failed to parse template: %v", err) + } + + var generatedCode bytes.Buffer + generatedCode.WriteString("// Code generated by generate_multiplexer.go; DO NOT EDIT.\n\n") + if err = tmpl.Execute(&generatedCode, map[string]interface{}{ + "Fields": fieldDataList, + "StructName": structName, + }); err != nil { + log.Fatalf("Failed to execute template: %v", err) + } + + // Format the generated code and add imports + formattedCode, err := imports.Process(outputFile, generatedCode.Bytes(), nil) + if err != nil { + log.Fatalf("Failed to process imports: %v", err) + } + + if err := os.WriteFile(outputFile, formattedCode, 0o644); err != nil { + log.Fatalf("Failed to write output file: %v", err) + } +} diff --git a/logging/multiplexer.tmpl b/logging/multiplexer.tmpl new file mode 100644 index 00000000..9ba52e0f --- /dev/null +++ b/logging/multiplexer.tmpl @@ -0,0 +1,21 @@ +package logging + +func NewMultiplexed{{ .StructName }} (tracers ...*{{ .StructName }}) *{{ .StructName }} { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &{{ .StructName }}{ + {{- range .Fields }} + {{ .Name }}: func({{ .Params }}){{ .ReturnTypes }} { + for _, t := range tracers { + if t.{{ .Name }} != nil { + t.{{ .Name }}({{ .Args }}) + } + } + }, + {{- end }} + } +} diff --git a/logging/tracer.go b/logging/tracer.go index 625a809e..4fe01462 100644 --- a/logging/tracer.go +++ b/logging/tracer.go @@ -2,58 +2,13 @@ package logging import "net" +//go:generate go run generate_multiplexer.go Tracer tracer.go multiplexer.tmpl tracer_multiplexer.go + // A Tracer traces events. type Tracer struct { - SentPacket func(net.Addr, *Header, ByteCount, []Frame) - SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []Version) - DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) + SentPacket func(dest net.Addr, hdr *Header, size ByteCount, frames []Frame) + SentVersionNegotiationPacket func(dest net.Addr, destConnID, srcConnID ArbitraryLenConnectionID, versions []Version) + DroppedPacket func(addr net.Addr, packetType PacketType, size ByteCount, reason PacketDropReason) Debug func(name, msg string) Close func() } - -// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. -func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &Tracer{ - SentPacket: func(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { - for _, t := range tracers { - if t.SentPacket != nil { - t.SentPacket(remote, hdr, size, frames) - } - } - }, - SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []Version) { - for _, t := range tracers { - if t.SentVersionNegotiationPacket != nil { - t.SentVersionNegotiationPacket(remote, dest, src, versions) - } - } - }, - DroppedPacket: func(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range tracers { - if t.DroppedPacket != nil { - t.DroppedPacket(remote, typ, size, reason) - } - } - }, - Debug: func(name, msg string) { - for _, t := range tracers { - if t.Debug != nil { - t.Debug(name, msg) - } - } - }, - Close: func() { - for _, t := range tracers { - if t.Close != nil { - t.Close() - } - } - }, - } -} diff --git a/logging/tracer_multiplexer.go b/logging/tracer_multiplexer.go new file mode 100644 index 00000000..f0878cfe --- /dev/null +++ b/logging/tracer_multiplexer.go @@ -0,0 +1,51 @@ +// Code generated by generate_multiplexer.go; DO NOT EDIT. + +package logging + +import "net" + +func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &Tracer{ + SentPacket: func(dest net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range tracers { + if t.SentPacket != nil { + t.SentPacket(dest, hdr, size, frames) + } + } + }, + SentVersionNegotiationPacket: func(dest net.Addr, destConnID ArbitraryLenConnectionID, srcConnID ArbitraryLenConnectionID, versions []Version) { + for _, t := range tracers { + if t.SentVersionNegotiationPacket != nil { + t.SentVersionNegotiationPacket(dest, destConnID, srcConnID, versions) + } + } + }, + DroppedPacket: func(addr net.Addr, packetType PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(addr, packetType, size, reason) + } + } + }, + Debug: func(name string, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + } +} diff --git a/logging/tracer_test.go b/logging/tracer_test.go index c3641e38..b1fbc39b 100644 --- a/logging/tracer_test.go +++ b/logging/tracer_test.go @@ -4,12 +4,10 @@ import ( "net" "testing" - mocklogging "github.com/quic-go/quic-go/internal/mocks/logging" "github.com/quic-go/quic-go/internal/protocol" . "github.com/quic-go/quic-go/logging" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" ) func TestNilTracerWhenEmpty(t *testing.T) { @@ -23,69 +21,16 @@ func TestSingleTracer(t *testing.T) { } func TestTracerPacketSent(t *testing.T) { - ctrl := gomock.NewController(t) - - t1, tr1 := mocklogging.NewMockTracer(ctrl) - t2, tr2 := mocklogging.NewMockTracer(ctrl) + var s1, s2 ByteCount + t1 := &Tracer{SentPacket: func(_ net.Addr, _ *Header, s ByteCount, _ []Frame) { s1 = s }} + t2 := &Tracer{SentPacket: func(_ net.Addr, _ *Header, s ByteCount, _ []Frame) { s2 = s }} tracer := NewMultiplexedTracer(t1, t2, &Tracer{}) + const size ByteCount = 1024 remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} f := &MaxDataFrame{MaximumData: 1337} - tr1.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) - tr2.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) - tracer.SentPacket(remote, hdr, 1024, []Frame{f}) -} - -func TestTracerVersionNegotiationSent(t *testing.T) { - ctrl := gomock.NewController(t) - - t1, tr1 := mocklogging.NewMockTracer(ctrl) - t2, tr2 := mocklogging.NewMockTracer(ctrl) - tracer := NewMultiplexedTracer(t1, t2, &Tracer{}) - - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - src := ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13} - dest := ArbitraryLenConnectionID{1, 2, 3, 4} - versions := []Version{1, 2, 3} - tr1.EXPECT().SentVersionNegotiationPacket(remote, dest, src, versions) - tr2.EXPECT().SentVersionNegotiationPacket(remote, dest, src, versions) - tracer.SentVersionNegotiationPacket(remote, dest, src, versions) -} - -func TestTracerPacketDropped(t *testing.T) { - ctrl := gomock.NewController(t) - - t1, tr1 := mocklogging.NewMockTracer(ctrl) - t2, tr2 := mocklogging.NewMockTracer(ctrl) - tracer := NewMultiplexedTracer(t1, t2, &Tracer{}) - - remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - tr1.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) - tr2.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) - tracer.DroppedPacket(remote, PacketTypeRetry, 1024, PacketDropDuplicate) -} - -func TestTracerDebug(t *testing.T) { - ctrl := gomock.NewController(t) - - t1, tr1 := mocklogging.NewMockTracer(ctrl) - t2, tr2 := mocklogging.NewMockTracer(ctrl) - tracer := NewMultiplexedTracer(t1, t2, &Tracer{}) - - tr1.EXPECT().Debug("foo", "bar") - tr2.EXPECT().Debug("foo", "bar") - tracer.Debug("foo", "bar") -} - -func TestTracerClose(t *testing.T) { - ctrl := gomock.NewController(t) - - t1, tr1 := mocklogging.NewMockTracer(ctrl) - t2, tr2 := mocklogging.NewMockTracer(ctrl) - tracer := NewMultiplexedTracer(t1, t2, &Tracer{}) - - tr1.EXPECT().Close() - tr2.EXPECT().Close() - tracer.Close() + tracer.SentPacket(remote, hdr, size, []Frame{f}) + require.Equal(t, size, s1) + require.Equal(t, size, s2) }