Browse Source

add new files for compat layer

Ryan Huber 1 month ago
parent
commit
608904b9dd

+ 102 - 0
overlay/wireguard_tun_linux.go

@@ -0,0 +1,102 @@
+//go:build linux && !android && !e2e_testing
+
+package overlay
+
+import (
+	"fmt"
+	"sync"
+
+	wgtun "github.com/slackhq/nebula/wgstack/tun"
+)
+
+type wireguardTunIO struct {
+	dev       wgtun.Device
+	mtu       int
+	batchSize int
+
+	readMu   sync.Mutex
+	readBufs [][]byte
+	readLens []int
+	pending  [][]byte
+	pendIdx  int
+
+	writeMu   sync.Mutex
+	writeBuf  []byte
+	writeWrap [][]byte
+}
+
+func newWireguardTunIO(dev wgtun.Device, mtu int) *wireguardTunIO {
+	batch := dev.BatchSize()
+	if batch <= 0 {
+		batch = 1
+	}
+	if mtu <= 0 {
+		mtu = DefaultMTU
+	}
+	bufs := make([][]byte, batch)
+	for i := range bufs {
+		bufs[i] = make([]byte, wgtun.VirtioNetHdrLen+mtu)
+	}
+	return &wireguardTunIO{
+		dev:       dev,
+		mtu:       mtu,
+		batchSize: batch,
+		readBufs:  bufs,
+		readLens:  make([]int, batch),
+		pending:   make([][]byte, 0, batch),
+		writeBuf:  make([]byte, wgtun.VirtioNetHdrLen+mtu),
+		writeWrap: make([][]byte, 1),
+	}
+}
+
+func (w *wireguardTunIO) Read(p []byte) (int, error) {
+	w.readMu.Lock()
+	defer w.readMu.Unlock()
+
+	for {
+		if w.pendIdx < len(w.pending) {
+			segment := w.pending[w.pendIdx]
+			w.pendIdx++
+			n := copy(p, segment)
+			return n, nil
+		}
+
+		n, err := w.dev.Read(w.readBufs, w.readLens, wgtun.VirtioNetHdrLen)
+		if err != nil {
+			return 0, err
+		}
+		w.pending = w.pending[:0]
+		w.pendIdx = 0
+		for i := 0; i < n; i++ {
+			length := w.readLens[i]
+			if length == 0 {
+				continue
+			}
+			segment := w.readBufs[i][wgtun.VirtioNetHdrLen : wgtun.VirtioNetHdrLen+length]
+			w.pending = append(w.pending, segment)
+		}
+	}
+}
+
+func (w *wireguardTunIO) Write(p []byte) (int, error) {
+	if len(p) > w.mtu {
+		return 0, fmt.Errorf("wireguard tun: payload exceeds MTU (%d > %d)", len(p), w.mtu)
+	}
+	w.writeMu.Lock()
+	defer w.writeMu.Unlock()
+	buf := w.writeBuf[:wgtun.VirtioNetHdrLen+len(p)]
+	for i := 0; i < wgtun.VirtioNetHdrLen; i++ {
+		buf[i] = 0
+	}
+	copy(buf[wgtun.VirtioNetHdrLen:], p)
+	w.writeWrap[0] = buf
+	n, err := w.dev.Write(w.writeWrap, wgtun.VirtioNetHdrLen)
+	if err != nil {
+		return n, err
+	}
+	return len(p), nil
+}
+
+func (w *wireguardTunIO) Close() error {
+	return nil
+}

+ 132 - 0
udp/wireguard_conn_linux.go

@@ -0,0 +1,132 @@
+//go:build linux && !android && !e2e_testing
+
+package udp
+
+import (
+	"errors"
+	"net"
+	"net/netip"
+	"sync"
+	"sync/atomic"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	wgconn "github.com/slackhq/nebula/wgstack/conn"
+)
+
+// WGConn adapts WireGuard's batched UDP bind implementation to Nebula's udp.Conn interface.
+type WGConn struct {
+	l         *logrus.Logger
+	bind      *wgconn.StdNetBind
+	recvers   []wgconn.ReceiveFunc
+	batch     int
+	localIP   netip.Addr
+	localPort uint16
+	closed    atomic.Bool
+
+	closeOnce sync.Once
+}
+
+// NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind.
+func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
+	bind := wgconn.NewStdNetBindForAddr(ip, multi)
+	recvers, actualPort, err := bind.Open(uint16(port))
+	if err != nil {
+		return nil, err
+	}
+	if batch <= 0 || batch > bind.BatchSize() {
+		batch = bind.BatchSize()
+	}
+	return &WGConn{
+		l:         l,
+		bind:      bind,
+		recvers:   recvers,
+		batch:     batch,
+		localIP:   ip,
+		localPort: actualPort,
+	}, nil
+}
+
+func (c *WGConn) Rebind() error {
+	// WireGuard's bind does not support rebinding in place.
+	return nil
+}
+
+func (c *WGConn) LocalAddr() (netip.AddrPort, error) {
+	if !c.localIP.IsValid() || c.localIP.IsUnspecified() {
+		// Fallback to wildcard IPv4 for display purposes.
+		return netip.AddrPortFrom(netip.IPv4Unspecified(), c.localPort), nil
+	}
+	return netip.AddrPortFrom(c.localIP, c.localPort), nil
+}
+
+func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) {
+	batchSize := c.batch
+	packets := make([][]byte, batchSize)
+	for i := range packets {
+		packets[i] = make([]byte, MTU)
+	}
+	sizes := make([]int, batchSize)
+	endpoints := make([]wgconn.Endpoint, batchSize)
+
+	for {
+		if c.closed.Load() {
+			return
+		}
+		n, err := fn(packets, sizes, endpoints)
+		if err != nil {
+			if errors.Is(err, net.ErrClosed) {
+				return
+			}
+			if c.l != nil {
+				c.l.WithError(err).Debug("wireguard UDP listener receive error")
+			}
+			continue
+		}
+		for i := 0; i < n; i++ {
+			if sizes[i] == 0 {
+				continue
+			}
+			stdEp, ok := endpoints[i].(*wgconn.StdNetEndpoint)
+			if !ok {
+				if c.l != nil {
+					c.l.Warn("wireguard UDP listener received unexpected endpoint type")
+				}
+				continue
+			}
+			addr := stdEp.AddrPort
+			r(addr, packets[i][:sizes[i]])
+			endpoints[i] = nil
+		}
+	}
+}
+
+func (c *WGConn) ListenOut(r EncReader) {
+	for _, fn := range c.recvers {
+		go c.listen(fn, r)
+	}
+}
+
+func (c *WGConn) WriteTo(b []byte, addr netip.AddrPort) error {
+	if len(b) == 0 {
+		return nil
+	}
+	if c.closed.Load() {
+		return net.ErrClosed
+	}
+	ep := &wgconn.StdNetEndpoint{AddrPort: addr}
+	return c.bind.Send([][]byte{b}, ep)
+}
+
+func (c *WGConn) ReloadConfig(*config.C) {
+	// WireGuard bind currently does not expose runtime configuration knobs.
+}
+
+func (c *WGConn) Close() error {
+	var err error
+	c.closeOnce.Do(func() {
+		c.closed.Store(true)
+		err = c.bind.Close()
+	})
+	return err
+}

+ 15 - 0
udp/wireguard_conn_unsupported.go

@@ -0,0 +1,15 @@
+//go:build !linux || android || e2e_testing
+
+package udp
+
+import (
+	"fmt"
+	"net/netip"
+
+	"github.com/sirupsen/logrus"
+)
+
+// NewWireguardListener is only available on Linux builds.
+func NewWireguardListener(*logrus.Logger, netip.Addr, int, bool, int) (Conn, error) {
+	return nil, fmt.Errorf("wireguard experimental UDP listener is only supported on Linux")
+}

+ 513 - 0
wgstack/conn/bind_std.go

@@ -0,0 +1,513 @@
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package conn
+
+import (
+	"context"
+	"errors"
+	"net"
+	"net/netip"
+	"runtime"
+	"strconv"
+	"sync"
+	"syscall"
+
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+	"golang.org/x/sys/unix"
+)
+
+var (
+	_ Bind = (*StdNetBind)(nil)
+)
+
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
+type StdNetBind struct {
+	mu     sync.Mutex // protects all fields except as specified
+	ipv4   *net.UDPConn
+	ipv6   *net.UDPConn
+	ipv4PC *ipv4.PacketConn // will be nil on non-Linux
+	ipv6PC *ipv6.PacketConn // will be nil on non-Linux
+
+	// these three fields are not guarded by mu
+	udpAddrPool  sync.Pool
+	ipv4MsgsPool sync.Pool
+	ipv6MsgsPool sync.Pool
+
+	blackhole4 bool
+	blackhole6 bool
+
+	listenAddr4 string
+	listenAddr6 string
+	bindV4      bool
+	bindV6      bool
+	reusePort   bool
+}
+
+func newStdNetBind() *StdNetBind {
+	return &StdNetBind{
+		udpAddrPool: sync.Pool{
+			New: func() any {
+				return &net.UDPAddr{
+					IP: make([]byte, 16),
+				}
+			},
+		},
+
+		ipv4MsgsPool: sync.Pool{
+			New: func() any {
+				msgs := make([]ipv4.Message, IdealBatchSize)
+				for i := range msgs {
+					msgs[i].Buffers = make(net.Buffers, 1)
+					msgs[i].OOB = make([]byte, srcControlSize)
+				}
+				return &msgs
+			},
+		},
+
+		ipv6MsgsPool: sync.Pool{
+			New: func() any {
+				msgs := make([]ipv6.Message, IdealBatchSize)
+				for i := range msgs {
+					msgs[i].Buffers = make(net.Buffers, 1)
+					msgs[i].OOB = make([]byte, srcControlSize)
+				}
+				return &msgs
+			},
+		},
+		bindV4:    true,
+		bindV6:    true,
+		reusePort: false,
+	}
+}
+
+// NewStdNetBind creates a bind that listens on all interfaces.
+func NewStdNetBind() *StdNetBind {
+	return newStdNetBind()
+}
+
+// NewStdNetBindForAddr creates a bind that listens on a specific address.
+// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the
+// IPv6 socket will be created.
+func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind {
+	b := newStdNetBind()
+	if addr.IsValid() {
+		if addr.Is4() {
+			b.listenAddr4 = addr.Unmap().String()
+			b.bindV4 = true
+			b.bindV6 = false
+		} else {
+			b.listenAddr6 = addr.Unmap().String()
+			b.bindV6 = true
+			b.bindV4 = false
+		}
+	}
+	b.reusePort = reusePort
+	return b
+}
+
+type StdNetEndpoint struct {
+	// AddrPort is the endpoint destination.
+	netip.AddrPort
+	// src is the current sticky source address and interface index, if supported.
+	src struct {
+		netip.Addr
+		ifidx int32
+	}
+}
+
+var (
+	_ Bind     = (*StdNetBind)(nil)
+	_ Endpoint = &StdNetEndpoint{}
+)
+
+func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
+	e, err := netip.ParseAddrPort(s)
+	if err != nil {
+		return nil, err
+	}
+	return &StdNetEndpoint{
+		AddrPort: e,
+	}, nil
+}
+
+func (e *StdNetEndpoint) ClearSrc() {
+	e.src.ifidx = 0
+	e.src.Addr = netip.Addr{}
+}
+
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+	return e.AddrPort.Addr()
+}
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+	return e.src.Addr
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+	return e.src.ifidx
+}
+
+func (e *StdNetEndpoint) DstToBytes() []byte {
+	b, _ := e.AddrPort.MarshalBinary()
+	return b
+}
+
+func (e *StdNetEndpoint) DstToString() string {
+	return e.AddrPort.String()
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+	return e.src.Addr.String()
+}
+
+func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) {
+	lc := listenConfig()
+	if s.reusePort {
+		base := lc.Control
+		lc.Control = func(network, address string, c syscall.RawConn) error {
+			if base != nil {
+				if err := base(network, address, c); err != nil {
+					return err
+				}
+			}
+			return c.Control(func(fd uintptr) {
+				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
+			})
+		}
+	}
+
+	addr := ":" + strconv.Itoa(port)
+	if host != "" {
+		addr = net.JoinHostPort(host, strconv.Itoa(port))
+	}
+
+	conn, err := lc.ListenPacket(context.Background(), network, addr)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	// Retrieve port.
+	laddr := conn.LocalAddr()
+	uaddr, err := net.ResolveUDPAddr(
+		laddr.Network(),
+		laddr.String(),
+	)
+	if err != nil {
+		return nil, 0, err
+	}
+	return conn.(*net.UDPConn), uaddr.Port, nil
+}
+
+func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) {
+	if !s.bindV4 {
+		return nil, nil, port, nil
+	}
+	host := s.listenAddr4
+	conn, actualPort, err := s.listenNet("udp4", host, port)
+	if err != nil {
+		if errors.Is(err, syscall.EAFNOSUPPORT) {
+			return nil, nil, port, nil
+		}
+		return nil, nil, port, err
+	}
+	if runtime.GOOS != "linux" {
+		return conn, nil, actualPort, nil
+	}
+	pc := ipv4.NewPacketConn(conn)
+	return conn, pc, actualPort, nil
+}
+
+func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) {
+	if !s.bindV6 {
+		return nil, nil, port, nil
+	}
+	host := s.listenAddr6
+	conn, actualPort, err := s.listenNet("udp6", host, port)
+	if err != nil {
+		if errors.Is(err, syscall.EAFNOSUPPORT) {
+			return nil, nil, port, nil
+		}
+		return nil, nil, port, err
+	}
+	if runtime.GOOS != "linux" {
+		return conn, nil, actualPort, nil
+	}
+	pc := ipv6.NewPacketConn(conn)
+	return conn, pc, actualPort, nil
+}
+
+func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	var err error
+	var tries int
+
+	if s.ipv4 != nil || s.ipv6 != nil {
+		return nil, 0, ErrBindAlreadyOpen
+	}
+
+	// Attempt to open ipv4 and ipv6 listeners on the same port.
+	// If uport is 0, we can retry on failure.
+again:
+	port := int(uport)
+	var v4conn *net.UDPConn
+	var v6conn *net.UDPConn
+	var v4pc *ipv4.PacketConn
+	var v6pc *ipv6.PacketConn
+
+	v4conn, v4pc, port, err = s.openIPv4(port)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	// Listen on the same port as we're using for ipv4.
+	v6conn, v6pc, port, err = s.openIPv6(port)
+	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
+		if v4conn != nil {
+			v4conn.Close()
+		}
+		tries++
+		goto again
+	}
+	if err != nil {
+		if v4conn != nil {
+			v4conn.Close()
+		}
+		return nil, 0, err
+	}
+
+	var fns []ReceiveFunc
+	if v4conn != nil {
+		s.ipv4 = v4conn
+		if v4pc != nil {
+			s.ipv4PC = v4pc
+		}
+		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
+	}
+	if v6conn != nil {
+		s.ipv6 = v6conn
+		if v6pc != nil {
+			s.ipv6PC = v6pc
+		}
+		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
+	}
+	if len(fns) == 0 {
+		return nil, 0, syscall.EAFNOSUPPORT
+	}
+
+	return fns, uint16(port), nil
+}
+
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
+	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+		msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+		defer s.ipv4MsgsPool.Put(msgs)
+		for i := range bufs {
+			(*msgs)[i].Buffers[0] = bufs[i]
+		}
+		var numMsgs int
+		if runtime.GOOS == "linux" && pc != nil {
+			numMsgs, err = pc.ReadBatch(*msgs, 0)
+			if err != nil {
+				return 0, err
+			}
+		} else {
+			msg := &(*msgs)[0]
+			msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+			if err != nil {
+				return 0, err
+			}
+			numMsgs = 1
+		}
+		for i := 0; i < numMsgs; i++ {
+			msg := &(*msgs)[i]
+			sizes[i] = msg.N
+			addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+			ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+			getSrcFromControl(msg.OOB[:msg.NN], ep)
+			eps[i] = ep
+		}
+		return numMsgs, nil
+	}
+}
+
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
+	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+		msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+		defer s.ipv6MsgsPool.Put(msgs)
+		for i := range bufs {
+			(*msgs)[i].Buffers[0] = bufs[i]
+		}
+		var numMsgs int
+		if runtime.GOOS == "linux" && pc != nil {
+			numMsgs, err = pc.ReadBatch(*msgs, 0)
+			if err != nil {
+				return 0, err
+			}
+		} else {
+			msg := &(*msgs)[0]
+			msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+			if err != nil {
+				return 0, err
+			}
+			numMsgs = 1
+		}
+		for i := 0; i < numMsgs; i++ {
+			msg := &(*msgs)[i]
+			sizes[i] = msg.N
+			addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+			ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
+			getSrcFromControl(msg.OOB[:msg.NN], ep)
+			eps[i] = ep
+		}
+		return numMsgs, nil
+	}
+}
+
+// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
+// rename the IdealBatchSize constant to BatchSize.
+func (s *StdNetBind) BatchSize() int {
+	if runtime.GOOS == "linux" {
+		return IdealBatchSize
+	}
+	return 1
+}
+
+func (s *StdNetBind) Close() error {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	var err1, err2 error
+	if s.ipv4 != nil {
+		err1 = s.ipv4.Close()
+		s.ipv4 = nil
+		s.ipv4PC = nil
+	}
+	if s.ipv6 != nil {
+		err2 = s.ipv6.Close()
+		s.ipv6 = nil
+		s.ipv6PC = nil
+	}
+	s.blackhole4 = false
+	s.blackhole6 = false
+	if err1 != nil {
+		return err1
+	}
+	return err2
+}
+
+func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
+	s.mu.Lock()
+	blackhole := s.blackhole4
+	conn := s.ipv4
+	var (
+		pc4 *ipv4.PacketConn
+		pc6 *ipv6.PacketConn
+	)
+	is6 := false
+	if endpoint.DstIP().Is6() {
+		blackhole = s.blackhole6
+		conn = s.ipv6
+		pc6 = s.ipv6PC
+		is6 = true
+	} else {
+		pc4 = s.ipv4PC
+	}
+	s.mu.Unlock()
+
+	if blackhole {
+		return nil
+	}
+	if conn == nil {
+		return syscall.EAFNOSUPPORT
+	}
+	if is6 {
+		return s.send6(conn, pc6, endpoint, bufs)
+	} else {
+		return s.send4(conn, pc4, endpoint, bufs)
+	}
+}
+
+func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
+	ua := s.udpAddrPool.Get().(*net.UDPAddr)
+	as4 := ep.DstIP().As4()
+	copy(ua.IP, as4[:])
+	ua.IP = ua.IP[:4]
+	ua.Port = int(ep.(*StdNetEndpoint).Port())
+	msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+	for i, buf := range bufs {
+		(*msgs)[i].Buffers[0] = buf
+		(*msgs)[i].Addr = ua
+		setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+	}
+	var (
+		n     int
+		err   error
+		start int
+	)
+	if runtime.GOOS == "linux" && pc != nil {
+		for {
+			n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
+			if err != nil || n == len((*msgs)[start:len(bufs)]) {
+				break
+			}
+			start += n
+		}
+	} else {
+		for i, buf := range bufs {
+			_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
+			if err != nil {
+				break
+			}
+		}
+	}
+	s.udpAddrPool.Put(ua)
+	s.ipv4MsgsPool.Put(msgs)
+	return err
+}
+
+func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
+	ua := s.udpAddrPool.Get().(*net.UDPAddr)
+	as16 := ep.DstIP().As16()
+	copy(ua.IP, as16[:])
+	ua.IP = ua.IP[:16]
+	ua.Port = int(ep.(*StdNetEndpoint).Port())
+	msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
+	for i, buf := range bufs {
+		(*msgs)[i].Buffers[0] = buf
+		(*msgs)[i].Addr = ua
+		setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
+	}
+	var (
+		n     int
+		err   error
+		start int
+	)
+	if runtime.GOOS == "linux" && pc != nil {
+		for {
+			n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
+			if err != nil || n == len((*msgs)[start:len(bufs)]) {
+				break
+			}
+			start += n
+		}
+	} else {
+		for i, buf := range bufs {
+			_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
+			if err != nil {
+				break
+			}
+		}
+	}
+	s.udpAddrPool.Put(ua)
+	s.ipv6MsgsPool.Put(msgs)
+	return err
+}

+ 131 - 0
wgstack/conn/conn.go

@@ -0,0 +1,131 @@
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package conn
+
+import (
+	"errors"
+	"fmt"
+	"net/netip"
+	"reflect"
+	"runtime"
+	"strings"
+)
+
+const (
+	IdealBatchSize = 128 // maximum number of packets handled per read and write
+)
+
+// A ReceiveFunc receives at least one packet from the network and writes them
+// into packets. On a successful read it returns the number of elements of
+// sizes, packets, and endpoints that should be evaluated. Some elements of
+// sizes may be zero, and callers should ignore them. Callers must pass a sizes
+// and eps slice with a length greater than or equal to the length of packets.
+// These lengths must not exceed the length of the associated Bind.BatchSize().
+type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
+
+// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
+//
+// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
+// depending on the platform-specific implementation.
+type Bind interface {
+	// Open puts the Bind into a listening state on a given port and reports the actual
+	// port that it bound to. Passing zero results in a random selection.
+	// fns is the set of functions that will be called to receive packets.
+	Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
+
+	// Close closes the Bind listener.
+	// All fns returned by Open must return net.ErrClosed after a call to Close.
+	Close() error
+
+	// SetMark sets the mark for each packet sent through this Bind.
+	// This mark is passed to the kernel as the socket option SO_MARK.
+	SetMark(mark uint32) error
+
+	// Send writes one or more packets in bufs to address ep. The length of
+	// bufs must not exceed BatchSize().
+	Send(bufs [][]byte, ep Endpoint) error
+
+	// ParseEndpoint creates a new endpoint from a string.
+	ParseEndpoint(s string) (Endpoint, error)
+
+	// BatchSize is the number of buffers expected to be passed to
+	// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
+	BatchSize() int
+}
+
+// BindSocketToInterface is implemented by Bind objects that support being
+// tied to a single network interface. Used by wireguard-windows.
+type BindSocketToInterface interface {
+	BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
+	BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
+}
+
+// PeekLookAtSocketFd is implemented by Bind objects that support having their
+// file descriptor peeked at. Used by wireguard-android.
+type PeekLookAtSocketFd interface {
+	PeekLookAtSocketFd4() (fd int, err error)
+	PeekLookAtSocketFd6() (fd int, err error)
+}
+
+// An Endpoint maintains the source/destination caching for a peer.
+//
+//	dst: the remote address of a peer ("endpoint" in uapi terminology)
+//	src: the local address from which datagrams originate going to the peer
+type Endpoint interface {
+	ClearSrc()           // clears the source address
+	SrcToString() string // returns the local source address (ip:port)
+	DstToString() string // returns the destination address (ip:port)
+	DstToBytes() []byte  // used for mac2 cookie calculations
+	DstIP() netip.Addr
+	SrcIP() netip.Addr
+}
+
+var (
+	ErrBindAlreadyOpen   = errors.New("bind is already open")
+	ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
+)
+
+func (fn ReceiveFunc) PrettyName() string {
+	name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
+	// 0. cheese/taco.beansIPv6.func12.func21218-fm
+	name = strings.TrimSuffix(name, "-fm")
+	// 1. cheese/taco.beansIPv6.func12.func21218
+	if idx := strings.LastIndexByte(name, '/'); idx != -1 {
+		name = name[idx+1:]
+		// 2. taco.beansIPv6.func12.func21218
+	}
+	for {
+		var idx int
+		for idx = len(name) - 1; idx >= 0; idx-- {
+			if name[idx] < '0' || name[idx] > '9' {
+				break
+			}
+		}
+		if idx == len(name)-1 {
+			break
+		}
+		const dotFunc = ".func"
+		if !strings.HasSuffix(name[:idx+1], dotFunc) {
+			break
+		}
+		name = name[:idx+1-len(dotFunc)]
+		// 3. taco.beansIPv6.func12
+		// 4. taco.beansIPv6
+	}
+	if idx := strings.LastIndexByte(name, '.'); idx != -1 {
+		name = name[idx+1:]
+		// 5. beansIPv6
+	}
+	if name == "" {
+		return fmt.Sprintf("%p", fn)
+	}
+	if strings.HasSuffix(name, "IPv4") {
+		return "v4"
+	}
+	if strings.HasSuffix(name, "IPv6") {
+		return "v6"
+	}
+	return name
+}

+ 42 - 0
wgstack/conn/controlfns.go

@@ -0,0 +1,42 @@
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package conn
+
+import (
+	"net"
+	"syscall"
+)
+
+// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
+// the max supported by a default configuration of macOS. Some platforms will
+// silently clamp the value to other maximums, such as linux clamping to
+// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
+// around this limitation)
+const socketBufferSize = 7 << 20
+
+// controlFn is the callback function signature from net.ListenConfig.Control.
+// It is used to apply platform specific configuration to the socket prior to
+// bind.
+type controlFn func(network, address string, c syscall.RawConn) error
+
+// controlFns is a list of functions that are called from the listen config
+// that can apply socket options.
+var controlFns = []controlFn{}
+
+// listenConfig returns a net.ListenConfig that applies the controlFns to the
+// socket prior to bind. This is used to apply socket buffer sizing and packet
+// information OOB configuration for sticky sockets.
+func listenConfig() *net.ListenConfig {
+	return &net.ListenConfig{
+		Control: func(network, address string, c syscall.RawConn) error {
+			for _, fn := range controlFns {
+				if err := fn(network, address, c); err != nil {
+					return err
+				}
+			}
+			return nil
+		},
+	}
+}

+ 62 - 0
wgstack/conn/controlfns_linux.go

@@ -0,0 +1,62 @@
+//go:build linux
+
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package conn
+
+import (
+	"fmt"
+	"runtime"
+	"syscall"
+
+	"golang.org/x/sys/unix"
+)
+
+func init() {
+	controlFns = append(controlFns,
+
+		// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
+		// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
+		// fail silently - the result of failure is lower performance on very fast
+		// links or high latency links.
+		func(network, address string, c syscall.RawConn) error {
+			return c.Control(func(fd uintptr) {
+				// Set up to *mem_max
+				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
+				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
+				// Set beyond *mem_max if CAP_NET_ADMIN
+				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
+				_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
+			})
+		},
+
+		// Enable receiving of the packet information (IP_PKTINFO for IPv4,
+		// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
+		func(network, address string, c syscall.RawConn) error {
+			var err error
+			switch network {
+			case "udp4":
+				if runtime.GOOS != "android" {
+					c.Control(func(fd uintptr) {
+						err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
+					})
+				}
+			case "udp6":
+				c.Control(func(fd uintptr) {
+					if runtime.GOOS != "android" {
+						err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
+						if err != nil {
+							return
+						}
+					}
+					err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
+				})
+			default:
+				err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
+			}
+			return err
+		},
+	)
+}

+ 9 - 0
wgstack/conn/default.go

@@ -0,0 +1,9 @@
+//go:build !windows
+
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package conn
+
+func NewDefaultBind() Bind { return NewStdNetBind() }

+ 64 - 0
wgstack/conn/mark_unix.go

@@ -0,0 +1,64 @@
+//go:build linux || openbsd || freebsd
+
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package conn
+
+import (
+	"runtime"
+
+	"golang.org/x/sys/unix"
+)
+
+var fwmarkIoctl int
+
+func init() {
+	switch runtime.GOOS {
+	case "linux", "android":
+		fwmarkIoctl = 36 /* unix.SO_MARK */
+	case "freebsd":
+		fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
+	case "openbsd":
+		fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
+	}
+}
+
+func (s *StdNetBind) SetMark(mark uint32) error {
+	var operr error
+	if fwmarkIoctl == 0 {
+		return nil
+	}
+	if s.ipv4 != nil {
+		fd, err := s.ipv4.SyscallConn()
+		if err != nil {
+			return err
+		}
+		err = fd.Control(func(fd uintptr) {
+			operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
+		})
+		if err == nil {
+			err = operr
+		}
+		if err != nil {
+			return err
+		}
+	}
+	if s.ipv6 != nil {
+		fd, err := s.ipv6.SyscallConn()
+		if err != nil {
+			return err
+		}
+		err = fd.Control(func(fd uintptr) {
+			operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
+		})
+		if err == nil {
+			err = operr
+		}
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}

+ 116 - 0
wgstack/conn/sticky_linux.go

@@ -0,0 +1,116 @@
+//go:build linux && !android
+
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package conn
+
+import (
+	"net/netip"
+	"unsafe"
+
+	"golang.org/x/sys/unix"
+)
+
+// getSrcFromControl parses the control for PKTINFO and if found updates ep with
+// the source information found.
+func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
+	ep.ClearSrc()
+
+	var (
+		hdr  unix.Cmsghdr
+		data []byte
+		rem  []byte = control
+		err  error
+	)
+
+	for len(rem) > unix.SizeofCmsghdr {
+		hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
+		if err != nil {
+			return
+		}
+
+		if hdr.Level == unix.IPPROTO_IP &&
+			hdr.Type == unix.IP_PKTINFO {
+
+			info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
+			ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
+			ep.src.ifidx = info.Ifindex
+
+			return
+		}
+
+		if hdr.Level == unix.IPPROTO_IPV6 &&
+			hdr.Type == unix.IPV6_PKTINFO {
+
+			info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
+			ep.src.Addr = netip.AddrFrom16(info.Addr)
+			ep.src.ifidx = int32(info.Ifindex)
+
+			return
+		}
+	}
+}
+
+// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
+// panics if buf is of insufficient size.
+func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
+	size := int(unsafe.Sizeof(t))
+	if len(buf) < size {
+		panic("pktInfoFromBuf: buffer too small")
+	}
+	copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
+	return t
+}
+
+// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
+// and source ifindex found in ep. control's len will be set to 0 in the event
+// that ep is a default value.
+func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
+	*control = (*control)[:cap(*control)]
+	if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
+		*control = (*control)[:0]
+		return
+	}
+
+	if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
+		*control = (*control)[:0]
+		return
+	}
+
+	if len(*control) < srcControlSize {
+		*control = (*control)[:0]
+		return
+	}
+
+	hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
+	if ep.SrcIP().Is4() {
+		hdr.Level = unix.IPPROTO_IP
+		hdr.Type = unix.IP_PKTINFO
+		hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+
+		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
+		info.Ifindex = ep.src.ifidx
+		if ep.SrcIP().IsValid() {
+			info.Spec_dst = ep.SrcIP().As4()
+		}
+		*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
+	} else {
+		hdr.Level = unix.IPPROTO_IPV6
+		hdr.Type = unix.IPV6_PKTINFO
+		hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+
+		info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
+		info.Ifindex = uint32(ep.src.ifidx)
+		if ep.SrcIP().IsValid() {
+			info.Addr = ep.SrcIP().As16()
+		}
+		*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
+	}
+
+}
+
+var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
+
+const StdNetSupportsStickySockets = true

+ 42 - 0
wgstack/tun/checksum.go

@@ -0,0 +1,42 @@
+package tun
+
+import "encoding/binary"
+
+// TODO: Explore SIMD and/or other assembly optimizations.
+func checksumNoFold(b []byte, initial uint64) uint64 {
+	ac := initial
+	i := 0
+	n := len(b)
+	for n >= 4 {
+		ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
+		n -= 4
+		i += 4
+	}
+	for n >= 2 {
+		ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
+		n -= 2
+		i += 2
+	}
+	if n == 1 {
+		ac += uint64(b[i]) << 8
+	}
+	return ac
+}
+
+func checksum(b []byte, initial uint64) uint16 {
+	ac := checksumNoFold(b, initial)
+	ac = (ac >> 16) + (ac & 0xffff)
+	ac = (ac >> 16) + (ac & 0xffff)
+	ac = (ac >> 16) + (ac & 0xffff)
+	ac = (ac >> 16) + (ac & 0xffff)
+	return uint16(ac)
+}
+
+func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
+	sum := checksumNoFold(srcAddr, 0)
+	sum = checksumNoFold(dstAddr, sum)
+	sum = checksumNoFold([]byte{0, protocol}, sum)
+	tmp := make([]byte, 2)
+	binary.BigEndian.PutUint16(tmp, totalLen)
+	return checksumNoFold(tmp, sum)
+}

+ 3 - 0
wgstack/tun/export.go

@@ -0,0 +1,3 @@
+package tun
+
+const VirtioNetHdrLen = virtioNetHdrLen

+ 630 - 0
wgstack/tun/tcp_offload_linux.go

@@ -0,0 +1,630 @@
+//go:build linux
+
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package tun
+
+import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"io"
+	"unsafe"
+
+	wgconn "github.com/slackhq/nebula/wgstack/conn"
+	"golang.org/x/sys/unix"
+)
+
+var ErrTooManySegments = errors.New("tun: too many segments for TSO")
+
+const tcpFlagsOffset = 13
+
+const (
+	tcpFlagFIN uint8 = 0x01
+	tcpFlagPSH uint8 = 0x08
+	tcpFlagACK uint8 = 0x10
+)
+
+// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
+// kernel symbol is virtio_net_hdr.
+type virtioNetHdr struct {
+	flags      uint8
+	gsoType    uint8
+	hdrLen     uint16
+	gsoSize    uint16
+	csumStart  uint16
+	csumOffset uint16
+}
+
+func (v *virtioNetHdr) decode(b []byte) error {
+	if len(b) < virtioNetHdrLen {
+		return io.ErrShortBuffer
+	}
+	copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
+	return nil
+}
+
+func (v *virtioNetHdr) encode(b []byte) error {
+	if len(b) < virtioNetHdrLen {
+		return io.ErrShortBuffer
+	}
+	copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
+	return nil
+}
+
+const (
+	// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
+	// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
+	virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
+)
+
+// flowKey represents the key for a flow.
+type flowKey struct {
+	srcAddr, dstAddr [16]byte
+	srcPort, dstPort uint16
+	rxAck            uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+}
+
+// tcpGROTable holds flow and coalescing information for the purposes of GRO.
+type tcpGROTable struct {
+	itemsByFlow map[flowKey][]tcpGROItem
+	itemsPool   [][]tcpGROItem
+}
+
+func newTCPGROTable() *tcpGROTable {
+	t := &tcpGROTable{
+		itemsByFlow: make(map[flowKey][]tcpGROItem, wgconn.IdealBatchSize),
+		itemsPool:   make([][]tcpGROItem, wgconn.IdealBatchSize),
+	}
+	for i := range t.itemsPool {
+		t.itemsPool[i] = make([]tcpGROItem, 0, wgconn.IdealBatchSize)
+	}
+	return t
+}
+
+func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
+	key := flowKey{}
+	addrSize := dstAddr - srcAddr
+	copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
+	copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
+	key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
+	key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
+	key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+	return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
+	key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+	items, ok := t.itemsByFlow[key]
+	if ok {
+		return items, ok
+	}
+	// TODO: insert() performs another map lookup. This could be rearranged to avoid.
+	t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
+	return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
+	key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+	item := tcpGROItem{
+		key:       key,
+		bufsIndex: uint16(bufsIndex),
+		gsoSize:   uint16(len(pkt[tcphOffset+tcphLen:])),
+		iphLen:    uint8(tcphOffset),
+		tcphLen:   uint8(tcphLen),
+		sentSeq:   binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
+		pshSet:    pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
+	}
+	items, ok := t.itemsByFlow[key]
+	if !ok {
+		items = t.newItems()
+	}
+	items = append(items, item)
+	t.itemsByFlow[key] = items
+}
+
+func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
+	items, _ := t.itemsByFlow[item.key]
+	items[i] = item
+}
+
+func (t *tcpGROTable) deleteAt(key flowKey, i int) {
+	items, _ := t.itemsByFlow[key]
+	items = append(items[:i], items[i+1:]...)
+	t.itemsByFlow[key] = items
+}
+
+// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type tcpGROItem struct {
+	key       flowKey
+	sentSeq   uint32 // the sequence number
+	bufsIndex uint16 // the index into the original bufs slice
+	numMerged uint16 // the number of packets merged into this item
+	gsoSize   uint16 // payload size
+	iphLen    uint8  // ip header len
+	tcphLen   uint8  // tcp header len
+	pshSet    bool   // psh flag is set
+}
+
+func (t *tcpGROTable) newItems() []tcpGROItem {
+	var items []tcpGROItem
+	items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
+	return items
+}
+
+func (t *tcpGROTable) reset() {
+	for k, items := range t.itemsByFlow {
+		items = items[:0]
+		t.itemsPool = append(t.itemsPool, items)
+		delete(t.itemsByFlow, k)
+	}
+}
+
+// canCoalesce represents the outcome of checking if two TCP packets are
+// candidates for coalescing.
+type canCoalesce int
+
+const (
+	coalescePrepend     canCoalesce = -1
+	coalesceUnavailable canCoalesce = 0
+	coalesceAppend      canCoalesce = 1
+)
+
+// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. This function makes considerations that match the kernel's
+// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
+func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+	pktTarget := bufs[item.bufsIndex][bufsOffset:]
+	if tcphLen != item.tcphLen {
+		// cannot coalesce with unequal tcp options len
+		return coalesceUnavailable
+	}
+	if tcphLen > 20 {
+		if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
+			// cannot coalesce with unequal tcp options
+			return coalesceUnavailable
+		}
+	}
+	if pkt[0]>>4 == 6 {
+		if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
+			// cannot coalesce with unequal Traffic class values
+			return coalesceUnavailable
+		}
+		if pkt[7] != pktTarget[7] {
+			// cannot coalesce with unequal Hop limit values
+			return coalesceUnavailable
+		}
+	} else {
+		if pkt[1] != pktTarget[1] {
+			// cannot coalesce with unequal ToS values
+			return coalesceUnavailable
+		}
+		if pkt[6]>>5 != pktTarget[6]>>5 {
+			// cannot coalesce with unequal DF or reserved bits. MF is checked
+			// further up the stack.
+			return coalesceUnavailable
+		}
+		if pkt[8] != pktTarget[8] {
+			// cannot coalesce with unequal TTL values
+			return coalesceUnavailable
+		}
+	}
+	// seq adjacency
+	lhsLen := item.gsoSize
+	lhsLen += item.numMerged * item.gsoSize
+	if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
+		if item.pshSet {
+			// We cannot append to a segment that has the PSH flag set, PSH
+			// can only be set on the final segment in a reassembled group.
+			return coalesceUnavailable
+		}
+		if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
+			// A smaller than gsoSize packet has been appended previously.
+			// Nothing can come after a smaller packet on the end.
+			return coalesceUnavailable
+		}
+		if gsoSize > item.gsoSize {
+			// We cannot have a larger packet following a smaller one.
+			return coalesceUnavailable
+		}
+		return coalesceAppend
+	} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
+		if pshSet {
+			// We cannot prepend with a segment that has the PSH flag set, PSH
+			// can only be set on the final segment in a reassembled group.
+			return coalesceUnavailable
+		}
+		if gsoSize < item.gsoSize {
+			// We cannot have a larger packet following a smaller one.
+			return coalesceUnavailable
+		}
+		if gsoSize > item.gsoSize && item.numMerged > 0 {
+			// There's at least one previous merge, and we're larger than all
+			// previous. This would put multiple smaller packets on the end.
+			return coalesceUnavailable
+		}
+		return coalescePrepend
+	}
+	return coalesceUnavailable
+}
+
+func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
+	srcAddrAt := ipv4SrcAddrOffset
+	addrSize := 4
+	if isV6 {
+		srcAddrAt = ipv6SrcAddrOffset
+		addrSize = 16
+	}
+	tcpTotalLen := uint16(len(pkt) - int(iphLen))
+	tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
+	return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
+}
+
+// coalesceResult represents the result of attempting to coalesce two TCP
+// packets.
+type coalesceResult int
+
+const (
+	coalesceInsufficientCap coalesceResult = 0
+	coalescePSHEnding       coalesceResult = 1
+	coalesceItemInvalidCSum coalesceResult = 2
+	coalescePktInvalidCSum  coalesceResult = 3
+	coalesceSuccess         coalesceResult = 4
+)
+
+// coalesceTCPPackets attempts to coalesce pkt with the packet described by
+// item, returning the outcome. This function may swap bufs elements in the
+// event of a prepend as item's bufs index is already being tracked for writing
+// to a Device.
+func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+	var pktHead []byte // the packet that will end up at the front
+	headersLen := item.iphLen + item.tcphLen
+	coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+	// Copy data
+	if mode == coalescePrepend {
+		pktHead = pkt
+		if cap(pkt)-bufsOffset < coalescedLen {
+			// We don't want to allocate a new underlying array if capacity is
+			// too small.
+			return coalesceInsufficientCap
+		}
+		if pshSet {
+			return coalescePSHEnding
+		}
+		if item.numMerged == 0 {
+			if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
+				return coalesceItemInvalidCSum
+			}
+		}
+		if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+			return coalescePktInvalidCSum
+		}
+		item.sentSeq = seq
+		extendBy := coalescedLen - len(pktHead)
+		bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
+		copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
+		// Flip the slice headers in bufs as part of prepend. The index of item
+		// is already being tracked for writing.
+		bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
+	} else {
+		pktHead = bufs[item.bufsIndex][bufsOffset:]
+		if cap(pktHead)-bufsOffset < coalescedLen {
+			// We don't want to allocate a new underlying array if capacity is
+			// too small.
+			return coalesceInsufficientCap
+		}
+		if item.numMerged == 0 {
+			if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
+				return coalesceItemInvalidCSum
+			}
+		}
+		if !tcpChecksumValid(pkt, item.iphLen, isV6) {
+			return coalescePktInvalidCSum
+		}
+		if pshSet {
+			// We are appending a segment with PSH set.
+			item.pshSet = pshSet
+			pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
+		}
+		extendBy := len(pkt) - int(headersLen)
+		bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+		copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+	}
+
+	if gsoSize > item.gsoSize {
+		item.gsoSize = gsoSize
+	}
+	hdr := virtioNetHdr{
+		flags:      unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+		hdrLen:     uint16(headersLen),
+		gsoSize:    uint16(item.gsoSize),
+		csumStart:  uint16(item.iphLen),
+		csumOffset: 16,
+	}
+
+	// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
+	// (IPv4) header checksum.
+	if isV6 {
+		hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+		binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
+	} else {
+		hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+		pktHead[10], pktHead[11] = 0, 0                               // clear checksum field
+		binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
+		iphCSum := ^checksum(pktHead[:item.iphLen], 0)                // compute checksum
+		binary.BigEndian.PutUint16(pktHead[10:], iphCSum)             // set checksum field
+	}
+	hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
+
+	// Calculate the pseudo header checksum and place it at the TCP checksum
+	// offset. Downstream checksum offloading will combine this with computation
+	// of the tcp header and payload checksum.
+	addrLen := 4
+	addrOffset := ipv4SrcAddrOffset
+	if isV6 {
+		addrLen = 16
+		addrOffset = ipv6SrcAddrOffset
+	}
+	srcAddrAt := bufsOffset + addrOffset
+	srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+	dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+	psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
+	binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+
+	item.numMerged++
+	return coalesceSuccess
+}
+
+const (
+	ipv4FlagMoreFragments uint8 = 0x20
+)
+
+const (
+	ipv4SrcAddrOffset = 12
+	ipv6SrcAddrOffset = 8
+	maxUint16         = 1<<16 - 1
+)
+
+// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It will return false when pktI is not
+// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
+// should be written to the Device.
+func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
+	pkt := bufs[pktI][offset:]
+	if len(pkt) > maxUint16 {
+		// A valid IPv4 or IPv6 packet will never exceed this.
+		return false
+	}
+	iphLen := int((pkt[0] & 0x0F) * 4)
+	if isV6 {
+		iphLen = 40
+		ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+		if ipv6HPayloadLen != len(pkt)-iphLen {
+			return false
+		}
+	} else {
+		totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+		if totalLen != len(pkt) {
+			return false
+		}
+	}
+	if len(pkt) < iphLen {
+		return false
+	}
+	tcphLen := int((pkt[iphLen+12] >> 4) * 4)
+	if tcphLen < 20 || tcphLen > 60 {
+		return false
+	}
+	if len(pkt) < iphLen+tcphLen {
+		return false
+	}
+	if !isV6 {
+		if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+			// no GRO support for fragmented segments for now
+			return false
+		}
+	}
+	tcpFlags := pkt[iphLen+tcpFlagsOffset]
+	var pshSet bool
+	// not a candidate if any non-ACK flags (except PSH+ACK) are set
+	if tcpFlags != tcpFlagACK {
+		if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
+			return false
+		}
+		pshSet = true
+	}
+	gsoSize := uint16(len(pkt) - tcphLen - iphLen)
+	// not a candidate if payload len is 0
+	if gsoSize < 1 {
+		return false
+	}
+	seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
+	srcAddrOffset := ipv4SrcAddrOffset
+	addrLen := 4
+	if isV6 {
+		srcAddrOffset = ipv6SrcAddrOffset
+		addrLen = 16
+	}
+	items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+	if !existing {
+		return false
+	}
+	for i := len(items) - 1; i >= 0; i-- {
+		// In the best case of packets arriving in order iterating in reverse is
+		// more efficient if there are multiple items for a given flow. This
+		// also enables a natural table.deleteAt() in the
+		// coalesceItemInvalidCSum case without the need for index tracking.
+		// This algorithm makes a best effort to coalesce in the event of
+		// unordered packets, where pkt may land anywhere in items from a
+		// sequence number perspective, however once an item is inserted into
+		// the table it is never compared across other items later.
+		item := items[i]
+		can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
+		if can != coalesceUnavailable {
+			result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
+			switch result {
+			case coalesceSuccess:
+				table.updateAt(item, i)
+				return true
+			case coalesceItemInvalidCSum:
+				// delete the item with an invalid csum
+				table.deleteAt(item.key, i)
+			case coalescePktInvalidCSum:
+				// no point in inserting an item that we can't coalesce
+				return false
+			default:
+			}
+		}
+	}
+	// failed to coalesce with any other packets; store the item in the flow
+	table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+	return false
+}
+
+func isTCP4NoIPOptions(b []byte) bool {
+	if len(b) < 40 {
+		return false
+	}
+	if b[0]>>4 != 4 {
+		return false
+	}
+	if b[0]&0x0F != 5 {
+		return false
+	}
+	if b[9] != unix.IPPROTO_TCP {
+		return false
+	}
+	return true
+}
+
+func isTCP6NoEH(b []byte) bool {
+	if len(b) < 60 {
+		return false
+	}
+	if b[0]>>4 != 6 {
+		return false
+	}
+	if b[6] != unix.IPPROTO_TCP {
+		return false
+	}
+	return true
+}
+
+// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
+// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
+// empty (but non-nil), and are passed in to save allocs as the caller may reset
+// and recycle them across vectors of packets.
+func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
+	for i := range bufs {
+		if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
+			return errors.New("invalid offset")
+		}
+		var coalesced bool
+		switch {
+		case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
+			coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
+		case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
+			coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
+		}
+		if !coalesced {
+			hdr := virtioNetHdr{}
+			err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
+			if err != nil {
+				return err
+			}
+			*toWrite = append(*toWrite, i)
+		}
+	}
+	return nil
+}
+
+// tcpTSO splits packets from in into outBuffs, writing the size of each
+// element into sizes. It returns the number of buffers populated, and/or an
+// error.
+func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
+	iphLen := int(hdr.csumStart)
+	srcAddrOffset := ipv6SrcAddrOffset
+	addrLen := 16
+	if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+		in[10], in[11] = 0, 0 // clear ipv4 header checksum
+		srcAddrOffset = ipv4SrcAddrOffset
+		addrLen = 4
+	}
+	tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
+	in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
+	firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+	nextSegmentDataAt := int(hdr.hdrLen)
+	i := 0
+	for ; nextSegmentDataAt < len(in); i++ {
+		if i == len(outBuffs) {
+			return i - 1, ErrTooManySegments
+		}
+		nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
+		if nextSegmentEnd > len(in) {
+			nextSegmentEnd = len(in)
+		}
+		segmentDataLen := nextSegmentEnd - nextSegmentDataAt
+		totalLen := int(hdr.hdrLen) + segmentDataLen
+		sizes[i] = totalLen
+		out := outBuffs[i][outOffset:]
+
+		copy(out, in[:iphLen])
+		if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+			// For IPv4 we are responsible for incrementing the ID field,
+			// updating the total len field, and recalculating the header
+			// checksum.
+			if i > 0 {
+				id := binary.BigEndian.Uint16(out[4:])
+				id += uint16(i)
+				binary.BigEndian.PutUint16(out[4:], id)
+			}
+			binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
+			ipv4CSum := ^checksum(out[:iphLen], 0)
+			binary.BigEndian.PutUint16(out[10:], ipv4CSum)
+		} else {
+			// For IPv6 we are responsible for updating the payload length field.
+			binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
+		}
+
+		// TCP header
+		copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
+		tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+		binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+		if nextSegmentEnd != len(in) {
+			// FIN and PSH should only be set on last segment
+			clearFlags := tcpFlagFIN | tcpFlagPSH
+			out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+		}
+
+		// payload
+		copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
+
+		// TCP checksum
+		tcpHLen := int(hdr.hdrLen - hdr.csumStart)
+		tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
+		tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
+		tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
+		binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
+
+		nextSegmentDataAt += int(hdr.gsoSize)
+	}
+	return i, nil
+}
+
+func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
+	cSumAt := cSumStart + cSumOffset
+	// The initial value at the checksum offset should be summed with the
+	// checksum we compute. This is typically the pseudo-header checksum.
+	initial := binary.BigEndian.Uint16(in[cSumAt:])
+	in[cSumAt], in[cSumAt+1] = 0, 0
+	binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
+	return nil
+}

+ 52 - 0
wgstack/tun/tun.go

@@ -0,0 +1,52 @@
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package tun
+
+import (
+	"os"
+)
+
+type Event int
+
+const (
+	EventUp = 1 << iota
+	EventDown
+	EventMTUUpdate
+)
+
+type Device interface {
+	// File returns the file descriptor of the device.
+	File() *os.File
+
+	// Read one or more packets from the Device (without any additional headers).
+	// On a successful read it returns the number of packets read, and sets
+	// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
+	// A nonzero offset can be used to instruct the Device on where to begin
+	// reading into each element of the bufs slice.
+	Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
+
+	// Write one or more packets to the device (without any additional headers).
+	// On a successful write it returns the number of packets written. A nonzero
+	// offset can be used to instruct the Device on where to begin writing from
+	// each packet contained within the bufs slice.
+	Write(bufs [][]byte, offset int) (int, error)
+
+	// MTU returns the MTU of the Device.
+	MTU() (int, error)
+
+	// Name returns the current name of the Device.
+	Name() (string, error)
+
+	// Events returns a channel of type Event, which is fed Device events.
+	Events() <-chan Event
+
+	// Close stops the Device and closes the Event channel.
+	Close() error
+
+	// BatchSize returns the preferred/max number of packets that can be read or
+	// written in a single read/write call. BatchSize must not change over the
+	// lifetime of a Device.
+	BatchSize() int
+}

+ 652 - 0
wgstack/tun/tun_linux.go

@@ -0,0 +1,652 @@
+//go:build linux
+
+// SPDX-License-Identifier: MIT
+//
+// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+
+package tun
+
+/* Implementation of the TUN device interface for linux
+ */
+
+import (
+	"errors"
+	"fmt"
+	"os"
+	"sync"
+	"syscall"
+	"time"
+	"unsafe"
+
+	wgconn "github.com/slackhq/nebula/wgstack/conn"
+	"golang.org/x/sys/unix"
+	"golang.zx2c4.com/wireguard/rwcancel"
+)
+
+const (
+	cloneDevicePath = "/dev/net/tun"
+	ifReqSize       = unix.IFNAMSIZ + 64
+)
+
+type NativeTun struct {
+	tunFile                 *os.File
+	index                   int32      // if index
+	errors                  chan error // async error handling
+	events                  chan Event // device related events
+	netlinkSock             int
+	netlinkCancel           *rwcancel.RWCancel
+	hackListenerClosed      sync.Mutex
+	statusListenersShutdown chan struct{}
+	batchSize               int
+	vnetHdr                 bool
+
+	closeOnce sync.Once
+
+	nameOnce  sync.Once // guards calling initNameCache, which sets following fields
+	nameCache string    // name of interface
+	nameErr   error
+
+	readOpMu sync.Mutex                    // readOpMu guards readBuff
+	readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+	writeOpMu                  sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
+	toWrite                    []int
+	tcp4GROTable, tcp6GROTable *tcpGROTable
+}
+
+func (tun *NativeTun) File() *os.File {
+	return tun.tunFile
+}
+
+func (tun *NativeTun) routineHackListener() {
+	defer tun.hackListenerClosed.Unlock()
+	/* This is needed for the detection to work across network namespaces
+	 * If you are reading this and know a better method, please get in touch.
+	 */
+	last := 0
+	const (
+		up   = 1
+		down = 2
+	)
+	for {
+		sysconn, err := tun.tunFile.SyscallConn()
+		if err != nil {
+			return
+		}
+		err2 := sysconn.Control(func(fd uintptr) {
+			_, err = unix.Write(int(fd), nil)
+		})
+		if err2 != nil {
+			return
+		}
+		switch err {
+		case unix.EINVAL:
+			if last != up {
+				// If the tunnel is up, it reports that write() is
+				// allowed but we provided invalid data.
+				tun.events <- EventUp
+				last = up
+			}
+		case unix.EIO:
+			if last != down {
+				// If the tunnel is down, it reports that no I/O
+				// is possible, without checking our provided data.
+				tun.events <- EventDown
+				last = down
+			}
+		default:
+			return
+		}
+		select {
+		case <-time.After(time.Second):
+			// nothing
+		case <-tun.statusListenersShutdown:
+			return
+		}
+	}
+}
+
+func createNetlinkSocket() (int, error) {
+	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
+	if err != nil {
+		return -1, err
+	}
+	saddr := &unix.SockaddrNetlink{
+		Family: unix.AF_NETLINK,
+		Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
+	}
+	err = unix.Bind(sock, saddr)
+	if err != nil {
+		return -1, err
+	}
+	return sock, nil
+}
+
+func (tun *NativeTun) routineNetlinkListener() {
+	defer func() {
+		unix.Close(tun.netlinkSock)
+		tun.hackListenerClosed.Lock()
+		close(tun.events)
+		tun.netlinkCancel.Close()
+	}()
+
+	for msg := make([]byte, 1<<16); ; {
+		var err error
+		var msgn int
+		for {
+			msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
+			if err == nil || !rwcancel.RetryAfterError(err) {
+				break
+			}
+			if !tun.netlinkCancel.ReadyRead() {
+				tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
+				return
+			}
+		}
+		if err != nil {
+			tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
+			return
+		}
+
+		select {
+		case <-tun.statusListenersShutdown:
+			return
+		default:
+		}
+
+		wasEverUp := false
+		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+			if int(hdr.Len) > len(remain) {
+				break
+			}
+
+			switch hdr.Type {
+			case unix.NLMSG_DONE:
+				remain = []byte{}
+
+			case unix.RTM_NEWLINK:
+				info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
+				remain = remain[hdr.Len:]
+
+				if info.Index != tun.index {
+					// not our interface
+					continue
+				}
+
+				if info.Flags&unix.IFF_RUNNING != 0 {
+					tun.events <- EventUp
+					wasEverUp = true
+				}
+
+				if info.Flags&unix.IFF_RUNNING == 0 {
+					// Don't emit EventDown before we've ever emitted EventUp.
+					// This avoids a startup race with HackListener, which
+					// might detect Up before we have finished reporting Down.
+					if wasEverUp {
+						tun.events <- EventDown
+					}
+				}
+
+				tun.events <- EventMTUUpdate
+
+			default:
+				remain = remain[hdr.Len:]
+			}
+		}
+	}
+}
+
+func getIFIndex(name string) (int32, error) {
+	fd, err := unix.Socket(
+		unix.AF_INET,
+		unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
+		0,
+	)
+	if err != nil {
+		return 0, err
+	}
+
+	defer unix.Close(fd)
+
+	var ifr [ifReqSize]byte
+	copy(ifr[:], name)
+	_, _, errno := unix.Syscall(
+		unix.SYS_IOCTL,
+		uintptr(fd),
+		uintptr(unix.SIOCGIFINDEX),
+		uintptr(unsafe.Pointer(&ifr[0])),
+	)
+
+	if errno != 0 {
+		return 0, errno
+	}
+
+	return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
+}
+
+func (tun *NativeTun) setMTU(n int) error {
+	name, err := tun.Name()
+	if err != nil {
+		return err
+	}
+
+	// open datagram socket
+	fd, err := unix.Socket(
+		unix.AF_INET,
+		unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
+		0,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(fd)
+
+	var ifr [ifReqSize]byte
+	copy(ifr[:], name)
+	*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
+
+	_, _, errno := unix.Syscall(
+		unix.SYS_IOCTL,
+		uintptr(fd),
+		uintptr(unix.SIOCSIFMTU),
+		uintptr(unsafe.Pointer(&ifr[0])),
+	)
+
+	if errno != 0 {
+		return errno
+	}
+	return nil
+}
+
+func (tun *NativeTun) routineNetlinkRead() {
+	defer func() {
+		unix.Close(tun.netlinkSock)
+		tun.hackListenerClosed.Lock()
+		close(tun.events)
+		tun.netlinkCancel.Close()
+	}()
+
+	for msg := make([]byte, 1<<16); ; {
+		var err error
+		var msgn int
+		for {
+			msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
+			if err == nil || !rwcancel.RetryAfterError(err) {
+				break
+			}
+			if !tun.netlinkCancel.ReadyRead() {
+				tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
+				return
+			}
+		}
+		if err != nil {
+			tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
+			return
+		}
+
+		wasEverUp := false
+		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+			if int(hdr.Len) > len(remain) {
+				break
+			}
+
+			switch hdr.Type {
+			case unix.NLMSG_DONE:
+				remain = []byte{}
+
+			case unix.RTM_NEWLINK:
+				info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr]))
+				remain = remain[hdr.Len:]
+
+				if info.Index != tun.index {
+					continue
+				}
+
+				if info.Flags&unix.IFF_RUNNING != 0 {
+					tun.events <- EventUp
+					wasEverUp = true
+				}
+
+				if info.Flags&unix.IFF_RUNNING == 0 {
+					if wasEverUp {
+						tun.events <- EventDown
+					}
+				}
+				tun.events <- EventMTUUpdate
+
+			default:
+				remain = remain[hdr.Len:]
+			}
+		}
+	}
+}
+
+func (tun *NativeTun) routineNetlink() {
+	var err error
+
+	tun.netlinkSock, err = createNetlinkSocket()
+	if err != nil {
+		tun.errors <- fmt.Errorf("failed to create netlink socket: %w", err)
+		return
+	}
+
+	tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
+	if err != nil {
+		tun.errors <- fmt.Errorf("failed to create netlink cancel: %w", err)
+		return
+	}
+
+	go tun.routineNetlinkListener()
+}
+
+func (tun *NativeTun) Close() error {
+	var err1, err2 error
+	tun.closeOnce.Do(func() {
+		if tun.statusListenersShutdown != nil {
+			close(tun.statusListenersShutdown)
+			if tun.netlinkCancel != nil {
+				err1 = tun.netlinkCancel.Cancel()
+			}
+		} else if tun.events != nil {
+			close(tun.events)
+		}
+		err2 = tun.tunFile.Close()
+	})
+	if err1 != nil {
+		return err1
+	}
+	return err2
+}
+
+func (tun *NativeTun) BatchSize() int {
+	return tun.batchSize
+}
+
+const (
+	// TODO: support TSO with ECN bits
+	tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+	sc, err := tun.tunFile.SyscallConn()
+	if err != nil {
+		return err
+	}
+	if e := sc.Control(func(fd uintptr) {
+		var (
+			ifr *unix.Ifreq
+		)
+		ifr, err = unix.NewIfreq(name)
+		if err != nil {
+			return
+		}
+		err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+		if err != nil {
+			return
+		}
+		got := ifr.Uint16()
+		if got&unix.IFF_VNET_HDR != 0 {
+			err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
+			if err != nil {
+				return
+			}
+			tun.vnetHdr = true
+			tun.batchSize = wgconn.IdealBatchSize
+		} else {
+			tun.batchSize = 1
+		}
+	}); e != nil {
+		return e
+	}
+	return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
+func CreateTUN(name string, mtu int) (Device, error) {
+	nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
+	if err != nil {
+		return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
+	}
+	fd := os.NewFile(uintptr(nfd), cloneDevicePath)
+	tun, err := CreateTUNFromFile(fd, mtu)
+	if err != nil {
+		return nil, err
+	}
+	if name != "tun" {
+		if err := tun.(*NativeTun).initFromFlags(name); err != nil {
+			tun.Close()
+			return nil, fmt.Errorf("CreateTUN(%q) failed to set flags: %w", name, err)
+		}
+	}
+	return tun, nil
+}
+
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
+func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
+	tun := &NativeTun{
+		tunFile: file,
+		errors:  make(chan error, 5),
+		events:  make(chan Event, 5),
+	}
+
+	var err error
+	tun.index, err = getIFIndex("tun")
+	if err != nil {
+		return nil, fmt.Errorf("failed to get TUN index: %w", err)
+	}
+
+	if err = tun.setMTU(mtu); err != nil {
+		return nil, fmt.Errorf("failed to set MTU: %w", err)
+	}
+
+	tun.statusListenersShutdown = make(chan struct{})
+	go tun.routineNetlink()
+
+	if tun.batchSize == 0 {
+		tun.batchSize = 1
+	}
+
+	tun.tcp4GROTable = newTCPGROTable()
+	tun.tcp6GROTable = newTCPGROTable()
+
+	return tun, nil
+}
+
+func (tun *NativeTun) Name() (string, error) {
+	tun.nameOnce.Do(tun.initNameCache)
+	return tun.nameCache, tun.nameErr
+}
+
+func (tun *NativeTun) initNameCache() {
+	sysconn, err := tun.tunFile.SyscallConn()
+	if err != nil {
+		tun.nameErr = err
+		return
+	}
+	err = sysconn.Control(func(fd uintptr) {
+		var ifr [ifReqSize]byte
+		_, _, errno := unix.Syscall(
+			unix.SYS_IOCTL,
+			fd,
+			uintptr(unix.TUNGETIFF),
+			uintptr(unsafe.Pointer(&ifr[0])),
+		)
+		if errno != 0 {
+			tun.nameErr = errno
+			return
+		}
+		tun.nameCache = unix.ByteSliceToString(ifr[:])
+	})
+	if err != nil && tun.nameErr == nil {
+		tun.nameErr = err
+	}
+}
+
+func (tun *NativeTun) MTU() (int, error) {
+	name, err := tun.Name()
+	if err != nil {
+		return 0, err
+	}
+
+	// open datagram socket
+	fd, err := unix.Socket(
+		unix.AF_INET,
+		unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
+		0,
+	)
+	if err != nil {
+		return 0, err
+	}
+	defer unix.Close(fd)
+
+	var ifr [ifReqSize]byte
+	copy(ifr[:], name)
+
+	_, _, errno := unix.Syscall(
+		unix.SYS_IOCTL,
+		uintptr(fd),
+		uintptr(unix.SIOCGIFMTU),
+		uintptr(unsafe.Pointer(&ifr[0])),
+	)
+
+	if errno != 0 {
+		return 0, errno
+	}
+
+	return int(*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
+}
+
+func (tun *NativeTun) Events() <-chan Event {
+	return tun.events
+}
+
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+	tun.writeOpMu.Lock()
+	defer func() {
+		tun.tcp4GROTable.reset()
+		tun.tcp6GROTable.reset()
+		tun.writeOpMu.Unlock()
+	}()
+	var (
+		errs  error
+		total int
+	)
+	tun.toWrite = tun.toWrite[:0]
+	if tun.vnetHdr {
+		err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
+		if err != nil {
+			return 0, err
+		}
+		offset -= virtioNetHdrLen
+	} else {
+		for i := range bufs {
+			tun.toWrite = append(tun.toWrite, i)
+		}
+	}
+	for _, bufsI := range tun.toWrite {
+		n, err := tun.tunFile.Write(bufs[bufsI][offset:])
+		if errors.Is(err, syscall.EBADFD) {
+			return total, os.ErrClosed
+		}
+		if err != nil {
+			errs = errors.Join(errs, err)
+		} else {
+			total += n
+		}
+	}
+	return total, errs
+}
+
+// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of bufs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
+	var hdr virtioNetHdr
+	if err := hdr.decode(in); err != nil {
+		return 0, err
+	}
+	in = in[virtioNetHdrLen:]
+	if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+		if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+			if err := gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset); err != nil {
+				return 0, err
+			}
+		}
+		if len(in) > len(bufs[0][offset:]) {
+			return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
+		}
+		n := copy(bufs[0][offset:], in)
+		sizes[0] = n
+		return 1, nil
+	}
+	if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+		return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
+	}
+
+	ipVersion := in[0] >> 4
+	switch ipVersion {
+	case 4:
+		if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
+			return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+		}
+	case 6:
+		if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+			return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+		}
+	default:
+		return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+	}
+
+	if len(in) <= int(hdr.csumStart+12) {
+		return 0, errors.New("packet is too short")
+	}
+	tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+	if tcpHLen < 20 || tcpHLen > 60 {
+		return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
+	}
+	hdr.hdrLen = hdr.csumStart + tcpHLen
+	if len(in) < int(hdr.hdrLen) {
+		return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
+	}
+	if hdr.hdrLen < hdr.csumStart {
+		return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
+	}
+	cSumAt := int(hdr.csumStart + hdr.csumOffset)
+	if cSumAt+1 >= len(in) {
+		return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+	}
+
+	return tcpTSO(in, hdr, bufs, sizes, offset)
+}
+
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+	tun.readOpMu.Lock()
+	defer tun.readOpMu.Unlock()
+	select {
+	case err := <-tun.errors:
+		return 0, err
+	default:
+		readInto := bufs[0][offset:]
+		if tun.vnetHdr {
+			readInto = tun.readBuff[:]
+		}
+		n, err := tun.tunFile.Read(readInto)
+		if errors.Is(err, syscall.EBADFD) {
+			err = os.ErrClosed
+		}
+		if err != nil {
+			return 0, err
+		}
+		if tun.vnetHdr {
+			return handleVirtioRead(readInto[:n], bufs, sizes, offset)
+		}
+		sizes[0] = n
+		return 1, nil
+	}
+}