From 090e505aa9072a4125765e59b0399fc7850aa3d9 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 31 Aug 2023 14:49:27 +0700 Subject: [PATCH] move GSO control message handling to the oobConn (#4056) * move GSO control message handling to the oobConn * disable OOB test on Windows * improve GSO tests * update ooConn.WritePacket comment --- connection.go | 12 ++++----- connection_test.go | 36 +++++++++++++------------- mock_raw_conn_test.go | 8 +++--- mock_send_conn_test.go | 3 +-- mock_sender_test.go | 3 +-- packet_handler_map.go | 5 ++-- send_conn.go | 40 +++++++++-------------------- send_conn_test.go | 58 ++++++++++++++++++++---------------------- send_queue.go | 15 +++++------ send_queue_test.go | 8 +++--- server.go | 6 ++--- sys_conn.go | 5 +++- sys_conn_oob.go | 10 ++++++-- sys_conn_oob_test.go | 34 +++++++++++++++++++++++++ transport.go | 6 ++--- 15 files changed, 134 insertions(+), 115 deletions(-) diff --git a/connection.go b/connection.go index 877f2d03..eb80cd13 100644 --- a/connection.go +++ b/connection.go @@ -1832,7 +1832,7 @@ func (s *connection) sendPackets(now time.Time) error { } s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, buf.Len()) + s.sendQueue.Send(buf, 0) // This is kind of a hack. We need to trigger sending again somehow. s.pacingDeadline = deadlineSendImmediately return nil @@ -1881,7 +1881,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error { return err } - s.sendQueue.Send(buf, buf.Len()) + s.sendQueue.Send(buf, 0) if s.sendQueue.WouldBlock() { return nil @@ -1938,7 +1938,7 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { continue } - s.sendQueue.Send(buf, maxSize) + s.sendQueue.Send(buf, uint16(maxSize)) if dontSendMore { return nil @@ -1986,7 +1986,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { } s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, buf.Len()) + s.sendQueue.Send(buf, 0) return nil } @@ -2078,7 +2078,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) } s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer, packet.buffer.Len()) + s.sendQueue.Send(packet.buffer, 0) return nil } @@ -2101,7 +2101,7 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { return nil, err } s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len()) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0) } func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { diff --git a/connection_test.go b/connection_test.go index 05a0f20d..eb512513 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1243,7 +1243,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{ DestConnectionID: p.DestConnID, PacketNumber: p.PacketNumber, @@ -1291,7 +1291,7 @@ var _ = Describe("Connection", func() { conn.connFlowController = fc runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), nil, []logging.Frame{}) conn.scheduleSending() Eventually(sent).Should(BeClosed()) @@ -1347,7 +1347,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1372,7 +1372,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(sent) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(sent) }) if enc == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) } else { @@ -1428,10 +1428,10 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, []byte("packet11")) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { Expect(b.Data).To(Equal([]byte("packet10"))) }) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16) { Expect(b.Data).To(Equal([]byte("packet11"))) }) go func() { @@ -1456,7 +1456,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), conn.mtuDiscoverer.CurrentSize()).Do(func(b *packetBuffer, l protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1481,7 +1481,7 @@ var _ = Describe("Connection", func() { expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1) expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), conn.mtuDiscoverer.CurrentSize()).Do(func(b *packetBuffer, l protocol.ByteCount) { + sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize())).Do(func(b *packetBuffer, l uint16) { Expect(b.Data).To(Equal(append(payload1, payload2...))) }) go func() { @@ -1564,7 +1564,7 @@ var _ = Describe("Connection", func() { ) written := make(chan struct{}, 2) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(2) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(2) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1587,7 +1587,7 @@ var _ = Describe("Connection", func() { } written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(3) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }).Times(3) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1622,7 +1622,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1000}, []byte("packet1000")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) }) @@ -1646,7 +1646,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, []byte("packet10")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { close(written) }) conn.scheduleSending() time.Sleep(scaleDuration(50 * time.Millisecond)) @@ -1661,7 +1661,7 @@ var _ = Describe("Connection", func() { written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) @@ -1680,7 +1680,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 1001}, []byte("packet1001")) packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1713,7 +1713,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, uint16) { written <- struct{}{} }) mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true) ping := ackhandler.Frame{Frame: &wire.PingFrame{}} mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) @@ -1776,7 +1776,7 @@ var _ = Describe("Connection", func() { time.Sleep(50 * time.Millisecond) // only EXPECT calls after scheduleSending is called written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() conn.scheduleSending() Eventually(written).Should(BeClosed()) @@ -1799,7 +1799,7 @@ var _ = Describe("Connection", func() { conn.receivedPacketHandler = rph written := make(chan struct{}) - sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, protocol.ByteCount) { close(written) }) + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16) { close(written) }) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() @@ -1864,7 +1864,7 @@ var _ = Describe("Connection", func() { ) sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar"), protocol.ByteCount(6)).Do(func([]byte, protocol.ByteCount) { close(sent) }) + mconn.EXPECT().Write([]byte("foobar"), uint16(0)).Do(func([]byte, uint16) { close(sent) }) go func() { defer GinkgoRecover() diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go index 0a1a0f3a..84d8f276 100644 --- a/mock_raw_conn_test.go +++ b/mock_raw_conn_test.go @@ -93,18 +93,18 @@ func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Cal } // WritePacket mocks base method. -func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte) (int, error) { +func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 uint16) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "WritePacket", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // WritePacket indicates an expected call of WritePacket. -func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3) } // capabilities mocks base method. diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index 04df8763..529b9c58 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -8,7 +8,6 @@ import ( net "net" reflect "reflect" - protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -78,7 +77,7 @@ func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { } // Write mocks base method. -func (m *MockSendConn) Write(arg0 []byte, arg1 protocol.ByteCount) error { +func (m *MockSendConn) Write(arg0 []byte, arg1 uint16) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Write", arg0, arg1) ret0, _ := ret[0].(error) diff --git a/mock_sender_test.go b/mock_sender_test.go index c2a0fa8f..40671562 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -7,7 +7,6 @@ package quic import ( reflect "reflect" - protocol "github.com/quic-go/quic-go/internal/protocol" gomock "go.uber.org/mock/gomock" ) @@ -75,7 +74,7 @@ func (mr *MockSenderMockRecorder) Run() *gomock.Call { } // Send mocks base method. -func (m *MockSender) Send(arg0 *packetBuffer, arg1 protocol.ByteCount) { +func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16) { m.ctrl.T.Helper() m.ctrl.Call(m, "Send", arg0, arg1) } diff --git a/packet_handler_map.go b/packet_handler_map.go index e0f0567d..60b7cef9 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -27,8 +27,9 @@ type connCapabilities struct { type rawConn interface { ReadPacket() (receivedPacket, error) // WritePacket writes a packet on the wire. - // If GSO is enabled, it's the caller's responsibility to set the correct control message. - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + // gsoSize is the size of a single packet, or 0 to disable GSO. + // It is invalid to set gsoSize if capabilities.GSO is not set. + WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer diff --git a/send_conn.go b/send_conn.go index d8ddbc87..272c61b1 100644 --- a/send_conn.go +++ b/send_conn.go @@ -1,17 +1,14 @@ package quic import ( - "fmt" - "math" "net" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { - Write(b []byte, size protocol.ByteCount) error + Write(b []byte, gsoSize uint16) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr @@ -27,8 +24,7 @@ type sconn struct { logger utils.Logger - info packetInfo - oob []byte + packetInfoOOB []byte // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. gotGSOError bool } @@ -51,28 +47,16 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge oob = append(oob, make([]byte, 32)...) oob = oob[:l] return &sconn{ - rawConn: c, - localAddr: localAddr, - remoteAddr: remote, - info: info, - oob: oob, - logger: logger, + rawConn: c, + localAddr: localAddr, + remoteAddr: remote, + packetInfoOOB: oob, + logger: logger, } } -func (c *sconn) Write(p []byte, size protocol.ByteCount) error { - if !c.capabilities().GSO { - if protocol.ByteCount(len(p)) != size { - panic(fmt.Sprintf("inconsistent packet size (%d vs %d)", len(p), size)) - } - _, err := c.WritePacket(p, c.remoteAddr, c.oob) - return err - } - // GSO is supported. Append the control message and send. - if size > math.MaxUint16 { - panic("size overflow") - } - _, err := c.WritePacket(p, c.remoteAddr, appendUDPSegmentSizeMsg(c.oob, uint16(size))) +func (c *sconn) Write(p []byte, gsoSize uint16) error { + _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true @@ -82,10 +66,10 @@ func (c *sconn) Write(p []byte, size protocol.ByteCount) error { // send out the packets one by one for len(p) > 0 { l := len(p) - if l > int(size) { - l = int(size) + if l > int(gsoSize) { + l = int(gsoSize) } - if _, err := c.WritePacket(p[:l], c.remoteAddr, c.oob); err != nil { + if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0); err != nil { return err } p = p[l:] diff --git a/send_conn_test.go b/send_conn_test.go index 7f072430..963f2482 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -3,6 +3,7 @@ package quic import ( "net" "net/netip" + "runtime" "github.com/quic-go/quic-go/internal/utils" @@ -35,48 +36,43 @@ var _ = Describe("Connection (for sending packets)", func() { Expect(c.LocalAddr().String()).To(Equal("127.0.0.42:1234")) }) - if platformSupportsGSO { - It("writes with GSO", func() { + // We're not using an OOB conn on windows, and packetInfo.OOB() always returns an empty slice. + if runtime.GOOS != "windows" { + It("sets the OOB", func() { rawConn := NewMockRawConn(mockCtrl) rawConn.EXPECT().LocalAddr() - rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() - c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).Do(func(_ []byte, _ net.Addr, oob []byte) { - msg := appendUDPSegmentSizeMsg([]byte{}, 3) - Expect(oob).To(Equal(msg)) - }) - Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + rawConn.EXPECT().capabilities().AnyTimes() + pi := packetInfo{addr: netip.IPv6Loopback()} + Expect(pi.OOB()).ToNot(BeEmpty()) + c := newSendConn(rawConn, remoteAddr, pi, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, pi.OOB(), uint16(0)) + Expect(c.Write([]byte("foobar"), 0)).To(Succeed()) }) + } - It("disables GSO if writing fails", func() { + It("writes", func() { + rawConn := NewMockRawConn(mockCtrl) + rawConn.EXPECT().LocalAddr() + rawConn.EXPECT().capabilities().AnyTimes() + c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(3)) + Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) + }) + + if platformSupportsGSO { + It("disables GSO if sending fails", func() { rawConn := NewMockRawConn(mockCtrl) rawConn.EXPECT().LocalAddr() rawConn.EXPECT().capabilities().Return(connCapabilities{GSO: true}).AnyTimes() c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) Expect(c.capabilities().GSO).To(BeTrue()) gomock.InOrder( - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any()).DoAndReturn(func(_ []byte, _ net.Addr, oob []byte) (int, error) { - msg := appendUDPSegmentSizeMsg([]byte{}, 3) - Expect(oob).To(Equal(msg)) - return 0, errGSO - }), - rawConn.EXPECT().WritePacket([]byte("foo"), remoteAddr, gomock.Len(0)).Return(3, nil), - rawConn.EXPECT().WritePacket([]byte("bar"), remoteAddr, gomock.Len(0)).Return(3, nil), + rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Any(), uint16(4)).Return(0, errGSO), + rawConn.EXPECT().WritePacket([]byte("foob"), remoteAddr, gomock.Any(), uint16(0)).Return(4, nil), + rawConn.EXPECT().WritePacket([]byte("ar"), remoteAddr, gomock.Any(), uint16(0)).Return(2, nil), ) - Expect(c.Write([]byte("foobar"), 3)).To(Succeed()) - Expect(c.capabilities().GSO).To(BeFalse()) // GSO support is now disabled - // make sure we actually enforce that - Expect(func() { c.Write([]byte("foobar"), 3) }).To(PanicWith("inconsistent packet size (6 vs 3)")) - }) - } else { - It("writes without GSO", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - rawConn := NewMockRawConn(mockCtrl) - rawConn.EXPECT().LocalAddr() - rawConn.EXPECT().capabilities() - c := newSendConn(rawConn, remoteAddr, packetInfo{}, utils.DefaultLogger) - rawConn.EXPECT().WritePacket([]byte("foobar"), remoteAddr, gomock.Len(0)) - Expect(c.Write([]byte("foobar"), 6)).To(Succeed()) + Expect(c.Write([]byte("foobar"), 4)).To(Succeed()) + Expect(c.capabilities().GSO).To(BeFalse()) }) } }) diff --git a/send_queue.go b/send_queue.go index a9f7ca1a..2da546e5 100644 --- a/send_queue.go +++ b/send_queue.go @@ -1,9 +1,8 @@ package quic -import "github.com/quic-go/quic-go/internal/protocol" - type sender interface { - Send(p *packetBuffer, packetSize protocol.ByteCount) + // Send sends a packet. GSO is only used if gsoSize > 0. + Send(p *packetBuffer, gsoSize uint16) Run() error WouldBlock() bool Available() <-chan struct{} @@ -11,8 +10,8 @@ type sender interface { } type queueEntry struct { - buf *packetBuffer - size protocol.ByteCount + buf *packetBuffer + gsoSize uint16 } type sendQueue struct { @@ -40,9 +39,9 @@ func newSendQueue(conn sendConn) sender { // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) { +func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16) { select { - case h.queue <- queueEntry{buf: p, size: size}: + case h.queue <- queueEntry{buf: p, gsoSize: gsoSize}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { @@ -77,7 +76,7 @@ func (h *sendQueue) Run() error { // make sure that all queued packets are actually sent out shouldClose = true case e := <-h.queue: - if err := h.conn.Write(e.buf.Data, e.size); err != nil { + if err := h.conn.Write(e.buf.Data, e.gsoSize); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and diff --git a/send_queue_test.go b/send_queue_test.go index 69562c58..0ed7bea5 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -3,8 +3,6 @@ package quic import ( "errors" - "github.com/quic-go/quic-go/internal/protocol" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "go.uber.org/mock/gomock" @@ -31,7 +29,7 @@ var _ = Describe("Send Queue", func() { q.Send(p, 10) // make sure the packet size is passed through to the conn written := make(chan struct{}) - c.EXPECT().Write([]byte("foobar"), protocol.ByteCount(10)).Do(func([]byte, protocol.ByteCount) { close(written) }) + c.EXPECT().Write([]byte("foobar"), uint16(10)).Do(func([]byte, uint16) { close(written) }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -79,7 +77,7 @@ var _ = Describe("Send Queue", func() { write := make(chan struct{}, 1) written := make(chan struct{}, 100) // now start sending out packets. This should free up queue space. - c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, protocol.ByteCount) error { + c.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func([]byte, uint16) error { written <- struct{}{} <-write return nil @@ -149,7 +147,7 @@ var _ = Describe("Send Queue", func() { It("blocks Close() until the packet has been sent out", func() { written := make(chan []byte) - c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ protocol.ByteCount) { written <- p }) + c.EXPECT().Write(gomock.Any(), gomock.Any()).Do(func(p []byte, _ uint16) { written <- p }) done := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/server.go b/server.go index 14cc6f82..495ca65a 100644 --- a/server.go +++ b/server.go @@ -745,7 +745,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0) return err } @@ -844,7 +844,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han if s.tracer != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0) return err } @@ -882,7 +882,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { if s.tracer != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/sys_conn.go b/sys_conn.go index f2224e4c..a72aead5 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -104,7 +104,10 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16) (n int, err error) { + if gsoSize != 0 { + panic("cannot use GSO with a basicConn") + } return c.PacketConn.WriteTo(b, addr) } diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 4026a7b3..24b73e9c 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -228,8 +228,14 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -// If the connection supports GSO, it's the caller's responsibility to append the right control mesage. -func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) { +func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16) (int, error) { + oob := packetInfoOOB + if gsoSize > 0 { + if !c.capabilities().GSO { + panic("GSO disabled") + } + oob = appendUDPSegmentSizeMsg(oob, gsoSize) + } n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index 3ae97ed9..54dac82d 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -18,6 +18,16 @@ import ( "go.uber.org/mock/gomock" ) +type oobRecordingConn struct { + *net.UDPConn + oobs [][]byte +} + +func (c *oobRecordingConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { + c.oobs = append(c.oobs, oob) + return c.UDPConn.WriteMsgUDP(b, oob, addr) +} + var _ = Describe("OOB Conn Test", func() { runServer := func(network, address string) (*net.UDPConn, <-chan receivedPacket) { addr, err := net.ResolveUDPAddr(network, address) @@ -242,4 +252,28 @@ var _ = Describe("OOB Conn Test", func() { } }) }) + + if platformSupportsGSO { + Context("GSO", func() { + It("appends the GSO control message", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + c := &oobRecordingConn{UDPConn: udpConn} + oobConn, err := newConn(c, true) + Expect(err).ToNot(HaveOccurred()) + Expect(oobConn.capabilities().GSO).To(BeTrue()) + + oob := make([]byte, 0, 42) + oobConn.WritePacket([]byte("foobar"), addr, oob, 3) + Expect(c.oobs).To(HaveLen(1)) + oobMsg := c.oobs[0] + Expect(oobMsg).ToNot(BeEmpty()) + Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob + expected := appendUDPSegmentSizeMsg([]byte{}, 3) + Expect(oobMsg).To(Equal(expected)) + }) + }) + } }) diff --git a/transport.go b/transport.go index 9f93cbb9..41c97347 100644 --- a/transport.go +++ b/transport.go @@ -228,7 +228,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, addr, nil) + return t.conn.WritePacket(b, addr, nil, 0) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -246,7 +246,7 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } @@ -414,7 +414,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } }