Răsfoiți Sursa

Push route handling into overlay, a few more nits fixed (#581)

Nate Brown 3 ani în urmă
părinte
comite
467e605d5e

+ 0 - 19
hostmap.go

@@ -15,7 +15,6 @@ import (
 	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 )
 )
 
 
@@ -36,7 +35,6 @@ type HostMap struct {
 	Hosts           map[iputil.VpnIp]*HostInfo
 	Hosts           map[iputil.VpnIp]*HostInfo
 	preferredRanges []*net.IPNet
 	preferredRanges []*net.IPNet
 	vpnCIDR         *net.IPNet
 	vpnCIDR         *net.IPNet
-	unsafeRoutes    *cidr.Tree4
 	metricsEnabled  bool
 	metricsEnabled  bool
 	l               *logrus.Logger
 	l               *logrus.Logger
 }
 }
@@ -99,7 +97,6 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
 		Hosts:           h,
 		Hosts:           h,
 		preferredRanges: preferredRanges,
 		preferredRanges: preferredRanges,
 		vpnCIDR:         vpnCIDR,
 		vpnCIDR:         vpnCIDR,
-		unsafeRoutes:    cidr.NewTree4(),
 		l:               l,
 		l:               l,
 	}
 	}
 	return &m
 	return &m
@@ -333,15 +330,6 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
 	return nil, errors.New("unable to find host")
 	return nil, errors.New("unable to find host")
 }
 }
 
 
-func (hm *HostMap) queryUnsafeRoute(ip iputil.VpnIp) iputil.VpnIp {
-	r := hm.unsafeRoutes.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	} else {
-		return 0
-	}
-}
-
 // We already have the hm Lock when this is called, so make sure to not call
 // We already have the hm Lock when this is called, so make sure to not call
 // any other methods that might try to grab it again
 // any other methods that might try to grab it again
 func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
@@ -409,13 +397,6 @@ func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) addUnsafeRoutes(routes *[]overlay.Route) {
-	for _, r := range *routes {
-		hm.l.WithField("cidr", r.Cidr).WithField("via", r.Via).Warn("Adding UNSAFE Route")
-		hm.unsafeRoutes.AddCIDR(r.Cidr, iputil.Ip2VpnIp(*r.Via))
-	}
-}
-
 func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
 func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
 	i.ConnectionState = cs
 	i.ConnectionState = cs
 }
 }

+ 1 - 1
inside.go

@@ -72,7 +72,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 	//TODO: we can find contains without converting back to bytes
 	//TODO: we can find contains without converting back to bytes
 	if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false {
 	if f.hostMap.vpnCIDR.Contains(vpnIp.ToIP()) == false {
-		vpnIp = f.hostMap.queryUnsafeRoute(vpnIp)
+		vpnIp = f.inside.RouteFor(vpnIp)
 		if vpnIp == 0 {
 		if vpnIp == 0 {
 			return nil
 			return nil
 		}
 		}

+ 1 - 11
main.go

@@ -78,14 +78,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 
 	// TODO: make sure mask is 4 bytes
 	// TODO: make sure mask is 4 bytes
 	tunCidr := cs.certificate.Details.Ips[0]
 	tunCidr := cs.certificate.Details.Ips[0]
-	routes, err := overlay.ParseRoutes(c, tunCidr)
-	if err != nil {
-		return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
-	}
-	unsafeRoutes, err := overlay.ParseUnsafeRoutes(c, tunCidr)
-	if err != nil {
-		return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
-	}
 
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	wireSSHReload(l, ssh, c)
 	wireSSHReload(l, ssh, c)
@@ -142,7 +134,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if !configTest {
 	if !configTest {
 		c.CatchHUP(ctx)
 		c.CatchHUP(ctx)
 
 
-		tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, routes, unsafeRoutes, tunFd, routines)
+		tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
 		if err != nil {
 		if err != nil {
 			return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
 			return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
 		}
 		}
@@ -217,8 +209,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	}
 	}
 
 
 	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
 	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
-
-	hostMap.addUnsafeRoutes(&unsafeRoutes)
 	hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
 	hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
 
 
 	l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")
 	l.WithField("network", hostMap.vpnCIDR).WithField("preferredRanges", hostMap.preferredRanges).Info("Main HostMap created")

+ 3 - 0
overlay/device.go

@@ -3,6 +3,8 @@ package overlay
 import (
 import (
 	"io"
 	"io"
 	"net"
 	"net"
+
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type Device interface {
 type Device interface {
@@ -11,5 +13,6 @@ type Device interface {
 	CidrNet() *net.IPNet
 	CidrNet() *net.IPNet
 	DeviceName() string
 	DeviceName() string
 	WriteRaw([]byte) error
 	WriteRaw([]byte) error
+	RouteFor(iputil.VpnIp) iputil.VpnIp
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 }
 }

+ 2 - 2
overlay/route.go

@@ -16,7 +16,7 @@ type Route struct {
 	Via    *net.IP
 	Via    *net.IP
 }
 }
 
 
-func ParseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 	var err error
 	var err error
 
 
 	r := c.Get("tun.routes")
 	r := c.Get("tun.routes")
@@ -86,7 +86,7 @@ func ParseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 	return routes, nil
 	return routes, nil
 }
 }
 
 
-func ParseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 	var err error
 	var err error
 
 
 	r := c.Get("tun.unsafe_routes")
 	r := c.Get("tun.unsafe_routes")

+ 30 - 30
overlay/tun_test.go → overlay/route_test.go

@@ -10,73 +10,73 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-func Test_ParseRoutes(t *testing.T) {
+func Test_parseRoutes(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	c := config.NewC(l)
 	c := config.NewC(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
 
 	// test no routes config
 	// test no routes config
-	routes, err := ParseRoutes(c, n)
+	routes, err := parseRoutes(c, n)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// not an array
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.routes is not an array")
 	assert.EqualError(t, err, "tun.routes is not an array")
 
 
 	// no routes
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// weird route
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
 	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
 
 
 	// no mtu
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 
 
 	// bad mtu
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 
 	// low mtu
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 
 
 	// missing route
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
 
 
 	// unparsable route
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope")
 	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope")
 
 
 	// below network range
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
 
 
 	// above network range
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
 
 
@@ -85,7 +85,7 @@ func Test_ParseRoutes(t *testing.T) {
 		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
 		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 	}}
 	}}
-	routes, err = ParseRoutes(c, n)
+	routes, err = parseRoutes(c, n)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 2)
 	assert.Len(t, routes, 2)
 
 
@@ -106,37 +106,37 @@ func Test_ParseRoutes(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func Test_ParseUnsafeRoutes(t *testing.T) {
+func Test_parseUnsafeRoutes(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	c := config.NewC(l)
 	c := config.NewC(l)
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 	_, n, _ := net.ParseCIDR("10.0.0.0/24")
 
 
 	// test no routes config
 	// test no routes config
-	routes, err := ParseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, n)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// not an array
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
 	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
 
 
 	// no routes
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// weird route
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 
 
 	// no via
 	// no via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 
 
@@ -145,62 +145,62 @@ func Test_ParseUnsafeRoutes(t *testing.T) {
 		127, false, nil, 1.0, []string{"1", "2"},
 		127, false, nil, 1.0, []string{"1", "2"},
 	} {
 	} {
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
-		routes, err = ParseUnsafeRoutes(c, n)
+		routes, err = parseUnsafeRoutes(c, n)
 		assert.Nil(t, routes)
 		assert.Nil(t, routes)
 		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 	}
 	}
 
 
 	// unparsable via
 	// unparsable via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
 
 
 	// missing route
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 
 
 	// unparsable route
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
 
 
 	// within network range
 	// within network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
 
 
 	// below network range
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	// above network range
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	// no mtu
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Equal(t, DefaultMTU, routes[0].MTU)
 	assert.Equal(t, DefaultMTU, routes[0].MTU)
 
 
 	// bad mtu
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 
 	// low mtu
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 
 
@@ -210,7 +210,7 @@ func Test_ParseUnsafeRoutes(t *testing.T) {
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32"},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 	}}
 	}}
-	routes, err = ParseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 3)
 	assert.Len(t, routes, 3)
 
 

+ 30 - 3
overlay/tun.go

@@ -1,15 +1,30 @@
 package overlay
 package overlay
 
 
 import (
 import (
+	"fmt"
 	"net"
 	"net"
+	"runtime"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
 )
 )
 
 
 const DefaultMTU = 1300
 const DefaultMTU = 1300
 
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routes, unsafeRoutes []Route, fd *int, routines int) (Device, error) {
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) {
+	routes, err := parseRoutes(c, tunCidr)
+	if err != nil {
+		return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
+	}
+
+	unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
+	if err != nil {
+		return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
+	}
+	routes = append(routes, unsafeRoutes...)
+
 	switch {
 	switch {
 	case c.GetBool("tun.disabled", false):
 	case c.GetBool("tun.disabled", false):
 		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -22,7 +37,6 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout
 			tunCidr,
 			tunCidr,
 			c.GetInt("tun.mtu", DefaultMTU),
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			routes,
-			unsafeRoutes,
 			c.GetInt("tun.tx_queue", 500),
 			c.GetInt("tun.tx_queue", 500),
 		)
 		)
 
 
@@ -33,9 +47,22 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout
 			tunCidr,
 			tunCidr,
 			c.GetInt("tun.mtu", DefaultMTU),
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			routes,
-			unsafeRoutes,
 			c.GetInt("tun.tx_queue", 500),
 			c.GetInt("tun.tx_queue", 500),
 			routines > 1,
 			routines > 1,
 		)
 		)
 	}
 	}
 }
 }
+
+func makeCidrTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) {
+	cidrTree := cidr.NewTree4()
+	for _, r := range routes {
+		if !allowMTU && r.MTU > 0 {
+			return nil, fmt.Errorf("route MTU is not supported in %s", runtime.GOOS)
+		}
+
+		if r.Via != nil {
+			cidrTree.AddCIDR(r.Cidr, r.Via)
+		}
+	}
+	return cidrTree, nil
+}

+ 16 - 17
overlay/tun_android.go

@@ -8,44 +8,43 @@ import (
 	"io"
 	"io"
 	"net"
 	"net"
 	"os"
 	"os"
+	"runtime"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/iputil"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	fd           int
-	Device       string
-	Cidr         *net.IPNet
-	MaxMTU       int
-	DefaultMTU   int
-	TXQueueLen   int
-	Routes       []Route
-	UnsafeRoutes []Route
-	l            *logrus.Logger
+	fd   int
+	Cidr *net.IPNet
+	l    *logrus.Logger
 }
 }
 
 
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, txQueueLen int) (*tun, error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
+	if len(routes) > 0 {
+		return nil, fmt.Errorf("routes are not supported in %s", runtime.GOOS)
+	}
+
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
 
 	return &tun{
 	return &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		fd:              int(file.Fd()),
-		Device:          "android",
 		Cidr:            cidr,
 		Cidr:            cidr,
-		DefaultMTU:      defaultMTU,
-		TXQueueLen:      txQueueLen,
-		Routes:          routes,
-		UnsafeRoutes:    unsafeRoutes,
 		l:               l,
 		l:               l,
 	}, nil
 	}, nil
 }
 }
 
 
-func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int, _ bool) (*tun, error) {
+func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 }
 
 
+func (t *tun) RouteFor(iputil.VpnIp) iputil.VpnIp {
+	return 0
+}
+
 func (t *tun) WriteRaw(b []byte) error {
 func (t *tun) WriteRaw(b []byte) error {
 	var nn int
 	var nn int
 	for {
 	for {
@@ -77,7 +76,7 @@ func (t *tun) CidrNet() *net.IPNet {
 }
 }
 
 
 func (t *tun) DeviceName() string {
 func (t *tun) DeviceName() string {
-	return t.Device
+	return "android"
 }
 }
 
 
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {

+ 30 - 13
overlay/tun_darwin.go

@@ -12,18 +12,20 @@ import (
 	"unsafe"
 	"unsafe"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/iputil"
 	netroute "golang.org/x/net/route"
 	netroute "golang.org/x/net/route"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	Device       string
-	Cidr         *net.IPNet
-	DefaultMTU   int
-	TXQueueLen   int
-	UnsafeRoutes []Route
-	l            *logrus.Logger
+	Device     string
+	Cidr       *net.IPNet
+	DefaultMTU int
+	Routes     []Route
+	cidrTree   *cidr.Tree4
+	l          *logrus.Logger
 
 
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	out []byte
 	out []byte
@@ -74,9 +76,10 @@ type ifreqMTU struct {
 	pad  [8]byte
 	pad  [8]byte
 }
 }
 
 
-func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, txQueueLen int, _ bool) (*tun, error) {
-	if len(routes) > 0 {
-		return nil, fmt.Errorf("route MTU not supported in Darwin")
+func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
+	cidrTree, err := makeCidrTree(routes, false)
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
 	ifIndex := -1
 	ifIndex := -1
@@ -151,8 +154,8 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
 		Device:          name,
 		Device:          name,
 		Cidr:            cidr,
 		Cidr:            cidr,
 		DefaultMTU:      defaultMTU,
 		DefaultMTU:      defaultMTU,
-		TXQueueLen:      txQueueLen,
-		UnsafeRoutes:    unsafeRoutes,
+		Routes:          routes,
+		cidrTree:        cidrTree,
 		l:               l,
 		l:               l,
 	}
 	}
 
 
@@ -166,7 +169,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 	return
 }
 }
 
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int) (*tun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 }
 
 
@@ -279,7 +282,12 @@ func (t *tun) Activate() error {
 	}
 	}
 
 
 	// Unsafe path routes
 	// Unsafe path routes
-	for _, r := range t.UnsafeRoutes {
+	for _, r := range t.Routes {
+		if r.Via == nil {
+			// We don't allow route MTUs so only install routes with a via
+			continue
+		}
+
 		copy(routeAddr.IP[:], r.Cidr.IP.To4())
 		copy(routeAddr.IP[:], r.Cidr.IP.To4())
 		copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
 		copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
 
 
@@ -294,6 +302,15 @@ func (t *tun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
+func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
+	r := t.cidrTree.MostSpecificContains(ip)
+	if r != nil {
+		return r.(iputil.VpnIp)
+	}
+
+	return 0
+}
+
 // Get the LinkAddr for the interface of the given name
 // Get the LinkAddr for the interface of the given name
 // TODO: Is there an easier way to fetch this when we create the interface?
 // TODO: Is there an easier way to fetch this when we create the interface?
 // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
 // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.

+ 5 - 0
overlay/tun_disabled.go

@@ -9,6 +9,7 @@ import (
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type disabledTun struct {
 type disabledTun struct {
@@ -43,6 +44,10 @@ func (*disabledTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
+func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
+	return 0
+}
+
 func (t *disabledTun) CidrNet() *net.IPNet {
 func (t *disabledTun) CidrNet() *net.IPNet {
 	return t.cidr
 	return t.cidr
 }
 }

+ 35 - 15
overlay/tun_freebsd.go

@@ -14,16 +14,19 @@ import (
 	"strings"
 	"strings"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 
 type tun struct {
 type tun struct {
-	Device       string
-	Cidr         *net.IPNet
-	MTU          int
-	UnsafeRoutes []Route
-	l            *logrus.Logger
+	Device   string
+	Cidr     *net.IPNet
+	MTU      int
+	Routes   []Route
+	cidrTree *cidr.Tree4
+	l        *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 }
 }
@@ -35,14 +38,16 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int) (*tun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 }
 
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, _ int, _ bool) (*tun, error) {
-	if len(routes) > 0 {
-		return nil, fmt.Errorf("route MTU not supported in FreeBSD")
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
+	cidrTree, err := makeCidrTree(routes, false)
+	if err != nil {
+		return nil, err
 	}
 	}
+
 	if strings.HasPrefix(deviceName, "/dev/") {
 	if strings.HasPrefix(deviceName, "/dev/") {
 		deviceName = strings.TrimPrefix(deviceName, "/dev/")
 		deviceName = strings.TrimPrefix(deviceName, "/dev/")
 	}
 	}
@@ -50,11 +55,12 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
 		return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
 		return nil, fmt.Errorf("tun.dev must match `tun[0-9]+`")
 	}
 	}
 	return &tun{
 	return &tun{
-		Device:       deviceName,
-		Cidr:         cidr,
-		MTU:          defaultMTU,
-		UnsafeRoutes: unsafeRoutes,
-		l:            l,
+		Device:   deviceName,
+		Cidr:     cidr,
+		MTU:      defaultMTU,
+		Routes:   routes,
+		cidrTree: cidrTree,
+		l:        l,
 	}, nil
 	}, nil
 }
 }
 
 
@@ -79,7 +85,12 @@ func (t *tun) Activate() error {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 	// Unsafe path routes
 	// Unsafe path routes
-	for _, r := range t.UnsafeRoutes {
+	for _, r := range t.Routes {
+		if r.Via == nil {
+			// We don't allow route MTUs so only install routes with a via
+			continue
+		}
+
 		t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
 		t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
 		if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil {
 		if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil {
 			return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
 			return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
@@ -89,6 +100,15 @@ func (t *tun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
+func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
+	r := t.cidrTree.MostSpecificContains(ip)
+	if r != nil {
+		return r.(iputil.VpnIp)
+	}
+
+	return 0
+}
+
 func (t *tun) CidrNet() *net.IPNet {
 func (t *tun) CidrNet() *net.IPNet {
 	return t.Cidr
 	return t.Cidr
 }
 }

+ 11 - 7
overlay/tun_ios.go

@@ -9,25 +9,26 @@ import (
 	"io"
 	"io"
 	"net"
 	"net"
 	"os"
 	"os"
+	"runtime"
 	"sync"
 	"sync"
 	"syscall"
 	"syscall"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	Device string
-	Cidr   *net.IPNet
+	Cidr *net.IPNet
 }
 }
 
 
-func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int, _ bool) (*tun, error) {
+func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 }
 
 
-func newTunFromFd(_ *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ []Route, _ int) (*tun, error) {
+func newTunFromFd(_ *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
 	if len(routes) > 0 {
 	if len(routes) > 0 {
-		return nil, fmt.Errorf("route MTU not supported in Darwin")
+		return nil, fmt.Errorf("routes are not supported in %s", runtime.GOOS)
 	}
 	}
 
 
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
@@ -42,6 +43,10 @@ func (t *tun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
+func (t *tun) RouteFor(iputil.VpnIp) iputil.VpnIp {
+	return 0
+}
+
 func (t *tun) WriteRaw(b []byte) error {
 func (t *tun) WriteRaw(b []byte) error {
 	_, err := t.Write(b)
 	_, err := t.Write(b)
 	return err
 	return err
@@ -73,7 +78,6 @@ func (tr *tunReadCloser) Read(to []byte) (int, error) {
 }
 }
 
 
 func (tr *tunReadCloser) Write(from []byte) (int, error) {
 func (tr *tunReadCloser) Write(from []byte) (int, error) {
-
 	if len(from) == 0 {
 	if len(from) == 0 {
 		return 0, syscall.EIO
 		return 0, syscall.EIO
 	}
 	}
@@ -111,7 +115,7 @@ func (t *tun) CidrNet() *net.IPNet {
 }
 }
 
 
 func (t *tun) DeviceName() string {
 func (t *tun) DeviceName() string {
-	return t.Device
+	return "iOS"
 }
 }
 
 
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {

+ 35 - 28
overlay/tun_linux.go

@@ -12,21 +12,23 @@ import (
 	"unsafe"
 	"unsafe"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/vishvananda/netlink"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	fd           int
-	Device       string
-	Cidr         *net.IPNet
-	MaxMTU       int
-	DefaultMTU   int
-	TXQueueLen   int
-	Routes       []Route
-	UnsafeRoutes []Route
-	l            *logrus.Logger
+	fd         int
+	Device     string
+	Cidr       *net.IPNet
+	MaxMTU     int
+	DefaultMTU int
+	TXQueueLen int
+	Routes     []Route
+	cidrTree   *cidr.Tree4
+	l          *logrus.Logger
 }
 }
 
 
 type ifReq struct {
 type ifReq struct {
@@ -61,7 +63,11 @@ type ifreqQLEN struct {
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, txQueueLen int) (*tun, error) {
+func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
+	cidrTree, err := makeCidrTree(routes, true)
+	if err != nil {
+		return nil, err
+	}
 
 
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
 
@@ -73,12 +79,12 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
 		DefaultMTU:      defaultMTU,
 		DefaultMTU:      defaultMTU,
 		TXQueueLen:      txQueueLen,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
 		Routes:          routes,
-		UnsafeRoutes:    unsafeRoutes,
+		cidrTree:        cidrTree,
 		l:               l,
 		l:               l,
 	}, nil
 	}, nil
 }
 }
 
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, txQueueLen int, multiqueue bool) (*tun, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -104,6 +110,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
 		}
 		}
 	}
 	}
 
 
+	cidrTree, err := makeCidrTree(routes, true)
+	if err != nil {
+		return nil, err
+	}
+
 	return &tun{
 	return &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		fd:              int(file.Fd()),
@@ -113,7 +124,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
 		DefaultMTU:      defaultMTU,
 		DefaultMTU:      defaultMTU,
 		TXQueueLen:      txQueueLen,
 		TXQueueLen:      txQueueLen,
 		Routes:          routes,
 		Routes:          routes,
-		UnsafeRoutes:    unsafeRoutes,
+		cidrTree:        cidrTree,
 		l:               l,
 		l:               l,
 	}, nil
 	}, nil
 }
 }
@@ -136,6 +147,15 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return file, nil
 	return file, nil
 }
 }
 
 
+func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
+	r := t.cidrTree.MostSpecificContains(ip)
+	if r != nil {
+		return r.(iputil.VpnIp)
+	}
+
+	return 0
+}
+
 func (t *tun) WriteRaw(b []byte) error {
 func (t *tun) WriteRaw(b []byte) error {
 	var nn int
 	var nn int
 	for {
 	for {
@@ -266,21 +286,8 @@ func (t tun) Activate() error {
 			Scope:     unix.RT_SCOPE_LINK,
 			Scope:     unix.RT_SCOPE_LINK,
 		}
 		}
 
 
-		err = netlink.RouteAdd(&nr)
-		if err != nil {
-			return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err)
-		}
-	}
-
-	// Unsafe path routes
-	for _, r := range t.UnsafeRoutes {
-		nr := netlink.Route{
-			LinkIndex: link.Attrs().Index,
-			Dst:       r.Cidr,
-			MTU:       r.MTU,
-			Priority:  r.Metric,
-			AdvMSS:    t.advMSS(r),
-			Scope:     unix.RT_SCOPE_LINK,
+		if r.Metric > 0 {
+			nr.Priority = r.Metric
 		}
 		}
 
 
 		err = netlink.RouteAdd(&nr)
 		err = netlink.RouteAdd(&nr)

+ 30 - 14
overlay/tun_tester.go

@@ -9,32 +9,39 @@ import (
 	"net"
 	"net"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type TestTun struct {
 type TestTun struct {
-	Device       string
-	Cidr         *net.IPNet
-	MTU          int
-	UnsafeRoutes []Route
-	l            *logrus.Logger
+	Device   string
+	Cidr     *net.IPNet
+	Routes   []Route
+	cidrTree *cidr.Tree4
+	l        *logrus.Logger
 
 
 	rxPackets chan []byte // Packets to receive into nebula
 	rxPackets chan []byte // Packets to receive into nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 }
 
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, _ []Route, unsafeRoutes []Route, _ int, _ bool) (*TestTun, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
+	cidrTree, err := makeCidrTree(routes, false)
+	if err != nil {
+		return nil, err
+	}
+
 	return &TestTun{
 	return &TestTun{
-		Device:       deviceName,
-		Cidr:         cidr,
-		MTU:          defaultMTU,
-		UnsafeRoutes: unsafeRoutes,
-		l:            l,
-		rxPackets:    make(chan []byte, 1),
-		TxPackets:    make(chan []byte, 1),
+		Device:    deviceName,
+		Cidr:      cidr,
+		Routes:    routes,
+		cidrTree:  cidrTree,
+		l:         l,
+		rxPackets: make(chan []byte, 1),
+		TxPackets: make(chan []byte, 1),
 	}, nil
 	}, nil
 }
 }
 
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int) (*TestTun, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 }
 
 
@@ -66,6 +73,15 @@ func (t *TestTun) Get(block bool) []byte {
 // Below this is boilerplate implementation to make nebula actually work
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 //********************************************************************************************************************//
 
 
+func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
+	r := t.cidrTree.MostSpecificContains(ip)
+	if r != nil {
+		return r.(iputil.VpnIp)
+	}
+
+	return 0
+}
+
 func (t *TestTun) Activate() error {
 func (t *TestTun) Activate() error {
 	return nil
 	return nil
 }
 }

+ 32 - 9
overlay/tun_water_windows.go

@@ -7,24 +7,33 @@ import (
 	"os/exec"
 	"os/exec"
 	"strconv"
 	"strconv"
 
 
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/songgao/water"
 	"github.com/songgao/water"
 )
 )
 
 
 type waterTun struct {
 type waterTun struct {
-	Device       string
-	Cidr         *net.IPNet
-	MTU          int
-	UnsafeRoutes []Route
+	Device   string
+	Cidr     *net.IPNet
+	MTU      int
+	Routes   []Route
+	cidrTree *cidr.Tree4
 
 
 	*water.Interface
 	*water.Interface
 }
 }
 
 
-func newWaterTun(cidr *net.IPNet, defaultMTU int, unsafeRoutes []Route) (*waterTun, error) {
+func newWaterTun(cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
+	cidrTree, err := makeCidrTree(routes, false)
+	if err != nil {
+		return nil, err
+	}
+
 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
 	return &waterTun{
 	return &waterTun{
-		Cidr:         cidr,
-		MTU:          defaultMTU,
-		UnsafeRoutes: unsafeRoutes,
+		Cidr:     cidr,
+		MTU:      defaultMTU,
+		Routes:   routes,
+		cidrTree: cidrTree,
 	}, nil
 	}, nil
 }
 }
 
 
@@ -69,7 +78,12 @@ func (t *waterTun) Activate() error {
 		return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
 		return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
 	}
 	}
 
 
-	for _, r := range t.UnsafeRoutes {
+	for _, r := range t.Routes {
+		if r.Via == nil {
+			// We don't allow route MTUs so only install routes with a via
+			continue
+		}
+
 		err = exec.Command(
 		err = exec.Command(
 			"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric),
 			"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric),
 		).Run()
 		).Run()
@@ -81,6 +95,15 @@ func (t *waterTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
+func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
+	r := t.cidrTree.MostSpecificContains(ip)
+	if r != nil {
+		return r.(iputil.VpnIp)
+	}
+
+	return 0
+}
+
 func (t *waterTun) CidrNet() *net.IPNet {
 func (t *waterTun) CidrNet() *net.IPNet {
 	return t.Cidr
 	return t.Cidr
 }
 }

+ 4 - 4
overlay/tun_windows.go

@@ -14,11 +14,11 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 )
 )
 
 
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ []Route, _ int) (Device, error) {
+func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 }
 }
 
 
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, unsafeRoutes []Route, _ int, _ bool) (Device, error) {
+func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) {
 	if len(routes) > 0 {
 	if len(routes) > 0 {
 		return nil, fmt.Errorf("route MTU not supported in Windows")
 		return nil, fmt.Errorf("route MTU not supported in Windows")
 	}
 	}
@@ -30,14 +30,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
 	}
 	}
 
 
 	if useWintun {
 	if useWintun {
-		device, err := newWinTun(deviceName, cidr, defaultMTU, unsafeRoutes)
+		device, err := newWinTun(deviceName, cidr, defaultMTU, routes)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("create Wintun interface failed, %w", err)
 			return nil, fmt.Errorf("create Wintun interface failed, %w", err)
 		}
 		}
 		return device, nil
 		return device, nil
 	}
 	}
 
 
-	device, err := newWaterTun(cidr, defaultMTU, unsafeRoutes)
+	device, err := newWaterTun(cidr, defaultMTU, routes)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("create wintap driver failed, %w", err)
 		return nil, fmt.Errorf("create wintap driver failed, %w", err)
 	}
 	}

+ 35 - 12
overlay/tun_wintun_windows.go

@@ -7,6 +7,8 @@ import (
 	"net"
 	"net"
 	"unsafe"
 	"unsafe"
 
 
+	"github.com/slackhq/nebula/cidr"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/wintun"
 	"github.com/slackhq/nebula/wintun"
 	"golang.org/x/sys/windows"
 	"golang.org/x/sys/windows"
 	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
 	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -15,10 +17,11 @@ import (
 const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
 const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
 
 
 type winTun struct {
 type winTun struct {
-	Device       string
-	Cidr         *net.IPNet
-	MTU          int
-	UnsafeRoutes []Route
+	Device   string
+	Cidr     *net.IPNet
+	MTU      int
+	Routes   []Route
+	cidrTree *cidr.Tree4
 
 
 	tun *wintun.NativeTun
 	tun *wintun.NativeTun
 }
 }
@@ -42,7 +45,7 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
 	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
 	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
 }
 }
 
 
-func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, unsafeRoutes []Route) (*winTun, error) {
+func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
 	guid, err := generateGUIDByDeviceName(deviceName)
 	guid, err := generateGUIDByDeviceName(deviceName)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("generate GUID failed: %w", err)
 		return nil, fmt.Errorf("generate GUID failed: %w", err)
@@ -53,11 +56,17 @@ func newWinTun(deviceName string, cidr *net.IPNet, defaultMTU int, unsafeRoutes
 		return nil, fmt.Errorf("create TUN device failed: %w", err)
 		return nil, fmt.Errorf("create TUN device failed: %w", err)
 	}
 	}
 
 
+	cidrTree, err := makeCidrTree(routes, false)
+	if err != nil {
+		return nil, err
+	}
+
 	return &winTun{
 	return &winTun{
-		Device:       deviceName,
-		Cidr:         cidr,
-		MTU:          defaultMTU,
-		UnsafeRoutes: unsafeRoutes,
+		Device:   deviceName,
+		Cidr:     cidr,
+		MTU:      defaultMTU,
+		Routes:   routes,
+		cidrTree: cidrTree,
 
 
 		tun: tunDevice.(*wintun.NativeTun),
 		tun: tunDevice.(*wintun.NativeTun),
 	}, nil
 	}, nil
@@ -71,11 +80,16 @@ func (t *winTun) Activate() error {
 	}
 	}
 
 
 	foundDefault4 := false
 	foundDefault4 := false
-	routes := make([]*winipcfg.RouteData, 0, len(t.UnsafeRoutes)+1)
+	routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1)
+
+	for _, r := range t.Routes {
+		if r.Via == nil {
+			// We don't allow route MTUs so only install routes with a via
+			continue
+		}
 
 
-	for _, r := range t.UnsafeRoutes {
 		if !foundDefault4 {
 		if !foundDefault4 {
-			if cidr, bits := r.Cidr.Mask.Size(); cidr == 0 && bits != 0 {
+			if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
 				foundDefault4 = true
 				foundDefault4 = true
 			}
 			}
 		}
 		}
@@ -110,6 +124,15 @@ func (t *winTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
+func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
+	r := t.cidrTree.MostSpecificContains(ip)
+	if r != nil {
+		return r.(iputil.VpnIp)
+	}
+
+	return 0
+}
+
 func (t *winTun) CidrNet() *net.IPNet {
 func (t *winTun) CidrNet() *net.IPNet {
 	return t.Cidr
 	return t.Cidr
 }
 }