diff --git a/client.go b/client.go index 9cbfd4908..63db2cdbb 100644 --- a/client.go +++ b/client.go @@ -165,7 +165,7 @@ func dialContext( return nil, errors.New("quic: tls.Config not set") } config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey) + packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) if err != nil { return nil, err } diff --git a/client_test.go b/client_test.go index 61506f01f..ad339521f 100644 --- a/client_test.go +++ b/client_test.go @@ -131,7 +131,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) remoteAddrChan := make(chan string, 1) newClientSession = func( @@ -164,7 +164,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) hostnameChan := make(chan string, 1) newClientSession = func( @@ -197,7 +197,7 @@ var _ = Describe("Client", func() { It("allows passing host without port as server name", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) hostnameChan := make(chan string, 1) newClientSession = func( @@ -236,7 +236,7 @@ var _ = Describe("Client", func() { It("returns after the handshake is complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) run := make(chan struct{}) newClientSession = func( @@ -278,7 +278,7 @@ var _ = Describe("Client", func() { It("returns early sessions", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) readyChan := make(chan struct{}) done := make(chan struct{}) @@ -327,7 +327,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the handshake to complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) testErr := errors.New("early handshake error") newClientSession = func( @@ -365,7 +365,7 @@ var _ = Describe("Client", func() { It("closes the session when the context is canceled", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) sessionRunning := make(chan struct{}) defer close(sessionRunning) @@ -419,7 +419,7 @@ var _ = Describe("Client", func() { } manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) manager.EXPECT().Add(gomock.Any(), gomock.Any()) var conn connection @@ -497,7 +497,7 @@ var _ = Describe("Client", func() { It("errors when the Config contains an invalid version", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) version := protocol.VersionNumber(0x1234) _, err := Dial(packetConn, nil, "localhost:1234", tlsConf, &Config{Versions: []protocol.VersionNumber{version}}) @@ -540,7 +540,7 @@ var _ = Describe("Client", func() { It("creates new sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} c := make(chan struct{}) @@ -584,7 +584,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()).Times(2) manager.EXPECT().Destroy() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) initialVersion := cl.version diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index fd77c0a55..68f18f9aa 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -19,10 +19,9 @@ import ( "testing" "time" - "github.com/lucas-clemente/quic-go/logging" - "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/qlog" . "github.com/onsi/ginkgo" @@ -91,6 +90,19 @@ var ( tlsConfig *tls.Config tlsConfigLongChain *tls.Config tlsClientConfig *tls.Config + + tracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { + role := "server" + if p == logging.PerspectiveClient { + role = "client" + } + filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role) + fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename) + f, err := os.Create(filename) + Expect(err).ToNot(HaveOccurred()) + bw := bufio.NewWriter(f) + return utils.NewBufferedWriteCloser(bw, f) + }) ) // read the logfile command line flag @@ -254,18 +266,7 @@ func getQuicConfig(conf *quic.Config) *quic.Config { if !enableQlog { return conf } - conf.Tracer = qlog.NewTracer(func(p logging.Perspective, connectionID []byte) io.WriteCloser { - role := "server" - if p == logging.PerspectiveClient { - role = "client" - } - filename := fmt.Sprintf("log_%x_%s.qlog", connectionID, role) - fmt.Fprintf(GinkgoWriter, "Creating %s.\n", filename) - f, err := os.Create(filename) - Expect(err).ToNot(HaveOccurred()) - bw := bufio.NewWriter(f) - return utils.NewBufferedWriteCloser(bw, f) - }) + conf.Tracer = tracer return conf } diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go index 7ca5a2547..bcdcdcf80 100644 --- a/mock_multiplexer_test.go +++ b/mock_multiplexer_test.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + logging "github.com/lucas-clemente/quic-go/logging" ) // MockMultiplexer is a mock of Multiplexer interface @@ -35,18 +36,18 @@ func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { } // AddConn mocks base method -func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 []byte) (packetHandlerManager, error) { +func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 []byte, arg3 logging.Tracer) (packetHandlerManager, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(packetHandlerManager) ret1, _ := ret[1].(error) return ret0, ret1 } // AddConn indicates an expected call of AddConn -func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2, arg3) } // RemoveConn mocks base method diff --git a/multiplexer.go b/multiplexer.go index 824cb84ea..51eadfe08 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/logging" ) var ( @@ -15,13 +16,14 @@ var ( ) type multiplexer interface { - AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte) (packetHandlerManager, error) + AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) RemoveConn(net.PacketConn) error } type connManager struct { connIDLen int statelessResetKey []byte + tracer logging.Tracer manager packetHandlerManager } @@ -31,7 +33,7 @@ type connMultiplexer struct { mutex sync.Mutex conns map[string] /* LocalAddr().String() */ connManager - newPacketHandlerManager func(net.PacketConn, int, []byte, utils.Logger) packetHandlerManager // so it can be replaced in the tests + newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) packetHandlerManager // so it can be replaced in the tests logger utils.Logger } @@ -53,6 +55,7 @@ func (m *connMultiplexer) AddConn( c net.PacketConn, connIDLen int, statelessResetKey []byte, + tracer logging.Tracer, ) (packetHandlerManager, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -60,19 +63,24 @@ func (m *connMultiplexer) AddConn( connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() p, ok := m.conns[connIndex] if !ok { - manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, m.logger) + manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) p = connManager{ connIDLen: connIDLen, statelessResetKey: statelessResetKey, manager: manager, + tracer: tracer, } m.conns[connIndex] = p - } - if p.connIDLen != connIDLen { - return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) - } - if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { - return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") + } else { + if p.connIDLen != connIDLen { + return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) + } + if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { + return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") + } + if tracer != p.tracer { + return nil, fmt.Errorf("cannot use different tracers on the same packet conn") + } } return p.manager, nil } diff --git a/multiplexer_test.go b/multiplexer_test.go index f92d2aa9f..4141c46cd 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -3,6 +3,8 @@ package quic import ( "net" + "github.com/lucas-clemente/quic-go/internal/mocks" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -15,7 +17,7 @@ type testConn struct { var _ = Describe("Client Multiplexer", func() { It("adds a new packet conn ", func() { conn := newMockPacketConn() - _, err := getMultiplexer().AddConn(conn, 8, nil) + _, err := getMultiplexer().AddConn(conn, 8, nil, nil) Expect(err).ToNot(HaveOccurred()) }) @@ -23,27 +25,36 @@ var _ = Describe("Client Multiplexer", func() { pconn := newMockPacketConn() pconn.addr = &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} conn := testConn{PacketConn: pconn} - _, err := getMultiplexer().AddConn(conn, 8, nil) + tracer := mocks.NewMockTracer(mockCtrl) + _, err := getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) Expect(err).ToNot(HaveOccurred()) conn.counter++ - _, err = getMultiplexer().AddConn(conn, 8, nil) + _, err = getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) Expect(err).ToNot(HaveOccurred()) Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1)) }) It("errors when adding an existing conn with a different connection ID length", func() { conn := newMockPacketConn() - _, err := getMultiplexer().AddConn(conn, 5, nil) + _, err := getMultiplexer().AddConn(conn, 5, nil, nil) Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 6, nil) + _, err = getMultiplexer().AddConn(conn, 6, nil, nil) Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) }) It("errors when adding an existing conn with a different stateless rest key", func() { conn := newMockPacketConn() - _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar")) + _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar"), nil) Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof")) + _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof"), nil) Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) }) + + It("errors when adding an existing conn with different tracers", func() { + conn := newMockPacketConn() + _, err := getMultiplexer().AddConn(conn, 7, nil, mocks.NewMockTracer(mockCtrl)) + Expect(err).ToNot(HaveOccurred()) + _, err = getMultiplexer().AddConn(conn, 7, nil, mocks.NewMockTracer(mockCtrl)) + Expect(err).To(MatchError("cannot use different tracers on the same packet conn")) + }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index bd9b6f7cd..809cfea0c 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -13,6 +13,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/logging" ) type statelessResetErr struct { @@ -46,6 +47,7 @@ type packetHandlerMap struct { statelessResetMutex sync.Mutex statelessResetHasher hash.Hash + tracer logging.Tracer logger utils.Logger } @@ -55,6 +57,7 @@ func newPacketHandlerMap( conn net.PacketConn, connIDLen int, statelessResetKey []byte, + tracer logging.Tracer, logger utils.Logger, ) packetHandlerManager { m := &packetHandlerMap{ @@ -66,6 +69,7 @@ func newPacketHandlerMap( deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, statelessResetEnabled: len(statelessResetKey) > 0, statelessResetHasher: hmac.New(sha256.New, statelessResetKey), + tracer: tracer, logger: logger, } go m.listen() diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 08f6c86c2..130b969f4 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -50,7 +50,7 @@ var _ = Describe("Packet Handler Map", func() { JustBeforeEach(func() { conn = newMockPacketConn() - handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, utils.DefaultLogger).(*packetHandlerMap) + handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, nil, utils.DefaultLogger).(*packetHandlerMap) }) AfterEach(func() { diff --git a/server.go b/server.go index 512b82590..e297d3368 100644 --- a/server.go +++ b/server.go @@ -178,7 +178,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl } } - sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey) + sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) if err != nil { return nil, err }