Browse Source

Implement ECMP for unsafe_routes (#1332)

dioss-Machiel 4 months ago
parent
commit
f86953ca56

+ 22 - 1
examples/config.yml

@@ -239,7 +239,28 @@ tun:
 
   # Unsafe routes allows you to route traffic over nebula to non-nebula nodes
   # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
-  # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
+  # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
+  # NOTES:
+  # * You will only see a single gateway in the routing table if you are not on linux
+  # * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
+  #
+  # unsafe_routes:
+  # # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
+  # - route: 192.168.87.0/24
+  #   via:
+  #     - gateway: 10.0.0.1
+  #     - gateway: 10.0.0.2
+  #     - gateway: 10.0.0.3
+  # # Multiple gateways with a weight, this will balance traffic accordingly
+  # - route: 192.168.87.0/24
+  #   via:
+  #     - gateway: 10.0.0.1
+  #       weight: 10
+  #     - gateway: 10.0.0.2
+  #       weight: 5
+  #
+  # NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
+  # `via`: single node or list of gateways to use for this route
   # `mtu`: will default to tun mtu if this option is not specified
   # `metric`: will default to 0 if this option is not specified
   # `install`: will default to true, controls whether this route is installed in the systems routing table.

+ 83 - 10
inside.go

@@ -8,6 +8,7 @@ import (
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/noiseutil"
+	"github.com/slackhq/nebula/routing"
 )
 
 func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -49,7 +50,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 	}
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
+	hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 
@@ -121,22 +122,94 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 }
 
+// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
 func (f *Interface) Handshake(vpnAddr netip.Addr) {
-	f.getOrHandshake(vpnAddr, nil)
+	f.getOrHandshakeNoRouting(vpnAddr, nil)
 }
 
-// getOrHandshake returns nil if the vpnAddr is not routable.
+// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
-	if !found {
-		vpnAddr = f.inside.RouteFor(vpnAddr)
-		if !vpnAddr.IsValid() {
-			return nil, false
+	if found {
+		return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
+	}
+
+	return nil, false
+}
+
+// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
+// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
+func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+
+	destinationAddr := fwPacket.RemoteAddr
+
+	hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
+
+	// Host is inside the mesh, no routing required
+	if hostinfo != nil {
+		return hostinfo, ready
+	}
+
+	gateways := f.inside.RoutesFor(destinationAddr)
+
+	switch len(gateways) {
+	case 0:
+		return nil, false
+	case 1:
+		// Single gateway route
+		return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
+	default:
+		// Multi gateway route, perform ECMP categorization
+		gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
+
+		if !balancingOk {
+			// This happens if the gateway buckets were not calculated, this _should_ never happen
+			f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
 		}
+
+		var handshakeInfoForChosenGateway *HandshakeHostInfo
+		var hhReceiver = func(hh *HandshakeHostInfo) {
+			handshakeInfoForChosenGateway = hh
+		}
+
+		// Store the handshakeHostInfo for later.
+		// If this node is not reachable we will attempt other nodes, if none are reachable we will
+		// cache the packet for this gateway.
+		if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
+			return hostinfo, true
+		}
+
+		// It appears the selected gateway cannot be reached, find another gateway to fallback on.
+		// The current implementation breaks ECMP but that seems better than no connectivity.
+		// If ECMP is also required when a gateway is down then connectivity status
+		// for each gateway needs to be kept and the weights recalculated when they go up or down.
+		// This would also need to interact with unsafe_route updates through reloading the config or
+		// use of the use_system_route_table option
+
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("destination", destinationAddr).
+				WithField("originalGateway", gatewayAddr).
+				Debugln("Calculated gateway for ECMP not available, attempting other gateways")
+		}
+
+		for i := range gateways {
+			// Skip the gateway that failed previously
+			if gateways[i].Addr() == gatewayAddr {
+				continue
+			}
+
+			// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
+			if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
+				return hostinfo, true
+			}
+		}
+
+		// No gateways reachable, cache the packet in the originally chosen gateway
+		cacheCallback(handshakeInfoForChosenGateway)
+		return hostinfo, false
 	}
 
-	return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
 }
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -163,7 +236,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 
 // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
 func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
+	hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 

+ 3 - 1
overlay/device.go

@@ -3,6 +3,8 @@ package overlay
 import (
 	"io"
 	"net/netip"
+
+	"github.com/slackhq/nebula/routing"
 )
 
 type Device interface {
@@ -10,6 +12,6 @@ type Device interface {
 	Activate() error
 	Networks() []netip.Prefix
 	Name() string
-	RouteFor(netip.Addr) netip.Addr
+	RoutesFor(netip.Addr) routing.Gateways
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 }

+ 65 - 13
overlay/route.go

@@ -11,13 +11,14 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 )
 
 type Route struct {
 	MTU     int
 	Metric  int
 	Cidr    netip.Prefix
-	Via     netip.Addr
+	Via     routing.Gateways
 	Install bool
 }
 
@@ -47,15 +48,17 @@ func (r Route) String() string {
 	return s
 }
 
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
-	routeTree := new(bart.Table[netip.Addr])
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
+	routeTree := new(bart.Table[routing.Gateways])
 	for _, r := range routes {
 		if !allowMTU && r.MTU > 0 {
 			l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
 		}
 
-		if r.Via.IsValid() {
-			routeTree.Insert(r.Cidr, r.Via)
+		gateways := r.Via
+		if len(gateways) > 0 {
+			routing.CalculateBucketsForGateways(gateways)
+			routeTree.Insert(r.Cidr, gateways)
 		}
 	}
 	return routeTree, nil
@@ -201,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
 		}
 
-		via, ok := rVia.(string)
-		if !ok {
-			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
-		}
+		var gateways routing.Gateways
 
-		viaVpnIp, err := netip.ParseAddr(via)
-		if err != nil {
-			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
+		switch via := rVia.(type) {
+		case string:
+			viaIp, err := netip.ParseAddr(via)
+			if err != nil {
+				return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
+			}
+
+			gateways = routing.Gateways{routing.NewGateway(viaIp, 1)}
+
+		case []interface{}:
+			gateways = make(routing.Gateways, len(via))
+			for ig, v := range via {
+				gatewayMap, ok := v.(map[interface{}]interface{})
+				if !ok {
+					return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1)
+				}
+
+				rGateway, ok := gatewayMap["gateway"]
+				if !ok {
+					return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1)
+				}
+
+				parsedGateway, ok := rGateway.(string)
+				if !ok {
+					return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1)
+				}
+
+				gatewayIp, err := netip.ParseAddr(parsedGateway)
+				if err != nil {
+					return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err)
+				}
+
+				rGatewayWeight, ok := gatewayMap["weight"]
+				if !ok {
+					rGatewayWeight = 1
+				}
+
+				gatewayWeight, ok := rGatewayWeight.(int)
+				if !ok {
+					_, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32)
+					if err != nil {
+						return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1)
+					}
+				}
+
+				if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 {
+					return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight)
+				}
+
+				gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight)
+
+			}
+
+		default:
+			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia)
 		}
 
 		rRoute, ok := m["route"]
@@ -226,7 +278,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 		}
 
 		r := Route{
-			Via:     viaVpnIp,
+			Via:     gateways,
 			MTU:     mtu,
 			Metric:  metric,
 			Install: install,

+ 109 - 3
overlay/route_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
@@ -158,15 +159,39 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
 		routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 		assert.Nil(t, routes)
-		require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
+		require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue))
 	}
 
+	// Unparsable list of via
+	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": []string{"1", "2"}}}}
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
+	assert.Nil(t, routes)
+	require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string")
+
 	// unparsable via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
+	// unparsable gateway
+	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "1"}}}}}
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
+	assert.Nil(t, routes)
+	require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP")
+
+	// missing gateway element
+	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"weight": "1"}}}}}
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
+	assert.Nil(t, routes)
+	require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present")
+
+	// unparsable weight element
+	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "10.0.0.1", "weight": "a"}}}}}
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
+	assert.Nil(t, routes)
+	require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer")
+
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
@@ -280,7 +305,7 @@ func Test_makeRouteTree(t *testing.T) {
 
 	nip, err := netip.ParseAddr("192.168.0.1")
 	require.NoError(t, err)
-	assert.Equal(t, nip, r)
+	assert.Equal(t, nip, r[0].Addr())
 
 	ip, err = netip.ParseAddr("1.0.0.1")
 	require.NoError(t, err)
@@ -289,10 +314,91 @@ func Test_makeRouteTree(t *testing.T) {
 
 	nip, err = netip.ParseAddr("192.168.0.2")
 	require.NoError(t, err)
-	assert.Equal(t, nip, r)
+	assert.Equal(t, nip, r[0].Addr())
 
 	ip, err = netip.ParseAddr("1.1.0.1")
 	require.NoError(t, err)
 	r, ok = routeTree.Lookup(ip)
 	assert.False(t, ok)
 }
+
+func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
+	l := test.NewLogger()
+	c := config.NewC(l)
+	n, err := netip.ParsePrefix("10.0.0.0/24")
+	require.NoError(t, err)
+
+	c.Settings["tun"] = map[interface{}]interface{}{
+		"unsafe_routes": []interface{}{
+			map[interface{}]interface{}{
+				"route": "192.168.86.0/24",
+				"via":   "192.168.100.10",
+			},
+			map[interface{}]interface{}{
+				"route": "192.168.87.0/24",
+				"via": []interface{}{
+					map[interface{}]interface{}{
+						"gateway": "10.0.0.1",
+					},
+					map[interface{}]interface{}{
+						"gateway": "10.0.0.2",
+					},
+					map[interface{}]interface{}{
+						"gateway": "10.0.0.3",
+					},
+				},
+			},
+			map[interface{}]interface{}{
+				"route": "192.168.89.0/24",
+				"via": []interface{}{
+					map[interface{}]interface{}{
+						"gateway": "10.0.0.1",
+						"weight":  10,
+					},
+					map[interface{}]interface{}{
+						"gateway": "10.0.0.2",
+						"weight":  5,
+					},
+				},
+			},
+		},
+	}
+
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
+	require.NoError(t, err)
+	assert.Len(t, routes, 3)
+	routeTree, err := makeRouteTree(l, routes, true)
+	require.NoError(t, err)
+
+	ip, err := netip.ParseAddr("192.168.86.1")
+	require.NoError(t, err)
+	r, ok := routeTree.Lookup(ip)
+	assert.True(t, ok)
+
+	nip, err := netip.ParseAddr("192.168.100.10")
+	require.NoError(t, err)
+	assert.Equal(t, nip, r[0].Addr())
+
+	ip, err = netip.ParseAddr("192.168.87.1")
+	require.NoError(t, err)
+	r, ok = routeTree.Lookup(ip)
+	assert.True(t, ok)
+
+	expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1),
+		routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1),
+		routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)}
+
+	routing.CalculateBucketsForGateways(expectedGateways)
+	assert.ElementsMatch(t, expectedGateways, r)
+
+	ip, err = netip.ParseAddr("192.168.89.1")
+	require.NoError(t, err)
+	r, ok = routeTree.Lookup(ip)
+	assert.True(t, ok)
+
+	expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10),
+		routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)}
+
+	routing.CalculateBucketsForGateways(expectedGateways)
+	assert.ElementsMatch(t, expectedGateways, r)
+}

+ 3 - 2
overlay/tun_android.go

@@ -13,6 +13,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -21,7 +22,7 @@ type tun struct {
 	fd          int
 	vpnNetworks []netip.Prefix
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 }
 
@@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }

+ 5 - 4
overlay/tun_darwin.go

@@ -17,6 +17,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 	netroute "golang.org/x/net/route"
 	"golang.org/x/sys/unix"
@@ -28,7 +29,7 @@ type tun struct {
 	vpnNetworks []netip.Prefix
 	DefaultMTU  int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	linkAddr    *netroute.LinkAddr
 	l           *logrus.Logger
 
@@ -342,12 +343,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, ok := t.routeTree.Load().Lookup(ip)
 	if ok {
 		return r
 	}
-	return netip.Addr{}
+	return routing.Gateways{}
 }
 
 // Get the LinkAddr for the interface of the given name
@@ -382,7 +383,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 3 - 2
overlay/tun_disabled.go

@@ -9,6 +9,7 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/routing"
 )
 
 type disabledTun struct {
@@ -43,8 +44,8 @@ func (*disabledTun) Activate() error {
 	return nil
 }
 
-func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
-	return netip.Addr{}
+func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways {
+	return routing.Gateways{}
 }
 
 func (t *disabledTun) Networks() []netip.Prefix {

+ 4 - 3
overlay/tun_freebsd.go

@@ -20,6 +20,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -50,7 +51,7 @@ type tun struct {
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 
 	io.ReadWriteCloser
@@ -242,7 +243,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
@@ -262,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 3 - 2
overlay/tun_ios.go

@@ -16,6 +16,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -23,7 +24,7 @@ type tun struct {
 	io.ReadWriteCloser
 	vpnNetworks []netip.Prefix
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 }
 
@@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }

+ 69 - 22
overlay/tun_linux.go

@@ -17,6 +17,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
@@ -34,7 +35,7 @@ type tun struct {
 	ioctlFd     uintptr
 
 	Routes          atomic.Pointer[[]Route]
-	routeTree       atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree       atomic.Pointer[bart.Table[routing.Gateways]]
 	routeChan       chan struct{}
 	useSystemRoutes bool
 
@@ -231,7 +232,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return file, nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
@@ -550,20 +551,7 @@ func (t *tun) watchRoutes() {
 	}()
 }
 
-func (t *tun) updateRoutes(r netlink.RouteUpdate) {
-	if r.Gw == nil {
-		// Not a gateway route, ignore
-		t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
-		return
-	}
-
-	gwAddr, ok := netip.AddrFromSlice(r.Gw)
-	if !ok {
-		t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
-		return
-	}
-
-	gwAddr = gwAddr.Unmap()
+func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
 	withinNetworks := false
 	for i := range t.vpnNetworks {
 		if t.vpnNetworks[i].Contains(gwAddr) {
@@ -571,9 +559,68 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 			break
 		}
 	}
-	if !withinNetworks {
-		// Gateway isn't in our overlay network, ignore
-		t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
+
+	return withinNetworks
+}
+
+func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
+
+	var gateways routing.Gateways
+
+	link, err := netlink.LinkByName(t.Device)
+	if err != nil {
+		t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
+		return gateways
+	}
+
+	// If this route is relevant to our interface and there is a gateway then add it
+	if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
+		gwAddr, ok := netip.AddrFromSlice(r.Gw)
+		if !ok {
+			t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
+		} else {
+			gwAddr = gwAddr.Unmap()
+
+			if !t.isGatewayInVpnNetworks(gwAddr) {
+				// Gateway isn't in our overlay network, ignore
+				t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
+			} else {
+				gateways = append(gateways, routing.NewGateway(gwAddr, 1))
+			}
+		}
+	}
+
+	for _, p := range r.MultiPath {
+		// If this route is relevant to our interface and there is a gateway then add it
+		if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
+			gwAddr, ok := netip.AddrFromSlice(p.Gw)
+			if !ok {
+				t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
+			} else {
+				gwAddr = gwAddr.Unmap()
+
+				if !t.isGatewayInVpnNetworks(gwAddr) {
+					// Gateway isn't in our overlay network, ignore
+					t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
+				} else {
+					// p.Hops+1 = weight of the route
+					gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
+				}
+			}
+		}
+	}
+
+	routing.CalculateBucketsForGateways(gateways)
+	return gateways
+}
+
+func (t *tun) updateRoutes(r netlink.RouteUpdate) {
+
+	gateways := t.getGatewaysFromRoute(&r.Route)
+
+	if len(gateways) == 0 {
+		// No gateways relevant to our network, no routing changes required.
+		t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
 		return
 	}
 
@@ -589,12 +636,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	newTree := t.routeTree.Load().Clone()
 
 	if r.Type == unix.RTM_NEWROUTE {
-		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
-		newTree.Insert(dst, gwAddr)
+		t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
+		newTree.Insert(dst, gateways)
 
 	} else {
+		t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
 		newTree.Delete(dst)
-		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
 	}
 	t.routeTree.Store(newTree)
 }

+ 4 - 3
overlay/tun_netbsd.go

@@ -18,6 +18,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -31,7 +32,7 @@ type tun struct {
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 
 	io.ReadWriteCloser
@@ -177,7 +178,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
@@ -197,7 +198,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 4 - 3
overlay/tun_openbsd.go

@@ -17,6 +17,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -25,7 +26,7 @@ type tun struct {
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 
 	io.ReadWriteCloser
@@ -158,7 +159,7 @@ func (t *tun) Activate() error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
@@ -166,7 +167,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 3 - 2
overlay/tun_tester.go

@@ -13,13 +13,14 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 )
 
 type TestTun struct {
 	Device      string
 	vpnNetworks []netip.Prefix
 	Routes      []Route
-	routeTree   *bart.Table[netip.Addr]
+	routeTree   *bart.Table[routing.Gateways]
 	l           *logrus.Logger
 
 	closed    atomic.Bool
@@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte {
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 
-func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Lookup(ip)
 	return r
 }

+ 10 - 5
overlay/tun_windows.go

@@ -18,6 +18,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/wintun"
 	"golang.org/x/sys/windows"
@@ -31,7 +32,7 @@ type winTun struct {
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 
 	tun *wintun.NativeTun
@@ -147,13 +148,16 @@ func (t *winTun) addRoutes(logErrors bool) error {
 	foundDefault4 := false
 
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
 		// Add our unsafe route
-		err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
+		// Windows does not support multipath routes natively, so we install only a single route.
+		// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
+		// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
+		err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
 		if err != nil {
 			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
 			if logErrors {
@@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		err := luid.DeleteRoute(r.Cidr, r.Via)
+		// See comment on luid.AddRoute
+		err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
 		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
@@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }

+ 8 - 3
overlay/user.go

@@ -6,6 +6,7 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 )
 
 func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
@@ -38,9 +39,13 @@ type UserDevice struct {
 func (d *UserDevice) Activate() error {
 	return nil
 }
-func (d *UserDevice) Networks() []netip.Prefix          { return d.vpnNetworks }
-func (d *UserDevice) Name() string                      { return "faketun0" }
-func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
+
+func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
+func (d *UserDevice) Name() string             { return "faketun0" }
+func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
+	return routing.Gateways{routing.NewGateway(ip, 1)}
+}
+
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return d, nil
 }

+ 39 - 0
routing/balance.go

@@ -0,0 +1,39 @@
+package routing
+
+import (
+	"net/netip"
+
+	"github.com/slackhq/nebula/firewall"
+)
+
+// Hashes the packet source and destination port and always returns a positive integer
+// Based on 'Prospecting for Hash Functions'
+//   - https://nullprogram.com/blog/2018/07/31/
+//   - https://github.com/skeeto/hash-prospector
+//     [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
+func hashPacket(p *firewall.Packet) int {
+	x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
+	x ^= x >> 16
+	x *= 0x21f0aaad
+	x ^= x >> 15
+	x *= 0xd35a2d97
+	x ^= x >> 15
+
+	return int(x) & 0x7FFFFFFF
+}
+
+// For this function to work correctly it requires that the buckets for the gateways have been calculated
+// If the contract is violated balancing will not work properly and the second return value will return false
+func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) {
+	hash := hashPacket(fwPacket)
+
+	for i := range gateways {
+		if hash <= gateways[i].BucketUpperBound() {
+			return gateways[i].Addr(), true
+		}
+	}
+
+	// If you land here then the buckets for the gateways are not properly calculated
+	// Fallback to random routing and let the caller know
+	return gateways[hash%len(gateways)].Addr(), false
+}

+ 144 - 0
routing/balance_test.go

@@ -0,0 +1,144 @@
+package routing
+
+import (
+	"net/netip"
+	"testing"
+
+	"github.com/slackhq/nebula/firewall"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestPacketsAreBalancedEqually(t *testing.T) {
+
+	gateways := []Gateway{}
+
+	gw1Addr := netip.MustParseAddr("1.0.0.1")
+	gw2Addr := netip.MustParseAddr("1.0.0.2")
+	gw3Addr := netip.MustParseAddr("1.0.0.3")
+
+	gateways = append(gateways, NewGateway(gw1Addr, 1))
+	gateways = append(gateways, NewGateway(gw2Addr, 1))
+	gateways = append(gateways, NewGateway(gw3Addr, 1))
+
+	CalculateBucketsForGateways(gateways)
+
+	gw1count := 0
+	gw2count := 0
+	gw3count := 0
+
+	iterationCount := uint16(65535)
+	for i := uint16(0); i < iterationCount; i++ {
+		packet := firewall.Packet{
+			LocalAddr:  netip.MustParseAddr("192.168.1.1"),
+			RemoteAddr: netip.MustParseAddr("10.0.0.1"),
+			LocalPort:  i,
+			RemotePort: 65535 - i,
+			Protocol:   6, // TCP
+			Fragment:   false,
+		}
+
+		selectedGw, ok := BalancePacket(&packet, gateways)
+		assert.True(t, ok)
+
+		switch selectedGw {
+		case gw1Addr:
+			gw1count += 1
+		case gw2Addr:
+			gw2count += 1
+		case gw3Addr:
+			gw3count += 1
+		}
+
+	}
+
+	// Assert packets are balanced, allow variation of up to 100 packets per gateway
+	assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
+	assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
+	assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
+
+}
+
+func TestPacketsAreBalancedByPriority(t *testing.T) {
+
+	gateways := []Gateway{}
+
+	gw1Addr := netip.MustParseAddr("1.0.0.1")
+	gw2Addr := netip.MustParseAddr("1.0.0.2")
+
+	gateways = append(gateways, NewGateway(gw1Addr, 10))
+	gateways = append(gateways, NewGateway(gw2Addr, 5))
+
+	CalculateBucketsForGateways(gateways)
+
+	gw1count := 0
+	gw2count := 0
+
+	iterationCount := uint16(65535)
+	for i := uint16(0); i < iterationCount; i++ {
+		packet := firewall.Packet{
+			LocalAddr:  netip.MustParseAddr("192.168.1.1"),
+			RemoteAddr: netip.MustParseAddr("10.0.0.1"),
+			LocalPort:  i,
+			RemotePort: 65535 - i,
+			Protocol:   6, // TCP
+			Fragment:   false,
+		}
+
+		selectedGw, ok := BalancePacket(&packet, gateways)
+		assert.True(t, ok)
+
+		switch selectedGw {
+		case gw1Addr:
+			gw1count += 1
+		case gw2Addr:
+			gw2count += 1
+		}
+
+	}
+
+	iterationCountAsFloat := float32(iterationCount)
+
+	assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count)
+	assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count)
+}
+
+func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) {
+	gateways := []Gateway{}
+
+	gw1Addr := netip.MustParseAddr("1.0.0.1")
+	gw2Addr := netip.MustParseAddr("1.0.0.2")
+
+	gateways = append(gateways, NewGateway(gw1Addr, 10))
+	gateways = append(gateways, NewGateway(gw2Addr, 5))
+
+	iterationCount := uint16(65535)
+	gw1count := 0
+	gw2count := 0
+
+	for i := uint16(0); i < iterationCount; i++ {
+		packet := firewall.Packet{
+			LocalAddr:  netip.MustParseAddr("192.168.1.1"),
+			RemoteAddr: netip.MustParseAddr("10.0.0.1"),
+			LocalPort:  i,
+			RemotePort: 65535 - i,
+			Protocol:   6, // TCP
+			Fragment:   false,
+		}
+
+		selectedGw, ok := BalancePacket(&packet, gateways)
+		assert.False(t, ok)
+
+		switch selectedGw {
+		case gw1Addr:
+			gw1count += 1
+		case gw2Addr:
+			gw2count += 1
+		}
+
+	}
+
+	assert.Equal(t, int(iterationCount), (gw1count + gw2count))
+	assert.NotEqual(t, 0, gw1count)
+	assert.NotEqual(t, 0, gw2count)
+
+}

+ 70 - 0
routing/gateway.go

@@ -0,0 +1,70 @@
+package routing
+
+import (
+	"fmt"
+	"net/netip"
+)
+
+const (
+	// Sentinal value
+	BucketNotCalculated = -1
+)
+
+type Gateways []Gateway
+
+func (g Gateways) String() string {
+	str := ""
+	for i, gw := range g {
+		str += gw.String()
+		if i < len(g)-1 {
+			str += ", "
+		}
+	}
+	return str
+}
+
+type Gateway struct {
+	addr             netip.Addr
+	weight           int
+	bucketUpperBound int
+}
+
+func NewGateway(addr netip.Addr, weight int) Gateway {
+	return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated}
+}
+
+func (g *Gateway) BucketUpperBound() int {
+	return g.bucketUpperBound
+}
+
+func (g *Gateway) Addr() netip.Addr {
+	return g.addr
+}
+
+func (g *Gateway) String() string {
+	return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight)
+}
+
+// Divide and round to nearest integer
+func divideAndRound(v uint64, d uint64) uint64 {
+	var tmp uint64 = v + d/2
+	return tmp / d
+}
+
+// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel.
+// After this function returns each gateway will have a
+// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX)
+func CalculateBucketsForGateways(gateways []Gateway) {
+
+	var totalWeight int = 0
+	for i := range gateways {
+		totalWeight += gateways[i].weight
+	}
+
+	var loopWeight int = 0
+	for i := range gateways {
+		loopWeight += gateways[i].weight
+		gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1
+	}
+
+}

+ 34 - 0
routing/gateway_test.go

@@ -0,0 +1,34 @@
+package routing
+
+import (
+	"net/netip"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestRebalance3_2Split(t *testing.T) {
+	gateways := []Gateway{}
+
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10})
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5})
+
+	CalculateBucketsForGateways(gateways)
+
+	assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2
+	assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX
+}
+
+func TestRebalanceEqualSplit(t *testing.T) {
+	gateways := []Gateway{}
+
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
+
+	CalculateBucketsForGateways(gateways)
+
+	assert.Equal(t, 715827882, gateways[0].bucketUpperBound)  // INT_MAX/3
+	assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2
+	assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX
+}

+ 4 - 2
test/tun.go

@@ -4,12 +4,14 @@ import (
 	"errors"
 	"io"
 	"net/netip"
+
+	"github.com/slackhq/nebula/routing"
 )
 
 type NoopTun struct{}
 
-func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
-	return netip.Addr{}
+func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
+	return routing.Gateways{}
 }
 
 func (NoopTun) Activate() error {