Browse Source

overlay: fix tun.RouteFor getting *net.IP (#595)

tun.RouteFor expects the routeTree to have an iputil.VpnIp inside of it
instead of a *net.IP.
Wade Simmons 3 years ago
parent
commit
f60ed2b36d
3 changed files with 40 additions and 4 deletions
  1. 6 3
      overlay/route.go
  2. 33 0
      overlay/route_test.go
  3. 1 1
      overlay/tun_wintun_windows.go

+ 6 - 3
overlay/route.go

@@ -9,13 +9,14 @@ import (
 
 
 	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type Route struct {
 type Route struct {
 	MTU    int
 	MTU    int
 	Metric int
 	Metric int
 	Cidr   *net.IPNet
 	Cidr   *net.IPNet
-	Via    *net.IP
+	Via    *iputil.VpnIp
 }
 }
 
 
 func makeRouteTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) {
 func makeRouteTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) {
@@ -26,7 +27,7 @@ func makeRouteTree(routes []Route, allowMTU bool) (*cidr.Tree4, error) {
 		}
 		}
 
 
 		if r.Via != nil {
 		if r.Via != nil {
-			routeTree.AddCIDR(r.Cidr, r.Via)
+			routeTree.AddCIDR(r.Cidr, *r.Via)
 		}
 		}
 	}
 	}
 	return routeTree, nil
 	return routeTree, nil
@@ -180,8 +181,10 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
 		}
 		}
 
 
+		viaVpnIp := iputil.Ip2VpnIp(nVia)
+
 		r := Route{
 		r := Route{
-			Via:    &nVia,
+			Via:    &viaVpnIp,
 			MTU:    mtu,
 			MTU:    mtu,
 			Metric: metric,
 			Metric: metric,
 		}
 		}

+ 33 - 0
overlay/route_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 	"testing"
 
 
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
@@ -235,3 +236,35 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		t.Fatal("Did not see both unsafe_routes")
 		t.Fatal("Did not see both unsafe_routes")
 	}
 	}
 }
 }
+
+func Test_makeRouteTree(t *testing.T) {
+	l := test.NewLogger()
+	c := config.NewC(l)
+	_, n, _ := net.ParseCIDR("10.0.0.0/24")
+
+	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
+		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
+		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
+	}}
+	routes, err := parseUnsafeRoutes(c, n)
+	assert.NoError(t, err)
+	assert.Len(t, routes, 2)
+	routeTree, err := makeRouteTree(routes, true)
+	assert.NoError(t, err)
+
+	ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
+	r := routeTree.MostSpecificContains(ip)
+	assert.NotNil(t, r)
+	assert.IsType(t, iputil.VpnIp(0), r)
+	assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
+
+	ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
+	r = routeTree.MostSpecificContains(ip)
+	assert.NotNil(t, r)
+	assert.IsType(t, iputil.VpnIp(0), r)
+	assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
+
+	ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
+	r = routeTree.MostSpecificContains(ip)
+	assert.Nil(t, r)
+}

+ 1 - 1
overlay/tun_wintun_windows.go

@@ -97,7 +97,7 @@ func (t *winTun) Activate() error {
 		// Add our unsafe route
 		// Add our unsafe route
 		routes = append(routes, &winipcfg.RouteData{
 		routes = append(routes, &winipcfg.RouteData{
 			Destination: *r.Cidr,
 			Destination: *r.Cidr,
-			NextHop:     *r.Via,
+			NextHop:     r.Via.ToIP(),
 			Metric:      uint32(r.Metric),
 			Metric:      uint32(r.Metric),
 		})
 		})
 	}
 	}