diff --git a/client.go b/client.go index 4691d6d6b..e56329f4c 100644 --- a/client.go +++ b/client.go @@ -12,7 +12,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/logging" - "github.com/lucas-clemente/quic-go/qlog" ) type client struct { @@ -51,8 +50,6 @@ var ( // make it possible to mock connection ID generation in the tests generateConnectionID = protocol.GenerateConnectionID generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial - // make it possible to the tracer - newTracer = qlog.NewTracer ) // DialAddr establishes a new QUIC connection to a server. @@ -180,10 +177,8 @@ func dialContext( } c.packetHandlers = packetHandlers - if c.config.GetLogWriter != nil { - if w := c.config.GetLogWriter(c.destConnID); w != nil { - c.tracer = newTracer(w, protocol.PerspectiveClient, c.destConnID) - } + if c.config.Tracer != nil { + c.tracer = c.config.Tracer.TracerForClient(c.destConnID) } if err := c.dial(ctx); err != nil { return nil, err diff --git a/client_test.go b/client_test.go index 03698732d..b44a79709 100644 --- a/client_test.go +++ b/client_test.go @@ -1,13 +1,10 @@ package quic import ( - "bufio" "bytes" "context" "crypto/tls" "errors" - "io" - "io/ioutil" "net" "os" "time" @@ -52,7 +49,6 @@ var _ = Describe("Client", func() { logger utils.Logger, v protocol.VersionNumber, ) quicSession - originalTracerConstructor func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) logging.ConnectionTracer ) // generate a packet sent by the server that accepts the QUIC version suggested by the client @@ -70,21 +66,10 @@ var _ = Describe("Client", func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} originalClientSessConstructor = newClientSession - originalTracerConstructor = newTracer tracer = mocks.NewMockConnectionTracer(mockCtrl) - newTracer = func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) logging.ConnectionTracer { - return tracer - } - config = &Config{ - GetLogWriter: func([]byte) io.WriteCloser { - // Since we're mocking the tracer, it doesn't matter what we return here, - // as long as it's not nil. - return utils.NewBufferedWriteCloser( - bufio.NewWriter(&bytes.Buffer{}), - ioutil.NopCloser(&bytes.Buffer{}), - ) - }, - } + tr := mocks.NewMockTracer(mockCtrl) + tr.EXPECT().TracerForClient(gomock.Any()).Return(tracer).MaxTimes(1) + config = &Config{Tracer: tr} Eventually(areSessionsRunning).Should(BeFalse()) // sess = NewMockQuicSession(mockCtrl) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} @@ -109,7 +94,6 @@ var _ = Describe("Client", func() { AfterEach(func() { connMuxer = origMultiplexer newClientSession = originalClientSessConstructor - newTracer = originalTracerConstructor }) AfterEach(func() { diff --git a/config.go b/config.go index 729b77799..1790bbb18 100644 --- a/config.go +++ b/config.go @@ -82,6 +82,6 @@ func populateConfig(config *Config) *Config { StatelessResetKey: config.StatelessResetKey, TokenStore: config.TokenStore, QuicTracer: config.QuicTracer, - GetLogWriter: config.GetLogWriter, + Tracer: config.Tracer, } } diff --git a/config_test.go b/config_test.go index 10a4423ab..fe1ac9840 100644 --- a/config_test.go +++ b/config_test.go @@ -2,11 +2,11 @@ package quic import ( "fmt" - "io" "net" "reflect" "time" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/quictrace" @@ -54,6 +54,8 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(true)) case "QuicTracer": f.Set(reflect.ValueOf(quictrace.NewTracer())) + case "Tracer": + f.Set(reflect.ValueOf(mocks.NewMockTracer(mockCtrl))) default: Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) } @@ -62,16 +64,13 @@ var _ = Describe("Config", func() { } Context("cloning", func() { It("clones function fields", func() { - var calledAcceptToken, calledGetLogWriter bool + var calledAcceptToken bool c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, - GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil }, + AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, } c2 := c1.Clone() c2.AcceptToken(&net.UDPAddr{}, &Token{}) - c2.GetLogWriter([]byte{1, 2, 3}) Expect(calledAcceptToken).To(BeTrue()) - Expect(calledGetLogWriter).To(BeTrue()) }) It("clones non-function fields", func() { @@ -95,16 +94,13 @@ var _ = Describe("Config", func() { Context("populating", func() { It("populates function fields", func() { - var calledAcceptToken, calledGetLogWriter bool + var calledAcceptToken bool c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, - GetLogWriter: func(connectionID []byte) io.WriteCloser { calledGetLogWriter = true; return nil }, + AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, } c2 := populateConfig(c1) c2.AcceptToken(&net.UDPAddr{}, &Token{}) - c2.GetLogWriter([]byte{1, 2, 3}) Expect(calledAcceptToken).To(BeTrue()) - Expect(calledGetLogWriter).To(BeTrue()) }) It("copies non-function fields", func() { diff --git a/example/client/main.go b/example/client/main.go index c058c6a8a..5d18dca9f 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -17,6 +17,7 @@ import ( "github.com/lucas-clemente/quic-go/http3" "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qlog" ) func main() { @@ -24,7 +25,7 @@ func main() { quiet := flag.Bool("q", false, "don't print the data") keyLogFile := flag.String("keylog", "", "key log file") insecure := flag.Bool("insecure", false, "skip certificate verification") - qlog := flag.Bool("qlog", false, "output a qlog (in the same directory)") + enableQlog := flag.Bool("qlog", false, "output a qlog (in the same directory)") flag.Parse() urls := flag.Args() @@ -54,8 +55,8 @@ func main() { testdata.AddRootCA(pool) var qconf quic.Config - if *qlog { - qconf.GetLogWriter = func(connID []byte) io.WriteCloser { + if *enableQlog { + qconf.Tracer = qlog.NewTracer(func(connID []byte) io.WriteCloser { filename := fmt.Sprintf("client_%x.qlog", connID) f, err := os.Create(filename) if err != nil { @@ -63,7 +64,7 @@ func main() { } log.Printf("Creating qlog file %s.\n", filename) return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f) - } + }) } roundTripper := &http3.RoundTripper{ TLSClientConfig: &tls.Config{ diff --git a/example/main.go b/example/main.go index ee0d440a9..b9b3df37d 100644 --- a/example/main.go +++ b/example/main.go @@ -22,6 +22,7 @@ import ( "github.com/lucas-clemente/quic-go/http3" "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qlog" "github.com/lucas-clemente/quic-go/quictrace" ) @@ -187,7 +188,7 @@ func main() { www := flag.String("www", "", "www data") tcp := flag.Bool("tcp", false, "also listen on TCP") trace := flag.Bool("trace", false, "enable quic-trace") - qlog := flag.Bool("qlog", false, "output a qlog (in the same directory)") + enableQlog := flag.Bool("qlog", false, "output a qlog (in the same directory)") flag.Parse() logger := utils.DefaultLogger @@ -208,8 +209,8 @@ func main() { if *trace { quicConf.QuicTracer = tracer } - if *qlog { - quicConf.GetLogWriter = func(connID []byte) io.WriteCloser { + if *enableQlog { + quicConf.Tracer = qlog.NewTracer(func(connID []byte) io.WriteCloser { filename := fmt.Sprintf("server_%x.qlog", connID) f, err := os.Create(filename) if err != nil { @@ -217,7 +218,7 @@ func main() { } log.Printf("Creating qlog file %s.\n", filename) return utils.NewBufferedWriteCloser(bufio.NewWriter(f), f) - } + }) } var wg sync.WaitGroup diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index cd0821476..371909a67 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -21,6 +21,7 @@ import ( "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qlog" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -259,14 +260,14 @@ func getQuicConfigForRole(role string, conf *quic.Config) *quic.Config { if !enableQlog { return conf } - conf.GetLogWriter = func(connectionID []byte) io.WriteCloser { + conf.Tracer = qlog.NewTracer(func(connectionID []byte) io.WriteCloser { filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role) fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename) f, err := os.Create(filename) Expect(err).ToNot(HaveOccurred()) bw := bufio.NewWriter(f) return utils.NewBufferedWriteCloser(bw, f) - } + }) return conf } diff --git a/interface.go b/interface.go index 429e1caab..1abd89ece 100644 --- a/interface.go +++ b/interface.go @@ -6,6 +6,8 @@ import ( "net" "time" + "github.com/lucas-clemente/quic-go/logging" + "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/quictrace" @@ -245,11 +247,7 @@ type Config struct { // QUIC Event Tracer. // Warning: Experimental. This API should not be considered stable and will change soon. QuicTracer quictrace.Tracer - // GetLogWriter is used to pass in a writer for the qlog. - // If it is nil, no qlog will be collected and exported. - // If it returns nil, no qlog will be collected and exported for the respective connection. - // It is recommended to use a buffered writer here. - GetLogWriter func(connectionID []byte) io.WriteCloser + Tracer logging.Tracer } // A Listener for incoming QUIC connections diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 5e9494295..9d84e3d22 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -3,6 +3,7 @@ package mocks //go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go" //go:generate sh -c "mockgen -package mockquic -destination quic/early_session.go github.com/lucas-clemente/quic-go EarlySession && goimports -w quic/early_session.go" //go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/lucas-clemente/quic-go EarlyListener && goimports -w quic/early_listener.go" +//go:generate sh -c "mockgen -package mocks -destination tracer.go github.com/lucas-clemente/quic-go/logging Tracer && goimports -w tracer.go" //go:generate sh -c "mockgen -package mocks -destination connection_tracer.go github.com/lucas-clemente/quic-go/logging ConnectionTracer && goimports -w connection_tracer.go" //go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer && goimports -w short_header_sealer.go" //go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener && goimports -w short_header_opener.go" diff --git a/internal/mocks/tracer.go b/internal/mocks/tracer.go new file mode 100644 index 000000000..41aba518f --- /dev/null +++ b/internal/mocks/tracer.go @@ -0,0 +1,64 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/logging (interfaces: Tracer) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + logging "github.com/lucas-clemente/quic-go/logging" +) + +// MockTracer is a mock of Tracer interface +type MockTracer struct { + ctrl *gomock.Controller + recorder *MockTracerMockRecorder +} + +// MockTracerMockRecorder is the mock recorder for MockTracer +type MockTracerMockRecorder struct { + mock *MockTracer +} + +// NewMockTracer creates a new mock instance +func NewMockTracer(ctrl *gomock.Controller) *MockTracer { + mock := &MockTracer{ctrl: ctrl} + mock.recorder = &MockTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTracer) EXPECT() *MockTracerMockRecorder { + return m.recorder +} + +// TracerForClient mocks base method +func (m *MockTracer) TracerForClient(arg0 protocol.ConnectionID) logging.ConnectionTracer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TracerForClient", arg0) + ret0, _ := ret[0].(logging.ConnectionTracer) + return ret0 +} + +// TracerForClient indicates an expected call of TracerForClient +func (mr *MockTracerMockRecorder) TracerForClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForClient", reflect.TypeOf((*MockTracer)(nil).TracerForClient), arg0) +} + +// TracerForServer mocks base method +func (m *MockTracer) TracerForServer(arg0 protocol.ConnectionID) logging.ConnectionTracer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TracerForServer", arg0) + ret0, _ := ret[0].(logging.ConnectionTracer) + return ret0 +} + +// TracerForServer indicates an expected call of TracerForServer +func (mr *MockTracerMockRecorder) TracerForServer(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TracerForServer", reflect.TypeOf((*MockTracer)(nil).TracerForServer), arg0) +} diff --git a/interop/client/main.go b/interop/client/main.go index 649bc6167..9d7a23a10 100644 --- a/interop/client/main.go +++ b/interop/client/main.go @@ -19,6 +19,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/interop/http09" "github.com/lucas-clemente/quic-go/interop/utils" + "github.com/lucas-clemente/quic-go/qlog" ) var errUnsupported = errors.New("unsupported test case") @@ -66,7 +67,7 @@ func runTestcase(testcase string) error { if err != nil { return err } - quicConf := &quic.Config{GetLogWriter: getLogWriter} + quicConf := &quic.Config{Tracer: qlog.NewTracer(getLogWriter)} if testcase == "http3" { r := &http3.RoundTripper{ diff --git a/interop/server/main.go b/interop/server/main.go index f830ad30f..3eaefbb01 100644 --- a/interop/server/main.go +++ b/interop/server/main.go @@ -13,6 +13,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/interop/http09" "github.com/lucas-clemente/quic-go/interop/utils" + "github.com/lucas-clemente/quic-go/qlog" ) var tlsConf *tls.Config @@ -44,8 +45,8 @@ func main() { } // a quic.Config that doesn't do a Retry quicConf := &quic.Config{ - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - GetLogWriter: getLogWriter, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Tracer: qlog.NewTracer(getLogWriter), } tlsConf = testdata.GetTLSConfig() tlsConf.KeyLogWriter = keyLog diff --git a/logging/interface.go b/logging/interface.go index 9d38b162e..e925ba0d3 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -11,6 +11,11 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type Tracer interface { + TracerForServer(odcid protocol.ConnectionID) ConnectionTracer + TracerForClient(odcid protocol.ConnectionID) ConnectionTracer +} + // A ConnectionTracer records events. type ConnectionTracer interface { StartedConnection(local, remote net.Addr, version protocol.VersionNumber, srcConnID, destConnID protocol.ConnectionID) diff --git a/qlog/qlog.go b/qlog/qlog.go index e3e809f2a..9d8090f0c 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -19,6 +19,25 @@ import ( const eventChanSize = 50 +type tracer struct { + getLogWriter func(connectionID []byte) io.WriteCloser +} + +var _ logging.Tracer = &tracer{} + +// NewTracer creates a new qlog tracer. +func NewTracer(getLogWriter func(connectionID []byte) io.WriteCloser) logging.Tracer { + return &tracer{getLogWriter: getLogWriter} +} + +func (t *tracer) TracerForServer(odcid protocol.ConnectionID) logging.ConnectionTracer { + return newTracer(t.getLogWriter(odcid.Bytes()), protocol.PerspectiveServer, odcid) +} + +func (t *tracer) TracerForClient(odcid protocol.ConnectionID) logging.ConnectionTracer { + return newTracer(t.getLogWriter(odcid.Bytes()), protocol.PerspectiveClient, odcid) +} + type connectionTracer struct { mutex sync.Mutex @@ -37,8 +56,8 @@ type connectionTracer struct { var _ logging.ConnectionTracer = &connectionTracer{} -// NewTracer creates a new connectionTracer to record a qlog. -func NewTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { +// newTracer creates a new connectionTracer to record a qlog. +func newTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { t := &connectionTracer{ w: w, perspective: p, diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index a95c870e5..e2bf7b230 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -57,7 +57,7 @@ var _ = Describe("Tracer", func() { BeforeEach(func() { buf = &bytes.Buffer{} - tracer = NewTracer( + tracer = newTracer( nopWriteCloser(buf), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, @@ -92,7 +92,7 @@ var _ = Describe("Tracer", func() { }) It("stops writing when encountering an error", func() { - tracer = NewTracer( + tracer = newTracer( &limitedWriter{WriteCloser: nopWriteCloser(buf), N: 250}, protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, diff --git a/server.go b/server.go index a16ebe4f8..83e356874 100644 --- a/server.go +++ b/server.go @@ -17,7 +17,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/logging" - "github.com/lucas-clemente/quic-go/qlog" ) // packetHandler handles packets @@ -448,15 +447,13 @@ func (s *baseServer) createNewSession( var sess quicSession if added := s.sessionHandler.AddWithConnID(clientDestConnID, srcConnID, func() packetHandler { var tracer logging.ConnectionTracer - if s.config.GetLogWriter != nil { + if s.config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. connID := clientDestConnID if origDestConnID.Len() > 0 { connID = origDestConnID } - if w := s.config.GetLogWriter(connID); w != nil { - tracer = qlog.NewTracer(w, protocol.PerspectiveServer, connID) - } + tracer = s.config.Tracer.TracerForServer(connID) } sess = s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr},