Parcourir la source

refactoring a bit

JackDoan il y a 5 jours
Parent
commit
41c9a3b2eb

+ 2 - 2
inside.go

@@ -13,7 +13,7 @@ import (
 	"github.com/slackhq/nebula/routing"
 )
 
-func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
+func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.UDPPacket, q int, localCache firewall.ConntrackCache, now time.Time) {
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
 		if f.l.Level >= logrus.DebugLevel {
@@ -412,7 +412,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	}
 }
 
-func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
+func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.UDPPacket, q int) {
 	if ci.eKey == nil {
 		return
 	}

+ 6 - 7
interface.go

@@ -294,7 +294,7 @@ func (f *Interface) listenOut(q int) {
 
 	toSend := make([][]byte, batch)
 
-	li.ListenOut(func(pkts []*packet.Packet) {
+	li.ListenOut(func(pkts []*packet.UDPPacket) {
 		toSend = toSend[:0]
 		for i := range outPackets {
 			outPackets[i].SegCounter = 0
@@ -323,11 +323,11 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
 
 	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 
-	packets := make([]*packet.VirtIOPacket, batch)
-	outPackets := make([]*packet.Packet, batch)
+	packets := reader.NewPacketArrays(batch)
+
+	outPackets := make([]*packet.UDPPacket, batch)
 	for i := 0; i < batch; i++ {
-		packets[i] = packet.NewVIO()
-		outPackets[i] = packet.New(false) //todo?
+		outPackets[i] = packet.New(false) //todo isv4?
 	}
 
 	for {
@@ -352,9 +352,8 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
 		now := time.Now()
 		for i, pkt := range packets[:n] {
 			outPackets[i].ReadyToSend = false
-			f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
+			f.consumeInsidePacket(pkt.GetPayload(), fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
 			reader.RecycleRxSeg(pkt, i == (n-1), queueNum) //todo handle err?
-			pkt.Reset()
 		}
 		_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
 		if err != nil {

+ 1 - 1
outside.go

@@ -359,7 +359,7 @@ func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packe
 	f.connectionManager.In(hostinfo)
 }
 
-func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
+func (f *Interface) readOutsidePacketsMany(packets []*packet.UDPPacket, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
 	for i, pkt := range packets {
 		out[i].Scratch = out[i].Scratch[:0]
 		via := ViaSender{UdpAddr: pkt.AddrPort()}

+ 36 - 0
overlay/packets.go

@@ -0,0 +1,36 @@
+package overlay
+
+//import (
+//	"github.com/slackhq/nebula/util/virtio"
+//)
+
+//type VirtIOPacket struct {
+//	Payload   []byte
+//	Header    virtio.NetHdr
+//	Chains    []uint16
+//	ChainRefs [][]byte
+//}
+//
+//func NewVIO() *VirtIOPacket {
+//	out := new(VirtIOPacket)
+//	out.Payload = nil
+//	out.ChainRefs = make([][]byte, 0, 4)
+//	out.Chains = make([]uint16, 0, 8)
+//	return out
+//}
+//
+//func (v *VirtIOPacket) Reset() {
+//	v.Payload = nil
+//	v.ChainRefs = v.ChainRefs[:0]
+//	v.Chains = v.Chains[:0]
+//}
+
+// TunPacket is formerly VirtIOPacket
+type TunPacket interface {
+	SetPayload([]byte)
+	GetPayload() []byte
+}
+type OutPacket interface {
+	SetPayload([]byte)
+	GetPayload() []byte
+}

+ 6 - 4
overlay/tun.go

@@ -16,13 +16,15 @@ const DefaultMTU = 1300
 
 type TunDev interface {
 	io.WriteCloser
-	ReadMany(x []*packet.VirtIOPacket, q int) (int, error)
+	NewPacketArrays(batchSize int) []TunPacket
+
+	ReadMany(x []TunPacket, q int) (int, error)
+	RecycleRxSeg(pkt TunPacket, kick bool, q int) error
 
 	//todo this interface sux
 	AllocSeg(pkt *packet.OutPacket, q int) (int, error)
 	WriteOne(x *packet.OutPacket, kick bool, q int) (int, error)
 	WriteMany(x []*packet.OutPacket, q int) (int, error)
-	RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error
 }
 
 // TODO: We may be able to remove routines
@@ -31,8 +33,8 @@ type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefi
 func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
 	switch {
 	case c.GetBool("tun.disabled", false):
-		tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
-		return tun, nil
+		t := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
+		return t, nil
 
 	default:
 		return newTun(c, l, vpnNetworks, routines > 1)

+ 7 - 3
overlay/tun_disabled.go

@@ -24,7 +24,11 @@ type disabledTun struct {
 	l  *logrus.Logger
 }
 
-func (*disabledTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
+func (t *disabledTun) NewPacketArrays(batchSize int) []TunPacket {
+	panic("implement me") //TODO
+}
+
+func (*disabledTun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
 	return nil
 }
 
@@ -131,8 +135,8 @@ func (t *disabledTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
 	return 0, fmt.Errorf("tun_disabled: WriteMany not implemented")
 }
 
-func (t *disabledTun) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
-	return t.Read(b[0].Payload)
+func (t *disabledTun) ReadMany(b []TunPacket, _ int) (int, error) {
+	return t.Read(b[0].GetPayload())
 }
 
 func (t *disabledTun) NewMultiQueueReader() (TunDev, error) {

+ 30 - 5
overlay/tun_linux.go

@@ -4,6 +4,7 @@
 package overlay
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"net/netip"
@@ -183,6 +184,14 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
 	return t, nil
 }
 
+func (t *tun) NewPacketArrays(batchSize int) []TunPacket {
+	inPackets := make([]TunPacket, batchSize)
+	for i := 0; i < batchSize; i++ {
+		inPackets[i] = vhostnet.NewVIO()
+	}
+	return inPackets
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
 	routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
@@ -725,12 +734,25 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func (t *tun) ReadMany(p []*packet.VirtIOPacket, q int) (int, error) {
-	n, err := t.vdev[q].ReceivePackets(p) //we are TXing
+func (t *tun) ReadMany(p []TunPacket, q int) (int, error) {
+	err := t.vdev[q].ReceiveQueue.WaitForUsedElements(context.TODO())
 	if err != nil {
 		return 0, err
 	}
-	return n, nil
+	i := 0
+	for i = 0; i < len(p); i++ {
+		item, ok := t.vdev[q].ReceiveQueue.TakeSingleNoBlock()
+		if !ok {
+			break
+		}
+		pkt := p[i].(*vhostnet.VirtIOPacket) //todo I'm not happy about this but I don't want to change how memory is "owned" rn
+		_, err = t.vdev[q].ProcessRxChain(pkt, item)
+		if err != nil {
+			return i, err
+		}
+		i++
+	}
+	return i, nil
 }
 
 func (t *tun) Write(b []byte) (int, error) {
@@ -783,6 +805,9 @@ func (t *tun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
 	return maximum, nil
 }
 
-func (t *tun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
-	return t.vdev[q].ReceiveQueue.OfferDescriptorChains(pkt.Chains, kick)
+func (t *tun) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
+	vpkt := pkt.(*vhostnet.VirtIOPacket)
+	err := t.vdev[q].ReceiveQueue.OfferDescriptorChains(vpkt.Chains, kick)
+	vpkt.Reset() //intentionally ignoring err!
+	return err
 }

+ 2 - 2
overlay/tun_tester.go

@@ -106,7 +106,7 @@ func (t *TestTun) Name() string {
 	return t.Device
 }
 
-func (t *TestTun) ReadMany(x []*packet.VirtIOPacket, q int) (int, error) {
+func (t *TestTun) ReadMany(x []TunPacket, q int) (int, error) {
 	p, ok := <-t.rxPackets
 	if !ok {
 		return 0, os.ErrClosed
@@ -165,7 +165,7 @@ func (t *TestTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
 	return len(x), nil
 }
 
-func (t *TestTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
+func (t *TestTun) RecycleRxSeg(pkt *TunPacket, kick bool, q int) error {
 	//todo this ought to maybe track something
 	return nil
 }

+ 18 - 3
overlay/user.go

@@ -38,7 +38,18 @@ type UserDevice struct {
 	inboundWriter *io.PipeWriter
 }
 
-func (d *UserDevice) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
+func (d *UserDevice) NewPacketArrays(batchSize int) []TunPacket {
+	//inPackets := make([]TunPacket, batchSize)
+	//outPackets := make([]OutPacket, batchSize)
+	panic("not implemented") //todo!
+	//for i := 0; i < batchSize; i++ {
+	//	inPackets[i] = vhostnet.NewVIO()
+	//	outPackets[i] = packet.New(false)
+	//}
+	//return inPackets, outPackets
+}
+
+func (d *UserDevice) RecycleRxSeg(pkt TunPacket, kick bool, q int) error {
 	return nil
 }
 
@@ -76,8 +87,12 @@ func (d *UserDevice) Close() error {
 	return nil
 }
 
-func (d *UserDevice) ReadMany(b []*packet.VirtIOPacket, _ int) (int, error) {
-	return d.Read(b[0].Payload)
+func (d *UserDevice) ReadMany(b []TunPacket, _ int) (int, error) {
+	_, err := d.Read(b[0].GetPayload())
+	if err != nil {
+		return 0, err
+	}
+	return 1, nil
 }
 
 func (d *UserDevice) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {

+ 38 - 72
overlay/vhostnet/device.go

@@ -118,7 +118,7 @@ func NewDevice(options ...Option) (*Device, error) {
 		return nil, fmt.Errorf("set transmit queue backend: %w", err)
 	}
 
-	// Fully populate the receive queue with available buffers which the device
+	// Fully populate the rx queue with available buffers which the device
 	// can write new packets into.
 	if err = dev.refillReceiveQueue(); err != nil {
 		return nil, fmt.Errorf("refill receive queue: %w", err)
@@ -198,11 +198,8 @@ func (dev *Device) Close() error {
 // createQueue creates a new virtqueue and registers it with the vhost device
 // using the given index.
 func createQueue(controlFD int, queueIndex int, queueSize int, itemSize int) (*virtqueue.SplitQueue, error) {
-	var (
-		queue *virtqueue.SplitQueue
-		err   error
-	)
-	if queue, err = virtqueue.NewSplitQueue(queueSize, itemSize); err != nil {
+	queue, err := virtqueue.NewSplitQueue(queueSize, itemSize)
+	if err != nil {
 		return nil, fmt.Errorf("create virtqueue: %w", err)
 	}
 	if err = vhost.RegisterQueue(controlFD, uint32(queueIndex), queue); err != nil {
@@ -218,10 +215,10 @@ func (dev *Device) GetPacketForTx() (uint16, []byte, error) {
 		idx, err = dev.TransmitQueue.DescriptorTable().CreateDescriptorForOutputs()
 		if err == virtqueue.ErrNotEnoughFreeDescriptors {
 			dev.fullTable = true
-			idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
+			idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
 		}
 	} else {
-		idx, err = dev.TransmitQueue.TakeSingle(context.TODO())
+		idx, err = dev.TransmitQueue.TakeSingleIndex(context.TODO())
 	}
 	if err != nil {
 		return 0, nil, fmt.Errorf("transmit queue: %w", err)
@@ -271,18 +268,15 @@ func (dev *Device) TransmitPackets(pkts []*packet.OutPacket) error {
 	return nil
 }
 
-// processChains processes as many chains as needed to create one packet. The number of processed chains is returned.
-func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.UsedElement) (int, error) {
+// ProcessRxChain processes a single chain to create one packet. The number of processed chains is returned.
+func (dev *Device) ProcessRxChain(pkt *VirtIOPacket, chain virtqueue.UsedElement) (int, error) {
 	//read first element to see how many descriptors we need:
 	pkt.Reset()
-
-	err := dev.ReceiveQueue.GetDescriptorInbuffers(uint16(chains[0].DescriptorIndex), &pkt.ChainRefs)
+	idx := uint16(chain.DescriptorIndex)
+	buf, err := dev.ReceiveQueue.GetDescriptorItem(idx)
 	if err != nil {
 		return 0, fmt.Errorf("get descriptor chain: %w", err)
 	}
-	if len(pkt.ChainRefs) == 0 {
-		return 1, nil
-	}
 
 	// The specification requires that the first descriptor chain starts
 	// with a virtio-net header. It is not clear, whether it is also
@@ -290,7 +284,7 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
 	// descriptor chain, but it is reasonable to assume that this is
 	// always the case.
 	// The decode method already does the buffer length check.
-	if err = pkt.Header.Decode(pkt.ChainRefs[0][0:]); err != nil {
+	if err = pkt.header.Decode(buf); err != nil {
 		// The device misbehaved. There is no way we can gracefully
 		// recover from this, because we don't know how many of the
 		// following descriptor chains belong to this packet.
@@ -298,72 +292,44 @@ func (dev *Device) processChains(pkt *packet.VirtIOPacket, chains []virtqueue.Us
 	}
 
 	//we have the header now: what do we need to do?
-	if int(pkt.Header.NumBuffers) > len(chains) {
-		return 0, fmt.Errorf("number of buffers is greater than number of chains %d", len(chains))
+	if int(pkt.header.NumBuffers) > 1 {
+		return 0, fmt.Errorf("number of buffers is greater than number of chains %d", 1)
 	}
-	if int(pkt.Header.NumBuffers) != 1 {
-		return 0, fmt.Errorf("too smol-brain to handle more than one chain right now: %d chains", len(chains))
+	if int(pkt.header.NumBuffers) != 1 {
+		return 0, fmt.Errorf("too smol-brain to handle more than one buffer per chain item right now: %d chains, %d bufs", 1, int(pkt.header.NumBuffers))
 	}
-	if chains[0].Length > 16000 {
+	if chain.Length > 16000 {
 		//todo!
-		return 1, fmt.Errorf("too big packet length: %d", chains[0].Length)
+		return 1, fmt.Errorf("too big packet length: %d", chain.Length)
 	}
 
 	//shift the buffer out of out:
-	pkt.Payload = pkt.ChainRefs[0][virtio.NetHdrSize:chains[0].Length]
-	pkt.Chains = append(pkt.Chains, uint16(chains[0].DescriptorIndex))
+	pkt.payload = buf[virtio.NetHdrSize:chain.Length]
+	pkt.Chains = append(pkt.Chains, idx)
 	return 1, nil
-
-	//cursor := n - virtio.NetHdrSize
-	//
-	//if uint32(n) >= chains[0].Length && pkt.Header.NumBuffers == 1 {
-	//	pkt.Payload = pkt.Payload[:chains[0].Length-virtio.NetHdrSize]
-	//	return 1, nil
-	//}
-	//
-	//i := 1
-	//// we used chain 0 already
-	//for i = 1; i < len(chains); i++ {
-	//	n, err = dev.ReceiveQueue.GetDescriptorChainContents(uint16(chains[i].DescriptorIndex), pkt.Payload[cursor:], int(chains[i].Length))
-	//	if err != nil {
-	//		// When this fails we may miss to free some descriptor chains. We
-	//		// could try to mitigate this by deferring the freeing somehow, but
-	//		// it's not worth the hassle. When this method fails, the queue will
-	//		// be in a broken state anyway.
-	//		return i, fmt.Errorf("get descriptor chain: %w", err)
-	//	}
-	//	cursor += n
-	//}
-	////todo this has to be wrong
-	//pkt.Payload = pkt.Payload[:cursor]
-	//return i, nil
 }
 
-func (dev *Device) ReceivePackets(out []*packet.VirtIOPacket) (int, error) {
-	//todo optimize?
-	var chains []virtqueue.UsedElement
-	var err error
+type VirtIOPacket struct {
+	payload []byte
+	header  virtio.NetHdr
+	Chains  []uint16
+}
 
-	chains, err = dev.ReceiveQueue.BlockAndGetHeadsCapped(context.TODO(), len(out))
-	if err != nil {
-		return 0, err
-	}
-	if len(chains) == 0 {
-		return 0, nil
-	}
+func NewVIO() *VirtIOPacket {
+	out := new(VirtIOPacket)
+	out.payload = nil
+	out.Chains = make([]uint16, 0, 8)
+	return out
+}
 
-	numPackets := 0
-	chainsIdx := 0
-	for numPackets = 0; chainsIdx < len(chains); numPackets++ {
-		if numPackets >= len(out) {
-			return numPackets, fmt.Errorf("dropping %d packets, no room", len(chains)-numPackets)
-		}
-		numChains, err := dev.processChains(out[numPackets], chains[chainsIdx:])
-		if err != nil {
-			return 0, err
-		}
-		chainsIdx += numChains
-	}
+func (v *VirtIOPacket) Reset() {
+	v.payload = nil
+	v.Chains = v.Chains[:0]
+}
 
-	return numPackets, nil
+func (v *VirtIOPacket) GetPayload() []byte {
+	return v.payload
+}
+func (v *VirtIOPacket) SetPayload(x []byte) {
+	v.payload = x //todo?
 }

+ 0 - 172
overlay/virtqueue/descriptor_table.go

@@ -10,10 +10,6 @@ import (
 )
 
 var (
-	// ErrDescriptorChainEmpty is returned when a descriptor chain would contain
-	// no buffers, which is not allowed.
-	ErrDescriptorChainEmpty = errors.New("empty descriptor chains are not allowed")
-
 	// ErrNotEnoughFreeDescriptors is returned when the free descriptors are
 	// exhausted, meaning that the queue is full.
 	ErrNotEnoughFreeDescriptors = errors.New("not enough free descriptors, queue is full")
@@ -272,59 +268,6 @@ func (dt *DescriptorTable) createDescriptorForInputs() (uint16, error) {
 	return head, nil
 }
 
-// TODO: Implement a zero-copy variant of createDescriptorChain?
-
-// getDescriptorChain returns the device-readable buffers (out buffers) and
-// device-writable buffers (in buffers) of the descriptor chain that starts with
-// the given head index. The descriptor chain must have been created using
-// [createDescriptorChain] and must not have been freed yet (meaning that the
-// head index must not be contained in the free chain).
-//
-// Be careful to only access the returned buffer slices when the device has not
-// yet or is no longer using them. They must not be accessed after
-// [freeDescriptorChain] has been called.
-func (dt *DescriptorTable) getDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
-	if int(head) > len(dt.descriptors) {
-		return nil, nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
-	}
-
-	// Iterate over the chain. The iteration is limited to the queue size to
-	// avoid ending up in an endless loop when things go very wrong.
-	next := head
-	for range len(dt.descriptors) {
-		if next == dt.freeHeadIndex {
-			return nil, nil, fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
-		}
-
-		desc := &dt.descriptors[next]
-
-		// The descriptor address points to memory not managed by Go, so this
-		// conversion is safe. See https://github.com/golang/go/issues/58625
-		//goland:noinspection GoVetUnsafePointer
-		bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
-
-		if desc.flags&descriptorFlagWritable == 0 {
-			outBuffers = append(outBuffers, bs)
-		} else {
-			inBuffers = append(inBuffers, bs)
-		}
-
-		// Is this the tail of the chain?
-		if desc.flags&descriptorFlagHasNext == 0 {
-			break
-		}
-
-		// Detect loops.
-		if desc.next == head {
-			return nil, nil, fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
-		}
-
-		next = desc.next
-	}
-
-	return
-}
-
 func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
 	if int(head) > len(dt.descriptors) {
 		return nil, fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
@@ -339,121 +282,6 @@ func (dt *DescriptorTable) getDescriptorItem(head uint16) ([]byte, error) {
 	return bs, nil
 }
 
-func (dt *DescriptorTable) getDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
-	if int(head) > len(dt.descriptors) {
-		return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
-	}
-
-	// Iterate over the chain. The iteration is limited to the queue size to
-	// avoid ending up in an endless loop when things go very wrong.
-	next := head
-	for range len(dt.descriptors) {
-		if next == dt.freeHeadIndex {
-			return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
-		}
-
-		desc := &dt.descriptors[next]
-
-		// The descriptor address points to memory not managed by Go, so this
-		// conversion is safe. See https://github.com/golang/go/issues/58625
-		//goland:noinspection GoVetUnsafePointer
-		bs := unsafe.Slice((*byte)(unsafe.Pointer(desc.address)), desc.length)
-
-		if desc.flags&descriptorFlagWritable == 0 {
-			return fmt.Errorf("there should not be an outbuffer in %d", head)
-		} else {
-			*inBuffers = append(*inBuffers, bs)
-		}
-
-		// Is this the tail of the chain?
-		if desc.flags&descriptorFlagHasNext == 0 {
-			break
-		}
-
-		// Detect loops.
-		if desc.next == head {
-			return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
-		}
-
-		next = desc.next
-	}
-
-	return nil
-}
-
-// freeDescriptorChain can be used to free a descriptor chain when it is no
-// longer in use. The descriptor chain that starts with the given index will be
-// put back into the free chain, so the descriptors can be used for later calls
-// of [createDescriptorChain].
-// The descriptor chain must have been created using [createDescriptorChain] and
-// must not have been freed yet (meaning that the head index must not be
-// contained in the free chain).
-func (dt *DescriptorTable) freeDescriptorChain(head uint16) error {
-	if int(head) > len(dt.descriptors) {
-		return fmt.Errorf("%w: index out of range", ErrInvalidDescriptorChain)
-	}
-
-	// Iterate over the chain. The iteration is limited to the queue size to
-	// avoid ending up in an endless loop when things go very wrong.
-	next := head
-	var tailDesc *Descriptor
-	var chainLen uint16
-	for range len(dt.descriptors) {
-		if next == dt.freeHeadIndex {
-			return fmt.Errorf("%w: must not be part of the free chain", ErrInvalidDescriptorChain)
-		}
-
-		desc := &dt.descriptors[next]
-		chainLen++
-
-		// Set the length of all unused descriptors back to zero.
-		desc.length = 0
-
-		// Unset all flags except the next flag.
-		desc.flags &= descriptorFlagHasNext
-
-		// Is this the tail of the chain?
-		if desc.flags&descriptorFlagHasNext == 0 {
-			tailDesc = desc
-			break
-		}
-
-		// Detect loops.
-		if desc.next == head {
-			return fmt.Errorf("%w: contains a loop", ErrInvalidDescriptorChain)
-		}
-
-		next = desc.next
-	}
-	if tailDesc == nil {
-		// A descriptor chain longer than the queue size but without loops
-		// should be impossible.
-		panic(fmt.Sprintf("could not find a tail for descriptor chain starting at %d", head))
-	}
-
-	// The tail descriptor does not have the next flag set, but when it comes
-	// back into the free chain, it should have.
-	tailDesc.flags = descriptorFlagHasNext
-
-	if dt.freeHeadIndex == noFreeHead {
-		// The whole free chain was used up, so we turn this returned descriptor
-		// chain into the new free chain by completing the circle and using its
-		// head.
-		tailDesc.next = head
-		dt.freeHeadIndex = head
-	} else {
-		// Attach the returned chain at the beginning of the free chain but
-		// right after the free chain head.
-		freeHeadDesc := &dt.descriptors[dt.freeHeadIndex]
-		tailDesc.next = freeHeadDesc.next
-		freeHeadDesc.next = head
-	}
-
-	dt.freeNum += chainLen
-
-	return nil
-}
-
 // checkUnusedDescriptorLength asserts that the length of an unused descriptor
 // is zero, as it should be.
 // This is not a requirement by the virtio spec but rather a thing we do to

+ 40 - 39
overlay/virtqueue/split_virtqueue.go

@@ -128,8 +128,7 @@ func NewSplitQueue(queueSize int, itemSize int) (_ *SplitQueue, err error) {
 		return nil, err
 	}
 
-	// Consume used buffer notifications in the background.
-	sq.stop = sq.startConsumeUsedRing()
+	sq.stop = sq.kickSelfToExit()
 
 	return &sq, nil
 }
@@ -169,9 +168,7 @@ func (sq *SplitQueue) CallEventFD() int {
 	return sq.callEventFD.FD()
 }
 
-// startConsumeUsedRing starts a goroutine that runs [consumeUsedRing].
-// A function is returned that can be used to gracefully cancel it. todo rename
-func (sq *SplitQueue) startConsumeUsedRing() func() error {
+func (sq *SplitQueue) kickSelfToExit() func() error {
 	return func() error {
 
 		// The goroutine blocks until it receives a signal on the event file
@@ -185,7 +182,15 @@ func (sq *SplitQueue) startConsumeUsedRing() func() error {
 	}
 }
 
-func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
+func (sq *SplitQueue) TakeSingleIndex(ctx context.Context) (uint16, error) {
+	element, err := sq.TakeSingle(ctx)
+	if err != nil {
+		return 0xffff, err
+	}
+	return element.GetHead(), nil
+}
+
+func (sq *SplitQueue) TakeSingle(ctx context.Context) (UsedElement, error) {
 	var n int
 	var err error
 	for ctx.Err() == nil {
@@ -195,7 +200,7 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
 		}
 		// Wait for a signal from the device.
 		if n, err = sq.epoll.Block(); err != nil {
-			return 0, fmt.Errorf("wait: %w", err)
+			return UsedElement{}, fmt.Errorf("wait: %w", err)
 		}
 
 		if n > 0 {
@@ -208,7 +213,31 @@ func (sq *SplitQueue) TakeSingle(ctx context.Context) (uint16, error) {
 			}
 		}
 	}
-	return 0, ctx.Err()
+	return UsedElement{}, ctx.Err()
+}
+
+func (sq *SplitQueue) TakeSingleNoBlock() (UsedElement, bool) {
+	return sq.usedRing.takeOne()
+}
+
+func (sq *SplitQueue) WaitForUsedElements(ctx context.Context) error {
+	if sq.usedRing.availableToTake() != 0 {
+		return nil
+	}
+	for ctx.Err() == nil {
+		// Wait for a signal from the device.
+		n, err := sq.epoll.Block()
+		if err != nil {
+			return fmt.Errorf("wait: %w", err)
+		}
+		if n > 0 {
+			_ = sq.epoll.Clear()
+			if sq.usedRing.availableToTake() != 0 {
+				return nil
+			}
+		}
+	}
+	return ctx.Err()
 }
 
 func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int) ([]UsedElement, error) {
@@ -235,7 +264,7 @@ func (sq *SplitQueue) BlockAndGetHeadsCapped(ctx context.Context, maxToTake int)
 			return nil, fmt.Errorf("wait: %w", err)
 		}
 		if n > 0 {
-			_ = sq.epoll.Clear() //???
+			_ = sq.epoll.Clear()
 			stillNeedToTake, out = sq.usedRing.take(maxToTake)
 			sq.more = stillNeedToTake
 			return out, nil
@@ -296,16 +325,14 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
 	sq.availableRing.offerSingle(head)
 
 	// Notify the device to make it process the updated available ring.
-	if err := sq.kickEventFD.Kick(); err != nil {
+	if err = sq.kickEventFD.Kick(); err != nil {
 		return head, fmt.Errorf("notify device: %w", err)
 	}
 
 	return head, nil
 }
 
-// GetDescriptorChain returns the device-readable buffers (out buffers) and
-// device-writable buffers (in buffers) of the descriptor chain with the given
-// head index.
+// GetDescriptorItem returns the buffer of a given index
 // The head index must be one that was returned by a previous call to
 // [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
 // freed yet.
@@ -313,37 +340,11 @@ func (sq *SplitQueue) OfferInDescriptorChains() (uint16, error) {
 // Be careful to only access the returned buffer slices when the device is no
 // longer using them. They must not be accessed after
 // [SplitQueue.FreeDescriptorChain] has been called.
-func (sq *SplitQueue) GetDescriptorChain(head uint16) (outBuffers, inBuffers [][]byte, err error) {
-	return sq.descriptorTable.getDescriptorChain(head)
-}
-
 func (sq *SplitQueue) GetDescriptorItem(head uint16) ([]byte, error) {
 	sq.descriptorTable.descriptors[head].length = uint32(sq.descriptorTable.itemSize)
 	return sq.descriptorTable.getDescriptorItem(head)
 }
 
-func (sq *SplitQueue) GetDescriptorInbuffers(head uint16, inBuffers *[][]byte) error {
-	return sq.descriptorTable.getDescriptorInbuffers(head, inBuffers)
-}
-
-// FreeDescriptorChain frees the descriptor chain with the given head index.
-// The head index must be one that was returned by a previous call to
-// [SplitQueue.OfferDescriptorChain] and the descriptor chain must not have been
-// freed yet.
-//
-// This creates new room in the queue which can be used by following
-// [SplitQueue.OfferDescriptorChain] calls.
-// When there are outstanding calls for [SplitQueue.OfferDescriptorChain] that
-// are waiting for free room in the queue, they may become unblocked by this.
-func (sq *SplitQueue) FreeDescriptorChain(head uint16) error {
-	//not called under lock
-	if err := sq.descriptorTable.freeDescriptorChain(head); err != nil {
-		return fmt.Errorf("free: %w", err)
-	}
-
-	return nil
-}
-
 func (sq *SplitQueue) SetDescSize(head uint16, sz int) {
 	//not called under lock
 	sq.descriptorTable.descriptors[int(head)].length = uint32(sz)

+ 17 - 25
overlay/virtqueue/used_ring.go

@@ -84,17 +84,11 @@ func (r *UsedRing) Address() uintptr {
 	return uintptr(unsafe.Pointer(r.flags))
 }
 
-// take returns all new [UsedElement]s that the device put into the ring and
-// that weren't already returned by a previous call to this method.
-// had a lock, I removed it
-func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
-	//r.mu.Lock()
-	//defer r.mu.Unlock()
-
+func (r *UsedRing) availableToTake() int {
 	ringIndex := *r.ringIndex
 	if ringIndex == r.lastIndex {
 		// Nothing new.
-		return 0, nil
+		return 0
 	}
 
 	// Calculate the number new used elements that we can read from the ring.
@@ -103,6 +97,16 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
 	if count < 0 {
 		count += 0xffff
 	}
+	return count
+}
+
+// take returns all new [UsedElement]s that the device put into the ring and
+// that weren't already returned by a previous call to this method.
+func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
+	count := r.availableToTake()
+	if count == 0 {
+		return 0, nil
+	}
 
 	stillNeedToTake := 0
 
@@ -128,21 +132,13 @@ func (r *UsedRing) take(maxToTake int) (int, []UsedElement) {
 	return stillNeedToTake, elems
 }
 
-func (r *UsedRing) takeOne() (uint16, bool) {
+func (r *UsedRing) takeOne() (UsedElement, bool) {
 	//r.mu.Lock()
 	//defer r.mu.Unlock()
 
-	ringIndex := *r.ringIndex
-	if ringIndex == r.lastIndex {
-		// Nothing new.
-		return 0xffff, false
-	}
-
-	// Calculate the number new used elements that we can read from the ring.
-	// The ring index may wrap, so special handling for that case is needed.
-	count := int(ringIndex - r.lastIndex)
-	if count < 0 {
-		count += 0xffff
+	count := r.availableToTake()
+	if count == 0 {
+		return UsedElement{}, false
 	}
 
 	// The number of new elements can never exceed the queue size.
@@ -150,11 +146,7 @@ func (r *UsedRing) takeOne() (uint16, bool) {
 		panic("used ring contains more new elements than the ring is long")
 	}
 
-	if count == 0 {
-		return 0xffff, false
-	}
-
-	out := r.ring[r.lastIndex%uint16(len(r.ring))].GetHead()
+	out := r.ring[r.lastIndex%uint16(len(r.ring))]
 	r.lastIndex++
 
 	return out, true

+ 12 - 12
packet/packet.go

@@ -14,7 +14,7 @@ import (
 
 const Size = 0xffff
 
-type Packet struct {
+type UDPPacket struct {
 	Payload []byte
 	Control []byte
 	Name    []byte
@@ -25,8 +25,8 @@ type Packet struct {
 	isV4         bool
 }
 
-func New(isV4 bool) *Packet {
-	return &Packet{
+func New(isV4 bool) *UDPPacket {
+	return &UDPPacket{
 		Payload: make([]byte, Size),
 		Control: make([]byte, unix.CmsgSpace(2)),
 		Name:    make([]byte, unix.SizeofSockaddrInet6),
@@ -34,7 +34,7 @@ func New(isV4 bool) *Packet {
 	}
 }
 
-func (p *Packet) AddrPort() netip.AddrPort {
+func (p *UDPPacket) AddrPort() netip.AddrPort {
 	var ip netip.Addr
 	// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
 	if p.isV4 {
@@ -45,7 +45,7 @@ func (p *Packet) AddrPort() netip.AddrPort {
 	return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
 }
 
-func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
+func (p *UDPPacket) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
 	//todo no chance this works on windows?
 	if p.isV4 {
 		if !addr.Addr().Is4() {
@@ -69,7 +69,7 @@ func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error)
 	return uint32(size), nil
 }
 
-func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
+func (p *UDPPacket) SetAddrPort(addr netip.AddrPort) error {
 	nl, err := p.encodeSockaddr(p.Name, addr)
 	if err != nil {
 		return err
@@ -78,7 +78,7 @@ func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
 	return nil
 }
 
-func (p *Packet) updateCtrl(ctrlLen int) {
+func (p *UDPPacket) updateCtrl(ctrlLen int) {
 	p.SegSize = len(p.Payload)
 	p.wasSegmented = false
 	if ctrlLen == 0 {
@@ -101,12 +101,12 @@ func (p *Packet) updateCtrl(ctrlLen int) {
 	}
 }
 
-// Update sets a Packet into "just received, not processed" state
-func (p *Packet) Update(ctrlLen int) {
+// Update sets a UDPPacket into "just received, not processed" state
+func (p *UDPPacket) Update(ctrlLen int) {
 	p.updateCtrl(ctrlLen)
 }
 
-func (p *Packet) SetSegSizeForTX() {
+func (p *UDPPacket) SetSegSizeForTX() {
 	p.SegSize = len(p.Payload)
 	hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
 	hdr.Level = unix.SOL_UDP
@@ -115,7 +115,7 @@ func (p *Packet) SetSegSizeForTX() {
 	binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
 }
 
-func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize int) bool {
+func (p *UDPPacket) CompatibleForSegmentationWith(otherP *UDPPacket, currentTotalSize int) bool {
 	//same dest
 	if !slices.Equal(p.Name, otherP.Name) {
 		return false
@@ -134,7 +134,7 @@ func (p *Packet) CompatibleForSegmentationWith(otherP *Packet, currentTotalSize
 	return true
 }
 
-func (p *Packet) Segments() iter.Seq[[]byte] {
+func (p *UDPPacket) Segments() iter.Seq[[]byte] {
 	return func(yield func([]byte) bool) {
 		//cursor := 0
 		for offset := 0; offset < len(p.Payload); offset += p.SegSize {

+ 0 - 26
packet/virtio.go

@@ -1,26 +0,0 @@
-package packet
-
-import (
-	"github.com/slackhq/nebula/util/virtio"
-)
-
-type VirtIOPacket struct {
-	Payload   []byte
-	Header    virtio.NetHdr
-	Chains    []uint16
-	ChainRefs [][]byte
-}
-
-func NewVIO() *VirtIOPacket {
-	out := new(VirtIOPacket)
-	out.Payload = nil
-	out.ChainRefs = make([][]byte, 0, 4)
-	out.Chains = make([]uint16, 0, 8)
-	return out
-}
-
-func (v *VirtIOPacket) Reset() {
-	v.Payload = nil
-	v.ChainRefs = v.ChainRefs[:0]
-	v.Chains = v.Chains[:0]
-}

+ 3 - 3
udp/conn.go

@@ -10,7 +10,7 @@ import (
 const MTU = 9001
 
 type EncReader func(
-	[]*packet.Packet,
+	[]*packet.UDPPacket,
 )
 
 type Conn interface {
@@ -19,8 +19,8 @@ type Conn interface {
 	ListenOut(r EncReader)
 	WriteTo(b []byte, addr netip.AddrPort) error
 	ReloadConfig(c *config.C)
-	Prep(pkt *packet.Packet, addr netip.AddrPort) error
-	WriteBatch(pkt []*packet.Packet) (int, error)
+	Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error
+	WriteBatch(pkt []*packet.UDPPacket) (int, error)
 	SupportsMultipleReaders() bool
 	Close() error
 }

+ 3 - 3
udp/udp_linux.go

@@ -215,7 +215,7 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
 	return u.writeTo6(b, ip)
 }
 
-func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
+func (u *StdConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error {
 	//todo move this into pkt
 	nl, err := u.encodeSockaddr(pkt.Name, addr)
 	if err != nil {
@@ -226,7 +226,7 @@ func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
 	return nil
 }
 
-func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
+func (u *StdConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) {
 	if len(pkts) == 0 {
 		return 0, nil
 	}
@@ -235,7 +235,7 @@ func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
 	//u.iovs = u.iovs[:0]
 
 	sent := 0
-	var mostRecentPkt *packet.Packet
+	var mostRecentPkt *packet.UDPPacket
 	mostRecentPktSize := 0
 	//segmenting := false
 	idx := 0

+ 2 - 2
udp/udp_linux_64.go

@@ -52,9 +52,9 @@ func setCmsgLen(h *unix.Cmsghdr, l int) {
 	h.Len = uint64(l)
 }
 
-func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.Packet) {
+func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.UDPPacket) {
 	msgs := make([]rawMessage, n)
-	packets := make([]*packet.Packet, n)
+	packets := make([]*packet.UDPPacket, n)
 
 	for i := range msgs {
 		packets[i] = packet.New(isV4)

+ 3 - 3
udp/udp_tester.go

@@ -41,7 +41,7 @@ type TesterConn struct {
 	l      *logrus.Logger
 }
 
-func (u *TesterConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
+func (u *TesterConn) Prep(pkt *packet.UDPPacket, addr netip.AddrPort) error {
 	pkt.ReadyToSend = true
 	return pkt.SetAddrPort(addr)
 }
@@ -96,7 +96,7 @@ func (u *TesterConn) Get(block bool) *Packet {
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 
-func (u *TesterConn) WriteBatch(pkts []*packet.Packet) (int, error) {
+func (u *TesterConn) WriteBatch(pkts []*packet.UDPPacket) (int, error) {
 	for _, pkt := range pkts {
 		if !pkt.ReadyToSend {
 			continue
@@ -141,7 +141,7 @@ func (u *TesterConn) ListenOut(r EncReader) {
 		if err != nil {
 			panic(err)
 		}
-		y := []*packet.Packet{x}
+		y := []*packet.UDPPacket{x}
 		r(y)
 	}
 }