diff --git a/client.go b/client.go index 5576c84bf..aeda2b16c 100644 --- a/client.go +++ b/client.go @@ -46,7 +46,8 @@ type client struct { session quicSession - logger utils.Logger + qlogger qlog.Tracer + logger utils.Logger } var _ packetHandler = &client{} @@ -55,6 +56,8 @@ var ( // make it possible to mock connection ID generation in the tests generateConnectionID = protocol.GenerateConnectionID generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial + // make it possible to the qlogger + newQlogger = qlog.NewTracer ) // DialAddr establishes a new QUIC connection to a server. @@ -178,13 +181,12 @@ func dialContext( } c.packetHandlers = packetHandlers - var qlogger qlog.Tracer if c.config.GetLogWriter != nil { if w := c.config.GetLogWriter(c.destConnID); w != nil { - qlogger = qlog.NewTracer(w, protocol.PerspectiveClient, c.destConnID) + c.qlogger = newQlogger(w, protocol.PerspectiveClient, c.destConnID) } } - if err := c.dial(ctx, qlogger); err != nil { + if err := c.dial(ctx); err != nil { return nil, err } return c.session, nil @@ -247,10 +249,10 @@ func newClient( return c, nil } -func (c *client) dial(ctx context.Context, qlogger qlog.Tracer) error { +func (c *client) dial(ctx context.Context) error { c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - if qlogger != nil { - qlogger.StartedConnection(c.conn.LocalAddr(), c.conn.LocalAddr(), c.version, c.srcConnID, c.destConnID) + if c.qlogger != nil { + c.qlogger.StartedConnection(c.conn.LocalAddr(), c.conn.LocalAddr(), c.version, c.srcConnID, c.destConnID) } c.mutex.Lock() @@ -264,7 +266,7 @@ func (c *client) dial(ctx context.Context, qlogger qlog.Tracer) error { c.initialPacketNumber, c.initialVersion, c.use0RTT, - qlogger, + c.qlogger, c.logger, c.version, ) @@ -295,7 +297,7 @@ func (c *client) dial(ctx context.Context, qlogger qlog.Tracer) error { return ctx.Err() case err := <-errorChan: if err == errCloseForRecreating { - return c.dial(ctx, qlogger) + return c.dial(ctx) } return err case <-earlySessionChan: @@ -328,18 +330,27 @@ func (c *client) handleVersionNegotiationPacket(p *receivedPacket) { hdr, _, _, err := wire.ParsePacket(p.data, 0) if err != nil { + if c.qlogger != nil { + c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError) + } c.logger.Debugf("Error parsing Version Negotiation packet: %s", err) return } // ignore delayed / duplicated version negotiation packets if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() { + if c.qlogger != nil { + c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket) + } c.logger.Debugf("Received a delayed Version Negotiation packet.") return } for _, v := range hdr.SupportedVersions { if v == c.version { + if c.qlogger != nil { + c.qlogger.DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion) + } // The Version Negotiation packet contains the version that we offered. // This might be a packet sent by an attacker (or by a terribly broken server implementation). return @@ -347,6 +358,9 @@ func (c *client) handleVersionNegotiationPacket(p *receivedPacket) { } c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions) + if c.qlogger != nil { + c.qlogger.ReceivedVersionNegotiationPacket(hdr) + } newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { //nolint:stylecheck diff --git a/client_test.go b/client_test.go index 3db990927..0a9dc35f7 100644 --- a/client_test.go +++ b/client_test.go @@ -1,10 +1,13 @@ package quic import ( + "bufio" "bytes" "context" "crypto/tls" "errors" + "io" + "io/ioutil" "net" "os" "time" @@ -12,6 +15,7 @@ import ( "github.com/lucas-clemente/quic-go/qlog" "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" @@ -30,6 +34,8 @@ var _ = Describe("Client", func() { mockMultiplexer *MockMultiplexer origMultiplexer multiplexer tlsConf *tls.Config + qlogger *mocks.MockTracer + config *Config originalClientSessConstructor func( conn connection, @@ -45,6 +51,7 @@ var _ = Describe("Client", func() { logger utils.Logger, v protocol.VersionNumber, ) quicSession + originalQlogConstructor func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) qlog.Tracer ) // generate a packet sent by the server that accepts the QUIC version suggested by the client @@ -72,6 +79,21 @@ var _ = Describe("Client", func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} originalClientSessConstructor = newClientSession + originalQlogConstructor = newQlogger + qlogger = mocks.NewMockTracer(mockCtrl) + newQlogger = func(io.WriteCloser, protocol.Perspective, protocol.ConnectionID) qlog.Tracer { + return qlogger + } + config = &Config{ + GetLogWriter: func([]byte) io.WriteCloser { + // Since we're mocking the qlogger, it doesn't matter what we return here, + // as long as it's not nil. + return utils.NewBufferedWriteCloser( + bufio.NewWriter(&bytes.Buffer{}), + ioutil.NopCloser(&bytes.Buffer{}), + ) + }, + } Eventually(areSessionsRunning).Should(BeFalse()) // sess = NewMockQuicSession(mockCtrl) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} @@ -83,6 +105,7 @@ var _ = Describe("Client", func() { destConnID: connID, version: protocol.SupportedVersions[0], conn: &conn{pconn: packetConn, currentAddr: addr}, + qlogger: qlogger, logger: utils.DefaultLogger, } getMultiplexer() // make the sync.Once execute @@ -95,6 +118,7 @@ var _ = Describe("Client", func() { AfterEach(func() { connMuxer = origMultiplexer newClientSession = originalClientSessConstructor + newQlogger = originalQlogConstructor }) AfterEach(func() { @@ -219,12 +243,13 @@ var _ = Describe("Client", func() { sess.EXPECT().run() return sess } + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any()) _, err := Dial( packetConn, addr, "test.com", tlsConf, - &Config{}, + config, ) Expect(err).ToNot(HaveOccurred()) Eventually(hostnameChan).Should(Receive(Equal("test.com"))) @@ -258,12 +283,13 @@ var _ = Describe("Client", func() { sess.EXPECT().HandshakeComplete().Return(ctx) return sess } + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any()) s, err := Dial( packetConn, addr, "localhost:1337", tlsConf, - &Config{}, + config, ) Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) @@ -302,12 +328,13 @@ var _ = Describe("Client", func() { go func() { defer GinkgoRecover() defer close(done) + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any()) s, err := DialEarly( packetConn, addr, "localhost:1337", tlsConf, - &Config{}, + config, ) Expect(err).ToNot(HaveOccurred()) Expect(s).ToNot(BeNil()) @@ -343,12 +370,13 @@ var _ = Describe("Client", func() { return sess } packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any()) _, err := Dial( packetConn, addr, "localhost:1337", tlsConf, - &Config{}, + config, ) Expect(err).To(MatchError(testErr)) }) @@ -385,13 +413,14 @@ var _ = Describe("Client", func() { dialed := make(chan struct{}) go func() { defer GinkgoRecover() + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any()) _, err := DialContext( ctx, packetConn, addr, "localhost:1337", tlsConf, - &Config{}, + config, ) Expect(err).To(MatchError(context.Canceled)) close(dialed) @@ -513,8 +542,7 @@ var _ = Describe("Client", func() { }) It("uses 0-byte connection IDs when dialing an address", func() { - config := &Config{} - c := populateClientConfig(config, true) + c := populateClientConfig(&Config{}, true) Expect(c.ConnectionIDLength).To(BeZero()) }) @@ -606,12 +634,13 @@ var _ = Describe("Client", func() { sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess } + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any()) _, err := Dial( packetConn, addr, "localhost:1337", tlsConf, - &Config{}, + config, ) Expect(err).To(MatchError(testErr)) }) @@ -620,7 +649,7 @@ var _ = Describe("Client", func() { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(gomock.Any()) cl.session = sess - cl.config = &Config{} + cl.config = config buf := &bytes.Buffer{} Expect((&wire.ExtendedHeader{ Header: wire.Header{ @@ -647,7 +676,11 @@ var _ = Describe("Client", func() { }) cl.session = sess cl.config = &Config{Versions: protocol.SupportedVersions} - cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1337})) + p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1337}) + hdr, _, _, err := wire.ParsePacket(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + qlogger.EXPECT().ReceivedVersionNegotiationPacket(hdr) + cl.handlePacket(p) Eventually(done).Should(BeClosed()) }) @@ -664,6 +697,7 @@ var _ = Describe("Client", func() { v := protocol.VersionNumber(1234) Expect(v).ToNot(Equal(cl.version)) cl.config = &Config{Versions: protocol.SupportedVersions} + qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()) cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v})) Eventually(done).Should(BeClosed()) }) @@ -680,15 +714,63 @@ var _ = Describe("Client", func() { cl.session = sess versions := []protocol.VersionNumber{1234, 4321} cl.config = &Config{Versions: versions} + qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any()) cl.handlePacket(composeVersionNegotiationPacket(connID, versions)) Eventually(destroyed).Should(BeClosed()) Expect(cl.version).To(Equal(protocol.VersionNumber(1234))) }) - It("drops version negotiation packets that contain the offered version", func() { - cl.config = &Config{} + It("drops unparseable version negotiation packets", func() { + cl.config = config ver := cl.version - cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})) + p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}) + p.data = p.data[:len(p.data)-1] + done := make(chan struct{}) + qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) { + close(done) + }) + cl.handlePacket(p) + Eventually(done).Should(BeClosed()) + Expect(cl.version).To(Equal(ver)) + }) + + It("drops version negotiation packets if any other packet was received before", func() { + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()) + cl.session = sess + cl.config = config + buf := &bytes.Buffer{} + Expect((&wire.ExtendedHeader{ + Header: wire.Header{ + DestConnectionID: connID, + SrcConnectionID: connID, + Version: cl.version, + }, + PacketNumberLen: protocol.PacketNumberLen3, + }).Write(buf, protocol.VersionTLS)).To(Succeed()) + cl.handlePacket(&receivedPacket{data: buf.Bytes()}) + + ver := cl.version + p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234}) + done := make(chan struct{}) + qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) { + close(done) + }) + cl.handlePacket(p) + Eventually(done).Should(BeClosed()) + Expect(cl.version).To(Equal(ver)) + }) + + It("drops version negotiation packets that contain the offered version", func() { + cl.config = config + ver := cl.version + p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver}) + done := make(chan struct{}) + qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) { + close(done) + }) + cl.handlePacket(p) + Eventually(done).Should(BeClosed()) Expect(cl.version).To(Equal(ver)) }) }) @@ -698,34 +780,4 @@ var _ = Describe("Client", func() { Expect(cl.version).ToNot(BeZero()) Expect(cl.GetVersion()).To(Equal(cl.version)) }) - - Context("handling potentially injected packets", func() { - // NOTE: We hope these tests as written will fail once mitigations for injection adversaries are put in place. - - // Illustrates that adversary who injects any packet quickly can - // cause a real version negotiation packet to be ignored. - It("version negotiation packets ignored if any other packet is received", func() { - // Copy of existing test "recognizes that a non Version Negotiation packet means that the server accepted the suggested version" - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) - cl.session = sess - cl.config = &Config{} - buf := &bytes.Buffer{} - Expect((&wire.ExtendedHeader{ - Header: wire.Header{ - DestConnectionID: connID, - SrcConnectionID: connID, - Version: cl.version, - }, - PacketNumberLen: protocol.PacketNumberLen3, - }).Write(buf, protocol.VersionTLS)).To(Succeed()) - cl.handlePacket(&receivedPacket{data: buf.Bytes()}) - - // Version negotiation is now ignored - cl.config = &Config{} - ver := cl.version - cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234})) - Expect(cl.version).To(Equal(ver)) - }) - }) }) diff --git a/internal/mocks/qlog.go b/internal/mocks/qlog.go index f8f845a8a..ba4a7474b 100644 --- a/internal/mocks/qlog.go +++ b/internal/mocks/qlog.go @@ -173,6 +173,18 @@ func (mr *MockTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockTracer)(nil).ReceivedTransportParameters), arg0) } +// ReceivedVersionNegotiationPacket mocks base method +func (m *MockTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0) +} + +// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket +func (mr *MockTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).ReceivedVersionNegotiationPacket), arg0) +} + // SentPacket mocks base method func (m *MockTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []wire.Frame) { m.ctrl.T.Helper() diff --git a/qlog/event.go b/qlog/event.go index 17eb7d24e..b77076789 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -47,6 +47,15 @@ func (e event) MarshalJSONArray(enc *gojay.Encoder) { enc.Object(e.eventDetails) } +type versions []versionNumber + +func (v versions) IsNil() bool { return false } +func (v versions) MarshalJSONArray(enc *gojay.Encoder) { + for _, e := range v { + enc.AddString(e.String()) + } +} + type eventConnectionStarted struct { SrcAddr *net.UDPAddr DestAddr *net.UDPAddr @@ -140,6 +149,21 @@ func (e eventRetryReceived) MarshalJSONObject(enc *gojay.Encoder) { enc.ObjectKey("header", e.Header) } +type eventVersionNegotiationReceived struct { + Header packetHeader + SupportedVersions []versionNumber +} + +func (e eventVersionNegotiationReceived) Category() category { return categoryTransport } +func (e eventVersionNegotiationReceived) Name() string { return "packet_received" } +func (e eventVersionNegotiationReceived) IsNil() bool { return false } + +func (e eventVersionNegotiationReceived) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("packet_type", PacketTypeVersionNegotiation.String()) + enc.ObjectKey("header", e.Header) + enc.ArrayKey("supported_versions", versions(e.SupportedVersions)) +} + type eventStatelessResetReceived struct { Token *[16]byte } diff --git a/qlog/qlog.go b/qlog/qlog.go index 2f1f3933d..9276b070c 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -24,6 +24,7 @@ type Tracer interface { SentTransportParameters(*wire.TransportParameters) ReceivedTransportParameters(*wire.TransportParameters) SentPacket(hdr *wire.ExtendedHeader, packetSize protocol.ByteCount, ack *wire.AckFrame, frames []wire.Frame) + ReceivedVersionNegotiationPacket(*wire.Header) ReceivedRetry(*wire.Header) ReceivedPacket(hdr *wire.ExtendedHeader, packetSize protocol.ByteCount, frames []wire.Frame) ReceivedStatelessReset(token *[16]byte) @@ -230,6 +231,19 @@ func (t *tracer) ReceivedRetry(hdr *wire.Header) { t.mutex.Unlock() } +func (t *tracer) ReceivedVersionNegotiationPacket(hdr *wire.Header) { + versions := make([]versionNumber, len(hdr.SupportedVersions)) + for i, v := range hdr.SupportedVersions { + versions[i] = versionNumber(v) + } + t.mutex.Lock() + t.recordEvent(time.Now(), &eventVersionNegotiationReceived{ + Header: *transformHeader(hdr), + SupportedVersions: versions, + }) + t.mutex.Unlock() +} + func (t *tracer) ReceivedStatelessReset(token *[16]byte) { t.mutex.Lock() t.recordEvent(time.Now(), &eventStatelessResetReceived{ diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index 2a7a49320..ccf8083aa 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -321,6 +321,33 @@ var _ = Describe("Tracer", func() { Expect(ev).ToNot(HaveKey("frames")) }) + It("records a received Version Negotiation packet", func() { + tracer.ReceivedVersionNegotiationPacket( + &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + SupportedVersions: []protocol.VersionNumber{0xdeadbeef, 0xdecafbad}, + }, + ) + entry := exportAndParseSingle() + Expect(entry.Time).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) + Expect(entry.Category).To(Equal("transport")) + Expect(entry.Name).To(Equal("packet_received")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("packet_type", "version_negotiation")) + Expect(ev).To(HaveKey("header")) + Expect(ev).ToNot(HaveKey("frames")) + Expect(ev).To(HaveKey("supported_versions")) + Expect(ev["supported_versions"].([]interface{})).To(Equal([]interface{}{"deadbeef", "decafbad"})) + header := ev["header"] + Expect(header).ToNot(HaveKey("packet_number")) + Expect(header).ToNot(HaveKey("version")) + Expect(header).To(HaveKey("dcid")) + Expect(header).To(HaveKey("scid")) + }) + It("records a received Retry packet", func() { tracer.ReceivedStatelessReset(&[16]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}) entry := exportAndParseSingle()