Browse Source

first try

Ryan Huber 1 month ago
parent
commit
fd1c52127f
2 changed files with 57 additions and 8 deletions
  1. 11 3
      main.go
  2. 46 5
      overlay/tun_linux.go

+ 11 - 3
main.go

@@ -162,9 +162,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			listenHost = ips[0].Unmap()
 		}
 
-		for i := 0; i < routines; i++ {
-			l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
-			udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
+	useWG := c.GetBool("listen.use_wireguard_stack", false)
+	var mkListener func(*logrus.Logger, netip.Addr, int, bool, int) (udp.Conn, error)
+	if useWG {
+		mkListener = udp.NewWireguardListener
+	} else {
+		mkListener = udp.NewListener
+	}
+
+	for i := 0; i < routines; i++ {
+		l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
+		udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}

+ 46 - 5
overlay/tun_linux.go

@@ -19,6 +19,7 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
+	wgtun "github.com/slackhq/nebula/wgstack/tun"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
 )
@@ -33,6 +34,7 @@ type tun struct {
 	TXQueueLen  int
 	deviceIndex int
 	ioctlFd     uintptr
+	wgDevice    wgtun.Device
 
 	Routes                    atomic.Pointer[[]Route]
 	routeTree                 atomic.Pointer[bart.Table[routing.Gateways]]
@@ -68,7 +70,8 @@ type ifreqQLEN struct {
 func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
-	t, err := newTunGeneric(c, l, file, vpnNetworks)
+	useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", false))
+	t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
 	if err != nil {
 		return nil, err
 	}
@@ -113,7 +116,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
 	name := strings.Trim(string(req.Name[:]), "\x00")
 
 	file := os.NewFile(uintptr(fd), "/dev/net/tun")
-	t, err := newTunGeneric(c, l, file, vpnNetworks)
+	useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", false))
+	t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
 	if err != nil {
 		return nil, err
 	}
@@ -123,16 +127,45 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
 	return t, nil
 }
 
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, useWireguard bool) (*tun, error) {
+	var (
+		rw    io.ReadWriteCloser = file
+		fd                       = int(file.Fd())
+		wgDev wgtun.Device
+	)
+
+	if useWireguard {
+		dev, err := wgtun.CreateTUNFromFile(file, c.GetInt("tun.mtu", DefaultMTU))
+		if err != nil {
+			return nil, fmt.Errorf("failed to initialize wireguard tun device: %w", err)
+		}
+		wgDev = dev
+		rw = newWireguardTunIO(dev, c.GetInt("tun.mtu", DefaultMTU))
+		fd = int(dev.File().Fd())
+	}
+
 	t := &tun{
-		ReadWriteCloser:           file,
-		fd:                        int(file.Fd()),
+		ReadWriteCloser:           rw,
+		fd:                        fd,
 		vpnNetworks:               vpnNetworks,
 		TXQueueLen:                c.GetInt("tun.tx_queue", 500),
 		useSystemRoutes:           c.GetBool("tun.use_system_route_table", false),
 		useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
 		l:                         l,
 	}
+	if wgDev != nil {
+		t.wgDevice = wgDev
+	}
+	if wgDev != nil {
+		// replace ioctl fd with device file descriptor to keep route management working
+		file = wgDev.File()
+		t.fd = int(file.Fd())
+		t.ioctlFd = file.Fd()
+	}
+
+	if t.ioctlFd == 0 {
+		t.ioctlFd = file.Fd()
+	}
 
 	err := t.reload(c, true)
 	if err != nil {
@@ -678,6 +711,14 @@ func (t *tun) Close() error {
 		_ = t.ReadWriteCloser.Close()
 	}
 
+	if t.wgDevice != nil {
+		_ = t.wgDevice.Close()
+		if t.ioctlFd > 0 {
+			// underlying fd already closed by the device
+			t.ioctlFd = 0
+		}
+	}
+
 	if t.ioctlFd > 0 {
 		_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
 	}