Browse Source

hmmmmmm it works i guess maybe

Ryan 1 month ago
parent
commit
aa44f4c7c9

+ 164 - 0
batch_pipeline.go

@@ -0,0 +1,164 @@
+package nebula
+
+import (
+	"net/netip"
+
+	"github.com/slackhq/nebula/overlay"
+	"github.com/slackhq/nebula/udp"
+)
+
+// batchPipelines tracks whether the inside device can operate on packet batches
+// and, if so, holds the shared packet pool sized for the virtio headroom and
+// payload limits advertised by the device. It also owns the fan-in/fan-out
+// queues between the TUN readers, encrypt/decrypt workers, and the UDP writers.
+type batchPipelines struct {
+	enabled    bool
+	inside     overlay.BatchCapableDevice
+	headroom   int
+	payloadCap int
+	pool       *overlay.PacketPool
+	batchSize  int
+	routines   int
+	rxQueues   []chan *overlay.Packet
+	txQueues   []chan queuedDatagram
+	tunQueues  []chan *overlay.Packet
+}
+
+type queuedDatagram struct {
+	packet *overlay.Packet
+	addr   netip.AddrPort
+}
+
+func (bp *batchPipelines) init(device overlay.Device, routines int, queueDepth int, maxSegments int) {
+	if device == nil || routines <= 0 {
+		return
+	}
+	bcap, ok := device.(overlay.BatchCapableDevice)
+	if !ok {
+		return
+	}
+	headroom := bcap.BatchHeadroom()
+	payload := bcap.BatchPayloadCap()
+	if maxSegments < 1 {
+		maxSegments = 1
+	}
+	requiredPayload := udp.MTU * maxSegments
+	if payload < requiredPayload {
+		payload = requiredPayload
+	}
+	batchSize := bcap.BatchSize()
+	if headroom <= 0 || payload <= 0 || batchSize <= 0 {
+		return
+	}
+	bp.enabled = true
+	bp.inside = bcap
+	bp.headroom = headroom
+	bp.payloadCap = payload
+	bp.batchSize = batchSize
+	bp.routines = routines
+	bp.pool = overlay.NewPacketPool(headroom, payload)
+	queueCap := batchSize * defaultBatchQueueDepthFactor
+	if queueDepth > 0 {
+		queueCap = queueDepth
+	}
+	if queueCap < batchSize {
+		queueCap = batchSize
+	}
+	bp.rxQueues = make([]chan *overlay.Packet, routines)
+	bp.txQueues = make([]chan queuedDatagram, routines)
+	bp.tunQueues = make([]chan *overlay.Packet, routines)
+	for i := 0; i < routines; i++ {
+		bp.rxQueues[i] = make(chan *overlay.Packet, queueCap)
+		bp.txQueues[i] = make(chan queuedDatagram, queueCap)
+		bp.tunQueues[i] = make(chan *overlay.Packet, queueCap)
+	}
+}
+
+func (bp *batchPipelines) Pool() *overlay.PacketPool {
+	if bp == nil || !bp.enabled {
+		return nil
+	}
+	return bp.pool
+}
+
+func (bp *batchPipelines) Enabled() bool {
+	return bp != nil && bp.enabled
+}
+
+func (bp *batchPipelines) batchSizeHint() int {
+	if bp == nil || bp.batchSize <= 0 {
+		return 1
+	}
+	return bp.batchSize
+}
+
+func (bp *batchPipelines) rxQueue(i int) chan *overlay.Packet {
+	if bp == nil || !bp.enabled || i < 0 || i >= len(bp.rxQueues) {
+		return nil
+	}
+	return bp.rxQueues[i]
+}
+
+func (bp *batchPipelines) txQueue(i int) chan queuedDatagram {
+	if bp == nil || !bp.enabled || i < 0 || i >= len(bp.txQueues) {
+		return nil
+	}
+	return bp.txQueues[i]
+}
+
+func (bp *batchPipelines) tunQueue(i int) chan *overlay.Packet {
+	if bp == nil || !bp.enabled || i < 0 || i >= len(bp.tunQueues) {
+		return nil
+	}
+	return bp.tunQueues[i]
+}
+
+func (bp *batchPipelines) txQueueLen(i int) int {
+	q := bp.txQueue(i)
+	if q == nil {
+		return 0
+	}
+	return len(q)
+}
+
+func (bp *batchPipelines) tunQueueLen(i int) int {
+	q := bp.tunQueue(i)
+	if q == nil {
+		return 0
+	}
+	return len(q)
+}
+
+func (bp *batchPipelines) enqueueRx(i int, pkt *overlay.Packet) bool {
+	q := bp.rxQueue(i)
+	if q == nil {
+		return false
+	}
+	q <- pkt
+	return true
+}
+
+func (bp *batchPipelines) enqueueTx(i int, pkt *overlay.Packet, addr netip.AddrPort) bool {
+	q := bp.txQueue(i)
+	if q == nil {
+		return false
+	}
+	q <- queuedDatagram{packet: pkt, addr: addr}
+	return true
+}
+
+func (bp *batchPipelines) enqueueTun(i int, pkt *overlay.Packet) bool {
+	q := bp.tunQueue(i)
+	if q == nil {
+		return false
+	}
+	q <- pkt
+	return true
+}
+
+func (bp *batchPipelines) newPacket() *overlay.Packet {
+	if bp == nil || !bp.enabled || bp.pool == nil {
+		return nil
+	}
+	return bp.pool.Get()
+}

+ 61 - 23
inside.go

@@ -8,6 +8,7 @@ import (
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/noiseutil"
 	"github.com/slackhq/nebula/noiseutil"
+	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/routing"
 )
 )
 
 
@@ -335,9 +336,21 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	if ci.eKey == nil {
 	if ci.eKey == nil {
 		return
 		return
 	}
 	}
-	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
+	target := remote
+	if !target.IsValid() {
+		target = hostinfo.remote
+	}
+	useRelay := !target.IsValid()
 	fullOut := out
 	fullOut := out
 
 
+	var pkt *overlay.Packet
+	if !useRelay && f.batches.Enabled() {
+		pkt = f.batches.newPacket()
+		if pkt != nil {
+			out = pkt.Payload()[:0]
+		}
+	}
+
 	if useRelay {
 	if useRelay {
 		if len(out) < header.Len {
 		if len(out) < header.Len {
 			// out always has a capacity of mtu, but not always a length greater than the header.Len.
 			// out always has a capacity of mtu, but not always a length greater than the header.Len.
@@ -376,36 +389,61 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		ci.writeLock.Unlock()
 		ci.writeLock.Unlock()
 	}
 	}
 	if err != nil {
 	if err != nil {
+		if pkt != nil {
+			pkt.Release()
+		}
 		hostinfo.logger(f.l).WithError(err).
 		hostinfo.logger(f.l).WithError(err).
-			WithField("udpAddr", remote).WithField("counter", c).
+			WithField("udpAddr", target).WithField("counter", c).
 			WithField("attemptedCounter", c).
 			WithField("attemptedCounter", c).
 			Error("Failed to encrypt outgoing packet")
 			Error("Failed to encrypt outgoing packet")
 		return
 		return
 	}
 	}
 
 
-	if remote.IsValid() {
-		err = f.writers[q].WriteTo(out, remote)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).
-				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
+	if target.IsValid() {
+		if pkt != nil {
+			pkt.Len = len(out)
+			if f.l.Level >= logrus.DebugLevel {
+				f.l.WithFields(logrus.Fields{
+					"queue":        q,
+					"dest":         target,
+					"payload_len":  pkt.Len,
+					"use_batches":  true,
+					"remote_index": hostinfo.remoteIndexId,
+				}).Debug("enqueueing packet to UDP batch queue")
+			}
+			if f.tryQueuePacket(q, pkt, target) {
+				return
+			}
+			if f.l.Level >= logrus.DebugLevel {
+				f.l.WithFields(logrus.Fields{
+					"queue": q,
+					"dest":  target,
+				}).Debug("failed to enqueue packet; falling back to immediate send")
+			}
+			f.writeImmediatePacket(q, pkt, target, hostinfo)
+			return
 		}
 		}
-	} else if hostinfo.remote.IsValid() {
-		err = f.writers[q].WriteTo(out, hostinfo.remote)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).
-				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
+		if f.tryQueueDatagram(q, out, target) {
+			return
 		}
 		}
-	} else {
-		// Try to send via a relay
-		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
-			if err != nil {
-				hostinfo.relayState.DeleteRelay(relayIP)
-				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
-				continue
-			}
-			f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
-			break
+		f.writeImmediate(q, out, target, hostinfo)
+		return
+	}
+
+	// fall back to relay path
+	if pkt != nil {
+		pkt.Release()
+	}
+
+	// Try to send via a relay
+	for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
+		relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
+		if err != nil {
+			hostinfo.relayState.DeleteRelay(relayIP)
+			hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
+			continue
 		}
 		}
+		f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
+		break
 	}
 	}
 }
 }

+ 602 - 1
interface.go

@@ -21,7 +21,13 @@ import (
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 )
 )
 
 
-const mtu = 9001
+const (
+	mtu                          = 9001
+	defaultGSOFlushInterval      = 150 * time.Microsecond
+	defaultBatchQueueDepthFactor = 4
+	defaultGSOMaxSegments        = 8
+	maxKernelGSOSegments         = 64
+)
 
 
 type InterfaceConfig struct {
 type InterfaceConfig struct {
 	HostMap            *HostMap
 	HostMap            *HostMap
@@ -36,6 +42,9 @@ type InterfaceConfig struct {
 	connectionManager  *connectionManager
 	connectionManager  *connectionManager
 	DropLocalBroadcast bool
 	DropLocalBroadcast bool
 	DropMulticast      bool
 	DropMulticast      bool
+	EnableGSO          bool
+	EnableGRO          bool
+	GSOMaxSegments     int
 	routines           int
 	routines           int
 	MessageMetrics     *MessageMetrics
 	MessageMetrics     *MessageMetrics
 	version            string
 	version            string
@@ -47,6 +56,8 @@ type InterfaceConfig struct {
 	reQueryWait     time.Duration
 	reQueryWait     time.Duration
 
 
 	ConntrackCacheTimeout time.Duration
 	ConntrackCacheTimeout time.Duration
+	BatchFlushInterval    time.Duration
+	BatchQueueDepth       int
 	l                     *logrus.Logger
 	l                     *logrus.Logger
 }
 }
 
 
@@ -84,9 +95,20 @@ type Interface struct {
 	version     string
 	version     string
 
 
 	conntrackCacheTimeout time.Duration
 	conntrackCacheTimeout time.Duration
+	batchQueueDepth       int
+	enableGSO             bool
+	enableGRO             bool
+	gsoMaxSegments        int
+	batchUDPQueueGauge    metrics.Gauge
+	batchUDPFlushCounter  metrics.Counter
+	batchTunQueueGauge    metrics.Gauge
+	batchTunFlushCounter  metrics.Counter
+	batchFlushInterval    atomic.Int64
+	sendSem               chan struct{}
 
 
 	writers []udp.Conn
 	writers []udp.Conn
 	readers []io.ReadWriteCloser
 	readers []io.ReadWriteCloser
+	batches batchPipelines
 
 
 	metricHandshakes    metrics.Histogram
 	metricHandshakes    metrics.Histogram
 	messageMetrics      *MessageMetrics
 	messageMetrics      *MessageMetrics
@@ -161,6 +183,22 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return nil, errors.New("no connection manager")
 		return nil, errors.New("no connection manager")
 	}
 	}
 
 
+	if c.GSOMaxSegments <= 0 {
+		c.GSOMaxSegments = defaultGSOMaxSegments
+	}
+	if c.GSOMaxSegments > maxKernelGSOSegments {
+		c.GSOMaxSegments = maxKernelGSOSegments
+	}
+	if c.BatchQueueDepth <= 0 {
+		c.BatchQueueDepth = c.routines * defaultBatchQueueDepthFactor
+	}
+	if c.BatchFlushInterval < 0 {
+		c.BatchFlushInterval = 0
+	}
+	if c.BatchFlushInterval == 0 && c.EnableGSO {
+		c.BatchFlushInterval = defaultGSOFlushInterval
+	}
+
 	cs := c.pki.getCertState()
 	cs := c.pki.getCertState()
 	ifce := &Interface{
 	ifce := &Interface{
 		pki:                   c.pki,
 		pki:                   c.pki,
@@ -186,6 +224,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		relayManager:          c.relayManager,
 		relayManager:          c.relayManager,
 		connectionManager:     c.connectionManager,
 		connectionManager:     c.connectionManager,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
+		batchQueueDepth:       c.BatchQueueDepth,
+		enableGSO:             c.EnableGSO,
+		enableGRO:             c.EnableGRO,
+		gsoMaxSegments:        c.GSOMaxSegments,
 
 
 		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
 		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
 		messageMetrics:   c.MessageMetrics,
 		messageMetrics:   c.MessageMetrics,
@@ -198,8 +240,25 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	}
 	}
 
 
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
+	ifce.batchUDPQueueGauge = metrics.GetOrRegisterGauge("batch.udp.queue_depth", nil)
+	ifce.batchUDPFlushCounter = metrics.GetOrRegisterCounter("batch.udp.flushes", nil)
+	ifce.batchTunQueueGauge = metrics.GetOrRegisterGauge("batch.tun.queue_depth", nil)
+	ifce.batchTunFlushCounter = metrics.GetOrRegisterCounter("batch.tun.flushes", nil)
+	ifce.batchFlushInterval.Store(int64(c.BatchFlushInterval))
+	ifce.sendSem = make(chan struct{}, c.routines)
+	ifce.batches.init(c.Inside, c.routines, c.BatchQueueDepth, c.GSOMaxSegments)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
+	if c.l.Level >= logrus.DebugLevel {
+		c.l.WithFields(logrus.Fields{
+			"enableGSO":       c.EnableGSO,
+			"enableGRO":       c.EnableGRO,
+			"gsoMaxSegments":  c.GSOMaxSegments,
+			"batchQueueDepth": c.BatchQueueDepth,
+			"batchFlush":      c.BatchFlushInterval,
+			"batching":        ifce.batches.Enabled(),
+		}).Debug("initialized batch pipelines")
+	}
 
 
 	ifce.connectionManager.intf = ifce
 	ifce.connectionManager.intf = ifce
 
 
@@ -248,6 +307,18 @@ func (f *Interface) run() {
 		go f.listenOut(i)
 		go f.listenOut(i)
 	}
 	}
 
 
+	if f.l.Level >= logrus.DebugLevel {
+		f.l.WithField("batching", f.batches.Enabled()).Debug("starting interface run loops")
+	}
+
+	if f.batches.Enabled() {
+		for i := 0; i < f.routines; i++ {
+			go f.runInsideBatchWorker(i)
+			go f.runTunWriteQueue(i)
+			go f.runSendQueue(i)
+		}
+	}
+
 	// Launch n queues to read packets from tun dev
 	// Launch n queues to read packets from tun dev
 	for i := 0; i < f.routines; i++ {
 	for i := 0; i < f.routines; i++ {
 		go f.listenIn(f.readers[i], i)
 		go f.listenIn(f.readers[i], i)
@@ -279,6 +350,17 @@ func (f *Interface) listenOut(i int) {
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	runtime.LockOSThread()
 	runtime.LockOSThread()
 
 
+	if f.batches.Enabled() {
+		if br, ok := reader.(overlay.BatchReader); ok {
+			f.listenInBatchLocked(reader, br, i)
+			return
+		}
+	}
+
+	f.listenInLegacyLocked(reader, i)
+}
+
+func (f *Interface) listenInLegacyLocked(reader io.ReadWriteCloser, i int) {
 	packet := make([]byte, mtu)
 	packet := make([]byte, mtu)
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
 	fwPacket := &firewall.Packet{}
 	fwPacket := &firewall.Packet{}
@@ -302,6 +384,489 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	}
 	}
 }
 }
 
 
+func (f *Interface) listenInBatchLocked(raw io.ReadWriteCloser, reader overlay.BatchReader, i int) {
+	pool := f.batches.Pool()
+	if pool == nil {
+		f.l.Warn("batch pipeline enabled without an allocated pool; falling back to single-packet reads")
+		f.listenInLegacyLocked(raw, i)
+		return
+	}
+
+	for {
+		packets, err := reader.ReadIntoBatch(pool)
+		if err != nil {
+			if errors.Is(err, os.ErrClosed) && f.closed.Load() {
+				return
+			}
+
+			f.l.WithError(err).Error("Error while reading outbound packet batch")
+			os.Exit(2)
+		}
+
+		if len(packets) == 0 {
+			continue
+		}
+
+		for _, pkt := range packets {
+			if pkt == nil {
+				continue
+			}
+			if !f.batches.enqueueRx(i, pkt) {
+				pkt.Release()
+			}
+		}
+	}
+}
+
+func (f *Interface) runInsideBatchWorker(i int) {
+	queue := f.batches.rxQueue(i)
+	if queue == nil {
+		return
+	}
+
+	out := make([]byte, mtu)
+	fwPacket := &firewall.Packet{}
+	nb := make([]byte, 12, 12)
+	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
+
+	for pkt := range queue {
+		if pkt == nil {
+			continue
+		}
+		f.consumeInsidePacket(pkt.Payload(), fwPacket, nb, out, i, conntrackCache.Get(f.l))
+		pkt.Release()
+	}
+}
+
+func (f *Interface) runSendQueue(i int) {
+	queue := f.batches.txQueue(i)
+	if queue == nil {
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("queue", i).Debug("tx queue not initialized; batching disabled for writer")
+		}
+		return
+	}
+	writer := f.writerForIndex(i)
+	if writer == nil {
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("queue", i).Debug("no UDP writer for batch queue")
+		}
+		return
+	}
+	if f.l.Level >= logrus.DebugLevel {
+		f.l.WithField("queue", i).Debug("send queue worker started")
+	}
+	defer func() {
+		if f.l.Level >= logrus.WarnLevel {
+			f.l.WithField("queue", i).Warn("send queue worker exited")
+		}
+	}()
+
+	batchCap := f.batches.batchSizeHint()
+	if batchCap <= 0 {
+		batchCap = 1
+	}
+	gsoLimit := f.effectiveGSOMaxSegments()
+	if gsoLimit > batchCap {
+		batchCap = gsoLimit
+	}
+	pending := make([]queuedDatagram, 0, batchCap)
+	var (
+		flushTimer *time.Timer
+		flushC     <-chan time.Time
+	)
+	dispatch := func(reason string, timerFired bool) {
+		if len(pending) == 0 {
+			return
+		}
+		batch := pending
+		f.flushAndReleaseBatch(i, writer, batch, reason)
+		for idx := range batch {
+			batch[idx] = queuedDatagram{}
+		}
+		pending = pending[:0]
+		if flushTimer != nil {
+			if !timerFired {
+				if !flushTimer.Stop() {
+					select {
+					case <-flushTimer.C:
+					default:
+					}
+				}
+			}
+			flushTimer = nil
+			flushC = nil
+		}
+	}
+	armTimer := func() {
+		delay := f.currentBatchFlushInterval()
+		if delay <= 0 {
+			dispatch("nogso", false)
+			return
+		}
+		if flushTimer == nil {
+			flushTimer = time.NewTimer(delay)
+			flushC = flushTimer.C
+		}
+	}
+
+	for {
+		select {
+		case d := <-queue:
+			if d.packet == nil {
+				continue
+			}
+			if f.l.Level >= logrus.DebugLevel {
+				f.l.WithFields(logrus.Fields{
+					"queue":       i,
+					"payload_len": d.packet.Len,
+					"dest":        d.addr,
+				}).Debug("send queue received packet")
+			}
+			pending = append(pending, d)
+			if gsoLimit > 0 && len(pending) >= gsoLimit {
+				dispatch("gso", false)
+				continue
+			}
+			if len(pending) >= cap(pending) {
+				dispatch("cap", false)
+				continue
+			}
+			armTimer()
+			f.observeUDPQueueLen(i)
+		case <-flushC:
+			dispatch("timer", true)
+		}
+	}
+}
+
+func (f *Interface) runTunWriteQueue(i int) {
+	queue := f.batches.tunQueue(i)
+	if queue == nil {
+		return
+	}
+	writer := f.batches.inside
+	if writer == nil {
+		return
+	}
+
+	batchCap := f.batches.batchSizeHint()
+	if batchCap <= 0 {
+		batchCap = 1
+	}
+	pending := make([]*overlay.Packet, 0, batchCap)
+	var (
+		flushTimer *time.Timer
+		flushC     <-chan time.Time
+	)
+	flush := func(reason string, timerFired bool) {
+		if len(pending) == 0 {
+			return
+		}
+		if _, err := writer.WriteBatch(pending); err != nil {
+			f.l.WithError(err).
+				WithField("queue", i).
+				WithField("reason", reason).
+				Warn("Failed to write tun batch")
+		}
+		for idx := range pending {
+			if pending[idx] != nil {
+				pending[idx].Release()
+			}
+		}
+		pending = pending[:0]
+		if flushTimer != nil {
+			if !timerFired {
+				if !flushTimer.Stop() {
+					select {
+					case <-flushTimer.C:
+					default:
+					}
+				}
+			}
+			flushTimer = nil
+			flushC = nil
+		}
+	}
+	armTimer := func() {
+		delay := f.currentBatchFlushInterval()
+		if delay <= 0 {
+			return
+		}
+		if flushTimer == nil {
+			flushTimer = time.NewTimer(delay)
+			flushC = flushTimer.C
+		}
+	}
+
+	for {
+		select {
+		case pkt := <-queue:
+			if pkt == nil {
+				continue
+			}
+			pending = append(pending, pkt)
+			if len(pending) >= cap(pending) {
+				flush("cap", false)
+				continue
+			}
+			armTimer()
+			f.observeTunQueueLen(i)
+		case <-flushC:
+			flush("timer", true)
+		}
+	}
+}
+
+func (f *Interface) flushAndReleaseBatch(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
+	if len(batch) == 0 {
+		return
+	}
+	f.flushDatagrams(index, writer, batch, reason)
+	for idx := range batch {
+		if batch[idx].packet != nil {
+			batch[idx].packet.Release()
+			batch[idx].packet = nil
+		}
+	}
+	if f.batchUDPFlushCounter != nil {
+		f.batchUDPFlushCounter.Inc(int64(len(batch)))
+	}
+}
+
+func (f *Interface) flushDatagrams(index int, writer udp.Conn, batch []queuedDatagram, reason string) {
+	if len(batch) == 0 {
+		return
+	}
+	if f.l.Level >= logrus.DebugLevel {
+		f.l.WithFields(logrus.Fields{
+			"writer":  index,
+			"reason":  reason,
+			"pending": len(batch),
+		}).Debug("udp batch flush summary")
+	}
+	maxSeg := f.effectiveGSOMaxSegments()
+	if bw, ok := writer.(udp.BatchConn); ok {
+		chunkCap := maxSeg
+		if chunkCap <= 0 {
+			chunkCap = len(batch)
+		}
+		chunk := make([]udp.Datagram, 0, chunkCap)
+		var (
+			currentAddr netip.AddrPort
+			segments    int
+		)
+		flushChunk := func() {
+			if len(chunk) == 0 {
+				return
+			}
+			if f.l.Level >= logrus.DebugLevel {
+				f.l.WithFields(logrus.Fields{
+					"writer":        index,
+					"segments":      len(chunk),
+					"dest":          chunk[0].Addr,
+					"reason":        reason,
+					"pending_total": len(batch),
+				}).Debug("flushing UDP batch")
+			}
+			if err := bw.WriteBatch(chunk); err != nil {
+				f.l.WithError(err).
+					WithField("writer", index).
+					WithField("reason", reason).
+					Warn("Failed to write UDP batch")
+			}
+			chunk = chunk[:0]
+			segments = 0
+		}
+		for _, item := range batch {
+			if item.packet == nil || !item.addr.IsValid() {
+				continue
+			}
+			payload := item.packet.Payload()[:item.packet.Len]
+			if segments == 0 {
+				currentAddr = item.addr
+			}
+			if item.addr != currentAddr || (maxSeg > 0 && segments >= maxSeg) {
+				flushChunk()
+				currentAddr = item.addr
+			}
+			chunk = append(chunk, udp.Datagram{Payload: payload, Addr: item.addr})
+			segments++
+		}
+		flushChunk()
+		return
+	}
+	for _, item := range batch {
+		if item.packet == nil || !item.addr.IsValid() {
+			continue
+		}
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithFields(logrus.Fields{
+				"writer":   index,
+				"reason":   reason,
+				"dest":     item.addr,
+				"segments": 1,
+			}).Debug("flushing UDP batch")
+		}
+		if err := writer.WriteTo(item.packet.Payload()[:item.packet.Len], item.addr); err != nil {
+			f.l.WithError(err).
+				WithField("writer", index).
+				WithField("udpAddr", item.addr).
+				WithField("reason", reason).
+				Warn("Failed to write UDP packet")
+		}
+	}
+}
+
+func (f *Interface) tryQueueDatagram(q int, buf []byte, addr netip.AddrPort) bool {
+	if !addr.IsValid() || !f.batches.Enabled() {
+		return false
+	}
+	pkt := f.batches.newPacket()
+	if pkt == nil {
+		return false
+	}
+	payload := pkt.Payload()
+	if len(payload) < len(buf) {
+		pkt.Release()
+		return false
+	}
+	copy(payload, buf)
+	pkt.Len = len(buf)
+	if f.batches.enqueueTx(q, pkt, addr) {
+		f.observeUDPQueueLen(q)
+		return true
+	}
+	pkt.Release()
+	return false
+}
+
+func (f *Interface) writerForIndex(i int) udp.Conn {
+	if i < 0 || i >= len(f.writers) {
+		return nil
+	}
+	return f.writers[i]
+}
+
+func (f *Interface) writeImmediate(q int, buf []byte, addr netip.AddrPort, hostinfo *HostInfo) {
+	writer := f.writerForIndex(q)
+	if writer == nil {
+		f.l.WithField("udpAddr", addr).
+			WithField("writer", q).
+			Error("Failed to write outgoing packet: no writer available")
+		return
+	}
+	if err := writer.WriteTo(buf, addr); err != nil {
+		hostinfo.logger(f.l).
+			WithError(err).
+			WithField("udpAddr", addr).
+			Error("Failed to write outgoing packet")
+	}
+}
+
+func (f *Interface) tryQueuePacket(q int, pkt *overlay.Packet, addr netip.AddrPort) bool {
+	if pkt == nil || !addr.IsValid() || !f.batches.Enabled() {
+		return false
+	}
+	if f.batches.enqueueTx(q, pkt, addr) {
+		f.observeUDPQueueLen(q)
+		return true
+	}
+	return false
+}
+
+func (f *Interface) writeImmediatePacket(q int, pkt *overlay.Packet, addr netip.AddrPort, hostinfo *HostInfo) {
+	if pkt == nil {
+		return
+	}
+	writer := f.writerForIndex(q)
+	if writer == nil {
+		f.l.WithField("udpAddr", addr).
+			WithField("writer", q).
+			Error("Failed to write outgoing packet: no writer available")
+		pkt.Release()
+		return
+	}
+	if err := writer.WriteTo(pkt.Payload()[:pkt.Len], addr); err != nil {
+		hostinfo.logger(f.l).
+			WithError(err).
+			WithField("udpAddr", addr).
+			Error("Failed to write outgoing packet")
+	}
+	pkt.Release()
+}
+
+func (f *Interface) writePacketToTun(q int, pkt *overlay.Packet) {
+	if pkt == nil {
+		return
+	}
+	writer := f.readers[q]
+	if writer == nil {
+		pkt.Release()
+		return
+	}
+	if _, err := writer.Write(pkt.Payload()[:pkt.Len]); err != nil {
+		f.l.WithError(err).Error("Failed to write to tun")
+	}
+	pkt.Release()
+}
+
+func (f *Interface) observeUDPQueueLen(i int) {
+	if f.batchUDPQueueGauge == nil {
+		return
+	}
+	f.batchUDPQueueGauge.Update(int64(f.batches.txQueueLen(i)))
+}
+
+func (f *Interface) observeTunQueueLen(i int) {
+	if f.batchTunQueueGauge == nil {
+		return
+	}
+	f.batchTunQueueGauge.Update(int64(f.batches.tunQueueLen(i)))
+}
+
+func (f *Interface) currentBatchFlushInterval() time.Duration {
+	if v := f.batchFlushInterval.Load(); v > 0 {
+		return time.Duration(v)
+	}
+	return 0
+}
+
+func (f *Interface) effectiveGSOMaxSegments() int {
+	max := f.gsoMaxSegments
+	if max <= 0 {
+		max = defaultGSOMaxSegments
+	}
+	if max > maxKernelGSOSegments {
+		max = maxKernelGSOSegments
+	}
+	if !f.enableGSO {
+		return 1
+	}
+	return max
+}
+
+type udpOffloadConfigurator interface {
+	ConfigureOffload(enableGSO, enableGRO bool, maxSegments int)
+}
+
+func (f *Interface) applyOffloadConfig(enableGSO, enableGRO bool, maxSegments int) {
+	if maxSegments <= 0 {
+		maxSegments = defaultGSOMaxSegments
+	}
+	if maxSegments > maxKernelGSOSegments {
+		maxSegments = maxKernelGSOSegments
+	}
+	f.enableGSO = enableGSO
+	f.enableGRO = enableGRO
+	f.gsoMaxSegments = maxSegments
+	for _, writer := range f.writers {
+		if cfg, ok := writer.(udpOffloadConfigurator); ok {
+			cfg.ConfigureOffload(enableGSO, enableGRO, maxSegments)
+		}
+	}
+}
+
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadSendRecvError)
 	c.RegisterReloadCallback(f.reloadSendRecvError)
@@ -404,6 +969,42 @@ func (f *Interface) reloadMisc(c *config.C) {
 		f.reQueryWait.Store(int64(n))
 		f.reQueryWait.Store(int64(n))
 		f.l.Info("timers.requery_wait_duration has changed")
 		f.l.Info("timers.requery_wait_duration has changed")
 	}
 	}
+
+	if c.HasChanged("listen.gso_flush_timeout") {
+		d := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
+		if d < 0 {
+			d = 0
+		}
+		f.batchFlushInterval.Store(int64(d))
+		f.l.WithField("duration", d).Info("listen.gso_flush_timeout has changed")
+	} else if c.HasChanged("batch.flush_interval") {
+		d := c.GetDuration("batch.flush_interval", defaultGSOFlushInterval)
+		if d < 0 {
+			d = 0
+		}
+		f.batchFlushInterval.Store(int64(d))
+		f.l.WithField("duration", d).Warn("batch.flush_interval is deprecated; use listen.gso_flush_timeout")
+	}
+
+	if c.HasChanged("batch.queue_depth") {
+		n := c.GetInt("batch.queue_depth", f.batchQueueDepth)
+		if n != f.batchQueueDepth {
+			f.batchQueueDepth = n
+			f.l.Warn("batch.queue_depth changes require a restart to take effect")
+		}
+	}
+
+	if c.HasChanged("listen.enable_gso") || c.HasChanged("listen.enable_gro") || c.HasChanged("listen.gso_max_segments") {
+		enableGSO := c.GetBool("listen.enable_gso", f.enableGSO)
+		enableGRO := c.GetBool("listen.enable_gro", f.enableGRO)
+		maxSeg := c.GetInt("listen.gso_max_segments", f.gsoMaxSegments)
+		f.applyOffloadConfig(enableGSO, enableGRO, maxSeg)
+		f.l.WithFields(logrus.Fields{
+			"enableGSO":      enableGSO,
+			"enableGRO":      enableGRO,
+			"gsoMaxSegments": maxSeg,
+		}).Info("listen GSO/GRO configuration updated")
+	}
 }
 }
 
 
 func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 func (f *Interface) emitStats(ctx context.Context, i time.Duration) {

+ 25 - 0
main.go

@@ -144,6 +144,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	// set up our UDP listener
 	// set up our UDP listener
 	udpConns := make([]udp.Conn, routines)
 	udpConns := make([]udp.Conn, routines)
 	port := c.GetInt("listen.port", 0)
 	port := c.GetInt("listen.port", 0)
+	enableGSO := c.GetBool("listen.enable_gso", true)
+	enableGRO := c.GetBool("listen.enable_gro", true)
+	gsoMaxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
+	if gsoMaxSegments <= 0 {
+		gsoMaxSegments = defaultGSOMaxSegments
+	}
+	if gsoMaxSegments > maxKernelGSOSegments {
+		gsoMaxSegments = maxKernelGSOSegments
+	}
+	gsoFlushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushInterval)
+	if gsoFlushTimeout < 0 {
+		gsoFlushTimeout = 0
+	}
+	batchQueueDepth := c.GetInt("batch.queue_depth", 0)
 
 
 	if !configTest {
 	if !configTest {
 		rawListenHost := c.GetString("listen.host", "0.0.0.0")
 		rawListenHost := c.GetString("listen.host", "0.0.0.0")
@@ -179,6 +193,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
 			}
 			udpServer.ReloadConfig(c)
 			udpServer.ReloadConfig(c)
+			if cfg, ok := udpServer.(interface {
+				ConfigureOffload(bool, bool, int)
+			}); ok {
+				cfg.ConfigureOffload(enableGSO, enableGRO, gsoMaxSegments)
+			}
 			udpConns[i] = udpServer
 			udpConns[i] = udpServer
 
 
 			// If port is dynamic, discover it before the next pass through the for loop
 			// If port is dynamic, discover it before the next pass through the for loop
@@ -246,12 +265,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		reQueryWait:           c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
 		reQueryWait:           c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
 		DropLocalBroadcast:    c.GetBool("tun.drop_local_broadcast", false),
 		DropLocalBroadcast:    c.GetBool("tun.drop_local_broadcast", false),
 		DropMulticast:         c.GetBool("tun.drop_multicast", false),
 		DropMulticast:         c.GetBool("tun.drop_multicast", false),
+		EnableGSO:             enableGSO,
+		EnableGRO:             enableGRO,
+		GSOMaxSegments:        gsoMaxSegments,
 		routines:              routines,
 		routines:              routines,
 		MessageMetrics:        messageMetrics,
 		MessageMetrics:        messageMetrics,
 		version:               buildVersion,
 		version:               buildVersion,
 		relayManager:          NewRelayManager(ctx, l, hostMap, c),
 		relayManager:          NewRelayManager(ctx, l, hostMap, c),
 		punchy:                punchy,
 		punchy:                punchy,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
+		BatchFlushInterval:    gsoFlushTimeout,
+		BatchQueueDepth:       batchQueueDepth,
 		l:                     l,
 		l:                     l,
 	}
 	}
 
 
@@ -263,6 +287,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 		}
 
 
 		ifce.writers = udpConns
 		ifce.writers = udpConns
+		ifce.applyOffloadConfig(enableGSO, enableGRO, gsoMaxSegments)
 		lightHouse.ifce = ifce
 		lightHouse.ifce = ifce
 
 
 		ifce.RegisterConfigChangeCallbacks(c)
 		ifce.RegisterConfigChangeCallbacks(c)

+ 35 - 3
outside.go

@@ -12,6 +12,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/overlay"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
 )
 )
 
 
@@ -466,22 +467,41 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 }
 }
 
 
 func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
 func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
-	var err error
+	var (
+		err error
+		pkt *overlay.Packet
+	)
+
+	if f.batches.tunQueue(q) != nil {
+		pkt = f.batches.newPacket()
+		if pkt != nil {
+			out = pkt.Payload()[:0]
+		}
+	}
 
 
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	if err != nil {
 	if err != nil {
+		if pkt != nil {
+			pkt.Release()
+		}
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		return false
 		return false
 	}
 	}
 
 
 	err = newPacket(out, true, fwPacket)
 	err = newPacket(out, true, fwPacket)
 	if err != nil {
 	if err != nil {
+		if pkt != nil {
+			pkt.Release()
+		}
 		hostinfo.logger(f.l).WithError(err).WithField("packet", out).
 		hostinfo.logger(f.l).WithError(err).WithField("packet", out).
 			Warnf("Error while validating inbound packet")
 			Warnf("Error while validating inbound packet")
 		return false
 		return false
 	}
 	}
 
 
 	if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
 	if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
+		if pkt != nil {
+			pkt.Release()
+		}
 		hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 		hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 			Debugln("dropping out of window packet")
 			Debugln("dropping out of window packet")
 		return false
 		return false
@@ -489,6 +509,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 
 
 	dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
 	dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason != nil {
 	if dropReason != nil {
+		if pkt != nil {
+			pkt.Release()
+		}
 		// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
 		// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
 		// This gives us a buffer to build the reject packet in
 		// This gives us a buffer to build the reject packet in
 		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
 		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
@@ -501,8 +524,17 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	}
 	}
 
 
 	f.connectionManager.In(hostinfo)
 	f.connectionManager.In(hostinfo)
-	_, err = f.readers[q].Write(out)
-	if err != nil {
+	if pkt != nil {
+		pkt.Len = len(out)
+		if f.batches.enqueueTun(q, pkt) {
+			f.observeTunQueueLen(q)
+			return true
+		}
+		f.writePacketToTun(q, pkt)
+		return true
+	}
+
+	if _, err = f.readers[q].Write(out); err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
 		f.l.WithError(err).Error("Failed to write to tun")
 	}
 	}
 	return true
 	return true

+ 82 - 0
overlay/device.go

@@ -3,6 +3,7 @@ package overlay
 import (
 import (
 	"io"
 	"io"
 	"net/netip"
 	"net/netip"
+	"sync"
 
 
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/routing"
 )
 )
@@ -15,3 +16,84 @@ type Device interface {
 	RoutesFor(netip.Addr) routing.Gateways
 	RoutesFor(netip.Addr) routing.Gateways
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 }
 }
+
+// Packet represents a single packet buffer with optional headroom to carry
+// metadata (for example virtio-net headers).
+type Packet struct {
+	Buf     []byte
+	Offset  int
+	Len     int
+	release func()
+}
+
+func (p *Packet) Payload() []byte {
+	return p.Buf[p.Offset : p.Offset+p.Len]
+}
+
+func (p *Packet) Reset() {
+	p.Len = 0
+	p.Offset = 0
+	p.release = nil
+}
+
+func (p *Packet) Release() {
+	if p.release != nil {
+		p.release()
+		p.release = nil
+	}
+}
+
+func (p *Packet) Capacity() int {
+	return len(p.Buf) - p.Offset
+}
+
+// PacketPool manages reusable buffers with headroom.
+type PacketPool struct {
+	headroom int
+	blksz    int
+	pool     sync.Pool
+}
+
+func NewPacketPool(headroom, payload int) *PacketPool {
+	p := &PacketPool{headroom: headroom, blksz: headroom + payload}
+	p.pool.New = func() any {
+		buf := make([]byte, p.blksz)
+		return &Packet{Buf: buf, Offset: headroom}
+	}
+	return p
+}
+
+func (p *PacketPool) Get() *Packet {
+	pkt := p.pool.Get().(*Packet)
+	pkt.Offset = p.headroom
+	pkt.Len = 0
+	pkt.release = func() { p.put(pkt) }
+	return pkt
+}
+
+func (p *PacketPool) put(pkt *Packet) {
+	pkt.Reset()
+	p.pool.Put(pkt)
+}
+
+// BatchReader allows reading multiple packets into a shared pool with
+// preallocated headroom (e.g. virtio-net headers).
+type BatchReader interface {
+	ReadIntoBatch(pool *PacketPool) ([]*Packet, error)
+}
+
+// BatchWriter writes a slice of packets that carry their own metadata.
+type BatchWriter interface {
+	WriteBatch(packets []*Packet) (int, error)
+}
+
+// BatchCapableDevice describes a device that can efficiently read and write
+// batches of packets with virtio headroom.
+type BatchCapableDevice interface {
+	Device
+	BatchReader
+	BatchWriter
+	BatchHeadroom() int
+	BatchPayloadCap() int
+	BatchSize() int
+}

+ 56 - 0
overlay/tun_linux_batch.go

@@ -0,0 +1,56 @@
+//go:build linux && !android && !e2e_testing
+
+package overlay
+
+import "fmt"
+
+func (t *tun) batchIO() (*wireguardTunIO, bool) {
+	io, ok := t.ReadWriteCloser.(*wireguardTunIO)
+	return io, ok
+}
+
+func (t *tun) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
+	io, ok := t.batchIO()
+	if !ok {
+		return nil, fmt.Errorf("wireguard batch I/O not enabled")
+	}
+	return io.ReadIntoBatch(pool)
+}
+
+func (t *tun) WriteBatch(packets []*Packet) (int, error) {
+	io, ok := t.batchIO()
+	if ok {
+		return io.WriteBatch(packets)
+	}
+	for _, pkt := range packets {
+		if pkt == nil {
+			continue
+		}
+		if _, err := t.Write(pkt.Payload()[:pkt.Len]); err != nil {
+			return 0, err
+		}
+		pkt.Release()
+	}
+	return len(packets), nil
+}
+
+func (t *tun) BatchHeadroom() int {
+	if io, ok := t.batchIO(); ok {
+		return io.BatchHeadroom()
+	}
+	return 0
+}
+
+func (t *tun) BatchPayloadCap() int {
+	if io, ok := t.batchIO(); ok {
+		return io.BatchPayloadCap()
+	}
+	return 0
+}
+
+func (t *tun) BatchSize() int {
+	if io, ok := t.batchIO(); ok {
+		return io.BatchSize()
+	}
+	return 1
+}

+ 152 - 37
overlay/wireguard_tun_linux.go

@@ -14,15 +14,15 @@ type wireguardTunIO struct {
 	mtu       int
 	mtu       int
 	batchSize int
 	batchSize int
 
 
-	readMu   sync.Mutex
-	readBufs [][]byte
-	readLens []int
-	pending  [][]byte
-	pendIdx  int
-
-	writeMu   sync.Mutex
-	writeBuf  []byte
-	writeWrap [][]byte
+	readMu      sync.Mutex
+	readBuffers [][]byte
+	readLens    []int
+	legacyBuf   []byte
+
+	writeMu      sync.Mutex
+	writeBuf     []byte
+	writeWrap    [][]byte
+	writeBuffers [][]byte
 }
 }
 
 
 func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
 func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
@@ -33,17 +33,12 @@ func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
 	if mtu <= 0 {
 	if mtu <= 0 {
 		mtu = DefaultMTU
 		mtu = DefaultMTU
 	}
 	}
-	bufs := make([][]byte, batch)
-	for i := range bufs {
-		bufs[i] = make([]byte, wgtun.VirtioNetHdrLen+mtu)
-	}
 	return &wireguardTunIO{
 	return &wireguardTunIO{
 		dev:       dev,
 		dev:       dev,
 		mtu:       mtu,
 		mtu:       mtu,
 		batchSize: batch,
 		batchSize: batch,
-		readBufs:  bufs,
 		readLens:  make([]int, batch),
 		readLens:  make([]int, batch),
-		pending:   make([][]byte, 0, batch),
+		legacyBuf: make([]byte, wgtun.VirtioNetHdrLen+mtu),
 		writeBuf:  make([]byte, wgtun.VirtioNetHdrLen+mtu),
 		writeBuf:  make([]byte, wgtun.VirtioNetHdrLen+mtu),
 		writeWrap: make([][]byte, 1),
 		writeWrap: make([][]byte, 1),
 	}
 	}
@@ -53,29 +48,21 @@ func (w *wireguardTunIO) Read(p []byte) (int, error) {
 	w.readMu.Lock()
 	w.readMu.Lock()
 	defer w.readMu.Unlock()
 	defer w.readMu.Unlock()
 
 
-	for {
-		if w.pendIdx < len(w.pending) {
-			segment := w.pending[w.pendIdx]
-			w.pendIdx++
-			n := copy(p, segment)
-			return n, nil
-		}
-
-		n, err := w.dev.Read(w.readBufs, w.readLens, wgtun.VirtioNetHdrLen)
-		if err != nil {
-			return 0, err
-		}
-		w.pending = w.pending[:0]
-		w.pendIdx = 0
-		for i := 0; i < n; i++ {
-			length := w.readLens[i]
-			if length == 0 {
-				continue
-			}
-			segment := w.readBufs[i][wgtun.VirtioNetHdrLen : wgtun.VirtioNetHdrLen+length]
-			w.pending = append(w.pending, segment)
-		}
+	bufs := w.readBuffers
+	if len(bufs) == 0 {
+		bufs = [][]byte{w.legacyBuf}
+		w.readBuffers = bufs
 	}
 	}
+	n, err := w.dev.Read(bufs[:1], w.readLens[:1], wgtun.VirtioNetHdrLen)
+	if err != nil {
+		return 0, err
+	}
+	if n == 0 {
+		return 0, nil
+	}
+	length := w.readLens[0]
+	copy(p, w.legacyBuf[wgtun.VirtioNetHdrLen:wgtun.VirtioNetHdrLen+length])
+	return length, nil
 }
 }
 
 
 func (w *wireguardTunIO) Write(p []byte) (int, error) {
 func (w *wireguardTunIO) Write(p []byte) (int, error) {
@@ -97,6 +84,134 @@ func (w *wireguardTunIO) Write(p []byte) (int, error) {
 	return len(p), nil
 	return len(p), nil
 }
 }
 
 
+func (w *wireguardTunIO) ReadIntoBatch(pool *PacketPool) ([]*Packet, error) {
+	if pool == nil {
+		return nil, fmt.Errorf("wireguard tun: packet pool is nil")
+	}
+
+	w.readMu.Lock()
+	defer w.readMu.Unlock()
+
+	if len(w.readBuffers) < w.batchSize {
+		w.readBuffers = make([][]byte, w.batchSize)
+	}
+	if len(w.readLens) < w.batchSize {
+		w.readLens = make([]int, w.batchSize)
+	}
+
+	packets := make([]*Packet, w.batchSize)
+	requiredHeadroom := w.BatchHeadroom()
+	requiredPayload := w.BatchPayloadCap()
+	headroom := 0
+	for i := 0; i < w.batchSize; i++ {
+		pkt := pool.Get()
+		if pkt == nil {
+			releasePackets(packets[:i])
+			return nil, fmt.Errorf("wireguard tun: packet pool returned nil packet")
+		}
+		if pkt.Capacity() < requiredPayload {
+			pkt.Release()
+			releasePackets(packets[:i])
+			return nil, fmt.Errorf("wireguard tun: packet capacity %d below required %d", pkt.Capacity(), requiredPayload)
+		}
+		if i == 0 {
+			headroom = pkt.Offset
+			if headroom < requiredHeadroom {
+				pkt.Release()
+				releasePackets(packets[:i])
+				return nil, fmt.Errorf("wireguard tun: packet headroom %d below virtio requirement %d", headroom, requiredHeadroom)
+			}
+		} else if pkt.Offset != headroom {
+			pkt.Release()
+			releasePackets(packets[:i])
+			return nil, fmt.Errorf("wireguard tun: inconsistent packet headroom (%d != %d)", pkt.Offset, headroom)
+		}
+		packets[i] = pkt
+		w.readBuffers[i] = pkt.Buf
+	}
+
+	n, err := w.dev.Read(w.readBuffers[:w.batchSize], w.readLens[:w.batchSize], headroom)
+	if err != nil {
+		releasePackets(packets)
+		return nil, err
+	}
+	if n == 0 {
+		releasePackets(packets)
+		return nil, nil
+	}
+	for i := 0; i < n; i++ {
+		packets[i].Len = w.readLens[i]
+	}
+	for i := n; i < w.batchSize; i++ {
+		packets[i].Release()
+		packets[i] = nil
+	}
+	return packets[:n], nil
+}
+
+func (w *wireguardTunIO) WriteBatch(packets []*Packet) (int, error) {
+	if len(packets) == 0 {
+		return 0, nil
+	}
+	requiredHeadroom := w.BatchHeadroom()
+	offset := packets[0].Offset
+	if offset < requiredHeadroom {
+		releasePackets(packets)
+		return 0, fmt.Errorf("wireguard tun: packet offset %d smaller than required headroom %d", offset, requiredHeadroom)
+	}
+	for _, pkt := range packets {
+		if pkt == nil {
+			continue
+		}
+		if pkt.Offset != offset {
+			releasePackets(packets)
+			return 0, fmt.Errorf("wireguard tun: mixed packet offsets not supported")
+		}
+		limit := pkt.Offset + pkt.Len
+		if limit > len(pkt.Buf) {
+			releasePackets(packets)
+			return 0, fmt.Errorf("wireguard tun: packet length %d exceeds buffer capacity %d", pkt.Len, len(pkt.Buf)-pkt.Offset)
+		}
+	}
+	w.writeMu.Lock()
+	defer w.writeMu.Unlock()
+
+	if len(w.writeBuffers) < len(packets) {
+		w.writeBuffers = make([][]byte, len(packets))
+	}
+	for i, pkt := range packets {
+		if pkt == nil {
+			w.writeBuffers[i] = nil
+			continue
+		}
+		limit := pkt.Offset + pkt.Len
+		w.writeBuffers[i] = pkt.Buf[:limit]
+	}
+	n, err := w.dev.Write(w.writeBuffers[:len(packets)], offset)
+	releasePackets(packets)
+	return n, err
+}
+
+func (w *wireguardTunIO) BatchHeadroom() int {
+	return wgtun.VirtioNetHdrLen
+}
+
+func (w *wireguardTunIO) BatchPayloadCap() int {
+	return w.mtu
+}
+
+func (w *wireguardTunIO) BatchSize() int {
+	return w.batchSize
+}
+
 func (w *wireguardTunIO) Close() error {
 func (w *wireguardTunIO) Close() error {
 	return nil
 	return nil
 }
 }
+
+func releasePackets(pkts []*Packet) {
+	for _, pkt := range pkts {
+		if pkt != nil {
+			pkt.Release()
+		}
+	}
+}

+ 12 - 0
udp/conn.go

@@ -22,6 +22,18 @@ type Conn interface {
 	Close() error
 	Close() error
 }
 }
 
 
+// Datagram represents a UDP payload destined to a specific address.
+type Datagram struct {
+	Payload []byte
+	Addr    netip.AddrPort
+}
+
+// BatchConn can send multiple datagrams in one syscall.
+type BatchConn interface {
+	Conn
+	WriteBatch(pkts []Datagram) error
+}
+
 type NoopConn struct{}
 type NoopConn struct{}
 
 
 func (NoopConn) Rebind() error {
 func (NoopConn) Rebind() error {

+ 94 - 1
udp/wireguard_conn_linux.go

@@ -20,8 +20,12 @@ type WGConn struct {
 	bind      *wgconn.StdNetBind
 	bind      *wgconn.StdNetBind
 	recvers   []wgconn.ReceiveFunc
 	recvers   []wgconn.ReceiveFunc
 	batch     int
 	batch     int
+	reqBatch  int
 	localIP   netip.Addr
 	localIP   netip.Addr
 	localPort uint16
 	localPort uint16
+	enableGSO bool
+	enableGRO bool
+	gsoMaxSeg int
 	closed    atomic.Bool
 	closed    atomic.Bool
 
 
 	closeOnce sync.Once
 	closeOnce sync.Once
@@ -34,7 +38,9 @@ func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool,
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	if batch <= 0 || batch > bind.BatchSize() {
+	if batch <= 0 {
+		batch = bind.BatchSize()
+	} else if batch > bind.BatchSize() {
 		batch = bind.BatchSize()
 		batch = bind.BatchSize()
 	}
 	}
 	return &WGConn{
 	return &WGConn{
@@ -42,6 +48,7 @@ func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool,
 		bind:      bind,
 		bind:      bind,
 		recvers:   recvers,
 		recvers:   recvers,
 		batch:     batch,
 		batch:     batch,
+		reqBatch:  batch,
 		localIP:   ip,
 		localIP:   ip,
 		localPort: actualPort,
 		localPort: actualPort,
 	}, nil
 	}, nil
@@ -118,6 +125,92 @@ func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	return c.bind.Send([][]byte{b}, ep)
 	return c.bind.Send([][]byte{b}, ep)
 }
 }
 
 
+func (c *WGConn) WriteBatch(datagrams []Datagram) error {
+	if len(datagrams) == 0 {
+		return nil
+	}
+	if c.closed.Load() {
+		return net.ErrClosed
+	}
+	max := c.batch
+	if max <= 0 {
+		max = len(datagrams)
+		if max == 0 {
+			max = 1
+		}
+	}
+	bufs := make([][]byte, 0, max)
+	var (
+		current  netip.AddrPort
+		endpoint *wgconn.StdNetEndpoint
+		haveAddr bool
+	)
+	flush := func() error {
+		if len(bufs) == 0 || endpoint == nil {
+			bufs = bufs[:0]
+			return nil
+		}
+		err := c.bind.Send(bufs, endpoint)
+		bufs = bufs[:0]
+		return err
+	}
+
+	for _, d := range datagrams {
+		if len(d.Payload) == 0 || !d.Addr.IsValid() {
+			continue
+		}
+		if !haveAddr || d.Addr != current {
+			if err := flush(); err != nil {
+				return err
+			}
+			current = d.Addr
+			endpoint = &wgconn.StdNetEndpoint{AddrPort: current}
+			haveAddr = true
+		}
+		bufs = append(bufs, d.Payload)
+		if len(bufs) >= max {
+			if err := flush(); err != nil {
+				return err
+			}
+		}
+	}
+	return flush()
+}
+
+func (c *WGConn) ConfigureOffload(enableGSO, enableGRO bool, maxSegments int) {
+	c.enableGSO = enableGSO
+	c.enableGRO = enableGRO
+	if maxSegments <= 0 {
+		maxSegments = 1
+	} else if maxSegments > wgconn.IdealBatchSize {
+		maxSegments = wgconn.IdealBatchSize
+	}
+	c.gsoMaxSeg = maxSegments
+
+	effectiveBatch := c.reqBatch
+	if enableGSO && c.bind != nil {
+		bindBatch := c.bind.BatchSize()
+		if effectiveBatch < bindBatch {
+			if c.l != nil {
+				c.l.WithFields(logrus.Fields{
+					"requested": c.reqBatch,
+					"effective": bindBatch,
+				}).Warn("listen.batch below wireguard minimum; using bind batch size for UDP GSO support")
+			}
+			effectiveBatch = bindBatch
+		}
+	}
+	c.batch = effectiveBatch
+
+	if c.l != nil {
+		c.l.WithFields(logrus.Fields{
+			"enableGSO":      enableGSO,
+			"enableGRO":      enableGRO,
+			"gsoMaxSegments": maxSegments,
+		}).Debug("configured wireguard UDP offload")
+	}
+}
+
 func (c *WGConn) ReloadConfig(*config.C) {
 func (c *WGConn) ReloadConfig(*config.C) {
 	// WireGuard bind currently does not expose runtime configuration knobs.
 	// WireGuard bind currently does not expose runtime configuration knobs.
 }
 }

+ 12 - 0
wgstack/conn/errors_default.go

@@ -0,0 +1,12 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func errShouldDisableUDPGSO(err error) bool {
+	return false
+}

+ 26 - 0
wgstack/conn/errors_linux.go

@@ -0,0 +1,26 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+	"errors"
+	"os"
+
+	"golang.org/x/sys/unix"
+)
+
+func errShouldDisableUDPGSO(err error) bool {
+	var serr *os.SyscallError
+	if errors.As(err, &serr) {
+		// EIO is returned by udp_send_skb() if the device driver does not have
+		// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
+		// See:
+		// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
+		// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
+		return serr.Err == unix.EIO
+	}
+	return false
+}

+ 15 - 0
wgstack/conn/features_default.go

@@ -0,0 +1,15 @@
+//go:build !linux
+// +build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net"
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+	return
+}

+ 29 - 0
wgstack/conn/features_linux.go

@@ -0,0 +1,29 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+	"net"
+
+	"golang.org/x/sys/unix"
+)
+
+func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
+	rc, err := conn.SyscallConn()
+	if err != nil {
+		return
+	}
+	err = rc.Control(func(fd uintptr) {
+		_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
+		txOffload = errSyscall == nil
+		opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
+		rxOffload = errSyscall == nil && opt == 1
+	})
+	if err != nil {
+		return false, false
+	}
+	return txOffload, rxOffload
+}

+ 21 - 0
wgstack/conn/gso_default.go

@@ -0,0 +1,21 @@
+//go:build !linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+	return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+}
+
+// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
+// offloading control data.
+const gsoControlSize = 0

+ 65 - 0
wgstack/conn/gso_linux.go

@@ -0,0 +1,65 @@
+//go:build linux
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+	"fmt"
+	"unsafe"
+
+	"golang.org/x/sys/unix"
+)
+
+const (
+	sizeOfGSOData = 2
+)
+
+// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
+func getGSOSize(control []byte) (int, error) {
+	var (
+		hdr  unix.Cmsghdr
+		data []byte
+		rem  = control
+		err  error
+	)
+
+	for len(rem) > unix.SizeofCmsghdr {
+		hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+		if err != nil {
+			return 0, fmt.Errorf("error parsing socket control message: %w", err)
+		}
+		if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
+			var gso uint16
+			copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
+			return int(gso), nil
+		}
+	}
+	return 0, nil
+}
+
+// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
+// data in control untouched.
+func setGSOSize(control *[]byte, gsoSize uint16) {
+	existingLen := len(*control)
+	avail := cap(*control) - existingLen
+	space := unix.CmsgSpace(sizeOfGSOData)
+	if avail < space {
+		return
+	}
+	*control = (*control)[:cap(*control)]
+	gsoControl := (*control)[existingLen:]
+	hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
+	hdr.Level = unix.SOL_UDP
+	hdr.Type = unix.UDP_SEGMENT
+	hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
+	copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
+	*control = (*control)[:existingLen+space]
+}
+
+// gsoControlSize returns the recommended buffer size for pooling UDP
+// offloading control data.
+var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

+ 42 - 0
wgstack/conn/sticky_default.go

@@ -0,0 +1,42 @@
+//go:build !linux || android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import "net/netip"
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+	return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+	return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+	return ""
+}
+
+// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
+// {get,set}srcControl feature set, but use alternatively named flags and need
+// ports and require testing.
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+}
+
+// setSrcControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+}
+
+// stickyControlSize returns the recommended buffer size for pooling sticky
+// offloading control data.
+const stickyControlSize = 0
+
+const StdNetSupportsStickySockets = false