forked from quic-go/quic-go
When the server is listening on multiple interfaces or interfaces with multiple IPs, the outgoing datagrams are sometime delivered with the wrong source IP address. In order to fix that, each quic connection needs to extract the destination IP (and optionally interface id) of the received datagrams, and set it as source IP (and interface) on the sent datagrams. On most platforms, this can be done using ancillary data with recvmsg() and sendmsg(). Some of the machinery for this is already there for ECN, this change extends it to read the destination IP info and write it to the outgoing packets. Fix #1736
240 lines
6.9 KiB
Go
240 lines
6.9 KiB
Go
// +build darwin linux freebsd
|
|
|
|
package quic
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"runtime"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"golang.org/x/sys/unix"
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
)
|
|
|
|
const ecnMask uint8 = 0x3
|
|
|
|
func inspectReadBuffer(c interface{}) (int, error) {
|
|
conn, ok := c.(interface {
|
|
SyscallConn() (syscall.RawConn, error)
|
|
})
|
|
if !ok {
|
|
return 0, errors.New("doesn't have a SyscallConn")
|
|
}
|
|
rawConn, err := conn.SyscallConn()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
|
|
}
|
|
var size int
|
|
var serr error
|
|
if err := rawConn.Control(func(fd uintptr) {
|
|
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
|
|
}); err != nil {
|
|
return 0, err
|
|
}
|
|
return size, serr
|
|
}
|
|
|
|
type oobConn struct {
|
|
OOBCapablePacketConn
|
|
oobBuffer []byte
|
|
}
|
|
|
|
var _ connection = &oobConn{}
|
|
|
|
func newConn(c OOBCapablePacketConn) (*oobConn, error) {
|
|
rawConn, err := c.SyscallConn()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
needsPacketInfo := false
|
|
if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
|
|
needsPacketInfo = true
|
|
}
|
|
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
|
|
// Try enabling receiving of ECN and packet info for both IP versions.
|
|
// We expect at least one of those syscalls to succeed.
|
|
var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error
|
|
if err := rawConn.Control(func(fd uintptr) {
|
|
errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
|
|
errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
|
|
|
|
if needsPacketInfo {
|
|
errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1)
|
|
errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1)
|
|
}
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
switch {
|
|
case errECNIPv4 == nil && errECNIPv6 == nil:
|
|
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
|
|
case errECNIPv4 == nil && errECNIPv6 != nil:
|
|
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
|
|
case errECNIPv4 != nil && errECNIPv6 == nil:
|
|
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
|
|
case errECNIPv4 != nil && errECNIPv6 != nil:
|
|
return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
|
|
}
|
|
if needsPacketInfo {
|
|
switch {
|
|
case errPIIPv4 == nil && errPIIPv6 == nil:
|
|
utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.")
|
|
case errPIIPv4 == nil && errPIIPv6 != nil:
|
|
utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.")
|
|
case errPIIPv4 != nil && errPIIPv6 == nil:
|
|
utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.")
|
|
case errPIIPv4 != nil && errPIIPv6 != nil:
|
|
return nil, errors.New("activating packet info failed for both IPv4 and IPv6")
|
|
}
|
|
}
|
|
return &oobConn{
|
|
OOBCapablePacketConn: c,
|
|
oobBuffer: make([]byte, 128),
|
|
}, nil
|
|
}
|
|
|
|
func (c *oobConn) ReadPacket() (*receivedPacket, error) {
|
|
buffer := getPacketBuffer()
|
|
// The packet size should not exceed protocol.MaxPacketBufferSize bytes
|
|
// If it does, we only read a truncated packet, which will then end up undecryptable
|
|
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
|
|
c.oobBuffer = c.oobBuffer[:cap(c.oobBuffer)]
|
|
n, oobn, _, addr, err := c.OOBCapablePacketConn.ReadMsgUDP(buffer.Data, c.oobBuffer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ctrlMsgs, err := unix.ParseSocketControlMessage(c.oobBuffer[:oobn])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var ecn protocol.ECN
|
|
var destIP net.IP
|
|
var ifIndex uint32
|
|
for _, ctrlMsg := range ctrlMsgs {
|
|
if ctrlMsg.Header.Level == unix.IPPROTO_IP {
|
|
switch ctrlMsg.Header.Type {
|
|
case msgTypeIPTOS:
|
|
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
|
|
case msgTypeIPv4PKTINFO:
|
|
// struct in_pktinfo {
|
|
// unsigned int ipi_ifindex; /* Interface index */
|
|
// struct in_addr ipi_spec_dst; /* Local address */
|
|
// struct in_addr ipi_addr; /* Header Destination
|
|
// address */
|
|
// };
|
|
if len(ctrlMsg.Data) == 12 {
|
|
ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data)
|
|
destIP = net.IP(ctrlMsg.Data[8:12])
|
|
} else if len(ctrlMsg.Data) == 4 {
|
|
// FreeBSD
|
|
destIP = net.IP(ctrlMsg.Data)
|
|
}
|
|
}
|
|
}
|
|
if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 {
|
|
switch ctrlMsg.Header.Type {
|
|
case unix.IPV6_TCLASS:
|
|
ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
|
|
case msgTypeIPv6PKTINFO:
|
|
// struct in6_pktinfo {
|
|
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
|
|
// unsigned int ipi6_ifindex; /* send/recv interface index */
|
|
// };
|
|
if len(ctrlMsg.Data) == 20 {
|
|
destIP = net.IP(ctrlMsg.Data[:16])
|
|
ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:])
|
|
}
|
|
}
|
|
}
|
|
}
|
|
var info *packetInfo
|
|
if destIP != nil {
|
|
info = &packetInfo{
|
|
addr: destIP,
|
|
ifIndex: ifIndex,
|
|
}
|
|
}
|
|
return &receivedPacket{
|
|
remoteAddr: addr,
|
|
rcvTime: time.Now(),
|
|
data: buffer.Data[:n],
|
|
ecn: ecn,
|
|
info: info,
|
|
buffer: buffer,
|
|
}, nil
|
|
}
|
|
|
|
func (c *oobConn) WritePacket(b []byte, addr net.Addr, info *packetInfo) (n int, err error) {
|
|
n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, info.OOB(), addr.(*net.UDPAddr))
|
|
return n, err
|
|
}
|
|
|
|
func (info *packetInfo) OOB() []byte {
|
|
if info == nil {
|
|
return nil
|
|
}
|
|
info.once.Do(info.computeOOB)
|
|
return info.oob
|
|
}
|
|
|
|
func (info *packetInfo) computeOOB() {
|
|
if ip4 := info.addr.To4(); ip4 != nil {
|
|
// struct in_pktinfo {
|
|
// unsigned int ipi_ifindex; /* Interface index */
|
|
// struct in_addr ipi_spec_dst; /* Local address */
|
|
// struct in_addr ipi_addr; /* Header Destination address */
|
|
// };
|
|
msgLen := 12
|
|
if runtime.GOOS == "freebsd" {
|
|
msgLen = 4
|
|
}
|
|
cmsglen := cmsgLen(msgLen)
|
|
oob := make([]byte, cmsglen)
|
|
cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&oob[0]))
|
|
cmsg.Level = syscall.IPPROTO_TCP
|
|
cmsg.Type = msgTypeIPv4PKTINFO
|
|
cmsg.SetLen(cmsglen)
|
|
off := cmsgLen(0)
|
|
if runtime.GOOS != "freebsd" {
|
|
// FreeBSD does not support in_pktinfo, just an in_addr is sent
|
|
binary.LittleEndian.PutUint32(oob[off:], info.ifIndex)
|
|
off += 4
|
|
}
|
|
copy(oob[off:], ip4)
|
|
info.oob = oob
|
|
} else if len(info.addr) == 16 {
|
|
// struct in6_pktinfo {
|
|
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
|
|
// unsigned int ipi6_ifindex; /* send/recv interface index */
|
|
// };
|
|
const msgLen = 20
|
|
cmsglen := cmsgLen(msgLen)
|
|
oob := make([]byte, cmsglen)
|
|
cmsg := (*syscall.Cmsghdr)(unsafe.Pointer(&oob[0]))
|
|
cmsg.Level = syscall.IPPROTO_IPV6
|
|
cmsg.Type = msgTypeIPv6PKTINFO
|
|
cmsg.SetLen(cmsglen)
|
|
off := cmsgLen(0)
|
|
off += copy(oob[off:], info.addr)
|
|
binary.LittleEndian.PutUint32(oob[off:], info.ifIndex)
|
|
info.oob = oob
|
|
}
|
|
}
|
|
|
|
func cmsgLen(datalen int) int {
|
|
return cmsgAlign(syscall.SizeofCmsghdr) + datalen
|
|
}
|
|
|
|
func cmsgAlign(salen int) int {
|
|
const sizeOfPtr = 0x8
|
|
salign := sizeOfPtr
|
|
return (salen + salign - 1) & ^(salign - 1)
|
|
}
|