Browse Source

try with sendmmsg merged back

Ryan 1 month ago
parent
commit
c9a695c2bf
10 changed files with 318 additions and 44 deletions
  1. 52 42
      inside.go
  2. 81 2
      interface.go
  3. 9 0
      udp/conn.go
  4. 11 0
      udp/udp_darwin.go
  5. 11 0
      udp/udp_generic.go
  6. 112 0
      udp/udp_linux.go
  7. 10 0
      udp/udp_linux_32.go
  8. 10 0
      udp/udp_linux_64.go
  9. 11 0
      udp/udp_rio_windows.go
  10. 11 0
      udp/udp_tester.go

+ 52 - 42
inside.go

@@ -11,19 +11,19 @@ import (
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/routing"
 )
 )
 
 
-func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, queue func(netip.AddrPort, int), q int, localCache firewall.ConntrackCache) bool {
 	err := newPacket(packet, false, fwPacket)
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
 	if err != nil {
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
 			f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
 		}
 		}
-		return
+		return false
 	}
 	}
 
 
 	// Ignore local broadcast packets
 	// Ignore local broadcast packets
 	if f.dropLocalBroadcast {
 	if f.dropLocalBroadcast {
 		if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
 		if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
-			return
+			return false
 		}
 		}
 	}
 	}
 
 
@@ -40,12 +40,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		}
 		}
 		// Otherwise, drop. On linux, we should never see these packets - Linux
 		// Otherwise, drop. On linux, we should never see these packets - Linux
 		// routes packets from the nebula addr to the nebula addr through the loopback device.
 		// routes packets from the nebula addr to the nebula addr through the loopback device.
-		return
+		return false
 	}
 	}
 
 
 	// Ignore multicast packets
 	// Ignore multicast packets
 	if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
 	if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
-		return
+		return false
 	}
 	}
 
 
 	hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
 	hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
@@ -59,26 +59,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 				WithField("fwPacket", fwPacket).
 				WithField("fwPacket", fwPacket).
 				Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
 				Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
 		}
 		}
-		return
+		return false
 	}
 	}
 
 
 	if !ready {
 	if !ready {
-		return
+		return false
 	}
 	}
 
 
 	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason == nil {
 	if dropReason == nil {
-		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
+		return f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, queue, q)
+	}
 
 
-	} else {
-		f.rejectInside(packet, out, q)
-		if f.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(f.l).
-				WithField("fwPacket", fwPacket).
-				WithField("reason", dropReason).
-				Debugln("dropping outbound packet")
-		}
+	f.rejectInside(packet, out, q)
+	if f.l.Level >= logrus.DebugLevel {
+		hostinfo.logger(f.l).
+			WithField("fwPacket", fwPacket).
+			WithField("reason", dropReason).
+			Debugln("dropping outbound packet")
 	}
 	}
+	return false
 }
 }
 
 
 func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
 func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
@@ -117,7 +117,7 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 		return
 		return
 	}
 	}
 
 
-	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
+	_ = f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, nil, q)
 }
 }
 
 
 // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
 // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
@@ -228,7 +228,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 		return
 		return
 	}
 	}
 
 
-	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
+	_ = f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
 }
 }
 
 
 // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
 // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
@@ -258,12 +258,12 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
 
 
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.messageMetrics.Tx(t, st, 1)
-	f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
+	_ = f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
 }
 }
 
 
 func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
 func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.messageMetrics.Tx(t, st, 1)
-	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
+	_ = f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, nil, 0)
 }
 }
 
 
 // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
 // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
@@ -331,9 +331,12 @@ func (f *Interface) SendVia(via *HostInfo,
 	f.connectionManager.RelayUsed(relay.LocalIndex)
 	f.connectionManager.RelayUsed(relay.LocalIndex)
 }
 }
 
 
-func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
+// sendNoMetrics encrypts and writes/queues an outbound packet. It returns true
+// when the payload has been handed to a caller-provided queue (meaning the
+// caller is responsible for flushing it later).
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, queue func(netip.AddrPort, int), q int) bool {
 	if ci.eKey == nil {
 	if ci.eKey == nil {
-		return
+		return false
 	}
 	}
 	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
 	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
 	fullOut := out
 	fullOut := out
@@ -380,32 +383,39 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 			WithField("udpAddr", remote).WithField("counter", c).
 			WithField("udpAddr", remote).WithField("counter", c).
 			WithField("attemptedCounter", c).
 			WithField("attemptedCounter", c).
 			Error("Failed to encrypt outgoing packet")
 			Error("Failed to encrypt outgoing packet")
-		return
+		return false
 	}
 	}
 
 
-	if remote.IsValid() {
-		err = f.writers[q].WriteTo(out, remote)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).
-				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
+	dest := remote
+	if !dest.IsValid() {
+		dest = hostinfo.remote
+	}
+
+	if dest.IsValid() {
+		if queue != nil {
+			queue(dest, len(out))
+			return true
 		}
 		}
-	} else if hostinfo.remote.IsValid() {
-		err = f.writers[q].WriteTo(out, hostinfo.remote)
+
+		err = f.writers[q].WriteTo(out, dest)
 		if err != nil {
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).
 			hostinfo.logger(f.l).WithError(err).
-				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
+				WithField("udpAddr", dest).Error("Failed to write outgoing packet")
 		}
 		}
-	} else {
-		// Try to send via a relay
-		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
-			if err != nil {
-				hostinfo.relayState.DeleteRelay(relayIP)
-				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
-				continue
-			}
-			f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
-			break
+		return false
+	}
+
+	// Try to send via a relay
+	for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
+		relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
+		if err != nil {
+			hostinfo.relayState.DeleteRelay(relayIP)
+			hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
+			continue
 		}
 		}
+		f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
+		break
 	}
 	}
+
+	return false
 }
 }

+ 81 - 2
interface.go

@@ -29,6 +29,7 @@ const (
 	outboundBatchSizeDefault     = 32
 	outboundBatchSizeDefault     = 32
 	batchFlushIntervalDefault    = 50 * time.Microsecond
 	batchFlushIntervalDefault    = 50 * time.Microsecond
 	maxOutstandingBatchesDefault = 1028
 	maxOutstandingBatchesDefault = 1028
+	sendBatchSizeDefault         = 32
 )
 )
 
 
 type InterfaceConfig struct {
 type InterfaceConfig struct {
@@ -120,12 +121,21 @@ type Interface struct {
 	packetBatchPool   sync.Pool
 	packetBatchPool   sync.Pool
 	outboundBatchPool sync.Pool
 	outboundBatchPool sync.Pool
 
 
+	sendPool      sync.Pool
+	sendBatchSize int
+
 	inboundBatchSize      int
 	inboundBatchSize      int
 	outboundBatchSize     int
 	outboundBatchSize     int
 	batchFlushInterval    time.Duration
 	batchFlushInterval    time.Duration
 	maxOutstandingPerChan int
 	maxOutstandingPerChan int
 }
 }
 
 
+type outboundSend struct {
+	buf    *[]byte
+	length int
+	addr   netip.AddrPort
+}
+
 type packetBatch struct {
 type packetBatch struct {
 	packets []*packet.Packet
 	packets []*packet.Packet
 }
 }
@@ -194,6 +204,48 @@ func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
 	f.outboundBatchPool.Put(b)
 	f.outboundBatchPool.Put(b)
 }
 }
 
 
+func (f *Interface) getSendBuffer() *[]byte {
+	if v := f.sendPool.Get(); v != nil {
+		buf := v.(*[]byte)
+		*buf = (*buf)[:0]
+		return buf
+	}
+	b := make([]byte, mtu)
+	return &b
+}
+
+func (f *Interface) releaseSendBuffer(buf *[]byte) {
+	if buf == nil {
+		return
+	}
+	*buf = (*buf)[:0]
+	f.sendPool.Put(buf)
+}
+
+func (f *Interface) flushSendQueue(q int, pending *[]outboundSend) {
+	if len(*pending) == 0 {
+		return
+	}
+
+	batch := make([]udp.BatchPacket, len(*pending))
+	for i, entry := range *pending {
+		batch[i] = udp.BatchPacket{
+			Payload: (*entry.buf)[:entry.length],
+			Addr:    entry.addr,
+		}
+	}
+
+	sent, err := f.writers[q].WriteBatch(batch)
+	if err != nil {
+		f.l.WithError(err).WithField("sent", sent).Error("Failed to batch send packets")
+	}
+
+	for _, entry := range *pending {
+		f.releaseSendBuffer(entry.buf)
+	}
+	*pending = (*pending)[:0]
+}
+
 type EncWriter interface {
 type EncWriter interface {
 	SendVia(via *HostInfo,
 	SendVia(via *HostInfo,
 		relay *Relay,
 		relay *Relay,
@@ -316,6 +368,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		outboundBatchSize:     bc.OutboundBatchSize,
 		outboundBatchSize:     bc.OutboundBatchSize,
 		batchFlushInterval:    bc.FlushInterval,
 		batchFlushInterval:    bc.FlushInterval,
 		maxOutstandingPerChan: bc.MaxOutstandingPerChan,
 		maxOutstandingPerChan: bc.MaxOutstandingPerChan,
+		sendBatchSize:         bc.OutboundBatchSize,
 	}
 	}
 
 
 	for i := 0; i < c.routines; i++ {
 	for i := 0; i < c.routines; i++ {
@@ -340,6 +393,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return newOutboundBatch(ifce.outboundBatchSize)
 		return newOutboundBatch(ifce.outboundBatchSize)
 	}}
 	}}
 
 
+	ifce.sendPool = sync.Pool{New: func() any {
+		buf := make([]byte, mtu)
+		return &buf
+	}}
+
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -539,18 +597,39 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	fwPacket1 := &firewall.Packet{}
 	fwPacket1 := &firewall.Packet{}
 	nb1 := make([]byte, 12, 12)
 	nb1 := make([]byte, 12, 12)
-	result1 := make([]byte, mtu)
+	pending := make([]outboundSend, 0, f.sendBatchSize)
 
 
 	for {
 	for {
 		select {
 		select {
 		case batch := <-f.outbound[i]:
 		case batch := <-f.outbound[i]:
 			for _, data := range batch.payloads {
 			for _, data := range batch.payloads {
-				f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
+				sendBuf := f.getSendBuffer()
+				buf := (*sendBuf)[:0]
+				queue := func(addr netip.AddrPort, length int) {
+					pending = append(pending, outboundSend{
+						buf:    sendBuf,
+						length: length,
+						addr:   addr,
+					})
+					if len(pending) >= f.sendBatchSize {
+						f.flushSendQueue(i, &pending)
+					}
+				}
+				sent := f.consumeInsidePacket(*data, fwPacket1, nb1, buf, queue, i, conntrackCache.Get(f.l))
+				if !sent {
+					f.releaseSendBuffer(sendBuf)
+				}
 				*data = (*data)[:mtu]
 				*data = (*data)[:mtu]
 				f.outPool.Put(data)
 				f.outPool.Put(data)
 			}
 			}
 			f.releaseOutboundBatch(batch)
 			f.releaseOutboundBatch(batch)
+			if len(pending) > 0 {
+				f.flushSendQueue(i, &pending)
+			}
 		case <-ctx.Done():
 		case <-ctx.Done():
+			if len(pending) > 0 {
+				f.flushSendQueue(i, &pending)
+			}
 			f.wg.Done()
 			f.wg.Done()
 			return
 			return
 		}
 		}

+ 9 - 0
udp/conn.go

@@ -18,10 +18,16 @@ type Conn interface {
 	LocalAddr() (netip.AddrPort, error)
 	LocalAddr() (netip.AddrPort, error)
 	ListenOut(r EncReader) error
 	ListenOut(r EncReader) error
 	WriteTo(b []byte, addr netip.AddrPort) error
 	WriteTo(b []byte, addr netip.AddrPort) error
+	WriteBatch(pkts []BatchPacket) (int, error)
 	ReloadConfig(c *config.C)
 	ReloadConfig(c *config.C)
 	Close() error
 	Close() error
 }
 }
 
 
+type BatchPacket struct {
+	Payload []byte
+	Addr    netip.AddrPort
+}
+
 type NoopConn struct{}
 type NoopConn struct{}
 
 
 func (NoopConn) Rebind() error {
 func (NoopConn) Rebind() error {
@@ -36,6 +42,9 @@ func (NoopConn) ListenOut(_ EncReader) error {
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 	return nil
 	return nil
 }
 }
+func (NoopConn) WriteBatch(_ []BatchPacket) (int, error) {
+	return 0, nil
+}
 func (NoopConn) ReloadConfig(_ *config.C) {
 func (NoopConn) ReloadConfig(_ *config.C) {
 	return
 	return
 }
 }

+ 11 - 0
udp/udp_darwin.go

@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
 	}
 	}
 }
 }
 
 
+func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	sent := 0
+	for _, pkt := range pkts {
+		if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+			return sent, err
+		}
+		sent++
+	}
+	return sent, nil
+}
+
 func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 	a := u.UDPConn.LocalAddr()
 	a := u.UDPConn.LocalAddr()
 
 

+ 11 - 0
udp/udp_generic.go

@@ -42,6 +42,17 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	return err
 	return err
 }
 }
 
 
+func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	sent := 0
+	for _, pkt := range pkts {
+		if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+			return sent, err
+		}
+		sent++
+	}
+	return sent, nil
+}
+
 func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
 func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
 	a := u.UDPConn.LocalAddr()
 	a := u.UDPConn.LocalAddr()
 
 

+ 112 - 0
udp/udp_linux.go

@@ -343,6 +343,118 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
 	return u.writeTo6(b, ip)
 	return u.writeTo6(b, ip)
 }
 }
 
 
+func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	if len(pkts) == 0 {
+		return 0, nil
+	}
+
+	msgs := make([]rawMessage, 0, len(pkts))
+	iovs := make([]iovec, 0, len(pkts))
+	names := make([][unix.SizeofSockaddrInet6]byte, 0, len(pkts))
+
+	sent := 0
+
+	for _, pkt := range pkts {
+		if len(pkt.Payload) == 0 {
+			sent++
+			continue
+		}
+
+		if u.enableGSO && pkt.Addr.IsValid() {
+			if err := u.queueGSOPacket(pkt.Payload, pkt.Addr); err == nil {
+				sent++
+				continue
+			} else if !errors.Is(err, errGSOFallback) {
+				return sent, err
+			}
+		}
+
+		if !pkt.Addr.IsValid() {
+			if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+				return sent, err
+			}
+			sent++
+			continue
+		}
+
+		msgs = append(msgs, rawMessage{})
+		iovs = append(iovs, iovec{})
+		names = append(names, [unix.SizeofSockaddrInet6]byte{})
+
+		idx := len(msgs) - 1
+		msg := &msgs[idx]
+		iov := &iovs[idx]
+		name := &names[idx]
+
+		setIovecSlice(iov, pkt.Payload)
+		msg.Hdr.Iov = iov
+		msg.Hdr.Iovlen = 1
+		setRawMessageControl(msg, nil)
+		msg.Hdr.Flags = 0
+
+		nameLen, err := u.encodeSockaddr(name[:], pkt.Addr)
+		if err != nil {
+			return sent, err
+		}
+		msg.Hdr.Name = &name[0]
+		msg.Hdr.Namelen = nameLen
+	}
+
+	if len(msgs) == 0 {
+		return sent, nil
+	}
+
+	offset := 0
+	for offset < len(msgs) {
+		n, _, errno := unix.Syscall6(
+			unix.SYS_SENDMMSG,
+			uintptr(u.sysFd),
+			uintptr(unsafe.Pointer(&msgs[offset])),
+			uintptr(len(msgs)-offset),
+			0,
+			0,
+			0,
+		)
+
+		if errno != 0 {
+			if errno == unix.EINTR {
+				continue
+			}
+			return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
+		}
+
+		if n == 0 {
+			break
+		}
+		offset += int(n)
+	}
+
+	return sent + len(msgs), nil
+}
+
+func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
+	if u.isV4 {
+		if !addr.Addr().Is4() {
+			return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
+		}
+		var sa unix.RawSockaddrInet4
+		sa.Family = unix.AF_INET
+		sa.Addr = addr.Addr().As4()
+		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
+		size := unix.SizeofSockaddrInet4
+		copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
+		return uint32(size), nil
+	}
+
+	var sa unix.RawSockaddrInet6
+	sa.Family = unix.AF_INET6
+	sa.Addr = addr.Addr().As16()
+	binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
+	size := unix.SizeofSockaddrInet6
+	copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
+	return uint32(size), nil
+}
+
 func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
 func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
 	var rsa unix.RawSockaddrInet6
 	var rsa unix.RawSockaddrInet6
 	rsa.Family = unix.AF_INET6
 	rsa.Family = unix.AF_INET6

+ 10 - 0
udp/udp_linux_32.go

@@ -77,3 +77,13 @@ func getRawMessageFlags(msg *rawMessage) int {
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 	h.Len = uint32(l)
 	h.Len = uint32(l)
 }
 }
+
+func setIovecSlice(iov *iovec, b []byte) {
+	if len(b) == 0 {
+		iov.Base = nil
+		iov.Len = 0
+		return
+	}
+	iov.Base = &b[0]
+	iov.Len = uint32(len(b))
+}

+ 10 - 0
udp/udp_linux_64.go

@@ -80,3 +80,13 @@ func getRawMessageFlags(msg *rawMessage) int {
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 func setCmsgLen(h *unix.Cmsghdr, l int) {
 	h.Len = uint64(l)
 	h.Len = uint64(l)
 }
 }
+
+func setIovecSlice(iov *iovec, b []byte) {
+	if len(b) == 0 {
+		iov.Base = nil
+		iov.Len = 0
+		return
+	}
+	iov.Base = &b[0]
+	iov.Len = uint64(len(b))
+}

+ 11 - 0
udp/udp_rio_windows.go

@@ -304,6 +304,17 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
 	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
 	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
 }
 }
 
 
+func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	sent := 0
+	for _, pkt := range pkts {
+		if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+			return sent, err
+		}
+		sent++
+	}
+	return sent, nil
+}
+
 func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
 func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
 	sa, err := windows.Getsockname(u.sock)
 	sa, err := windows.Getsockname(u.sock)
 	if err != nil {
 	if err != nil {

+ 11 - 0
udp/udp_tester.go

@@ -106,6 +106,17 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	return nil
 	return nil
 }
 }
 
 
+func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) {
+	sent := 0
+	for _, pkt := range pkts {
+		if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
+			return sent, err
+		}
+		sent++
+	}
+	return sent, nil
+}
+
 func (u *TesterConn) ListenOut(r EncReader) {
 func (u *TesterConn) ListenOut(r EncReader) {
 	for {
 	for {
 		p, ok := <-u.RxPackets
 		p, ok := <-u.RxPackets