diff --git a/internal/mocks/logging/internal/tracer.go b/internal/mocks/logging/internal/tracer.go index 88c3b645..f8f3ae0e 100644 --- a/internal/mocks/logging/internal/tracer.go +++ b/internal/mocks/logging/internal/tracer.go @@ -41,6 +41,42 @@ func (m *MockTracer) EXPECT() *MockTracerMockRecorder { return m.recorder } +// Close mocks base method. +func (m *MockTracer) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockTracerMockRecorder) Close() *TracerCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTracer)(nil).Close)) + return &TracerCloseCall{Call: call} +} + +// TracerCloseCall wrap *gomock.Call +type TracerCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *TracerCloseCall) Return() *TracerCloseCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *TracerCloseCall) Do(f func()) *TracerCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *TracerCloseCall) DoAndReturn(f func()) *TracerCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Debug mocks base method. func (m *MockTracer) Debug(arg0, arg1 string) { m.ctrl.T.Helper() diff --git a/internal/mocks/logging/mockgen.go b/internal/mocks/logging/mockgen.go index fb58e117..a9850e64 100644 --- a/internal/mocks/logging/mockgen.go +++ b/internal/mocks/logging/mockgen.go @@ -15,6 +15,7 @@ type Tracer interface { SentVersionNegotiationPacket(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) DroppedPacket(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) Debug(name, msg string) + Close() } //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package internal -destination internal/connection_tracer.go github.com/quic-go/quic-go/internal/mocks/logging ConnectionTracer" diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index 66e210a5..a4b081bf 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -28,5 +28,8 @@ func NewMockTracer(ctrl *gomock.Controller) (*logging.Tracer, *MockTracer) { Debug: func(name, msg string) { t.Debug(name, msg) }, + Close: func() { + t.Close() + }, }, t } diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 96cd1185..4a9fe677 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -70,6 +70,12 @@ var _ = Describe("Tracing", func() { tr2.EXPECT().Debug("foo", "bar") tracer.Debug("foo", "bar") }) + + It("traces the Close event", func() { + tr1.EXPECT().Close() + tr2.EXPECT().Close() + tracer.Close() + }) }) }) diff --git a/logging/tracer.go b/logging/tracer.go index 735ec3de..edd85dba 100644 --- a/logging/tracer.go +++ b/logging/tracer.go @@ -8,6 +8,7 @@ type Tracer struct { SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) Debug func(name, msg string) + Close func() } // NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. @@ -47,5 +48,12 @@ func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { } } }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, } } diff --git a/server_test.go b/server_test.go index 58b15921..cbc16c19 100644 --- a/server_test.go +++ b/server_test.go @@ -203,6 +203,7 @@ var _ = Describe("Server", func() { }) AfterEach(func() { + tracer.EXPECT().Close() tr.Close() }) @@ -1429,7 +1430,10 @@ var _ = Describe("Server", func() { serv.connHandler = phm }) - AfterEach(func() { tr.Close() }) + AfterEach(func() { + tracer.EXPECT().Close() + tr.Close() + }) It("passes packets to existing connections", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) diff --git a/transport.go b/transport.go index 4988bcfe..fdfa1b6b 100644 --- a/transport.go +++ b/transport.go @@ -111,6 +111,7 @@ type Transport struct { MaxHandshakes int // A Tracer traces events that don't belong to a single QUIC connection. + // Tracer.Close is called when the transport is closed. Tracer *logging.Tracer handlerMap packetHandlerManager @@ -366,6 +367,9 @@ func (t *Transport) close(e error) { if t.server != nil { t.server.close(e, false) } + if t.Tracer != nil && t.Tracer.Close != nil { + t.Tracer.Close() + } t.closed = true } diff --git a/transport_test.go b/transport_test.go index c93d1da9..2865ae94 100644 --- a/transport_test.go +++ b/transport_test.go @@ -141,6 +141,7 @@ var _ = Describe("Transport", func() { Eventually(dropped).Should(BeClosed()) // shutdown + tracer.EXPECT().Close() close(packetChan) tr.Close() }) @@ -391,6 +392,7 @@ var _ = Describe("Transport", func() { Eventually(done).Should(BeClosed()) // shutdown + tracer.EXPECT().Close() close(packetChan) tr.Close() })