JackDoan 1 month ago
parent
commit
ac5382928e
50 changed files with 4013 additions and 123 deletions
  1. 7 0
      cmd/nebula/main.go
  2. 1 1
      connection_state.go
  3. 14 14
      firewall.go
  4. 1 1
      go.mod
  5. 2 2
      go.sum
  6. 86 7
      inside.go
  7. 76 15
      interface.go
  8. 265 5
      outside.go
  9. 2 3
      overlay/device.go
  10. 91 0
      overlay/eventfd/eventfd.go
  11. 18 5
      overlay/tun.go
  12. 27 1
      overlay/tun_disabled.go
  13. 121 34
      overlay/tun_linux.go
  14. 23 1
      overlay/user.go
  15. 23 0
      overlay/vhost/README.md
  16. 4 0
      overlay/vhost/doc.go
  17. 218 0
      overlay/vhost/ioctl.go
  18. 21 0
      overlay/vhost/ioctl_test.go
  19. 73 0
      overlay/vhost/memory.go
  20. 42 0
      overlay/vhost/memory_internal_test.go
  21. 23 0
      overlay/vhostnet/README.md
  22. 372 0
      overlay/vhostnet/device.go
  23. 3 0
      overlay/vhostnet/doc.go
  24. 31 0
      overlay/vhostnet/ioctl.go
  25. 69 0
      overlay/vhostnet/options.go
  26. 23 0
      overlay/virtqueue/README.md
  27. 140 0
      overlay/virtqueue/available_ring.go
  28. 71 0
      overlay/virtqueue/available_ring_internal_test.go
  29. 43 0
      overlay/virtqueue/descriptor.go
  30. 12 0
      overlay/virtqueue/descriptor_internal_test.go
  31. 465 0
      overlay/virtqueue/descriptor_table.go
  32. 7 0
      overlay/virtqueue/doc.go
  33. 45 0
      overlay/virtqueue/eventfd_test.go
  34. 33 0
      overlay/virtqueue/size.go
  35. 59 0
      overlay/virtqueue/size_test.go
  36. 421 0
      overlay/virtqueue/split_virtqueue.go
  37. 21 0
      overlay/virtqueue/used_element.go
  38. 12 0
      overlay/virtqueue/used_element_internal_test.go
  39. 184 0
      overlay/virtqueue/used_ring.go
  40. 136 0
      overlay/virtqueue/used_ring_internal_test.go
  41. 70 0
      packet/outpacket.go
  42. 119 0
      packet/packet.go
  43. 37 0
      packet/virtio.go
  44. 4 2
      udp/conn.go
  45. 195 23
      udp/udp_linux.go
  46. 44 9
      udp/udp_linux_64.go
  47. 3 0
      util/virtio/doc.go
  48. 136 0
      util/virtio/features.go
  49. 77 0
      util/virtio/net_hdr.go
  50. 43 0
      util/virtio/net_hdr_test.go

+ 7 - 0
cmd/nebula/main.go

@@ -3,6 +3,9 @@ package main
 import (
 	"flag"
 	"fmt"
+	"log"
+	"net/http"
+	_ "net/http/pprof"
 	"os"
 	"runtime/debug"
 	"strings"
@@ -71,6 +74,10 @@ func main() {
 		os.Exit(1)
 	}
 
+	go func() {
+		log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
+	}()
+
 	if !*configTest {
 		ctrl.Start()
 		notifyReady(l)

+ 1 - 1
connection_state.go

@@ -13,7 +13,7 @@ import (
 	"github.com/slackhq/nebula/noiseutil"
 )
 
-const ReplayWindow = 1024
+const ReplayWindow = 4096
 
 type ConnectionState struct {
 	eKey           *NebulaCipherState

+ 14 - 14
firewall.go

@@ -403,9 +403,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 // Drop returns an error if the packet should be dropped, explaining why. It
 // returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
+func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
-	if f.inConns(fp, h, caPool, localCache) {
+	if f.inConns(fp, h, caPool, localCache, now) {
 		return nil
 	}
 
@@ -454,7 +454,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 	}
 
 	// We always want to conntrack since it is a faster operation
-	f.addConn(fp, incoming)
+	f.addConn(fp, incoming, now)
 
 	return nil
 }
@@ -483,7 +483,7 @@ func (f *Firewall) EmitStats() {
 	metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
 }
 
-func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
+func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool {
 	if localCache != nil {
 		if _, ok := localCache[fp]; ok {
 			return true
@@ -495,7 +495,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
 	// Purge every time we test
 	ep, has := conntrack.TimerWheel.Purge()
 	if has {
-		f.evict(ep)
+		f.evict(ep, now)
 	}
 
 	c, ok := conntrack.Conns[fp]
@@ -542,11 +542,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
 
 	switch fp.Protocol {
 	case firewall.ProtoTCP:
-		c.Expires = time.Now().Add(f.TCPTimeout)
+		c.Expires = now.Add(f.TCPTimeout)
 	case firewall.ProtoUDP:
-		c.Expires = time.Now().Add(f.UDPTimeout)
+		c.Expires = now.Add(f.UDPTimeout)
 	default:
-		c.Expires = time.Now().Add(f.DefaultTimeout)
+		c.Expires = now.Add(f.DefaultTimeout)
 	}
 
 	conntrack.Unlock()
@@ -558,7 +558,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
 	return true
 }
 
-func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
+func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
 	var timeout time.Duration
 	c := &conn{}
 
@@ -574,7 +574,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
 	conntrack := f.Conntrack
 	conntrack.Lock()
 	if _, ok := conntrack.Conns[fp]; !ok {
-		conntrack.TimerWheel.Advance(time.Now())
+		conntrack.TimerWheel.Advance(now)
 		conntrack.TimerWheel.Add(fp, timeout)
 	}
 
@@ -582,14 +582,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
 	// firewall reload
 	c.incoming = incoming
 	c.rulesVersion = f.rulesVersion
-	c.Expires = time.Now().Add(timeout)
+	c.Expires = now.Add(timeout)
 	conntrack.Conns[fp] = c
 	conntrack.Unlock()
 }
 
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Caller must own the connMutex lock!
-func (f *Firewall) evict(p firewall.Packet) {
+func (f *Firewall) evict(p firewall.Packet, now time.Time) {
 	// Are we still tracking this conn?
 	conntrack := f.Conntrack
 	t, ok := conntrack.Conns[p]
@@ -597,11 +597,11 @@ func (f *Firewall) evict(p firewall.Packet) {
 		return
 	}
 
-	newT := t.Expires.Sub(time.Now())
+	newT := t.Expires.Sub(now)
 
 	// Timeout is in the future, re-add the timer
 	if newT > 0 {
-		conntrack.TimerWheel.Advance(time.Now())
+		conntrack.TimerWheel.Advance(now)
 		conntrack.TimerWheel.Add(p, newT)
 		return
 	}

+ 1 - 1
go.mod

@@ -50,6 +50,6 @@ require (
 	github.com/vishvananda/netns v0.0.5 // indirect
 	go.yaml.in/yaml/v2 v2.4.2 // indirect
 	golang.org/x/mod v0.24.0 // indirect
-	golang.org/x/time v0.5.0 // indirect
+	golang.org/x/time v0.7.0 // indirect
 	golang.org/x/tools v0.33.0 // indirect
 )

+ 2 - 2
go.sum

@@ -217,8 +217,8 @@ golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
-golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
+golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=

+ 86 - 7
inside.go

@@ -2,16 +2,18 @@ package nebula
 
 import (
 	"net/netip"
+	"time"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/noiseutil"
+	"github.com/slackhq/nebula/packet"
 	"github.com/slackhq/nebula/routing"
 )
 
-func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
 		if f.l.Level >= logrus.DebugLevel {
@@ -53,7 +55,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	})
 
 	if hostinfo == nil {
-		f.rejectInside(packet, out, q)
+		f.rejectInside(packet, out.Payload, q) //todo vector?
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
 				WithField("fwPacket", fwPacket).
@@ -66,12 +68,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 	}
 
-	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
+	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now)
 	if dropReason == nil {
-		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
-
+		f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
 	} else {
-		f.rejectInside(packet, out, q)
+		f.rejectInside(packet, out.Payload, q) //todo vector?
 		if f.l.Level >= logrus.DebugLevel {
 			hostinfo.logger(f.l).
 				WithField("fwPacket", fwPacket).
@@ -218,7 +219,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	}
 
 	// check if packet is in outbound fw rules
-	dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
+	dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now())
 	if dropReason != nil {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("fwPacket", fp).
@@ -410,3 +411,81 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		}
 	}
 }
+
+func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
+	if ci.eKey == nil {
+		return
+	}
+	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
+	fullOut := out.Payload
+
+	if useRelay {
+		if len(out.Payload) < header.Len {
+			// out always has a capacity of mtu, but not always a length greater than the header.Len.
+			// Grow it to make sure the next operation works.
+			out.Payload = out.Payload[:header.Len]
+		}
+		// Save a header's worth of data at the front of the 'out' buffer.
+		out.Payload = out.Payload[header.Len:]
+	}
+
+	if noiseutil.EncryptLockNeeded {
+		// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
+		ci.writeLock.Lock()
+	}
+	c := ci.messageCounter.Add(1)
+
+	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
+	out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c)
+	f.connectionManager.Out(hostinfo)
+
+	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
+	// all our addrs and enable a faster roaming.
+	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
+		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
+		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
+		f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
+		hostinfo.lastRebindCount = f.rebindCount
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
+		}
+	}
+
+	var err error
+	out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb)
+	if noiseutil.EncryptLockNeeded {
+		ci.writeLock.Unlock()
+	}
+	if err != nil {
+		hostinfo.logger(f.l).WithError(err).
+			WithField("udpAddr", remote).WithField("counter", c).
+			WithField("attemptedCounter", c).
+			Error("Failed to encrypt outgoing packet")
+		return
+	}
+
+	if remote.IsValid() {
+		err = f.writers[q].Prep(out, remote)
+		if err != nil {
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
+		}
+	} else if hostinfo.remote.IsValid() {
+		err = f.writers[q].Prep(out, hostinfo.remote)
+		if err != nil {
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
+		}
+	} else {
+		// Try to send via a relay
+		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
+			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
+			if err != nil {
+				hostinfo.relayState.DeleteRelay(relayIP)
+				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
+				continue
+			}
+			//todo vector!!
+			f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true)
+			break
+		}
+	}
+}

+ 76 - 15
interface.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"errors"
 	"fmt"
-	"io"
 	"net/netip"
 	"os"
 	"runtime"
@@ -18,10 +17,12 @@ import (
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/overlay"
+	"github.com/slackhq/nebula/packet"
 	"github.com/slackhq/nebula/udp"
 )
 
 const mtu = 9001
+const batch = 1024 //todo config!
 
 type InterfaceConfig struct {
 	HostMap            *HostMap
@@ -86,12 +87,18 @@ type Interface struct {
 	conntrackCacheTimeout time.Duration
 
 	writers []udp.Conn
-	readers []io.ReadWriteCloser
+	readers []overlay.TunDev
 
 	metricHandshakes    metrics.Histogram
 	messageMetrics      *MessageMetrics
 	cachedPacketMetrics *cachedPacketMetrics
 
+	listenInN  int
+	listenOutN int
+
+	listenInMetric  metrics.Histogram
+	listenOutMetric metrics.Histogram
+
 	l *logrus.Logger
 }
 
@@ -177,7 +184,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		routines:              c.routines,
 		version:               c.version,
 		writers:               make([]udp.Conn, c.routines),
-		readers:               make([]io.ReadWriteCloser, c.routines),
+		readers:               make([]overlay.TunDev, c.routines),
 		myVpnNetworks:         cs.myVpnNetworks,
 		myVpnNetworksTable:    cs.myVpnNetworksTable,
 		myVpnAddrs:            cs.myVpnAddrs,
@@ -196,6 +203,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 
 		l: c.l,
 	}
+	ifce.listenInMetric = metrics.GetOrRegisterHistogram("vhost.listenIn.n", nil, metrics.NewExpDecaySample(1028, 0.015))
+	ifce.listenOutMetric = metrics.GetOrRegisterHistogram("vhost.listenOut.n", nil, metrics.NewExpDecaySample(1028, 0.015))
 
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
@@ -232,7 +241,7 @@ func (f *Interface) activate() {
 	metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
 
 	// Prepare n tun queues
-	var reader io.ReadWriteCloser = f.inside
+	var reader overlay.TunDev = f.inside
 	for i := 0; i < f.routines; i++ {
 		if i > 0 {
 			reader, err = f.inside.NewMultiQueueReader()
@@ -261,40 +270,72 @@ func (f *Interface) run() {
 	}
 }
 
-func (f *Interface) listenOut(i int) {
+func (f *Interface) listenOut(q int) {
 	runtime.LockOSThread()
 
 	var li udp.Conn
-	if i > 0 {
-		li = f.writers[i]
+	if q > 0 {
+		li = f.writers[q]
 	} else {
 		li = f.outside
 	}
 
 	ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	lhh := f.lightHouse.NewRequestHandler()
-	plaintext := make([]byte, udp.MTU)
+
+	outPackets := make([]*packet.OutPacket, batch)
+	for i := 0; i < batch; i++ {
+		outPackets[i] = packet.NewOut()
+	}
+
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
 	nb := make([]byte, 12, 12)
 
-	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
-		f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+	toSend := make([][]byte, batch)
+
+	li.ListenOut(func(pkts []*packet.Packet) {
+		toSend = toSend[:0]
+		for i := range outPackets {
+			outPackets[i].Valid = false
+			outPackets[i].SegCounter = 0
+		}
+
+		//todo f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+		f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
+		//we opportunistically tx, but try to also send stragglers
+		if _, err := f.readers[q].WriteMany(outPackets, q); err != nil {
+			f.l.WithError(err).Error("Failed to send packets")
+		}
+		//todo I broke this
+		//n := len(toSend)
+		//if f.l.Level == logrus.DebugLevel {
+		//	f.listenOutMetric.Update(int64(n))
+		//}
+		//f.listenOutN = n
+
 	})
 }
 
-func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
+func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
 	runtime.LockOSThread()
 
-	packet := make([]byte, mtu)
-	out := make([]byte, mtu)
 	fwPacket := &firewall.Packet{}
 	nb := make([]byte, 12, 12)
 
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 
+	packets := make([]*packet.VirtIOPacket, batch)
+	outPackets := make([]*packet.Packet, batch)
+	for i := 0; i < batch; i++ {
+		packets[i] = packet.NewVIO()
+		outPackets[i] = packet.New(false) //todo?
+	}
+
 	for {
-		n, err := reader.Read(packet)
+		n, err := reader.ReadMany(packets, queueNum)
+
+		//todo!!
 		if err != nil {
 			if errors.Is(err, os.ErrClosed) && f.closed.Load() {
 				return
@@ -305,7 +346,22 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 			os.Exit(2)
 		}
 
-		f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
+		if f.l.Level == logrus.DebugLevel {
+			f.listenInMetric.Update(int64(n))
+		}
+		f.listenInN = n
+
+		now := time.Now()
+		for i, pkt := range packets[:n] {
+			outPackets[i].OutLen = -1
+			f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
+			reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
+			pkt.Reset()
+		}
+		_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
+		if err != nil {
+			f.l.WithError(err).Error("Error while writing outbound packets")
+		}
 	}
 }
 
@@ -443,6 +499,11 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			} else {
 				certMaxVersion.Update(int64(certState.v1Cert.Version()))
 			}
+			if f.l.Level != logrus.DebugLevel {
+				f.listenInMetric.Update(int64(f.listenInN))
+				f.listenOutMetric.Update(int64(f.listenOutN))
+			}
+
 		}
 	}
 }

+ 265 - 5
outside.go

@@ -7,6 +7,7 @@ import (
 	"time"
 
 	"github.com/google/gopacket/layers"
+	"github.com/slackhq/nebula/packet"
 	"golang.org/x/net/ipv6"
 
 	"github.com/sirupsen/logrus"
@@ -19,7 +20,7 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
 	err := h.Parse(packet)
 	if err != nil {
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -60,7 +61,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
 
 		switch h.Subtype {
 		case header.MessageNone:
-			if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
+			if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
 				return
 			}
 		case header.MessageRelay:
@@ -102,7 +103,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
 					relay:     relay,
 					IsRelayed: true,
 				}
-				f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
+				f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
 				return
 			case ForwardingType:
 				// Find the target HostInfo relay object
@@ -223,6 +224,217 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
 	f.connectionManager.In(hostinfo)
 }
 
+func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
+	for i, pkt := range packets {
+		out[i].Scratch = out[i].Scratch[:0]
+		ip := pkt.AddrPort()
+
+		//l.Error("in packet ", header, packet[HeaderLen:])
+		if ip.IsValid() {
+			if f.myVpnNetworksTable.Contains(ip.Addr()) {
+				if f.l.Level >= logrus.DebugLevel {
+					f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
+				}
+				return
+			}
+		}
+
+		//todo per-segment!
+		for segment := range pkt.Segments() {
+
+			err := h.Parse(segment)
+			if err != nil {
+				// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
+				if len(segment) > 1 {
+					f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", ip, err)
+				}
+				return
+			}
+
+			var hostinfo *HostInfo
+			// verify if we've seen this index before, otherwise respond to the handshake initiation
+			if h.Type == header.Message && h.Subtype == header.MessageRelay {
+				hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
+			} else {
+				hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
+			}
+
+			var ci *ConnectionState
+			if hostinfo != nil {
+				ci = hostinfo.ConnectionState
+			}
+
+			switch h.Type {
+			case header.Message:
+				// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
+				if !f.handleEncrypted(ci, ip, h) {
+					return
+				}
+
+				switch h.Subtype {
+				case header.MessageNone:
+					if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
+						return
+					}
+				case header.MessageRelay:
+					// The entire body is sent as AD, not encrypted.
+					// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
+					// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
+					// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
+					// which will gracefully fail in the DecryptDanger call.
+					signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
+					signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
+					out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
+					if err != nil {
+						return
+					}
+					// Successfully validated the thing. Get rid of the Relay header.
+					signedPayload = signedPayload[header.Len:]
+					// Pull the Roaming parts up here, and return in all call paths.
+					f.handleHostRoaming(hostinfo, ip)
+					// Track usage of both the HostInfo and the Relay for the received & authenticated packet
+					f.connectionManager.In(hostinfo)
+					f.connectionManager.RelayUsed(h.RemoteIndex)
+
+					relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
+					if !ok {
+						// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
+						// its internal mapping. This should never happen.
+						hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+						return
+					}
+
+					switch relay.Type {
+					case TerminalType:
+						// If I am the target of this relay, process the unwrapped packet
+						// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
+						f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
+						return
+					case ForwardingType:
+						// Find the target HostInfo relay object
+						targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
+						if err != nil {
+							hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
+							return
+						}
+
+						// If that relay is Established, forward the payload through it
+						if targetRelay.State == Established {
+							switch targetRelay.Type {
+							case ForwardingType:
+								// Forward this packet through the relay tunnel
+								// Find the target HostInfo
+								f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
+								return
+							case TerminalType:
+								hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
+							}
+						} else {
+							hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
+							return
+						}
+					}
+				}
+
+			case header.LightHouse:
+				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+				if !f.handleEncrypted(ci, ip, h) {
+					return
+				}
+
+				d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
+				if err != nil {
+					hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+						WithField("packet", segment).
+						Error("Failed to decrypt lighthouse packet")
+					return
+				}
+
+				lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
+
+				// Fallthrough to the bottom to record incoming traffic
+
+			case header.Test:
+				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+				if !f.handleEncrypted(ci, ip, h) {
+					return
+				}
+
+				d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
+				if err != nil {
+					hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+						WithField("packet", segment).
+						Error("Failed to decrypt test packet")
+					return
+				}
+
+				if h.Subtype == header.TestRequest {
+					// This testRequest might be from TryPromoteBest, so we should roam
+					// to the new IP address before responding
+					f.handleHostRoaming(hostinfo, ip)
+					f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
+				}
+
+				// Fallthrough to the bottom to record incoming traffic
+
+				// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
+				// are unauthenticated
+
+			case header.Handshake:
+				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+				f.handshakeManager.HandleIncoming(ip, nil, segment, h)
+				return
+
+			case header.RecvError:
+				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+				f.handleRecvError(ip, h)
+				return
+
+			case header.CloseTunnel:
+				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+				if !f.handleEncrypted(ci, ip, h) {
+					return
+				}
+
+				hostinfo.logger(f.l).WithField("udpAddr", ip).
+					Info("Close tunnel received, tearing down.")
+
+				f.closeTunnel(hostinfo)
+				return
+
+			case header.Control:
+				if !f.handleEncrypted(ci, ip, h) {
+					return
+				}
+
+				d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
+				if err != nil {
+					hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+						WithField("packet", segment).
+						Error("Failed to decrypt Control packet")
+					return
+				}
+
+				f.relayManager.HandleControlMsg(hostinfo, d, f)
+
+			default:
+				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+				hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
+				return
+			}
+
+			f.handleHostRoaming(hostinfo, ip)
+
+			f.connectionManager.In(hostinfo)
+
+		}
+		_, err := f.readers[q].WriteOne(out[i], false, q)
+		if err != nil {
+			f.l.WithError(err).Error("Failed to write packet")
+		}
+	}
+}
+
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	final := f.hostMap.DeleteHostInfo(hostInfo)
@@ -472,7 +684,55 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 	return out, nil
 }
 
-func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
+func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
+	var err error
+
+	seg, err := f.readers[q].AllocSeg(out, q)
+	if err != nil {
+		f.l.WithError(err).Errorln("decryptToTunDelayWrite: failed to allocate segment")
+		return false
+	}
+
+	out.SegmentPayloads[seg] = out.SegmentPayloads[seg][:0]
+	out.SegmentPayloads[seg], err = hostinfo.ConnectionState.dKey.DecryptDanger(out.SegmentPayloads[seg], inSegment[:header.Len], inSegment[header.Len:], messageCounter, nb)
+	if err != nil {
+		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
+		return false
+	}
+
+	err = newPacket(out.SegmentPayloads[seg], true, fwPacket)
+	if err != nil {
+		hostinfo.logger(f.l).WithError(err).WithField("packet", out).
+			Warnf("Error while validating inbound packet")
+		return false
+	}
+
+	if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
+		hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
+			Debugln("dropping out of window packet")
+		return false
+	}
+
+	dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
+	if dropReason != nil {
+		// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
+		// This gives us a buffer to build the reject packet in
+		f.rejectOutside(out.SegmentPayloads[seg], hostinfo.ConnectionState, hostinfo, nb, inSegment, q)
+		if f.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
+				WithField("reason", dropReason).
+				Debugln("dropping inbound packet")
+		}
+		return false
+	}
+
+	f.connectionManager.In(hostinfo)
+	pkt.OutLen += len(inSegment)
+	out.Segments[seg] = out.Segments[seg][:len(out.SegmentHeaders[seg])+len(out.SegmentPayloads[seg])]
+	return true
+}
+
+func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
 	var err error
 
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
@@ -494,7 +754,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return false
 	}
 
-	dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
+	dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
 	if dropReason != nil {
 		// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
 		// This gives us a buffer to build the reject packet in

+ 2 - 3
overlay/device.go

@@ -1,18 +1,17 @@
 package overlay
 
 import (
-	"io"
 	"net/netip"
 
 	"github.com/slackhq/nebula/routing"
 )
 
 type Device interface {
-	io.ReadWriteCloser
+	TunDev
 	Activate() error
 	Networks() []netip.Prefix
 	Name() string
 	RoutesFor(netip.Addr) routing.Gateways
 	SupportsMultiqueue() bool
-	NewMultiQueueReader() (io.ReadWriteCloser, error)
+	NewMultiQueueReader() (TunDev, error)
 }

+ 91 - 0
overlay/eventfd/eventfd.go

@@ -0,0 +1,91 @@
+package eventfd
+
+import (
+	"encoding/binary"
+	"syscall"
+
+	"golang.org/x/sys/unix"
+)
+
+type EventFD struct {
+	fd  int
+	buf [8]byte
+}
+
+func New() (EventFD, error) {
+	fd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
+	if err != nil {
+		return EventFD{}, err
+	}
+	return EventFD{
+		fd:  fd,
+		buf: [8]byte{},
+	}, nil
+}
+
+func (e *EventFD) Kick() error {
+	binary.LittleEndian.PutUint64(e.buf[:], 1) //is this right???
+	_, err := syscall.Write(int(e.fd), e.buf[:])
+	return err
+}
+
+func (e *EventFD) Close() error {
+	if e.fd != 0 {
+		return unix.Close(e.fd)
+	}
+	return nil
+}
+
+func (e *EventFD) FD() int {
+	return e.fd
+}
+
+type Epoll struct {
+	fd     int
+	buf    [8]byte
+	events []syscall.EpollEvent
+}
+
+func NewEpoll() (Epoll, error) {
+	fd, err := unix.EpollCreate1(0)
+	if err != nil {
+		return Epoll{}, err
+	}
+	return Epoll{
+		fd:     fd,
+		buf:    [8]byte{},
+		events: make([]syscall.EpollEvent, 1),
+	}, nil
+}
+
+func (ep *Epoll) AddEvent(fdToAdd int) error {
+	event := syscall.EpollEvent{
+		Events: syscall.EPOLLIN,
+		Fd:     int32(fdToAdd),
+	}
+	return syscall.EpollCtl(ep.fd, syscall.EPOLL_CTL_ADD, fdToAdd, &event)
+}
+
+func (ep *Epoll) Block() (int, error) {
+	n, err := syscall.EpollWait(ep.fd, ep.events, -1)
+	if err != nil {
+		//goland:noinspection GoDirectComparisonOfErrors
+		if err == syscall.EINTR {
+			return 0, nil //??
+		}
+		return -1, err
+	}
+	return n, nil
+}
+
+func (ep *Epoll) Clear() error {
+	_, err := syscall.Read(int(ep.events[0].Fd), ep.buf[:])
+	return err
+}
+
+func (ep *Epoll) Close() error {
+	if ep.fd != 0 {
+		return unix.Close(ep.fd)
+	}
+	return nil
+}

+ 18 - 5
overlay/tun.go

@@ -2,16 +2,29 @@ package overlay
 
 import (
 	"fmt"
+	"io"
 	"net"
 	"net/netip"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/packet"
 	"github.com/slackhq/nebula/util"
 )
 
 const DefaultMTU = 1300
 
+type TunDev interface {
+	io.WriteCloser
+	ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
+
+	//todo this interface sux
+	AllocSeg(pkt *packet.OutPacket, q int) (int, error)
+	WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
+	WriteMany(x []*packet.OutPacket, q int) (int, error)
+	RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
+}
+
 // TODO: We may be able to remove routines
 type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
 
@@ -26,11 +39,11 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Pref
 	}
 }
 
-func NewFdDeviceFromConfig(fd *int) DeviceFactory {
-	return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
-		return newTunFromFd(c, l, *fd, vpnNetworks)
-	}
-}
+//func NewFdDeviceFromConfig(fd *int) DeviceFactory {
+//	return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
+//		return newTunFromFd(c, l, *fd, vpnNetworks)
+//	}
+//}
 
 func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {

+ 27 - 1
overlay/tun_disabled.go

@@ -9,6 +9,8 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/overlay/virtqueue"
+	"github.com/slackhq/nebula/packet"
 	"github.com/slackhq/nebula/routing"
 )
 
@@ -22,6 +24,10 @@ type disabledTun struct {
 	l  *logrus.Logger
 }
 
+func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
+	return nil
+}
+
 func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
 	tun := &disabledTun{
 		vpnNetworks: vpnNetworks,
@@ -40,6 +46,10 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
 	return tun
 }
 
+func (*disabledTun) GetQueues() []*virtqueue.SplitQueue {
+	return nil
+}
+
 func (*disabledTun) Activate() error {
 	return nil
 }
@@ -109,7 +119,23 @@ func (t *disabledTun) SupportsMultiqueue() bool {
 	return true
 }
 
-func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+func (t *disabledTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
+	return 0, fmt.Errorf("tun_disabled: AllocSeg not implemented")
+}
+
+func (t *disabledTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
+	return 0, fmt.Errorf("tun_disabled: WriteOne not implemented")
+}
+
+func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
+	return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
+}
+
+func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
+	return t.Read(b[0].Payload)
+}
+
+func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {
 	return t, nil
 }
 

+ 121 - 34
overlay/tun_linux.go

@@ -5,7 +5,6 @@ package overlay
 
 import (
 	"fmt"
-	"io"
 	"net"
 	"net/netip"
 	"os"
@@ -17,15 +16,19 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/overlay/vhostnet"
+	"github.com/slackhq/nebula/packet"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/util/virtio"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
 )
 
 type tun struct {
-	io.ReadWriteCloser
+	file        *os.File
 	fd          int
+	vdev        []*vhostnet.Device
 	Device      string
 	vpnNetworks []netip.Prefix
 	MaxMTU      int
@@ -40,7 +43,8 @@ type tun struct {
 	useSystemRoutes           bool
 	useSystemRoutesBufferSize int
 
-	l *logrus.Logger
+	isV6 bool
+	l    *logrus.Logger
 }
 
 func (t *tun) Networks() []netip.Prefix {
@@ -102,7 +106,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
 	}
 
 	var req ifReq
-	req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
+	req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_TUN_EXCL | unix.IFF_VNET_HDR | unix.IFF_NAPI)
 	if multiqueue {
 		req.Flags |= unix.IFF_MULTI_QUEUE
 	}
@@ -112,20 +116,47 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
 	}
 	name := strings.Trim(string(req.Name[:]), "\x00")
 
+	if err = unix.SetNonblock(fd, true); err != nil {
+		_ = unix.Close(fd)
+		return nil, fmt.Errorf("make file descriptor non-blocking: %w", err)
+	}
+
 	file := os.NewFile(uintptr(fd), "/dev/net/tun")
+
+	err = unix.IoctlSetPointerInt(fd, unix.TUNSETVNETHDRSZ, virtio.NetHdrSize)
+	if err != nil {
+		return nil, fmt.Errorf("set vnethdr size: %w", err)
+	}
+
+	flags := 0
+	//flags = //unix.TUN_F_CSUM //| unix.TUN_F_TSO4 | unix.TUN_F_USO4 | unix.TUN_F_TSO6 | unix.TUN_F_USO6
+	err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, flags)
+	if err != nil {
+		return nil, fmt.Errorf("set offloads: %w", err)
+	}
+
 	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 		return nil, err
 	}
-
+	t.fd = fd
 	t.Device = name
 
+	vdev, err := vhostnet.NewDevice(
+		vhostnet.WithBackendFD(fd),
+		vhostnet.WithQueueSize(8192), //todo config
+	)
+	if err != nil {
+		return nil, err
+	}
+	t.vdev = []*vhostnet.Device{vdev}
+
 	return t, nil
 }
 
 func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
 	t := &tun{
-		ReadWriteCloser:           file,
+		file:                      file,
 		fd:                        int(file.Fd()),
 		vpnNetworks:               vpnNetworks,
 		TXQueueLen:                c.GetInt("tun.tx_queue", 500),
@@ -133,6 +164,9 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
 		useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
 		l:                         l,
 	}
+	if len(vpnNetworks) != 0 {
+		t.isV6 = vpnNetworks[0].Addr().Is6() //todo what about multi-IP?
+	}
 
 	err := t.reload(c, true)
 	if err != nil {
@@ -220,7 +254,7 @@ func (t *tun) SupportsMultiqueue() bool {
 	return true
 }
 
-func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+func (t *tun) NewMultiQueueReader() (TunDev, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 		return nil, err
@@ -233,9 +267,17 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 		return nil, err
 	}
 
-	file := os.NewFile(uintptr(fd), "/dev/net/tun")
+	vdev, err := vhostnet.NewDevice(
+		vhostnet.WithBackendFD(fd),
+		vhostnet.WithQueueSize(8192), //todo config
+	)
+	if err != nil {
+		return nil, err
+	}
 
-	return file, nil
+	t.vdev = append(t.vdev, vdev)
+
+	return t, nil
 }
 
 func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -243,29 +285,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	return r
 }
 
-func (t *tun) Write(b []byte) (int, error) {
-	var nn int
-	maximum := len(b)
-
-	for {
-		n, err := unix.Write(t.fd, b[nn:maximum])
-		if n > 0 {
-			nn += n
-		}
-		if nn == len(b) {
-			return nn, err
-		}
-
-		if err != nil {
-			return nn, err
-		}
-
-		if n == 0 {
-			return nn, io.ErrUnexpectedEOF
-		}
-	}
-}
-
 func (t *tun) deviceBytes() (o [16]byte) {
 	for i, c := range t.Device {
 		o[i] = byte(c)
@@ -689,8 +708,14 @@ func (t *tun) Close() error {
 		close(t.routeChan)
 	}
 
-	if t.ReadWriteCloser != nil {
-		_ = t.ReadWriteCloser.Close()
+	for _, v := range t.vdev {
+		if v != nil {
+			_ = v.Close()
+		}
+	}
+
+	if t.file != nil {
+		_ = t.file.Close()
 	}
 
 	if t.ioctlFd > 0 {
@@ -699,3 +724,65 @@ func (t *tun) Close() error {
 
 	return nil
 }
+
+func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
+	n, err := t.vdev[q].ReceivePackets(p) //we are TXing
+	if err != nil {
+		return 0, err
+	}
+	return n, nil
+}
+
+func (t *tun) Write(b []byte) (int, error) {
+	maximum := len(b) //we are RXing
+
+	//todo garbagey
+	out := packet.NewOut()
+	x, err := t.AllocSeg(out, 0)
+	if err != nil {
+		return 0, err
+	}
+	copy(out.SegmentPayloads[x], b)
+	err = t.vdev[0].TransmitPacket(out, true)
+
+	if err != nil {
+		t.l.WithError(err).Error("Transmitting packet")
+		return 0, err
+	}
+	return maximum, nil
+}
+
+func (t *tun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
+	idx, buf, err := t.vdev[q].GetPacketForTx()
+	if err != nil {
+		return 0, err
+	}
+	x := pkt.UseSegment(idx, buf, t.isV6)
+	return x, nil
+}
+
+func (t *tun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
+	if err := t.vdev[q].TransmitPacket(x, kick); err != nil {
+		t.l.WithError(err).Error("Transmitting packet")
+		return 0, err
+	}
+	return 1, nil
+}
+
+func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
+	maximum := len(x) //we are RXing
+	if maximum == 0 {
+		return 0, nil
+	}
+
+	err := t.vdev[q].TransmitPackets(x)
+	if err != nil {
+		t.l.WithError(err).Error("Transmitting packet")
+		return 0, err
+	}
+	return maximum, nil
+}
+
+func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
+	return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
+}

+ 23 - 1
overlay/user.go

@@ -1,11 +1,13 @@
 package overlay
 
 import (
+	"fmt"
 	"io"
 	"net/netip"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/packet"
 	"github.com/slackhq/nebula/routing"
 )
 
@@ -36,6 +38,10 @@ type UserDevice struct {
 	inboundWriter *io.PipeWriter
 }
 
+func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
+	return nil
+}
+
 func (d *UserDevice) Activate() error {
 	return nil
 }
@@ -50,7 +56,7 @@ func (d *UserDevice) SupportsMultiqueue() bool {
 	return true
 }
 
-func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+func (d *UserDevice) NewMultiQueueReader() (TunDev, error) {
 	return d, nil
 }
 
@@ -69,3 +75,19 @@ func (d *UserDevice) Close() error {
 	d.outboundWriter.Close()
 	return nil
 }
+
+func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
+	return d.Read(b[0].Payload)
+}
+
+func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
+	return 0, fmt.Errorf("user: AllocSeg not implemented")
+}
+
+func (d *UserDevice) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
+	return 0, fmt.Errorf("user: WriteOne not implemented")
+}
+
+func (d *UserDevice) WriteMany(x []*packet.OutPacket, q int) (int, error) {
+	return 0, fmt.Errorf("user: WriteMany not implemented")
+}

+ 23 - 0
overlay/vhost/README.md

@@ -0,0 +1,23 @@
+Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
+
+MIT License
+
+Copyright (c) 2025 Hetzner Cloud GmbH
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 4 - 0
overlay/vhost/doc.go

@@ -0,0 +1,4 @@
+// Package vhost implements the basic ioctl requests needed to interact with the
+// kernel-level virtio server that provides accelerated virtio devices for
+// networking and more.
+package vhost

+ 218 - 0
overlay/vhost/ioctl.go

@@ -0,0 +1,218 @@
+package vhost
+
+import (
+	"fmt"
+	"unsafe"
+
+	"github.com/slackhq/nebula/overlay/virtqueue"
+	"github.com/slackhq/nebula/util/virtio"
+	"golang.org/x/sys/unix"
+)
+
+const (
+	// vhostIoctlGetFeatures can be used to retrieve the features supported by
+	// the vhost implementation in the kernel.
+	//
+	// Response payload: [virtio.Feature]
+	// Kernel name: VHOST_GET_FEATURES
+	vhostIoctlGetFeatures = 0x8008af00
+
+	// vhostIoctlSetFeatures can be used to communicate the features supported
+	// by this virtio implementation to the kernel.
+	//
+	// Request payload: [virtio.Feature]
+	// Kernel name: VHOST_SET_FEATURES
+	vhostIoctlSetFeatures = 0x4008af00
+
+	// vhostIoctlSetOwner can be used to set the current process as the
+	// exclusive owner of a control file descriptor.
+	//
+	// Request payload: none
+	// Kernel name: VHOST_SET_OWNER
+	vhostIoctlSetOwner = 0x0000af01
+
+	// vhostIoctlSetMemoryLayout can be used to set up or modify the memory
+	// layout which describes the IOTLB mappings in the kernel.
+	//
+	// Request payload: [MemoryLayout] with custom serialization
+	// Kernel name: VHOST_SET_MEM_TABLE
+	vhostIoctlSetMemoryLayout = 0x4008af03
+
+	// vhostIoctlSetQueueSize can be used to set the size of the virtqueue.
+	//
+	// Request payload: [QueueState]
+	// Kernel name: VHOST_SET_VRING_NUM
+	vhostIoctlSetQueueSize = 0x4008af10
+
+	// vhostIoctlSetQueueAddress can be used to set the addresses of the
+	// different parts of the virtqueue.
+	//
+	// Request payload: [QueueAddresses]
+	// Kernel name: VHOST_SET_VRING_ADDR
+	vhostIoctlSetQueueAddress = 0x4028af11
+
+	// vhostIoctlSetAvailableRingBase can be used to set the index of the next
+	// available ring entry the device will process.
+	//
+	// Request payload: [QueueState]
+	// Kernel name: VHOST_SET_VRING_BASE
+	vhostIoctlSetAvailableRingBase = 0x4008af12
+
+	// vhostIoctlSetQueueKickEventFD can be used to set the event file
+	// descriptor to signal the device when descriptor chains were added to the
+	// available ring.
+	//
+	// Request payload: [QueueFile]
+	// Kernel name: VHOST_SET_VRING_KICK
+	vhostIoctlSetQueueKickEventFD = 0x4008af20
+
+	// vhostIoctlSetQueueCallEventFD can be used to set the event file
+	// descriptor that gets signaled by the device when descriptor chains have
+	// been used by it.
+	//
+	// Request payload: [QueueFile]
+	// Kernel name: VHOST_SET_VRING_CALL
+	vhostIoctlSetQueueCallEventFD = 0x4008af21
+)
+
+// QueueState is an ioctl request payload that can hold a queue index and any
+// 32-bit number.
+//
+// Kernel name: vhost_vring_state
+type QueueState struct {
+	// QueueIndex is the index of the virtqueue.
+	QueueIndex uint32
+	// Num is any 32-bit number, depending on the request.
+	Num uint32
+}
+
+// QueueAddresses is an ioctl request payload that can hold the addresses of the
+// different parts of a virtqueue.
+//
+// Kernel name: vhost_vring_addr
+type QueueAddresses struct {
+	// QueueIndex is the index of the virtqueue.
+	QueueIndex uint32
+	// Flags that are not used in this implementation.
+	Flags uint32
+	// DescriptorTableAddress is the address of the descriptor table in user
+	// space memory. It must be 16-byte aligned.
+	DescriptorTableAddress uintptr
+	// UsedRingAddress is the address of the used ring in user space memory. It
+	// must be 4-byte aligned.
+	UsedRingAddress uintptr
+	// AvailableRingAddress is the address of the available ring in user space
+	// memory. It must be 2-byte aligned.
+	AvailableRingAddress uintptr
+	// LogAddress is used for an optional logging support, not supported by this
+	// implementation.
+	LogAddress uintptr
+}
+
+// QueueFile is an ioctl request payload that can hold a queue index and a file
+// descriptor.
+//
+// Kernel name: vhost_vring_file
+type QueueFile struct {
+	// QueueIndex is the index of the virtqueue.
+	QueueIndex uint32
+	// FD is the file descriptor of the file. Pass -1 to unbind from a file.
+	FD int32
+}
+
+// IoctlPtr is a copy of the similarly named unexported function from the Go
+// unix package. This is needed to do custom ioctl requests not supported by the
+// standard library.
+func IoctlPtr(fd int, req uint, arg unsafe.Pointer) error {
+	_, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg))
+	if err != 0 {
+		return fmt.Errorf("ioctl request %d: %w", req, err)
+	}
+	return nil
+}
+
+// GetFeatures requests the supported feature bits from the virtio device
+// associated with the given control file descriptor.
+func GetFeatures(controlFD int) (virtio.Feature, error) {
+	var features virtio.Feature
+	if err := IoctlPtr(controlFD, vhostIoctlGetFeatures, unsafe.Pointer(&features)); err != nil {
+		return 0, fmt.Errorf("get features: %w", err)
+	}
+	return features, nil
+}
+
+// SetFeatures communicates the feature bits supported by this implementation
+// to the virtio device associated with the given control file descriptor.
+func SetFeatures(controlFD int, features virtio.Feature) error {
+	if err := IoctlPtr(controlFD, vhostIoctlSetFeatures, unsafe.Pointer(&features)); err != nil {
+		return fmt.Errorf("set features: %w", err)
+	}
+	return nil
+}
+
+// OwnControlFD sets the current process as the exclusive owner for the
+// given control file descriptor. This must be called before interacting with
+// the control file descriptor in any other way.
+func OwnControlFD(controlFD int) error {
+	if err := IoctlPtr(controlFD, vhostIoctlSetOwner, unsafe.Pointer(nil)); err != nil {
+		return fmt.Errorf("set control file descriptor owner: %w", err)
+	}
+	return nil
+}
+
+// SetMemoryLayout sets up or modifies the memory layout for the kernel-level
+// virtio device associated with the given control file descriptor.
+func SetMemoryLayout(controlFD int, layout MemoryLayout) error {
+	payload := layout.serializePayload()
+	if err := IoctlPtr(controlFD, vhostIoctlSetMemoryLayout, unsafe.Pointer(&payload[0])); err != nil {
+		return fmt.Errorf("set memory layout: %w", err)
+	}
+	return nil
+}
+
+// RegisterQueue registers a virtio queue with the kernel-level virtio server.
+// The virtqueue will be linked to the given control file descriptor and will
+// have the given index. The kernel will use this queue until the control file
+// descriptor is closed.
+func RegisterQueue(controlFD int, queueIndex uint32, queue *virtqueue.SplitQueue) error {
+	if err := IoctlPtr(controlFD, vhostIoctlSetQueueSize, unsafe.Pointer(&QueueState{
+		QueueIndex: queueIndex,
+		Num:        uint32(queue.Size()),
+	})); err != nil {
+		return fmt.Errorf("set queue size: %w", err)
+	}
+
+	if err := IoctlPtr(controlFD, vhostIoctlSetQueueAddress, unsafe.Pointer(&QueueAddresses{
+		QueueIndex:             queueIndex,
+		Flags:                  0,
+		DescriptorTableAddress: queue.DescriptorTable().Address(),
+		UsedRingAddress:        queue.UsedRing().Address(),
+		AvailableRingAddress:   queue.AvailableRing().Address(),
+		LogAddress:             0,
+	})); err != nil {
+		return fmt.Errorf("set queue addresses: %w", err)
+	}
+
+	if err := IoctlPtr(controlFD, vhostIoctlSetAvailableRingBase, unsafe.Pointer(&QueueState{
+		QueueIndex: queueIndex,
+		Num:        0,
+	})); err != nil {
+		return fmt.Errorf("set available ring base: %w", err)
+	}
+
+	if err := IoctlPtr(controlFD, vhostIoctlSetQueueKickEventFD, unsafe.Pointer(&QueueFile{
+		QueueIndex: queueIndex,
+		FD:         int32(queue.KickEventFD()),
+	})); err != nil {
+		return fmt.Errorf("set kick event file descriptor: %w", err)
+	}
+
+	if err := IoctlPtr(controlFD, vhostIoctlSetQueueCallEventFD, unsafe.Pointer(&QueueFile{
+		QueueIndex: queueIndex,
+		FD:         int32(queue.CallEventFD()),
+	})); err != nil {
+		return fmt.Errorf("set call event file descriptor: %w", err)
+	}
+
+	return nil
+}

+ 21 - 0
overlay/vhost/ioctl_test.go

@@ -0,0 +1,21 @@
+package vhost_test
+
+import (
+	"testing"
+	"unsafe"
+
+	"github.com/slackhq/nebula/overlay/vhost"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestQueueState_Size(t *testing.T) {
+	assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueState{}))
+}
+
+func TestQueueAddresses_Size(t *testing.T) {
+	assert.EqualValues(t, 40, unsafe.Sizeof(vhost.QueueAddresses{}))
+}
+
+func TestQueueFile_Size(t *testing.T) {
+	assert.EqualValues(t, 8, unsafe.Sizeof(vhost.QueueFile{}))
+}

+ 73 - 0
overlay/vhost/memory.go

@@ -0,0 +1,73 @@
+package vhost
+
+import (
+	"encoding/binary"
+	"fmt"
+	"unsafe"
+
+	"github.com/slackhq/nebula/overlay/virtqueue"
+)
+
+// MemoryRegion describes a region of userspace memory which is being made
+// accessible to a vhost device.
+//
+// Kernel name: vhost_memory_region
+type MemoryRegion struct {
+	// GuestPhysicalAddress is the physical address of the memory region within
+	// the guest, when virtualization is used. When no virtualization is used,
+	// this should be the same as UserspaceAddress.
+	GuestPhysicalAddress uintptr
+	// Size is the size of the memory region.
+	Size uint64
+	// UserspaceAddress is the virtual address in the userspace of the host
+	// where the memory region can be found.
+	UserspaceAddress uintptr
+	// Padding and room for flags. Currently unused.
+	_ uint64
+}
+
+// MemoryLayout is a list of [MemoryRegion]s.
+type MemoryLayout []MemoryRegion
+
+// NewMemoryLayoutForQueues returns a new [MemoryLayout] that describes the
+// memory pages used by the descriptor tables of the given queues.
+func NewMemoryLayoutForQueues(queues []*virtqueue.SplitQueue) MemoryLayout {
+	regions := make([]MemoryRegion, 0)
+	for _, queue := range queues {
+		for address, size := range queue.DescriptorTable().BufferAddresses() {
+			regions = append(regions, MemoryRegion{
+				// There is no virtualization in play here, so the guest address
+				// is the same as in the host's userspace.
+				GuestPhysicalAddress: address,
+				Size:                 uint64(size),
+				UserspaceAddress:     address,
+			})
+		}
+	}
+	return regions
+}
+
+// serializePayload serializes the list of memory regions into a format that is
+// compatible to the vhost_memory kernel struct. The returned byte slice can be
+// used as a payload for the vhostIoctlSetMemoryLayout ioctl.
+func (regions MemoryLayout) serializePayload() []byte {
+	regionCount := len(regions)
+	regionSize := int(unsafe.Sizeof(MemoryRegion{}))
+	payload := make([]byte, 8+regionCount*regionSize)
+
+	// The first 32 bits contain the number of memory regions. The following 32
+	// bits are padding.
+	binary.LittleEndian.PutUint32(payload[0:4], uint32(regionCount))
+
+	if regionCount > 0 {
+		// The underlying byte array of the slice should already have the correct
+		// format, so just copy that.
+		copied := copy(payload[8:], unsafe.Slice((*byte)(unsafe.Pointer(&regions[0])), regionCount*regionSize))
+		if copied != regionCount*regionSize {
+			panic(fmt.Sprintf("copied only %d bytes of the memory regions, but expected %d",
+				copied, regionCount*regionSize))
+		}
+	}
+
+	return payload
+}

+ 42 - 0
overlay/vhost/memory_internal_test.go

@@ -0,0 +1,42 @@
+package vhost
+
+import (
+	"testing"
+	"unsafe"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestMemoryRegion_Size(t *testing.T) {
+	assert.EqualValues(t, 32, unsafe.Sizeof(MemoryRegion{}))
+}
+
+func TestMemoryLayout_SerializePayload(t *testing.T) {
+	layout := MemoryLayout([]MemoryRegion{
+		{
+			GuestPhysicalAddress: 42,
+			Size:                 100,
+			UserspaceAddress:     142,
+		}, {
+			GuestPhysicalAddress: 99,
+			Size:                 100,
+			UserspaceAddress:     99,
+		},
+	})
+	payload := layout.serializePayload()
+
+	assert.Equal(t, []byte{
+		0x02, 0x00, 0x00, 0x00, // nregions
+		0x00, 0x00, 0x00, 0x00, // padding
+		// region 0
+		0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
+		0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
+		0x8e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
+		// region 1
+		0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // guest_phys_addr
+		0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // memory_size
+		0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // userspace_addr
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // flags_padding
+	}, payload)
+}

+ 23 - 0
overlay/vhostnet/README.md

@@ -0,0 +1,23 @@
+Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
+
+MIT License
+
+Copyright (c) 2025 Hetzner Cloud GmbH
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 372 - 0
overlay/vhostnet/device.go

@@ -0,0 +1,372 @@
+package vhostnet
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"os"
+	"runtime"
+
+	"github.com/slackhq/nebula/overlay/vhost"
+	"github.com/slackhq/nebula/overlay/virtqueue"
+	"github.com/slackhq/nebula/packet"
+	"github.com/slackhq/nebula/util/virtio"
+	"golang.org/x/sys/unix"
+)
+
+// ErrDeviceClosed is returned when the [Device] is closed while operations are
+// still running.
+var ErrDeviceClosed = errors.New("device was closed")
+
+// The indexes for the receive and transmit queues.
+const (
+	receiveQueueIndex  = 0
+	transmitQueueIndex = 1
+)
+
+// Device represents a vhost networking device within the kernel-level virtio
+// implementation and provides methods to interact with it.
+type Device struct {
+	initialized bool
+	controlFD   int
+
+	fullTable     bool
+	ReceiveQueue  *virtqueue.SplitQueue
+	TransmitQueue *virtqueue.SplitQueue
+}
+
+// NewDevice initializes a new vhost networking device within the
+// kernel-level virtio implementation, sets up the virtqueues and returns a
+// [Device] instance that can be used to communicate with that vhost device.
+//
+// There are multiple options that can be passed to this constructor to
+// influence device creation:
+//   - [WithQueueSize]
+//   - [WithBackendFD]
+//   - [WithBackendDevice]
+//
+// Remember to call [Device.Close] after use to free up resources.
+func NewDevice(options ...Option) (*Device, error) {
+	var err error
+	opts := optionDefaults
+	opts.apply(options)
+	if err = opts.validate(); err != nil {
+		return nil, fmt.Errorf("invalid options: %w", err)
+	}
+
+	dev := Device{
+		controlFD: -1,
+	}
+
+	// Clean up a partially initialized device when something fails.
+	defer func() {
+		if err != nil {
+			_ = dev.Close()
+		}
+	}()
+
+	// Retrieve a new control file descriptor. This will be used to configure
+	// the vhost networking device in the kernel.
+	dev.controlFD, err = unix.Open("/dev/vhost-net", os.O_RDWR, 0666)
+	if err != nil {
+		return nil, fmt.Errorf("get control file descriptor: %w", err)
+	}
+	if err = vhost.OwnControlFD(dev.controlFD); err != nil {
+		return nil, fmt.Errorf("own control file descriptor: %w", err)
+	}
+
+	// Advertise the supported features. This isn't much for now.
+	// TODO: Add feature options and implement proper feature negotiation.
+	getFeatures, err := vhost.GetFeatures(dev.controlFD) //0x1033D008000 but why
+	if err != nil {
+		return nil, fmt.Errorf("get features: %w", err)
+	}
+	if getFeatures == 0 {
+
+	}
+	//const funky = virtio.Feature(1 << 27)
+	//features := virtio.FeatureVersion1 | funky // | todo virtio.FeatureNetMergeRXBuffers
+	features := virtio.FeatureVersion1 | virtio.FeatureNetMergeRXBuffers
+	if err = vhost.SetFeatures(dev.controlFD, features); err != nil {
+		return nil, fmt.Errorf("set features: %w", err)
+	}
+
+	itemSize := os.Getpagesize() * 4 //todo config
+
+	// Initialize and register the queues needed for the networking device.
+	if dev.ReceiveQueue, err = createQueue(dev.controlFD, receiveQueueIndex, opts.queueSize, itemSize); err != nil {
+		return nil, fmt.Errorf("create receive queue: %w", err)
+	}
+	if dev.TransmitQueue, err = createQueue(dev.controlFD, transmitQueueIndex, opts.queueSize, itemSize); err != nil {
+		return nil, fmt.Errorf("create transmit queue: %w", err)
+	}
+
+	// Set up memory mappings for all buffers used by the queues. This has to
+	// happen before a backend for the queues can be registered.
+	memoryLayout := vhost.NewMemoryLayoutForQueues(
+		[]*virtqueue.SplitQueue{dev.ReceiveQueue, dev.TransmitQueue},
+	)
+	if err = vhost.SetMemoryLayout(dev.controlFD, memoryLayout); err != nil {
+		return nil, fmt.Errorf("setup memory layout: %w", err)
+	}
+
+	// Set the queue backends. This activates the queues within the kernel.
+	if err = SetQueueBackend(dev.controlFD, receiveQueueIndex, opts.backendFD); err != nil {
+		return nil, fmt.Errorf("set receive queue backend: %w", err)
+	}
+	if err = SetQueueBackend(dev.controlFD, transmitQueueIndex, opts.backendFD); err != nil {
+		return nil, fmt.Errorf("set transmit queue backend: %w", err)
+	}
+
+	// Fully populate the receive queue with available buffers which the device
+	// can write new packets into.
+	if err = dev.refillReceiveQueue(); err != nil {
+		return nil, fmt.Errorf("refill receive queue: %w", err)
+	}
+
+	dev.initialized = true
+
+	// Make sure to clean up even when the device gets garbage collected without
+	// Close being called first.
+	devPtr := &dev
+	runtime.SetFinalizer(devPtr, (*Device).Close)
+
+	return devPtr, nil
+}
+
+// refillReceiveQueue offers as many new device-writable buffers to the device
+// as the queue can fit. The device will then use these to write received
+// packets.
+func (dev *Device) refillReceiveQueue() error {
+	for {
+		_, err := dev.ReceiveQueue.OfferInDescriptorChains()
+		if err != nil {
+			if errors.Is(err, virtqueue.ErrNotEnoughFreeDescriptors) {
+				// Queue is full, job is done.
+				return nil
+			}
+			return fmt.Errorf("offer descriptor chain: %w", err)
+		}
+	}
+}
+
+// Close cleans up the vhost networking device within the kernel and releases
+// all resources used for it.
+// The implementation will try to release as many resources as possible and
+// collect potential errors before returning them.
+func (dev *Device) Close() error {
+	dev.initialized = false
+
+	// Closing the control file descriptor will unregister all queues from the
+	// kernel.
+	if dev.controlFD >= 0 {
+		if err := unix.Close(dev.controlFD); err != nil {
+			// Return an error and do not continue, because the memory used for
+			// the queues should not be released before they were unregistered
+			// from the kernel.
+			return fmt.Errorf("close control file descriptor: %w", err)
+		}
+		dev.controlFD = -1
+	}
+
+	var errs []error
+
+	if dev.ReceiveQueue != nil {
+		if err := dev.ReceiveQueue.Close(); err == nil {
+			dev.ReceiveQueue = nil
+		} else {
+			errs = append(errs, fmt.Errorf("close receive queue: %w", err))
+		}
+	}
+
+	if dev.TransmitQueue != nil {
+		if err := dev.TransmitQueue.Close(); err == nil {
+			dev.TransmitQueue = nil
+		} else {
+			errs = append(errs, fmt.Errorf("close transmit queue: %w", err))
+		}
+	}
+
+	if len(errs) == 0 {
+		// Everything was cleaned up. No need to run the finalizer anymore.
+		runtime.SetFinalizer(dev, nil)
+	}
+
+	return errors.Join(errs...)
+}
+
+// createQueue creates a new virtqueue and registers it with the vhost device
+// using the given index.
+func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
+	var (
+		queue *virtqueue.SplitQueue
+		err   error
+	)
+	if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
+		return nil, fmt.Errorf("create virtqueue: %w", err)
+	}
+	if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
+		return nil, fmt.Errorf("register virtqueue with index %d: %w", queueIndex, err)
+	}
+	return queue, nil
+}
+
+func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
+	var err error
+	var idx uint16
+	if !dev.fullTable {
+		idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
+		if err == virtqueue.ErrNotEnoughFreeDescriptors {
+			dev.fullTable = true
+			idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
+		}
+	} else {
+		idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
+	}
+	if err != nil {
+		return 0, nil, fmt.Errorf("transmit queue: %w", err)
+	}
+	buf, err := dev.TransmitQueue.GetDescriptorItem(idx)
+	if err != nil {
+		return 0, nil, fmt.Errorf("get descriptor chain: %w", err)
+	}
+	return idx, buf, nil
+}
+
+func (dev *Device) TransmitPacket(pkt *packet.OutPacket, kick bool) error {
+	if len(pkt.SegmentIDs) == 0 {
+		return nil
+	}
+	for idx := range pkt.SegmentIDs {
+		segmentID := pkt.SegmentIDs[idx]
+		dev.TransmitQueue.SetDescSize(segmentID, len(pkt.Segments[idx]))
+	}
+	err := dev.TransmitQueue.OfferDescriptorChains(pkt.SegmentIDs, false)
+	if err != nil {
+		return fmt.Errorf("offer descriptor chains: %w", err)
+	}
+	pkt.Reset()
+	if kick {
+		if err := dev.TransmitQueue.Kick(); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
+	if len(pkts) == 0 {
+		return nil
+	}
+
+	for i := range pkts {
+		if err := dev.TransmitPacket(pkts[i], false); err != nil {
+			return err
+		}
+	}
+	if err := dev.TransmitQueue.Kick(); err != nil {
+		return err
+	}
+	return nil
+}
+
+// TODO: Make above methods cancelable by taking a context.Context argument?
+// TODO: Implement zero-copy variants to transmit and receive packets?
+
+// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
+func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
+	//read first element to see how many descriptors we need:
+	pkt.Reset()
+
+	err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
+	if err != nil {
+		return 0, fmt.Errorf("get descriptor chain: %w", err)
+	}
+	if len(pkt.ChainRefs) == 0 {
+		return 1, nil
+	}
+
+	// The specification requires that the first descriptor chain starts
+	// with a virtio-net header. It is not clear, whether it is also
+	// required to be fully contained in the first buffer of that
+	// descriptor chain, but it is reasonable to assume that this is
+	// always the case.
+	// The decode method already does the buffer length check.
+	if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
+		// The device misbehaved. There is no way we can gracefully
+		// recover from this, because we don't know how many of the
+		// following descriptor chains belong to this packet.
+		return 0, fmt.Errorf("decode vnethdr: %w", err)
+	}
+
+	//we have the header now: what do we need to do?
+	if int(pkt.Header.NumBuffers) > len(chains) {
+		return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
+	}
+	if int(pkt.Header.NumBuffers) != 1 {
+		return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
+	}
+	if chains[0].Length > 16000 {
+		//todo!
+		return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
+	}
+
+	//shift the buffer out of out:
+	pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
+	pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
+	return 1, nil
+
+	//cursor := n - virtio.NetHdrSize
+	//
+	//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
+	//	pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
+	//	return 1, nil
+	//}
+	//
+	//i := 1
+	//// we used chain 0 already
+	//for i = 1; i < len(chains); i++ {
+	//	n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
+	//	if err != nil {
+	//		// When this fails we may miss to free some descriptor chains. We
+	//		// could try to mitigate this by deferring the freeing somehow, but
+	//		// it's not worth the hassle. When this method fails, the queue will
+	//		// be in a broken state anyway.
+	//		return i, fmt.Errorf("get descriptor chain: %w", err)
+	//	}
+	//	cursor += n
+	//}
+	////todo this has to be wrong
+	//pkt.Payload = pkt.Payload[:cursor]
+	//return i, nil
+}
+
+func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
+	//todo optimize?
+	var chains []virtqueue.UsedElement
+	var err error
+
+	chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
+	if err != nil {
+		return 0, err
+	}
+	if len(chains) == 0 {
+		return 0, nil
+	}
+
+	numPackets := 0
+	chainsIdx := 0
+	for numPackets = 0; chainsIdx < len(chains); numPackets++ {
+		if numPackets >= len(out) {
+			return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
+		}
+		numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
+		if err != nil {
+			return 0, err
+		}
+		chainsIdx += numChains
+	}
+
+	return numPackets, nil
+}

+ 3 - 0
overlay/vhostnet/doc.go

@@ -0,0 +1,3 @@
+// Package vhostnet implements methods to initialize vhost networking devices
+// within the kernel-level virtio implementation and communicate with them.
+package vhostnet

+ 31 - 0
overlay/vhostnet/ioctl.go

@@ -0,0 +1,31 @@
+package vhostnet
+
+import (
+	"fmt"
+	"unsafe"
+
+	"github.com/slackhq/nebula/overlay/vhost"
+)
+
+const (
+	// vhostNetIoctlSetBackend can be used to attach a virtqueue to a RAW socket
+	// or TAP device.
+	//
+	// Request payload: [vhost.QueueFile]
+	// Kernel name: VHOST_NET_SET_BACKEND
+	vhostNetIoctlSetBackend = 0x4008af30
+)
+
+// SetQueueBackend attaches a virtqueue of the vhost networking device
+// described by controlFD to the given backend file descriptor.
+// The backend file descriptor can either be a RAW socket or a TAP device. When
+// it is -1, the queue will be detached.
+func SetQueueBackend(controlFD int, queueIndex uint32, backendFD int) error {
+	if err := vhost.IoctlPtr(controlFD, vhostNetIoctlSetBackend, unsafe.Pointer(&vhost.QueueFile{
+		QueueIndex: queueIndex,
+		FD:         int32(backendFD),
+	})); err != nil {
+		return fmt.Errorf("set queue backend file descriptor: %w", err)
+	}
+	return nil
+}

+ 69 - 0
overlay/vhostnet/options.go

@@ -0,0 +1,69 @@
+package vhostnet
+
+import (
+	"errors"
+
+	"github.com/slackhq/nebula/overlay/virtqueue"
+)
+
+type optionValues struct {
+	queueSize int
+	backendFD int
+}
+
+func (o *optionValues) apply(options []Option) {
+	for _, option := range options {
+		option(o)
+	}
+}
+
+func (o *optionValues) validate() error {
+	if o.queueSize == -1 {
+		return errors.New("queue size is required")
+	}
+	if err := virtqueue.CheckQueueSize(o.queueSize); err != nil {
+		return err
+	}
+	if o.backendFD == -1 {
+		return errors.New("backend file descriptor is required")
+	}
+	return nil
+}
+
+var optionDefaults = optionValues{
+	// Required.
+	queueSize: -1,
+	// Required.
+	backendFD: -1,
+}
+
+// Option can be passed to [NewDevice] to influence device creation.
+type Option func(*optionValues)
+
+// WithQueueSize returns an [Option] that sets the size of the TX and RX queues
+// that are to be created for the device. It specifies the number of
+// entries/buffers each queue can hold. This also affects the memory
+// consumption.
+// This is required and must be an integer from 1 to 32768 that is also a power
+// of 2.
+func WithQueueSize(queueSize int) Option {
+	return func(o *optionValues) { o.queueSize = queueSize }
+}
+
+// WithBackendFD returns an [Option] that sets the file descriptor of the
+// backend that will be used for the queues of the device. The device will write
+// and read packets to/from that backend. The file descriptor can either be of a
+// RAW socket or TUN/TAP device.
+// Either this or [WithBackendDevice] is required.
+func WithBackendFD(backendFD int) Option {
+	return func(o *optionValues) { o.backendFD = backendFD }
+}
+
+//// WithBackendDevice returns an [Option] that sets the given TAP device as the
+//// backend that will be used for the queues of the device. The device will
+//// write and read packets to/from that backend. The TAP device should have been
+//// created with the [tuntap.WithVirtioNetHdr] option enabled.
+//// Either this or [WithBackendFD] is required.
+//func WithBackendDevice(dev *tuntap.Device) Option {
+//	return func(o *optionValues) { o.backendFD = int(dev.File().Fd()) }
+//}

+ 23 - 0
overlay/virtqueue/README.md

@@ -0,0 +1,23 @@
+Significant portions of this code are derived from https://pkg.go.dev/github.com/hetznercloud/virtio-go
+
+MIT License
+
+Copyright (c) 2025 Hetzner Cloud GmbH
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 140 - 0
overlay/virtqueue/available_ring.go

@@ -0,0 +1,140 @@
+package virtqueue
+
+import (
+	"fmt"
+	"unsafe"
+)
+
+// availableRingFlag is a flag that describes an [AvailableRing].
+type availableRingFlag uint16
+
+const (
+	// availableRingFlagNoInterrupt is used by the guest to advise the host to
+	// not interrupt it when consuming a buffer. It's unreliable, so it's simply
+	// an optimization.
+	availableRingFlagNoInterrupt availableRingFlag = 1 << iota
+)
+
+// availableRingSize is the number of bytes needed to store an [AvailableRing]
+// with the given queue size in memory.
+func availableRingSize(queueSize int) int {
+	return 6 + 2*queueSize
+}
+
+// availableRingAlignment is the minimum alignment of an [AvailableRing]
+// in memory, as required by the virtio spec.
+const availableRingAlignment = 2
+
+// AvailableRing is used by the driver to offer descriptor chains to the device.
+// Each ring entry refers to the head of a descriptor chain. It is only written
+// to by the driver and read by the device.
+//
+// Because the size of the ring depends on the queue size, we cannot define a
+// Go struct with a static size that maps to the memory of the ring. Instead,
+// this struct only contains pointers to the corresponding memory areas.
+type AvailableRing struct {
+	initialized bool
+
+	// flags that describe this ring.
+	flags *availableRingFlag
+	// ringIndex indicates where the driver would put the next entry into the
+	// ring (modulo the queue size).
+	ringIndex *uint16
+	// ring references buffers using the index of the head of the descriptor
+	// chain in the [DescriptorTable]. It wraps around at queue size.
+	ring []uint16
+	// usedEvent is not used by this implementation, but we reserve it anyway to
+	// avoid issues in case a device may try to access it, contrary to the
+	// virtio specification.
+	usedEvent *uint16
+}
+
+// newAvailableRing creates an available ring that uses the given underlying
+// memory. The length of the memory slice must match the size needed for the
+// ring (see [availableRingSize]) for the given queue size.
+func newAvailableRing(queueSize int, mem []byte) *AvailableRing {
+	ringSize := availableRingSize(queueSize)
+	if len(mem) != ringSize {
+		panic(fmt.Sprintf("memory size (%v) does not match required size "+
+			"for available ring: %v", len(mem), ringSize))
+	}
+
+	return &AvailableRing{
+		initialized: true,
+		flags:       (*availableRingFlag)(unsafe.Pointer(&mem[0])),
+		ringIndex:   (*uint16)(unsafe.Pointer(&mem[2])),
+		ring:        unsafe.Slice((*uint16)(unsafe.Pointer(&mem[4])), queueSize),
+		usedEvent:   (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
+	}
+}
+
+// Address returns the pointer to the beginning of the ring in memory.
+// Do not modify the memory directly to not interfere with this implementation.
+func (r *AvailableRing) Address() uintptr {
+	if !r.initialized {
+		panic("available ring is not initialized")
+	}
+	return uintptr(unsafe.Pointer(r.flags))
+}
+
+// offer adds the given descriptor chain heads to the available ring and
+// advances the ring index accordingly to make the device process the new
+// descriptor chains.
+func (r *AvailableRing) offerElements(chains []UsedElement) {
+	//always called under lock
+	//r.mu.Lock()
+	//defer r.mu.Unlock()
+
+	// Add descriptor chain heads to the ring.
+	for offset, x := range chains {
+		// The 16-bit ring index may overflow. This is expected and is not an
+		// issue because the size of the ring array (which equals the queue
+		// size) is always a power of 2 and smaller than the highest possible
+		// 16-bit value.
+		insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
+		r.ring[insertIndex] = x.GetHead()
+	}
+
+	// Increase the ring index by the number of descriptor chains added to the
+	// ring.
+	*r.ringIndex += uint16(len(chains))
+}
+
+func (r *AvailableRing) offer(chains []uint16) {
+	//always called under lock
+	//r.mu.Lock()
+	//defer r.mu.Unlock()
+
+	// Add descriptor chain heads to the ring.
+	for offset, x := range chains {
+		// The 16-bit ring index may overflow. This is expected and is not an
+		// issue because the size of the ring array (which equals the queue
+		// size) is always a power of 2 and smaller than the highest possible
+		// 16-bit value.
+		insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
+		r.ring[insertIndex] = x
+	}
+
+	// Increase the ring index by the number of descriptor chains added to the
+	// ring.
+	*r.ringIndex += uint16(len(chains))
+}
+
+func (r *AvailableRing) offerSingle(x uint16) {
+	//always called under lock
+	//r.mu.Lock()
+	//defer r.mu.Unlock()
+
+	offset := 0
+	// Add descriptor chain heads to the ring.
+
+	// The 16-bit ring index may overflow. This is expected and is not an
+	// issue because the size of the ring array (which equals the queue
+	// size) is always a power of 2 and smaller than the highest possible
+	// 16-bit value.
+	insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
+	r.ring[insertIndex] = x
+
+	// Increase the ring index by the number of descriptor chains added to the ring.
+	*r.ringIndex += 1
+}

+ 71 - 0
overlay/virtqueue/available_ring_internal_test.go

@@ -0,0 +1,71 @@
+package virtqueue
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestAvailableRing_MemoryLayout(t *testing.T) {
+	const queueSize = 2
+
+	memory := make([]byte, availableRingSize(queueSize))
+	r := newAvailableRing(queueSize, memory)
+
+	*r.flags = 0x01ff
+	*r.ringIndex = 1
+	r.ring[0] = 0x1234
+	r.ring[1] = 0x5678
+
+	assert.Equal(t, []byte{
+		0xff, 0x01,
+		0x01, 0x00,
+		0x34, 0x12,
+		0x78, 0x56,
+		0x00, 0x00,
+	}, memory)
+}
+
+func TestAvailableRing_Offer(t *testing.T) {
+	const queueSize = 8
+
+	chainHeads := []uint16{42, 33, 69}
+
+	tests := []struct {
+		name              string
+		startRingIndex    uint16
+		expectedRingIndex uint16
+		expectedRing      []uint16
+	}{
+		{
+			name:              "no overflow",
+			startRingIndex:    0,
+			expectedRingIndex: 3,
+			expectedRing:      []uint16{42, 33, 69, 0, 0, 0, 0, 0},
+		},
+		{
+			name:              "ring overflow",
+			startRingIndex:    6,
+			expectedRingIndex: 9,
+			expectedRing:      []uint16{69, 0, 0, 0, 0, 0, 42, 33},
+		},
+		{
+			name:              "index overflow",
+			startRingIndex:    65535,
+			expectedRingIndex: 2,
+			expectedRing:      []uint16{33, 69, 0, 0, 0, 0, 0, 42},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			memory := make([]byte, availableRingSize(queueSize))
+			r := newAvailableRing(queueSize, memory)
+			*r.ringIndex = tt.startRingIndex
+
+			r.offer(chainHeads)
+
+			assert.Equal(t, tt.expectedRingIndex, *r.ringIndex)
+			assert.Equal(t, tt.expectedRing, r.ring)
+		})
+	}
+}

+ 43 - 0
overlay/virtqueue/descriptor.go

@@ -0,0 +1,43 @@
+package virtqueue
+
+// descriptorFlag is a flag that describes a [Descriptor].
+type descriptorFlag uint16
+
+const (
+	// descriptorFlagHasNext marks a descriptor chain as continuing via the next
+	// field.
+	descriptorFlagHasNext descriptorFlag = 1 << iota
+	// descriptorFlagWritable marks a buffer as device write-only (otherwise
+	// device read-only).
+	descriptorFlagWritable
+	// descriptorFlagIndirect means the buffer contains a list of buffer
+	// descriptors to provide an additional layer of indirection.
+	// Only allowed when the [virtio.FeatureIndirectDescriptors] feature was
+	// negotiated.
+	descriptorFlagIndirect
+)
+
+// descriptorSize is the number of bytes needed to store a [Descriptor] in
+// memory.
+const descriptorSize = 16
+
+// Descriptor describes (a part of) a buffer which is either read-only for the
+// device or write-only for the device (depending on [descriptorFlagWritable]).
+// Multiple descriptors can be chained to produce a "descriptor chain" that can
+// contain both device-readable and device-writable buffers. Device-readable
+// descriptors always come first in a chain. A single, large buffer may be
+// split up by chaining multiple similar descriptors that reference different
+// memory pages. This is required, because buffers may exceed a single page size
+// and the memory accessed by the device is expected to be continuous.
+type Descriptor struct {
+	// address is the address to the continuous memory holding the data for this
+	// descriptor.
+	address uintptr
+	// length is the amount of bytes stored at address.
+	length uint32
+	// flags that describe this descriptor.
+	flags descriptorFlag
+	// next contains the index of the next descriptor continuing this descriptor
+	// chain when the [descriptorFlagHasNext] flag is set.
+	next uint16
+}

+ 12 - 0
overlay/virtqueue/descriptor_internal_test.go

@@ -0,0 +1,12 @@
+package virtqueue
+
+import (
+	"testing"
+	"unsafe"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestDescriptor_Size(t *testing.T) {
+	assert.EqualValues(t, descriptorSize, unsafe.Sizeof(Descriptor{}))
+}

+ 465 - 0
overlay/virtqueue/descriptor_table.go

@@ -0,0 +1,465 @@
+package virtqueue
+
+import (
+	"errors"
+	"fmt"
+	"math"
+	"unsafe"
+
+	"golang.org/x/sys/unix"
+)
+
+var (
+	// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
+	// no buffers, which is not allowed.
+	ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
+
+	// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
+	// exhausted, meaning that the queue is full.
+	ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
+
+	// ErrInvalidDescriptorChain is returned when a descriptor chain is not
+	// valid for a given operation.
+	ErrInvalidDescriptorChain = errors.New("invalid descriptor chain")
+)
+
+// noFreeHead is used to mark when all descriptors are in use and we have no
+// free chain. This value is impossible to occur as an index naturally, because
+// it exceeds the maximum queue size.
+const noFreeHead = uint16(math.MaxUint16)
+
+// descriptorTableSize is the number of bytes needed to store a
+// [DescriptorTable] with the given queue size in memory.
+func descriptorTableSize(queueSize int) int {
+	return descriptorSize * queueSize
+}
+
+// descriptorTableAlignment is the minimum alignment of a [DescriptorTable]
+// in memory, as required by the virtio spec.
+const descriptorTableAlignment = 16
+
+// DescriptorTable is a table that holds [Descriptor]s, addressed via their
+// index in the slice.
+type DescriptorTable struct {
+	descriptors []Descriptor
+
+	// freeHeadIndex is the index of the head of the descriptor chain which
+	// contains all currently unused descriptors. When all descriptors are in
+	// use, this has the special value of noFreeHead.
+	freeHeadIndex uint16
+	// freeNum tracks the number of descriptors which are currently not in use.
+	freeNum uint16
+
+	bufferBase uintptr
+	bufferSize int
+	itemSize   int
+}
+
+// newDescriptorTable creates a descriptor table that uses the given underlying
+// memory. The Length of the memory slice must match the size needed for the
+// descriptor table (see [descriptorTableSize]) for the given queue size.
+//
+// Before this descriptor table can be used, [initialize] must be called.
+func newDescriptorTable(queueSize int, mem []byte, itemSize int) *DescriptorTable {
+	dtSize := descriptorTableSize(queueSize)
+	if len(mem) != dtSize {
+		panic(fmt.Sprintf("memory size (%v) does not match required size "+
+			"for descriptor table: %v", len(mem), dtSize))
+	}
+
+	return &DescriptorTable{
+		descriptors: unsafe.Slice((*Descriptor)(unsafe.Pointer(&mem[0])), queueSize),
+		// We have no free descriptors until they were initialized.
+		freeHeadIndex: noFreeHead,
+		freeNum:       0,
+		itemSize:      itemSize, //todo configurable? needs to be page-aligned
+	}
+}
+
+// Address returns the pointer to the beginning of the descriptor table in
+// memory. Do not modify the memory directly to not interfere with this
+// implementation.
+func (dt *DescriptorTable) Address() uintptr {
+	if dt.descriptors == nil {
+		panic("descriptor table is not initialized")
+	}
+	//should be same as dt.bufferBase
+	return uintptr(unsafe.Pointer(&dt.descriptors[0]))
+}
+
+func (dt *DescriptorTable) Size() uintptr {
+	if dt.descriptors == nil {
+		panic("descriptor table is not initialized")
+	}
+	return uintptr(dt.bufferSize)
+}
+
+// BufferAddresses returns a map of pointer->size for all allocations used by the table
+func (dt *DescriptorTable) BufferAddresses() map[uintptr]int {
+	if dt.descriptors == nil {
+		panic("descriptor table is not initialized")
+	}
+
+	return map[uintptr]int{dt.bufferBase: dt.bufferSize}
+}
+
+// initializeDescriptors allocates buffers with the size of a full memory page
+// for each descriptor in the table. While this may be a bit wasteful, it makes
+// dealing with descriptors way easier. Without this preallocation, we would
+// have to allocate and free memory on demand, increasing complexity.
+//
+// All descriptors will be marked as free and will form a free chain. The
+// addresses of all descriptors will be populated while their length remains
+// zero.
+func (dt *DescriptorTable) initializeDescriptors() error {
+	numDescriptors := len(dt.descriptors)
+
+	// Allocate ONE large region for all buffers
+	totalSize := dt.itemSize * numDescriptors
+	basePtr, err := unix.MmapPtr(-1, 0, nil, uintptr(totalSize),
+		unix.PROT_READ|unix.PROT_WRITE,
+		unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
+	if err != nil {
+		return fmt.Errorf("allocate buffer memory for descriptors: %w", err)
+	}
+
+	// Store the base for cleanup later
+	dt.bufferBase = uintptr(basePtr)
+	dt.bufferSize = totalSize
+
+	for i := range dt.descriptors {
+		dt.descriptors[i] = Descriptor{
+			address: dt.bufferBase + uintptr(i*dt.itemSize),
+			length:  0,
+			// All descriptors should form a free chain that loops around.
+			flags: descriptorFlagHasNext,
+			next:  uint16((i + 1) % len(dt.descriptors)),
+		}
+	}
+
+	// All descriptors are free to use now.
+	dt.freeHeadIndex = 0
+	dt.freeNum = uint16(len(dt.descriptors))
+
+	return nil
+}
+
+// releaseBuffers releases all allocated buffers for this descriptor table.
+// The implementation will try to release as many buffers as possible and
+// collect potential errors before returning them.
+// The descriptor table should no longer be used after calling this.
+func (dt *DescriptorTable) releaseBuffers() error {
+	for i := range dt.descriptors {
+		descriptor := &dt.descriptors[i]
+		descriptor.address = 0
+	}
+
+	// As a safety measure, make sure no descriptors can be used anymore.
+	dt.freeHeadIndex = noFreeHead
+	dt.freeNum = 0
+
+	if dt.bufferBase != 0 {
+		// The pointer points to memory not managed by Go, so this conversion
+		// is safe. See https://github.com/golang/go/issues/58625
+		dt.bufferBase = 0
+		//goland:noinspection GoVetUnsafePointer
+		err := unix.MunmapPtr(unsafe.Pointer(dt.bufferBase), uintptr(dt.bufferSize))
+		if err != nil {
+			return fmt.Errorf("release buffer memory: %w", err)
+		}
+	}
+
+	return nil
+}
+
+func (dt *DescriptorTable) CreateDescriptorForOutputs() (uint16, error) {
+	//todo just fill the damn table
+	// Do we still have enough free descriptors?
+
+	if 1 > dt.freeNum {
+		return 0, ErrNotEnoughFreeDescriptors
+	}
+
+	// Above validation ensured that there is at least one free descriptor, so
+	// the free descriptor chain head should be valid.
+	if dt.freeHeadIndex == noFreeHead {
+		panic("free descriptor chain head is unset but there should be free descriptors")
+	}
+
+	// To avoid having to iterate over the whole table to find the descriptor
+	// pointing to the head just to replace the free head, we instead always
+	// create descriptor chains from the descriptors coming after the head.
+	// This way we only have to touch the head as a last resort, when all other
+	// descriptors are already used.
+	head := dt.descriptors[dt.freeHeadIndex].next
+	desc := &dt.descriptors[head]
+	next := desc.next
+
+	checkUnusedDescriptorLength(head, desc)
+
+	// Give the device the maximum available number of bytes to write into.
+	desc.length = uint32(dt.itemSize)
+	desc.flags = 0 // descriptorFlagWritable
+	desc.next = 0  // Not necessary to clear this, it's just for looks.
+
+	dt.freeNum -= 1
+
+	if dt.freeNum == 0 {
+		// The last descriptor in the chain should be the free chain head
+		// itself.
+		if next != dt.freeHeadIndex {
+			panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
+		}
+
+		// When this new chain takes up all remaining descriptors, we no longer
+		// have a free chain.
+		dt.freeHeadIndex = noFreeHead
+	} else {
+		// We took some descriptors out of the free chain, so make sure to close
+		// the circle again.
+		dt.descriptors[dt.freeHeadIndex].next = next
+	}
+
+	return head, nil
+}
+
+func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
+	// Do we still have enough free descriptors?
+	if 1 > dt.freeNum {
+		return 0, ErrNotEnoughFreeDescriptors
+	}
+
+	// Above validation ensured that there is at least one free descriptor, so
+	// the free descriptor chain head should be valid.
+	if dt.freeHeadIndex == noFreeHead {
+		panic("free descriptor chain head is unset but there should be free descriptors")
+	}
+
+	// To avoid having to iterate over the whole table to find the descriptor
+	// pointing to the head just to replace the free head, we instead always
+	// create descriptor chains from the descriptors coming after the head.
+	// This way we only have to touch the head as a last resort, when all other
+	// descriptors are already used.
+	head := dt.descriptors[dt.freeHeadIndex].next
+	desc := &dt.descriptors[head]
+	next := desc.next
+
+	checkUnusedDescriptorLength(head, desc)
+
+	// Give the device the maximum available number of bytes to write into.
+	desc.length = uint32(dt.itemSize)
+	desc.flags = descriptorFlagWritable
+	desc.next = 0 // Not necessary to clear this, it's just for looks.
+
+	dt.freeNum -= 1
+
+	if dt.freeNum == 0 {
+		// The last descriptor in the chain should be the free chain head
+		// itself.
+		if next != dt.freeHeadIndex {
+			panic("descriptor chain takes up all free descriptors but does not end with the free chain head")
+		}
+
+		// When this new chain takes up all remaining descriptors, we no longer
+		// have a free chain.
+		dt.freeHeadIndex = noFreeHead
+	} else {
+		// We took some descriptors out of the free chain, so make sure to close
+		// the circle again.
+		dt.descriptors[dt.freeHeadIndex].next = next
+	}
+
+	return head, nil
+}
+
+// TODO: Implement a zero-copy variant of createDescriptorChain?
+
+// getDescriptorChain returns the device-readable buffers (out buffers) and
+// device-writable buffers (in buffers) of the descriptor chain that starts with
+// the given head index. The descriptor chain must have been created using
+// [createDescriptorChain] and must not have been freed yet (meaning that the
+// head index must not be contained in the free chain).
+//
+// Be careful to only access the returned buffer slices when the device has not
+// yet or is no longer using them. They must not be accessed after
+// [freeDescriptorChain] has been called.
+func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
+	if int(head) > len(dt.descriptors) {
+		return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
+	}
+
+	// Iterate over the chain. The iteration is limited to the queue size to
+	// avoid ending up in an endless loop when things go very wrong.
+	next := head
+	for range len(dt.descriptors) {
+		if next == dt.freeHeadIndex {
+			return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
+		}
+
+		desc := &dt.descriptors[next]
+
+		// The descriptor address points to memory not managed by Go, so this
+		// conversion is safe. See https://github.com/golang/go/issues/58625
+		//goland:noinspection GoVetUnsafePointer
+		bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
+
+		if desc.flags&descriptorFlagWritable == 0 {
+			outBuffers = append(outBuffers, bs)
+		} else {
+			inBuffers = append(inBuffers, bs)
+		}
+
+		// Is this the tail of the chain?
+		if desc.flags&descriptorFlagHasNext == 0 {
+			break
+		}
+
+		// Detect loops.
+		if desc.next == head {
+			return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
+		}
+
+		next = desc.next
+	}
+
+	return
+}
+
+func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
+	if int(head) > len(dt.descriptors) {
+		return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
+	}
+
+	desc := &dt.descriptors[head] //todo this is a pretty nasty hack with no checks
+
+	// The descriptor address points to memory not managed by Go, so this
+	// conversion is safe. See https://github.com/golang/go/issues/58625
+	//goland:noinspection GoVetUnsafePointer
+	bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
+	return bs, nil
+}
+
+func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
+	if int(head) > len(dt.descriptors) {
+		return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
+	}
+
+	// Iterate over the chain. The iteration is limited to the queue size to
+	// avoid ending up in an endless loop when things go very wrong.
+	next := head
+	for range len(dt.descriptors) {
+		if next == dt.freeHeadIndex {
+			return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
+		}
+
+		desc := &dt.descriptors[next]
+
+		// The descriptor address points to memory not managed by Go, so this
+		// conversion is safe. See https://github.com/golang/go/issues/58625
+		//goland:noinspection GoVetUnsafePointer
+		bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
+
+		if desc.flags&descriptorFlagWritable == 0 {
+			return fmt.Errorf("there should not be an outbuffer in %d", head)
+		} else {
+			*inBuffers = append(*inBuffers, bs)
+		}
+
+		// Is this the tail of the chain?
+		if desc.flags&descriptorFlagHasNext == 0 {
+			break
+		}
+
+		// Detect loops.
+		if desc.next == head {
+			return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
+		}
+
+		next = desc.next
+	}
+
+	return nil
+}
+
+// freeDescriptorChain can be used to free a descriptor chain when it is no
+// longer in use. The descriptor chain that starts with the given index will be
+// put back into the free chain, so the descriptors can be used for later calls
+// of [createDescriptorChain].
+// The descriptor chain must have been created using [createDescriptorChain] and
+// must not have been freed yet (meaning that the head index must not be
+// contained in the free chain).
+func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
+	if int(head) > len(dt.descriptors) {
+		return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
+	}
+
+	// Iterate over the chain. The iteration is limited to the queue size to
+	// avoid ending up in an endless loop when things go very wrong.
+	next := head
+	var tailDesc *Descriptor
+	var chainLen uint16
+	for range len(dt.descriptors) {
+		if next == dt.freeHeadIndex {
+			return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
+		}
+
+		desc := &dt.descriptors[next]
+		chainLen++
+
+		// Set the length of all unused descriptors back to zero.
+		desc.length = 0
+
+		// Unset all flags except the next flag.
+		desc.flags &= descriptorFlagHasNext
+
+		// Is this the tail of the chain?
+		if desc.flags&descriptorFlagHasNext == 0 {
+			tailDesc = desc
+			break
+		}
+
+		// Detect loops.
+		if desc.next == head {
+			return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
+		}
+
+		next = desc.next
+	}
+	if tailDesc == nil {
+		// A descriptor chain longer than the queue size but without loops
+		// should be impossible.
+		panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
+	}
+
+	// The tail descriptor does not have the next flag set, but when it comes
+	// back into the free chain, it should have.
+	tailDesc.flags = descriptorFlagHasNext
+
+	if dt.freeHeadIndex == noFreeHead {
+		// The whole free chain was used up, so we turn this returned descriptor
+		// chain into the new free chain by completing the circle and using its
+		// head.
+		tailDesc.next = head
+		dt.freeHeadIndex = head
+	} else {
+		// Attach the returned chain at the beginning of the free chain but
+		// right after the free chain head.
+		freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
+		tailDesc.next = freeHeadDesc.next
+		freeHeadDesc.next = head
+	}
+
+	dt.freeNum += chainLen
+
+	return nil
+}
+
+// checkUnusedDescriptorLength asserts that the length of an unused descriptor
+// is zero, as it should be.
+// This is not a requirement by the virtio spec but rather a thing we do to
+// notice when our algorithm goes sideways.
+func checkUnusedDescriptorLength(index uint16, desc *Descriptor) {
+	if desc.length != 0 {
+		panic(fmt.Sprintf("descriptor %d should be unused but has a non-zero length", index))
+	}
+}

+ 7 - 0
overlay/virtqueue/doc.go

@@ -0,0 +1,7 @@
+// Package virtqueue implements the driver-side for a virtio queue as described
+// in the specification:
+// https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-270006
+// This package does not make assumptions about the device that consumes the
+// queue. It rather just allocates the queue structures in memory and provides
+// methods to interact with it.
+package virtqueue

+ 45 - 0
overlay/virtqueue/eventfd_test.go

@@ -0,0 +1,45 @@
+package virtqueue
+
+import (
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"gvisor.dev/gvisor/pkg/eventfd"
+)
+
+// Tests how an eventfd and a waiting goroutine can be gracefully closed.
+// Extends the eventfd test suite:
+// https://github.com/google/gvisor/blob/0799336d64be65eb97d330606c30162dc3440cab/pkg/eventfd/eventfd_test.go
+func TestEventFD_CancelWait(t *testing.T) {
+	efd, err := eventfd.Create()
+	require.NoError(t, err)
+	t.Cleanup(func() {
+		assert.NoError(t, efd.Close())
+	})
+
+	var stop bool
+
+	done := make(chan struct{})
+	go func() {
+		for !stop {
+			_ = efd.Wait()
+		}
+		close(done)
+	}()
+	select {
+	case <-done:
+		t.Fatalf("goroutine ended early")
+	case <-time.After(500 * time.Millisecond):
+	}
+
+	stop = true
+	assert.NoError(t, efd.Notify())
+	select {
+	case <-done:
+		break
+	case <-time.After(5 * time.Second):
+		t.Error("goroutine did not end")
+	}
+}

+ 33 - 0
overlay/virtqueue/size.go

@@ -0,0 +1,33 @@
+package virtqueue
+
+import (
+	"errors"
+	"fmt"
+)
+
+// ErrQueueSizeInvalid is returned when a queue size is invalid.
+var ErrQueueSizeInvalid = errors.New("queue size is invalid")
+
+// CheckQueueSize checks if the given value would be a valid size for a
+// virtqueue and returns an [ErrQueueSizeInvalid], if not.
+func CheckQueueSize(queueSize int) error {
+	if queueSize <= 0 {
+		return fmt.Errorf("%w: %d is too small", ErrQueueSizeInvalid, queueSize)
+	}
+
+	// The queue size must always be a power of 2.
+	// This ensures that ring indexes wrap correctly when the 16-bit integers
+	// overflow.
+	if queueSize&(queueSize-1) != 0 {
+		return fmt.Errorf("%w: %d is not a power of 2", ErrQueueSizeInvalid, queueSize)
+	}
+
+	// The largest power of 2 that fits into a 16-bit integer is 32768.
+	// 2 * 32768 would be 65536 which no longer fits.
+	if queueSize > 32768 {
+		return fmt.Errorf("%w: %d is larger than the maximum possible queue size 32768",
+			ErrQueueSizeInvalid, queueSize)
+	}
+
+	return nil
+}

+ 59 - 0
overlay/virtqueue/size_test.go

@@ -0,0 +1,59 @@
+package virtqueue
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCheckQueueSize(t *testing.T) {
+	tests := []struct {
+		name        string
+		queueSize   int
+		containsErr string
+	}{
+		{
+			name:        "negative",
+			queueSize:   -1,
+			containsErr: "too small",
+		},
+		{
+			name:        "zero",
+			queueSize:   0,
+			containsErr: "too small",
+		},
+		{
+			name:        "not a power of 2",
+			queueSize:   24,
+			containsErr: "not a power of 2",
+		},
+		{
+			name:        "too large",
+			queueSize:   65536,
+			containsErr: "larger than the maximum",
+		},
+		{
+			name:      "valid 1",
+			queueSize: 1,
+		},
+		{
+			name:      "valid 256",
+			queueSize: 256,
+		},
+
+		{
+			name:      "valid 32768",
+			queueSize: 32768,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := CheckQueueSize(tt.queueSize)
+			if tt.containsErr != "" {
+				assert.ErrorContains(t, err, tt.containsErr)
+			} else {
+				assert.NoError(t, err)
+			}
+		})
+	}
+}

+ 421 - 0
overlay/virtqueue/split_virtqueue.go

@@ -0,0 +1,421 @@
+package virtqueue
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"os"
+
+	"github.com/slackhq/nebula/overlay/eventfd"
+	"golang.org/x/sys/unix"
+)
+
+// SplitQueue is a virtqueue that consists of several parts, where each part is
+// writeable by either the driver or the device, but not both.
+type SplitQueue struct {
+	// size is the size of the queue.
+	size int
+	// buf is the underlying memory used for the queue.
+	buf []byte
+
+	descriptorTable *DescriptorTable
+	availableRing   *AvailableRing
+	usedRing        *UsedRing
+
+	// kickEventFD is used to signal the device when descriptor chains were
+	// added to the available ring.
+	kickEventFD eventfd.EventFD
+	// callEventFD is used by the device to signal when it has used descriptor
+	// chains and put them in the used ring.
+	callEventFD eventfd.EventFD
+
+	// stop is used by [SplitQueue.Close] to cancel the goroutine that handles
+	// used buffer notifications. It blocks until the goroutine ended.
+	stop func() error
+
+	itemSize int
+
+	epoll eventfd.Epoll
+	more  int
+}
+
+// NewSplitQueue allocates a new [SplitQueue] in memory. The given queue size
+// specifies the number of entries/buffers the queue can hold. This also affects
+// the memory consumption.
+func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
+	if err = CheckQueueSize(queueSize); err != nil {
+		return nil, err
+	}
+
+	if itemSize%os.Getpagesize() != 0 {
+		return nil, errors.New("split queue size must be multiple of os.Getpagesize()")
+	}
+
+	sq := SplitQueue{
+		size:     queueSize,
+		itemSize: itemSize,
+	}
+
+	// Clean up a partially initialized queue when something fails.
+	defer func() {
+		if err != nil {
+			_ = sq.Close()
+		}
+	}()
+
+	// There are multiple ways for how the memory for the virtqueue could be
+	// allocated. We could use Go native structs with arrays inside them, but
+	// this wouldn't allow us to make the queue size configurable. And including
+	// a slice in the Go structs wouldn't work, because this would just put the
+	// Go slice descriptor into the memory region which the virtio device will
+	// not understand.
+	// Additionally, Go does not allow us to ensure a correct alignment of the
+	// parts of the virtqueue, as it is required by the virtio specification.
+	//
+	// To resolve this, let's just allocate the memory manually by allocating
+	// one or more memory pages, depending on the queue size. Making the
+	// virtqueue start at the beginning of a page is not strictly necessary, as
+	// the virtio specification does not require it to be continuous in the
+	// physical memory of the host (e.g. the vhost implementation in the kernel
+	// always uses copy_from_user to access it), but this makes it very easy to
+	// guarantee the alignment. Also, it is not required for the virtqueue parts
+	// to be in the same memory region, as we pass separate pointers to them to
+	// the device, but this design just makes things easier to implement.
+	//
+	// One added benefit of allocating the memory manually is, that we have full
+	// control over its lifetime and don't risk the garbage collector to collect
+	// our valuable structures while the device still works with them.
+
+	// The descriptor table is at the start of the page, so alignment is not an
+	// issue here.
+	descriptorTableStart := 0
+	descriptorTableEnd := descriptorTableStart + descriptorTableSize(queueSize)
+	availableRingStart := align(descriptorTableEnd, availableRingAlignment)
+	availableRingEnd := availableRingStart + availableRingSize(queueSize)
+	usedRingStart := align(availableRingEnd, usedRingAlignment)
+	usedRingEnd := usedRingStart + usedRingSize(queueSize)
+
+	sq.buf, err = unix.Mmap(-1, 0, usedRingEnd,
+		unix.PROT_READ|unix.PROT_WRITE,
+		unix.MAP_PRIVATE|unix.MAP_ANONYMOUS)
+	if err != nil {
+		return nil, fmt.Errorf("allocate virtqueue buffer: %w", err)
+	}
+
+	sq.descriptorTable = newDescriptorTable(queueSize, sq.buf[descriptorTableStart:descriptorTableEnd], sq.itemSize)
+	sq.availableRing = newAvailableRing(queueSize, sq.buf[availableRingStart:availableRingEnd])
+	sq.usedRing = newUsedRing(queueSize, sq.buf[usedRingStart:usedRingEnd])
+
+	sq.kickEventFD, err = eventfd.New()
+	if err != nil {
+		return nil, fmt.Errorf("create kick event file descriptor: %w", err)
+	}
+	sq.callEventFD, err = eventfd.New()
+	if err != nil {
+		return nil, fmt.Errorf("create call event file descriptor: %w", err)
+	}
+
+	if err = sq.descriptorTable.initializeDescriptors(); err != nil {
+		return nil, fmt.Errorf("initialize descriptors: %w", err)
+	}
+
+	sq.epoll, err = eventfd.NewEpoll()
+	if err != nil {
+		return nil, err
+	}
+	err = sq.epoll.AddEvent(sq.callEventFD.FD())
+	if err != nil {
+		return nil, err
+	}
+
+	// Consume used buffer notifications in the background.
+	sq.stop = sq.startConsumeUsedRing()
+
+	return &sq, nil
+}
+
+// Size returns the size of this queue, which is the number of entries/buffers
+// this queue can hold.
+func (sq *SplitQueue) Size() int {
+	return sq.size
+}
+
+// DescriptorTable returns the [DescriptorTable] behind this queue.
+func (sq *SplitQueue) DescriptorTable() *DescriptorTable {
+	return sq.descriptorTable
+}
+
+// AvailableRing returns the [AvailableRing] behind this queue.
+func (sq *SplitQueue) AvailableRing() *AvailableRing {
+	return sq.availableRing
+}
+
+// UsedRing returns the [UsedRing] behind this queue.
+func (sq *SplitQueue) UsedRing() *UsedRing {
+	return sq.usedRing
+}
+
+// KickEventFD returns the kick event file descriptor behind this queue.
+// The returned file descriptor should be used with great care to not interfere
+// with this implementation.
+func (sq *SplitQueue) KickEventFD() int {
+	return sq.kickEventFD.FD()
+}
+
+// CallEventFD returns the call event file descriptor behind this queue.
+// The returned file descriptor should be used with great care to not interfere
+// with this implementation.
+func (sq *SplitQueue) CallEventFD() int {
+	return sq.callEventFD.FD()
+}
+
+// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
+// A function is returned that can be used to gracefully cancel it. todo rename
+func (sq *SplitQueue) startConsumeUsedRing() func() error {
+	return func() error {
+
+		// The goroutine blocks until it receives a signal on the event file
+		// descriptor, so it will never notice the context being canceled.
+		// To resolve this, we can just produce a fake-signal ourselves to wake
+		// it up.
+		if err := sq.callEventFD.Kick(); err != nil {
+			return fmt.Errorf("wake up goroutine: %w", err)
+		}
+		return nil
+	}
+}
+
+func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
+	var n int
+	var err error
+	for ctx.Err() == nil {
+		out, ok := sq.usedRing.takeOne()
+		if ok {
+			return out, nil
+		}
+		// Wait for a signal from the device.
+		if n, err = sq.epoll.Block(); err != nil {
+			return 0, fmt.Errorf("wait: %w", err)
+		}
+
+		if n > 0 {
+			out, ok = sq.usedRing.takeOne()
+			if ok {
+				_ = sq.epoll.Clear() //???
+				return out, nil
+			} else {
+				continue //???
+			}
+		}
+	}
+	return 0, ctx.Err()
+}
+
+func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
+	var n int
+	var err error
+	for ctx.Err() == nil {
+
+		//we have leftovers in the fridge
+		if sq.more > 0 {
+			stillNeedToTake, out := sq.usedRing.take(maxToTake)
+			sq.more = stillNeedToTake
+			return out, nil
+		}
+		//look inside the fridge
+		stillNeedToTake, out := sq.usedRing.take(maxToTake)
+		if len(out) > 0 {
+			sq.more = stillNeedToTake
+			return out, nil
+		}
+		//fridge is empty I guess
+
+		// Wait for a signal from the device.
+		if n, err = sq.epoll.Block(); err != nil {
+			return nil, fmt.Errorf("wait: %w", err)
+		}
+		if n > 0 {
+			_ = sq.epoll.Clear() //???
+			stillNeedToTake, out = sq.usedRing.take(maxToTake)
+			sq.more = stillNeedToTake
+			return out, nil
+		}
+	}
+
+	return nil, ctx.Err()
+}
+
+// OfferDescriptorChain offers a descriptor chain to the device which contains a
+// number of device-readable buffers (out buffers) and device-writable buffers
+// (in buffers).
+//
+// All buffers in the outBuffers slice will be concatenated by chaining
+// descriptors, one for each buffer in the slice. When a buffer is too large to
+// fit into a single descriptor (limited by the system's page size), it will be
+// split up into multiple descriptors within the chain.
+// When numInBuffers is greater than zero, the given number of device-writable
+// descriptors will be appended to the end of the chain, each referencing a
+// whole memory page (see [os.Getpagesize]).
+//
+// When the queue is full and no more descriptor chains can be added, a wrapped
+// [ErrNotEnoughFreeDescriptors] will be returned. If you set waitFree to true,
+// this method will handle this error and will block instead until there are
+// enough free descriptors again.
+//
+// After defining the descriptor chain in the [DescriptorTable], the index of
+// the head of the chain will be made available to the device using the
+// [AvailableRing] and will be returned by this method.
+// Callers should read from the [SplitQueue.UsedDescriptorChains] channel to be
+// notified when the descriptor chain was used by the device and should free the
+// used descriptor chains again using [SplitQueue.FreeDescriptorChain] when
+// they're done with them. When this does not happen, the queue will run full
+// and any further calls to [SplitQueue.OfferDescriptorChain] will stall.
+
+func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
+	// Create a descriptor chain for the given buffers.
+	var (
+		head uint16
+		err  error
+	)
+	for {
+		head, err = sq.descriptorTable.createDescriptorForInputs()
+		if err == nil {
+			break
+		}
+
+		// I don't wanna use errors.Is, it's slow
+		//goland:noinspection GoDirectComparisonOfErrors
+		if err == ErrNotEnoughFreeDescriptors {
+			return 0, err
+		} else {
+			return 0, fmt.Errorf("create descriptor chain: %w", err)
+		}
+	}
+
+	// Make the descriptor chain available to the device.
+	sq.availableRing.offerSingle(head)
+
+	// Notify the device to make it process the updated available ring.
+	if err := sq.kickEventFD.Kick(); err != nil {
+		return head, fmt.Errorf("notify device: %w", err)
+	}
+
+	return head, nil
+}
+
+// GetDescriptorChain returns the device-readable buffers (out buffers) and
+// device-writable buffers (in buffers) of the descriptor chain with the given
+// head index.
+// The head index must be one that was returned by a previous call to
+// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
+// freed yet.
+//
+// Be careful to only access the returned buffer slices when the device is no
+// longer using them. They must not be accessed after
+// [SplitQueue.FreeDescriptorChain] has been called.
+func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
+	return sq.descriptorTable.getDescriptorChain(head)
+}
+
+func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
+	sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
+	return sq.descriptorTable.getDescriptorItem(head)
+}
+
+func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
+	return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
+}
+
+// FreeDescriptorChain frees the descriptor chain with the given head index.
+// The head index must be one that was returned by a previous call to
+// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
+// freed yet.
+//
+// This creates new room in the queue which can be used by following
+// [SplitQueue.OfferDescriptorChain] calls.
+// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
+// are waiting for free room in the queue, they may become unblocked by this.
+func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
+	//not called under lock
+	if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
+		return fmt.Errorf("free: %w", err)
+	}
+
+	return nil
+}
+
+func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
+	//not called under lock
+	sq.descriptorTable.descriptors[int(head)].length = uint32(sz)
+}
+
+func (sq *SplitQueue) OfferDescriptorChains(chains []uint16, kick bool) error {
+	//todo not doing this may break eventually?
+	//not called under lock
+	//if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
+	//	return fmt.Errorf("free: %w", err)
+	//}
+
+	// Make the descriptor chain available to the device.
+	sq.availableRing.offer(chains)
+
+	// Notify the device to make it process the updated available ring.
+	if kick {
+		return sq.Kick()
+	}
+
+	return nil
+}
+
+func (sq *SplitQueue) Kick() error {
+	if err := sq.kickEventFD.Kick(); err != nil {
+		return fmt.Errorf("notify device: %w", err)
+	}
+	return nil
+}
+
+// Close releases all resources used for this queue.
+// The implementation will try to release as many resources as possible and
+// collect potential errors before returning them.
+func (sq *SplitQueue) Close() error {
+	var errs []error
+
+	if sq.stop != nil {
+		// This has to happen before the event file descriptors may be closed.
+		if err := sq.stop(); err != nil {
+			errs = append(errs, fmt.Errorf("stop consume used ring: %w", err))
+		}
+
+		// Make sure that this code block is executed only once.
+		sq.stop = nil
+	}
+
+	if err := sq.kickEventFD.Close(); err != nil {
+		errs = append(errs, fmt.Errorf("close kick event file descriptor: %w", err))
+	}
+	if err := sq.callEventFD.Close(); err != nil {
+		errs = append(errs, fmt.Errorf("close call event file descriptor: %w", err))
+	}
+
+	if err := sq.descriptorTable.releaseBuffers(); err != nil {
+		errs = append(errs, fmt.Errorf("release descriptor buffers: %w", err))
+	}
+
+	if sq.buf != nil {
+		if err := unix.Munmap(sq.buf); err == nil {
+			sq.buf = nil
+		} else {
+			errs = append(errs, fmt.Errorf("unmap virtqueue buffer: %w", err))
+		}
+	}
+
+	return errors.Join(errs...)
+}
+
+func align(index, alignment int) int {
+	remainder := index % alignment
+	if remainder == 0 {
+		return index
+	}
+	return index + alignment - remainder
+}

+ 21 - 0
overlay/virtqueue/used_element.go

@@ -0,0 +1,21 @@
+package virtqueue
+
+// usedElementSize is the number of bytes needed to store a [UsedElement] in
+// memory.
+const usedElementSize = 8
+
+// UsedElement is an element of the [UsedRing] and describes a descriptor chain
+// that was used by the device.
+type UsedElement struct {
+	// DescriptorIndex is the index of the head of the used descriptor chain in
+	// the [DescriptorTable].
+	// The index is 32-bit here for padding reasons.
+	DescriptorIndex uint32
+	// Length is the number of bytes written into the device writable portion of
+	// the buffer described by the descriptor chain.
+	Length uint32
+}
+
+func (u *UsedElement) GetHead() uint16 {
+	return uint16(u.DescriptorIndex)
+}

+ 12 - 0
overlay/virtqueue/used_element_internal_test.go

@@ -0,0 +1,12 @@
+package virtqueue
+
+import (
+	"testing"
+	"unsafe"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestUsedElement_Size(t *testing.T) {
+	assert.EqualValues(t, usedElementSize, unsafe.Sizeof(UsedElement{}))
+}

+ 184 - 0
overlay/virtqueue/used_ring.go

@@ -0,0 +1,184 @@
+package virtqueue
+
+import (
+	"fmt"
+	"unsafe"
+)
+
+// usedRingFlag is a flag that describes a [UsedRing].
+type usedRingFlag uint16
+
+const (
+	// usedRingFlagNoNotify is used by the host to advise the guest to not
+	// kick it when adding a buffer. It's unreliable, so it's simply an
+	// optimization. Guest will still kick when it's out of buffers.
+	usedRingFlagNoNotify usedRingFlag = 1 << iota
+)
+
+// usedRingSize is the number of bytes needed to store a [UsedRing] with the
+// given queue size in memory.
+func usedRingSize(queueSize int) int {
+	return 6 + usedElementSize*queueSize
+}
+
+// usedRingAlignment is the minimum alignment of a [UsedRing] in memory, as
+// required by the virtio spec.
+const usedRingAlignment = 4
+
+// UsedRing is where the device returns descriptor chains once it is done with
+// them. Each ring entry is a [UsedElement]. It is only written to by the device
+// and read by the driver.
+//
+// Because the size of the ring depends on the queue size, we cannot define a
+// Go struct with a static size that maps to the memory of the ring. Instead,
+// this struct only contains pointers to the corresponding memory areas.
+type UsedRing struct {
+	initialized bool
+
+	// flags that describe this ring.
+	flags *usedRingFlag
+	// ringIndex indicates where the device would put the next entry into the
+	// ring (modulo the queue size).
+	ringIndex *uint16
+	// ring contains the [UsedElement]s. It wraps around at queue size.
+	ring []UsedElement
+	// availableEvent is not used by this implementation, but we reserve it
+	// anyway to avoid issues in case a device may try to write to it, contrary
+	// to the virtio specification.
+	availableEvent *uint16
+
+	// lastIndex is the internal ringIndex up to which all [UsedElement]s were
+	// processed.
+	lastIndex uint16
+
+	//mu sync.Mutex
+}
+
+// newUsedRing creates a used ring that uses the given underlying memory. The
+// length of the memory slice must match the size needed for the ring (see
+// [usedRingSize]) for the given queue size.
+func newUsedRing(queueSize int, mem []byte) *UsedRing {
+	ringSize := usedRingSize(queueSize)
+	if len(mem) != ringSize {
+		panic(fmt.Sprintf("memory size (%v) does not match required size "+
+			"for used ring: %v", len(mem), ringSize))
+	}
+
+	r := UsedRing{
+		initialized:    true,
+		flags:          (*usedRingFlag)(unsafe.Pointer(&mem[0])),
+		ringIndex:      (*uint16)(unsafe.Pointer(&mem[2])),
+		ring:           unsafe.Slice((*UsedElement)(unsafe.Pointer(&mem[4])), queueSize),
+		availableEvent: (*uint16)(unsafe.Pointer(&mem[ringSize-2])),
+	}
+	r.lastIndex = *r.ringIndex
+	return &r
+}
+
+// Address returns the pointer to the beginning of the ring in memory.
+// Do not modify the memory directly to not interfere with this implementation.
+func (r *UsedRing) Address() uintptr {
+	if !r.initialized {
+		panic("used ring is not initialized")
+	}
+	return uintptr(unsafe.Pointer(r.flags))
+}
+
+// take returns all new [UsedElement]s that the device put into the ring and
+// that weren't already returned by a previous call to this method.
+// had a lock, I removed it
+func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
+	//r.mu.Lock()
+	//defer r.mu.Unlock()
+
+	ringIndex := *r.ringIndex
+	if ringIndex == r.lastIndex {
+		// Nothing new.
+		return 0, nil
+	}
+
+	// Calculate the number new used elements that we can read from the ring.
+	// The ring index may wrap, so special handling for that case is needed.
+	count := int(ringIndex - r.lastIndex)
+	if count < 0 {
+		count += 0xffff
+	}
+
+	stillNeedToTake := 0
+
+	if maxToTake > 0 {
+		stillNeedToTake = count - maxToTake
+		if stillNeedToTake < 0 {
+			stillNeedToTake = 0
+		}
+		count = min(count, maxToTake)
+	}
+
+	// The number of new elements can never exceed the queue size.
+	if count > len(r.ring) {
+		panic("used ring contains more new elements than the ring is long")
+	}
+
+	elems := make([]UsedElement, count)
+	for i := range count {
+		elems[i] = r.ring[r.lastIndex%uint16(len(r.ring))]
+		r.lastIndex++
+	}
+
+	return stillNeedToTake, elems
+}
+
+func (r *UsedRing) takeOne() (uint16, bool) {
+	//r.mu.Lock()
+	//defer r.mu.Unlock()
+
+	ringIndex := *r.ringIndex
+	if ringIndex == r.lastIndex {
+		// Nothing new.
+		return 0xffff, false
+	}
+
+	// Calculate the number new used elements that we can read from the ring.
+	// The ring index may wrap, so special handling for that case is needed.
+	count := int(ringIndex - r.lastIndex)
+	if count < 0 {
+		count += 0xffff
+	}
+
+	// The number of new elements can never exceed the queue size.
+	if count > len(r.ring) {
+		panic("used ring contains more new elements than the ring is long")
+	}
+
+	if count == 0 {
+		return 0xffff, false
+	}
+
+	out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
+	r.lastIndex++
+
+	return out, true
+}
+
+// InitOfferSingle is only used to pre-fill the used queue at startup, and should not be used if the device is running!
+func (r *UsedRing) InitOfferSingle(x uint16, size int) {
+	//always called under lock
+	//r.mu.Lock()
+	//defer r.mu.Unlock()
+
+	offset := 0
+	// Add descriptor chain heads to the ring.
+
+	// The 16-bit ring index may overflow. This is expected and is not an
+	// issue because the size of the ring array (which equals the queue
+	// size) is always a power of 2 and smaller than the highest possible
+	// 16-bit value.
+	insertIndex := int(*r.ringIndex+uint16(offset)) % len(r.ring)
+	r.ring[insertIndex] = UsedElement{
+		DescriptorIndex: uint32(x),
+		Length:          uint32(size),
+	}
+
+	// Increase the ring index by the number of descriptor chains added to the ring.
+	*r.ringIndex += 1
+}

+ 136 - 0
overlay/virtqueue/used_ring_internal_test.go

@@ -0,0 +1,136 @@
+package virtqueue
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestUsedRing_MemoryLayout(t *testing.T) {
+	const queueSize = 2
+
+	memory := make([]byte, usedRingSize(queueSize))
+	r := newUsedRing(queueSize, memory)
+
+	*r.flags = 0x01ff
+	*r.ringIndex = 1
+	r.ring[0] = UsedElement{
+		DescriptorIndex: 0x0123,
+		Length:          0x4567,
+	}
+	r.ring[1] = UsedElement{
+		DescriptorIndex: 0x89ab,
+		Length:          0xcdef,
+	}
+
+	assert.Equal(t, []byte{
+		0xff, 0x01,
+		0x01, 0x00,
+		0x23, 0x01, 0x00, 0x00,
+		0x67, 0x45, 0x00, 0x00,
+		0xab, 0x89, 0x00, 0x00,
+		0xef, 0xcd, 0x00, 0x00,
+		0x00, 0x00,
+	}, memory)
+}
+
+//func TestUsedRing_Take(t *testing.T) {
+//	const queueSize = 8
+//
+//	tests := []struct {
+//		name      string
+//		ring      []UsedElement
+//		ringIndex uint16
+//		lastIndex uint16
+//		expected  []UsedElement
+//	}{
+//		{
+//			name: "nothing new",
+//			ring: []UsedElement{
+//				{DescriptorIndex: 1},
+//				{DescriptorIndex: 2},
+//				{DescriptorIndex: 3},
+//				{DescriptorIndex: 4},
+//				{},
+//				{},
+//				{},
+//				{},
+//			},
+//			ringIndex: 4,
+//			lastIndex: 4,
+//			expected:  nil,
+//		},
+//		{
+//			name: "no overflow",
+//			ring: []UsedElement{
+//				{DescriptorIndex: 1},
+//				{DescriptorIndex: 2},
+//				{DescriptorIndex: 3},
+//				{DescriptorIndex: 4},
+//				{},
+//				{},
+//				{},
+//				{},
+//			},
+//			ringIndex: 4,
+//			lastIndex: 1,
+//			expected: []UsedElement{
+//				{DescriptorIndex: 2},
+//				{DescriptorIndex: 3},
+//				{DescriptorIndex: 4},
+//			},
+//		},
+//		{
+//			name: "ring overflow",
+//			ring: []UsedElement{
+//				{DescriptorIndex: 9},
+//				{DescriptorIndex: 10},
+//				{DescriptorIndex: 3},
+//				{DescriptorIndex: 4},
+//				{DescriptorIndex: 5},
+//				{DescriptorIndex: 6},
+//				{DescriptorIndex: 7},
+//				{DescriptorIndex: 8},
+//			},
+//			ringIndex: 10,
+//			lastIndex: 7,
+//			expected: []UsedElement{
+//				{DescriptorIndex: 8},
+//				{DescriptorIndex: 9},
+//				{DescriptorIndex: 10},
+//			},
+//		},
+//		{
+//			name: "index overflow",
+//			ring: []UsedElement{
+//				{DescriptorIndex: 9},
+//				{DescriptorIndex: 10},
+//				{DescriptorIndex: 3},
+//				{DescriptorIndex: 4},
+//				{DescriptorIndex: 5},
+//				{DescriptorIndex: 6},
+//				{DescriptorIndex: 7},
+//				{DescriptorIndex: 8},
+//			},
+//			ringIndex: 2,
+//			lastIndex: 65535,
+//			expected: []UsedElement{
+//				{DescriptorIndex: 8},
+//				{DescriptorIndex: 9},
+//				{DescriptorIndex: 10},
+//			},
+//		},
+//	}
+//	for _, tt := range tests {
+//		t.Run(tt.name, func(t *testing.T) {
+//			memory := make([]byte, usedRingSize(queueSize))
+//			r := newUsedRing(queueSize, memory)
+//
+//			copy(r.ring, tt.ring)
+//			*r.ringIndex = tt.ringIndex
+//			r.lastIndex = tt.lastIndex
+//
+//			assert.Equal(t, tt.expected, r.take())
+//		})
+//	}
+//}

+ 70 - 0
packet/outpacket.go

@@ -0,0 +1,70 @@
+package packet
+
+import (
+	"github.com/slackhq/nebula/util/virtio"
+	"golang.org/x/sys/unix"
+)
+
+type OutPacket struct {
+	Segments        [][]byte
+	SegmentPayloads [][]byte
+	SegmentHeaders  [][]byte
+	SegmentIDs      []uint16
+	//todo virtio header?
+	SegSize      int
+	SegCounter   int
+	Valid        bool
+	wasSegmented bool
+
+	Scratch []byte
+}
+
+func NewOut() *OutPacket {
+	out := new(OutPacket)
+	out.Segments = make([][]byte, 0, 64)
+	out.SegmentHeaders = make([][]byte, 0, 64)
+	out.SegmentPayloads = make([][]byte, 0, 64)
+	out.SegmentIDs = make([]uint16, 0, 64)
+	out.Scratch = make([]byte, Size)
+	return out
+}
+
+func (pkt *OutPacket) Reset() {
+	pkt.Segments = pkt.Segments[:0]
+	pkt.SegmentPayloads = pkt.SegmentPayloads[:0]
+	pkt.SegmentHeaders = pkt.SegmentHeaders[:0]
+	pkt.SegmentIDs = pkt.SegmentIDs[:0]
+	pkt.SegSize = 0
+	pkt.Valid = false
+	pkt.wasSegmented = false
+}
+
+func (pkt *OutPacket) UseSegment(segID uint16, seg []byte, isV6 bool) int {
+	pkt.Valid = true
+	pkt.SegmentIDs = append(pkt.SegmentIDs, segID)
+	pkt.Segments = append(pkt.Segments, seg) //todo do we need this?
+
+	vhdr := virtio.NetHdr{ //todo
+		Flags:      unix.VIRTIO_NET_HDR_F_DATA_VALID,
+		GSOType:    unix.VIRTIO_NET_HDR_GSO_NONE,
+		HdrLen:     0,
+		GSOSize:    0,
+		CsumStart:  0,
+		CsumOffset: 0,
+		NumBuffers: 0,
+	}
+
+	hdr := seg[0 : virtio.NetHdrSize+14]
+	_ = vhdr.Encode(hdr)
+	if isV6 {
+		hdr[virtio.NetHdrSize+14-2] = 0x86
+		hdr[virtio.NetHdrSize+14-1] = 0xdd
+	} else {
+		hdr[virtio.NetHdrSize+14-2] = 0x08
+		hdr[virtio.NetHdrSize+14-1] = 0x00
+	}
+
+	pkt.SegmentHeaders = append(pkt.SegmentHeaders, hdr)
+	pkt.SegmentPayloads = append(pkt.SegmentPayloads, seg[virtio.NetHdrSize+14:])
+	return len(pkt.SegmentIDs) - 1
+}

+ 119 - 0
packet/packet.go

@@ -0,0 +1,119 @@
+package packet
+
+import (
+	"encoding/binary"
+	"iter"
+	"net/netip"
+	"slices"
+	"syscall"
+	"unsafe"
+
+	"golang.org/x/sys/unix"
+)
+
+const Size = 0xffff
+
+type Packet struct {
+	Payload []byte
+	Control []byte
+	Name    []byte
+	SegSize int
+
+	//todo should this hold out as well?
+	OutLen int
+
+	wasSegmented bool
+	isV4         bool
+}
+
+func New(isV4 bool) *Packet {
+	return &Packet{
+		Payload: make([]byte, Size),
+		Control: make([]byte, unix.CmsgSpace(2)),
+		Name:    make([]byte, unix.SizeofSockaddrInet6),
+		isV4:    isV4,
+	}
+}
+
+func (p *Packet) AddrPort() netip.AddrPort {
+	var ip netip.Addr
+	// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
+	if p.isV4 {
+		ip, _ = netip.AddrFromSlice(p.Name[4:8])
+	} else {
+		ip, _ = netip.AddrFromSlice(p.Name[8:24])
+	}
+	return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
+}
+
+func (p *Packet) updateCtrl(ctrlLen int) {
+	p.SegSize = len(p.Payload)
+	p.wasSegmented = false
+	if ctrlLen == 0 {
+		return
+	}
+	if len(p.Control) == 0 {
+		return
+	}
+	cmsgs, err := unix.ParseSocketControlMessage(p.Control)
+	if err != nil {
+		return // oh well
+	}
+
+	for _, c := range cmsgs {
+		if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
+			p.wasSegmented = true
+			p.SegSize = int(binary.LittleEndian.Uint16(c.Data[:2]))
+			return
+		}
+	}
+}
+
+// Update sets a Packet into "just received, not processed" state
+func (p *Packet) Update(ctrlLen int) {
+	p.OutLen = -1
+	p.updateCtrl(ctrlLen)
+}
+
+func (p *Packet) SetSegSizeForTX() {
+	p.SegSize = len(p.Payload)
+	hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
+	hdr.Level = unix.SOL_UDP
+	hdr.Type = unix.UDP_SEGMENT
+	hdr.SetLen(syscall.CmsgLen(2))
+	binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
+}
+
+func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
+	//same dest
+	if !slices.Equal(p.Name, otherP.Name) {
+		return false
+	}
+
+	//don't get too big
+	if len(p.Payload)+currentTotalSize >= 0xffff {
+		return false
+	}
+
+	//same body len
+	//todo allow single different size at end
+	if len(p.Payload) != len(otherP.Payload) {
+		return false //todo technically you can cram one extra in
+	}
+	return true
+}
+
+func (p *Packet) Segments() iter.Seq[[]byte] {
+	return func(yield func([]byte) bool) {
+		//cursor := 0
+		for offset := 0; offset < len(p.Payload); offset += p.SegSize {
+			end := offset + p.SegSize
+			if end > len(p.Payload) {
+				end = len(p.Payload)
+			}
+			if !yield(p.Payload[offset:end]) {
+				return
+			}
+		}
+	}
+}

+ 37 - 0
packet/virtio.go

@@ -0,0 +1,37 @@
+package packet
+
+import (
+	"github.com/slackhq/nebula/util/virtio"
+)
+
+type VirtIOPacket struct {
+	Payload   []byte
+	Header    virtio.NetHdr
+	Chains    []uint16
+	ChainRefs [][]byte
+	// OfferDescriptorChains(chains []uint16, kick bool) error
+}
+
+func NewVIO() *VirtIOPacket {
+	out := new(VirtIOPacket)
+	out.Payload = nil
+	out.ChainRefs = make([][]byte, 0, 4)
+	out.Chains = make([]uint16, 0, 8)
+	return out
+}
+
+func (v *VirtIOPacket) Reset() {
+	v.Payload = nil
+	v.ChainRefs = v.ChainRefs[:0]
+	v.Chains = v.Chains[:0]
+}
+
+type VirtIOTXPacket struct {
+	VirtIOPacket
+}
+
+func NewVIOTX(isV4 bool) *VirtIOTXPacket {
+	out := new(VirtIOTXPacket)
+	out.VirtIOPacket = *NewVIO()
+	return out
+}

+ 4 - 2
udp/conn.go

@@ -4,13 +4,13 @@ import (
 	"net/netip"
 
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/packet"
 )
 
 const MTU = 9001
 
 type EncReader func(
-	addr netip.AddrPort,
-	payload []byte,
+	[]*packet.Packet,
 )
 
 type Conn interface {
@@ -19,6 +19,8 @@ type Conn interface {
 	ListenOut(r EncReader)
 	WriteTo(b []byte, addr netip.AddrPort) error
 	ReloadConfig(c *config.C)
+	Prep(pkt *packet.Packet, addr netip.AddrPort) error
+	WriteBatch(pkt []*packet.Packet) (int, error)
 	SupportsMultipleReaders() bool
 	Close() error
 }

+ 195 - 23
udp/udp_linux.go

@@ -14,22 +14,22 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/packet"
 	"golang.org/x/sys/unix"
 )
 
-type StdConn struct {
-	sysFd int
-	isV4  bool
-	l     *logrus.Logger
-	batch int
-}
+const iovMax = 128 //1024 //no unix constant for this? from limits.h
+//todo I'd like this to be 1024 but we seem to hit errors around ~130?
 
-func maybeIPV4(ip net.IP) (net.IP, bool) {
-	ip4 := ip.To4()
-	if ip4 != nil {
-		return ip4, true
-	}
-	return ip, false
+type StdConn struct {
+	sysFd     int
+	isV4      bool
+	l         *logrus.Logger
+	batch     int
+	enableGRO bool
+
+	msgs []rawMessage
+	iovs [][]iovec
 }
 
 func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -69,7 +69,20 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
 		return nil, fmt.Errorf("unable to bind to socket: %s", err)
 	}
 
-	return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
+	const batchSize = 8192
+	msgs := make([]rawMessage, 0, batchSize) //todo configure
+	iovs := make([][]iovec, batchSize)
+	for i := range iovs {
+		iovs[i] = make([]iovec, iovMax)
+	}
+	return &StdConn{
+		sysFd: fd,
+		isV4:  ip.Is4(),
+		l:     l,
+		batch: batch,
+		msgs:  msgs,
+		iovs:  iovs,
+	}, err
 }
 
 func (u *StdConn) SupportsMultipleReaders() bool {
@@ -123,9 +136,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 }
 
 func (u *StdConn) ListenOut(r EncReader) {
-	var ip netip.Addr
-
-	msgs, buffers, names := u.PrepareRawMessages(u.batch)
+	msgs, packets := u.PrepareRawMessages(u.batch, u.isV4)
 	read := u.ReadMulti
 	if u.batch == 1 {
 		read = u.ReadSingle
@@ -139,13 +150,12 @@ func (u *StdConn) ListenOut(r EncReader) {
 		}
 
 		for i := 0; i < n; i++ {
-			// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
-			if u.isV4 {
-				ip, _ = netip.AddrFromSlice(names[i][4:8])
-			} else {
-				ip, _ = netip.AddrFromSlice(names[i][8:24])
-			}
-			r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
+			packets[i].Payload = packets[i].Payload[:msgs[i].Len]
+			packets[i].Update(getRawMessageControlLen(&msgs[i]))
+		}
+		r(packets[:n])
+		for i := 0; i < n; i++ { //todo reset this in prev loop, but this makes debug ez
+			msgs[i].Hdr.Controllen = uint64(unix.CmsgSpace(2))
 		}
 	}
 }
@@ -198,6 +208,147 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
 	return u.writeTo6(b, ip)
 }
 
+func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
+	if u.isV4 {
+		return u.writeTo4(b, ip)
+	}
+	return u.writeTo6(b, ip)
+}
+
+func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
+	nl, err := u.encodeSockaddr(pkt.Name, addr)
+	if err != nil {
+		return err
+	}
+	pkt.Name = pkt.Name[:nl]
+	pkt.OutLen = len(pkt.Payload)
+	return nil
+}
+
+func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
+	if len(pkts) == 0 {
+		return 0, nil
+	}
+
+	u.msgs = u.msgs[:0]
+	//u.iovs = u.iovs[:0]
+
+	sent := 0
+	var mostRecentPkt *packet.Packet
+	mostRecentPktSize := 0
+	//segmenting := false
+	idx := 0
+	for _, pkt := range pkts {
+		if len(pkt.Payload) == 0 || pkt.OutLen == -1 {
+			sent++
+			continue
+		}
+		lastIdx := idx - 1
+		if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt, mostRecentPktSize) && u.msgs[lastIdx].Hdr.Iovlen < iovMax {
+			u.msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control))
+			u.msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0]
+
+			u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Base = &pkt.Payload[0]
+			u.iovs[lastIdx][u.msgs[lastIdx].Hdr.Iovlen].Len = uint64(len(pkt.Payload))
+			u.msgs[lastIdx].Hdr.Iovlen++
+
+			mostRecentPktSize += len(pkt.Payload)
+			mostRecentPkt.SetSegSizeForTX()
+		} else {
+			u.msgs = append(u.msgs, rawMessage{})
+			u.iovs[idx][0] = iovec{
+				Base: &pkt.Payload[0],
+				Len:  uint64(len(pkt.Payload)),
+			}
+
+			msg := &u.msgs[idx]
+			iov := &u.iovs[idx][0]
+			idx++
+
+			msg.Hdr.Iov = iov
+			msg.Hdr.Iovlen = 1
+			setRawMessageControl(msg, nil)
+			msg.Hdr.Flags = 0
+
+			msg.Hdr.Name = &pkt.Name[0]
+			msg.Hdr.Namelen = uint32(len(pkt.Name))
+			mostRecentPkt = pkt
+			mostRecentPktSize = len(pkt.Payload)
+		}
+	}
+
+	if len(u.msgs) == 0 {
+		return sent, nil
+	}
+
+	offset := 0
+	for offset < len(u.msgs) {
+		n, _, errno := unix.Syscall6(
+			unix.SYS_SENDMMSG,
+			uintptr(u.sysFd),
+			uintptr(unsafe.Pointer(&u.msgs[offset])),
+			uintptr(len(u.msgs)-offset),
+			0,
+			0,
+			0,
+		)
+
+		if errno != 0 {
+			if errno == unix.EINTR {
+				continue
+			}
+			//for i := 0; i < len(u.msgs); i++ {
+			//	for j := 0; j < int(u.msgs[i].Hdr.Iovlen); j++ {
+			//		u.l.WithFields(logrus.Fields{
+			//			"msg_index": i,
+			//			"iov idx":   j,
+			//			"iov":       fmt.Sprintf("%+v", u.iovs[i][j]),
+			//		}).Warn("failed to send message")
+			//	}
+			//
+			//}
+			u.l.WithFields(logrus.Fields{
+				"errno":   errno,
+				"idx":     idx,
+				"len":     len(u.msgs),
+				"deets":   fmt.Sprintf("%+v", u.msgs),
+				"lastIOV": fmt.Sprintf("%+v", u.iovs[len(u.msgs)-1][u.msgs[len(u.msgs)-1].Hdr.Iovlen-1]),
+			}).Error("failed to send message")
+			return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
+		}
+
+		if n == 0 {
+			break
+		}
+		offset += int(n)
+	}
+
+	return sent + len(u.msgs), nil
+}
+
+func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
+	if u.isV4 {
+		if !addr.Addr().Is4() {
+			return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
+		}
+		var sa unix.RawSockaddrInet4
+		sa.Family = unix.AF_INET
+		sa.Addr = addr.Addr().As4()
+		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
+		size := unix.SizeofSockaddrInet4
+		copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
+		return uint32(size), nil
+	}
+
+	var sa unix.RawSockaddrInet6
+	sa.Family = unix.AF_INET6
+	sa.Addr = addr.Addr().As16()
+	binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
+	size := unix.SizeofSockaddrInet6
+	copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
+	return uint32(size), nil
+}
+
 func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
 	var rsa unix.RawSockaddrInet6
 	rsa.Family = unix.AF_INET6
@@ -298,6 +449,27 @@ func (u *StdConn) ReloadConfig(c *config.C) {
 			u.l.WithError(err).Error("Failed to set listen.so_mark")
 		}
 	}
+	u.configureGRO(true)
+}
+
+func (u *StdConn) configureGRO(enable bool) {
+	if enable == u.enableGRO {
+		return
+	}
+
+	if enable {
+		if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
+			u.l.WithError(err).Warn("Failed to enable UDP GRO")
+			return
+		}
+		u.enableGRO = true
+		u.l.Info("UDP GRO enabled")
+	} else {
+		if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
+			u.l.WithError(err).Warn("Failed to disable UDP GRO")
+		}
+		u.enableGRO = false
+	}
 }
 
 func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {

+ 44 - 9
udp/udp_linux_64.go

@@ -7,6 +7,7 @@
 package udp
 
 import (
+	"github.com/slackhq/nebula/packet"
 	"golang.org/x/sys/unix"
 )
 
@@ -33,25 +34,59 @@ type rawMessage struct {
 	Pad0 [4]byte
 }
 
-func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
+func setRawMessageControl(msg *rawMessage, buf []byte) {
+	if len(buf) == 0 {
+		msg.Hdr.Control = nil
+		msg.Hdr.Controllen = 0
+		return
+	}
+	msg.Hdr.Control = &buf[0]
+	msg.Hdr.Controllen = uint64(len(buf))
+}
+
+func getRawMessageControlLen(msg *rawMessage) int {
+	return int(msg.Hdr.Controllen)
+}
+
+func setCmsgLen(h *unix.Cmsghdr, l int) {
+	h.Len = uint64(l)
+}
+
+func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
 	msgs := make([]rawMessage, n)
-	buffers := make([][]byte, n)
-	names := make([][]byte, n)
+	packets := make([]*packet.Packet, n)
 
 	for i := range msgs {
-		buffers[i] = make([]byte, MTU)
-		names[i] = make([]byte, unix.SizeofSockaddrInet6)
+		packets[i] = packet.New(isV4)
 
 		vs := []iovec{
-			{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
+			{Base: &packets[i].Payload[0], Len: uint64(packet.Size)},
 		}
 
 		msgs[i].Hdr.Iov = &vs[0]
 		msgs[i].Hdr.Iovlen = uint64(len(vs))
 
-		msgs[i].Hdr.Name = &names[i][0]
-		msgs[i].Hdr.Namelen = uint32(len(names[i]))
+		msgs[i].Hdr.Name = &packets[i].Name[0]
+		msgs[i].Hdr.Namelen = uint32(len(packets[i].Name))
+
+		if u.enableGRO {
+			msgs[i].Hdr.Control = &packets[i].Control[0]
+			msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
+		} else {
+			msgs[i].Hdr.Control = nil
+			msgs[i].Hdr.Controllen = 0
+		}
 	}
 
-	return msgs, buffers, names
+	return msgs, packets
+}
+
+func setIovecSlice(iov *iovec, b []byte) {
+	if len(b) == 0 {
+		iov.Base = nil
+		iov.Len = 0
+		return
+	}
+	iov.Base = &b[0]
+	iov.Len = uint64(len(b))
 }

+ 3 - 0
util/virtio/doc.go

@@ -0,0 +1,3 @@
+// Package virtio contains some generic types and concepts related to the virtio
+// protocol.
+package virtio

+ 136 - 0
util/virtio/features.go

@@ -0,0 +1,136 @@
+package virtio
+
+// Feature contains feature bits that describe a virtio device or driver.
+type Feature uint64
+
+// Device-independent feature bits.
+//
+// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-6600006
+const (
+	// FeatureIndirectDescriptors indicates that the driver can use descriptors
+	// with an additional layer of indirection.
+	FeatureIndirectDescriptors Feature = 1 << 28
+
+	// FeatureVersion1 indicates compliance with version 1.0 of the virtio
+	// specification.
+	FeatureVersion1 Feature = 1 << 32
+)
+
+// Feature bits for networking devices.
+//
+// Source: https://docs.oasis-open.org/virtio/virtio/v1.2/csd01/virtio-v1.2-csd01.html#x1-2200003
+const (
+	// FeatureNetDeviceCsum indicates that the device can handle packets with
+	// partial checksum (checksum offload).
+	FeatureNetDeviceCsum Feature = 1 << 0
+
+	// FeatureNetDriverCsum indicates that the driver can handle packets with
+	// partial checksum.
+	FeatureNetDriverCsum Feature = 1 << 1
+
+	// FeatureNetCtrlDriverOffloads indicates support for dynamic offload state
+	// reconfiguration.
+	FeatureNetCtrlDriverOffloads Feature = 1 << 2
+
+	// FeatureNetMTU indicates that the device reports a maximum MTU value.
+	FeatureNetMTU Feature = 1 << 3
+
+	// FeatureNetMAC indicates that the device provides a MAC address.
+	FeatureNetMAC Feature = 1 << 5
+
+	// FeatureNetDriverTSO4 indicates that the driver supports the TCP
+	// segmentation offload for received IPv4 packets.
+	FeatureNetDriverTSO4 Feature = 1 << 7
+
+	// FeatureNetDriverTSO6 indicates that the driver supports the TCP
+	// segmentation offload for received IPv6 packets.
+	FeatureNetDriverTSO6 Feature = 1 << 8
+
+	// FeatureNetDriverECN indicates that the driver supports the TCP
+	// segmentation offload with ECN for received packets.
+	FeatureNetDriverECN Feature = 1 << 9
+
+	// FeatureNetDriverUFO indicates that the driver supports the UDP
+	// fragmentation offload for received packets.
+	FeatureNetDriverUFO Feature = 1 << 10
+
+	// FeatureNetDeviceTSO4 indicates that the device supports the TCP
+	// segmentation offload for received IPv4 packets.
+	FeatureNetDeviceTSO4 Feature = 1 << 11
+
+	// FeatureNetDeviceTSO6 indicates that the device supports the TCP
+	// segmentation offload for received IPv6 packets.
+	FeatureNetDeviceTSO6 Feature = 1 << 12
+
+	// FeatureNetDeviceECN indicates that the device supports the TCP
+	// segmentation offload with ECN for received packets.
+	FeatureNetDeviceECN Feature = 1 << 13
+
+	// FeatureNetDeviceUFO indicates that the device supports the UDP
+	// fragmentation offload for received packets.
+	FeatureNetDeviceUFO Feature = 1 << 14
+
+	// FeatureNetMergeRXBuffers indicates that the driver can handle merged
+	// receive buffers.
+	// When this feature is negotiated, devices may merge multiple descriptor
+	// chains together to transport large received packets. [NetHdr.NumBuffers]
+	// will then contain the number of merged descriptor chains.
+	FeatureNetMergeRXBuffers Feature = 1 << 15
+
+	// FeatureNetStatus indicates that the device configuration status field is
+	// available.
+	FeatureNetStatus Feature = 1 << 16
+
+	// FeatureNetCtrlVQ indicates that a control channel virtqueue is
+	// available.
+	FeatureNetCtrlVQ Feature = 1 << 17
+
+	// FeatureNetCtrlRX indicates support for RX mode control (e.g. promiscuous
+	// or all-multicast) for packet receive filtering.
+	FeatureNetCtrlRX Feature = 1 << 18
+
+	// FeatureNetCtrlVLAN indicates support for VLAN filtering through the
+	// control channel.
+	FeatureNetCtrlVLAN Feature = 1 << 19
+
+	// FeatureNetDriverAnnounce indicates that the driver can send gratuitous
+	// packets.
+	FeatureNetDriverAnnounce Feature = 1 << 21
+
+	// FeatureNetMQ indicates that the device supports multiqueue with automatic
+	// receive steering.
+	FeatureNetMQ Feature = 1 << 22
+
+	// FeatureNetCtrlMACAddr indicates that the MAC address can be set through
+	// the control channel.
+	FeatureNetCtrlMACAddr Feature = 1 << 23
+
+	// FeatureNetDeviceUSO indicates that the device supports the UDP
+	// segmentation offload for received packets.
+	FeatureNetDeviceUSO Feature = 1 << 56
+
+	// FeatureNetHashReport indicates that the device can report a per-packet
+	// hash value and type.
+	FeatureNetHashReport Feature = 1 << 57
+
+	// FeatureNetDriverHdrLen indicates that the driver can provide the exact
+	// header length value (see [NetHdr.HdrLen]).
+	// Devices may benefit from knowing the exact header length.
+	FeatureNetDriverHdrLen Feature = 1 << 59
+
+	// FeatureNetRSS indicates that the device supports RSS (receive-side
+	// scaling) with configurable hash parameters.
+	FeatureNetRSS Feature = 1 << 60
+
+	// FeatureNetRSCExt indicates that the device can process duplicated ACKs
+	// and report the number of coalesced segments and duplicated ACKs.
+	FeatureNetRSCExt Feature = 1 << 61
+
+	// FeatureNetStandby indicates that the device may act as a standby for a
+	// primary device with the same MAC address.
+	FeatureNetStandby Feature = 1 << 62
+
+	// FeatureNetSpeedDuplex indicates that the device can report link speed and
+	// duplex mode.
+	FeatureNetSpeedDuplex Feature = 1 << 63
+)

+ 77 - 0
util/virtio/net_hdr.go

@@ -0,0 +1,77 @@
+package virtio
+
+import (
+	"errors"
+	"unsafe"
+
+	"golang.org/x/sys/unix"
+)
+
+// Workaround to make Go doc links work.
+var _ unix.Errno
+
+// NetHdrSize is the number of bytes needed to store a [NetHdr] in memory.
+const NetHdrSize = 12
+
+// ErrNetHdrBufferTooSmall is returned when a buffer is too small to fit a
+// virtio_net_hdr.
+var ErrNetHdrBufferTooSmall = errors.New("the buffer is too small to fit a virtio_net_hdr")
+
+// NetHdr defines the virtio_net_hdr as described by the virtio specification.
+type NetHdr struct {
+	// Flags that describe the packet.
+	// Possible values are:
+	//   - [unix.VIRTIO_NET_HDR_F_NEEDS_CSUM]
+	//   - [unix.VIRTIO_NET_HDR_F_DATA_VALID]
+	//   - [unix.VIRTIO_NET_HDR_F_RSC_INFO]
+	Flags uint8
+	// GSOType contains the type of segmentation offload that should be used for
+	// the packet.
+	// Possible values are:
+	//   - [unix.VIRTIO_NET_HDR_GSO_NONE]
+	//   - [unix.VIRTIO_NET_HDR_GSO_TCPV4]
+	//   - [unix.VIRTIO_NET_HDR_GSO_UDP]
+	//   - [unix.VIRTIO_NET_HDR_GSO_TCPV6]
+	//   - [unix.VIRTIO_NET_HDR_GSO_UDP_L4]
+	//   - [unix.VIRTIO_NET_HDR_GSO_ECN]
+	GSOType uint8
+	// HdrLen contains the length of the headers that need to be replicated by
+	// segmentation offloads. It's the number of bytes from the beginning of the
+	// packet to the beginning of the transport payload.
+	// Only used when [FeatureNetDriverHdrLen] is negotiated.
+	HdrLen uint16
+	// GSOSize contains the maximum size of each segmented packet beyond the
+	// header (payload size). In case of TCP, this is the MSS.
+	GSOSize uint16
+	// CsumStart contains the offset within the packet from which on the
+	// checksum should be computed.
+	CsumStart uint16
+	// CsumOffset specifies how many bytes after [NetHdr.CsumStart] the computed
+	// 16-bit checksum should be inserted.
+	CsumOffset uint16
+	// NumBuffers contains the number of merged descriptor chains when
+	// [FeatureNetMergeRXBuffers] is negotiated.
+	// This field is only used for packets received by the driver and should be
+	// zero for transmitted packets.
+	NumBuffers uint16
+}
+
+// Decode decodes the [NetHdr] from the given byte slice. The slice must contain
+// at least [NetHdrSize] bytes.
+func (v *NetHdr) Decode(data []byte) error {
+	if len(data) < NetHdrSize {
+		return ErrNetHdrBufferTooSmall
+	}
+	copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize), data[:NetHdrSize])
+	return nil
+}
+
+// Encode encodes the [NetHdr] into the given byte slice. The slice must have
+// room for at least [NetHdrSize] bytes.
+func (v *NetHdr) Encode(data []byte) error {
+	if len(data) < NetHdrSize {
+		return ErrNetHdrBufferTooSmall
+	}
+	copy(data[:NetHdrSize], unsafe.Slice((*byte)(unsafe.Pointer(v)), NetHdrSize))
+	return nil
+}

+ 43 - 0
util/virtio/net_hdr_test.go

@@ -0,0 +1,43 @@
+package virtio
+
+import (
+	"testing"
+	"unsafe"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"golang.org/x/sys/unix"
+)
+
+func TestNetHdr_Size(t *testing.T) {
+	assert.EqualValues(t, NetHdrSize, unsafe.Sizeof(NetHdr{}))
+}
+
+func TestNetHdr_Encoding(t *testing.T) {
+	vnethdr := NetHdr{
+		Flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+		GSOType:    unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+		HdrLen:     42,
+		GSOSize:    1472,
+		CsumStart:  34,
+		CsumOffset: 6,
+		NumBuffers: 16,
+	}
+
+	buf := make([]byte, NetHdrSize)
+	require.NoError(t, vnethdr.Encode(buf))
+
+	assert.Equal(t, []byte{
+		0x01, 0x05,
+		0x2a, 0x00,
+		0xc0, 0x05,
+		0x22, 0x00,
+		0x06, 0x00,
+		0x10, 0x00,
+	}, buf)
+
+	var decoded NetHdr
+	require.NoError(t, decoded.Decode(buf))
+
+	assert.Equal(t, vnethdr, decoded)
+}