diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index c5672f01..a5f07f77 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -26,9 +26,10 @@ var _ = Describe("MITM test", func() { const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it var ( - serverUDPConn, clientUDPConn net.PacketConn - serverConn quic.Connection - serverConfig *quic.Config + clientUDPConn net.PacketConn + serverTransport, clientTransport *quic.Transport + serverConn quic.Connection + serverConfig *quic.Config ) startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) { @@ -36,13 +37,11 @@ var _ = Describe("MITM test", func() { Expect(err).ToNot(HaveOccurred()) c, err := net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) - serverUDPConn, err = quic.OptimizeConn(c) - Expect(err).ToNot(HaveOccurred()) - tr := &quic.Transport{ - Conn: serverUDPConn, + serverTransport = &quic.Transport{ + Conn: c, ConnectionIDLength: connIDLen, } - ln, err := tr.Listen(getTLSConfig(), serverConfig) + ln, err := serverTransport.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { @@ -69,7 +68,7 @@ var _ = Describe("MITM test", func() { return proxy.LocalPort(), func() { proxy.Close() ln.Close() - serverUDPConn.Close() + serverTransport.Close() <-done } } @@ -78,10 +77,12 @@ var _ = Describe("MITM test", func() { serverConfig = getQuicConfig(nil) addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) - c, err := net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - clientUDPConn, err = quic.OptimizeConn(c) + clientUDPConn, err = net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) + clientTransport = &quic.Transport{ + Conn: clientUDPConn, + ConnectionIDLength: connIDLen, + } }) Context("unsuccessful attacks", func() { @@ -90,12 +91,13 @@ var _ = Describe("MITM test", func() { // Test shutdown is tricky due to the proxy. Just wait for a bit. time.Sleep(50 * time.Millisecond) Expect(clientUDPConn.Close()).To(Succeed()) + Expect(clientTransport.Close()).To(Succeed()) }) Context("injecting invalid packets", func() { const rtt = 20 * time.Millisecond - sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) { + sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) { defer GinkgoRecover() const numPackets = 10 ticker := time.NewTicker(rtt / numPackets) @@ -155,11 +157,7 @@ var _ = Describe("MITM test", func() { defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - tr := &quic.Transport{ - Conn: clientUDPConn, - ConnectionIDLength: connIDLen, - } - conn, err := tr.Dial( + conn, err := clientTransport.Dial( context.Background(), raddr, getTLSClientConfig(), @@ -178,7 +176,7 @@ var _ = Describe("MITM test", func() { delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { defer GinkgoRecover() - go sendRandomPacketsOfSameType(clientUDPConn, serverUDPConn.LocalAddr(), raw) + go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw) } return rtt / 2 } @@ -189,7 +187,7 @@ var _ = Describe("MITM test", func() { delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionOutgoing { defer GinkgoRecover() - go sendRandomPacketsOfSameType(serverUDPConn, clientUDPConn.LocalAddr(), raw) + go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw) } return rtt / 2 } @@ -202,11 +200,7 @@ var _ = Describe("MITM test", func() { defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - tr := &quic.Transport{ - Conn: clientUDPConn, - ConnectionIDLength: connIDLen, - } - conn, err := tr.Dial( + conn, err := clientTransport.Dial( context.Background(), raddr, getTLSClientConfig(), @@ -226,7 +220,7 @@ var _ = Describe("MITM test", func() { dropCb := func(dir quicproxy.Direction, raw []byte) bool { defer GinkgoRecover() if dir == quicproxy.DirectionIncoming { - _, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr()) + _, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return false @@ -238,7 +232,7 @@ var _ = Describe("MITM test", func() { dropCb := func(dir quicproxy.Direction, raw []byte) bool { defer GinkgoRecover() if dir == quicproxy.DirectionOutgoing { - _, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr()) + _, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return false @@ -276,7 +270,7 @@ var _ = Describe("MITM test", func() { if rand.Intn(interval) == 0 { pos := rand.Intn(len(raw)) raw[pos] = byte(rand.Intn(256)) - _, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr()) + _, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) atomic.AddInt32(&numCorrupted, 1) return true @@ -296,7 +290,7 @@ var _ = Describe("MITM test", func() { if rand.Intn(interval) == 0 { pos := rand.Intn(len(raw)) raw[pos] = byte(rand.Intn(256)) - _, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr()) + _, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) atomic.AddInt32(&numCorrupted, 1) return true @@ -320,17 +314,13 @@ var _ = Describe("MITM test", func() { proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - tr := &quic.Transport{ - Conn: clientUDPConn, - ConnectionIDLength: connIDLen, - } - _, err = tr.Dial( + _, err = clientTransport.Dial( context.Background(), raddr, getTLSClientConfig(), getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}), ) - return func() { tr.Close(); serverCloseFn() }, err + return func() { clientTransport.Close(); serverCloseFn() }, err } // fails immediately because client connection closes when it can't find compatible version @@ -356,7 +346,7 @@ var _ = Describe("MITM test", func() { ) // Send the packet - _, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr()) + _, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) close(done) } @@ -393,7 +383,7 @@ var _ = Describe("MITM test", func() { fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12}) retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) - _, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr()) + _, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return rtt / 2 @@ -423,7 +413,7 @@ var _ = Describe("MITM test", func() { defer close(done) injected = true initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil) - _, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr()) + _, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return rtt @@ -453,7 +443,7 @@ var _ = Describe("MITM test", func() { // Fake Initial with ACK for packet 2 (unsent) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack}) - _, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr()) + _, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return rtt diff --git a/sys_conn.go b/sys_conn.go index c68477c3..414472e7 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -27,27 +27,7 @@ type OOBCapablePacketConn interface { var _ OOBCapablePacketConn = &net.UDPConn{} -// OptimizeConn takes a net.PacketConn and attempts to enable various optimizations that will improve QUIC performance: -// 1. It enables the Don't Fragment (DF) bit on the IP header. -// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899). -// 2. It enables reading of the ECN bits from the IP header. -// This allows the remote node to speed up its loss detection and recovery. -// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket. -// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux). -// -// In order for this to work, the connection needs to implement the OOBCapablePacketConn interface (as a *net.UDPConn does). -// -// It's only necessary to call this function explicitly if the application calls WriteTo -// after passing the connection to the Transport. -func OptimizeConn(c net.PacketConn) (net.PacketConn, error) { - return wrapConn(c) -} - -func wrapConn(pc net.PacketConn) (interface { - net.PacketConn - rawConn -}, error, -) { +func wrapConn(pc net.PacketConn) (rawConn, error) { if err := setReceiveBuffer(pc); err != nil { if !strings.Contains(err.Error(), "use of closed network connection") { setBufferWarningOnce.Do(func() { diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 45774f2c..84d5e7e6 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -230,13 +230,6 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { return p, nil } -// WriteTo (re)implements the net.PacketConn method. -// This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection. -// With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set. -func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) { - return c.WritePacket(p, uint16(len(p)), addr, nil) -} - // WritePacket writes a new packet. // If the connection supports GSO (and we activated GSO support before), // it appends the UDP_SEGMENT size message to oob. diff --git a/transport.go b/transport.go index ace443c5..3c985965 100644 --- a/transport.go +++ b/transport.go @@ -26,9 +26,16 @@ type Transport struct { // A single net.PacketConn can only be handled by one Transport. // Bad things will happen if passed to multiple Transports. // - // If not done by the user, the connection is passed through OptimizeConn to enable a number of optimizations. - // After passing the connection to the Transport, it's invalid to call ReadFrom on the connection. - // Calling WriteTo is only valid on the connection returned by OptimizeConn. + // A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface, + // as a *net.UDPConn does. + // 1. It enables the Don't Fragment (DF) bit on the IP header. + // This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899). + // 2. It enables reading of the ECN bits from the IP header. + // This allows the remote node to speed up its loss detection and recovery. + // 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket. + // 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux). + // + // After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection. Conn net.PacketConn // The length of the connection ID in bytes. @@ -99,7 +106,7 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(true); err != nil { + if err := t.init(false); err != nil { return nil, err } s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) @@ -128,7 +135,7 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen return nil, errListenerAlreadySet } conf = populateServerConfig(conf) - if err := t.init(true); err != nil { + if err := t.init(false); err != nil { return nil, err } s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) @@ -145,7 +152,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config return nil, err } conf = populateConfig(conf) - if err := t.init(false); err != nil { + if err := t.init(t.isSingleUse); err != nil { return nil, err } var onClose func() @@ -163,7 +170,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C return nil, err } conf = populateConfig(conf) - if err := t.init(false); err != nil { + if err := t.init(t.isSingleUse); err != nil { return nil, err } var onClose func() @@ -175,7 +182,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) } -func (t *Transport) init(isServer bool) error { +func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.initOnce.Do(func() { var conn rawConn if c, ok := t.Conn.(rawConn); ok { @@ -203,7 +210,7 @@ func (t *Transport) init(isServer bool) error { t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() } else { connIDLen := t.ConnectionIDLength - if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) { + if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs { connIDLen = protocol.DefaultConnectionIDLength } t.connIDLen = connIDLen @@ -217,6 +224,14 @@ func (t *Transport) init(isServer bool) error { return t.initErr } +// WriteTo sends a packet on the underlying connection. +func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { + if err := t.init(false); err != nil { + return 0, err + } + return t.conn.WritePacket(b, uint16(len(b)), addr, nil) +} + func (t *Transport) enqueueClosePacket(p closePacket) { select { case t.closeQueue <- p: