|
@@ -19,6 +19,7 @@ import (
|
|
|
"github.com/slackhq/nebula/config"
|
|
"github.com/slackhq/nebula/config"
|
|
|
"github.com/slackhq/nebula/routing"
|
|
"github.com/slackhq/nebula/routing"
|
|
|
"github.com/slackhq/nebula/util"
|
|
"github.com/slackhq/nebula/util"
|
|
|
|
|
+ wgtun "github.com/slackhq/nebula/wgstack/tun"
|
|
|
"github.com/vishvananda/netlink"
|
|
"github.com/vishvananda/netlink"
|
|
|
"golang.org/x/sys/unix"
|
|
"golang.org/x/sys/unix"
|
|
|
)
|
|
)
|
|
@@ -33,6 +34,7 @@ type tun struct {
|
|
|
TXQueueLen int
|
|
TXQueueLen int
|
|
|
deviceIndex int
|
|
deviceIndex int
|
|
|
ioctlFd uintptr
|
|
ioctlFd uintptr
|
|
|
|
|
+ wgDevice wgtun.Device
|
|
|
|
|
|
|
|
Routes atomic.Pointer[[]Route]
|
|
Routes atomic.Pointer[[]Route]
|
|
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
|
@@ -68,7 +70,8 @@ type ifreqQLEN struct {
|
|
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
|
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
|
|
|
|
|
|
|
- t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
|
|
|
|
|
+ useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", false))
|
|
|
|
|
+ t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
@@ -113,7 +116,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
|
|
|
|
|
|
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
|
|
- t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
|
|
|
|
|
+ useWG := c.GetBool("tun.use_wireguard_stack", c.GetBool("listen.use_wireguard_stack", false))
|
|
|
|
|
+ t, err := newTunGeneric(c, l, file, vpnNetworks, useWG)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
@@ -123,16 +127,45 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|
|
return t, nil
|
|
return t, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
|
|
|
|
|
|
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, useWireguard bool) (*tun, error) {
|
|
|
|
|
+ var (
|
|
|
|
|
+ rw io.ReadWriteCloser = file
|
|
|
|
|
+ fd = int(file.Fd())
|
|
|
|
|
+ wgDev wgtun.Device
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if useWireguard {
|
|
|
|
|
+ dev, err := wgtun.CreateTUNFromFile(file, c.GetInt("tun.mtu", DefaultMTU))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, fmt.Errorf("failed to initialize wireguard tun device: %w", err)
|
|
|
|
|
+ }
|
|
|
|
|
+ wgDev = dev
|
|
|
|
|
+ rw = newWireguardTunIO(dev, c.GetInt("tun.mtu", DefaultMTU))
|
|
|
|
|
+ fd = int(dev.File().Fd())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
t := &tun{
|
|
t := &tun{
|
|
|
- ReadWriteCloser: file,
|
|
|
|
|
- fd: int(file.Fd()),
|
|
|
|
|
|
|
+ ReadWriteCloser: rw,
|
|
|
|
|
+ fd: fd,
|
|
|
vpnNetworks: vpnNetworks,
|
|
vpnNetworks: vpnNetworks,
|
|
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
|
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
|
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
|
|
l: l,
|
|
l: l,
|
|
|
}
|
|
}
|
|
|
|
|
+ if wgDev != nil {
|
|
|
|
|
+ t.wgDevice = wgDev
|
|
|
|
|
+ }
|
|
|
|
|
+ if wgDev != nil {
|
|
|
|
|
+ // replace ioctl fd with device file descriptor to keep route management working
|
|
|
|
|
+ file = wgDev.File()
|
|
|
|
|
+ t.fd = int(file.Fd())
|
|
|
|
|
+ t.ioctlFd = file.Fd()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if t.ioctlFd == 0 {
|
|
|
|
|
+ t.ioctlFd = file.Fd()
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
err := t.reload(c, true)
|
|
err := t.reload(c, true)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -678,6 +711,14 @@ func (t *tun) Close() error {
|
|
|
_ = t.ReadWriteCloser.Close()
|
|
_ = t.ReadWriteCloser.Close()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if t.wgDevice != nil {
|
|
|
|
|
+ _ = t.wgDevice.Close()
|
|
|
|
|
+ if t.ioctlFd > 0 {
|
|
|
|
|
+ // underlying fd already closed by the device
|
|
|
|
|
+ t.ioctlFd = 0
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if t.ioctlFd > 0 {
|
|
if t.ioctlFd > 0 {
|
|
|
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
|
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
|
|
}
|
|
}
|