Browse Source

Use x/net/route to manage routes directly (#1488)

Nate Brown 5 days ago
parent
commit
fb7f0c3657
3 changed files with 158 additions and 38 deletions
  1. 11 0
      overlay/tun.go
  2. 0 11
      overlay/tun_darwin.go
  3. 147 27
      overlay/tun_freebsd.go

+ 11 - 0
overlay/tun.go

@@ -1,6 +1,7 @@
 package overlay
 package overlay
 
 
 import (
 import (
+	"net"
 	"net/netip"
 	"net/netip"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
@@ -70,3 +71,13 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
 
 
 	return removed
 	return removed
 }
 }
+
+func prefixToMask(prefix netip.Prefix) netip.Addr {
+	pLen := 128
+	if prefix.Addr().Is4() {
+		pLen = 32
+	}
+
+	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
+	return addr
+}

+ 0 - 11
overlay/tun_darwin.go

@@ -7,7 +7,6 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
 	"net/netip"
 	"net/netip"
 	"os"
 	"os"
 	"sync/atomic"
 	"sync/atomic"
@@ -554,13 +553,3 @@ func (t *tun) Name() string {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 }
 }
-
-func prefixToMask(prefix netip.Prefix) netip.Addr {
-	pLen := 128
-	if prefix.Addr().Is4() {
-		pLen = 32
-	}
-
-	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
-	return addr
-}

+ 147 - 27
overlay/tun_freebsd.go

@@ -9,9 +9,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"io/fs"
 	"io/fs"
-	"net"
 	"net/netip"
 	"net/netip"
-	"os/exec"
 	"sync/atomic"
 	"sync/atomic"
 	"syscall"
 	"syscall"
 	"time"
 	"time"
@@ -22,6 +20,7 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
+	netroute "golang.org/x/net/route"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
@@ -92,6 +91,7 @@ type tun struct {
 	MTU         int
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
 	Routes      atomic.Pointer[[]Route]
 	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
+	linkAddr    *netroute.LinkAddr
 	l           *logrus.Logger
 	l           *logrus.Logger
 	devFd       int
 	devFd       int
 }
 }
@@ -162,6 +162,7 @@ func (t *tun) Write(from []byte) (int, error) {
 	} else {
 	} else {
 		err = nil
 		err = nil
 	}
 	}
+
 	return int(n) - 4, err
 	return int(n) - 4, err
 }
 }
 
 
@@ -308,7 +309,7 @@ func (t *tun) addIp(cidr netip.Prefix) error {
 			MaskAddr: unix.RawSockaddrInet4{
 			MaskAddr: unix.RawSockaddrInet4{
 				Len:    unix.SizeofSockaddrInet4,
 				Len:    unix.SizeofSockaddrInet4,
 				Family: unix.AF_INET,
 				Family: unix.AF_INET,
-				Addr:   getNetmask(cidr).As4(),
+				Addr:   prefixToMask(cidr).As4(),
 			},
 			},
 			VHid: 0,
 			VHid: 0,
 		}
 		}
@@ -321,7 +322,10 @@ func (t *tun) addIp(cidr netip.Prefix) error {
 		if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
 		if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
 			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
 			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
 		}
 		}
-	} else if cidr.Addr().Is6() {
+		return nil
+	}
+
+	if cidr.Addr().Is6() {
 		ifr := ifreqAlias6{
 		ifr := ifreqAlias6{
 			Name: t.deviceBytes(),
 			Name: t.deviceBytes(),
 			Addr: unix.RawSockaddrInet6{
 			Addr: unix.RawSockaddrInet6{
@@ -332,7 +336,7 @@ func (t *tun) addIp(cidr netip.Prefix) error {
 			PrefixMask: unix.RawSockaddrInet6{
 			PrefixMask: unix.RawSockaddrInet6{
 				Len:    unix.SizeofSockaddrInet6,
 				Len:    unix.SizeofSockaddrInet6,
 				Family: unix.AF_INET6,
 				Family: unix.AF_INET6,
-				Addr:   getNetmask(cidr).As16(),
+				Addr:   prefixToMask(cidr).As16(),
 			},
 			},
 			Lifetime: addrLifetime{
 			Lifetime: addrLifetime{
 				Expire:    0,
 				Expire:    0,
@@ -351,11 +355,10 @@ func (t *tun) addIp(cidr netip.Prefix) error {
 		if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
 		if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
 			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
 			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
 		}
 		}
-	} else {
-		return fmt.Errorf("Unknown address type")
+		return nil
 	}
 	}
 
 
-	return t.addRoutes(false)
+	return fmt.Errorf("unknown address type %v", cidr)
 }
 }
 
 
 func (t *tun) Activate() error {
 func (t *tun) Activate() error {
@@ -365,13 +368,23 @@ func (t *tun) Activate() error {
 		return err
 		return err
 	}
 	}
 
 
+	linkAddr, err := getLinkAddr(t.Device)
+	if err != nil {
+		return err
+	}
+	if linkAddr == nil {
+		return fmt.Errorf("unable to discover link_addr for tun interface")
+	}
+	t.linkAddr = linkAddr
+
 	for i := range t.vpnNetworks {
 	for i := range t.vpnNetworks {
 		err := t.addIp(t.vpnNetworks[i])
 		err := t.addIp(t.vpnNetworks[i])
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
-	return nil
+
+	return t.addRoutes(false)
 }
 }
 
 
 func (t *tun) setMTU() error {
 func (t *tun) setMTU() error {
@@ -449,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error {
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
-			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
+		err := addRoute(r.Cidr, t.linkAddr)
+		if err != nil {
+			retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
 			if logErrors {
 			if logErrors {
 				retErr.Log(t.l)
 				retErr.Log(t.l)
 			} else {
 			} else {
 				return retErr
 				return retErr
 			}
 			}
+		} else {
+			t.l.WithField("route", r).Info("Added route")
 		}
 		}
 	}
 	}
 
 
@@ -470,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
+		err := delRoute(r.Cidr, t.linkAddr)
+		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
 		} else {
 			t.l.WithField("route", r).Info("Removed route")
 			t.l.WithField("route", r).Info("Removed route")
@@ -502,22 +515,129 @@ func orBytes(a []byte, b []byte) []byte {
 	return ret
 	return ret
 }
 }
 
 
-func getNetmask(cidr netip.Prefix) netip.Addr {
-	pLen := 128
-	if cidr.Addr().Is4() {
-		pLen = 32
-	}
-
-	addr, _ := netip.AddrFromSlice(net.CIDRMask(cidr.Bits(), pLen))
-	return addr
-}
-
 func getBroadcast(cidr netip.Prefix) netip.Addr {
 func getBroadcast(cidr netip.Prefix) netip.Addr {
 	broadcast, _ := netip.AddrFromSlice(
 	broadcast, _ := netip.AddrFromSlice(
 		orBytes(
 		orBytes(
 			cidr.Addr().AsSlice(),
 			cidr.Addr().AsSlice(),
-			flipBytes(getNetmask(cidr).AsSlice()),
+			flipBytes(prefixToMask(cidr).AsSlice()),
 		),
 		),
 	)
 	)
 	return broadcast
 	return broadcast
 }
 }
+
+func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := &netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_ADD,
+		Flags:   unix.RTF_UP,
+		Seq:     1,
+	}
+
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		if errors.Is(err, unix.EEXIST) {
+			// Try to do a change
+			route.Type = unix.RTM_CHANGE
+			data, err = route.Marshal()
+			if err != nil {
+				return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
+			}
+			_, err = unix.Write(sock, data[:])
+			fmt.Println("DOING CHANGE")
+			return err
+		}
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
+}
+
+func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_DELETE,
+		Seq:     1,
+	}
+
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
+}
+
+// getLinkAddr Gets the link address for the interface of the given name
+func getLinkAddr(name string) (*netroute.LinkAddr, error) {
+	rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
+	if err != nil {
+		return nil, err
+	}
+	msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
+	if err != nil {
+		return nil, err
+	}
+
+	for _, m := range msgs {
+		switch m := m.(type) {
+		case *netroute.InterfaceMessage:
+			if m.Name == name {
+				sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
+				if ok {
+					return sa, nil
+				}
+			}
+		}
+	}
+
+	return nil, nil
+}