JackDoan 6 日 前
コミット
f5c46c43ce
6 ファイル変更303 行追加187 行削除
  1. 2 1
      control_tester.go
  2. 174 170
      outside.go
  3. 61 15
      overlay/tun_tester.go
  4. 34 0
      packet/packet.go
  5. 1 0
      udp/udp_linux.go
  6. 31 1
      udp/udp_tester.go

+ 2 - 1
control_tester.go

@@ -80,7 +80,8 @@ func (c *Control) GetFromTun(block bool) []byte {
 
 // GetFromUDP will pull a udp packet off the udp side of nebula
 func (c *Control) GetFromUDP(block bool) *udp.Packet {
-	return c.f.outside.(*udp.TesterConn).Get(block)
+	out := c.f.outside.(*udp.TesterConn).Get(block)
+	return out
 }
 
 func (c *Control) GetUDPTxChan() <-chan *udp.Packet {

+ 174 - 170
outside.go

@@ -166,212 +166,216 @@ func (f *Interface) readOutsidePacketFromRelay(via ViaSender, out []byte, packet
 	f.connectionManager.In(hostinfo)
 }
 
-func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
-	for i, pkt := range packets {
-		out[i].Scratch = out[i].Scratch[:0]
-		via := ViaSender{UdpAddr: pkt.AddrPort()}
+func (f *Interface) readOutsideSegment(via ViaSender, segment []byte, out *packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
+	err := h.Parse(segment)
+	if err != nil {
+		// Hole punch packets are 0 or 1 byte big, so let's ignore printing those errors
+		if len(segment) > 1 {
+			f.l.WithField("packet", segment).Infof("Error while parsing inbound packet from %s: %s", via, err)
+		}
+		return
+	}
 
-		//l.Error("in packet ", header, packet[HeaderLen:])
-		if !via.IsRelayed {
-			if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
-				if f.l.Level >= logrus.DebugLevel {
-					f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
-				}
-				return
-			}
+	var hostinfo *HostInfo
+	// verify if we've seen this index before, otherwise respond to the handshake initiation
+	if h.Type == header.Message && h.Subtype == header.MessageRelay {
+		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
+	} else {
+		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
+	}
+
+	var ci *ConnectionState
+	if hostinfo != nil {
+		ci = hostinfo.ConnectionState
+	}
+
+	switch h.Type {
+	case header.Message:
+		if !f.handleEncrypted(ci, via, h) {
+			return
 		}
 
-		for segment := range pkt.Segments() {
-			err := h.Parse(segment)
-			if err != nil {
-				// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
-				if len(segment) > 1 {
-					f.l.WithField("packet", pkt).Infof("Error while parsing inbound packet from %s: %s", via, err)
-				}
+		switch h.Subtype {
+		case header.MessageNone:
+			if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out, segment, fwPacket, nb, q, localCache, now) {
 				return
 			}
-
-			var hostinfo *HostInfo
-			// verify if we've seen this index before, otherwise respond to the handshake initiation
-			if h.Type == header.Message && h.Subtype == header.MessageRelay {
-				hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
-			} else {
-				hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
+		case header.MessageRelay:
+			// The entire body is sent as AD, not encrypted.
+			// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
+			// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
+			// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
+			// which will gracefully fail in the DecryptDanger call.
+			signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
+			signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
+			out.Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out.Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
+			if err != nil {
+				return
 			}
+			// Successfully validated the thing. Get rid of the Relay header.
+			signedPayload = signedPayload[header.Len:]
+			// Pull the Roaming parts up here, and return in all call paths.
+			f.handleHostRoaming(hostinfo, via)
+			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
+			f.connectionManager.In(hostinfo)
+			f.connectionManager.RelayUsed(h.RemoteIndex)
 
-			var ci *ConnectionState
-			if hostinfo != nil {
-				ci = hostinfo.ConnectionState
+			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
+			if !ok {
+				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
+				// its internal mapping. This should never happen.
+				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+				return
 			}
 
-			switch h.Type {
-			case header.Message:
-				if !f.handleEncrypted(ci, via, h) {
+			switch relay.Type {
+			case TerminalType:
+				// If I am the target of this relay, process the unwrapped packet
+				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
+				via = ViaSender{
+					UdpAddr:   via.UdpAddr,
+					relayHI:   hostinfo,
+					remoteIdx: relay.RemoteIndex,
+					relay:     relay,
+					IsRelayed: true,
+				}
+				f.readOutsidePacketFromRelay(via, out.Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
+				return
+			case ForwardingType:
+				// Find the target HostInfo relay object
+				targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
+				if err != nil {
+					hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
 					return
 				}
 
-				switch h.Subtype {
-				case header.MessageNone:
-					if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
-						return
-					}
-				case header.MessageRelay:
-					// The entire body is sent as AD, not encrypted.
-					// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
-					// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
-					// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
-					// which will gracefully fail in the DecryptDanger call.
-					signedPayload := segment[:len(segment)-hostinfo.ConnectionState.dKey.Overhead()]
-					signatureValue := segment[len(segment)-hostinfo.ConnectionState.dKey.Overhead():]
-					out[i].Scratch, err = hostinfo.ConnectionState.dKey.DecryptDanger(out[i].Scratch, signedPayload, signatureValue, h.MessageCounter, nb)
-					if err != nil {
-						return
-					}
-					// Successfully validated the thing. Get rid of the Relay header.
-					signedPayload = signedPayload[header.Len:]
-					// Pull the Roaming parts up here, and return in all call paths.
-					f.handleHostRoaming(hostinfo, via)
-					// Track usage of both the HostInfo and the Relay for the received & authenticated packet
-					f.connectionManager.In(hostinfo)
-					f.connectionManager.RelayUsed(h.RemoteIndex)
-
-					relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
-					if !ok {
-						// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
-						// its internal mapping. This should never happen.
-						hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+				// If that relay is Established, forward the payload through it
+				if targetRelay.State == Established {
+					switch targetRelay.Type {
+					case ForwardingType:
+						// Forward this packet through the relay tunnel
+						// Find the target HostInfo
+						f.SendVia(targetHI, targetRelay, signedPayload, nb, out.Scratch, false)
 						return
-					}
-
-					switch relay.Type {
 					case TerminalType:
-						// If I am the target of this relay, process the unwrapped packet
-						// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
-						via = ViaSender{
-							UdpAddr:   via.UdpAddr,
-							relayHI:   hostinfo,
-							remoteIdx: relay.RemoteIndex,
-							relay:     relay,
-							IsRelayed: true,
-						}
-						f.readOutsidePacketFromRelay(via, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
-						return
-					case ForwardingType:
-						// Find the target HostInfo relay object
-						targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
-						if err != nil {
-							hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
-							return
-						}
-
-						// If that relay is Established, forward the payload through it
-						if targetRelay.State == Established {
-							switch targetRelay.Type {
-							case ForwardingType:
-								// Forward this packet through the relay tunnel
-								// Find the target HostInfo
-								f.SendVia(targetHI, targetRelay, signedPayload, nb, out[i].Scratch, false)
-								return
-							case TerminalType:
-								hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
-							}
-						} else {
-							hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
-							return
-						}
+						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
 					}
-				}
-
-			case header.LightHouse:
-				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-				if !f.handleEncrypted(ci, via, h) {
+				} else {
+					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
 					return
 				}
+			}
+		}
 
-				d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
-				if err != nil {
-					hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via.UdpAddr).
-						WithField("packet", segment).
-						Error("Failed to decrypt lighthouse packet")
-					return
-				}
+	case header.LightHouse:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, via, h) {
+			return
+		}
 
-				lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
+		d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb)
+		if err != nil {
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via.UdpAddr).
+				WithField("packet", segment).
+				Error("Failed to decrypt lighthouse packet")
+			return
+		}
 
-				// Fallthrough to the bottom to record incoming traffic
+		lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
 
-			case header.Test:
-				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-				if !f.handleEncrypted(ci, via, h) {
-					return
-				}
+		// Fallthrough to the bottom to record incoming traffic
 
-				d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
-				if err != nil {
-					hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via).
-						WithField("packet", segment).
-						Error("Failed to decrypt test packet")
-					return
-				}
+	case header.Test:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, via, h) {
+			return
+		}
 
-				if h.Subtype == header.TestRequest {
-					// This testRequest might be from TryPromoteBest, so we should roam
-					// to the new IP address before responding
-					f.handleHostRoaming(hostinfo, via)
-					f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out[i].Scratch)
-				}
+		d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb)
+		if err != nil {
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via).
+				WithField("packet", segment).
+				Error("Failed to decrypt test packet")
+			return
+		}
 
-				// Fallthrough to the bottom to record incoming traffic
+		if h.Subtype == header.TestRequest {
+			// This testRequest might be from TryPromoteBest, so we should roam
+			// to the new IP address before responding
+			f.handleHostRoaming(hostinfo, via)
+			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out.Scratch)
+		}
 
-				// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
-				// are unauthenticated
+		// Fallthrough to the bottom to record incoming traffic
 
-			case header.Handshake:
-				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-				f.handshakeManager.HandleIncoming(via, segment, h)
-				return
+		// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
+		// are unauthenticated
 
-			case header.RecvError:
-				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-				f.handleRecvError(via.UdpAddr, h)
-				return
+	case header.Handshake:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		f.handshakeManager.HandleIncoming(via, segment, h)
+		return
 
-			case header.CloseTunnel:
-				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-				if !f.handleEncrypted(ci, via, h) {
-					return
-				}
+	case header.RecvError:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		f.handleRecvError(via.UdpAddr, h)
+		return
 
-				hostinfo.logger(f.l).WithField("udpAddr", via).
-					Info("Close tunnel received, tearing down.")
+	case header.CloseTunnel:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		if !f.handleEncrypted(ci, via, h) {
+			return
+		}
 
-				f.closeTunnel(hostinfo)
-				return
+		hostinfo.logger(f.l).WithField("udpAddr", via).
+			Info("Close tunnel received, tearing down.")
 
-			case header.Control:
-				if !f.handleEncrypted(ci, via, h) {
-					return
-				}
+		f.closeTunnel(hostinfo)
+		return
 
-				d, err := f.decrypt(hostinfo, h.MessageCounter, out[i].Scratch, segment, h, nb)
-				if err != nil {
-					hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via).
-						WithField("packet", segment).
-						Error("Failed to decrypt Control packet")
-					return
-				}
+	case header.Control:
+		if !f.handleEncrypted(ci, via, h) {
+			return
+		}
+
+		d, err := f.decrypt(hostinfo, h.MessageCounter, out.Scratch, segment, h, nb)
+		if err != nil {
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", via).
+				WithField("packet", segment).
+				Error("Failed to decrypt Control packet")
+			return
+		}
+
+		f.relayManager.HandleControlMsg(hostinfo, d, f)
+
+	default:
+		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
+		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
+		return
+	}
 
-				f.relayManager.HandleControlMsg(hostinfo, d, f)
+	f.handleHostRoaming(hostinfo, via)
+
+	f.connectionManager.In(hostinfo)
+}
+
+func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
+	for i, pkt := range packets {
+		out[i].Scratch = out[i].Scratch[:0]
+		via := ViaSender{UdpAddr: pkt.AddrPort()}
 
-			default:
-				f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-				hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
+		//l.Error("in packet ", header, packet[HeaderLen:])
+		if !via.IsRelayed {
+			if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
+				if f.l.Level >= logrus.DebugLevel {
+					f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
+				}
 				return
 			}
+		}
 
-			f.handleHostRoaming(hostinfo, via)
-
-			f.connectionManager.In(hostinfo)
+		for segment := range pkt.Segments() {
+			f.readOutsideSegment(via, segment, out[i], h, fwPacket, lhf, nb, q, localCache, now)
 
 		}
 		_, err := f.readers[q].WriteOne(out[i], false, q)
@@ -630,7 +634,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 	return out, nil
 }
 
-func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
+func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
 	var err error
 
 	seg, err := f.readers[q].AllocSeg(out, q)

+ 61 - 15
overlay/tun_tester.go

@@ -13,6 +13,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/packet"
 	"github.com/slackhq/nebula/routing"
 )
 
@@ -26,6 +27,7 @@ type TestTun struct {
 	closed    atomic.Bool
 	rxPackets chan []byte // Packets to receive into nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
+	buffers   [][]byte
 }
 
 func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
@@ -104,15 +106,68 @@ func (t *TestTun) Name() string {
 	return t.Device
 }
 
-func (t *TestTun) Write(b []byte) (n int, err error) {
+func (t *TestTun) ReadMany(x []*packet.VirtIOPacket, q int) (int, error) {
+	p, ok := <-t.rxPackets
+	if !ok {
+		return 0, os.ErrClosed
+	}
+	x[0].Payload = p
+	return 1, nil
+}
+
+func (t *TestTun) AllocSeg(pkt *packet.OutPacket, q int) (int, error) {
+	buf := make([]byte, 9000)
+	t.buffers = append(t.buffers, buf)
+	idx := len(t.buffers) - 1
+	isV6 := false //todo?
+	x := pkt.UseSegment(uint16(idx), buf, isV6)
+	return x, nil
+}
+
+func (t *TestTun) Write(b []byte) (int, error) {
+	//todo garbagey
+	out := packet.NewOut()
+	x, err := t.AllocSeg(out, 0)
+	if err != nil {
+		return 0, err
+	}
+	copy(out.SegmentPayloads[x], b)
+	return t.WriteOne(out, true, 0)
+}
+
+func (t *TestTun) WriteOne(x *packet.OutPacket, kick bool, q int) (int, error) {
 	if t.closed.Load() {
 		return 0, io.ErrClosedPipe
 	}
+	if len(x.SegmentIDs) == 0 {
+		return 0, nil
+	}
+	for i, _ := range x.SegmentIDs {
+		t.TxPackets <- x.SegmentPayloads[i]
+	}
+	//todo if kick, delete alloced seg
 
-	packet := make([]byte, len(b), len(b))
-	copy(packet, b)
-	t.TxPackets <- packet
-	return len(b), nil
+	return 1, nil
+}
+
+func (t *TestTun) WriteMany(x []*packet.OutPacket, q int) (int, error) {
+	if len(x) == 0 {
+		return 0, nil
+	}
+
+	for _, pkt := range x {
+		_, err := t.WriteOne(pkt, true, q)
+		if err != nil {
+			return 0, err
+		}
+	}
+
+	return len(x), nil
+}
+
+func (t *TestTun) RecycleRxSeg(pkt *packet.VirtIOPacket, kick bool, q int) error {
+	//todo this ought to maybe track something
+	return nil
 }
 
 func (t *TestTun) Close() error {
@@ -123,19 +178,10 @@ func (t *TestTun) Close() error {
 	return nil
 }
 
-func (t *TestTun) Read(b []byte) (int, error) {
-	p, ok := <-t.rxPackets
-	if !ok {
-		return 0, os.ErrClosed
-	}
-	copy(b, p)
-	return len(p), nil
-}
-
 func (t *TestTun) SupportsMultiqueue() bool {
 	return false
 }
 
-func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+func (t *TestTun) NewMultiQueueReader() (TunDev, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented")
 }

+ 34 - 0
packet/packet.go

@@ -2,6 +2,7 @@ package packet
 
 import (
 	"encoding/binary"
+	"fmt"
 	"iter"
 	"net/netip"
 	"slices"
@@ -44,6 +45,39 @@ func (p *Packet) AddrPort() netip.AddrPort {
 	return netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(p.Name[2:4]))
 }
 
+func (p *Packet) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
+	//todo no chance this works on windows?
+	if p.isV4 {
+		if !addr.Addr().Is4() {
+			return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
+		}
+		var sa unix.RawSockaddrInet4
+		sa.Family = unix.AF_INET
+		sa.Addr = addr.Addr().As4()
+		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
+		size := unix.SizeofSockaddrInet4
+		copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
+		return uint32(size), nil
+	}
+
+	var sa unix.RawSockaddrInet6
+	sa.Family = unix.AF_INET6
+	sa.Addr = addr.Addr().As16()
+	binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
+	size := unix.SizeofSockaddrInet6
+	copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
+	return uint32(size), nil
+}
+
+func (p *Packet) SetAddrPort(addr netip.AddrPort) error {
+	nl, err := p.encodeSockaddr(p.Name, addr)
+	if err != nil {
+		return err
+	}
+	p.Name = p.Name[:nl]
+	return nil
+}
+
 func (p *Packet) updateCtrl(ctrlLen int) {
 	p.SegSize = len(p.Payload)
 	p.wasSegmented = false

+ 1 - 0
udp/udp_linux.go

@@ -216,6 +216,7 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
 }
 
 func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
+	//todo move this into pkt
 	nl, err := u.encodeSockaddr(pkt.Name, addr)
 	if err != nil {
 		return err

+ 31 - 1
udp/udp_tester.go

@@ -11,6 +11,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/packet"
 )
 
 type Packet struct {
@@ -40,6 +41,11 @@ type TesterConn struct {
 	l      *logrus.Logger
 }
 
+func (u *TesterConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
+	pkt.ReadyToSend = true
+	return pkt.SetAddrPort(addr)
+}
+
 func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
 	return &TesterConn{
 		Addr:      netip.AddrPortFrom(ip, uint16(port)),
@@ -90,6 +96,19 @@ func (u *TesterConn) Get(block bool) *Packet {
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 
+func (u *TesterConn) WriteBatch(pkts []*packet.Packet) (int, error) {
+	for _, pkt := range pkts {
+		if !pkt.ReadyToSend {
+			continue
+		}
+		err := u.WriteTo(pkt.Payload, pkt.AddrPort())
+		if err != nil {
+			return 0, err
+		}
+	}
+	return len(pkts), nil
+}
+
 func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	if u.closed.Load() {
 		return io.ErrClosedPipe
@@ -100,6 +119,9 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
 		From: u.Addr,
 		To:   addr,
 	}
+	if addr.Addr().IsUnspecified() {
+		panic("invalid address")
+	}
 
 	copy(p.Data, b)
 	u.TxPackets <- p
@@ -112,7 +134,15 @@ func (u *TesterConn) ListenOut(r EncReader) {
 		if !ok {
 			return
 		}
-		r(p.From, p.Data)
+		x := packet.New(p.From.Addr().Is4())
+		x.Payload = p.Data
+		x.SetSegSizeForTX()
+		err := x.SetAddrPort(p.From)
+		if err != nil {
+			panic(err)
+		}
+		y := []*packet.Packet{x}
+		r(y)
 	}
 }