Ver Fonte

try with sendmmsg merged back

Ryan há 1 mês atrás
pai
commit
c9a695c2bf
10 ficheiros alterados com 318 adições e 44 exclusões
  1. 52 42
      inside.go
  2. 81 2
      interface.go
  3. 9 0
      udp/conn.go
  4. 11 0
      udp/udp_darwin.go
  5. 11 0
      udp/udp_generic.go
  6. 112 0
      udp/udp_linux.go
  7. 10 0
      udp/udp_linux_32.go
  8. 10 0
      udp/udp_linux_64.go
  9. 11 0
      udp/udp_rio_windows.go
  10. 11 0
      udp/udp_tester.go

+ 52 - 42
inside.go

@@ -11,19 +11,19 @@ import (
 	"github.com/slackhq/nebula/routing"
 )
 
-func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, queue func(netip.AddrPort, int), q int, localCache firewall.ConntrackCache) bool {
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
 		}
-		return
+		return false
 	}
 
 	// Ignore local broadcast packets
 	if f.dropLocalBroadcast {
 		if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
-			return
+			return false
 		}
 	}
 
@@ -40,12 +40,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		}
 		// Otherwise, drop. On linux, we should never see these packets - Linux
 		// routes packets from the nebula addr to the nebula addr through the loopback device.
-		return
+		return false
 	}
 
 	// Ignore multicast packets
 	if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
-		return
+		return false
 	}
 
 	hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
@@ -59,26 +59,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 				WithField("fwPacket", fwPacket).
 				Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
 		}
-		return
+		return false
 	}
 
 	if !ready {
-		return
+		return false
 	}
 
 	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason == nil {
-		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
+		return f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, queue, q)
+	}
 
-	} else {
-		f.rejectInside(packet, out, q)
-		if f.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(f.l).
-				WithField("fwPacket", fwPacket).
-				WithField("reason", dropReason).
-				Debugln("dropping outbound packet")
-		}
+	f.rejectInside(packet, out, q)
+	if f.l.Level >= logrus.DebugLevel {
+		hostinfo.logger(f.l).
+			WithField("fwPacket", fwPacket).
+			WithField("reason", dropReason).
+			Debugln("dropping outbound packet")
 	}
+	return false
 }
 
 func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
@@ -117,7 +117,7 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 		return
 	}
 
-	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
+	_ = f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, nil, q)
 }
 
 // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
@@ -228,7 +228,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 		return
 	}
 
-	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
+	_ = f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
 }
 
 // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
@@ -258,12 +258,12 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
 
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
-	f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
+	_ = f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
 }
 
 func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
-	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
+	_ = f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, nil, 0)
 }
 
 // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
@@ -331,9 +331,12 @@ func (f *Interface) SendVia(via *HostInfo,
 	f.connectionManager.RelayUsed(relay.LocalIndex)
 }
 
-func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
+// sendNoMetrics encrypts and writes/queues an outbound packet. It returns true
+// when the payload has been handed to a caller-provided queue (meaning the
+// caller is responsible for flushing it later).
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, queue func(netip.AddrPort, int), q int) bool {
 	if ci.eKey == nil {
-		return
+		return false
 	}
 	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
 	fullOut := out
@@ -380,32 +383,39 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 			WithField("udpAddr", remote).WithField("counter", c).
 			WithField("attemptedCounter", c).
 			Error("Failed to encrypt outgoing packet")
-		return
+		return false
 	}
 
-	if remote.IsValid() {
-		err = f.writers[q].WriteTo(out, remote)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).
-				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
+	dest := remote
+	if !dest.IsValid() {
+		dest = hostinfo.remote
+	}
+
+	if dest.IsValid() {
+		if queue != nil {
+			queue(dest, len(out))
+			return true
 		}
-	} else if hostinfo.remote.IsValid() {
-		err = f.writers[q].WriteTo(out, hostinfo.remote)
+
+		err = f.writers[q].WriteTo(out, dest)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).
-				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
+				WithField("udpAddr", dest).Error("Failed to write outgoing packet")
 		}
-	} else {
-		// Try to send via a relay
-		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
-			if err != nil {
-				hostinfo.relayState.DeleteRelay(relayIP)
-				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
-				continue
-			}
-			f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
-			break
+		return false
+	}
+
+	// Try to send via a relay
+	for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
+		relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
+		if err != nil {
+			hostinfo.relayState.DeleteRelay(relayIP)
+			hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
+			continue
 		}
+		f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
+		break
 	}
+
+	return false
 }

+ 81 - 2
interface.go

@@ -29,6 +29,7 @@ const (
 	outboundBatchSizeDefault     = 32
 	batchFlushIntervalDefault    = 50 * time.Microsecond
 	maxOutstandingBatchesDefault = 1028
+	sendBatchSizeDefault         = 32
 )
 
 type InterfaceConfig struct {
@@ -120,12 +121,21 @@ type Interface struct {
 	packetBatchPool   sync.Pool
 	outboundBatchPool sync.Pool
 
+	sendPool      sync.Pool
+	sendBatchSize int
+
 	inboundBatchSize      int
 	outboundBatchSize     int
 	batchFlushInterval    time.Duration
 	maxOutstandingPerChan int
 }
 
+type outboundSend struct {
+	buf    *[]byte
+	length int
+	addr   netip.AddrPort
+}
+
 type packetBatch struct {
 	packets []*packet.Packet
 }
@@ -194,6 +204,48 @@ func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
 	f.outboundBatchPool.Put(b)
 }
 
+func (f *Interface) getSendBuffer() *[]byte {
+	if v := f.sendPool.Get(); v != nil {
+		buf := v.(*[]byte)
+		*buf = (*buf)[:0]
+		return buf
+	}
+	b := make([]byte, mtu)
+	return &b
+}
+
+func (f *Interface) releaseSendBuffer(buf *[]byte) {
+	if buf == nil {
+		return
+	}
+	*buf = (*buf)[:0]
+	f.sendPool.Put(buf)
+}
+
+func (f *Interface) flushSendQueue(q int, pending *[]outboundSend) {
+	if len(*pending) == 0 {
+		return
+	}
+
+	batch := make([]udp.BatchPacket, len(*pending))
+	for i, entry := range *pending {
+		batch[i] = udp.BatchPacket{
+			Payload: (*entry.buf)[:entry.length],
+			Addr:    entry.addr,
+		}
+	}
+
+	sent, err := f.writers[q].WriteBatch(batch)
+	if err != nil {
+		f.l.WithError(err).WithField("sent", sent).Error("Failed to batch send packets")
+	}
+
+	for _, entry := range *pending {
+		f.releaseSendBuffer(entry.buf)
+	}
+	*pending = (*pending)[:0]
+}
+
 type EncWriter interface {
 	SendVia(via *HostInfo,
 		relay *Relay,
@@ -316,6 +368,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		outboundBatchSize:     bc.OutboundBatchSize,
 		batchFlushInterval:    bc.FlushInterval,
 		maxOutstandingPerChan: bc.MaxOutstandingPerChan,
+		sendBatchSize:         bc.OutboundBatchSize,
 	}
 
 	for i := 0; i < c.routines; i++ {
@@ -340,6 +393,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return newOutboundBatch(ifce.outboundBatchSize)
 	}}
 
+	ifce.sendPool = sync.Pool{New: func() any {
+		buf := make([]byte, mtu)
+		return &buf
+	}}
+
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -539,18 +597,39 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	fwPacket1 := &firewall.Packet{}
 	nb1 := make([]byte, 12, 12)
-	result1 := make([]byte, mtu)
+	pending := make([]outboundSend, 0, f.sendBatchSize)
 
 	for {
 		select {
 		case batch := <-f.outbound[i]:
 			for _, data := range batch.payloads {
-				f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
+				sendBuf := f.getSendBuffer()
+				buf := (*sendBuf)[:0]
+				queue := func(addr netip.AddrPort, length int) {
+					pending = append(pending, outboundSend{
+						buf:    sendBuf,
+						length: length,
+						addr:   addr,
+					})
+					if len(pending) >= f.sendBatchSize {
+						f.flushSendQueue(i, &pending)
+					}
+				}
+				sent := f.consumeInsidePacket(*data, fwPacket1, nb1, buf, queue, i, conntrackCache.Get(f.l))
+				if !sent {
+					f.releaseSendBuffer(sendBuf)
+				}
 				*data = (*data)[:mtu]
 				f.outPool.Put(data)
 			}
 			f.releaseOutboundBatch(batch)
+			if len(pending) > 0 {
+				f.flushSendQueue(i, &pending)
+			}
 		case <-ctx.Done():
+			if len(pending) > 0 {
+				f.flushSendQueue(i, &pending)
+			}
 			f.wg.Done()
 			return
 		}

+ 9 - 0
udp/conn.go

@@ -18,10 +18,16 @@ type Conn interface {
 	LocalAddr() (netip.AddrPort, error)
 	ListenOut(r EncReader) error
 	WriteTo(b []byte, addr netip.AddrPort) error
+	WriteBatch(pkts []BatchPacket) (int, error)
 	ReloadConfig(c *config.C)
 	Close() error
 }
 
+type BatchPacket struct {
+	Payload []byte
+	Addr    netip.AddrPort
+}
+
 type NoopConn struct{}
 
 func (NoopConn) Rebind() error {
@@ -36,6 +42,9 @@ func (NoopConn) ListenOut(_ EncReader) error {
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 	return nil
 }
+func (NoopConn) WriteBatch(_ []BatchPacket) (int, error) {
+	return 0, nil
+}
 func (NoopConn) ReloadConfig(_ *config.C) {
 	return
 }

+ 11 - 0
udp/udp_darwin.go

@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
 	}
 }
 
+func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	sent := 0
+	for _, pkt := range pkts {
+		if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+			return sent, err
+		}
+		sent++
+	}
+	return sent, nil
+}
+
 func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 	a := u.UDPConn.LocalAddr()
 

+ 11 - 0
udp/udp_generic.go

@@ -42,6 +42,17 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	return err
 }
 
+func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	sent := 0
+	for _, pkt := range pkts {
+		if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+			return sent, err
+		}
+		sent++
+	}
+	return sent, nil
+}
+
 func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
 	a := u.UDPConn.LocalAddr()
 

+ 112 - 0
udp/udp_linux.go

@@ -343,6 +343,118 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
 	return u.writeTo6(b, ip)
 }
 
+func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	if len(pkts) == 0 {
+		return 0, nil
+	}
+
+	msgs := make([]rawMessage, 0, len(pkts))
+	iovs := make([]iovec, 0, len(pkts))
+	names := make([][unix.SizeofSockaddrInet6]byte, 0, len(pkts))
+
+	sent := 0
+
+	for _, pkt := range pkts {
+		if len(pkt.Payload) == 0 {
+			sent++
+			continue
+		}
+
+		if u.enableGSO && pkt.Addr.IsValid() {
+			if err := u.queueGSOPacket(pkt.Payload, pkt.Addr); err == nil {
+				sent++
+				continue
+			} else if !errors.Is(err, errGSOFallback) {
+				return sent, err
+			}
+		}
+
+		if !pkt.Addr.IsValid() {
+			if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+				return sent, err
+			}
+			sent++
+			continue
+		}
+
+		msgs = append(msgs, rawMessage{})
+		iovs = append(iovs, iovec{})
+		names = append(names, [unix.SizeofSockaddrInet6]byte{})
+
+		idx := len(msgs) - 1
+		msg := &msgs[idx]
+		iov := &iovs[idx]
+		name := &names[idx]
+
+		setIovecSlice(iov, pkt.Payload)
+		msg.Hdr.Iov = iov
+		msg.Hdr.Iovlen = 1
+		setRawMessageControl(msg, nil)
+		msg.Hdr.Flags = 0
+
+		nameLen, err := u.encodeSockaddr(name[:], pkt.Addr)
+		if err != nil {
+			return sent, err
+		}
+		msg.Hdr.Name = &name[0]
+		msg.Hdr.Namelen = nameLen
+	}
+
+	if len(msgs) == 0 {
+		return sent, nil
+	}
+
+	offset := 0
+	for offset < len(msgs) {
+		n, _, errno := unix.Syscall6(
+			unix.SYS_SENDMMSG,
+			uintptr(u.sysFd),
+			uintptr(unsafe.Pointer(&msgs[offset])),
+			uintptr(len(msgs)-offset),
+			0,
+			0,
+			0,
+		)
+
+		if errno != 0 {
+			if errno == unix.EINTR {
+				continue
+			}
+			return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
+		}
+
+		if n == 0 {
+			break
+		}
+		offset += int(n)
+	}
+
+	return sent + len(msgs), nil
+}
+
+func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
+	if u.isV4 {
+		if !addr.Addr().Is4() {
+			return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
+		}
+		var sa unix.RawSockaddrInet4
+		sa.Family = unix.AF_INET
+		sa.Addr = addr.Addr().As4()
+		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
+		size := unix.SizeofSockaddrInet4
+		copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
+		return uint32(size), nil
+	}
+
+	var sa unix.RawSockaddrInet6
+	sa.Family = unix.AF_INET6
+	sa.Addr = addr.Addr().As16()
+	binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
+	size := unix.SizeofSockaddrInet6
+	copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
+	return uint32(size), nil
+}
+
 func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
 	var rsa unix.RawSockaddrInet6
 	rsa.Family = unix.AF_INET6

+ 10 - 0
udp/udp_linux_32.go

@@ -77,3 +77,13 @@ func getRawMessageFlags(msg *rawMessage) int {
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 	h.Len = uint32(l)
 }
+
+func setIovecSlice(iov *iovec, b []byte) {
+	if len(b) == 0 {
+		iov.Base = nil
+		iov.Len = 0
+		return
+	}
+	iov.Base = &b[0]
+	iov.Len = uint32(len(b))
+}

+ 10 - 0
udp/udp_linux_64.go

@@ -80,3 +80,13 @@ func getRawMessageFlags(msg *rawMessage) int {
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 	h.Len = uint64(l)
 }
+
+func setIovecSlice(iov *iovec, b []byte) {
+	if len(b) == 0 {
+		iov.Base = nil
+		iov.Len = 0
+		return
+	}
+	iov.Base = &b[0]
+	iov.Len = uint64(len(b))
+}

+ 11 - 0
udp/udp_rio_windows.go

@@ -304,6 +304,17 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
 	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
 }
 
+func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	sent := 0
+	for _, pkt := range pkts {
+		if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+			return sent, err
+		}
+		sent++
+	}
+	return sent, nil
+}
+
 func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
 	sa, err := windows.Getsockname(u.sock)
 	if err != nil {

+ 11 - 0
udp/udp_tester.go

@@ -106,6 +106,17 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	return nil
 }
 
+func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	sent := 0
+	for _, pkt := range pkts {
+		if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+			return sent, err
+		}
+		sent++
+	}
+	return sent, nil
+}
+
 func (u *TesterConn) ListenOut(r EncReader) {
 	for {
 		p, ok := <-u.RxPackets